diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index 9624964bc72..f43df61e487 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -76,7 +76,7 @@ Schema Discovery: - get_schema: Get schema metadata for chart/dataset/dashboard (columns, filters) System Information: -- get_instance_info: Get instance-wide statistics and metadata +- get_instance_info: Get instance-wide statistics, metadata, and current user identity - health_check: Simple health check tool (takes NO parameters, call without arguments) Available Resources: @@ -95,6 +95,13 @@ To create a chart: 3. generate_explore_link(dataset_id, config) -> preview interactively 4. generate_chart(dataset_id, config, save_chart=True) -> save permanently +To find your own charts/dashboards: +1. get_instance_info -> get current_user.id +2. list_charts(filters=[{{"col": "created_by_fk", + "opr": "eq", "value": current_user.id}}]) +3. Or: list_dashboards(filters=[{{"col": "created_by_fk", + "opr": "eq", "value": current_user.id}}]) + To explore data with SQL: 1. get_instance_info -> find database_id 2. execute_sql(database_id, sql) -> run query @@ -127,6 +134,10 @@ Query Examples: - List time series charts: filters=[{{"col": "viz_type", "opr": "sw", "value": "echarts_timeseries"}}] - Search by name: search="sales" +- My charts (use current_user.id from get_instance_info): + filters=[{{"col": "created_by_fk", "opr": "eq", "value": }}] +- My dashboards: + filters=[{{"col": "created_by_fk", "opr": "eq", "value": }}] General usage tips: - All listing tools use 1-based pagination (first page is 1) @@ -143,6 +154,9 @@ Input format: If you are unsure which tool to use, start with get_instance_info or use the quickstart prompt for an interactive guide. + +When you first connect, call get_instance_info to learn the user's identity. +Greet them by their first name (from current_user) and offer to help. """ diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index d02f3c7de02..a27c3e3ceef 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -270,10 +270,12 @@ class ChartFilter(ColumnOperator): "slice_name", "viz_type", "datasource_name", + "created_by_fk", ] = Field( ..., description="Column to filter on. Use get_schema(model_type='chart') for " - "available filter columns.", + "available filter columns. Use created_by_fk with the user ID from " + "get_instance_info's current_user to find charts created by a specific user.", ) opr: ColumnOperatorEnum = Field( ..., diff --git a/superset/mcp_service/dashboard/schemas.py b/superset/mcp_service/dashboard/schemas.py index d9dec129548..54c4e724471 100644 --- a/superset/mcp_service/dashboard/schemas.py +++ b/superset/mcp_service/dashboard/schemas.py @@ -163,10 +163,16 @@ class DashboardFilter(ColumnOperator): "dashboard_title", "published", "favorite", + "created_by_fk", ] = Field( ..., - description="Column to filter on. Use get_schema(model_type='dashboard') for " - "available filter columns.", + description=( + "Column to filter on. Use " + "get_schema(model_type='dashboard') for available " + "filter columns. Use created_by_fk with the user " + "ID from get_instance_info's current_user to find " + "dashboards created by a specific user." + ), ) opr: ColumnOperatorEnum = Field( ..., diff --git a/superset/mcp_service/system/schemas.py b/superset/mcp_service/system/schemas.py index 5d4a8ab7db0..93b676a0e1a 100644 --- a/superset/mcp_service/system/schemas.py +++ b/superset/mcp_service/system/schemas.py @@ -22,6 +22,8 @@ This module contains Pydantic models for serializing Superset instance metadata system-level info. """ +from __future__ import annotations + from datetime import datetime from typing import Dict, List @@ -122,6 +124,11 @@ class InstanceInfo(BaseModel): popular_content: PopularContent = Field( ..., description="Popular content information" ) + current_user: UserInfo | None = Field( + None, + description="The authenticated user making the request. " + "Use current_user.id with created_by_fk filter to find your own assets.", + ) timestamp: datetime = Field(..., description="Response timestamp") diff --git a/superset/mcp_service/system/tool/get_instance_info.py b/superset/mcp_service/system/tool/get_instance_info.py index 7c142acabfe..8d383ae167f 100644 --- a/superset/mcp_service/system/tool/get_instance_info.py +++ b/superset/mcp_service/system/tool/get_instance_info.py @@ -30,6 +30,7 @@ from superset.mcp_service.mcp_core import InstanceInfoCore from superset.mcp_service.system.schemas import ( GetSupersetInstanceInfoRequest, InstanceInfo, + UserInfo, ) from superset.mcp_service.system.system_utils import ( calculate_dashboard_breakdown, @@ -81,6 +82,8 @@ def get_instance_info( """ try: # Import DAOs at runtime to avoid circular imports + from flask import g + from superset.daos.chart import ChartDAO from superset.daos.dashboard import DashboardDAO from superset.daos.database import DatabaseDAO @@ -100,7 +103,20 @@ def get_instance_info( # Run the configurable core with event_logger.log_context(action="mcp.get_instance_info.metrics"): - return _instance_info_core.run_tool() + result = _instance_info_core.run_tool() + + # Attach the authenticated user's identity to the response + user = getattr(g, "user", None) + if user is not None: + result.current_user = UserInfo( + id=getattr(user, "id", None), + username=getattr(user, "username", None), + first_name=getattr(user, "first_name", None), + last_name=getattr(user, "last_name", None), + email=getattr(user, "email", None), + ) + + return result except Exception as e: error_msg = f"Unexpected error in instance info: {str(e)}" diff --git a/superset/mcp_service/utils/schema_utils.py b/superset/mcp_service/utils/schema_utils.py index 3c0a617f4fc..62992b8594d 100644 --- a/superset/mcp_service/utils/schema_utils.py +++ b/superset/mcp_service/utils/schema_utils.py @@ -446,10 +446,8 @@ def parse_request( def decorator(func: Callable[..., Any]) -> Callable[..., Any]: import types - parse_enabled = _is_parse_request_enabled() - def _maybe_parse(request: Any) -> Any: - if parse_enabled: + if _is_parse_request_enabled(): return parse_json_or_model(request, request_class, "request") return request @@ -501,7 +499,7 @@ def parse_request( # Copy docstring from original function (not wrapper, which has no docstring) new_wrapper.__doc__ = func.__doc__ - request_annotation = str | request_class if parse_enabled else request_class + request_annotation = str | request_class _apply_signature_for_fastmcp(new_wrapper, func, request_annotation) return new_wrapper diff --git a/tests/unit_tests/mcp_service/system/tool/test_get_current_user.py b/tests/unit_tests/mcp_service/system/tool/test_get_current_user.py new file mode 100644 index 00000000000..91d8fbc6e8b --- /dev/null +++ b/tests/unit_tests/mcp_service/system/tool/test_get_current_user.py @@ -0,0 +1,363 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for current_user in get_instance_info and created_by_fk filtering.""" + +from unittest.mock import Mock, patch + +import pytest +from fastmcp import Client +from pydantic import ValidationError + +from superset.mcp_service.app import mcp +from superset.mcp_service.chart.schemas import ChartFilter +from superset.mcp_service.dashboard.schemas import DashboardFilter +from superset.mcp_service.system.schemas import InstanceInfo, UserInfo +from superset.utils import json + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mcp_server(): + return mcp + + +@pytest.fixture(autouse=True) +def mock_auth(): + """Mock authentication for all tests.""" + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + +# --------------------------------------------------------------------------- +# Helper to build a minimal InstanceInfo +# --------------------------------------------------------------------------- +def _make_instance_info(**kwargs): + """Build a minimal InstanceInfo with defaults; override with kwargs.""" + from datetime import datetime, timezone + + from superset.mcp_service.system.schemas import ( + DashboardBreakdown, + DatabaseBreakdown, + InstanceSummary, + PopularContent, + RecentActivity, + ) + + defaults = { + "instance_summary": InstanceSummary( + total_dashboards=0, + total_charts=0, + total_datasets=0, + total_databases=0, + total_users=0, + total_roles=0, + total_tags=0, + avg_charts_per_dashboard=0.0, + ), + "recent_activity": RecentActivity( + dashboards_created_last_30_days=0, + charts_created_last_30_days=0, + datasets_created_last_30_days=0, + dashboards_modified_last_7_days=0, + charts_modified_last_7_days=0, + datasets_modified_last_7_days=0, + ), + "dashboard_breakdown": DashboardBreakdown( + published=0, + unpublished=0, + certified=0, + with_charts=0, + without_charts=0, + ), + "database_breakdown": DatabaseBreakdown(by_type={}), + "popular_content": PopularContent(top_tags=[], top_creators=[]), + "timestamp": datetime.now(timezone.utc), + } + defaults.update(kwargs) + return InstanceInfo(**defaults) + + +# --------------------------------------------------------------------------- +# Schema-level tests: UserInfo +# --------------------------------------------------------------------------- + + +def test_user_info_all_fields(): + """Test UserInfo with all fields populated.""" + user = UserInfo( + id=42, + username="sophie", + first_name="Sophie", + last_name="Test", + email="sophie@example.com", + ) + assert user.id == 42 + assert user.username == "sophie" + assert user.first_name == "Sophie" + assert user.last_name == "Test" + assert user.email == "sophie@example.com" + + +def test_user_info_with_minimal_fields(): + """Test UserInfo with only required fields (all optional).""" + user = UserInfo(id=1, username="admin") + assert user.id == 1 + assert user.username == "admin" + assert user.first_name is None + assert user.last_name is None + assert user.email is None + + +def test_user_info_serialization_roundtrip(): + """Test UserInfo can be serialized to dict and back.""" + user = UserInfo(id=7, username="testuser", first_name="Test", email="t@example.com") + data = user.model_dump() + assert data["id"] == 7 + assert data["username"] == "testuser" + assert data["first_name"] == "Test" + assert data["last_name"] is None + assert data["email"] == "t@example.com" + + # Reconstruct + user2 = UserInfo(**data) + assert user2 == user + + +# --------------------------------------------------------------------------- +# Schema-level tests: InstanceInfo.current_user +# --------------------------------------------------------------------------- + + +def test_instance_info_current_user_default_none(): + """Test that InstanceInfo.current_user defaults to None.""" + info = _make_instance_info() + assert info.current_user is None + + +def test_instance_info_with_current_user(): + """Test that InstanceInfo accepts a current_user UserInfo.""" + user = UserInfo( + id=42, + username="sophie", + first_name="Sophie", + last_name="Test", + email="sophie@example.com", + ) + info = _make_instance_info(current_user=user) + assert info.current_user is not None + assert info.current_user.id == 42 + assert info.current_user.username == "sophie" + assert info.current_user.first_name == "Sophie" + assert info.current_user.last_name == "Test" + assert info.current_user.email == "sophie@example.com" + + +def test_instance_info_current_user_in_serialized_output(): + """Test current_user appears when InstanceInfo is serialized to JSON.""" + user = UserInfo(id=1, username="admin", first_name="Admin") + info = _make_instance_info(current_user=user) + data = json.loads(info.model_dump_json()) + assert "current_user" in data + assert data["current_user"]["id"] == 1 + assert data["current_user"]["username"] == "admin" + assert data["current_user"]["first_name"] == "Admin" + + +def test_instance_info_none_current_user_in_serialized_output(): + """Test current_user is null when not set in serialized output.""" + info = _make_instance_info() + data = json.loads(info.model_dump_json()) + assert "current_user" in data + assert data["current_user"] is None + + +# --------------------------------------------------------------------------- +# Tool-level tests: get_instance_info via MCP Client +# --------------------------------------------------------------------------- + + +class TestGetInstanceInfoCurrentUserViaMCP: + """Test get_instance_info tool returns current_user via MCP client.""" + + @pytest.mark.asyncio + async def test_get_instance_info_returns_current_user(self, mcp_server): + """Test that get_instance_info populates current_user from g.user.""" + # Patch run_tool on the CLASS so all instances (including the + # module-level _instance_info_core) use the mock. We avoid patching + # via dotted module path because __init__.py re-exports + # get_instance_info as a function, which shadows the submodule name + # and breaks mock resolution on Python 3.10. + from superset.mcp_service.mcp_core import InstanceInfoCore + + mock_g_user = Mock() + mock_g_user.id = 5 + mock_g_user.username = "sophie" + mock_g_user.first_name = "Sophie" + mock_g_user.last_name = "Beaumont" + mock_g_user.email = "sophie@preset.io" + + with ( + patch.object( + InstanceInfoCore, + "run_tool", + return_value=_make_instance_info(), + ), + patch("flask.g") as mock_g, + ): + mock_g.user = mock_g_user + + async with Client(mcp_server) as client: + result = await client.call_tool("get_instance_info", {"request": {}}) + + data = json.loads(result.content[0].text) + assert "current_user" in data + cu = data["current_user"] + assert cu["id"] == 5 + assert cu["username"] == "sophie" + assert cu["first_name"] == "Sophie" + assert cu["last_name"] == "Beaumont" + assert cu["email"] == "sophie@preset.io" + + @pytest.mark.asyncio + async def test_get_instance_info_no_user_returns_null(self, mcp_server): + """Test that current_user is null when g.user is not set.""" + from superset.mcp_service.mcp_core import InstanceInfoCore + + with ( + patch.object( + InstanceInfoCore, + "run_tool", + return_value=_make_instance_info(), + ), + patch("flask.g") as mock_g, + ): + # Simulate no user on g so getattr(g, "user", None) returns None + mock_g.user = None + + async with Client(mcp_server) as client: + result = await client.call_tool("get_instance_info", {"request": {}}) + + data = json.loads(result.content[0].text) + assert data["current_user"] is None + + @pytest.mark.asyncio + async def test_get_instance_info_user_missing_optional_attrs(self, mcp_server): + """Test current_user when g.user is missing optional attributes.""" + from superset.mcp_service.mcp_core import InstanceInfoCore + + # User object with only id and username (no first_name, etc.) + mock_g_user = Mock(spec=["id", "username"]) + mock_g_user.id = 99 + mock_g_user.username = "bot" + + with ( + patch.object( + InstanceInfoCore, + "run_tool", + return_value=_make_instance_info(), + ), + patch("flask.g") as mock_g, + ): + mock_g.user = mock_g_user + + async with Client(mcp_server) as client: + result = await client.call_tool("get_instance_info", {"request": {}}) + + data = json.loads(result.content[0].text) + cu = data["current_user"] + assert cu["id"] == 99 + assert cu["username"] == "bot" + # Missing attrs should be None via getattr default + assert cu["first_name"] is None + assert cu["last_name"] is None + assert cu["email"] is None + + +# --------------------------------------------------------------------------- +# Filter schema tests: created_by_fk +# --------------------------------------------------------------------------- + + +def test_chart_filter_accepts_created_by_fk(): + """Test that ChartFilter accepts created_by_fk as a valid column.""" + f = ChartFilter(col="created_by_fk", opr="eq", value=42) + assert f.col == "created_by_fk" + assert f.opr == "eq" + assert f.value == 42 + + +def test_chart_filter_created_by_fk_with_ne_operator(): + """Test created_by_fk with 'ne' (not equal) operator.""" + f = ChartFilter(col="created_by_fk", opr="ne", value=1) + assert f.col == "created_by_fk" + assert f.opr == "ne" + assert f.value == 1 + + +def test_chart_filter_rejects_invalid_column(): + """Test that ChartFilter rejects invalid column names.""" + with pytest.raises(ValidationError): + ChartFilter(col="nonexistent_column", opr="eq", value=42) + + +def test_dashboard_filter_accepts_created_by_fk(): + """Test that DashboardFilter accepts created_by_fk as a valid column.""" + f = DashboardFilter(col="created_by_fk", opr="eq", value=42) + assert f.col == "created_by_fk" + assert f.opr == "eq" + assert f.value == 42 + + +def test_dashboard_filter_created_by_fk_with_ne_operator(): + """Test created_by_fk with 'ne' (not equal) operator on dashboards.""" + f = DashboardFilter(col="created_by_fk", opr="ne", value=1) + assert f.col == "created_by_fk" + assert f.opr == "ne" + assert f.value == 1 + + +def test_dashboard_filter_rejects_invalid_column(): + """Test that DashboardFilter rejects invalid column names.""" + with pytest.raises(ValidationError): + DashboardFilter(col="nonexistent_column", opr="eq", value=42) + + +# --------------------------------------------------------------------------- +# Existing filter columns still work +# --------------------------------------------------------------------------- + + +def test_chart_filter_existing_columns_still_work(): + """Test that pre-existing chart filter columns are not broken.""" + for col in ("slice_name", "viz_type", "datasource_name"): + f = ChartFilter(col=col, opr="eq", value="test") + assert f.col == col + + +def test_dashboard_filter_existing_columns_still_work(): + """Test that pre-existing dashboard filter columns are not broken.""" + for col in ("dashboard_title", "published", "favorite"): + f = DashboardFilter(col=col, opr="eq", value="test") + assert f.col == col diff --git a/tests/unit_tests/mcp_service/system/tool/test_get_schema.py b/tests/unit_tests/mcp_service/system/tool/test_get_schema.py index e00c1cb9cd7..38c30a3e47d 100644 --- a/tests/unit_tests/mcp_service/system/tool/test_get_schema.py +++ b/tests/unit_tests/mcp_service/system/tool/test_get_schema.py @@ -236,9 +236,15 @@ class TestGetSchemaToolViaClient: assert "dashboard_title" in info["sortable_columns"] assert "changed_on" in info["sortable_columns"] + @patch( + "superset.mcp_service.utils.schema_utils._is_parse_request_enabled", + return_value=True, + ) @patch("superset.daos.chart.ChartDAO.get_filterable_columns_and_operators") @pytest.mark.asyncio - async def test_get_schema_with_json_string_request(self, mock_filters, mcp_server): + async def test_get_schema_with_json_string_request( + self, mock_filters, mock_parse_enabled, mcp_server + ): """Test get_schema accepts JSON string request (Claude Code compatibility).""" mock_filters.return_value = {"slice_name": ["eq"]} diff --git a/tests/unit_tests/mcp_service/utils/test_schema_utils.py b/tests/unit_tests/mcp_service/utils/test_schema_utils.py index 4e592159c07..95b4c2bbbe4 100644 --- a/tests/unit_tests/mcp_service/utils/test_schema_utils.py +++ b/tests/unit_tests/mcp_service/utils/test_schema_utils.py @@ -354,6 +354,17 @@ class TestParseRequestDecorator: name: str count: int + @pytest.fixture(autouse=True) + def _enable_parse_request(self): + """Ensure MCP_PARSE_REQUEST_ENABLED=True for all parsing tests.""" + from unittest.mock import patch + + with patch( + "superset.mcp_service.utils.schema_utils._is_parse_request_enabled", + return_value=True, + ): + yield + def test_decorator_with_json_string_async(self): """Should parse JSON string request in async function.""" from unittest.mock import MagicMock, patch