mirror of
https://github.com/apache/superset.git
synced 2026-05-03 06:54:19 +00:00
Compare commits
1 Commits
feat/toolt
...
mcp-loggin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dface257b8 |
@@ -128,8 +128,18 @@ class LoggingMiddleware(Middleware):
|
||||
Tool calls are handled in on_call_tool() which wraps execution to capture
|
||||
duration_ms. Non-tool messages (resource reads, prompts, etc.) are handled
|
||||
in on_message().
|
||||
|
||||
When tool search is enabled (progressive discovery), the MCP client calls
|
||||
``call_tool`` / ``search_tools`` proxies instead of individual tools. This
|
||||
middleware resolves the underlying tool name from ``call_tool`` arguments so
|
||||
that analytics queries can filter by the actual tool (stored as ``mcp_tool``
|
||||
in the curated payload).
|
||||
"""
|
||||
|
||||
#: Default proxy names used by FastMCP tool-search transforms.
|
||||
_CALL_TOOL_PROXY = "call_tool"
|
||||
_SEARCH_TOOLS_PROXY = "search_tools"
|
||||
|
||||
def _extract_context_info(
|
||||
self, context: MiddlewareContext
|
||||
) -> tuple[
|
||||
@@ -156,6 +166,28 @@ class LoggingMiddleware(Middleware):
|
||||
dataset_id = params.get("dataset_id")
|
||||
return agent_id, user_id, dashboard_id, slice_id, dataset_id, params
|
||||
|
||||
@staticmethod
|
||||
def _resolve_tool_name(tool_name: str | None, params: dict[str, Any]) -> str | None:
|
||||
"""Resolve the underlying tool name from call_tool proxy arguments.
|
||||
|
||||
When tool search is enabled, the MCP client uses the ``call_tool``
|
||||
proxy and passes the real tool name as the ``name`` argument. This
|
||||
helper extracts that value so we can log which tool was actually
|
||||
executed rather than just ``"call_tool"``.
|
||||
|
||||
Returns:
|
||||
The resolved tool name if *tool_name* is the call_tool proxy and
|
||||
``params["name"]`` is a non-empty string, otherwise ``None``.
|
||||
"""
|
||||
if (
|
||||
tool_name == LoggingMiddleware._CALL_TOOL_PROXY
|
||||
and isinstance(params, dict)
|
||||
and isinstance(params.get("name"), str)
|
||||
and params["name"]
|
||||
):
|
||||
return params["name"]
|
||||
return None
|
||||
|
||||
async def on_call_tool(
|
||||
self,
|
||||
context: MiddlewareContext,
|
||||
@@ -166,15 +198,34 @@ class LoggingMiddleware(Middleware):
|
||||
self._extract_context_info(context)
|
||||
)
|
||||
tool_name = getattr(context.message, "name", None)
|
||||
mcp_tool = self._resolve_tool_name(tool_name, params)
|
||||
|
||||
start_time = time.time()
|
||||
success = False
|
||||
error_type: str | None = None
|
||||
try:
|
||||
result = await call_next(context)
|
||||
success = True
|
||||
return result
|
||||
except Exception as exc:
|
||||
error_type = type(exc).__name__
|
||||
raise
|
||||
finally:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
payload: dict[str, Any] = {
|
||||
"tool": tool_name,
|
||||
"agent_id": agent_id,
|
||||
"params": _sanitize_params(params),
|
||||
"method": context.method,
|
||||
"dashboard_id": dashboard_id,
|
||||
"slice_id": slice_id,
|
||||
"dataset_id": dataset_id,
|
||||
"success": success,
|
||||
}
|
||||
if mcp_tool is not None:
|
||||
payload["mcp_tool"] = mcp_tool
|
||||
if error_type is not None:
|
||||
payload["error_type"] = error_type
|
||||
if has_app_context():
|
||||
event_logger.log(
|
||||
user_id=user_id,
|
||||
@@ -183,22 +234,14 @@ class LoggingMiddleware(Middleware):
|
||||
duration_ms=duration_ms,
|
||||
slice_id=slice_id,
|
||||
referrer=None,
|
||||
curated_payload={
|
||||
"tool": tool_name,
|
||||
"agent_id": agent_id,
|
||||
"params": _sanitize_params(params),
|
||||
"method": context.method,
|
||||
"dashboard_id": dashboard_id,
|
||||
"slice_id": slice_id,
|
||||
"dataset_id": dataset_id,
|
||||
"success": success,
|
||||
},
|
||||
curated_payload=payload,
|
||||
)
|
||||
logger.info(
|
||||
"MCP tool call: tool=%s, agent_id=%s, user_id=%s, method=%s, "
|
||||
"dashboard_id=%s, slice_id=%s, dataset_id=%s, duration_ms=%s, "
|
||||
"success=%s",
|
||||
"MCP tool call: tool=%s, mcp_tool=%s, agent_id=%s, user_id=%s, "
|
||||
"method=%s, dashboard_id=%s, slice_id=%s, dataset_id=%s, "
|
||||
"duration_ms=%s, success=%s, error_type=%s",
|
||||
tool_name,
|
||||
mcp_tool,
|
||||
agent_id,
|
||||
user_id,
|
||||
context.method,
|
||||
@@ -207,6 +250,7 @@ class LoggingMiddleware(Middleware):
|
||||
dataset_id,
|
||||
duration_ms,
|
||||
success,
|
||||
error_type,
|
||||
)
|
||||
|
||||
async def on_message(
|
||||
|
||||
@@ -20,6 +20,8 @@ Unit tests for LoggingMiddleware on_call_tool() and on_message() methods.
|
||||
|
||||
Tests verify that:
|
||||
- on_call_tool() captures duration_ms and success status
|
||||
- on_call_tool() resolves call_tool proxy to actual tool name (mcp_tool)
|
||||
- on_call_tool() captures error_type on failure
|
||||
- on_message() logs non-tool messages without duration
|
||||
- _extract_context_info() extracts entity IDs from params
|
||||
"""
|
||||
@@ -88,7 +90,7 @@ class TestLoggingMiddlewareOnCallTool:
|
||||
async def test_on_call_tool_logs_failure_on_exception(
|
||||
self, mock_get_user_id, mock_event_logger
|
||||
):
|
||||
"""on_call_tool records success=False when tool raises."""
|
||||
"""on_call_tool records success=False and error_type when tool raises."""
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context(name="execute_sql")
|
||||
call_next = AsyncMock(side_effect=ValueError("boom"))
|
||||
@@ -100,8 +102,25 @@ class TestLoggingMiddlewareOnCallTool:
|
||||
mock_event_logger.log.assert_called_once()
|
||||
call_kwargs = mock_event_logger.log.call_args[1]
|
||||
assert call_kwargs["curated_payload"]["success"] is False
|
||||
assert call_kwargs["curated_payload"]["error_type"] == "ValueError"
|
||||
assert call_kwargs["duration_ms"] >= 0
|
||||
|
||||
@patch("superset.mcp_service.middleware.event_logger")
|
||||
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_call_tool_no_error_type_on_success(
|
||||
self, mock_get_user_id, mock_event_logger
|
||||
):
|
||||
"""on_call_tool omits error_type from payload on success."""
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context(name="list_charts")
|
||||
call_next = AsyncMock(return_value="ok")
|
||||
|
||||
await middleware.on_call_tool(ctx, call_next)
|
||||
|
||||
payload = mock_event_logger.log.call_args[1]["curated_payload"]
|
||||
assert "error_type" not in payload
|
||||
|
||||
@patch("superset.mcp_service.middleware.event_logger")
|
||||
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
|
||||
@pytest.mark.asyncio
|
||||
@@ -127,6 +146,66 @@ class TestLoggingMiddlewareOnCallTool:
|
||||
assert call_kwargs["slice_id"] == 20
|
||||
assert call_kwargs["curated_payload"]["dataset_id"] == 30
|
||||
|
||||
@patch("superset.mcp_service.middleware.event_logger")
|
||||
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_call_tool_resolves_call_tool_proxy(
|
||||
self, mock_get_user_id, mock_event_logger
|
||||
):
|
||||
"""call_tool proxy is resolved to the actual tool name via mcp_tool."""
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context(
|
||||
name="call_tool",
|
||||
params={"name": "list_datasets", "arguments": {"page": 1}},
|
||||
)
|
||||
call_next = AsyncMock(return_value="datasets")
|
||||
|
||||
await middleware.on_call_tool(ctx, call_next)
|
||||
|
||||
payload = mock_event_logger.log.call_args[1]["curated_payload"]
|
||||
assert payload["tool"] == "call_tool"
|
||||
assert payload["mcp_tool"] == "list_datasets"
|
||||
|
||||
@patch("superset.mcp_service.middleware.event_logger")
|
||||
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_call_tool_no_mcp_tool_for_direct_calls(
|
||||
self, mock_get_user_id, mock_event_logger
|
||||
):
|
||||
"""Direct tool calls (not via proxy) omit mcp_tool from payload."""
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context(name="list_charts")
|
||||
call_next = AsyncMock(return_value="charts")
|
||||
|
||||
await middleware.on_call_tool(ctx, call_next)
|
||||
|
||||
payload = mock_event_logger.log.call_args[1]["curated_payload"]
|
||||
assert payload["tool"] == "list_charts"
|
||||
assert "mcp_tool" not in payload
|
||||
|
||||
@patch("superset.mcp_service.middleware.event_logger")
|
||||
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_call_tool_proxy_failure_captures_both_fields(
|
||||
self, mock_get_user_id, mock_event_logger
|
||||
):
|
||||
"""call_tool proxy failure captures mcp_tool and error_type."""
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context(
|
||||
name="call_tool",
|
||||
params={"name": "get_chart_data", "arguments": {"chart_id": 1}},
|
||||
)
|
||||
call_next = AsyncMock(side_effect=PermissionError("access denied"))
|
||||
|
||||
with pytest.raises(PermissionError):
|
||||
await middleware.on_call_tool(ctx, call_next)
|
||||
|
||||
payload = mock_event_logger.log.call_args[1]["curated_payload"]
|
||||
assert payload["tool"] == "call_tool"
|
||||
assert payload["mcp_tool"] == "get_chart_data"
|
||||
assert payload["success"] is False
|
||||
assert payload["error_type"] == "PermissionError"
|
||||
|
||||
|
||||
class TestLoggingMiddlewareOnMessage:
|
||||
"""Tests for LoggingMiddleware.on_message()."""
|
||||
@@ -155,6 +234,42 @@ class TestLoggingMiddlewareOnMessage:
|
||||
assert "success" not in call_kwargs["curated_payload"]
|
||||
|
||||
|
||||
class TestResolveToolName:
|
||||
"""Tests for LoggingMiddleware._resolve_tool_name()."""
|
||||
|
||||
def test_resolves_call_tool_proxy(self):
|
||||
"""Returns the real tool name when call_tool proxy is used."""
|
||||
assert (
|
||||
LoggingMiddleware._resolve_tool_name(
|
||||
"call_tool", {"name": "list_datasets", "arguments": {}}
|
||||
)
|
||||
== "list_datasets"
|
||||
)
|
||||
|
||||
def test_returns_none_for_direct_tool(self):
|
||||
"""Returns None for direct tool calls (not via proxy)."""
|
||||
assert LoggingMiddleware._resolve_tool_name("list_charts", {"page": 1}) is None
|
||||
|
||||
def test_returns_none_when_name_missing(self):
|
||||
"""Returns None when call_tool params lack 'name'."""
|
||||
assert LoggingMiddleware._resolve_tool_name("call_tool", {"foo": "bar"}) is None
|
||||
|
||||
def test_returns_none_for_empty_name(self):
|
||||
"""Returns None when call_tool params have empty 'name'."""
|
||||
assert LoggingMiddleware._resolve_tool_name("call_tool", {"name": ""}) is None
|
||||
|
||||
def test_returns_none_for_non_string_name(self):
|
||||
"""Returns None when call_tool name param is not a string."""
|
||||
assert LoggingMiddleware._resolve_tool_name("call_tool", {"name": 123}) is None
|
||||
|
||||
def test_returns_none_for_search_tools(self):
|
||||
"""search_tools proxy is not resolved (no underlying tool name)."""
|
||||
assert (
|
||||
LoggingMiddleware._resolve_tool_name("search_tools", {"query": "datasets"})
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
class TestExtractContextInfo:
|
||||
"""Tests for LoggingMiddleware._extract_context_info()."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user