fix(mcp): Improve validation errors and field aliases to reduce failed LLM tool calls (#38625)

This commit is contained in:
Kamil Gabryjelski
2026-03-13 11:16:50 +01:00
committed by GitHub
parent 56d6bb1913
commit d91b96814e
6 changed files with 56 additions and 15 deletions

View File

@@ -70,7 +70,7 @@ Chart Management:
SQL Lab Integration: SQL Lab Integration:
- execute_sql: Execute SQL queries and get results (requires database_id) - execute_sql: Execute SQL queries and get results (requires database_id)
- open_sql_lab_with_context: Generate SQL Lab URL with pre-filled query - open_sql_lab_with_context: Generate SQL Lab URL with pre-filled sql
Schema Discovery: Schema Discovery:
- get_schema: Get schema metadata for chart/dataset/dashboard (columns, filters) - get_schema: Get schema metadata for chart/dataset/dashboard (columns, filters)
@@ -103,7 +103,7 @@ To find your own charts/dashboards:
"opr": "eq", "value": current_user.id}}]) "opr": "eq", "value": current_user.id}}])
To explore data with SQL: To explore data with SQL:
1. get_instance_info -> find database_id 1. list_datasets -> find a dataset and note its database_id
2. execute_sql(database_id, sql) -> run query 2. execute_sql(database_id, sql) -> run query
3. open_sql_lab_with_context(database_id) -> open SQL Lab UI 3. open_sql_lab_with_context(database_id) -> open SQL Lab UI

View File

@@ -19,7 +19,7 @@
from typing import Any from typing import Any
from pydantic import BaseModel, Field, field_validator from pydantic import AliasChoices, BaseModel, Field, field_validator
class ExecuteSqlRequest(BaseModel): class ExecuteSqlRequest(BaseModel):
@@ -119,7 +119,9 @@ class OpenSqlLabRequest(BaseModel):
"""Request schema for opening SQL Lab with context.""" """Request schema for opening SQL Lab with context."""
database_connection_id: int = Field( database_connection_id: int = Field(
..., description="Database connection ID to use in SQL Lab" ...,
description="Database connection ID to use in SQL Lab",
validation_alias=AliasChoices("database_connection_id", "database_id"),
) )
schema_name: str | None = Field( schema_name: str | None = Field(
None, description="Default schema to select in SQL Lab", alias="schema" None, description="Default schema to select in SQL Lab", alias="schema"
@@ -127,7 +129,11 @@ class OpenSqlLabRequest(BaseModel):
dataset_in_context: str | None = Field( dataset_in_context: str | None = Field(
None, description="Dataset name/table to provide as context" None, description="Dataset name/table to provide as context"
) )
sql: str | None = Field(None, description="SQL query to pre-populate in the editor") sql: str | None = Field(
None,
description="SQL to pre-populate in the editor",
validation_alias=AliasChoices("sql", "query"),
)
title: str | None = Field(None, description="Title for the SQL Lab tab/query") title: str | None = Field(None, description="Title for the SQL Lab tab/query")

View File

@@ -190,7 +190,20 @@ def _convert_to_response(result: QueryResult) -> ExecuteSqlResponse:
if data_stmt is not None and data_stmt.data is not None: if data_stmt is not None and data_stmt.data is not None:
# SELECT query - convert DataFrame # SELECT query - convert DataFrame
import pandas as pd
df = data_stmt.data df = data_stmt.data
if not isinstance(df, pd.DataFrame):
logger.error(
"Expected DataFrame but got %s for statement data",
type(df).__name__,
)
return ExecuteSqlResponse(
success=False,
error=f"Internal error: unexpected data type ({type(df).__name__})",
error_type="data_conversion_error",
statements=statements,
)
rows = df.to_dict(orient="records") rows = df.to_dict(orient="records")
columns = [ColumnInfo(name=col, type=str(df[col].dtype)) for col in df.columns] columns = [ColumnInfo(name=col, type=str(df[col].dtype)) for col in df.columns]
row_count = len(df) row_count = len(df)

View File

@@ -18,7 +18,7 @@
""" """
Open SQL Lab with Context MCP Tool Open SQL Lab with Context MCP Tool
Tool for generating SQL Lab URLs with pre-populated query and context. Tool for generating SQL Lab URLs with pre-populated sql and context.
""" """
import logging import logging
@@ -43,9 +43,9 @@ logger = logging.getLogger(__name__)
def open_sql_lab_with_context( def open_sql_lab_with_context(
request: OpenSqlLabRequest, ctx: Context request: OpenSqlLabRequest, ctx: Context
) -> SqlLabResponse: ) -> SqlLabResponse:
"""Generate SQL Lab URL with pre-populated query and context. """Generate SQL Lab URL with pre-populated sql and context.
Returns URL for direct navigation. Pass the sql parameter to pre-fill the editor. Returns URL for direct navigation.
""" """
try: try:
from superset.daos.database import DatabaseDAO from superset.daos.database import DatabaseDAO

View File

@@ -448,7 +448,25 @@ def parse_request(
def _maybe_parse(request: Any) -> Any: def _maybe_parse(request: Any) -> Any:
if _is_parse_request_enabled(): if _is_parse_request_enabled():
return parse_json_or_model(request, request_class, "request") try:
return parse_json_or_model(request, request_class, "request")
except ValidationError as e:
from fastmcp.exceptions import ToolError
details = []
for err in e.errors():
field = " -> ".join(str(loc) for loc in err["loc"])
details.append(f"{field}: {err['msg']}")
required_fields = [
f.alias or name
for name, f in request_class.model_fields.items()
if f.is_required()
]
raise ToolError(
f"Invalid request parameters: {'; '.join(details)}. "
f"Required fields for {request_class.__name__}: "
f"{', '.join(required_fields)}"
) from None
return request return request
if asyncio.iscoroutinefunction(func): if asyncio.iscoroutinefunction(func):

View File

@@ -481,10 +481,12 @@ class TestParseRequestDecorator:
result = sync_tool('{"name": "test", "count": 5}', extra="data") result = sync_tool('{"name": "test", "count": 5}', extra="data")
assert result == "test:5:data" assert result == "test:5:data"
def test_decorator_raises_validation_error_async(self): def test_decorator_raises_tool_error_for_invalid_data_async(self):
"""Should raise ValidationError for invalid data in async function.""" """Should raise ToolError with field details for invalid data."""
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from fastmcp.exceptions import ToolError
@parse_request(self.RequestModel) @parse_request(self.RequestModel)
async def async_tool(request, ctx=None): async def async_tool(request, ctx=None):
return f"{request.name}:{request.count}" return f"{request.name}:{request.count}"
@@ -493,20 +495,22 @@ class TestParseRequestDecorator:
mock_ctx = MagicMock() mock_ctx = MagicMock()
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx): with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
with pytest.raises(ValidationError): with pytest.raises(ToolError, match="Required fields for RequestModel"):
asyncio.run(async_tool('{"name": "test"}')) # Missing count asyncio.run(async_tool('{"name": "test"}')) # Missing count
def test_decorator_raises_validation_error_sync(self): def test_decorator_raises_tool_error_for_invalid_data_sync(self):
"""Should raise ValidationError for invalid data in sync function.""" """Should raise ToolError with field details for invalid data."""
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from fastmcp.exceptions import ToolError
@parse_request(self.RequestModel) @parse_request(self.RequestModel)
def sync_tool(request, ctx=None): def sync_tool(request, ctx=None):
return f"{request.name}:{request.count}" return f"{request.name}:{request.count}"
mock_ctx = MagicMock() mock_ctx = MagicMock()
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx): with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
with pytest.raises(ValidationError): with pytest.raises(ToolError, match="Required fields for RequestModel"):
sync_tool('{"name": "test"}') # Missing count sync_tool('{"name": "test"}') # Missing count
def test_decorator_with_complex_model_async(self): def test_decorator_with_complex_model_async(self):