diff --git a/superset/mcp_service/__main__.py b/superset/mcp_service/__main__.py index 0da7df6655e..30759c6e805 100644 --- a/superset/mcp_service/__main__.py +++ b/superset/mcp_service/__main__.py @@ -44,6 +44,36 @@ if os.environ.get("FASTMCP_TRANSPORT", "stdio") == "stdio": from superset.mcp_service.app import init_fastmcp_server, mcp +def _add_default_middlewares() -> None: + """Add the standard middleware stack to the MCP instance. + + This ensures all entry points (stdio, streamable-http, etc.) get + the same protection middlewares that the Flask CLI and server.py add. + Order is innermost → outermost (last-added wraps everything). + """ + from superset.mcp_service.middleware import ( + create_response_size_guard_middleware, + GlobalErrorHandlerMiddleware, + LoggingMiddleware, + StructuredContentStripperMiddleware, + ) + + # Response size guard (innermost among these) + if size_guard := create_response_size_guard_middleware(): + mcp.add_middleware(size_guard) + limit = size_guard.token_limit + sys.stderr.write(f"[MCP] Response size guard enabled (token_limit={limit})\n") + + # Logging + mcp.add_middleware(LoggingMiddleware()) + + # Global error handler + mcp.add_middleware(GlobalErrorHandlerMiddleware()) + + # Structured content stripper (must be outermost) + mcp.add_middleware(StructuredContentStripperMiddleware()) + + def main() -> None: """ Run the MCP service in stdio mode with proper output suppression. @@ -97,6 +127,7 @@ def main() -> None: # Initialize the FastMCP server # Disable auth config for stdio mode to avoid Flask app output init_fastmcp_server() + _add_default_middlewares() # Log captured output to stderr for debugging (optional) captured = captured_output.getvalue() @@ -118,6 +149,7 @@ def main() -> None: else: # For other transports, use normal initialization init_fastmcp_server() + _add_default_middlewares() # Run with specified transport if transport == "streamable-http": diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index 8914660c4fa..2da6c3ef804 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -46,6 +46,7 @@ from superset.mcp_service.common.cache_schemas import ( QueryCacheControl, ) from superset.mcp_service.common.error_schemas import ChartGenerationError +from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE from superset.mcp_service.system.schemas import ( PaginationInfo, serialize_user_object, @@ -1094,7 +1095,13 @@ class ListChartsRequest(MetadataCacheControl): Field(default=1, description="Page number for pagination (1-based)"), ] page_size: Annotated[ - PositiveInt, Field(default=10, description="Number of items per page") + int, + Field( + default=DEFAULT_PAGE_SIZE, + gt=0, + le=MAX_PAGE_SIZE, + description=f"Number of items per page (max {MAX_PAGE_SIZE})", + ), ] @model_validator(mode="after") diff --git a/superset/mcp_service/constants.py b/superset/mcp_service/constants.py index 7abf91147a8..a23a7949e94 100644 --- a/superset/mcp_service/constants.py +++ b/superset/mcp_service/constants.py @@ -16,6 +16,10 @@ # under the License. """Constants for the MCP service.""" +# Pagination defaults +DEFAULT_PAGE_SIZE = 10 # Default number of items per page +MAX_PAGE_SIZE = 100 # Maximum allowed page_size to prevent oversized responses + # Response size guard defaults DEFAULT_TOKEN_LIMIT = 25_000 # ~25k tokens prevents overwhelming LLM context windows DEFAULT_WARN_THRESHOLD_PCT = 80 # Log warnings above 80% of limit diff --git a/superset/mcp_service/dashboard/schemas.py b/superset/mcp_service/dashboard/schemas.py index 97661d1c794..99d42b2fa35 100644 --- a/superset/mcp_service/dashboard/schemas.py +++ b/superset/mcp_service/dashboard/schemas.py @@ -84,6 +84,7 @@ if TYPE_CHECKING: from superset.daos.base import ColumnOperator, ColumnOperatorEnum from superset.mcp_service.chart.schemas import ChartInfo, serialize_chart_object from superset.mcp_service.common.cache_schemas import MetadataCacheControl +from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE from superset.mcp_service.system.schemas import ( PaginationInfo, RoleInfo, @@ -244,7 +245,13 @@ class ListDashboardsRequest(MetadataCacheControl): Field(default=1, description="Page number for pagination (1-based)"), ] page_size: Annotated[ - PositiveInt, Field(default=10, description="Number of items per page") + int, + Field( + default=DEFAULT_PAGE_SIZE, + gt=0, + le=MAX_PAGE_SIZE, + description=f"Number of items per page (max {MAX_PAGE_SIZE})", + ), ] @model_validator(mode="after") diff --git a/superset/mcp_service/dataset/schemas.py b/superset/mcp_service/dataset/schemas.py index 1fc5d67e112..2d88b8bbc38 100644 --- a/superset/mcp_service/dataset/schemas.py +++ b/superset/mcp_service/dataset/schemas.py @@ -35,6 +35,7 @@ from pydantic import ( from superset.daos.base import ColumnOperator, ColumnOperatorEnum from superset.mcp_service.common.cache_schemas import MetadataCacheControl +from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE from superset.mcp_service.system.schemas import ( PaginationInfo, serialize_user_object, @@ -247,7 +248,13 @@ class ListDatasetsRequest(MetadataCacheControl): Field(default=1, description="Page number for pagination (1-based)"), ] page_size: Annotated[ - PositiveInt, Field(default=10, description="Number of items per page") + int, + Field( + default=DEFAULT_PAGE_SIZE, + gt=0, + le=MAX_PAGE_SIZE, + description=f"Number of items per page (max {MAX_PAGE_SIZE})", + ), ] @model_validator(mode="after") diff --git a/superset/mcp_service/mcp_core.py b/superset/mcp_service/mcp_core.py index 0b1a817d26d..051aa20809a 100644 --- a/superset/mcp_service/mcp_core.py +++ b/superset/mcp_service/mcp_core.py @@ -142,6 +142,11 @@ class ModelListCore(BaseCore, Generic[L]): page: int = 0, page_size: int = 10, ) -> L: + from superset.mcp_service.constants import MAX_PAGE_SIZE + + # Clamp page_size to MAX_PAGE_SIZE as defense-in-depth + page_size = min(page_size, MAX_PAGE_SIZE) + # Parse filters using generic utility (accepts JSON string or object) from superset.mcp_service.utils.schema_utils import ( parse_json_or_list, diff --git a/tests/unit_tests/mcp_service/chart/tool/test_list_charts.py b/tests/unit_tests/mcp_service/chart/tool/test_list_charts.py index c7fa505ac93..c5ef6469ecc 100644 --- a/tests/unit_tests/mcp_service/chart/tool/test_list_charts.py +++ b/tests/unit_tests/mcp_service/chart/tool/test_list_charts.py @@ -28,6 +28,7 @@ from superset.mcp_service.chart.schemas import ( ChartFilter, ListChartsRequest, ) +from superset.mcp_service.constants import MAX_PAGE_SIZE @pytest.fixture @@ -133,6 +134,19 @@ class TestListChartsRequestSchema: with pytest.raises(ValueError, match="Input should be greater than 0"): ListChartsRequest(page_size=0) + def test_page_size_exceeds_max(self): + """Test that page_size over MAX_PAGE_SIZE raises validation error.""" + with pytest.raises( + ValueError, + match=f"Input should be less than or equal to {MAX_PAGE_SIZE}", + ): + ListChartsRequest(page_size=MAX_PAGE_SIZE + 1) + + def test_page_size_at_max(self): + """Test that page_size at MAX_PAGE_SIZE is accepted.""" + request = ListChartsRequest(page_size=MAX_PAGE_SIZE) + assert request.page_size == MAX_PAGE_SIZE + def test_filter_validation(self): """Test that filter validation works correctly.""" # Valid filter