diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py old mode 100644 new mode 100755 index 4f7268458e8..115c5ce1c78 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -747,6 +747,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods # Configuration of feature_flags must be done first to allow init features # conditionally self.configure_feature_flags() + self.configure_mcp_chart_registry() self.configure_db_encrypt() self.setup_db() @@ -821,6 +822,22 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods def configure_feature_flags(self) -> None: feature_flag_manager.init_app(self.superset_app) + def configure_mcp_chart_registry(self) -> None: + from superset.mcp_service.chart import registry + from superset.mcp_service.mcp_config import ( + MCP_CHART_PLUGIN_ENABLED_FUNC, + MCP_DISABLED_CHART_PLUGINS, + ) + + registry.configure( + disabled=self.config.get( + "MCP_DISABLED_CHART_PLUGINS", MCP_DISABLED_CHART_PLUGINS + ), + enabled_func=self.config.get( + "MCP_CHART_PLUGIN_ENABLED_FUNC", MCP_CHART_PLUGIN_ENABLED_FUNC + ), + ) + def configure_sqlglot_dialects(self) -> None: extensions = self.config["SQLGLOT_DIALECTS_EXTENSIONS"] diff --git a/superset/mcp_service/chart/registry.py b/superset/mcp_service/chart/registry.py index 920cfcc5592..82b1a1972ff 100755 --- a/superset/mcp_service/chart/registry.py +++ b/superset/mcp_service/chart/registry.py @@ -38,6 +38,8 @@ from __future__ import annotations import logging import threading +from collections.abc import Callable, Iterable +from dataclasses import dataclass, field from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -49,6 +51,22 @@ _REGISTRY: dict[str, "ChartTypePlugin"] = {} _plugins_loaded = False _plugins_lock = threading.Lock() +# --------------------------------------------------------------------------- +# Plugin filter — replaced atomically by configure() at app startup. +# Default: all registered plugins visible (no disabled set, no callable). +# --------------------------------------------------------------------------- + +PluginEnabledFunc = Callable[[str], bool] + + +@dataclass(frozen=True) +class _PluginFilterConfig: + disabled_plugins: frozenset[str] = field(default_factory=frozenset) + enabled_func: PluginEnabledFunc | None = None + + +_filter_config: _PluginFilterConfig = _PluginFilterConfig() + def _ensure_plugins_loaded() -> None: """Lazily import the plugins package to populate _REGISTRY. @@ -70,6 +88,60 @@ def _ensure_plugins_loaded() -> None: logger.exception("Failed to load built-in chart type plugins") +def configure( + disabled: Iterable[str] | None = None, + enabled_func: PluginEnabledFunc | None = None, +) -> None: + """Set runtime plugin filters. Called once during app initialization. + + Replaces the filter config atomically with a single assignment so concurrent + readers always observe a consistent (disabled_plugins, enabled_func) pair. + + Args: + disabled: chart_type strings to suppress. Accepts any iterable (set, + frozenset, list, tuple). Ignored when enabled_func is provided. + enabled_func: callable(chart_type) -> bool. When set, overrides + ``disabled``. Must be cheap and in-process — no network I/O per + call. On exception the registry fails *closed* (plugin hidden). + """ + global _filter_config + + if enabled_func is not None and not callable(enabled_func): + raise TypeError("enabled_func must be callable or None") + + new_config = _PluginFilterConfig( + disabled_plugins=frozenset(disabled or ()), + enabled_func=enabled_func, + ) + _filter_config = new_config + + if new_config.disabled_plugins: + logger.info( + "MCP chart plugins disabled: %s", sorted(new_config.disabled_plugins) + ) + if new_config.enabled_func is not None: + logger.info( + "MCP chart plugin dynamic filter configured: %r", new_config.enabled_func + ) + + +def _is_plugin_enabled(chart_type: str) -> bool: + """Return True if the plugin is currently enabled (not filtered out).""" + config = _filter_config # read once — atomic reference in CPython + if config.enabled_func is not None: + try: + return bool(config.enabled_func(chart_type)) + except Exception: + logger.warning( + "MCP_CHART_PLUGIN_ENABLED_FUNC raised for chart_type=%r; " + "failing closed (plugin hidden)", + chart_type, + exc_info=True, + ) + return False + return chart_type not in config.disabled_plugins + + def register(plugin: "ChartTypePlugin") -> None: """Register a chart type plugin in the global registry.""" if not plugin.chart_type: @@ -83,23 +155,35 @@ def register(plugin: "ChartTypePlugin") -> None: def get(chart_type: str) -> "ChartTypePlugin | None": - """Return the plugin for a given chart_type, or None if not registered.""" + """Return the plugin for chart_type, or None if unknown or disabled.""" _ensure_plugins_loaded() - return _REGISTRY.get(chart_type) + if chart_type not in _REGISTRY or not _is_plugin_enabled(chart_type): + return None + return _REGISTRY[chart_type] def all_types() -> list[str]: - """Return all registered chart type strings in insertion order.""" + """Return enabled registered chart type strings in insertion order.""" _ensure_plugins_loaded() - return list(_REGISTRY.keys()) + return [ct for ct in _REGISTRY if _is_plugin_enabled(ct)] def is_registered(chart_type: str) -> bool: - """Return True if chart_type has a registered plugin.""" + """Return True if chart_type has a registered plugin, regardless of enabled state. + + Use this to distinguish an unknown chart type from a disabled one. + Use is_enabled() to check whether the plugin is currently available. + """ _ensure_plugins_loaded() return chart_type in _REGISTRY +def is_enabled(chart_type: str) -> bool: + """Return True if chart_type is registered AND currently enabled.""" + _ensure_plugins_loaded() + return chart_type in _REGISTRY and _is_plugin_enabled(chart_type) + + def display_name_for_viz_type(viz_type: str) -> str | None: """Return the user-facing display name for a Superset-internal viz_type. @@ -137,5 +221,8 @@ class _RegistryProxy: def is_registered(self, chart_type: str) -> bool: return is_registered(chart_type) + def is_enabled(self, chart_type: str) -> bool: + return is_enabled(chart_type) + def display_name_for_viz_type(self, viz_type: str) -> str | None: return display_name_for_viz_type(viz_type) diff --git a/superset/mcp_service/chart/validation/schema_validator.py b/superset/mcp_service/chart/validation/schema_validator.py index 11f0f6ada8a..8650a79122f 100755 --- a/superset/mcp_service/chart/validation/schema_validator.py +++ b/superset/mcp_service/chart/validation/schema_validator.py @@ -172,6 +172,21 @@ class SchemaValidator: error_code="INVALID_CHART_TYPE", ) + if not registry.is_enabled(chart_type): + valid_types = ", ".join(registry.all_types()) + return False, ChartGenerationError( + error_type="disabled_chart_type", + message=f"Chart type '{chart_type}' is not enabled on this instance", + details=f"Chart type '{chart_type}' is registered but has been " + f"disabled by the operator. " + f"Enabled chart types: {valid_types}", + suggestions=[ + f"Use one of the enabled chart types: {valid_types}", + "Contact your administrator if you believe this is an error", + ], + error_code="DISABLED_CHART_TYPE", + ) + plugin = registry.get(chart_type) if plugin is None: return False, ChartGenerationError( diff --git a/superset/mcp_service/flask_singleton.py b/superset/mcp_service/flask_singleton.py index d3a7124ec13..b39b73756e4 100644 --- a/superset/mcp_service/flask_singleton.py +++ b/superset/mcp_service/flask_singleton.py @@ -81,6 +81,17 @@ try: mcp_config = get_mcp_config(_mcp_app.config) _mcp_app.config.update(mcp_config) + # Re-configure chart registry so MCP-specific overrides (e.g. + # MCP_DISABLED_CHART_PLUGINS set by the operator) take effect after + # the MCP config overlay. SupersetAppInitializer.configure_mcp_chart_registry() + # ran earlier with pre-overlay values; this call corrects them. + from superset.mcp_service.chart import registry as _chart_registry + + _chart_registry.configure( + disabled=_mcp_app.config.get("MCP_DISABLED_CHART_PLUGINS"), + enabled_func=_mcp_app.config.get("MCP_CHART_PLUGIN_ENABLED_FUNC"), + ) + with _mcp_app.app_context(): from superset.core.mcp.core_mcp_injection import ( initialize_core_mcp_dependencies, diff --git a/superset/mcp_service/mcp_config.py b/superset/mcp_service/mcp_config.py index d12b44bbc87..e47e27b182d 100644 --- a/superset/mcp_service/mcp_config.py +++ b/superset/mcp_service/mcp_config.py @@ -18,6 +18,7 @@ import logging import secrets +from collections.abc import Callable from typing import Any, Dict, Optional from flask import Flask @@ -69,6 +70,46 @@ MCP_RBAC_ENABLED = True # MCP_DISABLED_TOOLS = {"extensions.myorg.myext.some_tool"} MCP_DISABLED_TOOLS: set[str] = set() +# ============================================================================= +# MCP Chart Plugin Filtering +# ============================================================================= +# +# Overview: +# --------- +# These two settings let operators enable/disable individual chart type plugins +# at runtime without a code deploy. +# +# Use cases: +# - Emergency kill switch: add "handlebars" to MCP_DISABLED_CHART_PLUGINS and +# restart to immediately hide it from all callers. +# - Dynamic per-request control (A/B test, gradual rollout): set +# MCP_CHART_PLUGIN_ENABLED_FUNC to an in-process predicate that can vary +# by user, request header, or any other context available at call time. +# +# Priority: +# MCP_CHART_PLUGIN_ENABLED_FUNC takes precedence over MCP_DISABLED_CHART_PLUGINS. +# When the callable is set, the deny-list is ignored entirely. +# +# MCP_CHART_PLUGIN_ENABLED_FUNC contract: +# - Called as enabled_func(chart_type: str) -> bool for every registry lookup. +# - Must be cheap and in-process: consult already-loaded feature flags or +# request-local context (e.g. Flask g). Do NOT perform network I/O per call. +# - On exception, the registry fails CLOSED (plugin hidden) and logs a warning. +# - Example (Harness / Split via pre-fetched flags in g): +# from flask import g +# def MCP_CHART_PLUGIN_ENABLED_FUNC(chart_type: str) -> bool: +# flags = getattr(g, "feature_flags", {}) +# return flags.get(f"mcp_chart_{chart_type}", True) +# ============================================================================= + +# Chart types in this set are hidden from all registry lookups. +# Use frozenset to avoid accidental mutation. +MCP_DISABLED_CHART_PLUGINS: frozenset[str] = frozenset() + +# Dynamic per-call predicate. When set, overrides MCP_DISABLED_CHART_PLUGINS. +# Signature: (chart_type: str) -> bool +MCP_CHART_PLUGIN_ENABLED_FUNC: Callable[[str], bool] | None = None + # MCP JWT Debug Errors - controls server-side JWT debug logging. # When False (default), uses the default JWTVerifier with minimal logging. # When True, uses DetailedJWTVerifier with tiered logging: @@ -416,6 +457,8 @@ def get_mcp_config(app_config: Dict[str, Any] | None = None) -> Dict[str, Any]: "MCP_DEBUG": MCP_DEBUG, "MCP_RBAC_ENABLED": MCP_RBAC_ENABLED, "MCP_DISABLED_TOOLS": set(MCP_DISABLED_TOOLS), + "MCP_DISABLED_CHART_PLUGINS": MCP_DISABLED_CHART_PLUGINS, + "MCP_CHART_PLUGIN_ENABLED_FUNC": MCP_CHART_PLUGIN_ENABLED_FUNC, **MCP_SESSION_CONFIG, **MCP_CSRF_CONFIG, } diff --git a/tests/unit_tests/mcp_service/chart/test_registry_filters.py b/tests/unit_tests/mcp_service/chart/test_registry_filters.py new file mode 100644 index 00000000000..241f014fc2b --- /dev/null +++ b/tests/unit_tests/mcp_service/chart/test_registry_filters.py @@ -0,0 +1,222 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for registry plugin filtering (configure / is_enabled / get / all_types).""" + +import pytest + +import superset.mcp_service.chart.registry as registry_module +from superset.mcp_service.chart.plugin import BaseChartPlugin +from superset.mcp_service.chart.registry import ( + _PluginFilterConfig, + all_types, + configure, + display_name_for_viz_type, + get, + is_enabled, + is_registered, + register, +) + + +@pytest.fixture(autouse=True) +def _isolated_registry(monkeypatch): + """Isolated registry with two known plugins and a clean filter for each test.""" + monkeypatch.setattr(registry_module, "_REGISTRY", {}) + monkeypatch.setattr(registry_module, "_plugins_loaded", True) + monkeypatch.setattr(registry_module, "_filter_config", _PluginFilterConfig()) + register(_AlphaPlugin()) + register(_BetaPlugin()) + + +class _AlphaPlugin(BaseChartPlugin): + chart_type = "alpha" + display_name = "Alpha Chart" + native_viz_types = {"alpha_viz": "Alpha Viz"} + + +class _BetaPlugin(BaseChartPlugin): + chart_type = "beta" + display_name = "Beta Chart" + native_viz_types = {"beta_viz": "Beta Viz"} + + +# --------------------------------------------------------------------------- +# Static deny-list tests +# --------------------------------------------------------------------------- + + +def test_get_returns_plugin_when_enabled(): + assert get("alpha") is not None + assert get("beta") is not None + + +def test_get_returns_none_for_disabled_plugin(): + configure(disabled={"alpha"}) + assert get("alpha") is None + + +def test_get_still_returns_other_plugins_when_one_is_disabled(): + configure(disabled={"alpha"}) + assert get("beta") is not None + + +def test_all_types_excludes_disabled(): + configure(disabled={"alpha"}) + types = all_types() + assert "alpha" not in types + assert "beta" in types + + +def test_all_types_empty_when_all_disabled(): + configure(disabled={"alpha", "beta"}) + assert all_types() == [] + + +def test_is_registered_ignores_deny_list(): + configure(disabled={"alpha"}) + assert is_registered("alpha") is True + + +def test_is_enabled_returns_false_for_disabled(): + configure(disabled={"alpha"}) + assert is_enabled("alpha") is False + + +def test_is_enabled_returns_true_when_not_disabled(): + configure(disabled={"alpha"}) + assert is_enabled("beta") is True + + +def test_is_enabled_returns_false_for_unknown(): + assert is_enabled("nonexistent") is False + + +# --------------------------------------------------------------------------- +# configure() accepts different iterable shapes +# --------------------------------------------------------------------------- + + +def test_configure_accepts_list(): + configure(disabled=["alpha"]) + assert get("alpha") is None + + +def test_configure_accepts_tuple(): + configure(disabled=("alpha",)) + assert get("alpha") is None + + +def test_configure_accepts_frozenset(): + configure(disabled=frozenset({"alpha"})) + assert get("alpha") is None + + +def test_configure_accepts_none_disabled(): + configure(disabled=None) + assert get("alpha") is not None + + +def test_configure_rejects_noncallable_enabled_func(): + with pytest.raises(TypeError): + configure(enabled_func="not_a_callable") + + +# --------------------------------------------------------------------------- +# Dynamic callable hook tests +# --------------------------------------------------------------------------- + + +def test_enabled_func_overrides_deny_list(): + # alpha is in deny-list but callable says True → should be visible + configure(disabled={"alpha"}, enabled_func=lambda ct: ct == "alpha") + assert get("alpha") is not None + + +def test_enabled_func_can_disable_plugin(): + configure(enabled_func=lambda ct: ct != "beta") + assert get("beta") is None + assert get("alpha") is not None + + +def test_enabled_func_called_per_lookup(): + calls = [] + + def hook(ct: str) -> bool: + calls.append(ct) + return True + + configure(enabled_func=hook) + get("alpha") + get("alpha") + assert calls.count("alpha") == 2 + + +def test_enabled_func_exception_fails_closed(caplog): + import logging + + def bad_hook(ct: str) -> bool: + raise RuntimeError("Harness down") + + configure(enabled_func=bad_hook) + with caplog.at_level(logging.WARNING, logger="superset.mcp_service.chart.registry"): + result = get("alpha") + + assert result is None # fail closed + assert "failing closed" in caplog.text.lower() or "alpha" in caplog.text + + +def test_enabled_func_all_types_respects_hook(): + configure(enabled_func=lambda ct: ct == "alpha") + assert all_types() == ["alpha"] + + +# --------------------------------------------------------------------------- +# display_name_for_viz_type is NOT filtered +# --------------------------------------------------------------------------- + + +def test_display_name_unaffected_by_deny_list(): + configure(disabled={"alpha"}) + # Even though alpha is disabled, its viz_type should still resolve + assert display_name_for_viz_type("alpha_viz") == "Alpha Viz" + + +def test_display_name_unaffected_by_callable(): + configure(enabled_func=lambda ct: False) + assert display_name_for_viz_type("beta_viz") == "Beta Viz" + + +# --------------------------------------------------------------------------- +# configure() atomicity: replacing config is visible to next lookup +# --------------------------------------------------------------------------- + + +def test_reconfigure_replaces_previous_filter(): + configure(disabled={"alpha"}) + assert get("alpha") is None + configure(disabled=set()) + assert get("alpha") is not None + + +def test_reconfigure_with_func_then_none_falls_back_to_deny_list(): + configure(enabled_func=lambda ct: False) + assert get("alpha") is None + + configure(disabled={"beta"}, enabled_func=None) + assert get("alpha") is not None + assert get("beta") is None