diff --git a/superset/mcp_service/auth.py b/superset/mcp_service/auth.py new file mode 100644 index 00000000000..743db87442c --- /dev/null +++ b/superset/mcp_service/auth.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import logging + +from flask import current_app, g +from flask_login import AnonymousUserMixin +from superset.extensions import security_manager + +logger = logging.getLogger(__name__) + + +def get_user_from_request(): + """ + Extract user info from the request context (e.g., from Bearer token, headers, etc.). + By default, returns admin user. Override for OIDC/OAuth/Okta integration. + """ + from flask import current_app + from superset.extensions import security_manager + admin_username = current_app.config.get("MCP_ADMIN_USERNAME", "admin") + return security_manager.get_user_by_username(admin_username) + + +def impersonate_user(user, run_as=None): + """ + Optionally impersonate another user if allowed. By default, returns the same user. + Override to enforce impersonation rules. + """ + return user + + +def has_permission(user, tool_func): + """ + Check if the user has permission to run the tool. By default, always True. + Override for RBAC. + """ + return True + + +def log_access(user, tool_name, args, kwargs): + """ + Log access/action for observability/audit. By default, does nothing. + Override to log to your system. + """ + pass + + +def mcp_auth_hook(tool_func): + """ + Decorator for MCP tool functions to enforce auth, impersonation, RBAC, and logging. + Also sets up Flask user context (g.user) for downstream DAO/model code. + All logic is overridable for enterprise integration. + """ + import functools + @functools.wraps(tool_func) + def wrapper(*args, **kwargs): + # --- Setup user context (was _setup_user_context) --- + admin_username = current_app.config.get("MCP_ADMIN_USERNAME", "admin") + admin_user = security_manager.get_user_by_username(admin_username) + if not admin_user: + g.user = AnonymousUserMixin() + else: + g.user = admin_user + # --- End user context setup --- + + user = get_user_from_request() + run_as = kwargs.get("run_as") + if run_as: + user = impersonate_user(user, run_as) + if not has_permission(user, tool_func): + raise PermissionError( + f"User {getattr(user, 'username', user)} not authorized for " + f"{tool_func.__name__}") + log_access(user, tool_func.__name__, args, kwargs) + return tool_func(*args, **kwargs) + + return wrapper diff --git a/superset/mcp_service/chart/tool/get_chart_info.py b/superset/mcp_service/chart/tool/get_chart_info.py index d4785c1bab9..e35328a61e9 100644 --- a/superset/mcp_service/chart/tool/get_chart_info.py +++ b/superset/mcp_service/chart/tool/get_chart_info.py @@ -18,14 +18,14 @@ """ MCP tool: get_chart_info """ -from typing import Any, Dict, Optional, Annotated -from superset.mcp_service.pydantic_schemas import ChartInfo, ChartError -from superset.mcp_service.pydantic_schemas.chart_schemas import serialize_chart_object -from datetime import datetime, timezone -from superset.daos.chart import ChartDAO -from pydantic import Field import logging -from superset.mcp_service.utils import ModelGetInfoTool +from typing import Annotated + +from pydantic import Field +from superset.daos.chart import ChartDAO +from superset.mcp_service.model_tools import ModelGetInfoTool +from superset.mcp_service.pydantic_schemas import ChartError, ChartInfo +from superset.mcp_service.pydantic_schemas.chart_schemas import serialize_chart_object logger = logging.getLogger(__name__) diff --git a/superset/mcp_service/chart/tool/list_charts.py b/superset/mcp_service/chart/tool/list_charts.py index 77fb55ef80b..fd102155723 100644 --- a/superset/mcp_service/chart/tool/list_charts.py +++ b/superset/mcp_service/chart/tool/list_charts.py @@ -23,9 +23,9 @@ from typing import Annotated, Literal, Optional from pydantic import conlist, constr, Field, PositiveInt from superset.daos.chart import ChartDAO +from superset.mcp_service.model_tools import ModelListTool from superset.mcp_service.pydantic_schemas import ChartInfo, ChartList from superset.mcp_service.pydantic_schemas.chart_schemas import ChartFilter -from superset.mcp_service.utils import ModelListTool logger = logging.getLogger(__name__) diff --git a/superset/mcp_service/dashboard/tool/get_dashboard_info.py b/superset/mcp_service/dashboard/tool/get_dashboard_info.py index 75135309f4d..da12579527c 100644 --- a/superset/mcp_service/dashboard/tool/get_dashboard_info.py +++ b/superset/mcp_service/dashboard/tool/get_dashboard_info.py @@ -27,8 +27,8 @@ from typing import Annotated from pydantic import Field -from superset.mcp_service.utils import ModelGetInfoTool from superset.daos.dashboard import DashboardDAO +from superset.mcp_service.model_tools import ModelGetInfoTool from superset.mcp_service.pydantic_schemas import DashboardInfo, DashboardError from superset.mcp_service.pydantic_schemas.system_schemas import RoleInfo, TagInfo, UserInfo from superset.mcp_service.pydantic_schemas.chart_schemas import serialize_chart_object diff --git a/superset/mcp_service/dashboard/tool/list_dashboards.py b/superset/mcp_service/dashboard/tool/list_dashboards.py index 8f9ab61de1c..d40da524d09 100644 --- a/superset/mcp_service/dashboard/tool/list_dashboards.py +++ b/superset/mcp_service/dashboard/tool/list_dashboards.py @@ -26,9 +26,9 @@ from typing import Annotated, Literal, Optional from pydantic import conlist, constr, Field, PositiveInt from superset.daos.dashboard import DashboardDAO +from superset.mcp_service.model_tools import ModelListTool from superset.mcp_service.pydantic_schemas import ( DashboardFilter, DashboardInfo, DashboardList, ) -from superset.mcp_service.utils import ModelListTool logger = logging.getLogger(__name__) diff --git a/superset/mcp_service/dataset/tool/get_dataset_info.py b/superset/mcp_service/dataset/tool/get_dataset_info.py index d152a7204d7..40f16ea20cd 100644 --- a/superset/mcp_service/dataset/tool/get_dataset_info.py +++ b/superset/mcp_service/dataset/tool/get_dataset_info.py @@ -26,8 +26,8 @@ from datetime import datetime, timezone from typing import Any, Annotated from pydantic import Field from superset.daos.dataset import DatasetDAO +from superset.mcp_service.model_tools import ModelGetInfoTool from superset.mcp_service.pydantic_schemas import DatasetInfo, DatasetError, serialize_dataset_object -from superset.mcp_service.utils import ModelGetInfoTool from superset.mcp_service.pydantic_schemas.dataset_schemas import serialize_dataset_object logger = logging.getLogger(__name__) diff --git a/superset/mcp_service/dataset/tool/list_datasets.py b/superset/mcp_service/dataset/tool/list_datasets.py index 37b741890cd..5be1237ba1c 100644 --- a/superset/mcp_service/dataset/tool/list_datasets.py +++ b/superset/mcp_service/dataset/tool/list_datasets.py @@ -26,9 +26,9 @@ from typing import Annotated, Literal, Optional from pydantic import conlist, constr, Field, PositiveInt from superset.daos.dataset import DatasetDAO +from superset.mcp_service.model_tools import ModelListTool from superset.mcp_service.pydantic_schemas import (DatasetInfo, DatasetList) from superset.mcp_service.pydantic_schemas.dataset_schemas import DatasetFilter -from superset.mcp_service.utils import ModelListTool logger = logging.getLogger(__name__) diff --git a/superset/mcp_service/utils.py b/superset/mcp_service/model_tools.py similarity index 71% rename from superset/mcp_service/utils.py rename to superset/mcp_service/model_tools.py index f57dcc50ba9..743f07eafbd 100644 --- a/superset/mcp_service/utils.py +++ b/superset/mcp_service/model_tools.py @@ -15,87 +15,20 @@ # specific language governing permissions and limitations # under the License. - import logging -from flask import current_app, g -from flask_login import AnonymousUserMixin -from superset.extensions import security_manager - -logger = logging.getLogger(__name__) - - -def get_user_from_request(): - """ - Extract user info from the request context (e.g., from Bearer token, headers, etc.). - By default, returns admin user. Override for OIDC/OAuth/Okta integration. - """ - from flask import current_app - from superset.extensions import security_manager - admin_username = current_app.config.get("MCP_ADMIN_USERNAME", "admin") - return security_manager.get_user_by_username(admin_username) - - -def impersonate_user(user, run_as=None): - """ - Optionally impersonate another user if allowed. By default, returns the same user. - Override to enforce impersonation rules. - """ - return user - - -def has_permission(user, tool_func): - """ - Check if the user has permission to run the tool. By default, always True. - Override for RBAC. - """ - return True - - -def log_access(user, tool_name, args, kwargs): - """ - Log access/action for observability/audit. By default, does nothing. - Override to log to your system. - """ - pass - - -def mcp_auth_hook(tool_func): - """ - Decorator for MCP tool functions to enforce auth, impersonation, RBAC, and logging. - Also sets up Flask user context (g.user) for downstream DAO/model code. - All logic is overridable for enterprise integration. - """ - import functools - @functools.wraps(tool_func) - def wrapper(*args, **kwargs): - # --- Setup user context (was _setup_user_context) --- - admin_username = current_app.config.get("MCP_ADMIN_USERNAME", "admin") - admin_user = security_manager.get_user_by_username(admin_username) - if not admin_user: - g.user = AnonymousUserMixin() - else: - g.user = admin_user - # --- End user context setup --- - - user = get_user_from_request() - run_as = kwargs.get("run_as") - if run_as: - user = impersonate_user(user, run_as) - if not has_permission(user, tool_func): - raise PermissionError( - f"User {getattr(user, 'username', user)} not authorized for " - f"{tool_func.__name__}") - log_access(user, tool_func.__name__, args, kwargs) - return tool_func(*args, **kwargs) - - return wrapper - - class ModelListTool: """ Generic tool for listing model objects with filtering, search, pagination, and column selection. + - Paging is 0-based: page=0 is the first page (to match backend and API conventions). + - total_pages is 0 if there are no results; otherwise, it's ceil(total_count / page_size). + - has_previous is True if page > 0 or (page == 0 and total_count == 0) (so UI can disable prev button on empty results). + - has_next is True if there are more results after the current page. + - columns_requested/columns_loaded track what columns were requested/returned for LLM/OpenAPI friendliness. + - Returns a strongly-typed Pydantic list schema (output_list_schema) with all metadata. + - Handles both object-based and JSON string filters. + - Designed for use by LLM agents and API clients. """ def __init__( self, @@ -195,11 +128,14 @@ class ModelListTool: self.logger.info(f"Successfully retrieved {len(item_objs)} {self.list_field_name}") return response - class ModelGetInfoTool: """ Generic tool for retrieving a single model object by ID, with error handling and serialization. + - Returns output_schema if found, otherwise error_schema with error_type and timestamp. + - If the DAO raises an exception, the error is logged and re-raised (for testability and observability). + - Used for get_dashboard_info, get_chart_info, get_dataset_info, etc. + - Designed for LLM/OpenAPI compatibility and robust error reporting. """ def __init__( self, @@ -233,4 +169,4 @@ class ModelGetInfoTool: except Exception as context_error: error_msg = f"Error in ModelGetInfoTool: {str(context_error)}" self.logger.error(error_msg, exc_info=True) - raise + raise \ No newline at end of file diff --git a/superset/mcp_service/server.py b/superset/mcp_service/server.py index 0d1abec4579..6a63a5968c9 100644 --- a/superset/mcp_service/server.py +++ b/superset/mcp_service/server.py @@ -32,7 +32,7 @@ def init_fastmcp_server() -> 'FastMCP': """ from fastmcp import FastMCP from superset.mcp_service.middleware import LoggingMiddleware, PrivateToolMiddleware - from superset.mcp_service.utils import mcp_auth_hook + from superset.mcp_service.auth import mcp_auth_hook logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) diff --git a/superset/mcp_service/tools/chart/get_chart_info.py b/superset/mcp_service/tools/chart/get_chart_info.py new file mode 100644 index 00000000000..0519ecba6ea --- /dev/null +++ b/superset/mcp_service/tools/chart/get_chart_info.py @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/tests/unit_tests/mcp_service/test_auth.py b/tests/unit_tests/mcp_service/test_auth.py new file mode 100644 index 00000000000..ad70eda548c --- /dev/null +++ b/tests/unit_tests/mcp_service/test_auth.py @@ -0,0 +1,52 @@ +from datetime import datetime +from types import SimpleNamespace + +import pytest +from pydantic import BaseModel +from superset.mcp_service.auth import mcp_auth_hook + + +# Dummy Pydantic output schema +class DummyOutputSchema(BaseModel): + id: int + name: str + +# Dummy list schema +class DummyListSchema(BaseModel): + items: list + count: int + total_count: int + page: int + page_size: int + total_pages: int + has_previous: bool + has_next: bool + columns_requested: list + columns_loaded: list + filters_applied: list + pagination: object + timestamp: datetime + +# Dummy error schema +class DummyErrorSchema(BaseModel): + error: str + error_type: str + timestamp: datetime + +# Dummy DAO +class DummyDAO: + @classmethod + def list(cls, **kwargs): + # Return a list of dummy objects and a total count + return [SimpleNamespace(id=1, name="foo"), SimpleNamespace(id=2, name="bar")], 2 + @classmethod + def find_by_id(cls, id): + if id == 1: + return SimpleNamespace(id=1, name="foo") + return None + +def dummy_serializer(obj, columns=None): + # Serialize mock object to DummyOutputSchema + return DummyOutputSchema(id=obj.id, name=obj.name) + +# All ModelListTool and ModelGetInfoTool tests have been moved to test_model_tools.py diff --git a/tests/unit_tests/mcp_service/test_utils.py b/tests/unit_tests/mcp_service/test_model_tools.py similarity index 97% rename from tests/unit_tests/mcp_service/test_utils.py rename to tests/unit_tests/mcp_service/test_model_tools.py index 5df46667a29..f332f3924bd 100644 --- a/tests/unit_tests/mcp_service/test_utils.py +++ b/tests/unit_tests/mcp_service/test_model_tools.py @@ -3,8 +3,7 @@ from types import SimpleNamespace import pytest from pydantic import BaseModel -from superset.mcp_service.utils import ModelGetInfoTool, ModelListTool - +from superset.mcp_service.model_tools import ModelGetInfoTool, ModelListTool # Dummy Pydantic output schema class DummyOutputSchema(BaseModel): @@ -166,4 +165,4 @@ def test_model_get_info_tool_exception(): ) with pytest.raises(Exception) as exc: tool.run(1) - assert "fail" in str(exc.value) + assert "fail" in str(exc.value) \ No newline at end of file