mirror of
https://github.com/apache/superset.git
synced 2026-04-07 10:31:50 +00:00
feat(mcp): expose current user identity in get_instance_info and add created_by_fk filter (#37967)
This commit is contained in:
@@ -76,7 +76,7 @@ Schema Discovery:
|
||||
- get_schema: Get schema metadata for chart/dataset/dashboard (columns, filters)
|
||||
|
||||
System Information:
|
||||
- get_instance_info: Get instance-wide statistics and metadata
|
||||
- get_instance_info: Get instance-wide statistics, metadata, and current user identity
|
||||
- health_check: Simple health check tool (takes NO parameters, call without arguments)
|
||||
|
||||
Available Resources:
|
||||
@@ -95,6 +95,13 @@ To create a chart:
|
||||
3. generate_explore_link(dataset_id, config) -> preview interactively
|
||||
4. generate_chart(dataset_id, config, save_chart=True) -> save permanently
|
||||
|
||||
To find your own charts/dashboards:
|
||||
1. get_instance_info -> get current_user.id
|
||||
2. list_charts(filters=[{{"col": "created_by_fk",
|
||||
"opr": "eq", "value": current_user.id}}])
|
||||
3. Or: list_dashboards(filters=[{{"col": "created_by_fk",
|
||||
"opr": "eq", "value": current_user.id}}])
|
||||
|
||||
To explore data with SQL:
|
||||
1. get_instance_info -> find database_id
|
||||
2. execute_sql(database_id, sql) -> run query
|
||||
@@ -127,6 +134,10 @@ Query Examples:
|
||||
- List time series charts:
|
||||
filters=[{{"col": "viz_type", "opr": "sw", "value": "echarts_timeseries"}}]
|
||||
- Search by name: search="sales"
|
||||
- My charts (use current_user.id from get_instance_info):
|
||||
filters=[{{"col": "created_by_fk", "opr": "eq", "value": <user_id>}}]
|
||||
- My dashboards:
|
||||
filters=[{{"col": "created_by_fk", "opr": "eq", "value": <user_id>}}]
|
||||
|
||||
General usage tips:
|
||||
- All listing tools use 1-based pagination (first page is 1)
|
||||
@@ -143,6 +154,9 @@ Input format:
|
||||
|
||||
If you are unsure which tool to use, start with get_instance_info
|
||||
or use the quickstart prompt for an interactive guide.
|
||||
|
||||
When you first connect, call get_instance_info to learn the user's identity.
|
||||
Greet them by their first name (from current_user) and offer to help.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@@ -270,10 +270,12 @@ class ChartFilter(ColumnOperator):
|
||||
"slice_name",
|
||||
"viz_type",
|
||||
"datasource_name",
|
||||
"created_by_fk",
|
||||
] = Field(
|
||||
...,
|
||||
description="Column to filter on. Use get_schema(model_type='chart') for "
|
||||
"available filter columns.",
|
||||
"available filter columns. Use created_by_fk with the user ID from "
|
||||
"get_instance_info's current_user to find charts created by a specific user.",
|
||||
)
|
||||
opr: ColumnOperatorEnum = Field(
|
||||
...,
|
||||
|
||||
@@ -163,10 +163,16 @@ class DashboardFilter(ColumnOperator):
|
||||
"dashboard_title",
|
||||
"published",
|
||||
"favorite",
|
||||
"created_by_fk",
|
||||
] = Field(
|
||||
...,
|
||||
description="Column to filter on. Use get_schema(model_type='dashboard') for "
|
||||
"available filter columns.",
|
||||
description=(
|
||||
"Column to filter on. Use "
|
||||
"get_schema(model_type='dashboard') for available "
|
||||
"filter columns. Use created_by_fk with the user "
|
||||
"ID from get_instance_info's current_user to find "
|
||||
"dashboards created by a specific user."
|
||||
),
|
||||
)
|
||||
opr: ColumnOperatorEnum = Field(
|
||||
...,
|
||||
|
||||
@@ -22,6 +22,8 @@ This module contains Pydantic models for serializing Superset instance metadata
|
||||
system-level info.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict, List
|
||||
|
||||
@@ -122,6 +124,11 @@ class InstanceInfo(BaseModel):
|
||||
popular_content: PopularContent = Field(
|
||||
..., description="Popular content information"
|
||||
)
|
||||
current_user: UserInfo | None = Field(
|
||||
None,
|
||||
description="The authenticated user making the request. "
|
||||
"Use current_user.id with created_by_fk filter to find your own assets.",
|
||||
)
|
||||
timestamp: datetime = Field(..., description="Response timestamp")
|
||||
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ from superset.mcp_service.mcp_core import InstanceInfoCore
|
||||
from superset.mcp_service.system.schemas import (
|
||||
GetSupersetInstanceInfoRequest,
|
||||
InstanceInfo,
|
||||
UserInfo,
|
||||
)
|
||||
from superset.mcp_service.system.system_utils import (
|
||||
calculate_dashboard_breakdown,
|
||||
@@ -81,6 +82,8 @@ def get_instance_info(
|
||||
"""
|
||||
try:
|
||||
# Import DAOs at runtime to avoid circular imports
|
||||
from flask import g
|
||||
|
||||
from superset.daos.chart import ChartDAO
|
||||
from superset.daos.dashboard import DashboardDAO
|
||||
from superset.daos.database import DatabaseDAO
|
||||
@@ -100,7 +103,20 @@ def get_instance_info(
|
||||
|
||||
# Run the configurable core
|
||||
with event_logger.log_context(action="mcp.get_instance_info.metrics"):
|
||||
return _instance_info_core.run_tool()
|
||||
result = _instance_info_core.run_tool()
|
||||
|
||||
# Attach the authenticated user's identity to the response
|
||||
user = getattr(g, "user", None)
|
||||
if user is not None:
|
||||
result.current_user = UserInfo(
|
||||
id=getattr(user, "id", None),
|
||||
username=getattr(user, "username", None),
|
||||
first_name=getattr(user, "first_name", None),
|
||||
last_name=getattr(user, "last_name", None),
|
||||
email=getattr(user, "email", None),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error in instance info: {str(e)}"
|
||||
|
||||
@@ -446,10 +446,8 @@ def parse_request(
|
||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
import types
|
||||
|
||||
parse_enabled = _is_parse_request_enabled()
|
||||
|
||||
def _maybe_parse(request: Any) -> Any:
|
||||
if parse_enabled:
|
||||
if _is_parse_request_enabled():
|
||||
return parse_json_or_model(request, request_class, "request")
|
||||
return request
|
||||
|
||||
@@ -501,7 +499,7 @@ def parse_request(
|
||||
# Copy docstring from original function (not wrapper, which has no docstring)
|
||||
new_wrapper.__doc__ = func.__doc__
|
||||
|
||||
request_annotation = str | request_class if parse_enabled else request_class
|
||||
request_annotation = str | request_class
|
||||
_apply_signature_for_fastmcp(new_wrapper, func, request_annotation)
|
||||
|
||||
return new_wrapper
|
||||
|
||||
@@ -0,0 +1,363 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Tests for current_user in get_instance_info and created_by_fk filtering."""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
from pydantic import ValidationError
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.mcp_service.chart.schemas import ChartFilter
|
||||
from superset.mcp_service.dashboard.schemas import DashboardFilter
|
||||
from superset.mcp_service.system.schemas import InstanceInfo, UserInfo
|
||||
from superset.utils import json
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth():
|
||||
"""Mock authentication for all tests."""
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper to build a minimal InstanceInfo
|
||||
# ---------------------------------------------------------------------------
|
||||
def _make_instance_info(**kwargs):
|
||||
"""Build a minimal InstanceInfo with defaults; override with kwargs."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from superset.mcp_service.system.schemas import (
|
||||
DashboardBreakdown,
|
||||
DatabaseBreakdown,
|
||||
InstanceSummary,
|
||||
PopularContent,
|
||||
RecentActivity,
|
||||
)
|
||||
|
||||
defaults = {
|
||||
"instance_summary": InstanceSummary(
|
||||
total_dashboards=0,
|
||||
total_charts=0,
|
||||
total_datasets=0,
|
||||
total_databases=0,
|
||||
total_users=0,
|
||||
total_roles=0,
|
||||
total_tags=0,
|
||||
avg_charts_per_dashboard=0.0,
|
||||
),
|
||||
"recent_activity": RecentActivity(
|
||||
dashboards_created_last_30_days=0,
|
||||
charts_created_last_30_days=0,
|
||||
datasets_created_last_30_days=0,
|
||||
dashboards_modified_last_7_days=0,
|
||||
charts_modified_last_7_days=0,
|
||||
datasets_modified_last_7_days=0,
|
||||
),
|
||||
"dashboard_breakdown": DashboardBreakdown(
|
||||
published=0,
|
||||
unpublished=0,
|
||||
certified=0,
|
||||
with_charts=0,
|
||||
without_charts=0,
|
||||
),
|
||||
"database_breakdown": DatabaseBreakdown(by_type={}),
|
||||
"popular_content": PopularContent(top_tags=[], top_creators=[]),
|
||||
"timestamp": datetime.now(timezone.utc),
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return InstanceInfo(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema-level tests: UserInfo
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_user_info_all_fields():
|
||||
"""Test UserInfo with all fields populated."""
|
||||
user = UserInfo(
|
||||
id=42,
|
||||
username="sophie",
|
||||
first_name="Sophie",
|
||||
last_name="Test",
|
||||
email="sophie@example.com",
|
||||
)
|
||||
assert user.id == 42
|
||||
assert user.username == "sophie"
|
||||
assert user.first_name == "Sophie"
|
||||
assert user.last_name == "Test"
|
||||
assert user.email == "sophie@example.com"
|
||||
|
||||
|
||||
def test_user_info_with_minimal_fields():
|
||||
"""Test UserInfo with only required fields (all optional)."""
|
||||
user = UserInfo(id=1, username="admin")
|
||||
assert user.id == 1
|
||||
assert user.username == "admin"
|
||||
assert user.first_name is None
|
||||
assert user.last_name is None
|
||||
assert user.email is None
|
||||
|
||||
|
||||
def test_user_info_serialization_roundtrip():
|
||||
"""Test UserInfo can be serialized to dict and back."""
|
||||
user = UserInfo(id=7, username="testuser", first_name="Test", email="t@example.com")
|
||||
data = user.model_dump()
|
||||
assert data["id"] == 7
|
||||
assert data["username"] == "testuser"
|
||||
assert data["first_name"] == "Test"
|
||||
assert data["last_name"] is None
|
||||
assert data["email"] == "t@example.com"
|
||||
|
||||
# Reconstruct
|
||||
user2 = UserInfo(**data)
|
||||
assert user2 == user
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema-level tests: InstanceInfo.current_user
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_instance_info_current_user_default_none():
|
||||
"""Test that InstanceInfo.current_user defaults to None."""
|
||||
info = _make_instance_info()
|
||||
assert info.current_user is None
|
||||
|
||||
|
||||
def test_instance_info_with_current_user():
|
||||
"""Test that InstanceInfo accepts a current_user UserInfo."""
|
||||
user = UserInfo(
|
||||
id=42,
|
||||
username="sophie",
|
||||
first_name="Sophie",
|
||||
last_name="Test",
|
||||
email="sophie@example.com",
|
||||
)
|
||||
info = _make_instance_info(current_user=user)
|
||||
assert info.current_user is not None
|
||||
assert info.current_user.id == 42
|
||||
assert info.current_user.username == "sophie"
|
||||
assert info.current_user.first_name == "Sophie"
|
||||
assert info.current_user.last_name == "Test"
|
||||
assert info.current_user.email == "sophie@example.com"
|
||||
|
||||
|
||||
def test_instance_info_current_user_in_serialized_output():
|
||||
"""Test current_user appears when InstanceInfo is serialized to JSON."""
|
||||
user = UserInfo(id=1, username="admin", first_name="Admin")
|
||||
info = _make_instance_info(current_user=user)
|
||||
data = json.loads(info.model_dump_json())
|
||||
assert "current_user" in data
|
||||
assert data["current_user"]["id"] == 1
|
||||
assert data["current_user"]["username"] == "admin"
|
||||
assert data["current_user"]["first_name"] == "Admin"
|
||||
|
||||
|
||||
def test_instance_info_none_current_user_in_serialized_output():
|
||||
"""Test current_user is null when not set in serialized output."""
|
||||
info = _make_instance_info()
|
||||
data = json.loads(info.model_dump_json())
|
||||
assert "current_user" in data
|
||||
assert data["current_user"] is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool-level tests: get_instance_info via MCP Client
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetInstanceInfoCurrentUserViaMCP:
|
||||
"""Test get_instance_info tool returns current_user via MCP client."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_instance_info_returns_current_user(self, mcp_server):
|
||||
"""Test that get_instance_info populates current_user from g.user."""
|
||||
# Patch run_tool on the CLASS so all instances (including the
|
||||
# module-level _instance_info_core) use the mock. We avoid patching
|
||||
# via dotted module path because __init__.py re-exports
|
||||
# get_instance_info as a function, which shadows the submodule name
|
||||
# and breaks mock resolution on Python 3.10.
|
||||
from superset.mcp_service.mcp_core import InstanceInfoCore
|
||||
|
||||
mock_g_user = Mock()
|
||||
mock_g_user.id = 5
|
||||
mock_g_user.username = "sophie"
|
||||
mock_g_user.first_name = "Sophie"
|
||||
mock_g_user.last_name = "Beaumont"
|
||||
mock_g_user.email = "sophie@preset.io"
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
InstanceInfoCore,
|
||||
"run_tool",
|
||||
return_value=_make_instance_info(),
|
||||
),
|
||||
patch("flask.g") as mock_g,
|
||||
):
|
||||
mock_g.user = mock_g_user
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("get_instance_info", {"request": {}})
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert "current_user" in data
|
||||
cu = data["current_user"]
|
||||
assert cu["id"] == 5
|
||||
assert cu["username"] == "sophie"
|
||||
assert cu["first_name"] == "Sophie"
|
||||
assert cu["last_name"] == "Beaumont"
|
||||
assert cu["email"] == "sophie@preset.io"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_instance_info_no_user_returns_null(self, mcp_server):
|
||||
"""Test that current_user is null when g.user is not set."""
|
||||
from superset.mcp_service.mcp_core import InstanceInfoCore
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
InstanceInfoCore,
|
||||
"run_tool",
|
||||
return_value=_make_instance_info(),
|
||||
),
|
||||
patch("flask.g") as mock_g,
|
||||
):
|
||||
# Simulate no user on g so getattr(g, "user", None) returns None
|
||||
mock_g.user = None
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("get_instance_info", {"request": {}})
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["current_user"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_instance_info_user_missing_optional_attrs(self, mcp_server):
|
||||
"""Test current_user when g.user is missing optional attributes."""
|
||||
from superset.mcp_service.mcp_core import InstanceInfoCore
|
||||
|
||||
# User object with only id and username (no first_name, etc.)
|
||||
mock_g_user = Mock(spec=["id", "username"])
|
||||
mock_g_user.id = 99
|
||||
mock_g_user.username = "bot"
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
InstanceInfoCore,
|
||||
"run_tool",
|
||||
return_value=_make_instance_info(),
|
||||
),
|
||||
patch("flask.g") as mock_g,
|
||||
):
|
||||
mock_g.user = mock_g_user
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("get_instance_info", {"request": {}})
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
cu = data["current_user"]
|
||||
assert cu["id"] == 99
|
||||
assert cu["username"] == "bot"
|
||||
# Missing attrs should be None via getattr default
|
||||
assert cu["first_name"] is None
|
||||
assert cu["last_name"] is None
|
||||
assert cu["email"] is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Filter schema tests: created_by_fk
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_chart_filter_accepts_created_by_fk():
|
||||
"""Test that ChartFilter accepts created_by_fk as a valid column."""
|
||||
f = ChartFilter(col="created_by_fk", opr="eq", value=42)
|
||||
assert f.col == "created_by_fk"
|
||||
assert f.opr == "eq"
|
||||
assert f.value == 42
|
||||
|
||||
|
||||
def test_chart_filter_created_by_fk_with_ne_operator():
|
||||
"""Test created_by_fk with 'ne' (not equal) operator."""
|
||||
f = ChartFilter(col="created_by_fk", opr="ne", value=1)
|
||||
assert f.col == "created_by_fk"
|
||||
assert f.opr == "ne"
|
||||
assert f.value == 1
|
||||
|
||||
|
||||
def test_chart_filter_rejects_invalid_column():
|
||||
"""Test that ChartFilter rejects invalid column names."""
|
||||
with pytest.raises(ValidationError):
|
||||
ChartFilter(col="nonexistent_column", opr="eq", value=42)
|
||||
|
||||
|
||||
def test_dashboard_filter_accepts_created_by_fk():
|
||||
"""Test that DashboardFilter accepts created_by_fk as a valid column."""
|
||||
f = DashboardFilter(col="created_by_fk", opr="eq", value=42)
|
||||
assert f.col == "created_by_fk"
|
||||
assert f.opr == "eq"
|
||||
assert f.value == 42
|
||||
|
||||
|
||||
def test_dashboard_filter_created_by_fk_with_ne_operator():
|
||||
"""Test created_by_fk with 'ne' (not equal) operator on dashboards."""
|
||||
f = DashboardFilter(col="created_by_fk", opr="ne", value=1)
|
||||
assert f.col == "created_by_fk"
|
||||
assert f.opr == "ne"
|
||||
assert f.value == 1
|
||||
|
||||
|
||||
def test_dashboard_filter_rejects_invalid_column():
|
||||
"""Test that DashboardFilter rejects invalid column names."""
|
||||
with pytest.raises(ValidationError):
|
||||
DashboardFilter(col="nonexistent_column", opr="eq", value=42)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Existing filter columns still work
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_chart_filter_existing_columns_still_work():
|
||||
"""Test that pre-existing chart filter columns are not broken."""
|
||||
for col in ("slice_name", "viz_type", "datasource_name"):
|
||||
f = ChartFilter(col=col, opr="eq", value="test")
|
||||
assert f.col == col
|
||||
|
||||
|
||||
def test_dashboard_filter_existing_columns_still_work():
|
||||
"""Test that pre-existing dashboard filter columns are not broken."""
|
||||
for col in ("dashboard_title", "published", "favorite"):
|
||||
f = DashboardFilter(col=col, opr="eq", value="test")
|
||||
assert f.col == col
|
||||
@@ -236,9 +236,15 @@ class TestGetSchemaToolViaClient:
|
||||
assert "dashboard_title" in info["sortable_columns"]
|
||||
assert "changed_on" in info["sortable_columns"]
|
||||
|
||||
@patch(
|
||||
"superset.mcp_service.utils.schema_utils._is_parse_request_enabled",
|
||||
return_value=True,
|
||||
)
|
||||
@patch("superset.daos.chart.ChartDAO.get_filterable_columns_and_operators")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_schema_with_json_string_request(self, mock_filters, mcp_server):
|
||||
async def test_get_schema_with_json_string_request(
|
||||
self, mock_filters, mock_parse_enabled, mcp_server
|
||||
):
|
||||
"""Test get_schema accepts JSON string request (Claude Code compatibility)."""
|
||||
mock_filters.return_value = {"slice_name": ["eq"]}
|
||||
|
||||
|
||||
@@ -354,6 +354,17 @@ class TestParseRequestDecorator:
|
||||
name: str
|
||||
count: int
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enable_parse_request(self):
|
||||
"""Ensure MCP_PARSE_REQUEST_ENABLED=True for all parsing tests."""
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch(
|
||||
"superset.mcp_service.utils.schema_utils._is_parse_request_enabled",
|
||||
return_value=True,
|
||||
):
|
||||
yield
|
||||
|
||||
def test_decorator_with_json_string_async(self):
|
||||
"""Should parse JSON string request in async function."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
Reference in New Issue
Block a user