mirror of
https://github.com/apache/superset.git
synced 2026-04-07 18:35:15 +00:00
fix(mcp): Improve validation errors and field aliases to reduce failed LLM tool calls (#38625)
This commit is contained in:
committed by
GitHub
parent
56d6bb1913
commit
d91b96814e
@@ -70,7 +70,7 @@ Chart Management:
|
||||
|
||||
SQL Lab Integration:
|
||||
- execute_sql: Execute SQL queries and get results (requires database_id)
|
||||
- open_sql_lab_with_context: Generate SQL Lab URL with pre-filled query
|
||||
- open_sql_lab_with_context: Generate SQL Lab URL with pre-filled sql
|
||||
|
||||
Schema Discovery:
|
||||
- get_schema: Get schema metadata for chart/dataset/dashboard (columns, filters)
|
||||
@@ -103,7 +103,7 @@ To find your own charts/dashboards:
|
||||
"opr": "eq", "value": current_user.id}}])
|
||||
|
||||
To explore data with SQL:
|
||||
1. get_instance_info -> find database_id
|
||||
1. list_datasets -> find a dataset and note its database_id
|
||||
2. execute_sql(database_id, sql) -> run query
|
||||
3. open_sql_lab_with_context(database_id) -> open SQL Lab UI
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import AliasChoices, BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class ExecuteSqlRequest(BaseModel):
|
||||
@@ -119,7 +119,9 @@ class OpenSqlLabRequest(BaseModel):
|
||||
"""Request schema for opening SQL Lab with context."""
|
||||
|
||||
database_connection_id: int = Field(
|
||||
..., description="Database connection ID to use in SQL Lab"
|
||||
...,
|
||||
description="Database connection ID to use in SQL Lab",
|
||||
validation_alias=AliasChoices("database_connection_id", "database_id"),
|
||||
)
|
||||
schema_name: str | None = Field(
|
||||
None, description="Default schema to select in SQL Lab", alias="schema"
|
||||
@@ -127,7 +129,11 @@ class OpenSqlLabRequest(BaseModel):
|
||||
dataset_in_context: str | None = Field(
|
||||
None, description="Dataset name/table to provide as context"
|
||||
)
|
||||
sql: str | None = Field(None, description="SQL query to pre-populate in the editor")
|
||||
sql: str | None = Field(
|
||||
None,
|
||||
description="SQL to pre-populate in the editor",
|
||||
validation_alias=AliasChoices("sql", "query"),
|
||||
)
|
||||
title: str | None = Field(None, description="Title for the SQL Lab tab/query")
|
||||
|
||||
|
||||
|
||||
@@ -190,7 +190,20 @@ def _convert_to_response(result: QueryResult) -> ExecuteSqlResponse:
|
||||
|
||||
if data_stmt is not None and data_stmt.data is not None:
|
||||
# SELECT query - convert DataFrame
|
||||
import pandas as pd
|
||||
|
||||
df = data_stmt.data
|
||||
if not isinstance(df, pd.DataFrame):
|
||||
logger.error(
|
||||
"Expected DataFrame but got %s for statement data",
|
||||
type(df).__name__,
|
||||
)
|
||||
return ExecuteSqlResponse(
|
||||
success=False,
|
||||
error=f"Internal error: unexpected data type ({type(df).__name__})",
|
||||
error_type="data_conversion_error",
|
||||
statements=statements,
|
||||
)
|
||||
rows = df.to_dict(orient="records")
|
||||
columns = [ColumnInfo(name=col, type=str(df[col].dtype)) for col in df.columns]
|
||||
row_count = len(df)
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
"""
|
||||
Open SQL Lab with Context MCP Tool
|
||||
|
||||
Tool for generating SQL Lab URLs with pre-populated query and context.
|
||||
Tool for generating SQL Lab URLs with pre-populated sql and context.
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -43,9 +43,9 @@ logger = logging.getLogger(__name__)
|
||||
def open_sql_lab_with_context(
|
||||
request: OpenSqlLabRequest, ctx: Context
|
||||
) -> SqlLabResponse:
|
||||
"""Generate SQL Lab URL with pre-populated query and context.
|
||||
"""Generate SQL Lab URL with pre-populated sql and context.
|
||||
|
||||
Returns URL for direct navigation.
|
||||
Pass the sql parameter to pre-fill the editor. Returns URL for direct navigation.
|
||||
"""
|
||||
try:
|
||||
from superset.daos.database import DatabaseDAO
|
||||
|
||||
@@ -448,7 +448,25 @@ def parse_request(
|
||||
|
||||
def _maybe_parse(request: Any) -> Any:
|
||||
if _is_parse_request_enabled():
|
||||
return parse_json_or_model(request, request_class, "request")
|
||||
try:
|
||||
return parse_json_or_model(request, request_class, "request")
|
||||
except ValidationError as e:
|
||||
from fastmcp.exceptions import ToolError
|
||||
|
||||
details = []
|
||||
for err in e.errors():
|
||||
field = " -> ".join(str(loc) for loc in err["loc"])
|
||||
details.append(f"{field}: {err['msg']}")
|
||||
required_fields = [
|
||||
f.alias or name
|
||||
for name, f in request_class.model_fields.items()
|
||||
if f.is_required()
|
||||
]
|
||||
raise ToolError(
|
||||
f"Invalid request parameters: {'; '.join(details)}. "
|
||||
f"Required fields for {request_class.__name__}: "
|
||||
f"{', '.join(required_fields)}"
|
||||
) from None
|
||||
return request
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
|
||||
@@ -481,10 +481,12 @@ class TestParseRequestDecorator:
|
||||
result = sync_tool('{"name": "test", "count": 5}', extra="data")
|
||||
assert result == "test:5:data"
|
||||
|
||||
def test_decorator_raises_validation_error_async(self):
|
||||
"""Should raise ValidationError for invalid data in async function."""
|
||||
def test_decorator_raises_tool_error_for_invalid_data_async(self):
|
||||
"""Should raise ToolError with field details for invalid data."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from fastmcp.exceptions import ToolError
|
||||
|
||||
@parse_request(self.RequestModel)
|
||||
async def async_tool(request, ctx=None):
|
||||
return f"{request.name}:{request.count}"
|
||||
@@ -493,20 +495,22 @@ class TestParseRequestDecorator:
|
||||
|
||||
mock_ctx = MagicMock()
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
with pytest.raises(ValidationError):
|
||||
with pytest.raises(ToolError, match="Required fields for RequestModel"):
|
||||
asyncio.run(async_tool('{"name": "test"}')) # Missing count
|
||||
|
||||
def test_decorator_raises_validation_error_sync(self):
|
||||
"""Should raise ValidationError for invalid data in sync function."""
|
||||
def test_decorator_raises_tool_error_for_invalid_data_sync(self):
|
||||
"""Should raise ToolError with field details for invalid data."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from fastmcp.exceptions import ToolError
|
||||
|
||||
@parse_request(self.RequestModel)
|
||||
def sync_tool(request, ctx=None):
|
||||
return f"{request.name}:{request.count}"
|
||||
|
||||
mock_ctx = MagicMock()
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
with pytest.raises(ValidationError):
|
||||
with pytest.raises(ToolError, match="Required fields for RequestModel"):
|
||||
sync_tool('{"name": "test"}') # Missing count
|
||||
|
||||
def test_decorator_with_complex_model_async(self):
|
||||
|
||||
Reference in New Issue
Block a user