mirror of
https://github.com/apache/superset.git
synced 2026-05-12 19:35:17 +00:00
update: refactor auth and model tools
This commit is contained in:
168
tests/unit_tests/mcp_service/test_model_tools.py
Normal file
168
tests/unit_tests/mcp_service/test_model_tools.py
Normal file
@@ -0,0 +1,168 @@
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from superset.mcp_service.model_tools import ModelGetInfoTool, ModelListTool
|
||||
|
||||
# Dummy Pydantic output schema
|
||||
class DummyOutputSchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
# Dummy list schema
|
||||
class DummyListSchema(BaseModel):
|
||||
items: list
|
||||
count: int
|
||||
total_count: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
has_previous: bool
|
||||
has_next: bool
|
||||
columns_requested: list
|
||||
columns_loaded: list
|
||||
filters_applied: list
|
||||
pagination: object
|
||||
timestamp: datetime
|
||||
|
||||
# Dummy error schema
|
||||
class DummyErrorSchema(BaseModel):
|
||||
error: str
|
||||
error_type: str
|
||||
timestamp: datetime
|
||||
|
||||
# Dummy DAO
|
||||
class DummyDAO:
|
||||
@classmethod
|
||||
def list(cls, **kwargs):
|
||||
# Return a list of dummy objects and a total count
|
||||
return [SimpleNamespace(id=1, name="foo"), SimpleNamespace(id=2, name="bar")], 2
|
||||
@classmethod
|
||||
def find_by_id(cls, id):
|
||||
if id == 1:
|
||||
return SimpleNamespace(id=1, name="foo")
|
||||
return None
|
||||
|
||||
def dummy_serializer(obj, columns=None):
|
||||
# Serialize mock object to DummyOutputSchema
|
||||
return DummyOutputSchema(id=obj.id, name=obj.name)
|
||||
|
||||
def test_model_list_tool_basic():
|
||||
tool = ModelListTool(
|
||||
dao_class=DummyDAO,
|
||||
output_schema=DummyOutputSchema,
|
||||
item_serializer=dummy_serializer,
|
||||
filter_type=None,
|
||||
default_columns=["id", "name"],
|
||||
search_columns=["name"],
|
||||
list_field_name="items",
|
||||
output_list_schema=DummyListSchema,
|
||||
)
|
||||
result = tool.run(page=1, page_size=2)
|
||||
assert result.count == 2
|
||||
assert result.total_count == 2
|
||||
assert isinstance(result.items[0], DummyOutputSchema)
|
||||
assert result.page == 1
|
||||
assert result.page_size == 2
|
||||
assert result.total_pages == 1
|
||||
# For page=1, ModelListTool sets has_previous=True
|
||||
assert result.has_previous is True
|
||||
assert result.has_next is False
|
||||
assert result.columns_requested == ["id", "name"]
|
||||
assert set(result.columns_loaded) == {"id", "name"}
|
||||
assert isinstance(result.timestamp, datetime)
|
||||
|
||||
def test_model_list_tool_with_filters_and_columns():
|
||||
tool = ModelListTool(
|
||||
dao_class=DummyDAO,
|
||||
output_schema=DummyOutputSchema,
|
||||
item_serializer=dummy_serializer,
|
||||
filter_type=None,
|
||||
default_columns=["id", "name"],
|
||||
search_columns=["name"],
|
||||
list_field_name="items",
|
||||
output_list_schema=DummyListSchema,
|
||||
)
|
||||
result = tool.run(filters=[{"col": "name", "opr": "eq", "value": "foo"}], select_columns=["id"])
|
||||
assert result.columns_requested == ["id"]
|
||||
assert "id" in result.columns_loaded
|
||||
|
||||
def test_model_list_tool_empty_result():
|
||||
class EmptyDAO:
|
||||
@classmethod
|
||||
def list(cls, **kwargs):
|
||||
return [], 0
|
||||
tool = ModelListTool(
|
||||
dao_class=EmptyDAO,
|
||||
output_schema=DummyOutputSchema,
|
||||
item_serializer=dummy_serializer,
|
||||
filter_type=None,
|
||||
default_columns=["id", "name"],
|
||||
search_columns=["name"],
|
||||
list_field_name="items",
|
||||
output_list_schema=DummyListSchema,
|
||||
)
|
||||
result = tool.run(page=1, page_size=10)
|
||||
assert result.count == 0
|
||||
assert result.total_count == 0
|
||||
assert result.items == []
|
||||
assert result.total_pages == 0
|
||||
assert result.has_next is False
|
||||
# For page=1 and no results, has_previous is True by ModelListTool logic
|
||||
assert result.has_previous is True
|
||||
|
||||
def test_model_list_tool_invalid_filters_json():
|
||||
tool = ModelListTool(
|
||||
dao_class=DummyDAO,
|
||||
output_schema=DummyOutputSchema,
|
||||
item_serializer=dummy_serializer,
|
||||
filter_type=None,
|
||||
default_columns=["id", "name"],
|
||||
search_columns=["name"],
|
||||
list_field_name="items",
|
||||
output_list_schema=DummyListSchema,
|
||||
)
|
||||
# Should parse JSON string filters
|
||||
result = tool.run(filters='[{"col": "name", "opr": "eq", "value": "foo"}]')
|
||||
assert result.count == 2
|
||||
|
||||
def test_model_get_info_tool_found():
|
||||
tool = ModelGetInfoTool(
|
||||
dao_class=DummyDAO,
|
||||
output_schema=DummyOutputSchema,
|
||||
error_schema=DummyErrorSchema,
|
||||
serializer=lambda obj: DummyOutputSchema(id=obj.id, name=obj.name),
|
||||
)
|
||||
result = tool.run(1)
|
||||
assert isinstance(result, DummyOutputSchema)
|
||||
assert result.id == 1
|
||||
assert result.name == "foo"
|
||||
|
||||
def test_model_get_info_tool_not_found():
|
||||
tool = ModelGetInfoTool(
|
||||
dao_class=DummyDAO,
|
||||
output_schema=DummyOutputSchema,
|
||||
error_schema=DummyErrorSchema,
|
||||
serializer=lambda obj: DummyOutputSchema(id=obj.id, name=obj.name),
|
||||
)
|
||||
result = tool.run(999)
|
||||
assert isinstance(result, DummyErrorSchema)
|
||||
assert result.error_type == "not_found"
|
||||
assert "not found" in result.error
|
||||
assert isinstance(result.timestamp, datetime)
|
||||
|
||||
def test_model_get_info_tool_exception():
|
||||
class FailingDAO:
|
||||
@classmethod
|
||||
def find_by_id(cls, id):
|
||||
raise Exception("fail")
|
||||
tool = ModelGetInfoTool(
|
||||
dao_class=FailingDAO,
|
||||
output_schema=DummyOutputSchema,
|
||||
error_schema=DummyErrorSchema,
|
||||
serializer=lambda obj: DummyOutputSchema(id=obj.id, name=obj.name),
|
||||
)
|
||||
with pytest.raises(Exception) as exc:
|
||||
tool.run(1)
|
||||
assert "fail" in str(exc.value)
|
||||
Reference in New Issue
Block a user