From d79eb5842adb389897c21d34cfe52bda3d3b709b Mon Sep 17 00:00:00 2001 From: Richard Fogaca Nienkotter <63572350+richardfogaca@users.noreply.github.com> Date: Fri, 24 Apr 2026 09:40:39 -0300 Subject: [PATCH] fix(mcp): protect data-model metadata from dashboard viewers (#39599) Co-authored-by: Elizabeth Thompson Co-authored-by: Claude Sonnet 4.6 --- .../mcp_service/chart/tool/get_chart_info.py | 13 +- .../mcp_service/chart/tool/get_chart_sql.py | 3 +- .../mcp_service/chart/tool/list_charts.py | 38 +++- .../database/tool/get_database_info.py | 14 ++ .../database/tool/list_databases.py | 26 ++- superset/mcp_service/dataset/schemas.py | 8 +- .../dataset/tool/get_dataset_info.py | 14 ++ .../mcp_service/dataset/tool/list_datasets.py | 26 ++- superset/mcp_service/privacy.py | 128 +++++++++++++- superset/mcp_service/server.py | 119 ++++++++++--- superset/mcp_service/system/schemas.py | 7 + .../system/tool/get_instance_info.py | 16 ++ .../mcp_service/system/tool/get_schema.py | 54 +++++- .../chart/tool/test_get_chart_info.py | 163 +++++++++++++++++ .../chart/tool/test_get_chart_sql.py | 8 + .../chart/tool/test_list_charts.py | 85 ++++++++- .../database/tool/test_database_tools.py | 43 +++++ .../dataset/tool/test_dataset_tools.py | 93 ++++++++++ .../system/tool/test_get_current_user.py | 152 ++++++++++++++++ .../system/tool/test_get_schema.py | 16 ++ tests/unit_tests/mcp_service/test_privacy.py | 82 ++++++--- .../mcp_service/test_tool_search_transform.py | 166 +++++++++++++++++- 22 files changed, 1207 insertions(+), 67 deletions(-) create mode 100644 tests/unit_tests/mcp_service/chart/tool/test_get_chart_info.py diff --git a/superset/mcp_service/chart/tool/get_chart_info.py b/superset/mcp_service/chart/tool/get_chart_info.py index f508a62c45e..68c13e4bafb 100644 --- a/superset/mcp_service/chart/tool/get_chart_info.py +++ b/superset/mcp_service/chart/tool/get_chart_info.py @@ -36,6 +36,10 @@ from superset.mcp_service.chart.schemas import ( serialize_chart_object, ) from superset.mcp_service.mcp_core import ModelGetInfoCore +from superset.mcp_service.privacy import ( + redact_chart_data_model_fields, + user_can_view_data_model_metadata, +) logger = logging.getLogger(__name__) @@ -155,6 +159,7 @@ async def get_chart_info( "Retrieving chart information: identifier=%s, form_data_key=%s" % (request.identifier, request.form_data_key) ) + can_view_data_model_metadata = user_can_view_data_model_metadata() # Handle unsaved chart (form_data_key only, no identifier) if not request.identifier and request.form_data_key: @@ -165,7 +170,10 @@ async def get_chart_info( "No chart identifier provided - retrieving unsaved chart from cache: " "form_data_key=%s" % (request.form_data_key,) ) - return _build_unsaved_chart_info(request.form_data_key) + result = _build_unsaved_chart_info(request.form_data_key) + if not can_view_data_model_metadata: + return redact_chart_data_model_fields(result) + return result # At this point identifier must be set (validator ensures at least one # of identifier/form_data_key is provided, and the form_data_key-only @@ -202,6 +210,9 @@ async def get_chart_info( ) _apply_unsaved_state_override(result, request.form_data_key) + if not can_view_data_model_metadata: + result = redact_chart_data_model_fields(result) + await ctx.info( "Chart information retrieved successfully: chart_name=%s, " "is_unsaved_state=%s" % (result.slice_name, result.is_unsaved_state) diff --git a/superset/mcp_service/chart/tool/get_chart_sql.py b/superset/mcp_service/chart/tool/get_chart_sql.py index 861d8ea89b2..42dc0b5cecc 100644 --- a/superset/mcp_service/chart/tool/get_chart_sql.py +++ b/superset/mcp_service/chart/tool/get_chart_sql.py @@ -348,7 +348,8 @@ def _extract_sql_from_result( @tool( tags=["data"], - class_permission_name="Chart", + class_permission_name="SQLLab", + method_permission_name="execute_sql_query", annotations=ToolAnnotations( title="Get chart SQL", readOnlyHint=True, diff --git a/superset/mcp_service/chart/tool/list_charts.py b/superset/mcp_service/chart/tool/list_charts.py index 8bfcf347615..f18b5c08f73 100644 --- a/superset/mcp_service/chart/tool/list_charts.py +++ b/superset/mcp_service/chart/tool/list_charts.py @@ -30,6 +30,7 @@ if TYPE_CHECKING: from superset.extensions import event_logger from superset.mcp_service.chart.schemas import ( + ChartError, ChartFilter, ChartInfo, ChartLike, @@ -38,6 +39,12 @@ from superset.mcp_service.chart.schemas import ( serialize_chart_object, ) from superset.mcp_service.mcp_core import ModelListCore +from superset.mcp_service.privacy import ( + DATA_MODEL_METADATA_ERROR_TYPE, + remove_chart_data_model_columns, + request_uses_chart_data_model_filter, + user_can_view_data_model_metadata, +) logger = logging.getLogger(__name__) @@ -58,7 +65,6 @@ SORTABLE_CHART_COLUMNS = [ "id", "slice_name", "viz_type", - "datasource_name", "description", "changed_on", "created_on", @@ -74,14 +80,16 @@ SORTABLE_CHART_COLUMNS = [ destructiveHint=False, ), ) -async def list_charts(request: ListChartsRequest, ctx: Context) -> ChartList: +async def list_charts( + request: ListChartsRequest, ctx: Context +) -> ChartList | ChartError: """List charts with filtering and search. Returns chart metadata including id, name, viz_type, URL, and last modified time. - Sortable columns for order_column: id, slice_name, viz_type, - datasource_name, description, changed_on, created_on + Sortable columns for order_column: id, slice_name, viz_type, description, + changed_on, created_on """ await ctx.info( "Listing charts: page=%s, page_size=%s, search=%s" @@ -107,8 +115,26 @@ async def list_charts(request: ListChartsRequest, ctx: Context) -> ChartList: get_chart_columns, ) + can_view_data_model_metadata = user_can_view_data_model_metadata() + if not can_view_data_model_metadata and request_uses_chart_data_model_filter( + request.filters + ): + return ChartError( + error=( + "You don't have permission to access underlying dataset or " + "database details for your role." + ), + error_type=DATA_MODEL_METADATA_ERROR_TYPE, + ) + # Get all column names dynamically from the model all_columns = get_all_column_names(get_chart_columns()) + sortable_columns = CHART_SORTABLE_COLUMNS + select_columns = request.select_columns + if not can_view_data_model_metadata: + all_columns = remove_chart_data_model_columns(all_columns) + sortable_columns = remove_chart_data_model_columns(sortable_columns) + select_columns = remove_chart_data_model_columns(select_columns) def _serialize_chart( obj: "Slice | None", cols: list[str] | None @@ -129,7 +155,7 @@ async def list_charts(request: ListChartsRequest, ctx: Context) -> ChartList: list_field_name="charts", output_list_schema=ChartList, all_columns=all_columns, - sortable_columns=CHART_SORTABLE_COLUMNS, + sortable_columns=sortable_columns, logger=logger, ) @@ -138,7 +164,7 @@ async def list_charts(request: ListChartsRequest, ctx: Context) -> ChartList: result = tool.run_tool( filters=request.filters, search=request.search, - select_columns=request.select_columns, + select_columns=select_columns, order_column=request.order_column, order_direction=request.order_direction, page=max(request.page - 1, 0), diff --git a/superset/mcp_service/database/tool/get_database_info.py b/superset/mcp_service/database/tool/get_database_info.py index 56aef1d2bb4..c91b4306ad8 100644 --- a/superset/mcp_service/database/tool/get_database_info.py +++ b/superset/mcp_service/database/tool/get_database_info.py @@ -36,6 +36,11 @@ from superset.mcp_service.database.schemas import ( serialize_database_object, ) from superset.mcp_service.mcp_core import ModelGetInfoCore +from superset.mcp_service.privacy import ( + DATA_MODEL_METADATA_ERROR_TYPE, + requires_data_model_metadata_access, + user_can_view_data_model_metadata, +) logger = logging.getLogger(__name__) @@ -49,6 +54,7 @@ logger = logging.getLogger(__name__) destructiveHint=False, ), ) +@requires_data_model_metadata_access async def get_database_info( request: GetDatabaseInfoRequest, ctx: Context ) -> DatabaseInfo | DatabaseError: @@ -87,6 +93,14 @@ async def get_database_info( ) ) + # The decorator hides this tool from search; this check enforces direct calls. + if not user_can_view_data_model_metadata(): + await ctx.warning("Database metadata lookup blocked by privacy controls") + return DatabaseError.create( + error="You don't have permission to access database details for your role.", + error_type=DATA_MODEL_METADATA_ERROR_TYPE, + ) + try: from superset.daos.database import DatabaseDAO diff --git a/superset/mcp_service/database/tool/list_databases.py b/superset/mcp_service/database/tool/list_databases.py index 8db4e0fbb35..6f3959f4528 100644 --- a/superset/mcp_service/database/tool/list_databases.py +++ b/superset/mcp_service/database/tool/list_databases.py @@ -33,6 +33,7 @@ if TYPE_CHECKING: from superset.extensions import event_logger from superset.mcp_service.database.schemas import ( + DatabaseError, DatabaseFilter, DatabaseInfo, DatabaseList, @@ -40,9 +41,16 @@ from superset.mcp_service.database.schemas import ( serialize_database_object, ) from superset.mcp_service.mcp_core import ModelListCore +from superset.mcp_service.privacy import ( + DATA_MODEL_METADATA_ERROR_TYPE, + requires_data_model_metadata_access, + user_can_view_data_model_metadata, +) logger = logging.getLogger(__name__) +_DEFAULT_LIST_DATABASES_REQUEST = ListDatabasesRequest() + @tool( tags=["core"], @@ -53,7 +61,11 @@ logger = logging.getLogger(__name__) destructiveHint=False, ), ) -async def list_databases(request: ListDatabasesRequest, ctx: Context) -> DatabaseList: +@requires_data_model_metadata_access +async def list_databases( + request: ListDatabasesRequest | None = None, + ctx: Context | None = None, +) -> DatabaseList | DatabaseError: """List database connections with filtering and search. Returns database metadata including name, backend type, and permissions. @@ -61,6 +73,11 @@ async def list_databases(request: ListDatabasesRequest, ctx: Context) -> Databas Sortable columns for order_column: id, database_name, changed_on, created_on """ + if ctx is None: + raise RuntimeError("FastMCP context is required for list_databases") + + request = request or _DEFAULT_LIST_DATABASES_REQUEST.model_copy(deep=True) + await ctx.info( "Listing databases: page=%s, page_size=%s, search=%s" % ( @@ -88,6 +105,13 @@ async def list_databases(request: ListDatabasesRequest, ctx: Context) -> Databas ) ) + if not user_can_view_data_model_metadata(): + await ctx.warning("Database listing blocked by data-model privacy controls") + return DatabaseError.create( + error="You don't have permission to access database details for your role.", + error_type=DATA_MODEL_METADATA_ERROR_TYPE, + ) + try: from superset.daos.database import DatabaseDAO from superset.mcp_service.common.schema_discovery import ( diff --git a/superset/mcp_service/dataset/schemas.py b/superset/mcp_service/dataset/schemas.py index 3a8216c1d4a..d09b312aa41 100644 --- a/superset/mcp_service/dataset/schemas.py +++ b/superset/mcp_service/dataset/schemas.py @@ -286,9 +286,13 @@ class DatasetError(BaseModel): @classmethod def create(cls, error: str, error_type: str) -> "DatasetError": """Create a standardized DatasetError with timestamp.""" - from datetime import datetime + from datetime import datetime, timezone - return cls(error=error, error_type=error_type, timestamp=datetime.now()) + return cls( + error=error, + error_type=error_type, + timestamp=datetime.now(timezone.utc), + ) class GetDatasetInfoRequest(MetadataCacheControl): diff --git a/superset/mcp_service/dataset/tool/get_dataset_info.py b/superset/mcp_service/dataset/tool/get_dataset_info.py index c211c618d63..93d5f21ddf0 100644 --- a/superset/mcp_service/dataset/tool/get_dataset_info.py +++ b/superset/mcp_service/dataset/tool/get_dataset_info.py @@ -37,6 +37,11 @@ from superset.mcp_service.dataset.schemas import ( serialize_dataset_object, ) from superset.mcp_service.mcp_core import ModelGetInfoCore +from superset.mcp_service.privacy import ( + DATA_MODEL_METADATA_ERROR_TYPE, + requires_data_model_metadata_access, + user_can_view_data_model_metadata, +) logger = logging.getLogger(__name__) @@ -50,6 +55,7 @@ logger = logging.getLogger(__name__) destructiveHint=False, ), ) +@requires_data_model_metadata_access async def get_dataset_info( request: GetDatasetInfoRequest, ctx: Context ) -> DatasetInfo | DatasetError: @@ -93,6 +99,14 @@ async def get_dataset_info( ) ) + # The decorator hides this tool from search; this check enforces direct calls. + if not user_can_view_data_model_metadata(): + await ctx.warning("Dataset metadata lookup blocked by privacy controls") + return DatasetError.create( + error="You don't have permission to access dataset details for your role.", + error_type=DATA_MODEL_METADATA_ERROR_TYPE, + ) + try: from superset.connectors.sqla.models import SqlaTable from superset.daos.dataset import DatasetDAO diff --git a/superset/mcp_service/dataset/tool/list_datasets.py b/superset/mcp_service/dataset/tool/list_datasets.py index e4d2d3bd427..945f9fd3c08 100644 --- a/superset/mcp_service/dataset/tool/list_datasets.py +++ b/superset/mcp_service/dataset/tool/list_datasets.py @@ -33,6 +33,7 @@ if TYPE_CHECKING: from superset.extensions import event_logger from superset.mcp_service.dataset.schemas import ( + DatasetError, DatasetFilter, DatasetInfo, DatasetList, @@ -40,6 +41,11 @@ from superset.mcp_service.dataset.schemas import ( serialize_dataset_object, ) from superset.mcp_service.mcp_core import ModelListCore +from superset.mcp_service.privacy import ( + DATA_MODEL_METADATA_ERROR_TYPE, + requires_data_model_metadata_access, + user_can_view_data_model_metadata, +) logger = logging.getLogger(__name__) @@ -68,6 +74,8 @@ SORTABLE_DATASET_COLUMNS = [ "created_on", ] +_DEFAULT_LIST_DATASETS_REQUEST = ListDatasetsRequest() + @tool( tags=["core"], @@ -78,7 +86,11 @@ SORTABLE_DATASET_COLUMNS = [ destructiveHint=False, ), ) -async def list_datasets(request: ListDatasetsRequest, ctx: Context) -> DatasetList: +@requires_data_model_metadata_access +async def list_datasets( + request: ListDatasetsRequest | None = None, + ctx: Context | None = None, +) -> DatasetList | DatasetError: """List datasets with filtering and search. Returns dataset metadata including table name, schema, and last modified @@ -87,6 +99,11 @@ async def list_datasets(request: ListDatasetsRequest, ctx: Context) -> DatasetLi Sortable columns for order_column: id, table_name, schema, changed_on, created_on """ + if ctx is None: + raise RuntimeError("FastMCP context is required for list_datasets") + + request = request or _DEFAULT_LIST_DATASETS_REQUEST.model_copy(deep=True) + await ctx.info( "Listing datasets: page=%s, page_size=%s, search=%s" % ( @@ -114,6 +131,13 @@ async def list_datasets(request: ListDatasetsRequest, ctx: Context) -> DatasetLi ) ) + if not user_can_view_data_model_metadata(): + await ctx.warning("Dataset listing blocked by data-model privacy controls") + return DatasetError.create( + error="You don't have permission to access dataset details for your role.", + error_type=DATA_MODEL_METADATA_ERROR_TYPE, + ) + try: from superset.daos.dataset import DatasetDAO from superset.mcp_service.common.schema_discovery import ( diff --git a/superset/mcp_service/privacy.py b/superset/mcp_service/privacy.py index dd26b9ee92e..74949c84f2c 100644 --- a/superset/mcp_service/privacy.py +++ b/superset/mcp_service/privacy.py @@ -15,12 +15,17 @@ # specific language governing permissions and limitations # under the License. -"""Privacy helpers for MCP user-directory and access-list metadata.""" +"""Privacy helpers for MCP user-directory and data-model metadata.""" from __future__ import annotations from collections.abc import Iterable -from typing import Any +from datetime import datetime, timezone +from typing import Any, Callable, TypeVar + +from pydantic import BaseModel, ConfigDict, Field + +F = TypeVar("F", bound=Callable[..., Any]) USER_DIRECTORY_FIELDS = frozenset( { @@ -39,6 +44,83 @@ USER_DIRECTORY_FIELDS = frozenset( } ) +DATA_MODEL_METADATA_ACCESS_ATTR = "_requires_data_model_metadata_access" +DATA_MODEL_METADATA_ERROR_TYPE = "DataModelMetadataRestricted" +DATA_MODEL_METADATA_PRIVACY_SCOPE = "data_model" +DATA_MODEL_METADATA_ERROR_MESSAGE = ( + "You don't have permission to access underlying dataset or database details " + "for your role." +) + +# Fields that reveal dataset/database metadata through chart list and schema surfaces. +# ChartInfo only exposes a subset of these as direct model fields. +CHART_DATA_MODEL_COLUMNS = frozenset( + { + "catalog_perm", + "datasource_id", + "datasource_name", + "datasource_type", + "filters", + "form_data", + "params", + "perm", + "query_context", + "schema_perm", + } +) + + +class PrivacyError(BaseModel): + """Structured privacy/permission denial for MCP tool responses.""" + + error: str = Field(..., description="Error message") + error_type: str = Field(..., description="Type of error") + privacy_scope: str = Field(..., description="Privacy scope for the denial") + timestamp: str | datetime | None = Field(None, description="Error timestamp") + model_config = ConfigDict(ser_json_timedelta="iso8601") + + @classmethod + def create_data_model_metadata_denied(cls) -> "PrivacyError": + return cls( + error=DATA_MODEL_METADATA_ERROR_MESSAGE, + error_type=DATA_MODEL_METADATA_ERROR_TYPE, + privacy_scope=DATA_MODEL_METADATA_PRIVACY_SCOPE, + timestamp=datetime.now(timezone.utc), + ) + + +def requires_data_model_metadata_access(func: F) -> F: + """Mark a tool as requiring data-model metadata permission.""" + setattr(func, DATA_MODEL_METADATA_ACCESS_ATTR, True) + return func + + +def tool_requires_data_model_metadata_access(func: Any) -> bool: + """Return whether a tool requires data-model metadata access.""" + return bool(getattr(func, DATA_MODEL_METADATA_ACCESS_ATTR, False)) + + +def user_can_view_data_model_metadata() -> bool: + """Return whether the current user can inspect data-model metadata. + + Dataset drill/write permissions indicate active data-model introspection access. + Dashboard-only viewers may have Dataset read access for chart rendering, but that + should not expose dataset/database metadata through MCP tools. + """ + try: + from superset import security_manager + + return any( + security_manager.can_access(permission_name, "Dataset") + for permission_name in ( + "can_get_drill_info", + "can_get_or_create_dataset", + "can_write", + ) + ) + except Exception: # noqa: BLE001 + return False + def filter_user_directory_fields(data: dict[str, Any]) -> dict[str, Any]: """Remove fields that expose users, roles, owners, or access metadata.""" @@ -50,3 +132,45 @@ def filter_user_directory_fields(data: dict[str, Any]) -> dict[str, Any]: def filter_user_directory_columns(columns: Iterable[str]) -> list[str]: """Remove user-directory columns while preserving order.""" return [column for column in columns if column not in USER_DIRECTORY_FIELDS] + + +def remove_chart_data_model_columns(columns: Iterable[str]) -> list[str]: + """Remove chart fields that reveal data-model metadata.""" + return [column for column in columns if column not in CHART_DATA_MODEL_COLUMNS] + + +def redact_chart_data_model_fields(chart_info: Any) -> Any: + """Redact chart fields that expose dataset or database metadata. + + Fails closed: if redaction cannot be applied, the exception propagates + rather than returning unredacted data. + """ + from superset.mcp_service.chart.schemas import ChartInfo + + if isinstance(chart_info, ChartInfo): + return chart_info.model_copy( + update={ + "datasource_name": None, + "datasource_type": None, + "filters": None, + "form_data": None, + } + ) + return chart_info + + +def request_uses_chart_data_model_filter(filters: Iterable[Any]) -> bool: + """Return whether chart filters target hidden data-model fields.""" + return any( + getattr(filter_, "col", None) in CHART_DATA_MODEL_COLUMNS for filter_ in filters + ) + + +def is_data_model_metadata_error(data: Any) -> bool: + """Return whether tool output is a structured data-model privacy denial.""" + return ( + isinstance(data, dict) + and data.get("error_type") == DATA_MODEL_METADATA_ERROR_TYPE + and data.get("privacy_scope", DATA_MODEL_METADATA_PRIVACY_SCOPE) + == DATA_MODEL_METADATA_PRIVACY_SCOPE + ) diff --git a/superset/mcp_service/server.py b/superset/mcp_service/server.py index ca330e0655c..14ae162c284 100644 --- a/superset/mcp_service/server.py +++ b/superset/mcp_service/server.py @@ -25,7 +25,7 @@ For multi-pod deployments, configure MCP_EVENT_STORE_CONFIG with Redis URL. import logging import os from collections.abc import Sequence -from typing import Annotated, Any +from typing import Annotated, Any, Callable import uvicorn from fastmcp.server.middleware import Middleware @@ -42,6 +42,10 @@ from superset.mcp_service.middleware import ( LoggingMiddleware, StructuredContentStripperMiddleware, ) +from superset.mcp_service.privacy import ( + tool_requires_data_model_metadata_access, + user_can_view_data_model_metadata, +) from superset.mcp_service.storage import _create_redis_store from superset.utils import json @@ -360,6 +364,51 @@ def _build_summary_serializer(max_desc: int) -> Any: return _summary_serializer +def _tool_allowed_for_current_user(tool: Any) -> bool: + """Return whether the current Flask user can see this tool in search results.""" + try: + from flask import current_app, g + + if not current_app.config.get("MCP_RBAC_ENABLED", True): + return True + + from superset import security_manager + from superset.mcp_service.auth import ( + CLASS_PERMISSION_ATTR, + get_user_from_request, + METHOD_PERMISSION_ATTR, + PERMISSION_PREFIX, + ) + + tool_func = getattr(tool, "fn", None) + if tool_requires_data_model_metadata_access(tool_func) and not ( + user_can_view_data_model_metadata() + ): + return False + + class_permission_name = getattr(tool_func, CLASS_PERMISSION_ATTR, None) + if not class_permission_name: + return True + + if not getattr(g, "user", None): + try: + g.user = get_user_from_request() + except ValueError: + return False + + method_permission_name = getattr(tool_func, METHOD_PERMISSION_ATTR, "read") + permission_name = f"{PERMISSION_PREFIX}{method_permission_name}" + return security_manager.can_access(permission_name, class_permission_name) + except (AttributeError, RuntimeError, ValueError): + logger.debug("Could not evaluate tool search permission", exc_info=True) + return False + + +def _filter_tools_by_current_user_permission(tools: Sequence[Any]) -> list[Any]: + """Filter search candidates to tools the current user can execute.""" + return [tool for tool in tools if _tool_allowed_for_current_user(tool)] + + def _create_search_result_serializer( config: dict[str, Any], ) -> Any: @@ -526,26 +575,11 @@ def _apply_tool_search_transform(mcp_instance: Any, config: dict[str, Any]) -> N tool = Tool.from_function(fn=call_tool, name=transform._call_tool_name) return _fix_call_tool_arguments(tool) - if strategy == "regex": - from fastmcp.server.transforms.search import RegexSearchTransform - - class _FixedRegexSearchTransform(RegexSearchTransform): - """Regex search with fixed call_tool schema and arg normalization.""" - - def _make_call_tool(self) -> Tool: - return _make_normalizing_call_tool(self) - - transform = _FixedRegexSearchTransform(**kwargs) - else: - from fastmcp.server.transforms.search import BM25SearchTransform - - class _FixedBM25SearchTransform(BM25SearchTransform): - """BM25 search with fixed call_tool schema and arg normalization.""" - - def _make_call_tool(self) -> Tool: - return _make_normalizing_call_tool(self) - - transform = _FixedBM25SearchTransform(**kwargs) + transform = _create_search_transform( + strategy=strategy, + kwargs=kwargs, + make_normalizing_call_tool=_make_normalizing_call_tool, + ) mcp_instance.add_transform(transform) logger.info( @@ -556,6 +590,49 @@ def _apply_tool_search_transform(mcp_instance: Any, config: dict[str, Any]) -> N ) +def _create_search_transform( + *, + strategy: str, + kwargs: dict[str, Any], + make_normalizing_call_tool: Callable[[Any], Any], +) -> Any: + """Create the configured search transform with tool-permission filtering.""" + from fastmcp.server.context import Context + + if strategy == "regex": + from fastmcp.server.transforms.search import RegexSearchTransform + + class _FixedRegexSearchTransform(RegexSearchTransform): + """Regex search with fixed call_tool schema and arg normalization.""" + + async def _get_visible_tools(self, ctx: Context) -> Sequence[Any]: + """Return only tools visible to the current authenticated user.""" + tools = await super()._get_visible_tools(ctx) + return _filter_tools_by_current_user_permission(tools) + + def _make_call_tool(self) -> Any: + """Build the normalized ``call_tool`` proxy for regex search.""" + return make_normalizing_call_tool(self) + + return _FixedRegexSearchTransform(**kwargs) + + from fastmcp.server.transforms.search import BM25SearchTransform + + class _FixedBM25SearchTransform(BM25SearchTransform): + """BM25 search with fixed call_tool schema and arg normalization.""" + + async def _get_visible_tools(self, ctx: Context) -> Sequence[Any]: + """Return only tools visible to the current authenticated user.""" + tools = await super()._get_visible_tools(ctx) + return _filter_tools_by_current_user_permission(tools) + + def _make_call_tool(self) -> Any: + """Build the normalized ``call_tool`` proxy for BM25 search.""" + return make_normalizing_call_tool(self) + + return _FixedBM25SearchTransform(**kwargs) + + def _create_auth_provider(flask_app: Any) -> Any | None: """Create an auth provider from Flask app config. diff --git a/superset/mcp_service/system/schemas.py b/superset/mcp_service/system/schemas.py index 582ae84aef9..85a78277c6d 100644 --- a/superset/mcp_service/system/schemas.py +++ b/superset/mcp_service/system/schemas.py @@ -121,6 +121,13 @@ class InstanceInfo(BaseModel): description="Information about the authenticated user.", ) feature_availability: FeatureAvailability + data_model_metadata_redacted: bool = Field( + default=False, + description=( + "True when dataset/database summary fields were removed because " + "the current user cannot inspect data model metadata." + ), + ) timestamp: datetime diff --git a/superset/mcp_service/system/tool/get_instance_info.py b/superset/mcp_service/system/tool/get_instance_info.py index 80bf8b915a7..2aaf4692ebd 100644 --- a/superset/mcp_service/system/tool/get_instance_info.py +++ b/superset/mcp_service/system/tool/get_instance_info.py @@ -28,6 +28,7 @@ from superset_core.mcp.decorators import tool, ToolAnnotations from superset.extensions import db, event_logger from superset.mcp_service.mcp_core import InstanceInfoCore +from superset.mcp_service.privacy import user_can_view_data_model_metadata from superset.mcp_service.system.schemas import ( GetSupersetInstanceInfoRequest, InstanceInfo, @@ -76,6 +77,18 @@ _instance_info_core = InstanceInfoCore( _DEFAULT_INSTANCE_INFO_REQUEST = GetSupersetInstanceInfoRequest() +def _redact_data_model_metadata(result: InstanceInfo) -> InstanceInfo: + """Remove dataset/database counts and activity from instance overview.""" + data = result.model_copy(deep=True) + data.instance_summary.total_datasets = 0 + data.instance_summary.total_databases = 0 + data.recent_activity.datasets_created_last_30_days = 0 + data.recent_activity.datasets_modified_last_7_days = 0 + data.database_breakdown.by_type = {} + data.data_model_metadata_redacted = True + return data + + @tool( tags=["core"], annotations=ToolAnnotations( @@ -161,6 +174,9 @@ def _run_instance_info() -> InstanceInfo: with event_logger.log_context(action="mcp.get_instance_info.metrics"): result = _instance_info_core.run_tool() + if not user_can_view_data_model_metadata(): + result = _redact_data_model_metadata(result) + if (user := getattr(g, "user", None)) is not None: result.current_user = serialize_user_object(user) diff --git a/superset/mcp_service/system/tool/get_schema.py b/superset/mcp_service/system/tool/get_schema.py index a0bc8213a4f..c5ca965c52d 100644 --- a/superset/mcp_service/system/tool/get_schema.py +++ b/superset/mcp_service/system/tool/get_schema.py @@ -53,6 +53,11 @@ from superset.mcp_service.common.schema_discovery import ( ) from superset.mcp_service.constants import ModelType from superset.mcp_service.mcp_core import ModelGetSchemaCore +from superset.mcp_service.privacy import ( + PrivacyError, + remove_chart_data_model_columns, + user_can_view_data_model_metadata, +) logger = logging.getLogger(__name__) @@ -149,13 +154,16 @@ _SCHEMA_CORE_FACTORIES: dict[ @tool( tags=["discovery"], + class_permission_name="Dataset", annotations=ToolAnnotations( title="Get schema", readOnlyHint=True, destructiveHint=False, ), ) -async def get_schema(request: GetSchemaRequest, ctx: Context) -> GetSchemaResponse: +async def get_schema( + request: GetSchemaRequest, ctx: Context +) -> GetSchemaResponse | PrivacyError: """ Get comprehensive schema metadata for a model type. @@ -177,6 +185,17 @@ async def get_schema(request: GetSchemaRequest, ctx: Context) -> GetSchemaRespon """ await ctx.info(f"Getting schema for model_type={request.model_type}") + can_view_data_model_metadata = user_can_view_data_model_metadata() + if not can_view_data_model_metadata and request.model_type in { + "dataset", + "database", + }: + await ctx.warning( + "Schema discovery blocked by data-model privacy controls: " + f"model_type={request.model_type}" + ) + return PrivacyError.create_data_model_metadata_denied() + # Get the appropriate core factory with defensive lookup factory = _SCHEMA_CORE_FACTORIES.get(request.model_type) if factory is None: @@ -191,6 +210,39 @@ async def get_schema(request: GetSchemaRequest, ctx: Context) -> GetSchemaRespon core = factory() schema_info = core.run_tool() + if not can_view_data_model_metadata and request.model_type == "chart": + schema_info = schema_info.model_copy(deep=True) + allowed_chart_columns = set( + remove_chart_data_model_columns( + [column.name for column in schema_info.select_columns] + ) + ) + schema_info.select_columns = [ + column + for column in schema_info.select_columns + if column.name in allowed_chart_columns + ] + schema_info.filter_columns = { + column: operators + for column, operators in schema_info.filter_columns.items() + if column in allowed_chart_columns + } + schema_info.sortable_columns = [ + column + for column in schema_info.sortable_columns + if column in allowed_chart_columns + ] + schema_info.default_select = [ + column + for column in schema_info.default_select + if column in allowed_chart_columns + ] + schema_info.search_columns = [ + column + for column in schema_info.search_columns + if column in allowed_chart_columns + ] + await ctx.debug( f"Schema for {request.model_type}: " f"{len(schema_info.select_columns)} select columns, " diff --git a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_info.py b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_info.py new file mode 100644 index 00000000000..b2a7fe31497 --- /dev/null +++ b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_info.py @@ -0,0 +1,163 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Unit tests for get_chart_info MCP tool privacy behavior. +""" + +import importlib +from contextlib import nullcontext +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest +from fastmcp import Client + +from superset.mcp_service.app import mcp +from superset.mcp_service.chart.schemas import ( + ChartInfo, + extract_filters_from_form_data, + GetChartInfoRequest, +) +from superset.utils import json + +get_chart_info_module = importlib.import_module( + "superset.mcp_service.chart.tool.get_chart_info" +) + + +@pytest.fixture +def mcp_server(): + return mcp + + +@pytest.fixture(autouse=True) +def mock_auth(): + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield + + +def _make_chart_info() -> ChartInfo: + form_data = { + "viz_type": "table", + "datasource": "12__table", + "datasource_name": "vehicle_sales", + "filters": [{"col": "state", "op": "IN", "val": ["CA"]}], + } + return ChartInfo( + id=123, + slice_name="Vehicle Sales", + viz_type="table", + datasource_name="vehicle_sales", + datasource_type="table", + filters=extract_filters_from_form_data(form_data), + form_data=form_data, + ) + + +class TestGetChartInfoPrivacy: + @pytest.mark.asyncio + async def test_restricted_user_redacts_saved_chart_data_model_fields( + self, mcp_server + ) -> None: + chart_info = _make_chart_info() + + with ( + patch.object( + get_chart_info_module.event_logger, + "log_context", + return_value=nullcontext(), + ), + patch.object( + get_chart_info_module.ModelGetInfoCore, + "run_tool", + return_value=chart_info, + ), + patch.object( + get_chart_info_module, + "user_can_view_data_model_metadata", + return_value=False, + create=True, + ), + patch.object( + get_chart_info_module, + "validate_chart_dataset", + return_value=SimpleNamespace(is_valid=True, warnings=[]), + ), + patch("superset.daos.chart.ChartDAO.find_by_id", return_value=Mock()), + patch("superset.mcp_service.auth.check_tool_permission", return_value=True), + ): + async with Client(mcp_server) as client: + response = await client.call_tool( + "get_chart_info", + {"request": GetChartInfoRequest(identifier=123).model_dump()}, + ) + + result = json.loads(response.content[0].text) + assert result["datasource_name"] is None + assert result["datasource_type"] is None + assert result["filters"] is None + assert result["form_data"] is None + + @pytest.mark.asyncio + async def test_restricted_user_redacts_unsaved_chart_data_model_fields( + self, mcp_server + ) -> None: + cached_form_data = ( + '{"viz_type":"table","datasource_name":"vehicle_sales",' + '"datasource_type":"table","filters":[{"col":"state","op":"IN",' + '"val":["CA"]}],"metrics":["count"]}' + ) + + with ( + patch.object( + get_chart_info_module.event_logger, + "log_context", + return_value=nullcontext(), + ), + patch.object( + get_chart_info_module, + "user_can_view_data_model_metadata", + return_value=False, + create=True, + ), + patch.object( + get_chart_info_module, + "get_cached_form_data", + return_value=cached_form_data, + ), + patch("superset.mcp_service.auth.check_tool_permission", return_value=True), + ): + async with Client(mcp_server) as client: + response = await client.call_tool( + "get_chart_info", + { + "request": GetChartInfoRequest( + form_data_key="cached-key" + ).model_dump() + }, + ) + + result = json.loads(response.content[0].text) + assert result["datasource_name"] is None + assert result["datasource_type"] is None + assert result["filters"] is None + assert result["form_data"] is None diff --git a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_sql.py b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_sql.py index 00fa65991c5..7546e55d68a 100644 --- a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_sql.py +++ b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_sql.py @@ -24,6 +24,7 @@ from unittest.mock import Mock, patch import pytest +from superset.mcp_service.auth import CLASS_PERMISSION_ATTR, METHOD_PERMISSION_ATTR from superset.mcp_service.chart.schemas import ( ChartError, ChartSql, @@ -37,6 +38,7 @@ from superset.mcp_service.chart.tool.get_chart_sql import ( _resolve_groupby, _resolve_metrics, _resolve_metrics_and_groupby, + get_chart_sql, ) _get_chart_sql_mod = importlib.import_module( @@ -44,6 +46,12 @@ _get_chart_sql_mod = importlib.import_module( ) +def test_get_chart_sql_requires_sql_lab_execute_permission(): + """Rendered SQL should not be exposed through basic chart read permission.""" + assert getattr(get_chart_sql, CLASS_PERMISSION_ATTR) == "SQLLab" + assert getattr(get_chart_sql, METHOD_PERMISSION_ATTR) == "execute_sql_query" + + class TestGetChartSqlRequestSchema: """Tests for GetChartSqlRequest schema validation.""" diff --git a/tests/unit_tests/mcp_service/chart/tool/test_list_charts.py b/tests/unit_tests/mcp_service/chart/tool/test_list_charts.py index 9d4d053b8c0..e85dab43156 100644 --- a/tests/unit_tests/mcp_service/chart/tool/test_list_charts.py +++ b/tests/unit_tests/mcp_service/chart/tool/test_list_charts.py @@ -19,9 +19,11 @@ Tests for the list_charts request schema """ -from unittest.mock import Mock +import importlib +from unittest.mock import Mock, patch import pytest +from fastmcp import Client from superset.mcp_service.app import mcp from superset.mcp_service.chart.schemas import ( @@ -29,6 +31,17 @@ from superset.mcp_service.chart.schemas import ( ListChartsRequest, ) from superset.mcp_service.constants import MAX_PAGE_SIZE +from superset.mcp_service.privacy import ( + DATA_MODEL_METADATA_ERROR_TYPE, + remove_chart_data_model_columns, + request_uses_chart_data_model_filter, + user_can_view_data_model_metadata, +) +from superset.utils import json + +list_charts_module = importlib.import_module( + "superset.mcp_service.chart.tool.list_charts" +) @pytest.fixture @@ -36,6 +49,17 @@ def mcp_server(): return mcp +@pytest.fixture(autouse=True) +def mock_auth(): + """Mock authentication for client-based tool tests.""" + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + @pytest.fixture def mock_chart(): """Create a mock chart object.""" @@ -235,3 +259,62 @@ class TestChartDefaultColumnFiltering: "description", "cache_timeout", } + + +class TestChartDataModelMetadataPrivacy: + """Test data-model field privacy helpers for chart listing.""" + + def test_remove_data_model_columns(self): + assert remove_chart_data_model_columns( + ["id", "slice_name", "datasource_name", "form_data", "url"] + ) == ["id", "slice_name", "url"] + + def test_uses_data_model_filter(self): + request = ListChartsRequest( + filters=[ + ChartFilter( + col="datasource_name", + opr="like", + value="Vehicle Sales", + ) + ] + ) + + assert request_uses_chart_data_model_filter(request.filters) is True + + def test_user_can_view_data_model_metadata_uses_dataset_permission(self): + with patch("superset.security_manager", new_callable=Mock) as security_manager: + security_manager.can_access.side_effect = [False, True, False] + + assert user_can_view_data_model_metadata() is True + + security_manager.can_access.assert_any_call("can_get_drill_info", "Dataset") + security_manager.can_access.assert_any_call( + "can_get_or_create_dataset", "Dataset" + ) + + @pytest.mark.asyncio + async def test_list_charts_returns_structured_privacy_error(self, mcp_server): + request = ListChartsRequest( + filters=[ + ChartFilter( + col="datasource_name", + opr="like", + value="Vehicle Sales", + ) + ] + ) + + with patch.object( + list_charts_module, + "user_can_view_data_model_metadata", + return_value=False, + ): + async with Client(mcp_server) as client: + result = await client.call_tool( + "list_charts", + {"request": request.model_dump()}, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == DATA_MODEL_METADATA_ERROR_TYPE diff --git a/tests/unit_tests/mcp_service/database/tool/test_database_tools.py b/tests/unit_tests/mcp_service/database/tool/test_database_tools.py index b98413fc269..e8cc747b822 100644 --- a/tests/unit_tests/mcp_service/database/tool/test_database_tools.py +++ b/tests/unit_tests/mcp_service/database/tool/test_database_tools.py @@ -16,6 +16,7 @@ # under the License. +import importlib import logging from unittest.mock import MagicMock, patch @@ -25,10 +26,17 @@ from fastmcp.exceptions import ToolError from superset.mcp_service.app import mcp from superset.mcp_service.database.schemas import ListDatabasesRequest +from superset.mcp_service.privacy import DATA_MODEL_METADATA_ERROR_TYPE from superset.utils import json logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) +list_databases_module = importlib.import_module( + "superset.mcp_service.database.tool.list_databases" +) +get_database_info_module = importlib.import_module( + "superset.mcp_service.database.tool.get_database_info" +) def create_mock_database( @@ -90,6 +98,41 @@ def mock_auth(): yield mock_get_user +@pytest.fixture(autouse=True) +def allow_data_model_metadata(): + """Keep database tests in the normal metadata-allowed path by default.""" + with ( + patch.object( + list_databases_module, + "user_can_view_data_model_metadata", + return_value=True, + ), + patch.object( + get_database_info_module, + "user_can_view_data_model_metadata", + return_value=True, + ), + ): + yield + + +@pytest.mark.asyncio +async def test_list_databases_without_request_returns_structured_privacy_error( + mcp_server, +) -> None: + """Restricted users are denied even when the request payload is omitted.""" + with patch.object( + list_databases_module, + "user_can_view_data_model_metadata", + return_value=False, + ): + async with Client(mcp_server) as client: + result = await client.call_tool("list_databases", {}) + + data = json.loads(result.content[0].text) + assert data["error_type"] == DATA_MODEL_METADATA_ERROR_TYPE + + @patch("superset.daos.database.DatabaseDAO.list") @pytest.mark.asyncio async def test_list_databases_basic(mock_list, mcp_server): diff --git a/tests/unit_tests/mcp_service/dataset/tool/test_dataset_tools.py b/tests/unit_tests/mcp_service/dataset/tool/test_dataset_tools.py index c64801df22f..1177a60533c 100644 --- a/tests/unit_tests/mcp_service/dataset/tool/test_dataset_tools.py +++ b/tests/unit_tests/mcp_service/dataset/tool/test_dataset_tools.py @@ -16,6 +16,7 @@ # under the License. +import importlib import logging from unittest.mock import MagicMock, patch @@ -29,10 +30,20 @@ from superset.mcp_service.dataset.schemas import ( CreateVirtualDatasetRequest, ListDatasetsRequest, ) +from superset.mcp_service.privacy import ( + DATA_MODEL_METADATA_ERROR_TYPE, + tool_requires_data_model_metadata_access, +) from superset.utils import json logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) +list_datasets_module = importlib.import_module( + "superset.mcp_service.dataset.tool.list_datasets" +) +get_dataset_info_module = importlib.import_module( + "superset.mcp_service.dataset.tool.get_dataset_info" +) def create_mock_dataset( @@ -78,6 +89,70 @@ def create_mock_dataset( return dataset +def test_dataset_discovery_tools_require_drill_permission() -> None: + """Dataset discovery tools are marked as metadata-restricted.""" + from superset.mcp_service.dataset.tool.get_dataset_info import get_dataset_info + from superset.mcp_service.dataset.tool.list_datasets import list_datasets + + assert tool_requires_data_model_metadata_access(list_datasets) is True + assert tool_requires_data_model_metadata_access(get_dataset_info) is True + + +@pytest.mark.asyncio +async def test_list_datasets_returns_structured_privacy_error(mcp_server) -> None: + """Restricted users receive a structured denial for dataset listing.""" + with patch.object( + list_datasets_module, + "user_can_view_data_model_metadata", + return_value=False, + ): + async with Client(mcp_server) as client: + result = await client.call_tool( + "list_datasets", + {"request": ListDatasetsRequest().model_dump()}, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == DATA_MODEL_METADATA_ERROR_TYPE + + +@pytest.mark.asyncio +async def test_list_datasets_without_request_returns_structured_privacy_error( + mcp_server, +) -> None: + """Restricted users are denied even when the request payload is omitted.""" + with patch.object( + list_datasets_module, + "user_can_view_data_model_metadata", + return_value=False, + ): + async with Client(mcp_server) as client: + result = await client.call_tool("list_datasets", {}) + + data = json.loads(result.content[0].text) + assert data["error_type"] == DATA_MODEL_METADATA_ERROR_TYPE + + +@pytest.mark.asyncio +async def test_get_dataset_info_returns_structured_privacy_error(mcp_server) -> None: + """Restricted users receive a structured denial for dataset details.""" + from superset.mcp_service.dataset.schemas import GetDatasetInfoRequest + + with patch.object( + get_dataset_info_module, + "user_can_view_data_model_metadata", + return_value=False, + ): + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_dataset_info", + {"request": GetDatasetInfoRequest(identifier=1).model_dump()}, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == DATA_MODEL_METADATA_ERROR_TYPE + + @pytest.fixture def mcp_server(): return mcp @@ -96,6 +171,24 @@ def mock_auth(): yield mock_get_user +@pytest.fixture(autouse=True) +def allow_data_model_metadata(): + """Keep dataset tests in the normal metadata-allowed path by default.""" + with ( + patch.object( + list_datasets_module, + "user_can_view_data_model_metadata", + return_value=True, + ), + patch.object( + get_dataset_info_module, + "user_can_view_data_model_metadata", + return_value=True, + ), + ): + yield + + @patch("superset.daos.dataset.DatasetDAO.list") @pytest.mark.asyncio async def test_list_datasets_basic(mock_list, mcp_server): diff --git a/tests/unit_tests/mcp_service/system/tool/test_get_current_user.py b/tests/unit_tests/mcp_service/system/tool/test_get_current_user.py index 4b57fa13f69..ca234eeb5ac 100644 --- a/tests/unit_tests/mcp_service/system/tool/test_get_current_user.py +++ b/tests/unit_tests/mcp_service/system/tool/test_get_current_user.py @@ -17,6 +17,7 @@ """Tests for current_user in get_instance_info and user-directory filtering.""" +import importlib from unittest.mock import Mock, patch import pytest @@ -25,10 +26,25 @@ from pydantic import ValidationError from superset.mcp_service.app import mcp from superset.mcp_service.chart.schemas import ChartFilter +from superset.mcp_service.common.schema_discovery import ( + ColumnMetadata, + ModelSchemaInfo, +) from superset.mcp_service.dashboard.schemas import DashboardFilter +from superset.mcp_service.privacy import ( + CHART_DATA_MODEL_COLUMNS, + DATA_MODEL_METADATA_ERROR_TYPE, + tool_requires_data_model_metadata_access, + user_can_view_data_model_metadata, +) from superset.mcp_service.system.schemas import InstanceInfo, UserInfo +from superset.mcp_service.system.tool.get_schema import get_schema from superset.utils import json +get_schema_module = importlib.import_module( + "superset.mcp_service.system.tool.get_schema" +) + # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @@ -101,6 +117,142 @@ def _make_instance_info(**kwargs): return InstanceInfo(**defaults) +def test_get_schema_is_not_globally_hidden_from_tool_search() -> None: + """Per-model privacy is enforced inside get_schema.""" + assert tool_requires_data_model_metadata_access(get_schema) is False + + +def test_redact_data_model_metadata_removes_dataset_and_database_summary(): + from superset.mcp_service.system.schemas import ( + DatabaseBreakdown, + InstanceSummary, + RecentActivity, + ) + from superset.mcp_service.system.tool.get_instance_info import ( + _redact_data_model_metadata, + ) + + instance_info = _make_instance_info( + instance_summary=InstanceSummary( + total_dashboards=2, + total_charts=4, + total_datasets=7, + total_databases=3, + total_users=5, + total_roles=6, + total_tags=8, + avg_charts_per_dashboard=2.0, + ), + recent_activity=RecentActivity( + dashboards_created_last_30_days=1, + charts_created_last_30_days=2, + datasets_created_last_30_days=3, + dashboards_modified_last_7_days=4, + charts_modified_last_7_days=5, + datasets_modified_last_7_days=6, + ), + database_breakdown=DatabaseBreakdown(by_type={"postgresql": 2}), + ) + + redacted = _redact_data_model_metadata(instance_info) + + assert redacted.instance_summary.total_dashboards == 2 + assert redacted.instance_summary.total_charts == 4 + assert redacted.instance_summary.total_datasets == 0 + assert redacted.instance_summary.total_databases == 0 + assert redacted.recent_activity.dashboards_created_last_30_days == 1 + assert redacted.recent_activity.charts_created_last_30_days == 2 + assert redacted.recent_activity.datasets_created_last_30_days == 0 + assert redacted.recent_activity.datasets_modified_last_7_days == 0 + assert redacted.database_breakdown.by_type == {} + assert redacted.data_model_metadata_redacted is True + + +def test_user_can_view_data_model_metadata_requires_stronger_dataset_permission( + app_context, +): + with patch("superset.security_manager", new_callable=Mock) as mock_security_manager: + mock_security_manager.can_access.side_effect = ( + lambda permission_name, view_name: permission_name == "can_read" + ) + assert user_can_view_data_model_metadata() is False + + mock_security_manager.can_access.side_effect = ( + lambda permission_name, view_name: ( + view_name == "Dataset" and permission_name == "can_get_drill_info" + ) + ) + assert user_can_view_data_model_metadata() is True + + +@pytest.mark.asyncio +async def test_get_schema_returns_structured_privacy_error_for_dataset(mcp_server): + with patch.object( + get_schema_module, + "user_can_view_data_model_metadata", + return_value=False, + ): + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_schema", + {"request": {"model_type": "dataset"}}, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == DATA_MODEL_METADATA_ERROR_TYPE + assert data["privacy_scope"] == "data_model" + + +@pytest.mark.asyncio +async def test_get_schema_redacts_chart_data_model_fields(mcp_server): + mock_schema = ModelSchemaInfo( + model_type="chart", + select_columns=[ + ColumnMetadata(name="id"), + ColumnMetadata(name="datasource_name"), + ColumnMetadata(name="url"), + ], + filter_columns={"slice_name": ["eq"], "datasource_name": ["like"]}, + sortable_columns=["slice_name", "datasource_name"], + default_select=["id", "datasource_name", "slice_name"], + default_sort="changed_on", + default_sort_direction="desc", + search_columns=["slice_name", "description", "datasource_name"], + ) + + mock_core = Mock() + mock_core.run_tool.return_value = mock_schema + + with ( + patch.object( + get_schema_module, + "user_can_view_data_model_metadata", + return_value=False, + ), + patch.dict( + get_schema_module._SCHEMA_CORE_FACTORIES, + {"chart": lambda: mock_core}, + clear=False, + ), + ): + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_schema", + {"request": {"model_type": "chart"}}, + ) + + data = json.loads(result.content[0].text) + schema_info = data["schema_info"] + assert all( + column["name"] not in CHART_DATA_MODEL_COLUMNS + for column in schema_info["select_columns"] + ) + assert "datasource_name" not in schema_info["filter_columns"] + assert "datasource_name" not in schema_info["sortable_columns"] + assert "datasource_name" not in schema_info["default_select"] + assert "datasource_name" not in schema_info["search_columns"] + + # --------------------------------------------------------------------------- # Schema-level tests: UserInfo # --------------------------------------------------------------------------- diff --git a/tests/unit_tests/mcp_service/system/tool/test_get_schema.py b/tests/unit_tests/mcp_service/system/tool/test_get_schema.py index bd139da23dc..65ffd633f34 100644 --- a/tests/unit_tests/mcp_service/system/tool/test_get_schema.py +++ b/tests/unit_tests/mcp_service/system/tool/test_get_schema.py @@ -19,6 +19,7 @@ Tests for the get_schema unified schema discovery tool. """ +import importlib from unittest.mock import patch import pytest @@ -40,6 +41,10 @@ from superset.mcp_service.common.schema_discovery import ( ) from superset.utils import json +get_schema_module = importlib.import_module( + "superset.mcp_service.system.tool.get_schema" +) + @pytest.fixture def mcp_server(): @@ -59,6 +64,17 @@ def mock_auth(): yield mock_get_user +@pytest.fixture(autouse=True) +def allow_data_model_metadata(): + """Keep the standalone get_schema suite in the unrestricted default path.""" + with patch.object( + get_schema_module, + "user_can_view_data_model_metadata", + return_value=True, + ): + yield + + class TestGetSchemaRequest: """Test the GetSchemaRequest schema validation.""" diff --git a/tests/unit_tests/mcp_service/test_privacy.py b/tests/unit_tests/mcp_service/test_privacy.py index df29c71e61b..cdbdeb14acd 100644 --- a/tests/unit_tests/mcp_service/test_privacy.py +++ b/tests/unit_tests/mcp_service/test_privacy.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -"""Tests for MCP user-directory privacy filtering.""" +"""Tests for MCP privacy helpers.""" import pytest @@ -23,39 +23,63 @@ from superset.mcp_service.chart.schemas import ChartInfo from superset.mcp_service.dashboard.schemas import DashboardInfo from superset.mcp_service.database.schemas import DatabaseInfo from superset.mcp_service.dataset.schemas import DatasetInfo +from superset.mcp_service.privacy import ( + is_data_model_metadata_error, + redact_chart_data_model_fields, +) + + +def test_is_data_model_metadata_error_accepts_missing_privacy_scope() -> None: + assert ( + is_data_model_metadata_error( + { + "error": "denied", + "error_type": "DataModelMetadataRestricted", + "timestamp": "2026-04-23T11:20:54.286885", + } + ) + is True + ) + + +def test_is_data_model_metadata_error_rejects_wrong_privacy_scope() -> None: + assert ( + is_data_model_metadata_error( + { + "error": "denied", + "error_type": "DataModelMetadataRestricted", + "privacy_scope": "user_directory", + } + ) + is False + ) + + +def test_redact_chart_data_model_fields_removes_restricted_fields() -> None: + chart_info = ChartInfo( + id=1, + slice_name="Revenue", + datasource_name="sales", + datasource_type="table", + filters={"time_range": "Last year"}, + form_data={"datasource": "1__table"}, + ) + + redacted = redact_chart_data_model_fields(chart_info) + + assert redacted.datasource_name is None + assert redacted.datasource_type is None + assert redacted.filters is None + assert redacted.form_data is None @pytest.mark.parametrize( "model", [ - ChartInfo( - id=1, - slice_name="Revenue", - created_by="creator", - changed_by="modifier", - owners=[], - ), - DashboardInfo( - id=1, - dashboard_title="Executive Dashboard", - created_by="creator", - changed_by="modifier", - owners=[], - roles=[], - ), - DatasetInfo( - id=1, - table_name="sales", - created_by="creator", - changed_by="modifier", - owners=[], - ), - DatabaseInfo( - id=1, - database_name="warehouse", - created_by="creator", - changed_by="modifier", - ), + ChartInfo(id=1, slice_name="Revenue"), + DashboardInfo(id=1, dashboard_title="Executive Dashboard", owners=[], roles=[]), + DatasetInfo(id=1, table_name="sales"), + DatabaseInfo(id=1, database_name="warehouse"), ], ) def test_user_directory_fields_removed_from_python_and_json_dumps(model): diff --git a/tests/unit_tests/mcp_service/test_tool_search_transform.py b/tests/unit_tests/mcp_service/test_tool_search_transform.py index f9d44e64dce..34b1df52441 100644 --- a/tests/unit_tests/mcp_service/test_tool_search_transform.py +++ b/tests/unit_tests/mcp_service/test_tool_search_transform.py @@ -18,16 +18,20 @@ """Tests for MCP tool search transform configuration and application.""" from types import SimpleNamespace -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock, patch from fastmcp.server.transforms.search import BM25SearchTransform, RegexSearchTransform +from flask import Flask, g +from superset.mcp_service.auth import CLASS_PERMISSION_ATTR, METHOD_PERMISSION_ATTR from superset.mcp_service.mcp_config import MCP_TOOL_SEARCH_CONFIG +from superset.mcp_service.privacy import requires_data_model_metadata_access from superset.mcp_service.server import ( _apply_tool_search_transform, _compact_schema, _create_search_result_serializer, _extract_parameter_names, + _filter_tools_by_current_user_permission, _fix_call_tool_arguments, _normalize_call_tool_arguments, _serialize_tools_without_output_schema, @@ -843,6 +847,166 @@ def test_apply_transform_uses_compact_serializer(): ) +def test_tool_search_permission_filter_hides_disallowed_tools(): + """Search candidates exclude tools the current user cannot execute.""" + app = Flask(__name__) + app.config["MCP_RBAC_ENABLED"] = True + + def permitted_tool(): + pass + + def denied_tool(): + pass + + for func in (permitted_tool, denied_tool): + setattr(func, CLASS_PERMISSION_ATTR, "Dataset") + setattr(func, METHOD_PERMISSION_ATTR, "get_drill_info") + + permitted = SimpleNamespace(fn=permitted_tool) + denied = SimpleNamespace(fn=denied_tool) + public = SimpleNamespace(fn=lambda: None) + + with app.app_context(): + g.user = SimpleNamespace(username="viewer") + with patch( + "superset.security_manager", new_callable=MagicMock + ) as security_manager: + security_manager.can_access.side_effect = [True, False] + + result = _filter_tools_by_current_user_permission( + [permitted, denied, public] + ) + + assert result == [permitted, public] + security_manager.can_access.assert_any_call("can_get_drill_info", "Dataset") + + +def test_tool_search_permission_filter_hides_protected_tools_without_user() -> None: + """Protected tools are hidden from search when no Flask user is present.""" + app = Flask(__name__) + app.config["MCP_RBAC_ENABLED"] = True + + def protected_tool(): + pass + + setattr(protected_tool, CLASS_PERMISSION_ATTR, "Dataset") + setattr(protected_tool, METHOD_PERMISSION_ATTR, "get_drill_info") + + protected = SimpleNamespace(fn=protected_tool) + public = SimpleNamespace(fn=lambda: None) + + with app.app_context(): + result = _filter_tools_by_current_user_permission([protected, public]) + + assert result == [public] + + +def test_tool_search_filter_hides_metadata_tools_without_access() -> None: + """Privacy-marked tools are hidden even if broad Dataset read exists.""" + app = Flask(__name__) + app.config["MCP_RBAC_ENABLED"] = True + + @requires_data_model_metadata_access + def metadata_tool(): + pass + + metadata = SimpleNamespace(fn=metadata_tool) + public = SimpleNamespace(fn=lambda: None) + + with app.app_context(): + g.user = SimpleNamespace(username="viewer") + with patch( + "superset.mcp_service.server.user_can_view_data_model_metadata", + return_value=False, + ): + result = _filter_tools_by_current_user_permission([metadata, public]) + + assert result == [public] + + +def test_tool_search_permission_filter_still_applies_rbac_to_metadata_tools() -> None: + """Privacy-marked tools still require the underlying tool permission.""" + app = Flask(__name__) + app.config["MCP_RBAC_ENABLED"] = True + + @requires_data_model_metadata_access + def metadata_tool(): + pass + + setattr(metadata_tool, CLASS_PERMISSION_ATTR, "Dataset") + setattr(metadata_tool, METHOD_PERMISSION_ATTR, "get_drill_info") + + metadata = SimpleNamespace(fn=metadata_tool) + public = SimpleNamespace(fn=lambda: None) + + with app.app_context(): + g.user = SimpleNamespace(username="viewer") + with ( + patch( + "superset.mcp_service.server.user_can_view_data_model_metadata", + return_value=True, + ), + patch("superset.security_manager", new_callable=Mock) as security_manager, + ): + security_manager.can_access.return_value = False + result = _filter_tools_by_current_user_permission([metadata, public]) + + assert result == [public] + + +def test_tool_search_permission_filter_resolves_user_from_request() -> None: + """Search filtering resolves the current user when g.user is not already set.""" + app = Flask(__name__) + app.config["MCP_RBAC_ENABLED"] = True + + def protected_tool(): + pass + + setattr(protected_tool, CLASS_PERMISSION_ATTR, "Dataset") + setattr(protected_tool, METHOD_PERMISSION_ATTR, "read") + + protected = SimpleNamespace(fn=protected_tool) + + with app.app_context(): + with ( + patch( + "superset.mcp_service.auth.get_user_from_request", + return_value=SimpleNamespace(username="viewer"), + ), + patch("superset.security_manager", new_callable=Mock) as security_manager, + ): + security_manager.can_access.return_value = True + result = _filter_tools_by_current_user_permission([protected]) + + assert result == [protected] + + +def test_tool_search_permission_filter_keeps_get_schema_visible_without_metadata() -> ( + None +): + """get_schema remains discoverable when only safe model types are available.""" + from superset.mcp_service.system.tool.get_schema import get_schema + + app = Flask(__name__) + app.config["MCP_RBAC_ENABLED"] = True + + schema_tool = SimpleNamespace(fn=get_schema) + + with app.app_context(): + g.user = SimpleNamespace(username="viewer") + with ( + patch( + "superset.mcp_service.server.user_can_view_data_model_metadata", + return_value=False, + ), + patch("superset.security_manager", new_callable=Mock) as security_manager, + ): + security_manager.can_access.return_value = True + result = _filter_tools_by_current_user_permission([schema_tool]) + + assert result == [schema_tool] + + # -- _extract_parameter_names tests --