diff --git a/superset/mcp_service/CLAUDE.md b/superset/mcp_service/CLAUDE.md index 94b33a1bce4..5d42c07849f 100644 --- a/superset/mcp_service/CLAUDE.md +++ b/superset/mcp_service/CLAUDE.md @@ -293,7 +293,113 @@ def my_function( - Still import `List`, `Dict`, `Any`, etc. from typing (for now) - All new code must follow this pattern -### 6. Error Handling +### 6. Flexible Input Parsing (JSON String or Object) + +**MCP tools accept both JSON string and native object formats for parameters** using utilities from `superset.mcp_service.utils.schema_utils`. This makes tools flexible for different client types (LLM clients send objects, CLI tools send JSON strings). + +**PREFERRED: Use the `@parse_request` decorator** for tool functions to automatically handle request parsing: + +```python +from superset.mcp_service.utils.schema_utils import parse_request + +@mcp.tool +@mcp_auth_hook +@parse_request(ListChartsRequest) # Automatically parses string requests! +async def list_charts(request: ListChartsRequest | str, ctx: Context) -> ChartList: + """List charts with filtering and search.""" + # request is guaranteed to be ListChartsRequest here - no manual parsing needed! + await ctx.info(f"Listing charts: page={request.page}") + ... +``` + +**Benefits:** +- Eliminates 5 lines of boilerplate code per tool +- Handles both async and sync functions automatically +- Works with Claude Code bug (GitHub issue #5504) +- Cleaner, more maintainable code + +**Available utilities for other use cases:** + +#### parse_json_or_passthrough +Parse JSON string or return object as-is: + +```python +from superset.mcp_service.utils.schema_utils import parse_json_or_passthrough + +# Accepts both formats +config = parse_json_or_passthrough(value, param_name="config") +# value can be: '{"key": "value"}' (JSON string) OR {"key": "value"} (dict) +``` + +#### parse_json_or_list +Parse to list from JSON, list, or comma-separated string: + +```python +from superset.mcp_service.utils.schema_utils import parse_json_or_list + +# Accepts multiple formats +items = parse_json_or_list(value, param_name="items") +# value can be: +# '["a", "b"]' (JSON array) +# ["a", "b"] (Python list) +# "a, b, c" (comma-separated string) +``` + +#### parse_json_or_model +Parse to Pydantic model from JSON or dict: + +```python +from superset.mcp_service.utils.schema_utils import parse_json_or_model + +# Accepts JSON string or dict +config = parse_json_or_model(value, ConfigModel, param_name="config") +# value can be: '{"name": "test"}' OR {"name": "test"} +``` + +#### parse_json_or_model_list +Parse to list of Pydantic models: + +```python +from superset.mcp_service.utils.schema_utils import parse_json_or_model_list + +# Accepts JSON array or list of dicts +filters = parse_json_or_model_list(value, FilterModel, param_name="filters") +# value can be: '[{"col": "name"}]' OR [{"col": "name"}] +``` + +**Using with Pydantic validators:** + +```python +from pydantic import BaseModel, field_validator +from superset.mcp_service.utils.schema_utils import parse_json_or_list + +class MyToolRequest(BaseModel): + filters: List[FilterModel] = Field(default_factory=list) + select_columns: List[str] = Field(default_factory=list) + + @field_validator("filters", mode="before") + @classmethod + def parse_filters(cls, v): + """Accept both JSON string and list of objects.""" + return parse_json_or_model_list(v, FilterModel, "filters") + + @field_validator("select_columns", mode="before") + @classmethod + def parse_columns(cls, v): + """Accept JSON array, list, or comma-separated string.""" + return parse_json_or_list(v, "select_columns") +``` + +**Core classes already use these utilities:** +- `ModelListCore` uses them for `filters` and `select_columns` +- No need to add parsing logic in individual tools that use core classes + +**When to use:** +- Tool parameters that accept complex objects (dicts, lists) +- Parameters that may come from CLI tools (JSON strings) or LLM clients (objects) +- Any field where you want maximum flexibility + +### 7. Error Handling **Use consistent error schemas**: diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index b7ccbb363ba..6c9198696ca 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -734,6 +734,33 @@ class ListChartsRequest(MetadataCacheControl): "specified.", ), ] + + @field_validator("filters", mode="before") + @classmethod + def parse_filters(cls, v: Any) -> List[ChartFilter]: + """ + Parse filters from JSON string or list. + + Handles Claude Code bug where objects are double-serialized as strings. + See: https://github.com/anthropics/claude-code/issues/5504 + """ + from superset.mcp_service.utils.schema_utils import parse_json_or_model_list + + return parse_json_or_model_list(v, ChartFilter, "filters") + + @field_validator("select_columns", mode="before") + @classmethod + def parse_select_columns(cls, v: Any) -> List[str]: + """ + Parse select_columns from JSON string, list, or CSV string. + + Handles Claude Code bug where arrays are double-serialized as strings. + See: https://github.com/anthropics/claude-code/issues/5504 + """ + from superset.mcp_service.utils.schema_utils import parse_json_or_list + + return parse_json_or_list(v, "select_columns") + search: Annotated[ str | None, Field( diff --git a/superset/mcp_service/chart/tool/generate_chart.py b/superset/mcp_service/chart/tool/generate_chart.py index 9afe5ba8949..bb827f2d74b 100644 --- a/superset/mcp_service/chart/tool/generate_chart.py +++ b/superset/mcp_service/chart/tool/generate_chart.py @@ -38,6 +38,7 @@ from superset.mcp_service.chart.schemas import ( PerformanceMetadata, URLPreview, ) +from superset.mcp_service.utils.schema_utils import parse_request from superset.mcp_service.utils.url_utils import ( get_chart_screenshot_url, get_superset_base_url, @@ -49,6 +50,7 @@ logger = logging.getLogger(__name__) @mcp.tool @mcp_auth_hook +@parse_request(GenerateChartRequest) async def generate_chart( # noqa: C901 request: GenerateChartRequest, ctx: Context ) -> GenerateChartResponse: diff --git a/superset/mcp_service/chart/tool/get_chart_available_filters.py b/superset/mcp_service/chart/tool/get_chart_available_filters.py index 47f6713e876..30df99e1591 100644 --- a/superset/mcp_service/chart/tool/get_chart_available_filters.py +++ b/superset/mcp_service/chart/tool/get_chart_available_filters.py @@ -30,12 +30,14 @@ from superset.mcp_service.chart.schemas import ( GetChartAvailableFiltersRequest, ) from superset.mcp_service.mcp_core import ModelGetAvailableFiltersCore +from superset.mcp_service.utils.schema_utils import parse_request logger = logging.getLogger(__name__) @mcp.tool @mcp_auth_hook +@parse_request(GetChartAvailableFiltersRequest) def get_chart_available_filters( request: GetChartAvailableFiltersRequest, ctx: Context ) -> ChartAvailableFiltersResponse: diff --git a/superset/mcp_service/chart/tool/get_chart_info.py b/superset/mcp_service/chart/tool/get_chart_info.py index 9c59131dee7..ed09b03acfc 100644 --- a/superset/mcp_service/chart/tool/get_chart_info.py +++ b/superset/mcp_service/chart/tool/get_chart_info.py @@ -32,12 +32,14 @@ from superset.mcp_service.chart.schemas import ( serialize_chart_object, ) from superset.mcp_service.mcp_core import ModelGetInfoCore +from superset.mcp_service.utils.schema_utils import parse_request logger = logging.getLogger(__name__) @mcp.tool @mcp_auth_hook +@parse_request(GetChartInfoRequest) async def get_chart_info( request: GetChartInfoRequest, ctx: Context ) -> ChartInfo | ChartError: diff --git a/superset/mcp_service/chart/tool/get_chart_preview.py b/superset/mcp_service/chart/tool/get_chart_preview.py index 24f1ff70b35..3f2ed87e551 100644 --- a/superset/mcp_service/chart/tool/get_chart_preview.py +++ b/superset/mcp_service/chart/tool/get_chart_preview.py @@ -38,6 +38,7 @@ from superset.mcp_service.chart.schemas import ( URLPreview, VegaLitePreview, ) +from superset.mcp_service.utils.schema_utils import parse_request from superset.mcp_service.utils.url_utils import get_superset_base_url logger = logging.getLogger(__name__) @@ -2021,6 +2022,7 @@ async def _get_chart_preview_internal( # noqa: C901 @mcp.tool @mcp_auth_hook +@parse_request(GetChartPreviewRequest) async def get_chart_preview( request: GetChartPreviewRequest, ctx: Context ) -> ChartPreview | ChartError: diff --git a/superset/mcp_service/chart/tool/list_charts.py b/superset/mcp_service/chart/tool/list_charts.py index ae8d8013ec6..5795c6f5588 100644 --- a/superset/mcp_service/chart/tool/list_charts.py +++ b/superset/mcp_service/chart/tool/list_charts.py @@ -38,6 +38,7 @@ from superset.mcp_service.chart.schemas import ( serialize_chart_object, ) from superset.mcp_service.mcp_core import ModelListCore +from superset.mcp_service.utils.schema_utils import parse_request logger = logging.getLogger(__name__) @@ -67,6 +68,7 @@ SORTABLE_CHART_COLUMNS = [ @mcp.tool @mcp_auth_hook +@parse_request(ListChartsRequest) async def list_charts(request: ListChartsRequest, ctx: Context) -> ChartList: """List charts with filtering and search. diff --git a/superset/mcp_service/chart/tool/update_chart.py b/superset/mcp_service/chart/tool/update_chart.py index 9ed3052a5a5..739a72240eb 100644 --- a/superset/mcp_service/chart/tool/update_chart.py +++ b/superset/mcp_service/chart/tool/update_chart.py @@ -38,6 +38,7 @@ from superset.mcp_service.chart.schemas import ( PerformanceMetadata, UpdateChartRequest, ) +from superset.mcp_service.utils.schema_utils import parse_request from superset.mcp_service.utils.url_utils import ( get_chart_screenshot_url, get_superset_base_url, @@ -49,6 +50,7 @@ logger = logging.getLogger(__name__) @mcp.tool @mcp_auth_hook +@parse_request(UpdateChartRequest) async def update_chart( request: UpdateChartRequest, ctx: Context ) -> GenerateChartResponse: diff --git a/superset/mcp_service/chart/tool/update_chart_preview.py b/superset/mcp_service/chart/tool/update_chart_preview.py index 5b01d1a0b14..7cf69bdb6b6 100644 --- a/superset/mcp_service/chart/tool/update_chart_preview.py +++ b/superset/mcp_service/chart/tool/update_chart_preview.py @@ -40,6 +40,7 @@ from superset.mcp_service.chart.schemas import ( UpdateChartPreviewRequest, URLPreview, ) +from superset.mcp_service.utils.schema_utils import parse_request from superset.mcp_service.utils.url_utils import get_mcp_service_url logger = logging.getLogger(__name__) @@ -47,6 +48,7 @@ logger = logging.getLogger(__name__) @mcp.tool @mcp_auth_hook +@parse_request(UpdateChartPreviewRequest) def update_chart_preview( request: UpdateChartPreviewRequest, ctx: Context ) -> Dict[str, Any]: diff --git a/superset/mcp_service/dashboard/schemas.py b/superset/mcp_service/dashboard/schemas.py index fd6bd632714..346892bc63d 100644 --- a/superset/mcp_service/dashboard/schemas.py +++ b/superset/mcp_service/dashboard/schemas.py @@ -68,7 +68,14 @@ from __future__ import annotations from datetime import datetime from typing import Annotated, Any, Dict, List, Literal, TYPE_CHECKING -from pydantic import BaseModel, ConfigDict, Field, model_validator, PositiveInt +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + model_validator, + PositiveInt, +) if TYPE_CHECKING: from superset.models.dashboard import Dashboard @@ -202,6 +209,33 @@ class ListDashboardsRequest(MetadataCacheControl): "if not specified.", ), ] + + @field_validator("filters", mode="before") + @classmethod + def parse_filters(cls, v: Any) -> List[DashboardFilter]: + """ + Parse filters from JSON string or list. + + Handles Claude Code bug where objects are double-serialized as strings. + See: https://github.com/anthropics/claude-code/issues/5504 + """ + from superset.mcp_service.utils.schema_utils import parse_json_or_model_list + + return parse_json_or_model_list(v, DashboardFilter, "filters") + + @field_validator("select_columns", mode="before") + @classmethod + def parse_select_columns(cls, v: Any) -> List[str]: + """ + Parse select_columns from JSON string, list, or CSV string. + + Handles Claude Code bug where arrays are double-serialized as strings. + See: https://github.com/anthropics/claude-code/issues/5504 + """ + from superset.mcp_service.utils.schema_utils import parse_json_or_list + + return parse_json_or_list(v, "select_columns") + search: Annotated[ str | None, Field( diff --git a/superset/mcp_service/dashboard/tool/add_chart_to_existing_dashboard.py b/superset/mcp_service/dashboard/tool/add_chart_to_existing_dashboard.py index 3cd5f65a243..671b0850d46 100644 --- a/superset/mcp_service/dashboard/tool/add_chart_to_existing_dashboard.py +++ b/superset/mcp_service/dashboard/tool/add_chart_to_existing_dashboard.py @@ -33,6 +33,7 @@ from superset.mcp_service.dashboard.schemas import ( AddChartToDashboardResponse, DashboardInfo, ) +from superset.mcp_service.utils.schema_utils import parse_request from superset.mcp_service.utils.url_utils import get_superset_base_url from superset.utils import json @@ -136,6 +137,7 @@ def _ensure_layout_structure(layout: Dict[str, Any], row_key: str) -> None: @mcp.tool @mcp_auth_hook +@parse_request(AddChartToDashboardRequest) def add_chart_to_existing_dashboard( request: AddChartToDashboardRequest, ctx: Context ) -> AddChartToDashboardResponse: diff --git a/superset/mcp_service/dashboard/tool/generate_dashboard.py b/superset/mcp_service/dashboard/tool/generate_dashboard.py index c0d5fd517fb..caa727f449c 100644 --- a/superset/mcp_service/dashboard/tool/generate_dashboard.py +++ b/superset/mcp_service/dashboard/tool/generate_dashboard.py @@ -33,6 +33,7 @@ from superset.mcp_service.dashboard.schemas import ( GenerateDashboardRequest, GenerateDashboardResponse, ) +from superset.mcp_service.utils.schema_utils import parse_request from superset.mcp_service.utils.url_utils import get_superset_base_url from superset.utils import json @@ -119,6 +120,7 @@ def _create_dashboard_layout(chart_objects: List[Any]) -> Dict[str, Any]: @mcp.tool @mcp_auth_hook +@parse_request(GenerateDashboardRequest) def generate_dashboard( request: GenerateDashboardRequest, ctx: Context ) -> GenerateDashboardResponse: diff --git a/superset/mcp_service/dashboard/tool/get_dashboard_available_filters.py b/superset/mcp_service/dashboard/tool/get_dashboard_available_filters.py index 0b4c31c8e4e..8839f634685 100644 --- a/superset/mcp_service/dashboard/tool/get_dashboard_available_filters.py +++ b/superset/mcp_service/dashboard/tool/get_dashboard_available_filters.py @@ -29,12 +29,14 @@ from superset.mcp_service.dashboard.schemas import ( GetDashboardAvailableFiltersRequest, ) from superset.mcp_service.mcp_core import ModelGetAvailableFiltersCore +from superset.mcp_service.utils.schema_utils import parse_request logger = logging.getLogger(__name__) @mcp.tool @mcp_auth_hook +@parse_request(GetDashboardAvailableFiltersRequest) async def get_dashboard_available_filters( request: GetDashboardAvailableFiltersRequest, ctx: Context ) -> DashboardAvailableFilters: diff --git a/superset/mcp_service/dashboard/tool/get_dashboard_info.py b/superset/mcp_service/dashboard/tool/get_dashboard_info.py index 9a3bfa23b8b..19c398b1516 100644 --- a/superset/mcp_service/dashboard/tool/get_dashboard_info.py +++ b/superset/mcp_service/dashboard/tool/get_dashboard_info.py @@ -36,12 +36,14 @@ from superset.mcp_service.dashboard.schemas import ( GetDashboardInfoRequest, ) from superset.mcp_service.mcp_core import ModelGetInfoCore +from superset.mcp_service.utils.schema_utils import parse_request logger = logging.getLogger(__name__) @mcp.tool @mcp_auth_hook +@parse_request(GetDashboardInfoRequest) async def get_dashboard_info( request: GetDashboardInfoRequest, ctx: Context ) -> DashboardInfo | DashboardError: diff --git a/superset/mcp_service/dashboard/tool/list_dashboards.py b/superset/mcp_service/dashboard/tool/list_dashboards.py index aaa2fda7fe9..8a5ebdfa486 100644 --- a/superset/mcp_service/dashboard/tool/list_dashboards.py +++ b/superset/mcp_service/dashboard/tool/list_dashboards.py @@ -36,6 +36,7 @@ from superset.mcp_service.dashboard.schemas import ( serialize_dashboard_object, ) from superset.mcp_service.mcp_core import ModelListCore +from superset.mcp_service.utils.schema_utils import parse_request logger = logging.getLogger(__name__) @@ -61,6 +62,7 @@ SORTABLE_DASHBOARD_COLUMNS = [ @mcp.tool @mcp_auth_hook +@parse_request(ListDashboardsRequest) async def list_dashboards( request: ListDashboardsRequest, ctx: Context ) -> DashboardList: @@ -70,7 +72,6 @@ async def list_dashboards( Sortable columns for order_column: id, dashboard_title, slug, published, changed_on, created_on """ - from superset.daos.dashboard import DashboardDAO tool = ModelListCore( diff --git a/superset/mcp_service/dataset/tool/get_dataset_available_filters.py b/superset/mcp_service/dataset/tool/get_dataset_available_filters.py index 4072f05f3c3..6f1b1038bdb 100644 --- a/superset/mcp_service/dataset/tool/get_dataset_available_filters.py +++ b/superset/mcp_service/dataset/tool/get_dataset_available_filters.py @@ -29,12 +29,14 @@ from superset.mcp_service.dataset.schemas import ( GetDatasetAvailableFiltersRequest, ) from superset.mcp_service.mcp_core import ModelGetAvailableFiltersCore +from superset.mcp_service.utils.schema_utils import parse_request logger = logging.getLogger(__name__) @mcp.tool @mcp_auth_hook +@parse_request(GetDatasetAvailableFiltersRequest) async def get_dataset_available_filters( request: GetDatasetAvailableFiltersRequest, ctx: Context ) -> DatasetAvailableFilters: diff --git a/superset/mcp_service/dataset/tool/get_dataset_info.py b/superset/mcp_service/dataset/tool/get_dataset_info.py index a4b83030e1a..a82d22bba68 100644 --- a/superset/mcp_service/dataset/tool/get_dataset_info.py +++ b/superset/mcp_service/dataset/tool/get_dataset_info.py @@ -36,12 +36,14 @@ from superset.mcp_service.dataset.schemas import ( serialize_dataset_object, ) from superset.mcp_service.mcp_core import ModelGetInfoCore +from superset.mcp_service.utils.schema_utils import parse_request logger = logging.getLogger(__name__) @mcp.tool @mcp_auth_hook +@parse_request(GetDatasetInfoRequest) async def get_dataset_info( request: GetDatasetInfoRequest, ctx: Context ) -> DatasetInfo | DatasetError: diff --git a/superset/mcp_service/dataset/tool/list_datasets.py b/superset/mcp_service/dataset/tool/list_datasets.py index 32edd6def36..3fb3426b913 100644 --- a/superset/mcp_service/dataset/tool/list_datasets.py +++ b/superset/mcp_service/dataset/tool/list_datasets.py @@ -36,6 +36,7 @@ from superset.mcp_service.dataset.schemas import ( serialize_dataset_object, ) from superset.mcp_service.mcp_core import ModelListCore +from superset.mcp_service.utils.schema_utils import parse_request logger = logging.getLogger(__name__) @@ -64,6 +65,7 @@ SORTABLE_DATASET_COLUMNS = [ @mcp.tool @mcp_auth_hook +@parse_request(ListDatasetsRequest) async def list_datasets(request: ListDatasetsRequest, ctx: Context) -> DatasetList: """List datasets with filtering and search. diff --git a/superset/mcp_service/explore/tool/generate_explore_link.py b/superset/mcp_service/explore/tool/generate_explore_link.py index aa09c624e1d..8ed3ef850ab 100644 --- a/superset/mcp_service/explore/tool/generate_explore_link.py +++ b/superset/mcp_service/explore/tool/generate_explore_link.py @@ -35,10 +35,12 @@ from superset.mcp_service.chart.chart_utils import ( from superset.mcp_service.chart.schemas import ( GenerateExploreLinkRequest, ) +from superset.mcp_service.utils.schema_utils import parse_request @mcp.tool @mcp_auth_hook +@parse_request(GenerateExploreLinkRequest) async def generate_explore_link( request: GenerateExploreLinkRequest, ctx: Context ) -> Dict[str, Any]: diff --git a/superset/mcp_service/mcp_core.py b/superset/mcp_service/mcp_core.py index 669e21420d4..ded6b74ebf9 100644 --- a/superset/mcp_service/mcp_core.py +++ b/superset/mcp_service/mcp_core.py @@ -128,17 +128,19 @@ class ModelListCore(BaseCore, Generic[L]): page: int = 0, page_size: int = 10, ) -> L: - # If filters is a string (e.g., from a test), parse it as JSON - if isinstance(filters, str): - from superset.utils import json + # Parse filters using generic utility (accepts JSON string or object) + from superset.mcp_service.utils.schema_utils import ( + parse_json_or_list, + parse_json_or_passthrough, + ) - filters = json.loads(filters) - # Ensure select_columns is a list and track what was requested + filters = parse_json_or_passthrough(filters, param_name="filters") + + # Parse select_columns using generic utility (accepts JSON, list, or CSV) if select_columns: - if isinstance(select_columns, str): - select_columns = [ - col.strip() for col in select_columns.split(",") if col.strip() - ] + select_columns = parse_json_or_list( + select_columns, param_name="select_columns" + ) columns_to_load = select_columns columns_requested = select_columns else: diff --git a/superset/mcp_service/sql_lab/tool/execute_sql.py b/superset/mcp_service/sql_lab/tool/execute_sql.py index 4fb8914d60e..ca2302b9b7c 100644 --- a/superset/mcp_service/sql_lab/tool/execute_sql.py +++ b/superset/mcp_service/sql_lab/tool/execute_sql.py @@ -33,12 +33,14 @@ from superset.mcp_service.sql_lab.schemas import ( ExecuteSqlRequest, ExecuteSqlResponse, ) +from superset.mcp_service.utils.schema_utils import parse_request logger = logging.getLogger(__name__) @mcp.tool @mcp_auth_hook +@parse_request(ExecuteSqlRequest) async def execute_sql(request: ExecuteSqlRequest, ctx: Context) -> ExecuteSqlResponse: """Execute SQL query against database. diff --git a/superset/mcp_service/sql_lab/tool/open_sql_lab_with_context.py b/superset/mcp_service/sql_lab/tool/open_sql_lab_with_context.py index 8e25a82cd4f..61693b97d73 100644 --- a/superset/mcp_service/sql_lab/tool/open_sql_lab_with_context.py +++ b/superset/mcp_service/sql_lab/tool/open_sql_lab_with_context.py @@ -32,12 +32,14 @@ from superset.mcp_service.sql_lab.schemas import ( OpenSqlLabRequest, SqlLabResponse, ) +from superset.mcp_service.utils.schema_utils import parse_request logger = logging.getLogger(__name__) @mcp.tool @mcp_auth_hook +@parse_request(OpenSqlLabRequest) def open_sql_lab_with_context( request: OpenSqlLabRequest, ctx: Context ) -> SqlLabResponse: diff --git a/superset/mcp_service/system/tool/get_instance_info.py b/superset/mcp_service/system/tool/get_instance_info.py index f83ca3da201..7a01ba618b1 100644 --- a/superset/mcp_service/system/tool/get_instance_info.py +++ b/superset/mcp_service/system/tool/get_instance_info.py @@ -38,6 +38,7 @@ from superset.mcp_service.system.system_utils import ( calculate_popular_content, calculate_recent_activity, ) +from superset.mcp_service.utils.schema_utils import parse_request logger = logging.getLogger(__name__) @@ -71,6 +72,7 @@ _instance_info_core = InstanceInfoCore( @mcp.tool @mcp_auth_hook +@parse_request(GetSupersetInstanceInfoRequest) def get_instance_info( request: GetSupersetInstanceInfoRequest, ctx: Context ) -> InstanceInfo: diff --git a/superset/mcp_service/utils/schema_utils.py b/superset/mcp_service/utils/schema_utils.py new file mode 100644 index 00000000000..f78ad20f050 --- /dev/null +++ b/superset/mcp_service/utils/schema_utils.py @@ -0,0 +1,444 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Generic utilities for flexible schema input handling in MCP tools. + +This module provides utilities to accept both JSON string and object formats +for input parameters, making MCP tools more flexible for different clients. +""" + +from __future__ import annotations + +import asyncio +import logging +from functools import wraps +from typing import Any, Callable, List, Type, TypeVar + +from pydantic import BaseModel, ValidationError + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +class JSONParseError(ValueError): + """Raised when JSON parsing fails with helpful context.""" + + def __init__(self, value: Any, error: Exception, param_name: str = "parameter"): + self.value = value + self.original_error = error + self.param_name = param_name + super().__init__( + f"Failed to parse {param_name} from JSON string: {error}. " + f"Received value: {value!r}" + ) + + +def parse_json_or_passthrough( + value: Any, param_name: str = "parameter", strict: bool = False +) -> Any: + """ + Parse a value that can be either a JSON string or a native Python object. + + This function handles the common pattern where API parameters can be provided + as either: + - A JSON string (e.g., from CLI tools or tests): '{"key": "value"}' + - A native Python object (e.g., from LLM clients): {"key": "value"} + + Args: + value: The input value to parse. Can be a string, list, dict, or any JSON- + serializable type. + param_name: Name of the parameter for error messages (default: "parameter") + strict: If True, raises JSONParseError on parse failures. If False, logs + warning and returns original value (default: False) + + Returns: + Parsed Python object if value was a JSON string, otherwise returns value + unchanged. + + Raises: + JSONParseError: If strict=True and JSON parsing fails. + + Examples: + >>> parse_json_or_passthrough('[1, 2, 3]', 'numbers') + [1, 2, 3] + + >>> parse_json_or_passthrough([1, 2, 3], 'numbers') + [1, 2, 3] + + >>> parse_json_or_passthrough('{"key": "value"}', 'config') + {'key': 'value'} + + >>> parse_json_or_passthrough({'key': 'value'}, 'config') + {'key': 'value'} + """ + # If not a string, return as-is (already in object form) + if not isinstance(value, str): + return value + + # Try to parse as JSON + try: + from superset.utils import json + + parsed = json.loads(value) + logger.debug("Successfully parsed %s from JSON string", param_name) + return parsed + except (ValueError, TypeError) as e: + error_msg = ( + f"Failed to parse {param_name} from JSON string: {e}. Received: {value!r}" + ) + + if strict: + raise JSONParseError(value, e, param_name) from None + + logger.warning("%s. Returning original value.", error_msg) + return value + + +def parse_json_or_list( + value: Any, param_name: str = "parameter", item_separator: str = "," +) -> List[Any]: + """ + Parse a value into a list, accepting JSON string, list, or comma-separated string. + + This function provides maximum flexibility for list parameters by accepting: + - JSON array string: '["item1", "item2"]' + - Python list: ["item1", "item2"] + - Comma-separated string: "item1, item2, item3" + - Empty/None: returns empty list + + Args: + value: Input value to parse into a list + param_name: Name of the parameter for error messages + item_separator: Separator for comma-separated strings (default: ",") + + Returns: + List of items. Returns empty list if value is None or empty. + + Examples: + >>> parse_json_or_list('["a", "b"]', 'items') + ['a', 'b'] + + >>> parse_json_or_list(['a', 'b'], 'items') + ['a', 'b'] + + >>> parse_json_or_list('a, b, c', 'items') + ['a', 'b', 'c'] + + >>> parse_json_or_list(None, 'items') + [] + """ + # Handle None and empty values + if value is None or value == "": + return [] + + # Already a list, return as-is + if isinstance(value, list): + return value + + # Try to parse as JSON + if isinstance(value, str): + try: + from superset.utils import json + + parsed = json.loads(value) + # If successfully parsed and it's a list, return it + if isinstance(parsed, list): + logger.debug("Successfully parsed %s from JSON string", param_name) + return parsed + + # If parsed to non-list (e.g., single value), wrap in list + logger.debug( + "Parsed %s from JSON to non-list, wrapping in list", param_name + ) + return [parsed] + except (ValueError, TypeError): + # Not valid JSON, try comma-separated parsing + logger.debug( + "Could not parse %s as JSON, trying comma-separated", param_name + ) + items = [ + item.strip() for item in value.split(item_separator) if item.strip() + ] + return items + + # For any other type, wrap in a list + logger.debug("Wrapping %s value in list", param_name) + return [value] + + +def parse_json_or_model( + value: Any, model_class: Type[BaseModel], param_name: str = "parameter" +) -> BaseModel: + """ + Parse a value into a Pydantic model, accepting JSON string or dict. + + Args: + value: Input value to parse (JSON string, dict, or model instance) + model_class: Pydantic model class to validate against + param_name: Name of the parameter for error messages + + Returns: + Validated Pydantic model instance + + Raises: + ValidationError: If the value cannot be parsed or validated + + Examples: + >>> class MyModel(BaseModel): + ... name: str + ... value: int + + >>> parse_json_or_model('{"name": "test", "value": 42}', MyModel) + MyModel(name='test', value=42) + + >>> parse_json_or_model({"name": "test", "value": 42}, MyModel) + MyModel(name='test', value=42) + """ + # If already an instance of the model, return as-is + if isinstance(value, model_class): + return value + + # Parse JSON string if needed + parsed_value = parse_json_or_passthrough(value, param_name, strict=True) + + # Validate and construct the model + try: + return model_class.model_validate(parsed_value) + except ValidationError: + logger.error( + "Failed to validate %s against %s", param_name, model_class.__name__ + ) + raise + + +def parse_json_or_model_list( + value: Any, + model_class: Type[BaseModel], + param_name: str = "parameter", +) -> List[BaseModel]: + """ + Parse a value into a list of Pydantic models, accepting JSON string or list. + + Args: + value: Input value to parse (JSON string, list of dicts, or list of models) + model_class: Pydantic model class for list items + param_name: Name of the parameter for error messages + + Returns: + List of validated Pydantic model instances + + Raises: + ValidationError: If any item cannot be parsed or validated + + Examples: + >>> class Item(BaseModel): + ... name: str + + >>> parse_json_or_model_list('[{"name": "a"}, {"name": "b"}]', Item) + [Item(name='a'), Item(name='b')] + + >>> parse_json_or_model_list([{"name": "a"}, {"name": "b"}], Item) + [Item(name='a'), Item(name='b')] + """ + # Handle None and empty + if value is None or value == "": + return [] + + # Parse to list first + items = parse_json_or_list(value, param_name) + + # Validate each item against the model + validated_items = [] + for i, item in enumerate(items): + try: + if isinstance(item, model_class): + validated_items.append(item) + else: + validated_items.append(model_class.model_validate(item)) + except ValidationError: + logger.error( + "Failed to validate %s[%s] against %s", + param_name, + i, + model_class.__name__, + ) + # Re-raise original validation error + raise + + return validated_items + + +# Pydantic validator decorators for common use cases +def json_or_passthrough_validator( + param_name: str | None = None, strict: bool = False +) -> Callable[[Type[BaseModel], Any, Any], Any]: + """ + Decorator factory for Pydantic field validators that accept JSON or objects. + + This creates a validator that can be used with Pydantic's @field_validator + decorator to automatically parse JSON strings. + + Args: + param_name: Parameter name for error messages (uses field name if None) + strict: Whether to raise errors on parse failures + + Returns: + Validator function compatible with @field_validator + + Example: + >>> class MySchema(BaseModel): + ... config: dict + ... + ... @field_validator('config', mode='before') + ... @classmethod + ... def parse_config(cls, v): + ... return parse_json_or_passthrough(v, 'config') + """ + + def validator(cls: Type[BaseModel], v: Any, info: Any = None) -> Any: + # Use field name from validation info if param_name not provided + field_name = param_name or (info.field_name if info else "field") + return parse_json_or_passthrough(v, field_name, strict) + + return validator + + +def json_or_list_validator( + param_name: str | None = None, item_separator: str = "," +) -> Callable[[Type[BaseModel], Any, Any], List[Any]]: + """ + Decorator factory for Pydantic validators that parse values into lists. + + Args: + param_name: Parameter name for error messages + item_separator: Separator for comma-separated strings + + Returns: + Validator function compatible with @field_validator + + Example: + >>> class MySchema(BaseModel): + ... items: List[str] + ... + ... @field_validator('items', mode='before') + ... @classmethod + ... def parse_items(cls, v): + ... return parse_json_or_list(v, 'items') + """ + + def validator(cls: Type[BaseModel], v: Any, info: Any = None) -> List[Any]: + field_name = param_name or (info.field_name if info else "field") + return parse_json_or_list(v, field_name, item_separator) + + return validator + + +def json_or_model_list_validator( + model_class: Type[BaseModel], param_name: str | None = None +) -> Callable[[Type[BaseModel], Any, Any], List[BaseModel]]: + """ + Decorator factory for Pydantic validators that parse lists of models. + + Args: + model_class: Pydantic model class for list items + param_name: Parameter name for error messages + + Returns: + Validator function compatible with @field_validator + + Example: + >>> class FilterModel(BaseModel): + ... col: str + ... value: str + ... + >>> class MySchema(BaseModel): + ... filters: List[FilterModel] + ... + ... @field_validator('filters', mode='before') + ... @classmethod + ... def parse_filters(cls, v): + ... return parse_json_or_model_list(v, FilterModel, 'filters') + """ + + def validator(cls: Type[BaseModel], v: Any, info: Any = None) -> List[BaseModel]: + field_name = param_name or (info.field_name if info else "field") + return parse_json_or_model_list(v, model_class, field_name) + + return validator + + +def parse_request( + request_class: Type[BaseModel], +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """ + Decorator to handle Claude Code bug where requests are double-serialized as strings. + + Automatically parses string requests to Pydantic models before calling + the tool function. + This eliminates the need for manual parsing code in every tool function. + + See: https://github.com/anthropics/claude-code/issues/5504 + + Args: + request_class: The Pydantic model class for the request + + Returns: + Decorator function that wraps the tool function + + Usage: + @mcp.tool + @mcp_auth_hook + @parse_request(ListChartsRequest) + async def list_charts( + request: ListChartsRequest, ctx: Context + ) -> ChartList: + # Decorator handles string conversion automatically + await ctx.info(f"Listing charts: page={request.page}") + ... + + Note: + - Works with both async and sync functions + - Request must be the first positional argument + - If request is already a model instance, it passes through unchanged + - Handles JSON string parsing with helpful error messages + """ + + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + if asyncio.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(request: Any, *args: Any, **kwargs: Any) -> Any: + # Parse if string, otherwise pass through + # (parse_json_or_model handles both) + parsed_request = parse_json_or_model(request, request_class, "request") + return await func(parsed_request, *args, **kwargs) + + return async_wrapper + else: + + @wraps(func) + def sync_wrapper(request: Any, *args: Any, **kwargs: Any) -> Any: + # Parse if string, otherwise pass through + # (parse_json_or_model handles both) + parsed_request = parse_json_or_model(request, request_class, "request") + return func(parsed_request, *args, **kwargs) + + return sync_wrapper + + return decorator diff --git a/tests/unit_tests/mcp_service/utils/test_schema_utils.py b/tests/unit_tests/mcp_service/utils/test_schema_utils.py new file mode 100644 index 00000000000..2c28f40782b --- /dev/null +++ b/tests/unit_tests/mcp_service/utils/test_schema_utils.py @@ -0,0 +1,515 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Unit tests for MCP service schema utilities. +""" + +import pytest +from pydantic import BaseModel, ValidationError + +from superset.mcp_service.utils.schema_utils import ( + JSONParseError, + parse_json_or_list, + parse_json_or_model, + parse_json_or_model_list, + parse_json_or_passthrough, + parse_request, +) + + +class TestParseJsonOrPassthrough: + """Test parse_json_or_passthrough function.""" + + def test_parse_valid_json_string(self): + """Should parse valid JSON string to Python object.""" + result = parse_json_or_passthrough('{"key": "value"}', "config") + assert result == {"key": "value"} + + def test_parse_json_array(self): + """Should parse JSON array string to Python list.""" + result = parse_json_or_passthrough("[1, 2, 3]", "numbers") + assert result == [1, 2, 3] + + def test_passthrough_dict(self): + """Should return dict as-is without parsing.""" + input_dict = {"key": "value"} + result = parse_json_or_passthrough(input_dict, "config") + assert result is input_dict + + def test_passthrough_list(self): + """Should return list as-is without parsing.""" + input_list = [1, 2, 3] + result = parse_json_or_passthrough(input_list, "numbers") + assert result is input_list + + def test_passthrough_none(self): + """Should return None as-is.""" + result = parse_json_or_passthrough(None, "value") + assert result is None + + def test_invalid_json_non_strict(self): + """Should return original value when JSON parsing fails (non-strict).""" + result = parse_json_or_passthrough("not valid json", "config", strict=False) + assert result == "not valid json" + + def test_invalid_json_strict(self): + """Should raise JSONParseError when parsing fails (strict mode).""" + with pytest.raises(JSONParseError) as exc_info: + parse_json_or_passthrough("not valid json", "config", strict=True) + + assert exc_info.value.param_name == "config" + assert exc_info.value.value == "not valid json" + + def test_parse_json_number(self): + """Should parse numeric JSON string.""" + result = parse_json_or_passthrough("42", "number") + assert result == 42 + + def test_parse_json_boolean(self): + """Should parse boolean JSON string.""" + result = parse_json_or_passthrough("true", "flag") + assert result is True + + def test_parse_nested_json(self): + """Should parse nested JSON structures.""" + json_str = '{"outer": {"inner": [1, 2, 3]}}' + result = parse_json_or_passthrough(json_str, "nested") + assert result == {"outer": {"inner": [1, 2, 3]}} + + +class TestParseJsonOrList: + """Test parse_json_or_list function.""" + + def test_parse_json_array(self): + """Should parse JSON array string to list.""" + result = parse_json_or_list('["a", "b", "c"]', "items") + assert result == ["a", "b", "c"] + + def test_passthrough_list(self): + """Should return list as-is.""" + input_list = ["a", "b", "c"] + result = parse_json_or_list(input_list, "items") + assert result is input_list + + def test_parse_comma_separated(self): + """Should parse comma-separated string to list.""" + result = parse_json_or_list("a, b, c", "items") + assert result == ["a", "b", "c"] + + def test_parse_comma_separated_with_whitespace(self): + """Should handle whitespace in comma-separated strings.""" + result = parse_json_or_list(" a , b , c ", "items") + assert result == ["a", "b", "c"] + + def test_empty_string_returns_empty_list(self): + """Should return empty list for empty string.""" + result = parse_json_or_list("", "items") + assert result == [] + + def test_none_returns_empty_list(self): + """Should return empty list for None.""" + result = parse_json_or_list(None, "items") + assert result == [] + + def test_single_json_value_wrapped_in_list(self): + """Should wrap single JSON value in list.""" + result = parse_json_or_list('"single"', "items") + assert result == ["single"] + + def test_custom_separator(self): + """Should use custom separator when provided.""" + result = parse_json_or_list("a|b|c", "items", item_separator="|") + assert result == ["a", "b", "c"] + + def test_non_list_wrapped(self): + """Should wrap non-list types in a list.""" + result = parse_json_or_list(42, "items") + assert result == [42] + + def test_parse_empty_items_in_csv(self): + """Should ignore empty items in comma-separated string.""" + result = parse_json_or_list("a,,b,,c", "items") + assert result == ["a", "b", "c"] + + +class TestParseJsonOrModel: + """Test parse_json_or_model function.""" + + class TestModel(BaseModel): + """Test Pydantic model.""" + + name: str + value: int + + def test_parse_json_string(self): + """Should parse JSON string to model instance.""" + result = parse_json_or_model( + '{"name": "test", "value": 42}', self.TestModel, "config" + ) + assert isinstance(result, self.TestModel) + assert result.name == "test" + assert result.value == 42 + + def test_parse_dict(self): + """Should parse dict to model instance.""" + result = parse_json_or_model( + {"name": "test", "value": 42}, self.TestModel, "config" + ) + assert isinstance(result, self.TestModel) + assert result.name == "test" + assert result.value == 42 + + def test_passthrough_model_instance(self): + """Should return model instance as-is.""" + instance = self.TestModel(name="test", value=42) + result = parse_json_or_model(instance, self.TestModel, "config") + assert result is instance + + def test_invalid_json_raises_error(self): + """Should raise JSONParseError for invalid JSON.""" + with pytest.raises(JSONParseError): + parse_json_or_model("not valid json", self.TestModel, "config") + + def test_invalid_model_data_raises_validation_error(self): + """Should raise ValidationError for invalid model data.""" + with pytest.raises(ValidationError): + parse_json_or_model({"name": "test"}, self.TestModel, "config") + + def test_wrong_type_raises_validation_error(self): + """Should raise ValidationError for wrong data types.""" + with pytest.raises(ValidationError): + parse_json_or_model( + {"name": "test", "value": "not_a_number"}, self.TestModel, "config" + ) + + +class TestParseJsonOrModelList: + """Test parse_json_or_model_list function.""" + + class ItemModel(BaseModel): + """Test Pydantic model for list items.""" + + name: str + value: int + + def test_parse_json_array(self): + """Should parse JSON array to list of models.""" + json_str = '[{"name": "a", "value": 1}, {"name": "b", "value": 2}]' + result = parse_json_or_model_list(json_str, self.ItemModel, "items") + assert len(result) == 2 + assert all(isinstance(item, self.ItemModel) for item in result) + assert result[0].name == "a" + assert result[1].value == 2 + + def test_parse_list_of_dicts(self): + """Should parse list of dicts to list of models.""" + input_list = [{"name": "a", "value": 1}, {"name": "b", "value": 2}] + result = parse_json_or_model_list(input_list, self.ItemModel, "items") + assert len(result) == 2 + assert all(isinstance(item, self.ItemModel) for item in result) + + def test_passthrough_list_of_models(self): + """Should return list of models as-is.""" + input_list = [ + self.ItemModel(name="a", value=1), + self.ItemModel(name="b", value=2), + ] + result = parse_json_or_model_list(input_list, self.ItemModel, "items") + assert len(result) == 2 + assert result[0] is input_list[0] + assert result[1] is input_list[1] + + def test_empty_returns_empty_list(self): + """Should return empty list for empty input.""" + assert parse_json_or_model_list(None, self.ItemModel, "items") == [] + assert parse_json_or_model_list("", self.ItemModel, "items") == [] + assert parse_json_or_model_list([], self.ItemModel, "items") == [] + + def test_invalid_item_raises_validation_error(self): + """Should raise ValidationError for invalid item in list.""" + input_list = [{"name": "a", "value": 1}, {"name": "b"}] # Missing value + with pytest.raises(ValidationError): + parse_json_or_model_list(input_list, self.ItemModel, "items") + + def test_mixed_models_and_dicts(self): + """Should handle mixed list of models and dicts.""" + input_list = [ + self.ItemModel(name="a", value=1), + {"name": "b", "value": 2}, + ] + result = parse_json_or_model_list(input_list, self.ItemModel, "items") + assert len(result) == 2 + assert all(isinstance(item, self.ItemModel) for item in result) + + +class TestPydanticIntegration: + """Test integration with Pydantic validators.""" + + def test_field_validator_with_json_string(self): + """Should work with Pydantic field validators for JSON strings.""" + from pydantic import field_validator + + class TestSchema(BaseModel): + """Test schema with field validator.""" + + config: dict + + @field_validator("config", mode="before") + @classmethod + def parse_config(cls, v): + """Parse config from JSON or dict.""" + return parse_json_or_passthrough(v, "config") + + # Test with JSON string + schema = TestSchema.model_validate({"config": '{"key": "value"}'}) + assert schema.config == {"key": "value"} + + # Test with dict + schema = TestSchema.model_validate({"config": {"key": "value"}}) + assert schema.config == {"key": "value"} + + def test_field_validator_with_list(self): + """Should work with Pydantic field validators for lists.""" + from pydantic import field_validator + + class TestSchema(BaseModel): + """Test schema with list field validator.""" + + items: list + + @field_validator("items", mode="before") + @classmethod + def parse_items(cls, v): + """Parse items from various formats.""" + return parse_json_or_list(v, "items") + + # Test with JSON array + schema = TestSchema.model_validate({"items": '["a", "b", "c"]'}) + assert schema.items == ["a", "b", "c"] + + # Test with list + schema = TestSchema.model_validate({"items": ["a", "b", "c"]}) + assert schema.items == ["a", "b", "c"] + + # Test with CSV string + schema = TestSchema.model_validate({"items": "a, b, c"}) + assert schema.items == ["a", "b", "c"] + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_empty_json_object(self): + """Should handle empty JSON objects.""" + result = parse_json_or_passthrough("{}", "config") + assert result == {} + + def test_empty_json_array(self): + """Should handle empty JSON arrays.""" + result = parse_json_or_list("[]", "items") + assert result == [] + + def test_whitespace_only_string(self): + """Should handle whitespace-only strings.""" + result = parse_json_or_list(" ", "items") + assert result == [] + + def test_malformed_json(self): + """Should handle malformed JSON gracefully.""" + result = parse_json_or_passthrough('{"key": invalid}', "config", strict=False) + assert result == '{"key": invalid}' + + def test_unicode_in_json(self): + """Should handle Unicode characters in JSON.""" + result = parse_json_or_passthrough('{"name": "测试"}', "config") + assert result == {"name": "测试"} + + def test_special_characters_in_csv(self): + """Should handle special characters in CSV strings.""" + result = parse_json_or_list("item-1, item_2, item.3", "items") + assert result == ["item-1", "item_2", "item.3"] + + +class TestParseRequestDecorator: + """Test parse_request decorator for MCP tools.""" + + class RequestModel(BaseModel): + """Test request model.""" + + name: str + count: int + + def test_decorator_with_json_string_async(self): + """Should parse JSON string request in async function.""" + + @parse_request(self.RequestModel) + async def async_tool(request, ctx=None): + return f"{request.name}:{request.count}" + + import asyncio + + result = asyncio.run(async_tool('{"name": "test", "count": 5}')) + assert result == "test:5" + + def test_decorator_with_json_string_sync(self): + """Should parse JSON string request in sync function.""" + + @parse_request(self.RequestModel) + def sync_tool(request, ctx=None): + return f"{request.name}:{request.count}" + + result = sync_tool('{"name": "test", "count": 5}') + assert result == "test:5" + + def test_decorator_with_dict_async(self): + """Should handle dict request in async function.""" + + @parse_request(self.RequestModel) + async def async_tool(request, ctx=None): + return f"{request.name}:{request.count}" + + import asyncio + + result = asyncio.run(async_tool({"name": "test", "count": 5})) + assert result == "test:5" + + def test_decorator_with_dict_sync(self): + """Should handle dict request in sync function.""" + + @parse_request(self.RequestModel) + def sync_tool(request, ctx=None): + return f"{request.name}:{request.count}" + + result = sync_tool({"name": "test", "count": 5}) + assert result == "test:5" + + def test_decorator_with_model_instance_async(self): + """Should pass through model instance in async function.""" + + @parse_request(self.RequestModel) + async def async_tool(request, ctx=None): + return f"{request.name}:{request.count}" + + import asyncio + + instance = self.RequestModel(name="test", count=5) + result = asyncio.run(async_tool(instance)) + assert result == "test:5" + + def test_decorator_with_model_instance_sync(self): + """Should pass through model instance in sync function.""" + + @parse_request(self.RequestModel) + def sync_tool(request, ctx=None): + return f"{request.name}:{request.count}" + + instance = self.RequestModel(name="test", count=5) + result = sync_tool(instance) + assert result == "test:5" + + def test_decorator_preserves_function_signature_async(self): + """Should preserve original async function signature.""" + + @parse_request(self.RequestModel) + async def async_tool(request, ctx=None, extra=None): + return f"{request.name}:{request.count}:{extra}" + + import asyncio + + result = asyncio.run( + async_tool('{"name": "test", "count": 5}', ctx=None, extra="data") + ) + assert result == "test:5:data" + + def test_decorator_preserves_function_signature_sync(self): + """Should preserve original sync function signature.""" + + @parse_request(self.RequestModel) + def sync_tool(request, ctx=None, extra=None): + return f"{request.name}:{request.count}:{extra}" + + result = sync_tool('{"name": "test", "count": 5}', ctx=None, extra="data") + assert result == "test:5:data" + + def test_decorator_raises_validation_error_async(self): + """Should raise ValidationError for invalid data in async function.""" + + @parse_request(self.RequestModel) + async def async_tool(request, ctx=None): + return f"{request.name}:{request.count}" + + import asyncio + + with pytest.raises(ValidationError): + asyncio.run(async_tool('{"name": "test"}')) # Missing count + + def test_decorator_raises_validation_error_sync(self): + """Should raise ValidationError for invalid data in sync function.""" + + @parse_request(self.RequestModel) + def sync_tool(request, ctx=None): + return f"{request.name}:{request.count}" + + with pytest.raises(ValidationError): + sync_tool('{"name": "test"}') # Missing count + + def test_decorator_with_complex_model_async(self): + """Should handle complex nested models in async function.""" + + class NestedModel(BaseModel): + """Nested model.""" + + value: int + + class ComplexModel(BaseModel): + """Complex request model.""" + + name: str + nested: NestedModel + + @parse_request(ComplexModel) + async def async_tool(request, ctx=None): + return f"{request.name}:{request.nested.value}" + + import asyncio + + json_str = '{"name": "test", "nested": {"value": 42}}' + result = asyncio.run(async_tool(json_str)) + assert result == "test:42" + + def test_decorator_with_complex_model_sync(self): + """Should handle complex nested models in sync function.""" + + class NestedModel(BaseModel): + """Nested model.""" + + value: int + + class ComplexModel(BaseModel): + """Complex request model.""" + + name: str + nested: NestedModel + + @parse_request(ComplexModel) + def sync_tool(request, ctx=None): + return f"{request.name}:{request.nested.value}" + + json_str = '{"name": "test", "nested": {"value": 42}}' + result = sync_tool(json_str) + assert result == "test:42"