mirror of
https://github.com/apache/superset.git
synced 2026-04-07 18:35:15 +00:00
feat(mcp): Add flexible input parsing to handle double-serialized requests (#36249)
This commit is contained in:
@@ -293,7 +293,113 @@ def my_function(
|
||||
- Still import `List`, `Dict`, `Any`, etc. from typing (for now)
|
||||
- All new code must follow this pattern
|
||||
|
||||
### 6. Error Handling
|
||||
### 6. Flexible Input Parsing (JSON String or Object)
|
||||
|
||||
**MCP tools accept both JSON string and native object formats for parameters** using utilities from `superset.mcp_service.utils.schema_utils`. This makes tools flexible for different client types (LLM clients send objects, CLI tools send JSON strings).
|
||||
|
||||
**PREFERRED: Use the `@parse_request` decorator** for tool functions to automatically handle request parsing:
|
||||
|
||||
```python
|
||||
from superset.mcp_service.utils.schema_utils import parse_request
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
@parse_request(ListChartsRequest) # Automatically parses string requests!
|
||||
async def list_charts(request: ListChartsRequest | str, ctx: Context) -> ChartList:
|
||||
"""List charts with filtering and search."""
|
||||
# request is guaranteed to be ListChartsRequest here - no manual parsing needed!
|
||||
await ctx.info(f"Listing charts: page={request.page}")
|
||||
...
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Eliminates 5 lines of boilerplate code per tool
|
||||
- Handles both async and sync functions automatically
|
||||
- Works with Claude Code bug (GitHub issue #5504)
|
||||
- Cleaner, more maintainable code
|
||||
|
||||
**Available utilities for other use cases:**
|
||||
|
||||
#### parse_json_or_passthrough
|
||||
Parse JSON string or return object as-is:
|
||||
|
||||
```python
|
||||
from superset.mcp_service.utils.schema_utils import parse_json_or_passthrough
|
||||
|
||||
# Accepts both formats
|
||||
config = parse_json_or_passthrough(value, param_name="config")
|
||||
# value can be: '{"key": "value"}' (JSON string) OR {"key": "value"} (dict)
|
||||
```
|
||||
|
||||
#### parse_json_or_list
|
||||
Parse to list from JSON, list, or comma-separated string:
|
||||
|
||||
```python
|
||||
from superset.mcp_service.utils.schema_utils import parse_json_or_list
|
||||
|
||||
# Accepts multiple formats
|
||||
items = parse_json_or_list(value, param_name="items")
|
||||
# value can be:
|
||||
# '["a", "b"]' (JSON array)
|
||||
# ["a", "b"] (Python list)
|
||||
# "a, b, c" (comma-separated string)
|
||||
```
|
||||
|
||||
#### parse_json_or_model
|
||||
Parse to Pydantic model from JSON or dict:
|
||||
|
||||
```python
|
||||
from superset.mcp_service.utils.schema_utils import parse_json_or_model
|
||||
|
||||
# Accepts JSON string or dict
|
||||
config = parse_json_or_model(value, ConfigModel, param_name="config")
|
||||
# value can be: '{"name": "test"}' OR {"name": "test"}
|
||||
```
|
||||
|
||||
#### parse_json_or_model_list
|
||||
Parse to list of Pydantic models:
|
||||
|
||||
```python
|
||||
from superset.mcp_service.utils.schema_utils import parse_json_or_model_list
|
||||
|
||||
# Accepts JSON array or list of dicts
|
||||
filters = parse_json_or_model_list(value, FilterModel, param_name="filters")
|
||||
# value can be: '[{"col": "name"}]' OR [{"col": "name"}]
|
||||
```
|
||||
|
||||
**Using with Pydantic validators:**
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, field_validator
|
||||
from superset.mcp_service.utils.schema_utils import parse_json_or_list
|
||||
|
||||
class MyToolRequest(BaseModel):
|
||||
filters: List[FilterModel] = Field(default_factory=list)
|
||||
select_columns: List[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("filters", mode="before")
|
||||
@classmethod
|
||||
def parse_filters(cls, v):
|
||||
"""Accept both JSON string and list of objects."""
|
||||
return parse_json_or_model_list(v, FilterModel, "filters")
|
||||
|
||||
@field_validator("select_columns", mode="before")
|
||||
@classmethod
|
||||
def parse_columns(cls, v):
|
||||
"""Accept JSON array, list, or comma-separated string."""
|
||||
return parse_json_or_list(v, "select_columns")
|
||||
```
|
||||
|
||||
**Core classes already use these utilities:**
|
||||
- `ModelListCore` uses them for `filters` and `select_columns`
|
||||
- No need to add parsing logic in individual tools that use core classes
|
||||
|
||||
**When to use:**
|
||||
- Tool parameters that accept complex objects (dicts, lists)
|
||||
- Parameters that may come from CLI tools (JSON strings) or LLM clients (objects)
|
||||
- Any field where you want maximum flexibility
|
||||
|
||||
### 7. Error Handling
|
||||
|
||||
**Use consistent error schemas**:
|
||||
|
||||
|
||||
@@ -734,6 +734,33 @@ class ListChartsRequest(MetadataCacheControl):
|
||||
"specified.",
|
||||
),
|
||||
]
|
||||
|
||||
@field_validator("filters", mode="before")
|
||||
@classmethod
|
||||
def parse_filters(cls, v: Any) -> List[ChartFilter]:
|
||||
"""
|
||||
Parse filters from JSON string or list.
|
||||
|
||||
Handles Claude Code bug where objects are double-serialized as strings.
|
||||
See: https://github.com/anthropics/claude-code/issues/5504
|
||||
"""
|
||||
from superset.mcp_service.utils.schema_utils import parse_json_or_model_list
|
||||
|
||||
return parse_json_or_model_list(v, ChartFilter, "filters")
|
||||
|
||||
@field_validator("select_columns", mode="before")
|
||||
@classmethod
|
||||
def parse_select_columns(cls, v: Any) -> List[str]:
|
||||
"""
|
||||
Parse select_columns from JSON string, list, or CSV string.
|
||||
|
||||
Handles Claude Code bug where arrays are double-serialized as strings.
|
||||
See: https://github.com/anthropics/claude-code/issues/5504
|
||||
"""
|
||||
from superset.mcp_service.utils.schema_utils import parse_json_or_list
|
||||
|
||||
return parse_json_or_list(v, "select_columns")
|
||||
|
||||
search: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
|
||||
@@ -38,6 +38,7 @@ from superset.mcp_service.chart.schemas import (
|
||||
PerformanceMetadata,
|
||||
URLPreview,
|
||||
)
|
||||
from superset.mcp_service.utils.schema_utils import parse_request
|
||||
from superset.mcp_service.utils.url_utils import (
|
||||
get_chart_screenshot_url,
|
||||
get_superset_base_url,
|
||||
@@ -49,6 +50,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
@parse_request(GenerateChartRequest)
|
||||
async def generate_chart( # noqa: C901
|
||||
request: GenerateChartRequest, ctx: Context
|
||||
) -> GenerateChartResponse:
|
||||
|
||||
@@ -30,12 +30,14 @@ from superset.mcp_service.chart.schemas import (
|
||||
GetChartAvailableFiltersRequest,
|
||||
)
|
||||
from superset.mcp_service.mcp_core import ModelGetAvailableFiltersCore
|
||||
from superset.mcp_service.utils.schema_utils import parse_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
@parse_request(GetChartAvailableFiltersRequest)
|
||||
def get_chart_available_filters(
|
||||
request: GetChartAvailableFiltersRequest, ctx: Context
|
||||
) -> ChartAvailableFiltersResponse:
|
||||
|
||||
@@ -32,12 +32,14 @@ from superset.mcp_service.chart.schemas import (
|
||||
serialize_chart_object,
|
||||
)
|
||||
from superset.mcp_service.mcp_core import ModelGetInfoCore
|
||||
from superset.mcp_service.utils.schema_utils import parse_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
@parse_request(GetChartInfoRequest)
|
||||
async def get_chart_info(
|
||||
request: GetChartInfoRequest, ctx: Context
|
||||
) -> ChartInfo | ChartError:
|
||||
|
||||
@@ -38,6 +38,7 @@ from superset.mcp_service.chart.schemas import (
|
||||
URLPreview,
|
||||
VegaLitePreview,
|
||||
)
|
||||
from superset.mcp_service.utils.schema_utils import parse_request
|
||||
from superset.mcp_service.utils.url_utils import get_superset_base_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -2021,6 +2022,7 @@ async def _get_chart_preview_internal( # noqa: C901
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
@parse_request(GetChartPreviewRequest)
|
||||
async def get_chart_preview(
|
||||
request: GetChartPreviewRequest, ctx: Context
|
||||
) -> ChartPreview | ChartError:
|
||||
|
||||
@@ -38,6 +38,7 @@ from superset.mcp_service.chart.schemas import (
|
||||
serialize_chart_object,
|
||||
)
|
||||
from superset.mcp_service.mcp_core import ModelListCore
|
||||
from superset.mcp_service.utils.schema_utils import parse_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -67,6 +68,7 @@ SORTABLE_CHART_COLUMNS = [
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
@parse_request(ListChartsRequest)
|
||||
async def list_charts(request: ListChartsRequest, ctx: Context) -> ChartList:
|
||||
"""List charts with filtering and search.
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ from superset.mcp_service.chart.schemas import (
|
||||
PerformanceMetadata,
|
||||
UpdateChartRequest,
|
||||
)
|
||||
from superset.mcp_service.utils.schema_utils import parse_request
|
||||
from superset.mcp_service.utils.url_utils import (
|
||||
get_chart_screenshot_url,
|
||||
get_superset_base_url,
|
||||
@@ -49,6 +50,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
@parse_request(UpdateChartRequest)
|
||||
async def update_chart(
|
||||
request: UpdateChartRequest, ctx: Context
|
||||
) -> GenerateChartResponse:
|
||||
|
||||
@@ -40,6 +40,7 @@ from superset.mcp_service.chart.schemas import (
|
||||
UpdateChartPreviewRequest,
|
||||
URLPreview,
|
||||
)
|
||||
from superset.mcp_service.utils.schema_utils import parse_request
|
||||
from superset.mcp_service.utils.url_utils import get_mcp_service_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -47,6 +48,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
@parse_request(UpdateChartPreviewRequest)
|
||||
def update_chart_preview(
|
||||
request: UpdateChartPreviewRequest, ctx: Context
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
@@ -68,7 +68,14 @@ from __future__ import annotations
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Dict, List, Literal, TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator, PositiveInt
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
field_validator,
|
||||
model_validator,
|
||||
PositiveInt,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.dashboard import Dashboard
|
||||
@@ -202,6 +209,33 @@ class ListDashboardsRequest(MetadataCacheControl):
|
||||
"if not specified.",
|
||||
),
|
||||
]
|
||||
|
||||
@field_validator("filters", mode="before")
|
||||
@classmethod
|
||||
def parse_filters(cls, v: Any) -> List[DashboardFilter]:
|
||||
"""
|
||||
Parse filters from JSON string or list.
|
||||
|
||||
Handles Claude Code bug where objects are double-serialized as strings.
|
||||
See: https://github.com/anthropics/claude-code/issues/5504
|
||||
"""
|
||||
from superset.mcp_service.utils.schema_utils import parse_json_or_model_list
|
||||
|
||||
return parse_json_or_model_list(v, DashboardFilter, "filters")
|
||||
|
||||
@field_validator("select_columns", mode="before")
|
||||
@classmethod
|
||||
def parse_select_columns(cls, v: Any) -> List[str]:
|
||||
"""
|
||||
Parse select_columns from JSON string, list, or CSV string.
|
||||
|
||||
Handles Claude Code bug where arrays are double-serialized as strings.
|
||||
See: https://github.com/anthropics/claude-code/issues/5504
|
||||
"""
|
||||
from superset.mcp_service.utils.schema_utils import parse_json_or_list
|
||||
|
||||
return parse_json_or_list(v, "select_columns")
|
||||
|
||||
search: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
|
||||
@@ -33,6 +33,7 @@ from superset.mcp_service.dashboard.schemas import (
|
||||
AddChartToDashboardResponse,
|
||||
DashboardInfo,
|
||||
)
|
||||
from superset.mcp_service.utils.schema_utils import parse_request
|
||||
from superset.mcp_service.utils.url_utils import get_superset_base_url
|
||||
from superset.utils import json
|
||||
|
||||
@@ -136,6 +137,7 @@ def _ensure_layout_structure(layout: Dict[str, Any], row_key: str) -> None:
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
@parse_request(AddChartToDashboardRequest)
|
||||
def add_chart_to_existing_dashboard(
|
||||
request: AddChartToDashboardRequest, ctx: Context
|
||||
) -> AddChartToDashboardResponse:
|
||||
|
||||
@@ -33,6 +33,7 @@ from superset.mcp_service.dashboard.schemas import (
|
||||
GenerateDashboardRequest,
|
||||
GenerateDashboardResponse,
|
||||
)
|
||||
from superset.mcp_service.utils.schema_utils import parse_request
|
||||
from superset.mcp_service.utils.url_utils import get_superset_base_url
|
||||
from superset.utils import json
|
||||
|
||||
@@ -119,6 +120,7 @@ def _create_dashboard_layout(chart_objects: List[Any]) -> Dict[str, Any]:
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
@parse_request(GenerateDashboardRequest)
|
||||
def generate_dashboard(
|
||||
request: GenerateDashboardRequest, ctx: Context
|
||||
) -> GenerateDashboardResponse:
|
||||
|
||||
@@ -29,12 +29,14 @@ from superset.mcp_service.dashboard.schemas import (
|
||||
GetDashboardAvailableFiltersRequest,
|
||||
)
|
||||
from superset.mcp_service.mcp_core import ModelGetAvailableFiltersCore
|
||||
from superset.mcp_service.utils.schema_utils import parse_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
@parse_request(GetDashboardAvailableFiltersRequest)
|
||||
async def get_dashboard_available_filters(
|
||||
request: GetDashboardAvailableFiltersRequest, ctx: Context
|
||||
) -> DashboardAvailableFilters:
|
||||
|
||||
@@ -36,12 +36,14 @@ from superset.mcp_service.dashboard.schemas import (
|
||||
GetDashboardInfoRequest,
|
||||
)
|
||||
from superset.mcp_service.mcp_core import ModelGetInfoCore
|
||||
from superset.mcp_service.utils.schema_utils import parse_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
@parse_request(GetDashboardInfoRequest)
|
||||
async def get_dashboard_info(
|
||||
request: GetDashboardInfoRequest, ctx: Context
|
||||
) -> DashboardInfo | DashboardError:
|
||||
|
||||
@@ -36,6 +36,7 @@ from superset.mcp_service.dashboard.schemas import (
|
||||
serialize_dashboard_object,
|
||||
)
|
||||
from superset.mcp_service.mcp_core import ModelListCore
|
||||
from superset.mcp_service.utils.schema_utils import parse_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -61,6 +62,7 @@ SORTABLE_DASHBOARD_COLUMNS = [
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
@parse_request(ListDashboardsRequest)
|
||||
async def list_dashboards(
|
||||
request: ListDashboardsRequest, ctx: Context
|
||||
) -> DashboardList:
|
||||
@@ -70,7 +72,6 @@ async def list_dashboards(
|
||||
Sortable columns for order_column: id, dashboard_title, slug, published,
|
||||
changed_on, created_on
|
||||
"""
|
||||
|
||||
from superset.daos.dashboard import DashboardDAO
|
||||
|
||||
tool = ModelListCore(
|
||||
|
||||
@@ -29,12 +29,14 @@ from superset.mcp_service.dataset.schemas import (
|
||||
GetDatasetAvailableFiltersRequest,
|
||||
)
|
||||
from superset.mcp_service.mcp_core import ModelGetAvailableFiltersCore
|
||||
from superset.mcp_service.utils.schema_utils import parse_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
@parse_request(GetDatasetAvailableFiltersRequest)
|
||||
async def get_dataset_available_filters(
|
||||
request: GetDatasetAvailableFiltersRequest, ctx: Context
|
||||
) -> DatasetAvailableFilters:
|
||||
|
||||
@@ -36,12 +36,14 @@ from superset.mcp_service.dataset.schemas import (
|
||||
serialize_dataset_object,
|
||||
)
|
||||
from superset.mcp_service.mcp_core import ModelGetInfoCore
|
||||
from superset.mcp_service.utils.schema_utils import parse_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
@parse_request(GetDatasetInfoRequest)
|
||||
async def get_dataset_info(
|
||||
request: GetDatasetInfoRequest, ctx: Context
|
||||
) -> DatasetInfo | DatasetError:
|
||||
|
||||
@@ -36,6 +36,7 @@ from superset.mcp_service.dataset.schemas import (
|
||||
serialize_dataset_object,
|
||||
)
|
||||
from superset.mcp_service.mcp_core import ModelListCore
|
||||
from superset.mcp_service.utils.schema_utils import parse_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -64,6 +65,7 @@ SORTABLE_DATASET_COLUMNS = [
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
@parse_request(ListDatasetsRequest)
|
||||
async def list_datasets(request: ListDatasetsRequest, ctx: Context) -> DatasetList:
|
||||
"""List datasets with filtering and search.
|
||||
|
||||
|
||||
@@ -35,10 +35,12 @@ from superset.mcp_service.chart.chart_utils import (
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
GenerateExploreLinkRequest,
|
||||
)
|
||||
from superset.mcp_service.utils.schema_utils import parse_request
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
@parse_request(GenerateExploreLinkRequest)
|
||||
async def generate_explore_link(
|
||||
request: GenerateExploreLinkRequest, ctx: Context
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
@@ -128,17 +128,19 @@ class ModelListCore(BaseCore, Generic[L]):
|
||||
page: int = 0,
|
||||
page_size: int = 10,
|
||||
) -> L:
|
||||
# If filters is a string (e.g., from a test), parse it as JSON
|
||||
if isinstance(filters, str):
|
||||
from superset.utils import json
|
||||
# Parse filters using generic utility (accepts JSON string or object)
|
||||
from superset.mcp_service.utils.schema_utils import (
|
||||
parse_json_or_list,
|
||||
parse_json_or_passthrough,
|
||||
)
|
||||
|
||||
filters = json.loads(filters)
|
||||
# Ensure select_columns is a list and track what was requested
|
||||
filters = parse_json_or_passthrough(filters, param_name="filters")
|
||||
|
||||
# Parse select_columns using generic utility (accepts JSON, list, or CSV)
|
||||
if select_columns:
|
||||
if isinstance(select_columns, str):
|
||||
select_columns = [
|
||||
col.strip() for col in select_columns.split(",") if col.strip()
|
||||
]
|
||||
select_columns = parse_json_or_list(
|
||||
select_columns, param_name="select_columns"
|
||||
)
|
||||
columns_to_load = select_columns
|
||||
columns_requested = select_columns
|
||||
else:
|
||||
|
||||
@@ -33,12 +33,14 @@ from superset.mcp_service.sql_lab.schemas import (
|
||||
ExecuteSqlRequest,
|
||||
ExecuteSqlResponse,
|
||||
)
|
||||
from superset.mcp_service.utils.schema_utils import parse_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
@parse_request(ExecuteSqlRequest)
|
||||
async def execute_sql(request: ExecuteSqlRequest, ctx: Context) -> ExecuteSqlResponse:
|
||||
"""Execute SQL query against database.
|
||||
|
||||
|
||||
@@ -32,12 +32,14 @@ from superset.mcp_service.sql_lab.schemas import (
|
||||
OpenSqlLabRequest,
|
||||
SqlLabResponse,
|
||||
)
|
||||
from superset.mcp_service.utils.schema_utils import parse_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
@parse_request(OpenSqlLabRequest)
|
||||
def open_sql_lab_with_context(
|
||||
request: OpenSqlLabRequest, ctx: Context
|
||||
) -> SqlLabResponse:
|
||||
|
||||
@@ -38,6 +38,7 @@ from superset.mcp_service.system.system_utils import (
|
||||
calculate_popular_content,
|
||||
calculate_recent_activity,
|
||||
)
|
||||
from superset.mcp_service.utils.schema_utils import parse_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -71,6 +72,7 @@ _instance_info_core = InstanceInfoCore(
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
@parse_request(GetSupersetInstanceInfoRequest)
|
||||
def get_instance_info(
|
||||
request: GetSupersetInstanceInfoRequest, ctx: Context
|
||||
) -> InstanceInfo:
|
||||
|
||||
444
superset/mcp_service/utils/schema_utils.py
Normal file
444
superset/mcp_service/utils/schema_utils.py
Normal file
@@ -0,0 +1,444 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Generic utilities for flexible schema input handling in MCP tools.
|
||||
|
||||
This module provides utilities to accept both JSON string and object formats
|
||||
for input parameters, making MCP tools more flexible for different clients.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, List, Type, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class JSONParseError(ValueError):
|
||||
"""Raised when JSON parsing fails with helpful context."""
|
||||
|
||||
def __init__(self, value: Any, error: Exception, param_name: str = "parameter"):
|
||||
self.value = value
|
||||
self.original_error = error
|
||||
self.param_name = param_name
|
||||
super().__init__(
|
||||
f"Failed to parse {param_name} from JSON string: {error}. "
|
||||
f"Received value: {value!r}"
|
||||
)
|
||||
|
||||
|
||||
def parse_json_or_passthrough(
|
||||
value: Any, param_name: str = "parameter", strict: bool = False
|
||||
) -> Any:
|
||||
"""
|
||||
Parse a value that can be either a JSON string or a native Python object.
|
||||
|
||||
This function handles the common pattern where API parameters can be provided
|
||||
as either:
|
||||
- A JSON string (e.g., from CLI tools or tests): '{"key": "value"}'
|
||||
- A native Python object (e.g., from LLM clients): {"key": "value"}
|
||||
|
||||
Args:
|
||||
value: The input value to parse. Can be a string, list, dict, or any JSON-
|
||||
serializable type.
|
||||
param_name: Name of the parameter for error messages (default: "parameter")
|
||||
strict: If True, raises JSONParseError on parse failures. If False, logs
|
||||
warning and returns original value (default: False)
|
||||
|
||||
Returns:
|
||||
Parsed Python object if value was a JSON string, otherwise returns value
|
||||
unchanged.
|
||||
|
||||
Raises:
|
||||
JSONParseError: If strict=True and JSON parsing fails.
|
||||
|
||||
Examples:
|
||||
>>> parse_json_or_passthrough('[1, 2, 3]', 'numbers')
|
||||
[1, 2, 3]
|
||||
|
||||
>>> parse_json_or_passthrough([1, 2, 3], 'numbers')
|
||||
[1, 2, 3]
|
||||
|
||||
>>> parse_json_or_passthrough('{"key": "value"}', 'config')
|
||||
{'key': 'value'}
|
||||
|
||||
>>> parse_json_or_passthrough({'key': 'value'}, 'config')
|
||||
{'key': 'value'}
|
||||
"""
|
||||
# If not a string, return as-is (already in object form)
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
|
||||
# Try to parse as JSON
|
||||
try:
|
||||
from superset.utils import json
|
||||
|
||||
parsed = json.loads(value)
|
||||
logger.debug("Successfully parsed %s from JSON string", param_name)
|
||||
return parsed
|
||||
except (ValueError, TypeError) as e:
|
||||
error_msg = (
|
||||
f"Failed to parse {param_name} from JSON string: {e}. Received: {value!r}"
|
||||
)
|
||||
|
||||
if strict:
|
||||
raise JSONParseError(value, e, param_name) from None
|
||||
|
||||
logger.warning("%s. Returning original value.", error_msg)
|
||||
return value
|
||||
|
||||
|
||||
def parse_json_or_list(
|
||||
value: Any, param_name: str = "parameter", item_separator: str = ","
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Parse a value into a list, accepting JSON string, list, or comma-separated string.
|
||||
|
||||
This function provides maximum flexibility for list parameters by accepting:
|
||||
- JSON array string: '["item1", "item2"]'
|
||||
- Python list: ["item1", "item2"]
|
||||
- Comma-separated string: "item1, item2, item3"
|
||||
- Empty/None: returns empty list
|
||||
|
||||
Args:
|
||||
value: Input value to parse into a list
|
||||
param_name: Name of the parameter for error messages
|
||||
item_separator: Separator for comma-separated strings (default: ",")
|
||||
|
||||
Returns:
|
||||
List of items. Returns empty list if value is None or empty.
|
||||
|
||||
Examples:
|
||||
>>> parse_json_or_list('["a", "b"]', 'items')
|
||||
['a', 'b']
|
||||
|
||||
>>> parse_json_or_list(['a', 'b'], 'items')
|
||||
['a', 'b']
|
||||
|
||||
>>> parse_json_or_list('a, b, c', 'items')
|
||||
['a', 'b', 'c']
|
||||
|
||||
>>> parse_json_or_list(None, 'items')
|
||||
[]
|
||||
"""
|
||||
# Handle None and empty values
|
||||
if value is None or value == "":
|
||||
return []
|
||||
|
||||
# Already a list, return as-is
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
|
||||
# Try to parse as JSON
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
from superset.utils import json
|
||||
|
||||
parsed = json.loads(value)
|
||||
# If successfully parsed and it's a list, return it
|
||||
if isinstance(parsed, list):
|
||||
logger.debug("Successfully parsed %s from JSON string", param_name)
|
||||
return parsed
|
||||
|
||||
# If parsed to non-list (e.g., single value), wrap in list
|
||||
logger.debug(
|
||||
"Parsed %s from JSON to non-list, wrapping in list", param_name
|
||||
)
|
||||
return [parsed]
|
||||
except (ValueError, TypeError):
|
||||
# Not valid JSON, try comma-separated parsing
|
||||
logger.debug(
|
||||
"Could not parse %s as JSON, trying comma-separated", param_name
|
||||
)
|
||||
items = [
|
||||
item.strip() for item in value.split(item_separator) if item.strip()
|
||||
]
|
||||
return items
|
||||
|
||||
# For any other type, wrap in a list
|
||||
logger.debug("Wrapping %s value in list", param_name)
|
||||
return [value]
|
||||
|
||||
|
||||
def parse_json_or_model(
|
||||
value: Any, model_class: Type[BaseModel], param_name: str = "parameter"
|
||||
) -> BaseModel:
|
||||
"""
|
||||
Parse a value into a Pydantic model, accepting JSON string or dict.
|
||||
|
||||
Args:
|
||||
value: Input value to parse (JSON string, dict, or model instance)
|
||||
model_class: Pydantic model class to validate against
|
||||
param_name: Name of the parameter for error messages
|
||||
|
||||
Returns:
|
||||
Validated Pydantic model instance
|
||||
|
||||
Raises:
|
||||
ValidationError: If the value cannot be parsed or validated
|
||||
|
||||
Examples:
|
||||
>>> class MyModel(BaseModel):
|
||||
... name: str
|
||||
... value: int
|
||||
|
||||
>>> parse_json_or_model('{"name": "test", "value": 42}', MyModel)
|
||||
MyModel(name='test', value=42)
|
||||
|
||||
>>> parse_json_or_model({"name": "test", "value": 42}, MyModel)
|
||||
MyModel(name='test', value=42)
|
||||
"""
|
||||
# If already an instance of the model, return as-is
|
||||
if isinstance(value, model_class):
|
||||
return value
|
||||
|
||||
# Parse JSON string if needed
|
||||
parsed_value = parse_json_or_passthrough(value, param_name, strict=True)
|
||||
|
||||
# Validate and construct the model
|
||||
try:
|
||||
return model_class.model_validate(parsed_value)
|
||||
except ValidationError:
|
||||
logger.error(
|
||||
"Failed to validate %s against %s", param_name, model_class.__name__
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def parse_json_or_model_list(
|
||||
value: Any,
|
||||
model_class: Type[BaseModel],
|
||||
param_name: str = "parameter",
|
||||
) -> List[BaseModel]:
|
||||
"""
|
||||
Parse a value into a list of Pydantic models, accepting JSON string or list.
|
||||
|
||||
Args:
|
||||
value: Input value to parse (JSON string, list of dicts, or list of models)
|
||||
model_class: Pydantic model class for list items
|
||||
param_name: Name of the parameter for error messages
|
||||
|
||||
Returns:
|
||||
List of validated Pydantic model instances
|
||||
|
||||
Raises:
|
||||
ValidationError: If any item cannot be parsed or validated
|
||||
|
||||
Examples:
|
||||
>>> class Item(BaseModel):
|
||||
... name: str
|
||||
|
||||
>>> parse_json_or_model_list('[{"name": "a"}, {"name": "b"}]', Item)
|
||||
[Item(name='a'), Item(name='b')]
|
||||
|
||||
>>> parse_json_or_model_list([{"name": "a"}, {"name": "b"}], Item)
|
||||
[Item(name='a'), Item(name='b')]
|
||||
"""
|
||||
# Handle None and empty
|
||||
if value is None or value == "":
|
||||
return []
|
||||
|
||||
# Parse to list first
|
||||
items = parse_json_or_list(value, param_name)
|
||||
|
||||
# Validate each item against the model
|
||||
validated_items = []
|
||||
for i, item in enumerate(items):
|
||||
try:
|
||||
if isinstance(item, model_class):
|
||||
validated_items.append(item)
|
||||
else:
|
||||
validated_items.append(model_class.model_validate(item))
|
||||
except ValidationError:
|
||||
logger.error(
|
||||
"Failed to validate %s[%s] against %s",
|
||||
param_name,
|
||||
i,
|
||||
model_class.__name__,
|
||||
)
|
||||
# Re-raise original validation error
|
||||
raise
|
||||
|
||||
return validated_items
|
||||
|
||||
|
||||
# Pydantic validator decorators for common use cases
|
||||
def json_or_passthrough_validator(
|
||||
param_name: str | None = None, strict: bool = False
|
||||
) -> Callable[[Type[BaseModel], Any, Any], Any]:
|
||||
"""
|
||||
Decorator factory for Pydantic field validators that accept JSON or objects.
|
||||
|
||||
This creates a validator that can be used with Pydantic's @field_validator
|
||||
decorator to automatically parse JSON strings.
|
||||
|
||||
Args:
|
||||
param_name: Parameter name for error messages (uses field name if None)
|
||||
strict: Whether to raise errors on parse failures
|
||||
|
||||
Returns:
|
||||
Validator function compatible with @field_validator
|
||||
|
||||
Example:
|
||||
>>> class MySchema(BaseModel):
|
||||
... config: dict
|
||||
...
|
||||
... @field_validator('config', mode='before')
|
||||
... @classmethod
|
||||
... def parse_config(cls, v):
|
||||
... return parse_json_or_passthrough(v, 'config')
|
||||
"""
|
||||
|
||||
def validator(cls: Type[BaseModel], v: Any, info: Any = None) -> Any:
|
||||
# Use field name from validation info if param_name not provided
|
||||
field_name = param_name or (info.field_name if info else "field")
|
||||
return parse_json_or_passthrough(v, field_name, strict)
|
||||
|
||||
return validator
|
||||
|
||||
|
||||
def json_or_list_validator(
|
||||
param_name: str | None = None, item_separator: str = ","
|
||||
) -> Callable[[Type[BaseModel], Any, Any], List[Any]]:
|
||||
"""
|
||||
Decorator factory for Pydantic validators that parse values into lists.
|
||||
|
||||
Args:
|
||||
param_name: Parameter name for error messages
|
||||
item_separator: Separator for comma-separated strings
|
||||
|
||||
Returns:
|
||||
Validator function compatible with @field_validator
|
||||
|
||||
Example:
|
||||
>>> class MySchema(BaseModel):
|
||||
... items: List[str]
|
||||
...
|
||||
... @field_validator('items', mode='before')
|
||||
... @classmethod
|
||||
... def parse_items(cls, v):
|
||||
... return parse_json_or_list(v, 'items')
|
||||
"""
|
||||
|
||||
def validator(cls: Type[BaseModel], v: Any, info: Any = None) -> List[Any]:
|
||||
field_name = param_name or (info.field_name if info else "field")
|
||||
return parse_json_or_list(v, field_name, item_separator)
|
||||
|
||||
return validator
|
||||
|
||||
|
||||
def json_or_model_list_validator(
|
||||
model_class: Type[BaseModel], param_name: str | None = None
|
||||
) -> Callable[[Type[BaseModel], Any, Any], List[BaseModel]]:
|
||||
"""
|
||||
Decorator factory for Pydantic validators that parse lists of models.
|
||||
|
||||
Args:
|
||||
model_class: Pydantic model class for list items
|
||||
param_name: Parameter name for error messages
|
||||
|
||||
Returns:
|
||||
Validator function compatible with @field_validator
|
||||
|
||||
Example:
|
||||
>>> class FilterModel(BaseModel):
|
||||
... col: str
|
||||
... value: str
|
||||
...
|
||||
>>> class MySchema(BaseModel):
|
||||
... filters: List[FilterModel]
|
||||
...
|
||||
... @field_validator('filters', mode='before')
|
||||
... @classmethod
|
||||
... def parse_filters(cls, v):
|
||||
... return parse_json_or_model_list(v, FilterModel, 'filters')
|
||||
"""
|
||||
|
||||
def validator(cls: Type[BaseModel], v: Any, info: Any = None) -> List[BaseModel]:
|
||||
field_name = param_name or (info.field_name if info else "field")
|
||||
return parse_json_or_model_list(v, model_class, field_name)
|
||||
|
||||
return validator
|
||||
|
||||
|
||||
def parse_request(
|
||||
request_class: Type[BaseModel],
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
"""
|
||||
Decorator to handle Claude Code bug where requests are double-serialized as strings.
|
||||
|
||||
Automatically parses string requests to Pydantic models before calling
|
||||
the tool function.
|
||||
This eliminates the need for manual parsing code in every tool function.
|
||||
|
||||
See: https://github.com/anthropics/claude-code/issues/5504
|
||||
|
||||
Args:
|
||||
request_class: The Pydantic model class for the request
|
||||
|
||||
Returns:
|
||||
Decorator function that wraps the tool function
|
||||
|
||||
Usage:
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
@parse_request(ListChartsRequest)
|
||||
async def list_charts(
|
||||
request: ListChartsRequest, ctx: Context
|
||||
) -> ChartList:
|
||||
# Decorator handles string conversion automatically
|
||||
await ctx.info(f"Listing charts: page={request.page}")
|
||||
...
|
||||
|
||||
Note:
|
||||
- Works with both async and sync functions
|
||||
- Request must be the first positional argument
|
||||
- If request is already a model instance, it passes through unchanged
|
||||
- Handles JSON string parsing with helpful error messages
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(request: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
# Parse if string, otherwise pass through
|
||||
# (parse_json_or_model handles both)
|
||||
parsed_request = parse_json_or_model(request, request_class, "request")
|
||||
return await func(parsed_request, *args, **kwargs)
|
||||
|
||||
return async_wrapper
|
||||
else:
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(request: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
# Parse if string, otherwise pass through
|
||||
# (parse_json_or_model handles both)
|
||||
parsed_request = parse_json_or_model(request, request_class, "request")
|
||||
return func(parsed_request, *args, **kwargs)
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
515
tests/unit_tests/mcp_service/utils/test_schema_utils.py
Normal file
515
tests/unit_tests/mcp_service/utils/test_schema_utils.py
Normal file
@@ -0,0 +1,515 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Unit tests for MCP service schema utilities.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from superset.mcp_service.utils.schema_utils import (
|
||||
JSONParseError,
|
||||
parse_json_or_list,
|
||||
parse_json_or_model,
|
||||
parse_json_or_model_list,
|
||||
parse_json_or_passthrough,
|
||||
parse_request,
|
||||
)
|
||||
|
||||
|
||||
class TestParseJsonOrPassthrough:
|
||||
"""Test parse_json_or_passthrough function."""
|
||||
|
||||
def test_parse_valid_json_string(self):
|
||||
"""Should parse valid JSON string to Python object."""
|
||||
result = parse_json_or_passthrough('{"key": "value"}', "config")
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_parse_json_array(self):
|
||||
"""Should parse JSON array string to Python list."""
|
||||
result = parse_json_or_passthrough("[1, 2, 3]", "numbers")
|
||||
assert result == [1, 2, 3]
|
||||
|
||||
def test_passthrough_dict(self):
|
||||
"""Should return dict as-is without parsing."""
|
||||
input_dict = {"key": "value"}
|
||||
result = parse_json_or_passthrough(input_dict, "config")
|
||||
assert result is input_dict
|
||||
|
||||
def test_passthrough_list(self):
|
||||
"""Should return list as-is without parsing."""
|
||||
input_list = [1, 2, 3]
|
||||
result = parse_json_or_passthrough(input_list, "numbers")
|
||||
assert result is input_list
|
||||
|
||||
def test_passthrough_none(self):
|
||||
"""Should return None as-is."""
|
||||
result = parse_json_or_passthrough(None, "value")
|
||||
assert result is None
|
||||
|
||||
def test_invalid_json_non_strict(self):
|
||||
"""Should return original value when JSON parsing fails (non-strict)."""
|
||||
result = parse_json_or_passthrough("not valid json", "config", strict=False)
|
||||
assert result == "not valid json"
|
||||
|
||||
def test_invalid_json_strict(self):
|
||||
"""Should raise JSONParseError when parsing fails (strict mode)."""
|
||||
with pytest.raises(JSONParseError) as exc_info:
|
||||
parse_json_or_passthrough("not valid json", "config", strict=True)
|
||||
|
||||
assert exc_info.value.param_name == "config"
|
||||
assert exc_info.value.value == "not valid json"
|
||||
|
||||
def test_parse_json_number(self):
|
||||
"""Should parse numeric JSON string."""
|
||||
result = parse_json_or_passthrough("42", "number")
|
||||
assert result == 42
|
||||
|
||||
def test_parse_json_boolean(self):
|
||||
"""Should parse boolean JSON string."""
|
||||
result = parse_json_or_passthrough("true", "flag")
|
||||
assert result is True
|
||||
|
||||
def test_parse_nested_json(self):
|
||||
"""Should parse nested JSON structures."""
|
||||
json_str = '{"outer": {"inner": [1, 2, 3]}}'
|
||||
result = parse_json_or_passthrough(json_str, "nested")
|
||||
assert result == {"outer": {"inner": [1, 2, 3]}}
|
||||
|
||||
|
||||
class TestParseJsonOrList:
|
||||
"""Test parse_json_or_list function."""
|
||||
|
||||
def test_parse_json_array(self):
|
||||
"""Should parse JSON array string to list."""
|
||||
result = parse_json_or_list('["a", "b", "c"]', "items")
|
||||
assert result == ["a", "b", "c"]
|
||||
|
||||
def test_passthrough_list(self):
|
||||
"""Should return list as-is."""
|
||||
input_list = ["a", "b", "c"]
|
||||
result = parse_json_or_list(input_list, "items")
|
||||
assert result is input_list
|
||||
|
||||
def test_parse_comma_separated(self):
|
||||
"""Should parse comma-separated string to list."""
|
||||
result = parse_json_or_list("a, b, c", "items")
|
||||
assert result == ["a", "b", "c"]
|
||||
|
||||
def test_parse_comma_separated_with_whitespace(self):
|
||||
"""Should handle whitespace in comma-separated strings."""
|
||||
result = parse_json_or_list(" a , b , c ", "items")
|
||||
assert result == ["a", "b", "c"]
|
||||
|
||||
def test_empty_string_returns_empty_list(self):
|
||||
"""Should return empty list for empty string."""
|
||||
result = parse_json_or_list("", "items")
|
||||
assert result == []
|
||||
|
||||
def test_none_returns_empty_list(self):
|
||||
"""Should return empty list for None."""
|
||||
result = parse_json_or_list(None, "items")
|
||||
assert result == []
|
||||
|
||||
def test_single_json_value_wrapped_in_list(self):
|
||||
"""Should wrap single JSON value in list."""
|
||||
result = parse_json_or_list('"single"', "items")
|
||||
assert result == ["single"]
|
||||
|
||||
def test_custom_separator(self):
|
||||
"""Should use custom separator when provided."""
|
||||
result = parse_json_or_list("a|b|c", "items", item_separator="|")
|
||||
assert result == ["a", "b", "c"]
|
||||
|
||||
def test_non_list_wrapped(self):
|
||||
"""Should wrap non-list types in a list."""
|
||||
result = parse_json_or_list(42, "items")
|
||||
assert result == [42]
|
||||
|
||||
def test_parse_empty_items_in_csv(self):
|
||||
"""Should ignore empty items in comma-separated string."""
|
||||
result = parse_json_or_list("a,,b,,c", "items")
|
||||
assert result == ["a", "b", "c"]
|
||||
|
||||
|
||||
class TestParseJsonOrModel:
|
||||
"""Test parse_json_or_model function."""
|
||||
|
||||
class TestModel(BaseModel):
|
||||
"""Test Pydantic model."""
|
||||
|
||||
name: str
|
||||
value: int
|
||||
|
||||
def test_parse_json_string(self):
|
||||
"""Should parse JSON string to model instance."""
|
||||
result = parse_json_or_model(
|
||||
'{"name": "test", "value": 42}', self.TestModel, "config"
|
||||
)
|
||||
assert isinstance(result, self.TestModel)
|
||||
assert result.name == "test"
|
||||
assert result.value == 42
|
||||
|
||||
def test_parse_dict(self):
|
||||
"""Should parse dict to model instance."""
|
||||
result = parse_json_or_model(
|
||||
{"name": "test", "value": 42}, self.TestModel, "config"
|
||||
)
|
||||
assert isinstance(result, self.TestModel)
|
||||
assert result.name == "test"
|
||||
assert result.value == 42
|
||||
|
||||
def test_passthrough_model_instance(self):
|
||||
"""Should return model instance as-is."""
|
||||
instance = self.TestModel(name="test", value=42)
|
||||
result = parse_json_or_model(instance, self.TestModel, "config")
|
||||
assert result is instance
|
||||
|
||||
def test_invalid_json_raises_error(self):
|
||||
"""Should raise JSONParseError for invalid JSON."""
|
||||
with pytest.raises(JSONParseError):
|
||||
parse_json_or_model("not valid json", self.TestModel, "config")
|
||||
|
||||
def test_invalid_model_data_raises_validation_error(self):
|
||||
"""Should raise ValidationError for invalid model data."""
|
||||
with pytest.raises(ValidationError):
|
||||
parse_json_or_model({"name": "test"}, self.TestModel, "config")
|
||||
|
||||
def test_wrong_type_raises_validation_error(self):
|
||||
"""Should raise ValidationError for wrong data types."""
|
||||
with pytest.raises(ValidationError):
|
||||
parse_json_or_model(
|
||||
{"name": "test", "value": "not_a_number"}, self.TestModel, "config"
|
||||
)
|
||||
|
||||
|
||||
class TestParseJsonOrModelList:
|
||||
"""Test parse_json_or_model_list function."""
|
||||
|
||||
class ItemModel(BaseModel):
|
||||
"""Test Pydantic model for list items."""
|
||||
|
||||
name: str
|
||||
value: int
|
||||
|
||||
def test_parse_json_array(self):
|
||||
"""Should parse JSON array to list of models."""
|
||||
json_str = '[{"name": "a", "value": 1}, {"name": "b", "value": 2}]'
|
||||
result = parse_json_or_model_list(json_str, self.ItemModel, "items")
|
||||
assert len(result) == 2
|
||||
assert all(isinstance(item, self.ItemModel) for item in result)
|
||||
assert result[0].name == "a"
|
||||
assert result[1].value == 2
|
||||
|
||||
def test_parse_list_of_dicts(self):
|
||||
"""Should parse list of dicts to list of models."""
|
||||
input_list = [{"name": "a", "value": 1}, {"name": "b", "value": 2}]
|
||||
result = parse_json_or_model_list(input_list, self.ItemModel, "items")
|
||||
assert len(result) == 2
|
||||
assert all(isinstance(item, self.ItemModel) for item in result)
|
||||
|
||||
def test_passthrough_list_of_models(self):
|
||||
"""Should return list of models as-is."""
|
||||
input_list = [
|
||||
self.ItemModel(name="a", value=1),
|
||||
self.ItemModel(name="b", value=2),
|
||||
]
|
||||
result = parse_json_or_model_list(input_list, self.ItemModel, "items")
|
||||
assert len(result) == 2
|
||||
assert result[0] is input_list[0]
|
||||
assert result[1] is input_list[1]
|
||||
|
||||
def test_empty_returns_empty_list(self):
|
||||
"""Should return empty list for empty input."""
|
||||
assert parse_json_or_model_list(None, self.ItemModel, "items") == []
|
||||
assert parse_json_or_model_list("", self.ItemModel, "items") == []
|
||||
assert parse_json_or_model_list([], self.ItemModel, "items") == []
|
||||
|
||||
def test_invalid_item_raises_validation_error(self):
|
||||
"""Should raise ValidationError for invalid item in list."""
|
||||
input_list = [{"name": "a", "value": 1}, {"name": "b"}] # Missing value
|
||||
with pytest.raises(ValidationError):
|
||||
parse_json_or_model_list(input_list, self.ItemModel, "items")
|
||||
|
||||
def test_mixed_models_and_dicts(self):
|
||||
"""Should handle mixed list of models and dicts."""
|
||||
input_list = [
|
||||
self.ItemModel(name="a", value=1),
|
||||
{"name": "b", "value": 2},
|
||||
]
|
||||
result = parse_json_or_model_list(input_list, self.ItemModel, "items")
|
||||
assert len(result) == 2
|
||||
assert all(isinstance(item, self.ItemModel) for item in result)
|
||||
|
||||
|
||||
class TestPydanticIntegration:
|
||||
"""Test integration with Pydantic validators."""
|
||||
|
||||
def test_field_validator_with_json_string(self):
|
||||
"""Should work with Pydantic field validators for JSON strings."""
|
||||
from pydantic import field_validator
|
||||
|
||||
class TestSchema(BaseModel):
|
||||
"""Test schema with field validator."""
|
||||
|
||||
config: dict
|
||||
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
def parse_config(cls, v):
|
||||
"""Parse config from JSON or dict."""
|
||||
return parse_json_or_passthrough(v, "config")
|
||||
|
||||
# Test with JSON string
|
||||
schema = TestSchema.model_validate({"config": '{"key": "value"}'})
|
||||
assert schema.config == {"key": "value"}
|
||||
|
||||
# Test with dict
|
||||
schema = TestSchema.model_validate({"config": {"key": "value"}})
|
||||
assert schema.config == {"key": "value"}
|
||||
|
||||
def test_field_validator_with_list(self):
|
||||
"""Should work with Pydantic field validators for lists."""
|
||||
from pydantic import field_validator
|
||||
|
||||
class TestSchema(BaseModel):
|
||||
"""Test schema with list field validator."""
|
||||
|
||||
items: list
|
||||
|
||||
@field_validator("items", mode="before")
|
||||
@classmethod
|
||||
def parse_items(cls, v):
|
||||
"""Parse items from various formats."""
|
||||
return parse_json_or_list(v, "items")
|
||||
|
||||
# Test with JSON array
|
||||
schema = TestSchema.model_validate({"items": '["a", "b", "c"]'})
|
||||
assert schema.items == ["a", "b", "c"]
|
||||
|
||||
# Test with list
|
||||
schema = TestSchema.model_validate({"items": ["a", "b", "c"]})
|
||||
assert schema.items == ["a", "b", "c"]
|
||||
|
||||
# Test with CSV string
|
||||
schema = TestSchema.model_validate({"items": "a, b, c"})
|
||||
assert schema.items == ["a", "b", "c"]
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and error conditions."""
|
||||
|
||||
def test_empty_json_object(self):
|
||||
"""Should handle empty JSON objects."""
|
||||
result = parse_json_or_passthrough("{}", "config")
|
||||
assert result == {}
|
||||
|
||||
def test_empty_json_array(self):
|
||||
"""Should handle empty JSON arrays."""
|
||||
result = parse_json_or_list("[]", "items")
|
||||
assert result == []
|
||||
|
||||
def test_whitespace_only_string(self):
|
||||
"""Should handle whitespace-only strings."""
|
||||
result = parse_json_or_list(" ", "items")
|
||||
assert result == []
|
||||
|
||||
def test_malformed_json(self):
|
||||
"""Should handle malformed JSON gracefully."""
|
||||
result = parse_json_or_passthrough('{"key": invalid}', "config", strict=False)
|
||||
assert result == '{"key": invalid}'
|
||||
|
||||
def test_unicode_in_json(self):
|
||||
"""Should handle Unicode characters in JSON."""
|
||||
result = parse_json_or_passthrough('{"name": "测试"}', "config")
|
||||
assert result == {"name": "测试"}
|
||||
|
||||
def test_special_characters_in_csv(self):
|
||||
"""Should handle special characters in CSV strings."""
|
||||
result = parse_json_or_list("item-1, item_2, item.3", "items")
|
||||
assert result == ["item-1", "item_2", "item.3"]
|
||||
|
||||
|
||||
class TestParseRequestDecorator:
|
||||
"""Test parse_request decorator for MCP tools."""
|
||||
|
||||
class RequestModel(BaseModel):
|
||||
"""Test request model."""
|
||||
|
||||
name: str
|
||||
count: int
|
||||
|
||||
def test_decorator_with_json_string_async(self):
|
||||
"""Should parse JSON string request in async function."""
|
||||
|
||||
@parse_request(self.RequestModel)
|
||||
async def async_tool(request, ctx=None):
|
||||
return f"{request.name}:{request.count}"
|
||||
|
||||
import asyncio
|
||||
|
||||
result = asyncio.run(async_tool('{"name": "test", "count": 5}'))
|
||||
assert result == "test:5"
|
||||
|
||||
def test_decorator_with_json_string_sync(self):
|
||||
"""Should parse JSON string request in sync function."""
|
||||
|
||||
@parse_request(self.RequestModel)
|
||||
def sync_tool(request, ctx=None):
|
||||
return f"{request.name}:{request.count}"
|
||||
|
||||
result = sync_tool('{"name": "test", "count": 5}')
|
||||
assert result == "test:5"
|
||||
|
||||
def test_decorator_with_dict_async(self):
|
||||
"""Should handle dict request in async function."""
|
||||
|
||||
@parse_request(self.RequestModel)
|
||||
async def async_tool(request, ctx=None):
|
||||
return f"{request.name}:{request.count}"
|
||||
|
||||
import asyncio
|
||||
|
||||
result = asyncio.run(async_tool({"name": "test", "count": 5}))
|
||||
assert result == "test:5"
|
||||
|
||||
def test_decorator_with_dict_sync(self):
|
||||
"""Should handle dict request in sync function."""
|
||||
|
||||
@parse_request(self.RequestModel)
|
||||
def sync_tool(request, ctx=None):
|
||||
return f"{request.name}:{request.count}"
|
||||
|
||||
result = sync_tool({"name": "test", "count": 5})
|
||||
assert result == "test:5"
|
||||
|
||||
def test_decorator_with_model_instance_async(self):
|
||||
"""Should pass through model instance in async function."""
|
||||
|
||||
@parse_request(self.RequestModel)
|
||||
async def async_tool(request, ctx=None):
|
||||
return f"{request.name}:{request.count}"
|
||||
|
||||
import asyncio
|
||||
|
||||
instance = self.RequestModel(name="test", count=5)
|
||||
result = asyncio.run(async_tool(instance))
|
||||
assert result == "test:5"
|
||||
|
||||
def test_decorator_with_model_instance_sync(self):
|
||||
"""Should pass through model instance in sync function."""
|
||||
|
||||
@parse_request(self.RequestModel)
|
||||
def sync_tool(request, ctx=None):
|
||||
return f"{request.name}:{request.count}"
|
||||
|
||||
instance = self.RequestModel(name="test", count=5)
|
||||
result = sync_tool(instance)
|
||||
assert result == "test:5"
|
||||
|
||||
def test_decorator_preserves_function_signature_async(self):
|
||||
"""Should preserve original async function signature."""
|
||||
|
||||
@parse_request(self.RequestModel)
|
||||
async def async_tool(request, ctx=None, extra=None):
|
||||
return f"{request.name}:{request.count}:{extra}"
|
||||
|
||||
import asyncio
|
||||
|
||||
result = asyncio.run(
|
||||
async_tool('{"name": "test", "count": 5}', ctx=None, extra="data")
|
||||
)
|
||||
assert result == "test:5:data"
|
||||
|
||||
def test_decorator_preserves_function_signature_sync(self):
|
||||
"""Should preserve original sync function signature."""
|
||||
|
||||
@parse_request(self.RequestModel)
|
||||
def sync_tool(request, ctx=None, extra=None):
|
||||
return f"{request.name}:{request.count}:{extra}"
|
||||
|
||||
result = sync_tool('{"name": "test", "count": 5}', ctx=None, extra="data")
|
||||
assert result == "test:5:data"
|
||||
|
||||
def test_decorator_raises_validation_error_async(self):
|
||||
"""Should raise ValidationError for invalid data in async function."""
|
||||
|
||||
@parse_request(self.RequestModel)
|
||||
async def async_tool(request, ctx=None):
|
||||
return f"{request.name}:{request.count}"
|
||||
|
||||
import asyncio
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
asyncio.run(async_tool('{"name": "test"}')) # Missing count
|
||||
|
||||
def test_decorator_raises_validation_error_sync(self):
|
||||
"""Should raise ValidationError for invalid data in sync function."""
|
||||
|
||||
@parse_request(self.RequestModel)
|
||||
def sync_tool(request, ctx=None):
|
||||
return f"{request.name}:{request.count}"
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
sync_tool('{"name": "test"}') # Missing count
|
||||
|
||||
def test_decorator_with_complex_model_async(self):
|
||||
"""Should handle complex nested models in async function."""
|
||||
|
||||
class NestedModel(BaseModel):
|
||||
"""Nested model."""
|
||||
|
||||
value: int
|
||||
|
||||
class ComplexModel(BaseModel):
|
||||
"""Complex request model."""
|
||||
|
||||
name: str
|
||||
nested: NestedModel
|
||||
|
||||
@parse_request(ComplexModel)
|
||||
async def async_tool(request, ctx=None):
|
||||
return f"{request.name}:{request.nested.value}"
|
||||
|
||||
import asyncio
|
||||
|
||||
json_str = '{"name": "test", "nested": {"value": 42}}'
|
||||
result = asyncio.run(async_tool(json_str))
|
||||
assert result == "test:42"
|
||||
|
||||
def test_decorator_with_complex_model_sync(self):
|
||||
"""Should handle complex nested models in sync function."""
|
||||
|
||||
class NestedModel(BaseModel):
|
||||
"""Nested model."""
|
||||
|
||||
value: int
|
||||
|
||||
class ComplexModel(BaseModel):
|
||||
"""Complex request model."""
|
||||
|
||||
name: str
|
||||
nested: NestedModel
|
||||
|
||||
@parse_request(ComplexModel)
|
||||
def sync_tool(request, ctx=None):
|
||||
return f"{request.name}:{request.nested.value}"
|
||||
|
||||
json_str = '{"name": "test", "nested": {"value": 42}}'
|
||||
result = sync_tool(json_str)
|
||||
assert result == "test:42"
|
||||
Reference in New Issue
Block a user