Files
superset2/tests/unit_tests/mcp_service/utils/test_schema_utils.py

567 lines
21 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for MCP service schema utilities.
"""
import pytest
from pydantic import BaseModel, ValidationError
from superset.mcp_service.utils.schema_utils import (
JSONParseError,
parse_json_or_list,
parse_json_or_model,
parse_json_or_model_list,
parse_json_or_passthrough,
parse_request,
)
class TestParseJsonOrPassthrough:
"""Test parse_json_or_passthrough function."""
def test_parse_valid_json_string(self):
"""Should parse valid JSON string to Python object."""
result = parse_json_or_passthrough('{"key": "value"}', "config")
assert result == {"key": "value"}
def test_parse_json_array(self):
"""Should parse JSON array string to Python list."""
result = parse_json_or_passthrough("[1, 2, 3]", "numbers")
assert result == [1, 2, 3]
def test_passthrough_dict(self):
"""Should return dict as-is without parsing."""
input_dict = {"key": "value"}
result = parse_json_or_passthrough(input_dict, "config")
assert result is input_dict
def test_passthrough_list(self):
"""Should return list as-is without parsing."""
input_list = [1, 2, 3]
result = parse_json_or_passthrough(input_list, "numbers")
assert result is input_list
def test_passthrough_none(self):
"""Should return None as-is."""
result = parse_json_or_passthrough(None, "value")
assert result is None
def test_invalid_json_non_strict(self):
"""Should return original value when JSON parsing fails (non-strict)."""
result = parse_json_or_passthrough("not valid json", "config", strict=False)
assert result == "not valid json"
def test_invalid_json_strict(self):
"""Should raise JSONParseError when parsing fails (strict mode)."""
with pytest.raises(JSONParseError) as exc_info:
parse_json_or_passthrough("not valid json", "config", strict=True)
assert exc_info.value.param_name == "config"
assert exc_info.value.value == "not valid json"
def test_parse_json_number(self):
"""Should parse numeric JSON string."""
result = parse_json_or_passthrough("42", "number")
assert result == 42
def test_parse_json_boolean(self):
"""Should parse boolean JSON string."""
result = parse_json_or_passthrough("true", "flag")
assert result is True
def test_parse_nested_json(self):
"""Should parse nested JSON structures."""
json_str = '{"outer": {"inner": [1, 2, 3]}}'
result = parse_json_or_passthrough(json_str, "nested")
assert result == {"outer": {"inner": [1, 2, 3]}}
class TestParseJsonOrList:
"""Test parse_json_or_list function."""
def test_parse_json_array(self):
"""Should parse JSON array string to list."""
result = parse_json_or_list('["a", "b", "c"]', "items")
assert result == ["a", "b", "c"]
def test_passthrough_list(self):
"""Should return list as-is."""
input_list = ["a", "b", "c"]
result = parse_json_or_list(input_list, "items")
assert result is input_list
def test_parse_comma_separated(self):
"""Should parse comma-separated string to list."""
result = parse_json_or_list("a, b, c", "items")
assert result == ["a", "b", "c"]
def test_parse_comma_separated_with_whitespace(self):
"""Should handle whitespace in comma-separated strings."""
result = parse_json_or_list(" a , b , c ", "items")
assert result == ["a", "b", "c"]
def test_empty_string_returns_empty_list(self):
"""Should return empty list for empty string."""
result = parse_json_or_list("", "items")
assert result == []
def test_none_returns_empty_list(self):
"""Should return empty list for None."""
result = parse_json_or_list(None, "items")
assert result == []
def test_single_json_value_wrapped_in_list(self):
"""Should wrap single JSON value in list."""
result = parse_json_or_list('"single"', "items")
assert result == ["single"]
def test_custom_separator(self):
"""Should use custom separator when provided."""
result = parse_json_or_list("a|b|c", "items", item_separator="|")
assert result == ["a", "b", "c"]
def test_non_list_wrapped(self):
"""Should wrap non-list types in a list."""
result = parse_json_or_list(42, "items")
assert result == [42]
def test_parse_empty_items_in_csv(self):
"""Should ignore empty items in comma-separated string."""
result = parse_json_or_list("a,,b,,c", "items")
assert result == ["a", "b", "c"]
class TestParseJsonOrModel:
"""Test parse_json_or_model function."""
class TestModel(BaseModel):
"""Test Pydantic model."""
name: str
value: int
def test_parse_json_string(self):
"""Should parse JSON string to model instance."""
result = parse_json_or_model(
'{"name": "test", "value": 42}', self.TestModel, "config"
)
assert isinstance(result, self.TestModel)
assert result.name == "test"
assert result.value == 42
def test_parse_dict(self):
"""Should parse dict to model instance."""
result = parse_json_or_model(
{"name": "test", "value": 42}, self.TestModel, "config"
)
assert isinstance(result, self.TestModel)
assert result.name == "test"
assert result.value == 42
def test_passthrough_model_instance(self):
"""Should return model instance as-is."""
instance = self.TestModel(name="test", value=42)
result = parse_json_or_model(instance, self.TestModel, "config")
assert result is instance
def test_invalid_json_raises_error(self):
"""Should raise JSONParseError for invalid JSON."""
with pytest.raises(JSONParseError):
parse_json_or_model("not valid json", self.TestModel, "config")
def test_invalid_model_data_raises_validation_error(self):
"""Should raise ValidationError for invalid model data."""
with pytest.raises(ValidationError):
parse_json_or_model({"name": "test"}, self.TestModel, "config")
def test_wrong_type_raises_validation_error(self):
"""Should raise ValidationError for wrong data types."""
with pytest.raises(ValidationError):
parse_json_or_model(
{"name": "test", "value": "not_a_number"}, self.TestModel, "config"
)
class TestParseJsonOrModelList:
"""Test parse_json_or_model_list function."""
class ItemModel(BaseModel):
"""Test Pydantic model for list items."""
name: str
value: int
def test_parse_json_array(self):
"""Should parse JSON array to list of models."""
json_str = '[{"name": "a", "value": 1}, {"name": "b", "value": 2}]'
result = parse_json_or_model_list(json_str, self.ItemModel, "items")
assert len(result) == 2
assert all(isinstance(item, self.ItemModel) for item in result)
assert result[0].name == "a"
assert result[1].value == 2
def test_parse_list_of_dicts(self):
"""Should parse list of dicts to list of models."""
input_list = [{"name": "a", "value": 1}, {"name": "b", "value": 2}]
result = parse_json_or_model_list(input_list, self.ItemModel, "items")
assert len(result) == 2
assert all(isinstance(item, self.ItemModel) for item in result)
def test_passthrough_list_of_models(self):
"""Should return list of models as-is."""
input_list = [
self.ItemModel(name="a", value=1),
self.ItemModel(name="b", value=2),
]
result = parse_json_or_model_list(input_list, self.ItemModel, "items")
assert len(result) == 2
assert result[0] is input_list[0]
assert result[1] is input_list[1]
def test_empty_returns_empty_list(self):
"""Should return empty list for empty input."""
assert parse_json_or_model_list(None, self.ItemModel, "items") == []
assert parse_json_or_model_list("", self.ItemModel, "items") == []
assert parse_json_or_model_list([], self.ItemModel, "items") == []
def test_invalid_item_raises_validation_error(self):
"""Should raise ValidationError for invalid item in list."""
input_list = [{"name": "a", "value": 1}, {"name": "b"}] # Missing value
with pytest.raises(ValidationError):
parse_json_or_model_list(input_list, self.ItemModel, "items")
def test_mixed_models_and_dicts(self):
"""Should handle mixed list of models and dicts."""
input_list = [
self.ItemModel(name="a", value=1),
{"name": "b", "value": 2},
]
result = parse_json_or_model_list(input_list, self.ItemModel, "items")
assert len(result) == 2
assert all(isinstance(item, self.ItemModel) for item in result)
class TestPydanticIntegration:
"""Test integration with Pydantic validators."""
def test_field_validator_with_json_string(self):
"""Should work with Pydantic field validators for JSON strings."""
from pydantic import field_validator
class TestSchema(BaseModel):
"""Test schema with field validator."""
config: dict
@field_validator("config", mode="before")
@classmethod
def parse_config(cls, v):
"""Parse config from JSON or dict."""
return parse_json_or_passthrough(v, "config")
# Test with JSON string
schema = TestSchema.model_validate({"config": '{"key": "value"}'})
assert schema.config == {"key": "value"}
# Test with dict
schema = TestSchema.model_validate({"config": {"key": "value"}})
assert schema.config == {"key": "value"}
def test_field_validator_with_list(self):
"""Should work with Pydantic field validators for lists."""
from pydantic import field_validator
class TestSchema(BaseModel):
"""Test schema with list field validator."""
items: list
@field_validator("items", mode="before")
@classmethod
def parse_items(cls, v):
"""Parse items from various formats."""
return parse_json_or_list(v, "items")
# Test with JSON array
schema = TestSchema.model_validate({"items": '["a", "b", "c"]'})
assert schema.items == ["a", "b", "c"]
# Test with list
schema = TestSchema.model_validate({"items": ["a", "b", "c"]})
assert schema.items == ["a", "b", "c"]
# Test with CSV string
schema = TestSchema.model_validate({"items": "a, b, c"})
assert schema.items == ["a", "b", "c"]
class TestEdgeCases:
"""Test edge cases and error conditions."""
def test_empty_json_object(self):
"""Should handle empty JSON objects."""
result = parse_json_or_passthrough("{}", "config")
assert result == {}
def test_empty_json_array(self):
"""Should handle empty JSON arrays."""
result = parse_json_or_list("[]", "items")
assert result == []
def test_whitespace_only_string(self):
"""Should handle whitespace-only strings."""
result = parse_json_or_list(" ", "items")
assert result == []
def test_malformed_json(self):
"""Should handle malformed JSON gracefully."""
result = parse_json_or_passthrough('{"key": invalid}', "config", strict=False)
assert result == '{"key": invalid}'
def test_unicode_in_json(self):
"""Should handle Unicode characters in JSON."""
result = parse_json_or_passthrough('{"name": "测试"}', "config")
assert result == {"name": "测试"}
def test_special_characters_in_csv(self):
"""Should handle special characters in CSV strings."""
result = parse_json_or_list("item-1, item_2, item.3", "items")
assert result == ["item-1", "item_2", "item.3"]
class TestParseRequestDecorator:
"""Test parse_request decorator for MCP tools."""
class RequestModel(BaseModel):
"""Test request model."""
name: str
count: int
@pytest.fixture(autouse=True)
def _enable_parse_request(self):
"""Ensure MCP_PARSE_REQUEST_ENABLED=True for all parsing tests."""
from unittest.mock import patch
with patch(
"superset.mcp_service.utils.schema_utils._is_parse_request_enabled",
return_value=True,
):
yield
def test_decorator_with_json_string_async(self):
"""Should parse JSON string request in async function."""
from unittest.mock import MagicMock, patch
@parse_request(self.RequestModel)
async def async_tool(request, ctx=None):
return f"{request.name}:{request.count}"
import asyncio
mock_ctx = MagicMock()
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
result = asyncio.run(async_tool('{"name": "test", "count": 5}'))
assert result == "test:5"
def test_decorator_with_json_string_sync(self):
"""Should parse JSON string request in sync function."""
from unittest.mock import MagicMock, patch
@parse_request(self.RequestModel)
def sync_tool(request, ctx=None):
return f"{request.name}:{request.count}"
mock_ctx = MagicMock()
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
result = sync_tool('{"name": "test", "count": 5}')
assert result == "test:5"
def test_decorator_with_dict_async(self):
"""Should handle dict request in async function."""
from unittest.mock import MagicMock, patch
@parse_request(self.RequestModel)
async def async_tool(request, ctx=None):
return f"{request.name}:{request.count}"
import asyncio
mock_ctx = MagicMock()
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
result = asyncio.run(async_tool({"name": "test", "count": 5}))
assert result == "test:5"
def test_decorator_with_dict_sync(self):
"""Should handle dict request in sync function."""
from unittest.mock import MagicMock, patch
@parse_request(self.RequestModel)
def sync_tool(request, ctx=None):
return f"{request.name}:{request.count}"
mock_ctx = MagicMock()
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
result = sync_tool({"name": "test", "count": 5})
assert result == "test:5"
def test_decorator_with_model_instance_async(self):
"""Should pass through model instance in async function."""
from unittest.mock import MagicMock, patch
@parse_request(self.RequestModel)
async def async_tool(request, ctx=None):
return f"{request.name}:{request.count}"
import asyncio
mock_ctx = MagicMock()
instance = self.RequestModel(name="test", count=5)
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
result = asyncio.run(async_tool(instance))
assert result == "test:5"
def test_decorator_with_model_instance_sync(self):
"""Should pass through model instance in sync function."""
from unittest.mock import MagicMock, patch
@parse_request(self.RequestModel)
def sync_tool(request, ctx=None):
return f"{request.name}:{request.count}"
mock_ctx = MagicMock()
instance = self.RequestModel(name="test", count=5)
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
result = sync_tool(instance)
assert result == "test:5"
def test_decorator_preserves_function_signature_async(self):
"""Should preserve original async function signature."""
from unittest.mock import MagicMock, patch
@parse_request(self.RequestModel)
async def async_tool(request, ctx=None, extra=None):
return f"{request.name}:{request.count}:{extra}"
import asyncio
mock_ctx = MagicMock()
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
result = asyncio.run(
async_tool('{"name": "test", "count": 5}', extra="data")
)
assert result == "test:5:data"
def test_decorator_preserves_function_signature_sync(self):
"""Should preserve original sync function signature."""
from unittest.mock import MagicMock, patch
@parse_request(self.RequestModel)
def sync_tool(request, ctx=None, extra=None):
return f"{request.name}:{request.count}:{extra}"
mock_ctx = MagicMock()
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
result = sync_tool('{"name": "test", "count": 5}', extra="data")
assert result == "test:5:data"
def test_decorator_raises_tool_error_for_invalid_data_async(self):
"""Should raise ToolError with field details for invalid data."""
from unittest.mock import MagicMock, patch
from fastmcp.exceptions import ToolError
@parse_request(self.RequestModel)
async def async_tool(request, ctx=None):
return f"{request.name}:{request.count}"
import asyncio
mock_ctx = MagicMock()
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
with pytest.raises(ToolError, match="Required fields for RequestModel"):
asyncio.run(async_tool('{"name": "test"}')) # Missing count
def test_decorator_raises_tool_error_for_invalid_data_sync(self):
"""Should raise ToolError with field details for invalid data."""
from unittest.mock import MagicMock, patch
from fastmcp.exceptions import ToolError
@parse_request(self.RequestModel)
def sync_tool(request, ctx=None):
return f"{request.name}:{request.count}"
mock_ctx = MagicMock()
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
with pytest.raises(ToolError, match="Required fields for RequestModel"):
sync_tool('{"name": "test"}') # Missing count
def test_decorator_with_complex_model_async(self):
"""Should handle complex nested models in async function."""
from unittest.mock import MagicMock, patch
class NestedModel(BaseModel):
"""Nested model."""
value: int
class ComplexModel(BaseModel):
"""Complex request model."""
name: str
nested: NestedModel
@parse_request(ComplexModel)
async def async_tool(request, ctx=None):
return f"{request.name}:{request.nested.value}"
import asyncio
mock_ctx = MagicMock()
json_str = '{"name": "test", "nested": {"value": 42}}'
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
result = asyncio.run(async_tool(json_str))
assert result == "test:42"
def test_decorator_with_complex_model_sync(self):
"""Should handle complex nested models in sync function."""
from unittest.mock import MagicMock, patch
class NestedModel(BaseModel):
"""Nested model."""
value: int
class ComplexModel(BaseModel):
"""Complex request model."""
name: str
nested: NestedModel
@parse_request(ComplexModel)
def sync_tool(request, ctx=None):
return f"{request.name}:{request.nested.value}"
mock_ctx = MagicMock()
json_str = '{"name": "test", "nested": {"value": 42}}'
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
result = sync_tool(json_str)
assert result == "test:42"