mirror of
https://github.com/apache/superset.git
synced 2026-04-11 04:15:33 +00:00
527 lines
18 KiB
Python
527 lines
18 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
"""
|
|
Generic utilities for flexible schema input handling in MCP tools.
|
|
|
|
This module provides utilities to accept both JSON string and object formats
|
|
for input parameters, making MCP tools more flexible for different clients.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from functools import wraps
|
|
from typing import Any, Callable, List, Type, TypeVar
|
|
|
|
from pydantic import BaseModel, ValidationError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class JSONParseError(ValueError):
|
|
"""Raised when JSON parsing fails with helpful context."""
|
|
|
|
def __init__(self, value: Any, error: Exception, param_name: str = "parameter"):
|
|
self.value = value
|
|
self.original_error = error
|
|
self.param_name = param_name
|
|
super().__init__(
|
|
f"Failed to parse {param_name} from JSON string: {error}. "
|
|
f"Received value: {value!r}"
|
|
)
|
|
|
|
|
|
def parse_json_or_passthrough(
|
|
value: Any, param_name: str = "parameter", strict: bool = False
|
|
) -> Any:
|
|
"""
|
|
Parse a value that can be either a JSON string or a native Python object.
|
|
|
|
This function handles the common pattern where API parameters can be provided
|
|
as either:
|
|
- A JSON string (e.g., from CLI tools or tests): '{"key": "value"}'
|
|
- A native Python object (e.g., from LLM clients): {"key": "value"}
|
|
|
|
Args:
|
|
value: The input value to parse. Can be a string, list, dict, or any JSON-
|
|
serializable type.
|
|
param_name: Name of the parameter for error messages (default: "parameter")
|
|
strict: If True, raises JSONParseError on parse failures. If False, logs
|
|
warning and returns original value (default: False)
|
|
|
|
Returns:
|
|
Parsed Python object if value was a JSON string, otherwise returns value
|
|
unchanged.
|
|
|
|
Raises:
|
|
JSONParseError: If strict=True and JSON parsing fails.
|
|
|
|
Examples:
|
|
>>> parse_json_or_passthrough('[1, 2, 3]', 'numbers')
|
|
[1, 2, 3]
|
|
|
|
>>> parse_json_or_passthrough([1, 2, 3], 'numbers')
|
|
[1, 2, 3]
|
|
|
|
>>> parse_json_or_passthrough('{"key": "value"}', 'config')
|
|
{'key': 'value'}
|
|
|
|
>>> parse_json_or_passthrough({'key': 'value'}, 'config')
|
|
{'key': 'value'}
|
|
"""
|
|
# If not a string, return as-is (already in object form)
|
|
if not isinstance(value, str):
|
|
return value
|
|
|
|
# Try to parse as JSON
|
|
try:
|
|
from superset.utils import json
|
|
|
|
parsed = json.loads(value)
|
|
logger.debug("Successfully parsed %s from JSON string", param_name)
|
|
return parsed
|
|
except (ValueError, TypeError) as e:
|
|
error_msg = (
|
|
f"Failed to parse {param_name} from JSON string: {e}. Received: {value!r}"
|
|
)
|
|
|
|
if strict:
|
|
raise JSONParseError(value, e, param_name) from None
|
|
|
|
logger.warning("%s. Returning original value.", error_msg)
|
|
return value
|
|
|
|
|
|
def parse_json_or_list(
|
|
value: Any, param_name: str = "parameter", item_separator: str = ","
|
|
) -> List[Any]:
|
|
"""
|
|
Parse a value into a list, accepting JSON string, list, or comma-separated string.
|
|
|
|
This function provides maximum flexibility for list parameters by accepting:
|
|
- JSON array string: '["item1", "item2"]'
|
|
- Python list: ["item1", "item2"]
|
|
- Comma-separated string: "item1, item2, item3"
|
|
- Empty/None: returns empty list
|
|
|
|
Args:
|
|
value: Input value to parse into a list
|
|
param_name: Name of the parameter for error messages
|
|
item_separator: Separator for comma-separated strings (default: ",")
|
|
|
|
Returns:
|
|
List of items. Returns empty list if value is None or empty.
|
|
|
|
Examples:
|
|
>>> parse_json_or_list('["a", "b"]', 'items')
|
|
['a', 'b']
|
|
|
|
>>> parse_json_or_list(['a', 'b'], 'items')
|
|
['a', 'b']
|
|
|
|
>>> parse_json_or_list('a, b, c', 'items')
|
|
['a', 'b', 'c']
|
|
|
|
>>> parse_json_or_list(None, 'items')
|
|
[]
|
|
"""
|
|
# Handle None and empty values
|
|
if value is None or value == "":
|
|
return []
|
|
|
|
# Already a list, return as-is
|
|
if isinstance(value, list):
|
|
return value
|
|
|
|
# Try to parse as JSON
|
|
if isinstance(value, str):
|
|
try:
|
|
from superset.utils import json
|
|
|
|
parsed = json.loads(value)
|
|
# If successfully parsed and it's a list, return it
|
|
if isinstance(parsed, list):
|
|
logger.debug("Successfully parsed %s from JSON string", param_name)
|
|
return parsed
|
|
|
|
# If parsed to non-list (e.g., single value), wrap in list
|
|
logger.debug(
|
|
"Parsed %s from JSON to non-list, wrapping in list", param_name
|
|
)
|
|
return [parsed]
|
|
except (ValueError, TypeError):
|
|
# Not valid JSON, try comma-separated parsing
|
|
logger.debug(
|
|
"Could not parse %s as JSON, trying comma-separated", param_name
|
|
)
|
|
items = [
|
|
item.strip() for item in value.split(item_separator) if item.strip()
|
|
]
|
|
return items
|
|
|
|
# For any other type, wrap in a list
|
|
logger.debug("Wrapping %s value in list", param_name)
|
|
return [value]
|
|
|
|
|
|
def parse_json_or_model(
|
|
value: Any, model_class: Type[BaseModel], param_name: str = "parameter"
|
|
) -> BaseModel:
|
|
"""
|
|
Parse a value into a Pydantic model, accepting JSON string or dict.
|
|
|
|
Args:
|
|
value: Input value to parse (JSON string, dict, or model instance)
|
|
model_class: Pydantic model class to validate against
|
|
param_name: Name of the parameter for error messages
|
|
|
|
Returns:
|
|
Validated Pydantic model instance
|
|
|
|
Raises:
|
|
ValidationError: If the value cannot be parsed or validated
|
|
|
|
Examples:
|
|
>>> class MyModel(BaseModel):
|
|
... name: str
|
|
... value: int
|
|
|
|
>>> parse_json_or_model('{"name": "test", "value": 42}', MyModel)
|
|
MyModel(name='test', value=42)
|
|
|
|
>>> parse_json_or_model({"name": "test", "value": 42}, MyModel)
|
|
MyModel(name='test', value=42)
|
|
"""
|
|
# If already an instance of the model, return as-is
|
|
if isinstance(value, model_class):
|
|
return value
|
|
|
|
# Parse JSON string if needed
|
|
parsed_value = parse_json_or_passthrough(value, param_name, strict=True)
|
|
|
|
# Validate and construct the model
|
|
try:
|
|
return model_class.model_validate(parsed_value)
|
|
except ValidationError:
|
|
logger.error(
|
|
"Failed to validate %s against %s", param_name, model_class.__name__
|
|
)
|
|
raise
|
|
|
|
|
|
def parse_json_or_model_list(
|
|
value: Any,
|
|
model_class: Type[BaseModel],
|
|
param_name: str = "parameter",
|
|
) -> List[BaseModel]:
|
|
"""
|
|
Parse a value into a list of Pydantic models, accepting JSON string or list.
|
|
|
|
Args:
|
|
value: Input value to parse (JSON string, list of dicts, or list of models)
|
|
model_class: Pydantic model class for list items
|
|
param_name: Name of the parameter for error messages
|
|
|
|
Returns:
|
|
List of validated Pydantic model instances
|
|
|
|
Raises:
|
|
ValidationError: If any item cannot be parsed or validated
|
|
|
|
Examples:
|
|
>>> class Item(BaseModel):
|
|
... name: str
|
|
|
|
>>> parse_json_or_model_list('[{"name": "a"}, {"name": "b"}]', Item)
|
|
[Item(name='a'), Item(name='b')]
|
|
|
|
>>> parse_json_or_model_list([{"name": "a"}, {"name": "b"}], Item)
|
|
[Item(name='a'), Item(name='b')]
|
|
"""
|
|
# Handle None and empty
|
|
if value is None or value == "":
|
|
return []
|
|
|
|
# Parse to list first
|
|
items = parse_json_or_list(value, param_name)
|
|
|
|
# Validate each item against the model
|
|
validated_items = []
|
|
for i, item in enumerate(items):
|
|
try:
|
|
if isinstance(item, model_class):
|
|
validated_items.append(item)
|
|
else:
|
|
validated_items.append(model_class.model_validate(item))
|
|
except ValidationError:
|
|
logger.error(
|
|
"Failed to validate %s[%s] against %s",
|
|
param_name,
|
|
i,
|
|
model_class.__name__,
|
|
)
|
|
# Re-raise original validation error
|
|
raise
|
|
|
|
return validated_items
|
|
|
|
|
|
# Pydantic validator decorators for common use cases
|
|
def json_or_passthrough_validator(
|
|
param_name: str | None = None, strict: bool = False
|
|
) -> Callable[[Type[BaseModel], Any, Any], Any]:
|
|
"""
|
|
Decorator factory for Pydantic field validators that accept JSON or objects.
|
|
|
|
This creates a validator that can be used with Pydantic's @field_validator
|
|
decorator to automatically parse JSON strings.
|
|
|
|
Args:
|
|
param_name: Parameter name for error messages (uses field name if None)
|
|
strict: Whether to raise errors on parse failures
|
|
|
|
Returns:
|
|
Validator function compatible with @field_validator
|
|
|
|
Example:
|
|
>>> class MySchema(BaseModel):
|
|
... config: dict
|
|
...
|
|
... @field_validator('config', mode='before')
|
|
... @classmethod
|
|
... def parse_config(cls, v):
|
|
... return parse_json_or_passthrough(v, 'config')
|
|
"""
|
|
|
|
def validator(cls: Type[BaseModel], v: Any, info: Any = None) -> Any:
|
|
# Use field name from validation info if param_name not provided
|
|
field_name = param_name or (info.field_name if info else "field")
|
|
return parse_json_or_passthrough(v, field_name, strict)
|
|
|
|
return validator
|
|
|
|
|
|
def json_or_list_validator(
|
|
param_name: str | None = None, item_separator: str = ","
|
|
) -> Callable[[Type[BaseModel], Any, Any], List[Any]]:
|
|
"""
|
|
Decorator factory for Pydantic validators that parse values into lists.
|
|
|
|
Args:
|
|
param_name: Parameter name for error messages
|
|
item_separator: Separator for comma-separated strings
|
|
|
|
Returns:
|
|
Validator function compatible with @field_validator
|
|
|
|
Example:
|
|
>>> class MySchema(BaseModel):
|
|
... items: List[str]
|
|
...
|
|
... @field_validator('items', mode='before')
|
|
... @classmethod
|
|
... def parse_items(cls, v):
|
|
... return parse_json_or_list(v, 'items')
|
|
"""
|
|
|
|
def validator(cls: Type[BaseModel], v: Any, info: Any = None) -> List[Any]:
|
|
field_name = param_name or (info.field_name if info else "field")
|
|
return parse_json_or_list(v, field_name, item_separator)
|
|
|
|
return validator
|
|
|
|
|
|
def json_or_model_list_validator(
|
|
model_class: Type[BaseModel], param_name: str | None = None
|
|
) -> Callable[[Type[BaseModel], Any, Any], List[BaseModel]]:
|
|
"""
|
|
Decorator factory for Pydantic validators that parse lists of models.
|
|
|
|
Args:
|
|
model_class: Pydantic model class for list items
|
|
param_name: Parameter name for error messages
|
|
|
|
Returns:
|
|
Validator function compatible with @field_validator
|
|
|
|
Example:
|
|
>>> class FilterModel(BaseModel):
|
|
... col: str
|
|
... value: str
|
|
...
|
|
>>> class MySchema(BaseModel):
|
|
... filters: List[FilterModel]
|
|
...
|
|
... @field_validator('filters', mode='before')
|
|
... @classmethod
|
|
... def parse_filters(cls, v):
|
|
... return parse_json_or_model_list(v, FilterModel, 'filters')
|
|
"""
|
|
|
|
def validator(cls: Type[BaseModel], v: Any, info: Any = None) -> List[BaseModel]:
|
|
field_name = param_name or (info.field_name if info else "field")
|
|
return parse_json_or_model_list(v, model_class, field_name)
|
|
|
|
return validator
|
|
|
|
|
|
def parse_request(
|
|
request_class: Type[BaseModel],
|
|
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
"""
|
|
Decorator to handle Claude Code bug where requests are double-serialized as strings.
|
|
|
|
Automatically parses string requests to Pydantic models before calling
|
|
the tool function. Also modifies the function's type annotations to accept
|
|
str | RequestModel to pass FastMCP validation.
|
|
|
|
See: https://github.com/anthropics/claude-code/issues/5504
|
|
|
|
Args:
|
|
request_class: The Pydantic model class for the request
|
|
|
|
Returns:
|
|
Decorator function that wraps the tool function
|
|
|
|
Usage:
|
|
@mcp.tool
|
|
@mcp_auth_hook
|
|
@parse_request(ListChartsRequest)
|
|
async def list_charts(
|
|
request: ListChartsRequest, ctx: Context # Keep clean type hint
|
|
) -> ChartList:
|
|
# Decorator handles string conversion and type annotation
|
|
await ctx.info(f"Listing charts: page={request.page}")
|
|
...
|
|
|
|
Note:
|
|
- Works with both async and sync functions
|
|
- Request must be the first positional argument
|
|
- Modifies __annotations__ to accept str | RequestModel for FastMCP
|
|
- Function implementation can use clean RequestModel type hint
|
|
- If request is already a model instance, it passes through unchanged
|
|
- Handles JSON string parsing with helpful error messages
|
|
"""
|
|
|
|
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
import types
|
|
|
|
if asyncio.iscoroutinefunction(func):
|
|
|
|
@wraps(func)
|
|
async def async_wrapper(request: Any, *args: Any, **kwargs: Any) -> Any:
|
|
# Parse if string, otherwise pass through
|
|
# (parse_json_or_model handles both)
|
|
parsed_request = parse_json_or_model(request, request_class, "request")
|
|
# Get ctx from FastMCP's dependency injection
|
|
# (we stripped it from signature)
|
|
from fastmcp.server.dependencies import get_context
|
|
|
|
ctx = get_context()
|
|
return await func(parsed_request, ctx, *args, **kwargs)
|
|
|
|
wrapper = async_wrapper
|
|
else:
|
|
|
|
@wraps(func)
|
|
def sync_wrapper(request: Any, *args: Any, **kwargs: Any) -> Any:
|
|
# Parse if string, otherwise pass through
|
|
# (parse_json_or_model handles both)
|
|
parsed_request = parse_json_or_model(request, request_class, "request")
|
|
# Get ctx from FastMCP's dependency injection
|
|
# (we stripped it from signature)
|
|
from fastmcp.server.dependencies import get_context
|
|
|
|
ctx = get_context()
|
|
return func(parsed_request, ctx, *args, **kwargs)
|
|
|
|
wrapper = sync_wrapper
|
|
|
|
# Merge original function's __globals__ into wrapper's __globals__
|
|
# This allows get_type_hints() to resolve type annotations from the
|
|
# original module (e.g., Context from fastmcp)
|
|
# FastMCP 2.13.2+ uses get_type_hints() which needs access to these types
|
|
merged_globals = {**wrapper.__globals__, **func.__globals__} # type: ignore[attr-defined]
|
|
new_wrapper = types.FunctionType(
|
|
wrapper.__code__, # type: ignore[attr-defined]
|
|
merged_globals,
|
|
wrapper.__name__,
|
|
wrapper.__defaults__, # type: ignore[attr-defined]
|
|
wrapper.__closure__, # type: ignore[attr-defined]
|
|
)
|
|
# Copy __dict__ but exclude __wrapped__
|
|
# NOTE: We intentionally do NOT preserve __wrapped__ here.
|
|
# Setting __wrapped__ causes inspect.signature() to follow the chain
|
|
# and find 'ctx' in the original function's signature, even after
|
|
# FastMCP's create_function_without_params removes it from annotations.
|
|
# This breaks Pydantic's TypeAdapter which expects signature params
|
|
# to match type_hints.
|
|
new_wrapper.__dict__.update(
|
|
{k: v for k, v in wrapper.__dict__.items() if k != "__wrapped__"}
|
|
)
|
|
new_wrapper.__module__ = wrapper.__module__
|
|
new_wrapper.__qualname__ = wrapper.__qualname__
|
|
# Copy docstring from original function (not wrapper, which has no docstring)
|
|
new_wrapper.__doc__ = func.__doc__
|
|
|
|
# Copy annotations from original function and modify request type
|
|
# Also remove ctx annotation - FastMCP strips it, and having it in
|
|
# annotations but not signature breaks Pydantic's TypeAdapter
|
|
if hasattr(func, "__annotations__"):
|
|
new_wrapper.__annotations__ = {
|
|
k: v
|
|
for k, v in func.__annotations__.items()
|
|
if k != "ctx" # Skip ctx - will be removed from signature too
|
|
}
|
|
# Modify request annotation to accept str | RequestModel
|
|
new_wrapper.__annotations__["request"] = str | request_class
|
|
else:
|
|
new_wrapper.__annotations__ = {"request": str | request_class}
|
|
|
|
# Set __signature__ from original function, but modify for FastMCP:
|
|
# 1. Modify request annotation to accept str | RequestModel
|
|
# 2. Do NOT include ctx parameter - FastMCP will strip it anyway, and
|
|
# having it in __signature__ but not __annotations__ breaks Pydantic
|
|
import inspect as sig_inspect
|
|
|
|
from fastmcp import Context as FMContext
|
|
|
|
orig_sig = sig_inspect.signature(func)
|
|
new_params = []
|
|
for name, param in orig_sig.parameters.items():
|
|
# Skip ctx parameter - FastMCP tools don't expose it to clients
|
|
if param.annotation is FMContext or (
|
|
hasattr(param.annotation, "__name__")
|
|
and param.annotation.__name__ == "Context"
|
|
):
|
|
continue
|
|
if name == "request":
|
|
new_params.append(param.replace(annotation=str | request_class))
|
|
else:
|
|
new_params.append(param)
|
|
new_wrapper.__signature__ = orig_sig.replace( # type: ignore[attr-defined]
|
|
parameters=new_params
|
|
)
|
|
|
|
return new_wrapper
|
|
|
|
return decorator
|