mirror of
https://github.com/apache/superset.git
synced 2026-04-20 08:34:37 +00:00
feat(mcp): Add ResponseCachingMiddleware and Storage (#36497)
This commit is contained in:
143
tests/unit_tests/mcp_service/test_mcp_caching.py
Normal file
143
tests/unit_tests/mcp_service/test_mcp_caching.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Tests for MCP response caching middleware."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from superset.mcp_service.caching import _build_caching_settings
|
||||
|
||||
|
||||
def test_build_caching_settings_empty_config():
|
||||
"""Empty config returns empty settings."""
|
||||
result = _build_caching_settings({})
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_build_caching_settings_list_ttls():
|
||||
"""List operation TTLs are mapped to settings."""
|
||||
config = {
|
||||
"list_tools_ttl": 300,
|
||||
"list_resources_ttl": 300,
|
||||
"list_prompts_ttl": 300,
|
||||
}
|
||||
result = _build_caching_settings(config)
|
||||
|
||||
assert result["list_tools_settings"] == {"ttl": 300}
|
||||
assert result["list_resources_settings"] == {"ttl": 300}
|
||||
assert result["list_prompts_settings"] == {"ttl": 300}
|
||||
|
||||
|
||||
def test_build_caching_settings_item_ttls():
|
||||
"""Individual item TTLs are mapped to settings."""
|
||||
config = {
|
||||
"read_resource_ttl": 3600,
|
||||
"get_prompt_ttl": 3600,
|
||||
}
|
||||
result = _build_caching_settings(config)
|
||||
|
||||
assert result["read_resource_settings"] == {"ttl": 3600}
|
||||
assert result["get_prompt_settings"] == {"ttl": 3600}
|
||||
|
||||
|
||||
def test_build_caching_settings_call_tool_with_exclusions():
|
||||
"""Call tool settings include TTL and exclusions."""
|
||||
config = {
|
||||
"call_tool_ttl": 3600,
|
||||
"excluded_tools": ["execute_sql", "generate_chart"],
|
||||
}
|
||||
result = _build_caching_settings(config)
|
||||
|
||||
assert result["call_tool_settings"] == {
|
||||
"ttl": 3600,
|
||||
"excluded_tools": ["execute_sql", "generate_chart"],
|
||||
}
|
||||
|
||||
|
||||
def test_create_response_caching_middleware_returns_none_when_disabled():
|
||||
"""Caching middleware returns None when disabled in config."""
|
||||
mock_flask_app = MagicMock()
|
||||
mock_flask_app.config.get.return_value = {"enabled": False}
|
||||
|
||||
with patch(
|
||||
"superset.mcp_service.flask_singleton.get_flask_app",
|
||||
return_value=mock_flask_app,
|
||||
):
|
||||
with patch("flask.has_app_context", return_value=True):
|
||||
from superset.mcp_service.caching import create_response_caching_middleware
|
||||
|
||||
result = create_response_caching_middleware()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_create_response_caching_middleware_returns_none_when_no_prefix():
|
||||
"""Caching middleware returns None when CACHE_KEY_PREFIX is not set."""
|
||||
mock_flask_app = MagicMock()
|
||||
mock_flask_app.config.get.return_value = {
|
||||
"enabled": True,
|
||||
"CACHE_KEY_PREFIX": None,
|
||||
}
|
||||
|
||||
with patch(
|
||||
"superset.mcp_service.flask_singleton.get_flask_app",
|
||||
return_value=mock_flask_app,
|
||||
):
|
||||
with patch("flask.has_app_context", return_value=True):
|
||||
from superset.mcp_service.caching import create_response_caching_middleware
|
||||
|
||||
result = create_response_caching_middleware()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_create_response_caching_middleware_creates_middleware():
|
||||
"""Caching middleware is created when properly configured."""
|
||||
mock_flask_app = MagicMock()
|
||||
mock_flask_app.config.get.return_value = {
|
||||
"enabled": True,
|
||||
"CACHE_KEY_PREFIX": "mcp_cache_v1_",
|
||||
"list_tools_ttl": 300,
|
||||
}
|
||||
|
||||
mock_store = MagicMock()
|
||||
mock_middleware = MagicMock()
|
||||
|
||||
with patch(
|
||||
"superset.mcp_service.flask_singleton.get_flask_app",
|
||||
return_value=mock_flask_app,
|
||||
):
|
||||
with patch("flask.has_app_context", return_value=True):
|
||||
with patch(
|
||||
"superset.mcp_service.caching.get_mcp_store", return_value=mock_store
|
||||
):
|
||||
with patch(
|
||||
"fastmcp.server.middleware.caching.ResponseCachingMiddleware",
|
||||
return_value=mock_middleware,
|
||||
) as mock_middleware_class:
|
||||
from superset.mcp_service.caching import (
|
||||
create_response_caching_middleware,
|
||||
)
|
||||
|
||||
result = create_response_caching_middleware()
|
||||
|
||||
assert result is mock_middleware
|
||||
# Verify middleware was created with store and settings
|
||||
mock_middleware_class.assert_called_once()
|
||||
call_kwargs = mock_middleware_class.call_args[1]
|
||||
assert call_kwargs["cache_storage"] is mock_store
|
||||
assert call_kwargs["list_tools_settings"] == {"ttl": 300}
|
||||
96
tests/unit_tests/mcp_service/test_mcp_storage.py
Normal file
96
tests/unit_tests/mcp_service/test_mcp_storage.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Tests for MCP storage factory."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def test_get_mcp_store_returns_none_when_disabled():
|
||||
"""Storage returns None when MCP_STORE_CONFIG.enabled is False."""
|
||||
mock_flask_app = MagicMock()
|
||||
mock_flask_app.config.get.return_value = {"enabled": False}
|
||||
|
||||
with patch(
|
||||
"superset.mcp_service.flask_singleton.get_flask_app",
|
||||
return_value=mock_flask_app,
|
||||
):
|
||||
with patch("flask.has_app_context", return_value=True):
|
||||
from superset.mcp_service.storage import get_mcp_store
|
||||
|
||||
result = get_mcp_store(prefix="test_")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_mcp_store_returns_none_when_no_redis_url():
|
||||
"""Storage returns None when CACHE_REDIS_URL is not configured."""
|
||||
mock_flask_app = MagicMock()
|
||||
mock_flask_app.config.get.return_value = {
|
||||
"enabled": True,
|
||||
"CACHE_REDIS_URL": None,
|
||||
"WRAPPER_TYPE": "key_value.aio.wrappers.prefix_keys.PrefixKeysWrapper",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"superset.mcp_service.flask_singleton.get_flask_app",
|
||||
return_value=mock_flask_app,
|
||||
):
|
||||
with patch("flask.has_app_context", return_value=True):
|
||||
from superset.mcp_service.storage import get_mcp_store
|
||||
|
||||
result = get_mcp_store(prefix="test_")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_mcp_store_creates_store_when_enabled():
|
||||
"""Storage creates wrapped RedisStore when properly configured."""
|
||||
mock_flask_app = MagicMock()
|
||||
mock_flask_app.config.get.return_value = {
|
||||
"enabled": True,
|
||||
"CACHE_REDIS_URL": "redis://localhost:6379/0",
|
||||
"WRAPPER_TYPE": "key_value.aio.wrappers.prefix_keys.PrefixKeysWrapper",
|
||||
}
|
||||
|
||||
mock_redis_store = MagicMock()
|
||||
mock_wrapper_instance = MagicMock()
|
||||
mock_wrapper_class = MagicMock(return_value=mock_wrapper_instance)
|
||||
|
||||
with patch(
|
||||
"superset.mcp_service.flask_singleton.get_flask_app",
|
||||
return_value=mock_flask_app,
|
||||
):
|
||||
with patch("flask.has_app_context", return_value=True):
|
||||
with patch(
|
||||
"superset.mcp_service.storage._import_wrapper_class",
|
||||
return_value=mock_wrapper_class,
|
||||
):
|
||||
with patch(
|
||||
"key_value.aio.stores.redis.RedisStore",
|
||||
return_value=mock_redis_store,
|
||||
):
|
||||
from superset.mcp_service.storage import get_mcp_store
|
||||
|
||||
result = get_mcp_store(prefix="test_prefix_")
|
||||
|
||||
# Verify store was created
|
||||
assert result is mock_wrapper_instance
|
||||
# Verify wrapper was called with correct args
|
||||
mock_wrapper_class.assert_called_once_with(
|
||||
key_value=mock_redis_store, prefix="test_prefix_"
|
||||
)
|
||||
@@ -356,6 +356,7 @@ class TestParseRequestDecorator:
|
||||
|
||||
def test_decorator_with_json_string_async(self):
|
||||
"""Should parse JSON string request in async function."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@parse_request(self.RequestModel)
|
||||
async def async_tool(request, ctx=None):
|
||||
@@ -363,21 +364,27 @@ class TestParseRequestDecorator:
|
||||
|
||||
import asyncio
|
||||
|
||||
result = asyncio.run(async_tool('{"name": "test", "count": 5}'))
|
||||
mock_ctx = MagicMock()
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
result = asyncio.run(async_tool('{"name": "test", "count": 5}'))
|
||||
assert result == "test:5"
|
||||
|
||||
def test_decorator_with_json_string_sync(self):
|
||||
"""Should parse JSON string request in sync function."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@parse_request(self.RequestModel)
|
||||
def sync_tool(request, ctx=None):
|
||||
return f"{request.name}:{request.count}"
|
||||
|
||||
result = sync_tool('{"name": "test", "count": 5}')
|
||||
mock_ctx = MagicMock()
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
result = sync_tool('{"name": "test", "count": 5}')
|
||||
assert result == "test:5"
|
||||
|
||||
def test_decorator_with_dict_async(self):
|
||||
"""Should handle dict request in async function."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@parse_request(self.RequestModel)
|
||||
async def async_tool(request, ctx=None):
|
||||
@@ -385,21 +392,27 @@ class TestParseRequestDecorator:
|
||||
|
||||
import asyncio
|
||||
|
||||
result = asyncio.run(async_tool({"name": "test", "count": 5}))
|
||||
mock_ctx = MagicMock()
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
result = asyncio.run(async_tool({"name": "test", "count": 5}))
|
||||
assert result == "test:5"
|
||||
|
||||
def test_decorator_with_dict_sync(self):
|
||||
"""Should handle dict request in sync function."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@parse_request(self.RequestModel)
|
||||
def sync_tool(request, ctx=None):
|
||||
return f"{request.name}:{request.count}"
|
||||
|
||||
result = sync_tool({"name": "test", "count": 5})
|
||||
mock_ctx = MagicMock()
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
result = sync_tool({"name": "test", "count": 5})
|
||||
assert result == "test:5"
|
||||
|
||||
def test_decorator_with_model_instance_async(self):
|
||||
"""Should pass through model instance in async function."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@parse_request(self.RequestModel)
|
||||
async def async_tool(request, ctx=None):
|
||||
@@ -407,23 +420,29 @@ class TestParseRequestDecorator:
|
||||
|
||||
import asyncio
|
||||
|
||||
mock_ctx = MagicMock()
|
||||
instance = self.RequestModel(name="test", count=5)
|
||||
result = asyncio.run(async_tool(instance))
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
result = asyncio.run(async_tool(instance))
|
||||
assert result == "test:5"
|
||||
|
||||
def test_decorator_with_model_instance_sync(self):
|
||||
"""Should pass through model instance in sync function."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@parse_request(self.RequestModel)
|
||||
def sync_tool(request, ctx=None):
|
||||
return f"{request.name}:{request.count}"
|
||||
|
||||
mock_ctx = MagicMock()
|
||||
instance = self.RequestModel(name="test", count=5)
|
||||
result = sync_tool(instance)
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
result = sync_tool(instance)
|
||||
assert result == "test:5"
|
||||
|
||||
def test_decorator_preserves_function_signature_async(self):
|
||||
"""Should preserve original async function signature."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@parse_request(self.RequestModel)
|
||||
async def async_tool(request, ctx=None, extra=None):
|
||||
@@ -431,23 +450,29 @@ class TestParseRequestDecorator:
|
||||
|
||||
import asyncio
|
||||
|
||||
result = asyncio.run(
|
||||
async_tool('{"name": "test", "count": 5}', ctx=None, extra="data")
|
||||
)
|
||||
mock_ctx = MagicMock()
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
result = asyncio.run(
|
||||
async_tool('{"name": "test", "count": 5}', extra="data")
|
||||
)
|
||||
assert result == "test:5:data"
|
||||
|
||||
def test_decorator_preserves_function_signature_sync(self):
|
||||
"""Should preserve original sync function signature."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@parse_request(self.RequestModel)
|
||||
def sync_tool(request, ctx=None, extra=None):
|
||||
return f"{request.name}:{request.count}:{extra}"
|
||||
|
||||
result = sync_tool('{"name": "test", "count": 5}', ctx=None, extra="data")
|
||||
mock_ctx = MagicMock()
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
result = sync_tool('{"name": "test", "count": 5}', extra="data")
|
||||
assert result == "test:5:data"
|
||||
|
||||
def test_decorator_raises_validation_error_async(self):
|
||||
"""Should raise ValidationError for invalid data in async function."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@parse_request(self.RequestModel)
|
||||
async def async_tool(request, ctx=None):
|
||||
@@ -455,21 +480,27 @@ class TestParseRequestDecorator:
|
||||
|
||||
import asyncio
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
asyncio.run(async_tool('{"name": "test"}')) # Missing count
|
||||
mock_ctx = MagicMock()
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
with pytest.raises(ValidationError):
|
||||
asyncio.run(async_tool('{"name": "test"}')) # Missing count
|
||||
|
||||
def test_decorator_raises_validation_error_sync(self):
|
||||
"""Should raise ValidationError for invalid data in sync function."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@parse_request(self.RequestModel)
|
||||
def sync_tool(request, ctx=None):
|
||||
return f"{request.name}:{request.count}"
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
sync_tool('{"name": "test"}') # Missing count
|
||||
mock_ctx = MagicMock()
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
with pytest.raises(ValidationError):
|
||||
sync_tool('{"name": "test"}') # Missing count
|
||||
|
||||
def test_decorator_with_complex_model_async(self):
|
||||
"""Should handle complex nested models in async function."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
class NestedModel(BaseModel):
|
||||
"""Nested model."""
|
||||
@@ -488,12 +519,15 @@ class TestParseRequestDecorator:
|
||||
|
||||
import asyncio
|
||||
|
||||
mock_ctx = MagicMock()
|
||||
json_str = '{"name": "test", "nested": {"value": 42}}'
|
||||
result = asyncio.run(async_tool(json_str))
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
result = asyncio.run(async_tool(json_str))
|
||||
assert result == "test:42"
|
||||
|
||||
def test_decorator_with_complex_model_sync(self):
|
||||
"""Should handle complex nested models in sync function."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
class NestedModel(BaseModel):
|
||||
"""Nested model."""
|
||||
@@ -510,6 +544,8 @@ class TestParseRequestDecorator:
|
||||
def sync_tool(request, ctx=None):
|
||||
return f"{request.name}:{request.nested.value}"
|
||||
|
||||
mock_ctx = MagicMock()
|
||||
json_str = '{"name": "test", "nested": {"value": 42}}'
|
||||
result = sync_tool(json_str)
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
result = sync_tool(json_str)
|
||||
assert result == "test:42"
|
||||
|
||||
Reference in New Issue
Block a user