feat(mcp): Add flexible input parsing to handle double-serialized requests (#36249)

This commit is contained in:
Amin Ghadersohi
2025-11-26 04:21:04 +11:00
committed by GitHub
parent cd36845d56
commit 8d5d71199a
25 changed files with 1177 additions and 12 deletions

View File

@@ -293,7 +293,113 @@ def my_function(
- Still import `List`, `Dict`, `Any`, etc. from typing (for now)
- All new code must follow this pattern
### 6. Error Handling
### 6. Flexible Input Parsing (JSON String or Object)
**MCP tools accept both JSON string and native object formats for parameters** using utilities from `superset.mcp_service.utils.schema_utils`. This makes tools flexible for different client types (LLM clients send objects, CLI tools send JSON strings).
**PREFERRED: Use the `@parse_request` decorator** for tool functions to automatically handle request parsing:
```python
from superset.mcp_service.utils.schema_utils import parse_request
@mcp.tool
@mcp_auth_hook
@parse_request(ListChartsRequest) # Automatically parses string requests!
async def list_charts(request: ListChartsRequest | str, ctx: Context) -> ChartList:
"""List charts with filtering and search."""
# request is guaranteed to be ListChartsRequest here - no manual parsing needed!
await ctx.info(f"Listing charts: page={request.page}")
...
```
**Benefits:**
- Eliminates 5 lines of boilerplate code per tool
- Handles both async and sync functions automatically
- Works with Claude Code bug (GitHub issue #5504)
- Cleaner, more maintainable code
**Available utilities for other use cases:**
#### parse_json_or_passthrough
Parse JSON string or return object as-is:
```python
from superset.mcp_service.utils.schema_utils import parse_json_or_passthrough
# Accepts both formats
config = parse_json_or_passthrough(value, param_name="config")
# value can be: '{"key": "value"}' (JSON string) OR {"key": "value"} (dict)
```
#### parse_json_or_list
Parse to list from JSON, list, or comma-separated string:
```python
from superset.mcp_service.utils.schema_utils import parse_json_or_list
# Accepts multiple formats
items = parse_json_or_list(value, param_name="items")
# value can be:
# '["a", "b"]' (JSON array)
# ["a", "b"] (Python list)
# "a, b, c" (comma-separated string)
```
#### parse_json_or_model
Parse to Pydantic model from JSON or dict:
```python
from superset.mcp_service.utils.schema_utils import parse_json_or_model
# Accepts JSON string or dict
config = parse_json_or_model(value, ConfigModel, param_name="config")
# value can be: '{"name": "test"}' OR {"name": "test"}
```
#### parse_json_or_model_list
Parse to list of Pydantic models:
```python
from superset.mcp_service.utils.schema_utils import parse_json_or_model_list
# Accepts JSON array or list of dicts
filters = parse_json_or_model_list(value, FilterModel, param_name="filters")
# value can be: '[{"col": "name"}]' OR [{"col": "name"}]
```
**Using with Pydantic validators:**
```python
from pydantic import BaseModel, field_validator
from superset.mcp_service.utils.schema_utils import parse_json_or_list
class MyToolRequest(BaseModel):
filters: List[FilterModel] = Field(default_factory=list)
select_columns: List[str] = Field(default_factory=list)
@field_validator("filters", mode="before")
@classmethod
def parse_filters(cls, v):
"""Accept both JSON string and list of objects."""
return parse_json_or_model_list(v, FilterModel, "filters")
@field_validator("select_columns", mode="before")
@classmethod
def parse_columns(cls, v):
"""Accept JSON array, list, or comma-separated string."""
return parse_json_or_list(v, "select_columns")
```
**Core classes already use these utilities:**
- `ModelListCore` uses them for `filters` and `select_columns`
- No need to add parsing logic in individual tools that use core classes
**When to use:**
- Tool parameters that accept complex objects (dicts, lists)
- Parameters that may come from CLI tools (JSON strings) or LLM clients (objects)
- Any field where you want maximum flexibility
### 7. Error Handling
**Use consistent error schemas**:

View File

@@ -734,6 +734,33 @@ class ListChartsRequest(MetadataCacheControl):
"specified.",
),
]
@field_validator("filters", mode="before")
@classmethod
def parse_filters(cls, v: Any) -> List[ChartFilter]:
"""
Parse filters from JSON string or list.
Handles Claude Code bug where objects are double-serialized as strings.
See: https://github.com/anthropics/claude-code/issues/5504
"""
from superset.mcp_service.utils.schema_utils import parse_json_or_model_list
return parse_json_or_model_list(v, ChartFilter, "filters")
@field_validator("select_columns", mode="before")
@classmethod
def parse_select_columns(cls, v: Any) -> List[str]:
"""
Parse select_columns from JSON string, list, or CSV string.
Handles Claude Code bug where arrays are double-serialized as strings.
See: https://github.com/anthropics/claude-code/issues/5504
"""
from superset.mcp_service.utils.schema_utils import parse_json_or_list
return parse_json_or_list(v, "select_columns")
search: Annotated[
str | None,
Field(

View File

@@ -38,6 +38,7 @@ from superset.mcp_service.chart.schemas import (
PerformanceMetadata,
URLPreview,
)
from superset.mcp_service.utils.schema_utils import parse_request
from superset.mcp_service.utils.url_utils import (
get_chart_screenshot_url,
get_superset_base_url,
@@ -49,6 +50,7 @@ logger = logging.getLogger(__name__)
@mcp.tool
@mcp_auth_hook
@parse_request(GenerateChartRequest)
async def generate_chart( # noqa: C901
request: GenerateChartRequest, ctx: Context
) -> GenerateChartResponse:

View File

@@ -30,12 +30,14 @@ from superset.mcp_service.chart.schemas import (
GetChartAvailableFiltersRequest,
)
from superset.mcp_service.mcp_core import ModelGetAvailableFiltersCore
from superset.mcp_service.utils.schema_utils import parse_request
logger = logging.getLogger(__name__)
@mcp.tool
@mcp_auth_hook
@parse_request(GetChartAvailableFiltersRequest)
def get_chart_available_filters(
request: GetChartAvailableFiltersRequest, ctx: Context
) -> ChartAvailableFiltersResponse:

View File

@@ -32,12 +32,14 @@ from superset.mcp_service.chart.schemas import (
serialize_chart_object,
)
from superset.mcp_service.mcp_core import ModelGetInfoCore
from superset.mcp_service.utils.schema_utils import parse_request
logger = logging.getLogger(__name__)
@mcp.tool
@mcp_auth_hook
@parse_request(GetChartInfoRequest)
async def get_chart_info(
request: GetChartInfoRequest, ctx: Context
) -> ChartInfo | ChartError:

View File

@@ -38,6 +38,7 @@ from superset.mcp_service.chart.schemas import (
URLPreview,
VegaLitePreview,
)
from superset.mcp_service.utils.schema_utils import parse_request
from superset.mcp_service.utils.url_utils import get_superset_base_url
logger = logging.getLogger(__name__)
@@ -2021,6 +2022,7 @@ async def _get_chart_preview_internal( # noqa: C901
@mcp.tool
@mcp_auth_hook
@parse_request(GetChartPreviewRequest)
async def get_chart_preview(
request: GetChartPreviewRequest, ctx: Context
) -> ChartPreview | ChartError:

View File

@@ -38,6 +38,7 @@ from superset.mcp_service.chart.schemas import (
serialize_chart_object,
)
from superset.mcp_service.mcp_core import ModelListCore
from superset.mcp_service.utils.schema_utils import parse_request
logger = logging.getLogger(__name__)
@@ -67,6 +68,7 @@ SORTABLE_CHART_COLUMNS = [
@mcp.tool
@mcp_auth_hook
@parse_request(ListChartsRequest)
async def list_charts(request: ListChartsRequest, ctx: Context) -> ChartList:
"""List charts with filtering and search.

View File

@@ -38,6 +38,7 @@ from superset.mcp_service.chart.schemas import (
PerformanceMetadata,
UpdateChartRequest,
)
from superset.mcp_service.utils.schema_utils import parse_request
from superset.mcp_service.utils.url_utils import (
get_chart_screenshot_url,
get_superset_base_url,
@@ -49,6 +50,7 @@ logger = logging.getLogger(__name__)
@mcp.tool
@mcp_auth_hook
@parse_request(UpdateChartRequest)
async def update_chart(
request: UpdateChartRequest, ctx: Context
) -> GenerateChartResponse:

View File

@@ -40,6 +40,7 @@ from superset.mcp_service.chart.schemas import (
UpdateChartPreviewRequest,
URLPreview,
)
from superset.mcp_service.utils.schema_utils import parse_request
from superset.mcp_service.utils.url_utils import get_mcp_service_url
logger = logging.getLogger(__name__)
@@ -47,6 +48,7 @@ logger = logging.getLogger(__name__)
@mcp.tool
@mcp_auth_hook
@parse_request(UpdateChartPreviewRequest)
def update_chart_preview(
request: UpdateChartPreviewRequest, ctx: Context
) -> Dict[str, Any]:

View File

@@ -68,7 +68,14 @@ from __future__ import annotations
from datetime import datetime
from typing import Annotated, Any, Dict, List, Literal, TYPE_CHECKING
from pydantic import BaseModel, ConfigDict, Field, model_validator, PositiveInt
from pydantic import (
BaseModel,
ConfigDict,
Field,
field_validator,
model_validator,
PositiveInt,
)
if TYPE_CHECKING:
from superset.models.dashboard import Dashboard
@@ -202,6 +209,33 @@ class ListDashboardsRequest(MetadataCacheControl):
"if not specified.",
),
]
@field_validator("filters", mode="before")
@classmethod
def parse_filters(cls, v: Any) -> List[DashboardFilter]:
"""
Parse filters from JSON string or list.
Handles Claude Code bug where objects are double-serialized as strings.
See: https://github.com/anthropics/claude-code/issues/5504
"""
from superset.mcp_service.utils.schema_utils import parse_json_or_model_list
return parse_json_or_model_list(v, DashboardFilter, "filters")
@field_validator("select_columns", mode="before")
@classmethod
def parse_select_columns(cls, v: Any) -> List[str]:
"""
Parse select_columns from JSON string, list, or CSV string.
Handles Claude Code bug where arrays are double-serialized as strings.
See: https://github.com/anthropics/claude-code/issues/5504
"""
from superset.mcp_service.utils.schema_utils import parse_json_or_list
return parse_json_or_list(v, "select_columns")
search: Annotated[
str | None,
Field(

View File

@@ -33,6 +33,7 @@ from superset.mcp_service.dashboard.schemas import (
AddChartToDashboardResponse,
DashboardInfo,
)
from superset.mcp_service.utils.schema_utils import parse_request
from superset.mcp_service.utils.url_utils import get_superset_base_url
from superset.utils import json
@@ -136,6 +137,7 @@ def _ensure_layout_structure(layout: Dict[str, Any], row_key: str) -> None:
@mcp.tool
@mcp_auth_hook
@parse_request(AddChartToDashboardRequest)
def add_chart_to_existing_dashboard(
request: AddChartToDashboardRequest, ctx: Context
) -> AddChartToDashboardResponse:

View File

@@ -33,6 +33,7 @@ from superset.mcp_service.dashboard.schemas import (
GenerateDashboardRequest,
GenerateDashboardResponse,
)
from superset.mcp_service.utils.schema_utils import parse_request
from superset.mcp_service.utils.url_utils import get_superset_base_url
from superset.utils import json
@@ -119,6 +120,7 @@ def _create_dashboard_layout(chart_objects: List[Any]) -> Dict[str, Any]:
@mcp.tool
@mcp_auth_hook
@parse_request(GenerateDashboardRequest)
def generate_dashboard(
request: GenerateDashboardRequest, ctx: Context
) -> GenerateDashboardResponse:

View File

@@ -29,12 +29,14 @@ from superset.mcp_service.dashboard.schemas import (
GetDashboardAvailableFiltersRequest,
)
from superset.mcp_service.mcp_core import ModelGetAvailableFiltersCore
from superset.mcp_service.utils.schema_utils import parse_request
logger = logging.getLogger(__name__)
@mcp.tool
@mcp_auth_hook
@parse_request(GetDashboardAvailableFiltersRequest)
async def get_dashboard_available_filters(
request: GetDashboardAvailableFiltersRequest, ctx: Context
) -> DashboardAvailableFilters:

View File

@@ -36,12 +36,14 @@ from superset.mcp_service.dashboard.schemas import (
GetDashboardInfoRequest,
)
from superset.mcp_service.mcp_core import ModelGetInfoCore
from superset.mcp_service.utils.schema_utils import parse_request
logger = logging.getLogger(__name__)
@mcp.tool
@mcp_auth_hook
@parse_request(GetDashboardInfoRequest)
async def get_dashboard_info(
request: GetDashboardInfoRequest, ctx: Context
) -> DashboardInfo | DashboardError:

View File

@@ -36,6 +36,7 @@ from superset.mcp_service.dashboard.schemas import (
serialize_dashboard_object,
)
from superset.mcp_service.mcp_core import ModelListCore
from superset.mcp_service.utils.schema_utils import parse_request
logger = logging.getLogger(__name__)
@@ -61,6 +62,7 @@ SORTABLE_DASHBOARD_COLUMNS = [
@mcp.tool
@mcp_auth_hook
@parse_request(ListDashboardsRequest)
async def list_dashboards(
request: ListDashboardsRequest, ctx: Context
) -> DashboardList:
@@ -70,7 +72,6 @@ async def list_dashboards(
Sortable columns for order_column: id, dashboard_title, slug, published,
changed_on, created_on
"""
from superset.daos.dashboard import DashboardDAO
tool = ModelListCore(

View File

@@ -29,12 +29,14 @@ from superset.mcp_service.dataset.schemas import (
GetDatasetAvailableFiltersRequest,
)
from superset.mcp_service.mcp_core import ModelGetAvailableFiltersCore
from superset.mcp_service.utils.schema_utils import parse_request
logger = logging.getLogger(__name__)
@mcp.tool
@mcp_auth_hook
@parse_request(GetDatasetAvailableFiltersRequest)
async def get_dataset_available_filters(
request: GetDatasetAvailableFiltersRequest, ctx: Context
) -> DatasetAvailableFilters:

View File

@@ -36,12 +36,14 @@ from superset.mcp_service.dataset.schemas import (
serialize_dataset_object,
)
from superset.mcp_service.mcp_core import ModelGetInfoCore
from superset.mcp_service.utils.schema_utils import parse_request
logger = logging.getLogger(__name__)
@mcp.tool
@mcp_auth_hook
@parse_request(GetDatasetInfoRequest)
async def get_dataset_info(
request: GetDatasetInfoRequest, ctx: Context
) -> DatasetInfo | DatasetError:

View File

@@ -36,6 +36,7 @@ from superset.mcp_service.dataset.schemas import (
serialize_dataset_object,
)
from superset.mcp_service.mcp_core import ModelListCore
from superset.mcp_service.utils.schema_utils import parse_request
logger = logging.getLogger(__name__)
@@ -64,6 +65,7 @@ SORTABLE_DATASET_COLUMNS = [
@mcp.tool
@mcp_auth_hook
@parse_request(ListDatasetsRequest)
async def list_datasets(request: ListDatasetsRequest, ctx: Context) -> DatasetList:
"""List datasets with filtering and search.

View File

@@ -35,10 +35,12 @@ from superset.mcp_service.chart.chart_utils import (
from superset.mcp_service.chart.schemas import (
GenerateExploreLinkRequest,
)
from superset.mcp_service.utils.schema_utils import parse_request
@mcp.tool
@mcp_auth_hook
@parse_request(GenerateExploreLinkRequest)
async def generate_explore_link(
request: GenerateExploreLinkRequest, ctx: Context
) -> Dict[str, Any]:

View File

@@ -128,17 +128,19 @@ class ModelListCore(BaseCore, Generic[L]):
page: int = 0,
page_size: int = 10,
) -> L:
# If filters is a string (e.g., from a test), parse it as JSON
if isinstance(filters, str):
from superset.utils import json
# Parse filters using generic utility (accepts JSON string or object)
from superset.mcp_service.utils.schema_utils import (
parse_json_or_list,
parse_json_or_passthrough,
)
filters = json.loads(filters)
# Ensure select_columns is a list and track what was requested
filters = parse_json_or_passthrough(filters, param_name="filters")
# Parse select_columns using generic utility (accepts JSON, list, or CSV)
if select_columns:
if isinstance(select_columns, str):
select_columns = [
col.strip() for col in select_columns.split(",") if col.strip()
]
select_columns = parse_json_or_list(
select_columns, param_name="select_columns"
)
columns_to_load = select_columns
columns_requested = select_columns
else:

View File

@@ -33,12 +33,14 @@ from superset.mcp_service.sql_lab.schemas import (
ExecuteSqlRequest,
ExecuteSqlResponse,
)
from superset.mcp_service.utils.schema_utils import parse_request
logger = logging.getLogger(__name__)
@mcp.tool
@mcp_auth_hook
@parse_request(ExecuteSqlRequest)
async def execute_sql(request: ExecuteSqlRequest, ctx: Context) -> ExecuteSqlResponse:
"""Execute SQL query against database.

View File

@@ -32,12 +32,14 @@ from superset.mcp_service.sql_lab.schemas import (
OpenSqlLabRequest,
SqlLabResponse,
)
from superset.mcp_service.utils.schema_utils import parse_request
logger = logging.getLogger(__name__)
@mcp.tool
@mcp_auth_hook
@parse_request(OpenSqlLabRequest)
def open_sql_lab_with_context(
request: OpenSqlLabRequest, ctx: Context
) -> SqlLabResponse:

View File

@@ -38,6 +38,7 @@ from superset.mcp_service.system.system_utils import (
calculate_popular_content,
calculate_recent_activity,
)
from superset.mcp_service.utils.schema_utils import parse_request
logger = logging.getLogger(__name__)
@@ -71,6 +72,7 @@ _instance_info_core = InstanceInfoCore(
@mcp.tool
@mcp_auth_hook
@parse_request(GetSupersetInstanceInfoRequest)
def get_instance_info(
request: GetSupersetInstanceInfoRequest, ctx: Context
) -> InstanceInfo:

View File

@@ -0,0 +1,444 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Generic utilities for flexible schema input handling in MCP tools.
This module provides utilities to accept both JSON string and object formats
for input parameters, making MCP tools more flexible for different clients.
"""
from __future__ import annotations
import asyncio
import logging
from functools import wraps
from typing import Any, Callable, List, Type, TypeVar
from pydantic import BaseModel, ValidationError
logger = logging.getLogger(__name__)
T = TypeVar("T")
class JSONParseError(ValueError):
"""Raised when JSON parsing fails with helpful context."""
def __init__(self, value: Any, error: Exception, param_name: str = "parameter"):
self.value = value
self.original_error = error
self.param_name = param_name
super().__init__(
f"Failed to parse {param_name} from JSON string: {error}. "
f"Received value: {value!r}"
)
def parse_json_or_passthrough(
value: Any, param_name: str = "parameter", strict: bool = False
) -> Any:
"""
Parse a value that can be either a JSON string or a native Python object.
This function handles the common pattern where API parameters can be provided
as either:
- A JSON string (e.g., from CLI tools or tests): '{"key": "value"}'
- A native Python object (e.g., from LLM clients): {"key": "value"}
Args:
value: The input value to parse. Can be a string, list, dict, or any JSON-
serializable type.
param_name: Name of the parameter for error messages (default: "parameter")
strict: If True, raises JSONParseError on parse failures. If False, logs
warning and returns original value (default: False)
Returns:
Parsed Python object if value was a JSON string, otherwise returns value
unchanged.
Raises:
JSONParseError: If strict=True and JSON parsing fails.
Examples:
>>> parse_json_or_passthrough('[1, 2, 3]', 'numbers')
[1, 2, 3]
>>> parse_json_or_passthrough([1, 2, 3], 'numbers')
[1, 2, 3]
>>> parse_json_or_passthrough('{"key": "value"}', 'config')
{'key': 'value'}
>>> parse_json_or_passthrough({'key': 'value'}, 'config')
{'key': 'value'}
"""
# If not a string, return as-is (already in object form)
if not isinstance(value, str):
return value
# Try to parse as JSON
try:
from superset.utils import json
parsed = json.loads(value)
logger.debug("Successfully parsed %s from JSON string", param_name)
return parsed
except (ValueError, TypeError) as e:
error_msg = (
f"Failed to parse {param_name} from JSON string: {e}. Received: {value!r}"
)
if strict:
raise JSONParseError(value, e, param_name) from None
logger.warning("%s. Returning original value.", error_msg)
return value
def parse_json_or_list(
value: Any, param_name: str = "parameter", item_separator: str = ","
) -> List[Any]:
"""
Parse a value into a list, accepting JSON string, list, or comma-separated string.
This function provides maximum flexibility for list parameters by accepting:
- JSON array string: '["item1", "item2"]'
- Python list: ["item1", "item2"]
- Comma-separated string: "item1, item2, item3"
- Empty/None: returns empty list
Args:
value: Input value to parse into a list
param_name: Name of the parameter for error messages
item_separator: Separator for comma-separated strings (default: ",")
Returns:
List of items. Returns empty list if value is None or empty.
Examples:
>>> parse_json_or_list('["a", "b"]', 'items')
['a', 'b']
>>> parse_json_or_list(['a', 'b'], 'items')
['a', 'b']
>>> parse_json_or_list('a, b, c', 'items')
['a', 'b', 'c']
>>> parse_json_or_list(None, 'items')
[]
"""
# Handle None and empty values
if value is None or value == "":
return []
# Already a list, return as-is
if isinstance(value, list):
return value
# Try to parse as JSON
if isinstance(value, str):
try:
from superset.utils import json
parsed = json.loads(value)
# If successfully parsed and it's a list, return it
if isinstance(parsed, list):
logger.debug("Successfully parsed %s from JSON string", param_name)
return parsed
# If parsed to non-list (e.g., single value), wrap in list
logger.debug(
"Parsed %s from JSON to non-list, wrapping in list", param_name
)
return [parsed]
except (ValueError, TypeError):
# Not valid JSON, try comma-separated parsing
logger.debug(
"Could not parse %s as JSON, trying comma-separated", param_name
)
items = [
item.strip() for item in value.split(item_separator) if item.strip()
]
return items
# For any other type, wrap in a list
logger.debug("Wrapping %s value in list", param_name)
return [value]
def parse_json_or_model(
value: Any, model_class: Type[BaseModel], param_name: str = "parameter"
) -> BaseModel:
"""
Parse a value into a Pydantic model, accepting JSON string or dict.
Args:
value: Input value to parse (JSON string, dict, or model instance)
model_class: Pydantic model class to validate against
param_name: Name of the parameter for error messages
Returns:
Validated Pydantic model instance
Raises:
ValidationError: If the value cannot be parsed or validated
Examples:
>>> class MyModel(BaseModel):
... name: str
... value: int
>>> parse_json_or_model('{"name": "test", "value": 42}', MyModel)
MyModel(name='test', value=42)
>>> parse_json_or_model({"name": "test", "value": 42}, MyModel)
MyModel(name='test', value=42)
"""
# If already an instance of the model, return as-is
if isinstance(value, model_class):
return value
# Parse JSON string if needed
parsed_value = parse_json_or_passthrough(value, param_name, strict=True)
# Validate and construct the model
try:
return model_class.model_validate(parsed_value)
except ValidationError:
logger.error(
"Failed to validate %s against %s", param_name, model_class.__name__
)
raise
def parse_json_or_model_list(
value: Any,
model_class: Type[BaseModel],
param_name: str = "parameter",
) -> List[BaseModel]:
"""
Parse a value into a list of Pydantic models, accepting JSON string or list.
Args:
value: Input value to parse (JSON string, list of dicts, or list of models)
model_class: Pydantic model class for list items
param_name: Name of the parameter for error messages
Returns:
List of validated Pydantic model instances
Raises:
ValidationError: If any item cannot be parsed or validated
Examples:
>>> class Item(BaseModel):
... name: str
>>> parse_json_or_model_list('[{"name": "a"}, {"name": "b"}]', Item)
[Item(name='a'), Item(name='b')]
>>> parse_json_or_model_list([{"name": "a"}, {"name": "b"}], Item)
[Item(name='a'), Item(name='b')]
"""
# Handle None and empty
if value is None or value == "":
return []
# Parse to list first
items = parse_json_or_list(value, param_name)
# Validate each item against the model
validated_items = []
for i, item in enumerate(items):
try:
if isinstance(item, model_class):
validated_items.append(item)
else:
validated_items.append(model_class.model_validate(item))
except ValidationError:
logger.error(
"Failed to validate %s[%s] against %s",
param_name,
i,
model_class.__name__,
)
# Re-raise original validation error
raise
return validated_items
# Pydantic validator decorators for common use cases
def json_or_passthrough_validator(
param_name: str | None = None, strict: bool = False
) -> Callable[[Type[BaseModel], Any, Any], Any]:
"""
Decorator factory for Pydantic field validators that accept JSON or objects.
This creates a validator that can be used with Pydantic's @field_validator
decorator to automatically parse JSON strings.
Args:
param_name: Parameter name for error messages (uses field name if None)
strict: Whether to raise errors on parse failures
Returns:
Validator function compatible with @field_validator
Example:
>>> class MySchema(BaseModel):
... config: dict
...
... @field_validator('config', mode='before')
... @classmethod
... def parse_config(cls, v):
... return parse_json_or_passthrough(v, 'config')
"""
def validator(cls: Type[BaseModel], v: Any, info: Any = None) -> Any:
# Use field name from validation info if param_name not provided
field_name = param_name or (info.field_name if info else "field")
return parse_json_or_passthrough(v, field_name, strict)
return validator
def json_or_list_validator(
param_name: str | None = None, item_separator: str = ","
) -> Callable[[Type[BaseModel], Any, Any], List[Any]]:
"""
Decorator factory for Pydantic validators that parse values into lists.
Args:
param_name: Parameter name for error messages
item_separator: Separator for comma-separated strings
Returns:
Validator function compatible with @field_validator
Example:
>>> class MySchema(BaseModel):
... items: List[str]
...
... @field_validator('items', mode='before')
... @classmethod
... def parse_items(cls, v):
... return parse_json_or_list(v, 'items')
"""
def validator(cls: Type[BaseModel], v: Any, info: Any = None) -> List[Any]:
field_name = param_name or (info.field_name if info else "field")
return parse_json_or_list(v, field_name, item_separator)
return validator
def json_or_model_list_validator(
model_class: Type[BaseModel], param_name: str | None = None
) -> Callable[[Type[BaseModel], Any, Any], List[BaseModel]]:
"""
Decorator factory for Pydantic validators that parse lists of models.
Args:
model_class: Pydantic model class for list items
param_name: Parameter name for error messages
Returns:
Validator function compatible with @field_validator
Example:
>>> class FilterModel(BaseModel):
... col: str
... value: str
...
>>> class MySchema(BaseModel):
... filters: List[FilterModel]
...
... @field_validator('filters', mode='before')
... @classmethod
... def parse_filters(cls, v):
... return parse_json_or_model_list(v, FilterModel, 'filters')
"""
def validator(cls: Type[BaseModel], v: Any, info: Any = None) -> List[BaseModel]:
field_name = param_name or (info.field_name if info else "field")
return parse_json_or_model_list(v, model_class, field_name)
return validator
def parse_request(
request_class: Type[BaseModel],
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""
Decorator to handle Claude Code bug where requests are double-serialized as strings.
Automatically parses string requests to Pydantic models before calling
the tool function.
This eliminates the need for manual parsing code in every tool function.
See: https://github.com/anthropics/claude-code/issues/5504
Args:
request_class: The Pydantic model class for the request
Returns:
Decorator function that wraps the tool function
Usage:
@mcp.tool
@mcp_auth_hook
@parse_request(ListChartsRequest)
async def list_charts(
request: ListChartsRequest, ctx: Context
) -> ChartList:
# Decorator handles string conversion automatically
await ctx.info(f"Listing charts: page={request.page}")
...
Note:
- Works with both async and sync functions
- Request must be the first positional argument
- If request is already a model instance, it passes through unchanged
- Handles JSON string parsing with helpful error messages
"""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
if asyncio.iscoroutinefunction(func):
@wraps(func)
async def async_wrapper(request: Any, *args: Any, **kwargs: Any) -> Any:
# Parse if string, otherwise pass through
# (parse_json_or_model handles both)
parsed_request = parse_json_or_model(request, request_class, "request")
return await func(parsed_request, *args, **kwargs)
return async_wrapper
else:
@wraps(func)
def sync_wrapper(request: Any, *args: Any, **kwargs: Any) -> Any:
# Parse if string, otherwise pass through
# (parse_json_or_model handles both)
parsed_request = parse_json_or_model(request, request_class, "request")
return func(parsed_request, *args, **kwargs)
return sync_wrapper
return decorator

View File

@@ -0,0 +1,515 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for MCP service schema utilities.
"""
import pytest
from pydantic import BaseModel, ValidationError
from superset.mcp_service.utils.schema_utils import (
JSONParseError,
parse_json_or_list,
parse_json_or_model,
parse_json_or_model_list,
parse_json_or_passthrough,
parse_request,
)
class TestParseJsonOrPassthrough:
"""Test parse_json_or_passthrough function."""
def test_parse_valid_json_string(self):
"""Should parse valid JSON string to Python object."""
result = parse_json_or_passthrough('{"key": "value"}', "config")
assert result == {"key": "value"}
def test_parse_json_array(self):
"""Should parse JSON array string to Python list."""
result = parse_json_or_passthrough("[1, 2, 3]", "numbers")
assert result == [1, 2, 3]
def test_passthrough_dict(self):
"""Should return dict as-is without parsing."""
input_dict = {"key": "value"}
result = parse_json_or_passthrough(input_dict, "config")
assert result is input_dict
def test_passthrough_list(self):
"""Should return list as-is without parsing."""
input_list = [1, 2, 3]
result = parse_json_or_passthrough(input_list, "numbers")
assert result is input_list
def test_passthrough_none(self):
"""Should return None as-is."""
result = parse_json_or_passthrough(None, "value")
assert result is None
def test_invalid_json_non_strict(self):
"""Should return original value when JSON parsing fails (non-strict)."""
result = parse_json_or_passthrough("not valid json", "config", strict=False)
assert result == "not valid json"
def test_invalid_json_strict(self):
"""Should raise JSONParseError when parsing fails (strict mode)."""
with pytest.raises(JSONParseError) as exc_info:
parse_json_or_passthrough("not valid json", "config", strict=True)
assert exc_info.value.param_name == "config"
assert exc_info.value.value == "not valid json"
def test_parse_json_number(self):
"""Should parse numeric JSON string."""
result = parse_json_or_passthrough("42", "number")
assert result == 42
def test_parse_json_boolean(self):
"""Should parse boolean JSON string."""
result = parse_json_or_passthrough("true", "flag")
assert result is True
def test_parse_nested_json(self):
"""Should parse nested JSON structures."""
json_str = '{"outer": {"inner": [1, 2, 3]}}'
result = parse_json_or_passthrough(json_str, "nested")
assert result == {"outer": {"inner": [1, 2, 3]}}
class TestParseJsonOrList:
"""Test parse_json_or_list function."""
def test_parse_json_array(self):
"""Should parse JSON array string to list."""
result = parse_json_or_list('["a", "b", "c"]', "items")
assert result == ["a", "b", "c"]
def test_passthrough_list(self):
"""Should return list as-is."""
input_list = ["a", "b", "c"]
result = parse_json_or_list(input_list, "items")
assert result is input_list
def test_parse_comma_separated(self):
"""Should parse comma-separated string to list."""
result = parse_json_or_list("a, b, c", "items")
assert result == ["a", "b", "c"]
def test_parse_comma_separated_with_whitespace(self):
"""Should handle whitespace in comma-separated strings."""
result = parse_json_or_list(" a , b , c ", "items")
assert result == ["a", "b", "c"]
def test_empty_string_returns_empty_list(self):
"""Should return empty list for empty string."""
result = parse_json_or_list("", "items")
assert result == []
def test_none_returns_empty_list(self):
"""Should return empty list for None."""
result = parse_json_or_list(None, "items")
assert result == []
def test_single_json_value_wrapped_in_list(self):
"""Should wrap single JSON value in list."""
result = parse_json_or_list('"single"', "items")
assert result == ["single"]
def test_custom_separator(self):
"""Should use custom separator when provided."""
result = parse_json_or_list("a|b|c", "items", item_separator="|")
assert result == ["a", "b", "c"]
def test_non_list_wrapped(self):
"""Should wrap non-list types in a list."""
result = parse_json_or_list(42, "items")
assert result == [42]
def test_parse_empty_items_in_csv(self):
"""Should ignore empty items in comma-separated string."""
result = parse_json_or_list("a,,b,,c", "items")
assert result == ["a", "b", "c"]
class TestParseJsonOrModel:
"""Test parse_json_or_model function."""
class TestModel(BaseModel):
"""Test Pydantic model."""
name: str
value: int
def test_parse_json_string(self):
"""Should parse JSON string to model instance."""
result = parse_json_or_model(
'{"name": "test", "value": 42}', self.TestModel, "config"
)
assert isinstance(result, self.TestModel)
assert result.name == "test"
assert result.value == 42
def test_parse_dict(self):
"""Should parse dict to model instance."""
result = parse_json_or_model(
{"name": "test", "value": 42}, self.TestModel, "config"
)
assert isinstance(result, self.TestModel)
assert result.name == "test"
assert result.value == 42
def test_passthrough_model_instance(self):
"""Should return model instance as-is."""
instance = self.TestModel(name="test", value=42)
result = parse_json_or_model(instance, self.TestModel, "config")
assert result is instance
def test_invalid_json_raises_error(self):
"""Should raise JSONParseError for invalid JSON."""
with pytest.raises(JSONParseError):
parse_json_or_model("not valid json", self.TestModel, "config")
def test_invalid_model_data_raises_validation_error(self):
"""Should raise ValidationError for invalid model data."""
with pytest.raises(ValidationError):
parse_json_or_model({"name": "test"}, self.TestModel, "config")
def test_wrong_type_raises_validation_error(self):
"""Should raise ValidationError for wrong data types."""
with pytest.raises(ValidationError):
parse_json_or_model(
{"name": "test", "value": "not_a_number"}, self.TestModel, "config"
)
class TestParseJsonOrModelList:
"""Test parse_json_or_model_list function."""
class ItemModel(BaseModel):
"""Test Pydantic model for list items."""
name: str
value: int
def test_parse_json_array(self):
"""Should parse JSON array to list of models."""
json_str = '[{"name": "a", "value": 1}, {"name": "b", "value": 2}]'
result = parse_json_or_model_list(json_str, self.ItemModel, "items")
assert len(result) == 2
assert all(isinstance(item, self.ItemModel) for item in result)
assert result[0].name == "a"
assert result[1].value == 2
def test_parse_list_of_dicts(self):
"""Should parse list of dicts to list of models."""
input_list = [{"name": "a", "value": 1}, {"name": "b", "value": 2}]
result = parse_json_or_model_list(input_list, self.ItemModel, "items")
assert len(result) == 2
assert all(isinstance(item, self.ItemModel) for item in result)
def test_passthrough_list_of_models(self):
"""Should return list of models as-is."""
input_list = [
self.ItemModel(name="a", value=1),
self.ItemModel(name="b", value=2),
]
result = parse_json_or_model_list(input_list, self.ItemModel, "items")
assert len(result) == 2
assert result[0] is input_list[0]
assert result[1] is input_list[1]
def test_empty_returns_empty_list(self):
"""Should return empty list for empty input."""
assert parse_json_or_model_list(None, self.ItemModel, "items") == []
assert parse_json_or_model_list("", self.ItemModel, "items") == []
assert parse_json_or_model_list([], self.ItemModel, "items") == []
def test_invalid_item_raises_validation_error(self):
"""Should raise ValidationError for invalid item in list."""
input_list = [{"name": "a", "value": 1}, {"name": "b"}] # Missing value
with pytest.raises(ValidationError):
parse_json_or_model_list(input_list, self.ItemModel, "items")
def test_mixed_models_and_dicts(self):
"""Should handle mixed list of models and dicts."""
input_list = [
self.ItemModel(name="a", value=1),
{"name": "b", "value": 2},
]
result = parse_json_or_model_list(input_list, self.ItemModel, "items")
assert len(result) == 2
assert all(isinstance(item, self.ItemModel) for item in result)
class TestPydanticIntegration:
"""Test integration with Pydantic validators."""
def test_field_validator_with_json_string(self):
"""Should work with Pydantic field validators for JSON strings."""
from pydantic import field_validator
class TestSchema(BaseModel):
"""Test schema with field validator."""
config: dict
@field_validator("config", mode="before")
@classmethod
def parse_config(cls, v):
"""Parse config from JSON or dict."""
return parse_json_or_passthrough(v, "config")
# Test with JSON string
schema = TestSchema.model_validate({"config": '{"key": "value"}'})
assert schema.config == {"key": "value"}
# Test with dict
schema = TestSchema.model_validate({"config": {"key": "value"}})
assert schema.config == {"key": "value"}
def test_field_validator_with_list(self):
"""Should work with Pydantic field validators for lists."""
from pydantic import field_validator
class TestSchema(BaseModel):
"""Test schema with list field validator."""
items: list
@field_validator("items", mode="before")
@classmethod
def parse_items(cls, v):
"""Parse items from various formats."""
return parse_json_or_list(v, "items")
# Test with JSON array
schema = TestSchema.model_validate({"items": '["a", "b", "c"]'})
assert schema.items == ["a", "b", "c"]
# Test with list
schema = TestSchema.model_validate({"items": ["a", "b", "c"]})
assert schema.items == ["a", "b", "c"]
# Test with CSV string
schema = TestSchema.model_validate({"items": "a, b, c"})
assert schema.items == ["a", "b", "c"]
class TestEdgeCases:
"""Test edge cases and error conditions."""
def test_empty_json_object(self):
"""Should handle empty JSON objects."""
result = parse_json_or_passthrough("{}", "config")
assert result == {}
def test_empty_json_array(self):
"""Should handle empty JSON arrays."""
result = parse_json_or_list("[]", "items")
assert result == []
def test_whitespace_only_string(self):
"""Should handle whitespace-only strings."""
result = parse_json_or_list(" ", "items")
assert result == []
def test_malformed_json(self):
"""Should handle malformed JSON gracefully."""
result = parse_json_or_passthrough('{"key": invalid}', "config", strict=False)
assert result == '{"key": invalid}'
def test_unicode_in_json(self):
"""Should handle Unicode characters in JSON."""
result = parse_json_or_passthrough('{"name": "测试"}', "config")
assert result == {"name": "测试"}
def test_special_characters_in_csv(self):
"""Should handle special characters in CSV strings."""
result = parse_json_or_list("item-1, item_2, item.3", "items")
assert result == ["item-1", "item_2", "item.3"]
class TestParseRequestDecorator:
"""Test parse_request decorator for MCP tools."""
class RequestModel(BaseModel):
"""Test request model."""
name: str
count: int
def test_decorator_with_json_string_async(self):
"""Should parse JSON string request in async function."""
@parse_request(self.RequestModel)
async def async_tool(request, ctx=None):
return f"{request.name}:{request.count}"
import asyncio
result = asyncio.run(async_tool('{"name": "test", "count": 5}'))
assert result == "test:5"
def test_decorator_with_json_string_sync(self):
"""Should parse JSON string request in sync function."""
@parse_request(self.RequestModel)
def sync_tool(request, ctx=None):
return f"{request.name}:{request.count}"
result = sync_tool('{"name": "test", "count": 5}')
assert result == "test:5"
def test_decorator_with_dict_async(self):
"""Should handle dict request in async function."""
@parse_request(self.RequestModel)
async def async_tool(request, ctx=None):
return f"{request.name}:{request.count}"
import asyncio
result = asyncio.run(async_tool({"name": "test", "count": 5}))
assert result == "test:5"
def test_decorator_with_dict_sync(self):
"""Should handle dict request in sync function."""
@parse_request(self.RequestModel)
def sync_tool(request, ctx=None):
return f"{request.name}:{request.count}"
result = sync_tool({"name": "test", "count": 5})
assert result == "test:5"
def test_decorator_with_model_instance_async(self):
"""Should pass through model instance in async function."""
@parse_request(self.RequestModel)
async def async_tool(request, ctx=None):
return f"{request.name}:{request.count}"
import asyncio
instance = self.RequestModel(name="test", count=5)
result = asyncio.run(async_tool(instance))
assert result == "test:5"
def test_decorator_with_model_instance_sync(self):
"""Should pass through model instance in sync function."""
@parse_request(self.RequestModel)
def sync_tool(request, ctx=None):
return f"{request.name}:{request.count}"
instance = self.RequestModel(name="test", count=5)
result = sync_tool(instance)
assert result == "test:5"
def test_decorator_preserves_function_signature_async(self):
"""Should preserve original async function signature."""
@parse_request(self.RequestModel)
async def async_tool(request, ctx=None, extra=None):
return f"{request.name}:{request.count}:{extra}"
import asyncio
result = asyncio.run(
async_tool('{"name": "test", "count": 5}', ctx=None, extra="data")
)
assert result == "test:5:data"
def test_decorator_preserves_function_signature_sync(self):
"""Should preserve original sync function signature."""
@parse_request(self.RequestModel)
def sync_tool(request, ctx=None, extra=None):
return f"{request.name}:{request.count}:{extra}"
result = sync_tool('{"name": "test", "count": 5}', ctx=None, extra="data")
assert result == "test:5:data"
def test_decorator_raises_validation_error_async(self):
"""Should raise ValidationError for invalid data in async function."""
@parse_request(self.RequestModel)
async def async_tool(request, ctx=None):
return f"{request.name}:{request.count}"
import asyncio
with pytest.raises(ValidationError):
asyncio.run(async_tool('{"name": "test"}')) # Missing count
def test_decorator_raises_validation_error_sync(self):
"""Should raise ValidationError for invalid data in sync function."""
@parse_request(self.RequestModel)
def sync_tool(request, ctx=None):
return f"{request.name}:{request.count}"
with pytest.raises(ValidationError):
sync_tool('{"name": "test"}') # Missing count
def test_decorator_with_complex_model_async(self):
"""Should handle complex nested models in async function."""
class NestedModel(BaseModel):
"""Nested model."""
value: int
class ComplexModel(BaseModel):
"""Complex request model."""
name: str
nested: NestedModel
@parse_request(ComplexModel)
async def async_tool(request, ctx=None):
return f"{request.name}:{request.nested.value}"
import asyncio
json_str = '{"name": "test", "nested": {"value": 42}}'
result = asyncio.run(async_tool(json_str))
assert result == "test:42"
def test_decorator_with_complex_model_sync(self):
"""Should handle complex nested models in sync function."""
class NestedModel(BaseModel):
"""Nested model."""
value: int
class ComplexModel(BaseModel):
"""Complex request model."""
name: str
nested: NestedModel
@parse_request(ComplexModel)
def sync_tool(request, ctx=None):
return f"{request.name}:{request.nested.value}"
json_str = '{"name": "test", "nested": {"value": 42}}'
result = sync_tool(json_str)
assert result == "test:42"