From ac8d6b0c5323df4b43d96102d8e1264279351a43 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Tue, 24 Mar 2026 05:41:24 -0400 Subject: [PATCH] fix(mcp): prevent encoding errors and fix tool bugs on MCP client transports (#38786) (cherry picked from commit ed3c5280a9340237277156858ea82719e8ba05a5) --- superset/mcp_service/auth.py | 17 ++++++ .../tool/add_chart_to_existing_dashboard.py | 8 +-- .../dashboard/tool/generate_dashboard.py | 27 +++++--- superset/mcp_service/middleware.py | 61 ++++++++++++++++++- superset/mcp_service/server.py | 9 ++- .../mcp_service/sql_lab/tool/execute_sql.py | 51 +++++++++++++--- .../sql_lab/tool/save_sql_query.py | 6 +- superset/mcp_service/utils/schema_utils.py | 24 +++++++- superset/sql/execution/executor.py | 13 +++- .../tool/test_dashboard_generation.py | 2 + .../sql_lab/tool/test_save_sql_query.py | 6 ++ 11 files changed, 193 insertions(+), 31 deletions(-) diff --git a/superset/mcp_service/auth.py b/superset/mcp_service/auth.py index 77544f224cb..41484d444cf 100644 --- a/superset/mcp_service/auth.py +++ b/superset/mcp_service/auth.py @@ -277,6 +277,23 @@ def _setup_user_context() -> User | None: logger.debug("No Flask app context available for user setup") return None raise + except ValueError as e: + # JWT user resolution failed (e.g. SAML subject not in DB). + # If middleware already set g.user (request context exists), + # use that instead of failing closed. + from flask import has_request_context + + if has_request_context() and hasattr(g, "user") and g.user: + logger.warning( + "JWT user resolution failed (%s), using middleware-provided g.user=%s", + e, + g.user.username, + ) + # Assign to local so relationship validation below runs + # (same as the normal path) to prevent detached instance errors. + user = g.user + else: + raise # Validate user has necessary relationships loaded # (Force access to ensure they're loaded if lazy) diff --git a/superset/mcp_service/dashboard/tool/add_chart_to_existing_dashboard.py b/superset/mcp_service/dashboard/tool/add_chart_to_existing_dashboard.py index 93888701451..f9ce4f319a0 100644 --- a/superset/mcp_service/dashboard/tool/add_chart_to_existing_dashboard.py +++ b/superset/mcp_service/dashboard/tool/add_chart_to_existing_dashboard.py @@ -442,12 +442,8 @@ def add_chart_to_existing_dashboard( published=updated_dashboard.published, created_on=updated_dashboard.created_on, changed_on=updated_dashboard.changed_on, - created_by=updated_dashboard.created_by.username - if updated_dashboard.created_by - else None, - changed_by=updated_dashboard.changed_by.username - if updated_dashboard.changed_by - else None, + created_by=updated_dashboard.created_by_name or None, + changed_by=updated_dashboard.changed_by_name or None, uuid=str(updated_dashboard.uuid) if updated_dashboard.uuid else None, url=f"{get_superset_base_url()}/superset/dashboard/{updated_dashboard.id}/", chart_count=len(updated_dashboard.slices), diff --git a/superset/mcp_service/dashboard/tool/generate_dashboard.py b/superset/mcp_service/dashboard/tool/generate_dashboard.py index 3cddf6d4c5e..1a25e1120b2 100644 --- a/superset/mcp_service/dashboard/tool/generate_dashboard.py +++ b/superset/mcp_service/dashboard/tool/generate_dashboard.py @@ -237,10 +237,8 @@ def generate_dashboard( # Prepare dashboard data and create dashboard with event_logger.log_context(action="mcp.generate_dashboard.db_write"): - dashboard_data = { + dashboard_data: Dict[str, Any] = { "dashboard_title": dashboard_title, - "slug": None, # Let Superset auto-generate slug - "css": "", "json_metadata": json.dumps( { "filter_scopes": {}, @@ -271,8 +269,23 @@ def generate_dashboard( dashboard_data["description"] = request.description # Create the dashboard using Superset's command pattern - command = CreateDashboardCommand(dashboard_data) - dashboard = command.run() + try: + command = CreateDashboardCommand(dashboard_data) + dashboard = command.run() + except Exception as cmd_err: + # Surface the root cause from @transaction's error wrapping + root_cause = cmd_err.__cause__ or cmd_err + logger.error( + "CreateDashboardCommand failed: %s (cause: %s)", + cmd_err, + root_cause, + exc_info=True, + ) + return GenerateDashboardResponse( + dashboard=None, + dashboard_url=None, + error=f"Failed to create dashboard: {root_cause}", + ) # Re-fetch the dashboard with eager-loaded relationships to avoid # "Instance is not bound to a Session" errors when serializing @@ -309,8 +322,8 @@ def generate_dashboard( published=dashboard.published, created_on=dashboard.created_on, changed_on=dashboard.changed_on, - created_by=dashboard.created_by.username if dashboard.created_by else None, - changed_by=dashboard.changed_by.username if dashboard.changed_by else None, + created_by=dashboard.created_by_name or None, + changed_by=dashboard.changed_by_name or None, uuid=str(dashboard.uuid) if dashboard.uuid else None, url=f"{get_superset_base_url()}/superset/dashboard/{dashboard.id}/", chart_count=len(request.chart_ids), diff --git a/superset/mcp_service/middleware.py b/superset/mcp_service/middleware.py index 0fadd648ae8..b0ffc4f5f7c 100644 --- a/superset/mcp_service/middleware.py +++ b/superset/mcp_service/middleware.py @@ -18,10 +18,13 @@ import logging import time from collections import defaultdict -from typing import Any, Awaitable, Callable, Dict, Protocol +from typing import Any, Awaitable, Callable, Dict, Protocol, Sequence +import mcp.types as mt from fastmcp.exceptions import ToolError from fastmcp.server.middleware import Middleware, MiddlewareContext +from fastmcp.server.middleware.middleware import CallNext +from fastmcp.tools.tool import Tool, ToolResult from flask import has_app_context from pydantic import ValidationError from sqlalchemy.exc import OperationalError, TimeoutError @@ -259,6 +262,62 @@ class PrivateToolMiddleware(Middleware): return await call_next(context) +class StructuredContentStripperMiddleware(Middleware): + """Strip ``outputSchema`` and ``structured_content`` to prevent encoding errors. + + FastMCP 3.x auto-generates ``outputSchema`` in tool definitions + (``tools/list``) and ``structuredContent`` in tool call responses + (``tools/call``) when the tool has a typed return annotation. + + Some MCP client transports (e.g. Claude.ai's MCP bridge) cannot handle + ``structuredContent`` dicts, causing ``TypeError: encoding without a + string argument``. Additionally, if ``outputSchema`` is advertised but + ``structuredContent`` is stripped from the response, clients may raise + ``Output validation error: outputSchema defined but no structured output + returned``. + + This middleware handles both sides: + - ``on_list_tools``: removes ``output_schema`` from every tool definition + - ``on_call_tool``: removes ``structured_content`` from every tool result + """ + + async def on_list_tools( + self, + context: MiddlewareContext[mt.ListToolsRequest], + call_next: CallNext[mt.ListToolsRequest, Sequence[Tool]], + ) -> Sequence[Tool]: + tools = await call_next(context) + return [ + t.model_copy(update={"output_schema": None}) + if t.output_schema is not None + else t + for t in tools + ] + + async def on_call_tool( + self, + context: MiddlewareContext[mt.CallToolRequestParams], + call_next: Callable[[MiddlewareContext], Awaitable[ToolResult]], + ) -> ToolResult: + try: + result = await call_next(context) + except Exception as e: + # When exceptions propagate past the middleware chain to the + # MCP SDK layer, they become CallToolResult(isError=True). + # Some transports (Claude.ai's MCP bridge) cannot encode these + # error responses, producing "encoding without a string argument". + # Catch ALL exceptions (not just specific types) because any + # unhandled exception — including ToolError from + # GlobalErrorHandlerMiddleware, ValueError, TypeError, etc. — + # will cause encoding failures on the wire. + return ToolResult( + content=[mt.TextContent(type="text", text=f"Error: {e}")], + ) + if isinstance(result, ToolResult) and result.structured_content is not None: + result = ToolResult(content=result.content, meta=result.meta) + return result + + class GlobalErrorHandlerMiddleware(Middleware): """ Global error handler middleware that provides consistent error responses diff --git a/superset/mcp_service/server.py b/superset/mcp_service/server.py index 73a078e7dda..03f089ddff0 100644 --- a/superset/mcp_service/server.py +++ b/superset/mcp_service/server.py @@ -39,6 +39,7 @@ from superset.mcp_service.middleware import ( create_response_size_guard_middleware, GlobalErrorHandlerMiddleware, LoggingMiddleware, + StructuredContentStripperMiddleware, ) from superset.mcp_service.storage import _create_redis_store from superset.utils import json @@ -436,9 +437,15 @@ def run_server( # Add logging middleware (logs all tool calls with duration tracking) middleware_list.append(LoggingMiddleware()) - # Add global error handler (outermost – catches all exceptions) + # Add global error handler (catches all exceptions, raises ToolError) middleware_list.append(GlobalErrorHandlerMiddleware()) + # Strip outputSchema from tool definitions and structuredContent from + # tool responses to prevent encoding errors on Claude.ai's MCP bridge. + # MUST be outermost so it catches ToolError from GlobalErrorHandler + # and converts to plain text before the MCP SDK tries to encode it. + middleware_list.append(StructuredContentStripperMiddleware()) + mcp_instance = init_fastmcp_server( auth=auth_provider, middleware=middleware_list or None, diff --git a/superset/mcp_service/sql_lab/tool/execute_sql.py b/superset/mcp_service/sql_lab/tool/execute_sql.py index 61f00889200..8617bbcf9f1 100644 --- a/superset/mcp_service/sql_lab/tool/execute_sql.py +++ b/superset/mcp_service/sql_lab/tool/execute_sql.py @@ -28,6 +28,7 @@ import logging from decimal import Decimal from typing import Any +import pandas as pd from fastmcp import Context from superset_core.mcp.decorators import tool, ToolAnnotations from superset_core.queries.types import ( @@ -176,6 +177,46 @@ def _sanitize_row_values(rows: list[dict[str, Any]]) -> None: row[key] = str(value) +def _data_to_statement_data(data: Any) -> StatementData: + """Convert statement data (DataFrame, list, dict, bytes) to StatementData. + + When results come from cache, data may be a dict/list/bytes instead of + a pandas DataFrame. This function handles all cases defensively. + """ + from superset.utils import json as json_utils + + if isinstance(data, list): + rows_data = data + elif isinstance(data, dict): + rows_data = data.get("data", [data]) + if not isinstance(rows_data, list): + rows_data = [rows_data] + elif isinstance(data, pd.DataFrame): + rows_data = data.to_dict(orient="records") + _sanitize_row_values(rows_data) + return StatementData( + rows=rows_data, + columns=[ + ColumnInfo(name=col, type=str(data[col].dtype)) for col in data.columns + ], + ) + elif isinstance(data, bytes): + try: + decoded = json_utils.loads(data) + rows_data = decoded if isinstance(decoded, list) else [decoded] + except (ValueError, UnicodeDecodeError): + rows_data = [] + else: + rows_data = [{"value": str(data)}] + + _sanitize_row_values(rows_data) + col_names = list(rows_data[0].keys()) if rows_data else [] + return StatementData( + rows=rows_data, + columns=[ColumnInfo(name=col, type="object") for col in col_names], + ) + + def _convert_to_response(result: QueryResult) -> ExecuteSqlResponse: """Convert QueryResult to ExecuteSqlResponse.""" if result.status != QueryStatus.SUCCESS: @@ -193,15 +234,7 @@ def _convert_to_response(result: QueryResult) -> ExecuteSqlResponse: for stmt in result.statements: stmt_data: StatementData | None = None if stmt.data is not None: - df = stmt.data - rows_data = df.to_dict(orient="records") - _sanitize_row_values(rows_data) - stmt_data = StatementData( - rows=rows_data, - columns=[ - ColumnInfo(name=col, type=str(df[col].dtype)) for col in df.columns - ], - ) + stmt_data = _data_to_statement_data(stmt.data) data_bearing_count += 1 statements.append( diff --git a/superset/mcp_service/sql_lab/tool/save_sql_query.py b/superset/mcp_service/sql_lab/tool/save_sql_query.py index a18a50483ab..66a8a5e99f8 100644 --- a/superset/mcp_service/sql_lab/tool/save_sql_query.py +++ b/superset/mcp_service/sql_lab/tool/save_sql_query.py @@ -126,10 +126,10 @@ async def save_sql_query( id=saved_query.id, label=saved_query.label, sql=saved_query.sql, - database_id=request.database_id, - schema_name=request.schema_name, + database_id=saved_query.db_id, + schema_name=saved_query.schema or None, catalog=getattr(saved_query, "catalog", None), - description=request.description, + description=saved_query.description or None, url=saved_query_url, ) diff --git a/superset/mcp_service/utils/schema_utils.py b/superset/mcp_service/utils/schema_utils.py index 5a57ac63d13..d3d787ab339 100644 --- a/superset/mcp_service/utils/schema_utils.py +++ b/superset/mcp_service/utils/schema_utils.py @@ -467,7 +467,18 @@ def _create_string_parsing_wrapper( """ import types + # Detect empty request models (no required fields) so + # FastMCP doesn't reject calls without arguments + # (e.g. get_instance_info via call_tool proxy). + _model_has_no_required_fields = not any( + f.is_required() for f in request_class.model_fields.values() + ) + def _maybe_parse(request: Any) -> Any: + # Auto-instantiate empty request models when no arguments provided + if (request is None or request == "{}") and _model_has_no_required_fields: + return request_class() + if _is_parse_request_enabled(): try: return parse_json_or_model(request, request_class, "request") @@ -539,7 +550,12 @@ def _create_string_parsing_wrapper( new_wrapper.__doc__ = func.__doc__ request_annotation = str | request_class - _apply_signature_for_fastmcp(new_wrapper, func, request_annotation) + _apply_signature_for_fastmcp( + new_wrapper, + func, + request_annotation, + request_default="{}" if _model_has_no_required_fields else None, + ) return new_wrapper @@ -678,6 +694,7 @@ def _apply_signature_for_fastmcp( wrapper: Any, original_func: Callable[..., Any], request_annotation: Any, + request_default: Any = None, ) -> None: """Apply annotations and signature to wrapper, stripping ctx for FastMCP. @@ -705,7 +722,10 @@ def _apply_signature_for_fastmcp( if _is_context_param(param, name, FMContext): continue if name == "request": - new_params.append(param.replace(annotation=request_annotation)) + replacement = {"annotation": request_annotation} + if request_default is not None: + replacement["default"] = request_default + new_params.append(param.replace(**replacement)) else: new_params.append(param) wrapper.__signature__ = orig_sig.replace(parameters=new_params) diff --git a/superset/sql/execution/executor.py b/superset/sql/execution/executor.py index 87d9fa671c5..90be63b5d2a 100644 --- a/superset/sql/execution/executor.py +++ b/superset/sql/execution/executor.py @@ -839,13 +839,22 @@ class SQLExecutor: or app.config.get("CACHE_DEFAULT_TIMEOUT", 300) ) - # Serialize statement results for caching + # Serialize statement results for caching. + # Convert DataFrames to list-of-dicts so the cache backend + # does not need to pickle pandas objects (which can fail to + # deserialize correctly with some backends or pandas versions). + import pandas as pd + cached_data = { "statements": [ { "original_sql": stmt.original_sql, "executed_sql": stmt.executed_sql, - "data": stmt.data, + "data": ( + stmt.data.to_dict(orient="records") + if isinstance(stmt.data, pd.DataFrame) + else stmt.data + ), "row_count": stmt.row_count, "execution_time_ms": stmt.execution_time_ms, } diff --git a/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_generation.py b/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_generation.py index 7d299d32ef1..0f07432eea1 100644 --- a/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_generation.py +++ b/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_generation.py @@ -96,6 +96,8 @@ def _mock_dashboard(id: int = 1, title: str = "Test Dashboard") -> Mock: dashboard.created_by.username = "test_user" dashboard.changed_by = Mock() dashboard.changed_by.username = "test_user" + dashboard.created_by_name = "test_user" + dashboard.changed_by_name = "test_user" dashboard.uuid = f"dashboard-uuid-{id}" dashboard.slices = [] dashboard.owners = [] diff --git a/tests/unit_tests/mcp_service/sql_lab/tool/test_save_sql_query.py b/tests/unit_tests/mcp_service/sql_lab/tool/test_save_sql_query.py index 469ca9fd43c..53b242c440e 100644 --- a/tests/unit_tests/mcp_service/sql_lab/tool/test_save_sql_query.py +++ b/tests/unit_tests/mcp_service/sql_lab/tool/test_save_sql_query.py @@ -248,6 +248,9 @@ class TestSaveSqlQueryToolLogic: mock_sq.id = 42 mock_sq.label = "Revenue Query" mock_sq.sql = "SELECT SUM(revenue) FROM sales" + mock_sq.db_id = 1 + mock_sq.schema = "" + mock_sq.description = "" mock_sq.catalog = None request = SaveSqlQueryRequest( @@ -412,6 +415,9 @@ class TestSaveSqlQueryToolLogic: mock_sq.id = 10 mock_sq.label = "Test" mock_sq.sql = "SELECT 1" + mock_sq.db_id = 1 + mock_sq.schema = "public" + mock_sq.description = "" mock_sq.catalog = None request = SaveSqlQueryRequest(