mirror of
https://github.com/apache/superset.git
synced 2026-05-07 17:04:58 +00:00
834 lines
31 KiB
Python
834 lines
31 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
"""
|
|
MCP server for Apache Superset
|
|
|
|
Supports both single-pod (in-memory) and multi-pod (Redis) deployments.
|
|
For multi-pod deployments, configure MCP_EVENT_STORE_CONFIG with Redis URL.
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
from collections.abc import Sequence
|
|
from typing import Annotated, Any, Callable
|
|
|
|
import uvicorn
|
|
from fastmcp.exceptions import ToolError
|
|
from fastmcp.server.middleware import Middleware
|
|
|
|
from superset.mcp_service.app import create_mcp_app, init_fastmcp_server
|
|
from superset.mcp_service.mcp_config import (
|
|
get_mcp_factory_config,
|
|
MCP_STORE_CONFIG,
|
|
MCP_TOOL_SEARCH_CONFIG,
|
|
)
|
|
from superset.mcp_service.middleware import (
|
|
create_response_size_guard_middleware,
|
|
GlobalErrorHandlerMiddleware,
|
|
LoggingMiddleware,
|
|
StructuredContentStripperMiddleware,
|
|
)
|
|
from superset.mcp_service.privacy import (
|
|
tool_requires_data_model_metadata_access,
|
|
user_can_view_data_model_metadata,
|
|
)
|
|
from superset.mcp_service.storage import _create_redis_store
|
|
from superset.utils import json
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _suppress_third_party_warnings() -> None:
|
|
"""Suppress known third-party deprecation warnings from MCP responses.
|
|
|
|
The MCP SDK captures Python warnings and forwards them to clients via
|
|
``mcp.server.lowlevel.server:Warning:`` log entries. This wastes LLM
|
|
tokens and causes clients to try to "fix" irrelevant internal warnings.
|
|
|
|
Suppressed warnings:
|
|
- marshmallow ``RemovedInMarshmallow4Warning`` (triggered during
|
|
database engine schema instantiation)
|
|
- google.api_core ``FutureWarning`` (Python version support notices)
|
|
"""
|
|
import warnings
|
|
|
|
warnings.filterwarnings(
|
|
"ignore",
|
|
category=DeprecationWarning,
|
|
module=r"marshmallow\..*",
|
|
)
|
|
warnings.filterwarnings(
|
|
"ignore",
|
|
category=FutureWarning,
|
|
module=r"google\..*",
|
|
)
|
|
|
|
|
|
class FastMCPValidationFilter(logging.Filter):
|
|
"""Downgrade FastMCP's user-error logs from ERROR to WARNING.
|
|
|
|
FastMCP's server.py logs ValidationError and ToolError at ERROR level
|
|
via logger.exception() before our GlobalErrorHandlerMiddleware sees it.
|
|
These are user errors (LLM sent bad params, access denied, not found)
|
|
and are expected in normal MCP operation — they should not pollute
|
|
ERROR-level logs in Datadog.
|
|
|
|
Only "Error validating tool" messages are downgraded — these are
|
|
always Pydantic ValidationErrors (bad params from LLM). "Error calling
|
|
tool" messages are NOT downgraded because our middleware wraps both
|
|
user errors and system errors in ToolError, making it impossible to
|
|
distinguish them by exception type alone.
|
|
"""
|
|
|
|
def filter(self, record: logging.LogRecord) -> bool:
|
|
# NOTE: This matches the literal log message from FastMCP's server.py
|
|
# (fastmcp/server/server.py line ~1245). If FastMCP changes this
|
|
# message format, this filter will stop working silently.
|
|
if record.levelno != logging.ERROR:
|
|
return True
|
|
if "Error validating tool" in record.getMessage():
|
|
record.levelno = logging.WARNING
|
|
record.levelname = "WARNING"
|
|
return True
|
|
|
|
|
|
def configure_logging(debug: bool = False) -> None:
|
|
"""Configure logging for the MCP service."""
|
|
import sys
|
|
|
|
if debug or os.environ.get("SQLALCHEMY_DEBUG"):
|
|
# Only configure basic logging if no handlers exist (respects logging.ini)
|
|
root_logger = logging.getLogger()
|
|
if not root_logger.handlers:
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
stream=sys.stderr, # Always log to stderr, not stdout
|
|
)
|
|
|
|
# Only override SQLAlchemy logger levels if they're not explicitly configured
|
|
for logger_name in [
|
|
"sqlalchemy.engine",
|
|
"sqlalchemy.pool",
|
|
"sqlalchemy.dialects",
|
|
]:
|
|
logger = logging.getLogger(logger_name)
|
|
# Only set level if it's still at default (WARNING for SQLAlchemy)
|
|
if logger.level == logging.WARNING or logger.level == logging.NOTSET:
|
|
logger.setLevel(logging.INFO)
|
|
|
|
# Use logging instead of print to avoid stdout contamination
|
|
logging.info("🔍 SQL Debug logging enabled")
|
|
|
|
# FastMCP's server.py logs ValidationError/ToolError at ERROR via
|
|
# logger.exception() before our middleware sees it. These are user errors
|
|
# (bad params from LLM) and should not pollute ERROR logs.
|
|
# Downgrade these specific messages from ERROR to WARNING.
|
|
fastmcp_server_logger = logging.getLogger("fastmcp.server.server")
|
|
fastmcp_server_logger.addFilter(FastMCPValidationFilter())
|
|
|
|
|
|
def create_event_store(config: dict[str, Any] | None = None) -> Any | None:
|
|
"""
|
|
Create an EventStore for MCP session management.
|
|
|
|
For multi-pod deployments, uses Redis-backed storage to share session state
|
|
across pods. For single-pod deployments, returns None (uses in-memory).
|
|
|
|
Args:
|
|
config: Optional config dict. If None, reads from MCP_STORE_CONFIG.
|
|
|
|
Returns:
|
|
EventStore instance if Redis URL is configured, None otherwise.
|
|
"""
|
|
if config is None:
|
|
config = MCP_STORE_CONFIG
|
|
|
|
if not config.get("CACHE_REDIS_URL"):
|
|
logging.info("EventStore: Using in-memory storage (single-pod mode)")
|
|
return None
|
|
|
|
try:
|
|
from fastmcp.server.event_store import EventStore
|
|
|
|
# Get prefix from config (allows Preset to customize for multi-tenancy)
|
|
# Default prefix prevents key collisions in shared Redis environments
|
|
prefix = config.get("event_store_prefix", "mcp_events_")
|
|
|
|
# Create wrapped Redis store with prefix for key namespacing
|
|
redis_store = _create_redis_store(config, prefix=prefix, wrap=True)
|
|
if redis_store is None:
|
|
logging.warning("Failed to create Redis store, falling back to in-memory")
|
|
return None
|
|
|
|
# Create EventStore with Redis backend
|
|
event_store = EventStore(
|
|
storage=redis_store,
|
|
max_events_per_stream=config.get("event_store_max_events", 100),
|
|
ttl=config.get("event_store_ttl", 3600),
|
|
)
|
|
|
|
logging.info("EventStore: Using Redis storage (multi-pod mode)")
|
|
return event_store
|
|
|
|
except ImportError as e:
|
|
logging.error(
|
|
"Failed to import EventStore dependencies: %s. "
|
|
"Ensure fastmcp package is installed.",
|
|
e,
|
|
)
|
|
return None
|
|
except Exception as e:
|
|
logging.error("Failed to create Redis EventStore: %s", e)
|
|
return None
|
|
|
|
|
|
def _strip_titles(obj: Any, in_properties_map: bool = False) -> Any:
|
|
"""Recursively strip schema metadata ``title`` keys.
|
|
|
|
Keeps real field names inside ``properties`` (e.g. a property literally
|
|
named ``title``), while removing auto-generated schema title metadata.
|
|
"""
|
|
if isinstance(obj, dict):
|
|
result: dict[str, Any] = {}
|
|
for key, value in obj.items():
|
|
if key == "title" and not in_properties_map:
|
|
continue
|
|
result[key] = _strip_titles(value, in_properties_map=(key == "properties"))
|
|
return result
|
|
if isinstance(obj, list):
|
|
return [_strip_titles(item, in_properties_map=False) for item in obj]
|
|
return obj
|
|
|
|
|
|
def _simplify_optional_union(result: dict[str, Any]) -> dict[str, Any]:
|
|
"""Collapse ``anyOf``/``oneOf`` with exactly one non-null variant.
|
|
|
|
Pydantic encodes ``Optional[X]`` as ``{"anyOf": [<X>, {"type": "null"}]}``.
|
|
This replaces the union with the non-null variant while preserving any
|
|
``description`` or ``default`` from the parent node.
|
|
"""
|
|
for union_key in ("anyOf", "oneOf"):
|
|
variants = result.get(union_key)
|
|
if not isinstance(variants, list) or len(variants) != 2:
|
|
continue
|
|
non_null = [v for v in variants if v.get("type") != "null"]
|
|
if len(non_null) != 1:
|
|
continue
|
|
simplified = dict(non_null[0])
|
|
for keep in ("description", "default"):
|
|
if keep in result and keep not in simplified:
|
|
simplified[keep] = result[keep]
|
|
result.pop(union_key)
|
|
result.pop("description", None)
|
|
result.pop("default", None)
|
|
result.update(simplified)
|
|
return result
|
|
|
|
|
|
def _resolve_ref(
|
|
obj: dict[str, Any],
|
|
defs: dict[str, Any],
|
|
resolving: frozenset[str],
|
|
) -> Any:
|
|
"""Resolve a ``$ref`` pointer by inlining its definition from *defs*.
|
|
|
|
Falls back to ``{"type": "object"}`` when the definition is missing
|
|
or would cause a circular reference.
|
|
"""
|
|
ref_path: str = obj["$ref"]
|
|
ref_name = ref_path.rsplit("/", 1)[-1] if "/" in ref_path else ""
|
|
definition = defs.get(ref_name) if ref_name else None
|
|
|
|
if definition is not None and ref_name not in resolving:
|
|
inlined = _compact_schema(
|
|
definition,
|
|
_defs=defs,
|
|
_resolving=resolving | {ref_name},
|
|
)
|
|
if isinstance(inlined, dict):
|
|
if desc := obj.get("description"):
|
|
inlined.setdefault("description", desc)
|
|
return inlined
|
|
|
|
replacement: dict[str, Any] = {"type": "object"}
|
|
if desc := obj.get("description"):
|
|
replacement["description"] = desc
|
|
return replacement
|
|
|
|
|
|
def _compact_schema(
|
|
obj: Any,
|
|
*,
|
|
_defs: dict[str, Any] | None = None,
|
|
_resolving: frozenset[str] | None = None,
|
|
) -> Any:
|
|
"""Collapse ``$defs`` and ``$ref`` pointers in a JSON Schema.
|
|
|
|
Search results only need enough schema detail for the LLM to identify
|
|
which tool to call and construct a basic invocation. Full schemas
|
|
(with all nested model definitions) are still available when the tool
|
|
is actually invoked via ``call_tool``.
|
|
|
|
Transformations applied:
|
|
|
|
* ``$defs`` sections are removed entirely.
|
|
* ``{"$ref": "..."}`` is resolved by inlining the referenced
|
|
definition from ``$defs``. If the definition cannot be found
|
|
(or would cause a circular reference), the ref is replaced with
|
|
``{"type": "object"}``.
|
|
* ``anyOf``/``oneOf`` lists containing only a ``$ref`` and
|
|
``{"type": "null"}`` (Pydantic's Optional encoding) are collapsed
|
|
to the simplified non-null variant.
|
|
"""
|
|
if isinstance(obj, list):
|
|
return [
|
|
_compact_schema(item, _defs=_defs, _resolving=_resolving) for item in obj
|
|
]
|
|
if not isinstance(obj, dict):
|
|
return obj
|
|
|
|
# On the first (top-level) call, extract $defs for later resolution.
|
|
if _defs is None:
|
|
_defs = obj.get("$defs", {})
|
|
if _resolving is None:
|
|
_resolving = frozenset()
|
|
|
|
if "$ref" in obj:
|
|
return _resolve_ref(obj, _defs, _resolving)
|
|
|
|
result: dict[str, Any] = {}
|
|
for key, value in obj.items():
|
|
if key == "$defs":
|
|
continue
|
|
result[key] = _compact_schema(value, _defs=_defs, _resolving=_resolving)
|
|
|
|
return _simplify_optional_union(result)
|
|
|
|
|
|
def _truncate_description(text: str, max_length: int) -> str:
|
|
"""Truncate a tool description for search results.
|
|
|
|
Cuts at the last sentence boundary before *max_length*, or at
|
|
*max_length* with an ellipsis if no sentence boundary is found.
|
|
"""
|
|
if not text or len(text) <= max_length:
|
|
return text
|
|
# Try to cut at the last sentence boundary
|
|
truncated = text[:max_length]
|
|
last_period = truncated.rfind(". ")
|
|
if last_period > max_length // 2:
|
|
return truncated[: last_period + 1]
|
|
return truncated.rstrip() + "..."
|
|
|
|
|
|
def _extract_parameter_names(input_schema: dict[str, Any]) -> str:
|
|
"""Extract top-level parameter names from a JSON Schema as a hint string.
|
|
|
|
Returns a comma-separated string of property names from the schema's
|
|
``properties`` key, or an empty string if none are found.
|
|
|
|
Example: ``"page, page_size, search, filters, select_columns"``
|
|
"""
|
|
properties = input_schema.get("properties", {})
|
|
if not properties:
|
|
return ""
|
|
return ", ".join(properties.keys())
|
|
|
|
|
|
def _serialize_tools_without_output_schema(
|
|
tools: Sequence[Any],
|
|
) -> list[dict[str, Any]]:
|
|
"""Serialize tools to JSON, stripping outputSchema and titles to reduce tokens.
|
|
|
|
LLMs only need inputSchema to call tools. outputSchema accounts for
|
|
50-80% of the per-tool schema size, and auto-generated 'title' fields
|
|
add ~12% bloat. Stripping both cuts search result tokens significantly.
|
|
"""
|
|
results = []
|
|
for tool in tools:
|
|
data = tool.to_mcp_tool().model_dump(
|
|
mode="json", exclude_none=True, exclude={"outputSchema"}
|
|
)
|
|
data.pop("outputSchema", None)
|
|
if input_schema := data.get("inputSchema"):
|
|
data["inputSchema"] = _strip_titles(input_schema)
|
|
results.append(data)
|
|
return results
|
|
|
|
|
|
def _build_summary_serializer(max_desc: int) -> Any:
|
|
"""Build a summary-mode serializer that omits ``inputSchema``.
|
|
|
|
Returns a callable that serializes each tool to ``name``,
|
|
``description`` (optionally truncated), and a ``parameters_hint``
|
|
string listing top-level parameter names. ``inputSchema`` and
|
|
``outputSchema`` are stripped entirely.
|
|
"""
|
|
|
|
def _summary_serializer(tools: Sequence[Any]) -> list[dict[str, Any]]:
|
|
results = []
|
|
for tool in tools:
|
|
data = tool.to_mcp_tool().model_dump(
|
|
mode="json", exclude_none=True, exclude={"outputSchema"}
|
|
)
|
|
data.pop("outputSchema", None)
|
|
if input_schema := data.pop("inputSchema", None):
|
|
hint = _extract_parameter_names(input_schema)
|
|
if hint:
|
|
data["parameters_hint"] = hint
|
|
if max_desc and (desc := data.get("description")):
|
|
data["description"] = _truncate_description(desc, max_desc)
|
|
results.append(data)
|
|
return results
|
|
|
|
return _summary_serializer
|
|
|
|
|
|
def _tool_allowed_for_current_user(tool: Any) -> bool:
|
|
"""Return whether the current Flask user can see this tool in search results."""
|
|
try:
|
|
from flask import current_app, g
|
|
|
|
if not current_app.config.get("MCP_RBAC_ENABLED", True):
|
|
return True
|
|
|
|
from superset import security_manager
|
|
from superset.mcp_service.auth import (
|
|
CLASS_PERMISSION_ATTR,
|
|
get_user_from_request,
|
|
METHOD_PERMISSION_ATTR,
|
|
PERMISSION_PREFIX,
|
|
)
|
|
|
|
tool_func = getattr(tool, "fn", None)
|
|
if tool_requires_data_model_metadata_access(tool_func) and not (
|
|
user_can_view_data_model_metadata()
|
|
):
|
|
return False
|
|
|
|
class_permission_name = getattr(tool_func, CLASS_PERMISSION_ATTR, None)
|
|
if not class_permission_name:
|
|
return True
|
|
|
|
if not getattr(g, "user", None):
|
|
try:
|
|
g.user = get_user_from_request()
|
|
except ValueError:
|
|
return False
|
|
|
|
method_permission_name = getattr(tool_func, METHOD_PERMISSION_ATTR, "read")
|
|
permission_name = f"{PERMISSION_PREFIX}{method_permission_name}"
|
|
return security_manager.can_access(permission_name, class_permission_name)
|
|
except (AttributeError, RuntimeError, ValueError):
|
|
logger.debug("Could not evaluate tool search permission", exc_info=True)
|
|
return False
|
|
|
|
|
|
def _filter_tools_by_current_user_permission(tools: Sequence[Any]) -> list[Any]:
|
|
"""Filter search candidates to tools the current user can execute."""
|
|
return [tool for tool in tools if _tool_allowed_for_current_user(tool)]
|
|
|
|
|
|
def _create_search_result_serializer(
|
|
config: dict[str, Any],
|
|
) -> Any:
|
|
"""Build a search-result serializer from the tool-search config.
|
|
|
|
When ``include_schemas`` is False (default), delegates to
|
|
:func:`_build_summary_serializer`, which strips ``inputSchema``
|
|
entirely and adds a ``parameters_hint`` field with comma-separated
|
|
top-level parameter names. This reduces per-search token cost by
|
|
~80% vs compact mode while still conveying what parameters a tool
|
|
accepts.
|
|
|
|
When ``include_schemas`` is True, the full ``compact_schemas``/
|
|
``max_description_length`` pipeline applies (existing behavior):
|
|
|
|
* ``$defs`` sections and ``$ref`` pointers are collapsed when
|
|
``compact_schemas`` is True (see :func:`_compact_schema`).
|
|
* Tool descriptions are truncated to ``max_description_length`` chars.
|
|
|
|
Full schemas remain available when the tool is invoked via ``call_tool``.
|
|
"""
|
|
include_schemas = config.get("include_schemas", False)
|
|
|
|
if not include_schemas:
|
|
max_desc = config.get("max_description_length", 300)
|
|
return _build_summary_serializer(max_desc)
|
|
|
|
# include_schemas=True: apply full compact_schemas/max_description_length pipeline
|
|
compact = config.get("compact_schemas", True)
|
|
# Description truncation defaults to 300 when compact_schemas is on,
|
|
# but is disabled when compact_schemas is off (unless explicitly set).
|
|
max_desc_default = 300 if compact else 0
|
|
max_desc = config.get("max_description_length", max_desc_default)
|
|
|
|
if not compact and not max_desc:
|
|
return _serialize_tools_without_output_schema
|
|
|
|
def _serializer(tools: Sequence[Any]) -> list[dict[str, Any]]:
|
|
results = _serialize_tools_without_output_schema(tools)
|
|
for data in results:
|
|
if compact:
|
|
if input_schema := data.get("inputSchema"):
|
|
data["inputSchema"] = _compact_schema(input_schema)
|
|
if max_desc and (desc := data.get("description")):
|
|
data["description"] = _truncate_description(desc, max_desc)
|
|
return results
|
|
|
|
return _serializer
|
|
|
|
|
|
def _fix_call_tool_arguments(tool: Any) -> Any:
|
|
"""Fix anyOf schema in call_tool ``arguments`` for MCP bridge compatibility.
|
|
|
|
FastMCP's BaseSearchTransform defines ``arguments`` as
|
|
``dict[str, Any] | None`` which emits an ``anyOf`` JSON Schema.
|
|
Some MCP bridges (mcp-remote, Claude Desktop) don't handle ``anyOf``
|
|
and strip it, leaving the field without a ``type`` — causing all
|
|
call_tool invocations to fail with "Input should be a valid dictionary".
|
|
|
|
Replaces the ``anyOf`` with a flat ``type: object``.
|
|
"""
|
|
if "arguments" in (props := (tool.parameters or {}).get("properties", {})):
|
|
props["arguments"] = {
|
|
"additionalProperties": True,
|
|
"default": None,
|
|
"description": "Arguments to pass to the tool",
|
|
"type": "object",
|
|
}
|
|
return tool
|
|
|
|
|
|
def _normalize_call_tool_arguments(
|
|
arguments: dict[str, Any] | None,
|
|
tool_schema: dict[str, Any] | None,
|
|
) -> dict[str, Any] | None:
|
|
"""JSON-serialize dict/list values when the tool schema accepts both
|
|
string and object variants (anyOf or oneOf with a string type).
|
|
|
|
When the BM25/regex ``call_tool`` proxy forwards arguments to the
|
|
actual tool, dict/list values must be serialized if the tool's schema
|
|
declares ``anyOf``/``oneOf`` with a string variant
|
|
(e.g. ``request: str | RequestModel``).
|
|
|
|
Without this, the MCP transport calls ``bytes(dict, 'utf-8')``
|
|
which raises ``TypeError: encoding without a string argument``.
|
|
"""
|
|
if not arguments or not isinstance(tool_schema, dict):
|
|
return arguments
|
|
|
|
properties = tool_schema.get("properties", {})
|
|
result = dict(arguments)
|
|
for key, value in result.items():
|
|
if not isinstance(value, (dict, list)) or key not in properties:
|
|
continue
|
|
prop_schema = properties[key]
|
|
variants = prop_schema.get("oneOf") or prop_schema.get("anyOf") or []
|
|
has_string = any(v.get("type") == "string" for v in variants)
|
|
if has_string:
|
|
result[key] = json.dumps(value)
|
|
return result
|
|
|
|
|
|
def _apply_tool_search_transform(mcp_instance: Any, config: dict[str, Any]) -> None:
|
|
"""Apply tool search transform to reduce initial context size.
|
|
|
|
When enabled, replaces the full tool catalog with a search interface.
|
|
LLMs see only synthetic search/call tools plus pinned tools, and
|
|
discover other tools on-demand via natural language search.
|
|
|
|
Uses subclassing (not monkey-patching) to override ``_make_call_tool``
|
|
and fix the ``arguments`` schema for MCP bridge compatibility, and
|
|
normalize forwarded arguments to prevent encoding errors.
|
|
|
|
NOTE: ``_make_call_tool`` is a private API in FastMCP 3.x
|
|
(fastmcp>=3.1.0,<4.0). If FastMCP changes or removes this method
|
|
in a future major version, these subclasses will need to be updated.
|
|
"""
|
|
from fastmcp.server.context import Context
|
|
from fastmcp.tools.tool import Tool, ToolResult
|
|
|
|
strategy = config.get("strategy", "bm25")
|
|
kwargs: dict[str, Any] = {
|
|
"max_results": config.get("max_results", 5),
|
|
"always_visible": config.get("always_visible", []),
|
|
"search_tool_name": config.get("search_tool_name", "search_tools"),
|
|
"call_tool_name": config.get("call_tool_name", "call_tool"),
|
|
"search_result_serializer": _create_search_result_serializer(config),
|
|
}
|
|
|
|
def _make_normalizing_call_tool(transform: Any) -> Tool:
|
|
"""Create a call_tool proxy that normalizes arguments before forwarding.
|
|
|
|
This fixes two issues:
|
|
1. anyOf schema incompatibility with MCP bridges (schema fix).
|
|
2. ``encoding without a string argument`` TypeError when dict/list
|
|
values are forwarded for parameters declared as
|
|
``str | SomeModel`` (argument normalization).
|
|
"""
|
|
|
|
async def call_tool(
|
|
name: Annotated[str, "The name of the tool to call"],
|
|
arguments: Annotated[
|
|
dict[str, Any] | None, "Arguments to pass to the tool"
|
|
] = None,
|
|
ctx: Context = None,
|
|
) -> ToolResult:
|
|
"""Call a tool by name with the given arguments.
|
|
|
|
Use this to execute tools discovered via search_tools.
|
|
"""
|
|
if name in {transform._call_tool_name, transform._search_tool_name}:
|
|
raise ToolError(
|
|
f"'{name}' is a synthetic search tool and cannot be "
|
|
f"called via the call_tool proxy"
|
|
)
|
|
if arguments:
|
|
target_tool = await ctx.fastmcp.get_tool(name)
|
|
if target_tool is not None:
|
|
arguments = _normalize_call_tool_arguments(
|
|
arguments, target_tool.parameters
|
|
)
|
|
return await ctx.fastmcp.call_tool(name, arguments)
|
|
|
|
tool = Tool.from_function(fn=call_tool, name=transform._call_tool_name)
|
|
return _fix_call_tool_arguments(tool)
|
|
|
|
transform = _create_search_transform(
|
|
strategy=strategy,
|
|
kwargs=kwargs,
|
|
make_normalizing_call_tool=_make_normalizing_call_tool,
|
|
)
|
|
|
|
mcp_instance.add_transform(transform)
|
|
logger.info(
|
|
"Tool search transform enabled (strategy=%s, max_results=%d, pinned=%s)",
|
|
strategy,
|
|
kwargs["max_results"],
|
|
kwargs["always_visible"],
|
|
)
|
|
|
|
|
|
def _create_search_transform(
|
|
*,
|
|
strategy: str,
|
|
kwargs: dict[str, Any],
|
|
make_normalizing_call_tool: Callable[[Any], Any],
|
|
) -> Any:
|
|
"""Create the configured search transform with tool-permission filtering."""
|
|
from fastmcp.server.context import Context
|
|
|
|
if strategy == "regex":
|
|
from fastmcp.server.transforms.search import RegexSearchTransform
|
|
|
|
class _FixedRegexSearchTransform(RegexSearchTransform):
|
|
"""Regex search with fixed call_tool schema and arg normalization."""
|
|
|
|
async def _get_visible_tools(self, ctx: Context) -> Sequence[Any]:
|
|
"""Return only tools visible to the current authenticated user."""
|
|
tools = await super()._get_visible_tools(ctx)
|
|
return _filter_tools_by_current_user_permission(tools)
|
|
|
|
def _make_call_tool(self) -> Any:
|
|
"""Build the normalized ``call_tool`` proxy for regex search."""
|
|
return make_normalizing_call_tool(self)
|
|
|
|
return _FixedRegexSearchTransform(**kwargs)
|
|
|
|
from fastmcp.server.transforms.search import BM25SearchTransform
|
|
|
|
class _FixedBM25SearchTransform(BM25SearchTransform):
|
|
"""BM25 search with fixed call_tool schema and arg normalization."""
|
|
|
|
async def _get_visible_tools(self, ctx: Context) -> Sequence[Any]:
|
|
"""Return only tools visible to the current authenticated user."""
|
|
tools = await super()._get_visible_tools(ctx)
|
|
return _filter_tools_by_current_user_permission(tools)
|
|
|
|
def _make_call_tool(self) -> Any:
|
|
"""Build the normalized ``call_tool`` proxy for BM25 search."""
|
|
return make_normalizing_call_tool(self)
|
|
|
|
return _FixedBM25SearchTransform(**kwargs)
|
|
|
|
|
|
def _create_auth_provider(flask_app: Any) -> Any | None:
|
|
"""Create an auth provider from Flask app config.
|
|
|
|
Tries MCP_AUTH_FACTORY first, then falls back to the default factory
|
|
when MCP_AUTH_ENABLED is True.
|
|
"""
|
|
auth_provider = None
|
|
if auth_factory := flask_app.config.get("MCP_AUTH_FACTORY"):
|
|
try:
|
|
auth_provider = auth_factory(flask_app)
|
|
logger.info(
|
|
"Auth provider created from MCP_AUTH_FACTORY: %s",
|
|
type(auth_provider).__name__ if auth_provider else "None",
|
|
)
|
|
except Exception:
|
|
# Do not log the exception — it may contain secrets
|
|
logger.error("Failed to create auth provider from MCP_AUTH_FACTORY")
|
|
elif flask_app.config.get("MCP_AUTH_ENABLED", False):
|
|
from superset.mcp_service.mcp_config import (
|
|
create_default_mcp_auth_factory,
|
|
)
|
|
|
|
try:
|
|
auth_provider = create_default_mcp_auth_factory(flask_app)
|
|
logger.info(
|
|
"Auth provider created from default factory: %s",
|
|
type(auth_provider).__name__ if auth_provider else "None",
|
|
)
|
|
except Exception:
|
|
# Do not log the exception — it may contain secrets
|
|
logger.error("Failed to create auth provider from default factory")
|
|
return auth_provider
|
|
|
|
|
|
def build_middleware_list() -> list[Middleware]:
|
|
"""Build the core MCP middleware list in the correct order.
|
|
|
|
FastMCP wraps handlers so that the FIRST-added middleware is
|
|
outermost. Order here is outermost → innermost:
|
|
|
|
1. StructuredContentStripper — safety net, converts exceptions
|
|
to safe ToolResult text for transports that can't encode errors
|
|
2. LoggingMiddleware — logs tool calls with success/failure status
|
|
3. GlobalErrorHandler — catches tool exceptions, raises ToolError
|
|
"""
|
|
return [
|
|
StructuredContentStripperMiddleware(),
|
|
LoggingMiddleware(),
|
|
GlobalErrorHandlerMiddleware(),
|
|
]
|
|
|
|
|
|
def run_server(
|
|
host: str = "127.0.0.1",
|
|
port: int = 5008,
|
|
debug: bool = False,
|
|
use_factory_config: bool = False,
|
|
event_store_config: dict[str, Any] | None = None,
|
|
) -> None:
|
|
"""
|
|
Run the MCP service server with FastMCP endpoints.
|
|
Uses streamable-http transport for HTTP server mode.
|
|
|
|
For multi-pod deployments, configure MCP_EVENT_STORE_CONFIG with Redis URL
|
|
to share session state across pods.
|
|
|
|
Args:
|
|
host: Host to bind to
|
|
port: Port to bind to
|
|
debug: Enable debug logging
|
|
use_factory_config: Use configuration from get_mcp_factory_config()
|
|
event_store_config: Optional EventStore configuration dict.
|
|
If None, reads from MCP_EVENT_STORE_CONFIG.
|
|
"""
|
|
|
|
configure_logging(debug)
|
|
_suppress_third_party_warnings()
|
|
|
|
# DO NOT IMPORT TOOLS HERE!! IMPORT THEM IN app.py!!!!!
|
|
|
|
if use_factory_config:
|
|
# Use factory configuration for customization
|
|
logging.info("Creating MCP app from factory configuration...")
|
|
factory_config = get_mcp_factory_config()
|
|
mcp_instance = create_mcp_app(**factory_config)
|
|
|
|
# Apply tool search transform if configured
|
|
tool_search_config = MCP_TOOL_SEARCH_CONFIG
|
|
if tool_search_config.get("enabled", False):
|
|
_apply_tool_search_transform(mcp_instance, tool_search_config)
|
|
else:
|
|
# Use default initialization with auth from Flask config
|
|
logging.info("Creating MCP app with default configuration...")
|
|
from superset.mcp_service.caching import create_response_caching_middleware
|
|
from superset.mcp_service.flask_singleton import get_flask_app
|
|
|
|
flask_app = get_flask_app()
|
|
auth_provider = _create_auth_provider(flask_app)
|
|
|
|
middleware_list = build_middleware_list()
|
|
|
|
# Add optional middleware (innermost, closest to tool)
|
|
size_guard_middleware = create_response_size_guard_middleware()
|
|
if size_guard_middleware:
|
|
middleware_list.append(size_guard_middleware)
|
|
|
|
if caching_middleware := create_response_caching_middleware():
|
|
middleware_list.append(caching_middleware)
|
|
|
|
mcp_instance = init_fastmcp_server(
|
|
auth=auth_provider,
|
|
middleware=middleware_list or None,
|
|
)
|
|
|
|
# Apply tool search transform if configured
|
|
tool_search_config = flask_app.config.get(
|
|
"MCP_TOOL_SEARCH_CONFIG", MCP_TOOL_SEARCH_CONFIG
|
|
)
|
|
if tool_search_config.get("enabled", False):
|
|
_apply_tool_search_transform(mcp_instance, tool_search_config)
|
|
# Ensure the configured search tool name is excluded from the
|
|
# response size guard (search results are intentionally large)
|
|
if size_guard_middleware:
|
|
search_name = tool_search_config.get("search_tool_name", "search_tools")
|
|
size_guard_middleware.excluded_tools.add(search_name)
|
|
|
|
# Create EventStore for session management (Redis for multi-pod, None for in-memory)
|
|
event_store = create_event_store(event_store_config)
|
|
|
|
env_key = f"FASTMCP_RUNNING_{port}"
|
|
if not os.environ.get(env_key):
|
|
os.environ[env_key] = "1"
|
|
try:
|
|
logging.info("Starting FastMCP on %s:%s", host, port)
|
|
|
|
if event_store is not None:
|
|
# Multi-pod: Use http_app with Redis EventStore, run with uvicorn
|
|
logging.info("Running in multi-pod mode with Redis EventStore")
|
|
app = mcp_instance.http_app(
|
|
transport="streamable-http",
|
|
event_store=event_store,
|
|
stateless_http=True,
|
|
)
|
|
uvicorn.run(app, host=host, port=port)
|
|
else:
|
|
# Single-pod mode: Use built-in run() with in-memory sessions
|
|
logging.info("Running in single-pod mode with in-memory sessions")
|
|
mcp_instance.run(
|
|
transport="streamable-http",
|
|
host=host,
|
|
port=port,
|
|
stateless_http=True,
|
|
)
|
|
except Exception as e:
|
|
logging.error("FastMCP failed: %s", e)
|
|
os.environ.pop(env_key, None)
|
|
else:
|
|
logging.info("FastMCP already running on %s:%s", host, port)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_server()
|