diff --git a/requirements/development.txt b/requirements/development.txt index af663457ea0..f2844f60be7 100644 --- a/requirements/development.txt +++ b/requirements/development.txt @@ -58,8 +58,6 @@ backoff==2.2.1 # via # -c requirements/base-constraint.txt # apache-superset -backports-tarfile==1.2.0 - # via jaraco-context bcrypt==4.3.0 # via # -c requirements/base-constraint.txt @@ -176,7 +174,6 @@ cryptography==44.0.3 # paramiko # pyjwt # pyopenssl - # secretstorage cycler==0.12.1 # via matplotlib cyclopts==4.2.4 @@ -228,7 +225,7 @@ et-xmlfile==2.0.0 # openpyxl exceptiongroup==1.3.0 # via fastmcp -fastmcp==2.13.1 +fastmcp==2.13.3 # via apache-superset filelock==3.12.2 # via virtualenv @@ -418,8 +415,6 @@ idna==3.10 # requests # trio # url-normalize -importlib-metadata==8.7.0 - # via keyring importlib-resources==6.5.2 # via prophet iniconfig==2.0.0 @@ -435,16 +430,6 @@ itsdangerous==2.2.0 # -c requirements/base-constraint.txt # flask # flask-wtf -jaraco-classes==3.4.0 - # via keyring -jaraco-context==6.0.1 - # via keyring -jaraco-functools==4.3.0 - # via keyring -jeepney==0.9.0 - # via - # keyring - # secretstorage jinja2==3.1.6 # via # -c requirements/base-constraint.txt @@ -471,8 +456,6 @@ jsonschema-specifications==2025.4.1 # -c requirements/base-constraint.txt # jsonschema # openapi-schema-validator -keyring==25.6.0 - # via py-key-value-aio kiwisolver==1.4.7 # via matplotlib kombu==5.5.3 @@ -530,10 +513,6 @@ mdurl==0.1.2 # via # -c requirements/base-constraint.txt # markdown-it-py -more-itertools==10.8.0 - # via - # jaraco-classes - # jaraco-functools msgpack==1.0.8 # via # -c requirements/base-constraint.txt @@ -888,8 +867,6 @@ rsa==4.9.1 # google-auth ruff==0.9.7 # via apache-superset -secretstorage==3.4.1 - # via keyring selenium==4.32.0 # via # -c requirements/base-constraint.txt @@ -1095,8 +1072,6 @@ xlsxwriter==3.0.9 # -c requirements/base-constraint.txt # apache-superset # pandas -zipp==3.23.0 - # via importlib-metadata zope-event==5.0 # via gevent zope-interface==5.4.0 diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index 0cfb8f299ee..06f68a33ea8 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -23,9 +23,10 @@ mcp from here and use @mcp.tool decorators. """ import logging -from typing import Any, Callable, Dict, List, Set +from typing import Any, Callable, Dict, List, Sequence, Set from fastmcp import FastMCP +from fastmcp.server.middleware import Middleware logger = logging.getLogger(__name__) @@ -146,6 +147,7 @@ def _build_mcp_kwargs( tools: List[Any] | None, include_tags: Set[str] | None, exclude_tags: Set[str] | None, + middleware: Sequence[Middleware] | None = None, **kwargs: Any, ) -> Dict[str, Any]: """Build FastMCP constructor arguments.""" @@ -165,6 +167,8 @@ def _build_mcp_kwargs( mcp_kwargs["include_tags"] = include_tags if exclude_tags is not None: mcp_kwargs["exclude_tags"] = exclude_tags + if middleware is not None: + mcp_kwargs["middleware"] = middleware # Add any additional kwargs mcp_kwargs.update(kwargs) @@ -206,6 +210,7 @@ def create_mcp_app( include_tags: Set[str] | None = None, exclude_tags: Set[str] | None = None, config: Dict[str, Any] | None = None, + middleware: Sequence[Middleware] | None = None, **kwargs: Any, ) -> FastMCP: """ @@ -225,6 +230,7 @@ def create_mcp_app( include_tags: Set of tags to include (whitelist) exclude_tags: Set of tags to exclude (blacklist) config: Additional configuration dictionary + middleware: Sequence of middleware to apply to the server **kwargs: Additional FastMCP constructor arguments Returns: @@ -244,7 +250,15 @@ def create_mcp_app( # Build FastMCP constructor arguments mcp_kwargs = _build_mcp_kwargs( - name, instructions, auth, lifespan, tools, include_tags, exclude_tags, **kwargs + name, + instructions, + auth, + lifespan, + tools, + include_tags, + exclude_tags, + middleware, + **kwargs, ) # Create the FastMCP instance @@ -321,6 +335,7 @@ def init_fastmcp_server( include_tags: Set[str] | None = None, exclude_tags: Set[str] | None = None, config: Dict[str, Any] | None = None, + middleware: Sequence[Middleware] | None = None, **kwargs: Any, ) -> FastMCP: """ @@ -334,6 +349,7 @@ def init_fastmcp_server( name: Server name (defaults to "{APP_NAME} MCP Server") instructions: Custom instructions (defaults to branded with APP_NAME) auth, lifespan, tools, include_tags, exclude_tags, config: FastMCP configuration + middleware: Sequence of middleware to apply to the server **kwargs: Additional FastMCP configuration Returns: @@ -364,6 +380,7 @@ def init_fastmcp_server( include_tags is not None, exclude_tags is not None, config is not None, + middleware is not None, kwargs, ] ) @@ -379,6 +396,7 @@ def init_fastmcp_server( include_tags=include_tags, exclude_tags=exclude_tags, config=config, + middleware=middleware, **kwargs, ) else: diff --git a/superset/mcp_service/auth.py b/superset/mcp_service/auth.py index f70b1623d65..129425eb04c 100644 --- a/superset/mcp_service/auth.py +++ b/superset/mcp_service/auth.py @@ -186,6 +186,7 @@ def mcp_auth_hook(tool_func: F) -> F: """ import functools import inspect + import types is_async = inspect.iscoroutinefunction(tool_func) @@ -209,7 +210,7 @@ def mcp_auth_hook(tool_func: F) -> F: finally: _cleanup_session_finally() - return async_wrapper # type: ignore[return-value] + wrapper = async_wrapper else: @@ -231,4 +232,65 @@ def mcp_auth_hook(tool_func: F) -> F: finally: _cleanup_session_finally() - return sync_wrapper # type: ignore[return-value] + wrapper = sync_wrapper + + # Merge original function's __globals__ into wrapper's __globals__ + # This allows get_type_hints() to resolve type annotations from the + # original module (e.g., Context from fastmcp) + # FastMCP 2.13.2+ uses get_type_hints() which needs access to these types + merged_globals = {**wrapper.__globals__, **tool_func.__globals__} # type: ignore[attr-defined] + new_wrapper = types.FunctionType( + wrapper.__code__, # type: ignore[attr-defined] + merged_globals, + wrapper.__name__, + wrapper.__defaults__, # type: ignore[attr-defined] + wrapper.__closure__, # type: ignore[attr-defined] + ) + # Copy __dict__ but exclude __wrapped__ + # NOTE: We intentionally do NOT preserve __wrapped__ here. + # Setting __wrapped__ causes inspect.signature() to follow the chain + # and find 'ctx' in the original function's signature, even after + # FastMCP's create_function_without_params removes it from annotations. + # This breaks Pydantic's TypeAdapter which expects signature params + # to match type_hints. + new_wrapper.__dict__.update( + {k: v for k, v in wrapper.__dict__.items() if k != "__wrapped__"} + ) + new_wrapper.__module__ = wrapper.__module__ + new_wrapper.__qualname__ = wrapper.__qualname__ + new_wrapper.__annotations__ = wrapper.__annotations__ + # Copy docstring from original function (not wrapper, which may have lost it) + new_wrapper.__doc__ = tool_func.__doc__ + + # Set __signature__ from the original function, but: + # 1. Remove ctx parameter - FastMCP tools don't expose it to clients + # 2. Skip if original has *args (parse_request output has its own handling) + from fastmcp import Context as FMContext + + tool_sig = inspect.signature(tool_func) + has_var_positional = any( + p.kind == inspect.Parameter.VAR_POSITIONAL for p in tool_sig.parameters.values() + ) + + if not has_var_positional: + # For functions without *args, preserve signature but remove ctx + new_params = [] + for _name, param in tool_sig.parameters.items(): + # Skip ctx parameter - FastMCP tools don't expose it to clients + if param.annotation is FMContext or ( + hasattr(param.annotation, "__name__") + and param.annotation.__name__ == "Context" + ): + continue + new_params.append(param) + new_wrapper.__signature__ = tool_sig.replace( # type: ignore[attr-defined] + parameters=new_params + ) + + # Also remove ctx from annotations to match signature + if "ctx" in new_wrapper.__annotations__: + del new_wrapper.__annotations__["ctx"] + # For functions with *args (parse_request output), the signature + # is already set by parse_request without ctx. + + return new_wrapper # type: ignore[return-value] diff --git a/superset/mcp_service/caching.py b/superset/mcp_service/caching.py new file mode 100644 index 00000000000..d50f2761ab4 --- /dev/null +++ b/superset/mcp_service/caching.py @@ -0,0 +1,137 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +MCP response caching using FastMCP's native ResponseCachingMiddleware. +""" + +import logging +from typing import Any, Dict + +from superset.mcp_service.storage import get_mcp_store + +logger = logging.getLogger(__name__) + + +def _build_caching_settings(cache_config: Dict[str, Any]) -> Dict[str, Any]: + """ + Build FastMCP caching settings from MCP_CACHE_CONFIG. + + Maps our config format to FastMCP's settings objects: + - list_tools_ttl -> list_tools_settings + - list_resources_ttl -> list_resources_settings + - list_prompts_ttl -> list_prompts_settings + - read_resource_ttl -> read_resource_settings + - get_prompt_ttl -> get_prompt_settings + - call_tool_ttl + excluded_tools -> call_tool_settings + + TTL values are already integers (Python evaluates '60 * 5' at config load time). + + Args: + cache_config: MCP_CACHE_CONFIG dict + + Returns: + Dict of settings kwargs for ResponseCachingMiddleware + """ + settings: Dict[str, Any] = {} + + # List operations (default 5 min) + if "list_tools_ttl" in cache_config: + settings["list_tools_settings"] = {"ttl": cache_config["list_tools_ttl"]} + if "list_resources_ttl" in cache_config: + settings["list_resources_settings"] = { + "ttl": cache_config["list_resources_ttl"] + } + if "list_prompts_ttl" in cache_config: + settings["list_prompts_settings"] = {"ttl": cache_config["list_prompts_ttl"]} + + # Individual item operations (default 1 hour) + if "read_resource_ttl" in cache_config: + settings["read_resource_settings"] = {"ttl": cache_config["read_resource_ttl"]} + if "get_prompt_ttl" in cache_config: + settings["get_prompt_settings"] = {"ttl": cache_config["get_prompt_ttl"]} + + # Tool calls with exclusions + call_tool_settings: Dict[str, Any] = {} + if "call_tool_ttl" in cache_config: + call_tool_settings["ttl"] = cache_config["call_tool_ttl"] + if "excluded_tools" in cache_config: + call_tool_settings["excluded_tools"] = cache_config["excluded_tools"] + if call_tool_settings: + settings["call_tool_settings"] = call_tool_settings + + return settings + + +def create_response_caching_middleware() -> Any | None: + """ + Create ResponseCachingMiddleware with RedisStore backend. + + Uses MCP_CACHE_CONFIG for caching settings and prefix. + Uses get_mcp_store() factory for store creation. + + Returns: + ResponseCachingMiddleware instance or None if not configured/disabled + """ + from flask import has_app_context + + from superset.mcp_service.flask_singleton import get_flask_app + + flask_app = get_flask_app() + + def _create_middleware() -> Any | None: + cache_config = flask_app.config.get("MCP_CACHE_CONFIG", {}) + + if not cache_config.get("enabled", False): + logger.debug("MCP response caching disabled") + return None + + # Get cache-specific prefix from MCP_CACHE_CONFIG + cache_prefix = cache_config.get("CACHE_KEY_PREFIX") + if not cache_prefix: + logger.warning("MCP caching enabled but no CACHE_KEY_PREFIX configured") + return None + + # Create store with cache-specific prefix + store = get_mcp_store(prefix=cache_prefix) + if store is None: + return None + + try: + from fastmcp.server.middleware.caching import ResponseCachingMiddleware + except ImportError: + logger.warning( + "ResponseCachingMiddleware not available. Requires FastMCP >= 2.13.0" + ) + return None + + # Build per-operation settings from config + settings = _build_caching_settings(cache_config) + + middleware = ResponseCachingMiddleware( + cache_storage=store, + **settings, + ) + logger.info("MCP caching middleware enabled") + return middleware + + # Use existing app context if available, otherwise push one + if has_app_context(): + return _create_middleware() + else: + with flask_app.app_context(): + return _create_middleware() diff --git a/superset/mcp_service/mcp_config.py b/superset/mcp_service/mcp_config.py index bf4f3ff88dc..624c60e47af 100644 --- a/superset/mcp_service/mcp_config.py +++ b/superset/mcp_service/mcp_config.py @@ -73,6 +73,33 @@ MCP_FACTORY_CONFIG = { "config": None, # No additional config } +# MCP Store Configuration - shared infrastructure for all MCP storage needs +# (caching, auth, events, etc.) +MCP_STORE_CONFIG: Dict[str, Any] = { + "enabled": False, # Disabled by default in OSS + "CACHE_REDIS_URL": None, # Must be configured to enable + "WRAPPER_TYPE": "key_value.aio.wrappers.prefix_keys.PrefixKeysWrapper", +} + +# MCP Response Caching Configuration - feature-specific settings +MCP_CACHE_CONFIG: Dict[str, Any] = { + "enabled": False, # Disabled by default in OSS + "CACHE_KEY_PREFIX": "mcp_cache_v1_", # Static prefix for OSS + "list_tools_ttl": 60 * 5, # 5 minutes + "list_resources_ttl": 60 * 5, # 5 minutes + "list_prompts_ttl": 60 * 5, # 5 minutes + "read_resource_ttl": 60 * 60, # 1 hour + "get_prompt_ttl": 60 * 60, # 1 hour + "call_tool_ttl": 60 * 60, # 1 hour + "max_item_size": 1024 * 1024, # 1MB + "excluded_tools": [ + "execute_sql", + "generate_dashboard", + "generate_chart", + "update_chart", + ], +} + def create_default_mcp_auth_factory(app: Flask) -> Optional[Any]: """Default MCP auth factory using app.config values.""" diff --git a/superset/mcp_service/server.py b/superset/mcp_service/server.py index 3fe0907651c..7d83aa6ecb7 100644 --- a/superset/mcp_service/server.py +++ b/superset/mcp_service/server.py @@ -84,6 +84,7 @@ def run_server( else: # Use default initialization with auth from Flask config logging.info("Creating MCP app with default configuration...") + from superset.mcp_service.caching import create_response_caching_middleware from superset.mcp_service.flask_singleton import get_flask_app flask_app = get_flask_app() @@ -101,7 +102,16 @@ def run_server( except Exception as e: logging.error("Failed to create auth provider: %s", e) - mcp_instance = init_fastmcp_server(auth=auth_provider) + # Build middleware list + middleware_list = [] + caching_middleware = create_response_caching_middleware() + if caching_middleware: + middleware_list.append(caching_middleware) + + mcp_instance = init_fastmcp_server( + auth=auth_provider, + middleware=middleware_list or None, + ) env_key = f"FASTMCP_RUNNING_{port}" if not os.environ.get(env_key): diff --git a/superset/mcp_service/storage.py b/superset/mcp_service/storage.py new file mode 100644 index 00000000000..47d19b0eb7a --- /dev/null +++ b/superset/mcp_service/storage.py @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +MCP Redis storage factory. + +Provides get_mcp_store(prefix) factory for creating stores with feature-specific +prefixes. Uses shared MCP_STORE_CONFIG for Redis URL and wrapper type. + +Reusable across caching middleware, OAuth providers, EventStore, etc. +""" + +import logging +from importlib import import_module +from typing import Any, Callable, Dict + +logger = logging.getLogger(__name__) + + +def get_mcp_store( + prefix: str | Callable[[], str], +) -> Any | None: + """ + Create a store instance with the specified prefix. + + Uses shared MCP_STORE_CONFIG for Redis URL and wrapper type. + Each caller provides their own prefix (cache, auth, events, etc.). + + Args: + prefix: Feature-specific prefix (string or callable for multi-tenancy) + + Returns: + Wrapped RedisStore instance or None if not configured/disabled + + Examples: + # Caching + cache_store = get_mcp_store(prefix=cache_prefix_lambda) + + # Auth (future) + auth_store = get_mcp_store(prefix="mcp_auth_v1_") + + # EventStore (future) + event_store = get_mcp_store(prefix=event_prefix_lambda) + """ + from flask import has_app_context + + from superset.mcp_service.flask_singleton import get_flask_app + + flask_app = get_flask_app() + + def _get_store() -> Any | None: + store_config = flask_app.config.get("MCP_STORE_CONFIG", {}) + + # Check if store is enabled + if not store_config.get("enabled", False): + logger.debug("MCP store disabled via config") + return None + + return _create_redis_store(store_config, prefix) + + # Use existing app context if available, otherwise push one + if has_app_context(): + return _get_store() + else: + with flask_app.app_context(): + return _get_store() + + +def _create_redis_store( + store_config: Dict[str, Any], + prefix: str | Callable[[], str], +) -> Any | None: + """ + Create a RedisStore with the given prefix. + + Args: + store_config: MCP_STORE_CONFIG dict (Redis URL, wrapper type) + prefix: Feature-specific prefix + + Returns: + Wrapped RedisStore instance or None if not configured + """ + redis_url = store_config.get("CACHE_REDIS_URL") + if not redis_url: + logger.debug("MCP storage disabled - no CACHE_REDIS_URL configured") + return None + + try: + from key_value.aio.stores.redis import RedisStore + except ImportError: + logger.warning( + "key_value package not available for Redis storage. " + "Install with: pip install py-key-value-aio[redis]" + ) + return None + + try: + wrapper_type = store_config.get("WRAPPER_TYPE") + if not wrapper_type: + logger.error("MCP store WRAPPER_TYPE not configured") + return None + + wrapper_class = _import_wrapper_class(wrapper_type) + redis_store = RedisStore(url=redis_url) + store = wrapper_class(key_value=redis_store, prefix=prefix) + logger.info("✅ MCP RedisStore created") + return store + except Exception as e: + logger.error("Failed to create MCP store: %s", e) + return None + + +def _import_wrapper_class(class_path: str) -> type: + """ + Import a wrapper class from a dotted path. + + Args: + class_path: Dotted path like + 'key_value.aio.wrappers.prefix_keys.PrefixKeysWrapper' + + Returns: + The imported class + + Raises: + ImportError: If the class cannot be imported + """ + module_path, class_name = class_path.rsplit(".", 1) + module = import_module(module_path) + return getattr(module, class_name) diff --git a/superset/mcp_service/utils/schema_utils.py b/superset/mcp_service/utils/schema_utils.py index 4f2a588cce9..4e97abf4055 100644 --- a/superset/mcp_service/utils/schema_utils.py +++ b/superset/mcp_service/utils/schema_utils.py @@ -422,6 +422,8 @@ def parse_request( """ def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + import types + if asyncio.iscoroutinefunction(func): @wraps(func) @@ -429,7 +431,12 @@ def parse_request( # Parse if string, otherwise pass through # (parse_json_or_model handles both) parsed_request = parse_json_or_model(request, request_class, "request") - return await func(parsed_request, *args, **kwargs) + # Get ctx from FastMCP's dependency injection + # (we stripped it from signature) + from fastmcp.server.dependencies import get_context + + ctx = get_context() + return await func(parsed_request, ctx, *args, **kwargs) wrapper = async_wrapper else: @@ -439,17 +446,81 @@ def parse_request( # Parse if string, otherwise pass through # (parse_json_or_model handles both) parsed_request = parse_json_or_model(request, request_class, "request") - return func(parsed_request, *args, **kwargs) + # Get ctx from FastMCP's dependency injection + # (we stripped it from signature) + from fastmcp.server.dependencies import get_context + + ctx = get_context() + return func(parsed_request, ctx, *args, **kwargs) wrapper = sync_wrapper - # Modify the wrapper's annotations to accept str | RequestModel - # This allows FastMCP to accept string inputs while keeping the - # original function's type hints clean - if hasattr(wrapper, "__annotations__"): - # Create union type: str | RequestModel - wrapper.__annotations__["request"] = str | request_class + # Merge original function's __globals__ into wrapper's __globals__ + # This allows get_type_hints() to resolve type annotations from the + # original module (e.g., Context from fastmcp) + # FastMCP 2.13.2+ uses get_type_hints() which needs access to these types + merged_globals = {**wrapper.__globals__, **func.__globals__} # type: ignore[attr-defined] + new_wrapper = types.FunctionType( + wrapper.__code__, # type: ignore[attr-defined] + merged_globals, + wrapper.__name__, + wrapper.__defaults__, # type: ignore[attr-defined] + wrapper.__closure__, # type: ignore[attr-defined] + ) + # Copy __dict__ but exclude __wrapped__ + # NOTE: We intentionally do NOT preserve __wrapped__ here. + # Setting __wrapped__ causes inspect.signature() to follow the chain + # and find 'ctx' in the original function's signature, even after + # FastMCP's create_function_without_params removes it from annotations. + # This breaks Pydantic's TypeAdapter which expects signature params + # to match type_hints. + new_wrapper.__dict__.update( + {k: v for k, v in wrapper.__dict__.items() if k != "__wrapped__"} + ) + new_wrapper.__module__ = wrapper.__module__ + new_wrapper.__qualname__ = wrapper.__qualname__ + # Copy docstring from original function (not wrapper, which has no docstring) + new_wrapper.__doc__ = func.__doc__ - return wrapper + # Copy annotations from original function and modify request type + # Also remove ctx annotation - FastMCP strips it, and having it in + # annotations but not signature breaks Pydantic's TypeAdapter + if hasattr(func, "__annotations__"): + new_wrapper.__annotations__ = { + k: v + for k, v in func.__annotations__.items() + if k != "ctx" # Skip ctx - will be removed from signature too + } + # Modify request annotation to accept str | RequestModel + new_wrapper.__annotations__["request"] = str | request_class + else: + new_wrapper.__annotations__ = {"request": str | request_class} + + # Set __signature__ from original function, but modify for FastMCP: + # 1. Modify request annotation to accept str | RequestModel + # 2. Do NOT include ctx parameter - FastMCP will strip it anyway, and + # having it in __signature__ but not __annotations__ breaks Pydantic + import inspect as sig_inspect + + from fastmcp import Context as FMContext + + orig_sig = sig_inspect.signature(func) + new_params = [] + for name, param in orig_sig.parameters.items(): + # Skip ctx parameter - FastMCP tools don't expose it to clients + if param.annotation is FMContext or ( + hasattr(param.annotation, "__name__") + and param.annotation.__name__ == "Context" + ): + continue + if name == "request": + new_params.append(param.replace(annotation=str | request_class)) + else: + new_params.append(param) + new_wrapper.__signature__ = orig_sig.replace( # type: ignore[attr-defined] + parameters=new_params + ) + + return new_wrapper return decorator diff --git a/tests/unit_tests/mcp_service/test_mcp_caching.py b/tests/unit_tests/mcp_service/test_mcp_caching.py new file mode 100644 index 00000000000..ed1d33345ae --- /dev/null +++ b/tests/unit_tests/mcp_service/test_mcp_caching.py @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for MCP response caching middleware.""" + +from unittest.mock import MagicMock, patch + +from superset.mcp_service.caching import _build_caching_settings + + +def test_build_caching_settings_empty_config(): + """Empty config returns empty settings.""" + result = _build_caching_settings({}) + assert result == {} + + +def test_build_caching_settings_list_ttls(): + """List operation TTLs are mapped to settings.""" + config = { + "list_tools_ttl": 300, + "list_resources_ttl": 300, + "list_prompts_ttl": 300, + } + result = _build_caching_settings(config) + + assert result["list_tools_settings"] == {"ttl": 300} + assert result["list_resources_settings"] == {"ttl": 300} + assert result["list_prompts_settings"] == {"ttl": 300} + + +def test_build_caching_settings_item_ttls(): + """Individual item TTLs are mapped to settings.""" + config = { + "read_resource_ttl": 3600, + "get_prompt_ttl": 3600, + } + result = _build_caching_settings(config) + + assert result["read_resource_settings"] == {"ttl": 3600} + assert result["get_prompt_settings"] == {"ttl": 3600} + + +def test_build_caching_settings_call_tool_with_exclusions(): + """Call tool settings include TTL and exclusions.""" + config = { + "call_tool_ttl": 3600, + "excluded_tools": ["execute_sql", "generate_chart"], + } + result = _build_caching_settings(config) + + assert result["call_tool_settings"] == { + "ttl": 3600, + "excluded_tools": ["execute_sql", "generate_chart"], + } + + +def test_create_response_caching_middleware_returns_none_when_disabled(): + """Caching middleware returns None when disabled in config.""" + mock_flask_app = MagicMock() + mock_flask_app.config.get.return_value = {"enabled": False} + + with patch( + "superset.mcp_service.flask_singleton.get_flask_app", + return_value=mock_flask_app, + ): + with patch("flask.has_app_context", return_value=True): + from superset.mcp_service.caching import create_response_caching_middleware + + result = create_response_caching_middleware() + + assert result is None + + +def test_create_response_caching_middleware_returns_none_when_no_prefix(): + """Caching middleware returns None when CACHE_KEY_PREFIX is not set.""" + mock_flask_app = MagicMock() + mock_flask_app.config.get.return_value = { + "enabled": True, + "CACHE_KEY_PREFIX": None, + } + + with patch( + "superset.mcp_service.flask_singleton.get_flask_app", + return_value=mock_flask_app, + ): + with patch("flask.has_app_context", return_value=True): + from superset.mcp_service.caching import create_response_caching_middleware + + result = create_response_caching_middleware() + + assert result is None + + +def test_create_response_caching_middleware_creates_middleware(): + """Caching middleware is created when properly configured.""" + mock_flask_app = MagicMock() + mock_flask_app.config.get.return_value = { + "enabled": True, + "CACHE_KEY_PREFIX": "mcp_cache_v1_", + "list_tools_ttl": 300, + } + + mock_store = MagicMock() + mock_middleware = MagicMock() + + with patch( + "superset.mcp_service.flask_singleton.get_flask_app", + return_value=mock_flask_app, + ): + with patch("flask.has_app_context", return_value=True): + with patch( + "superset.mcp_service.caching.get_mcp_store", return_value=mock_store + ): + with patch( + "fastmcp.server.middleware.caching.ResponseCachingMiddleware", + return_value=mock_middleware, + ) as mock_middleware_class: + from superset.mcp_service.caching import ( + create_response_caching_middleware, + ) + + result = create_response_caching_middleware() + + assert result is mock_middleware + # Verify middleware was created with store and settings + mock_middleware_class.assert_called_once() + call_kwargs = mock_middleware_class.call_args[1] + assert call_kwargs["cache_storage"] is mock_store + assert call_kwargs["list_tools_settings"] == {"ttl": 300} diff --git a/tests/unit_tests/mcp_service/test_mcp_storage.py b/tests/unit_tests/mcp_service/test_mcp_storage.py new file mode 100644 index 00000000000..b61e99e62bd --- /dev/null +++ b/tests/unit_tests/mcp_service/test_mcp_storage.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for MCP storage factory.""" + +from unittest.mock import MagicMock, patch + + +def test_get_mcp_store_returns_none_when_disabled(): + """Storage returns None when MCP_STORE_CONFIG.enabled is False.""" + mock_flask_app = MagicMock() + mock_flask_app.config.get.return_value = {"enabled": False} + + with patch( + "superset.mcp_service.flask_singleton.get_flask_app", + return_value=mock_flask_app, + ): + with patch("flask.has_app_context", return_value=True): + from superset.mcp_service.storage import get_mcp_store + + result = get_mcp_store(prefix="test_") + + assert result is None + + +def test_get_mcp_store_returns_none_when_no_redis_url(): + """Storage returns None when CACHE_REDIS_URL is not configured.""" + mock_flask_app = MagicMock() + mock_flask_app.config.get.return_value = { + "enabled": True, + "CACHE_REDIS_URL": None, + "WRAPPER_TYPE": "key_value.aio.wrappers.prefix_keys.PrefixKeysWrapper", + } + + with patch( + "superset.mcp_service.flask_singleton.get_flask_app", + return_value=mock_flask_app, + ): + with patch("flask.has_app_context", return_value=True): + from superset.mcp_service.storage import get_mcp_store + + result = get_mcp_store(prefix="test_") + + assert result is None + + +def test_get_mcp_store_creates_store_when_enabled(): + """Storage creates wrapped RedisStore when properly configured.""" + mock_flask_app = MagicMock() + mock_flask_app.config.get.return_value = { + "enabled": True, + "CACHE_REDIS_URL": "redis://localhost:6379/0", + "WRAPPER_TYPE": "key_value.aio.wrappers.prefix_keys.PrefixKeysWrapper", + } + + mock_redis_store = MagicMock() + mock_wrapper_instance = MagicMock() + mock_wrapper_class = MagicMock(return_value=mock_wrapper_instance) + + with patch( + "superset.mcp_service.flask_singleton.get_flask_app", + return_value=mock_flask_app, + ): + with patch("flask.has_app_context", return_value=True): + with patch( + "superset.mcp_service.storage._import_wrapper_class", + return_value=mock_wrapper_class, + ): + with patch( + "key_value.aio.stores.redis.RedisStore", + return_value=mock_redis_store, + ): + from superset.mcp_service.storage import get_mcp_store + + result = get_mcp_store(prefix="test_prefix_") + + # Verify store was created + assert result is mock_wrapper_instance + # Verify wrapper was called with correct args + mock_wrapper_class.assert_called_once_with( + key_value=mock_redis_store, prefix="test_prefix_" + ) diff --git a/tests/unit_tests/mcp_service/utils/test_schema_utils.py b/tests/unit_tests/mcp_service/utils/test_schema_utils.py index 2c28f40782b..4e592159c07 100644 --- a/tests/unit_tests/mcp_service/utils/test_schema_utils.py +++ b/tests/unit_tests/mcp_service/utils/test_schema_utils.py @@ -356,6 +356,7 @@ class TestParseRequestDecorator: def test_decorator_with_json_string_async(self): """Should parse JSON string request in async function.""" + from unittest.mock import MagicMock, patch @parse_request(self.RequestModel) async def async_tool(request, ctx=None): @@ -363,21 +364,27 @@ class TestParseRequestDecorator: import asyncio - result = asyncio.run(async_tool('{"name": "test", "count": 5}')) + mock_ctx = MagicMock() + with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx): + result = asyncio.run(async_tool('{"name": "test", "count": 5}')) assert result == "test:5" def test_decorator_with_json_string_sync(self): """Should parse JSON string request in sync function.""" + from unittest.mock import MagicMock, patch @parse_request(self.RequestModel) def sync_tool(request, ctx=None): return f"{request.name}:{request.count}" - result = sync_tool('{"name": "test", "count": 5}') + mock_ctx = MagicMock() + with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx): + result = sync_tool('{"name": "test", "count": 5}') assert result == "test:5" def test_decorator_with_dict_async(self): """Should handle dict request in async function.""" + from unittest.mock import MagicMock, patch @parse_request(self.RequestModel) async def async_tool(request, ctx=None): @@ -385,21 +392,27 @@ class TestParseRequestDecorator: import asyncio - result = asyncio.run(async_tool({"name": "test", "count": 5})) + mock_ctx = MagicMock() + with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx): + result = asyncio.run(async_tool({"name": "test", "count": 5})) assert result == "test:5" def test_decorator_with_dict_sync(self): """Should handle dict request in sync function.""" + from unittest.mock import MagicMock, patch @parse_request(self.RequestModel) def sync_tool(request, ctx=None): return f"{request.name}:{request.count}" - result = sync_tool({"name": "test", "count": 5}) + mock_ctx = MagicMock() + with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx): + result = sync_tool({"name": "test", "count": 5}) assert result == "test:5" def test_decorator_with_model_instance_async(self): """Should pass through model instance in async function.""" + from unittest.mock import MagicMock, patch @parse_request(self.RequestModel) async def async_tool(request, ctx=None): @@ -407,23 +420,29 @@ class TestParseRequestDecorator: import asyncio + mock_ctx = MagicMock() instance = self.RequestModel(name="test", count=5) - result = asyncio.run(async_tool(instance)) + with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx): + result = asyncio.run(async_tool(instance)) assert result == "test:5" def test_decorator_with_model_instance_sync(self): """Should pass through model instance in sync function.""" + from unittest.mock import MagicMock, patch @parse_request(self.RequestModel) def sync_tool(request, ctx=None): return f"{request.name}:{request.count}" + mock_ctx = MagicMock() instance = self.RequestModel(name="test", count=5) - result = sync_tool(instance) + with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx): + result = sync_tool(instance) assert result == "test:5" def test_decorator_preserves_function_signature_async(self): """Should preserve original async function signature.""" + from unittest.mock import MagicMock, patch @parse_request(self.RequestModel) async def async_tool(request, ctx=None, extra=None): @@ -431,23 +450,29 @@ class TestParseRequestDecorator: import asyncio - result = asyncio.run( - async_tool('{"name": "test", "count": 5}', ctx=None, extra="data") - ) + mock_ctx = MagicMock() + with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx): + result = asyncio.run( + async_tool('{"name": "test", "count": 5}', extra="data") + ) assert result == "test:5:data" def test_decorator_preserves_function_signature_sync(self): """Should preserve original sync function signature.""" + from unittest.mock import MagicMock, patch @parse_request(self.RequestModel) def sync_tool(request, ctx=None, extra=None): return f"{request.name}:{request.count}:{extra}" - result = sync_tool('{"name": "test", "count": 5}', ctx=None, extra="data") + mock_ctx = MagicMock() + with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx): + result = sync_tool('{"name": "test", "count": 5}', extra="data") assert result == "test:5:data" def test_decorator_raises_validation_error_async(self): """Should raise ValidationError for invalid data in async function.""" + from unittest.mock import MagicMock, patch @parse_request(self.RequestModel) async def async_tool(request, ctx=None): @@ -455,21 +480,27 @@ class TestParseRequestDecorator: import asyncio - with pytest.raises(ValidationError): - asyncio.run(async_tool('{"name": "test"}')) # Missing count + mock_ctx = MagicMock() + with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx): + with pytest.raises(ValidationError): + asyncio.run(async_tool('{"name": "test"}')) # Missing count def test_decorator_raises_validation_error_sync(self): """Should raise ValidationError for invalid data in sync function.""" + from unittest.mock import MagicMock, patch @parse_request(self.RequestModel) def sync_tool(request, ctx=None): return f"{request.name}:{request.count}" - with pytest.raises(ValidationError): - sync_tool('{"name": "test"}') # Missing count + mock_ctx = MagicMock() + with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx): + with pytest.raises(ValidationError): + sync_tool('{"name": "test"}') # Missing count def test_decorator_with_complex_model_async(self): """Should handle complex nested models in async function.""" + from unittest.mock import MagicMock, patch class NestedModel(BaseModel): """Nested model.""" @@ -488,12 +519,15 @@ class TestParseRequestDecorator: import asyncio + mock_ctx = MagicMock() json_str = '{"name": "test", "nested": {"value": 42}}' - result = asyncio.run(async_tool(json_str)) + with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx): + result = asyncio.run(async_tool(json_str)) assert result == "test:42" def test_decorator_with_complex_model_sync(self): """Should handle complex nested models in sync function.""" + from unittest.mock import MagicMock, patch class NestedModel(BaseModel): """Nested model.""" @@ -510,6 +544,8 @@ class TestParseRequestDecorator: def sync_tool(request, ctx=None): return f"{request.name}:{request.nested.value}" + mock_ctx = MagicMock() json_str = '{"name": "test", "nested": {"value": 42}}' - result = sync_tool(json_str) + with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx): + result = sync_tool(json_str) assert result == "test:42"