Compare commits

..

5 Commits

Author SHA1 Message Date
Michael S. Molina
c85661f4fd feat: Chat prototype 2026-05-15 15:15:42 -03:00
Michael S. Molina
a06e6ea19b fix(extensions): add cache headers and strip Vary: Cookie for extension static assets (#40120)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-15 09:23:39 -03:00
Shaitan
ee9eec25f9 fix(dataset): validate datasource access during import (#39998)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-15 12:06:47 +01:00
Shaitan
ffa32414ef fix(query): restrict query cancellation to the query owner (#39996)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-15 12:05:38 +01:00
Shaitan
407321e394 fix(database): extend shillelagh URI pattern to cover all driver variants (#39995)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-15 12:04:34 +01:00
24 changed files with 659 additions and 880 deletions

View File

View File

@@ -16,6 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
import { SupersetClient } from '@superset-ui/core';
import { logging } from '@apache-superset/core/utils';
import type { common as core } from '@apache-superset/core';
import ExtensionsLoader from './ExtensionsLoader';
@@ -111,3 +112,33 @@ test('handles initialization errors gracefully', async () => {
errorSpy.mockRestore();
appendChildSpy.mockRestore();
});
test('logs success after initializeExtensions completes', async () => {
const loader = ExtensionsLoader.getInstance();
const infoSpy = jest.spyOn(logging, 'info').mockImplementation();
jest.spyOn(SupersetClient, 'get').mockResolvedValue({
json: { result: [] },
} as any);
await loader.initializeExtensions();
expect(infoSpy).toHaveBeenCalledWith('Extensions initialized successfully.');
infoSpy.mockRestore();
});
test('logs error when initializeExtensions fails', async () => {
const loader = ExtensionsLoader.getInstance();
const errorSpy = jest.spyOn(logging, 'error').mockImplementation();
const fetchError = new Error('Network error');
jest.spyOn(SupersetClient, 'get').mockRejectedValue(fetchError);
await loader.initializeExtensions();
expect(errorSpy).toHaveBeenCalledWith(
'Error setting up extensions:',
fetchError,
);
errorSpy.mockRestore();
});

View File

@@ -34,6 +34,8 @@ class ExtensionsLoader {
private extensionIndex: Map<string, Extension> = new Map();
private initializationPromise: Promise<void> | null = null;
// eslint-disable-next-line no-useless-constructor
private constructor() {
// Private constructor for singleton pattern
@@ -54,16 +56,27 @@ class ExtensionsLoader {
* Initializes extensions by fetching the list from the API and loading each one.
* @throws Error if initialization fails.
*/
public async initializeExtensions(): Promise<void> {
const response = await SupersetClient.get({
endpoint: '/api/v1/extensions/',
});
const extensions: Extension[] = response.json.result;
await Promise.all(
extensions.map(async extension => {
await this.initializeExtension(extension);
}),
);
public initializeExtensions(): Promise<void> {
if (this.initializationPromise) {
return this.initializationPromise;
}
this.initializationPromise = (async () => {
try {
const response = await SupersetClient.get({
endpoint: '/api/v1/extensions/',
});
const extensions: Extension[] = response.json.result;
await Promise.all(
extensions.map(async extension => {
await this.initializeExtension(extension);
}),
);
logging.info('Extensions initialized successfully.');
} catch (error) {
logging.error('Error setting up extensions:', error);
}
})();
return this.initializationPromise;
}
/**

View File

@@ -18,7 +18,6 @@
*/
import { render, waitFor } from 'spec/helpers/testing-library';
import { FeatureFlag, isFeatureEnabled } from '@superset-ui/core';
import { logging } from '@apache-superset/core/utils';
import fetchMock from 'fetch-mock';
import ExtensionsStartup from './ExtensionsStartup';
import ExtensionsLoader from './ExtensionsLoader';
@@ -192,14 +191,12 @@ test('only initializes once even with multiple renders', async () => {
loader.initializeExtensions = originalInitialize;
});
test('initializes ExtensionsLoader and logs success when EnableExtensions feature flag is enabled', async () => {
test('initializes ExtensionsLoader when EnableExtensions feature flag is enabled', async () => {
// Ensure feature flag is enabled
mockIsFeatureEnabled.mockImplementation(
(flag: FeatureFlag) => flag === FeatureFlag.EnableExtensions,
);
const infoSpy = jest.spyOn(logging, 'info').mockImplementation();
// Mock the initializeExtensions method to succeed
const originalInitialize = ExtensionsLoader.prototype.initializeExtensions;
ExtensionsLoader.prototype.initializeExtensions = jest
@@ -220,15 +217,10 @@ test('initializes ExtensionsLoader and logs success when EnableExtensions featur
expect(
ExtensionsLoader.prototype.initializeExtensions,
).toHaveBeenCalledTimes(1);
// Verify success message was logged
expect(infoSpy).toHaveBeenCalledWith(
'Extensions initialized successfully.',
);
});
// Restore original method
ExtensionsLoader.prototype.initializeExtensions = originalInitialize;
infoSpy.mockRestore();
});
test('does not initialize ExtensionsLoader when EnableExtensions feature flag is disabled', async () => {
@@ -259,38 +251,36 @@ test('does not initialize ExtensionsLoader when EnableExtensions feature flag is
initializeSpy.mockRestore();
});
test('logs error when ExtensionsLoader initialization fails', async () => {
test('continues rendering children even when ExtensionsLoader initialization fails', async () => {
// Ensure feature flag is enabled
mockIsFeatureEnabled.mockReturnValue(true);
const errorSpy = jest.spyOn(logging, 'error').mockImplementation();
// Mock the initializeExtensions method to throw an error
// Mock the initializeExtensions method to reject — ExtensionsLoader handles
// its own error logging internally
const originalInitialize = ExtensionsLoader.prototype.initializeExtensions;
ExtensionsLoader.prototype.initializeExtensions = jest
.fn()
.mockImplementation(() => {
throw new Error('Test initialization error');
});
.mockImplementation(() => Promise.resolve());
render(<ExtensionsStartup />, {
useRedux: true,
initialState: mockInitialState,
});
const { container } = render(
<ExtensionsStartup>
<div data-testid="child" />
</ExtensionsStartup>,
{
useRedux: true,
initialState: mockInitialState,
},
);
await waitFor(() => {
// Verify feature flag was checked
expect(mockIsFeatureEnabled).toHaveBeenCalledWith(
FeatureFlag.EnableExtensions,
);
// Verify error was logged
expect(errorSpy).toHaveBeenCalledWith(
'Error setting up extensions:',
expect.any(Error),
);
expect(
container.querySelector('[data-testid="child"]'),
).toBeInTheDocument();
});
// Restore original method
ExtensionsLoader.prototype.initializeExtensions = originalInitialize;
errorSpy.mockRestore();
});

View File

@@ -82,17 +82,7 @@ const ExtensionsStartup: React.FC<{ children?: React.ReactNode }> = ({
const setup = async () => {
if (isFeatureEnabled(FeatureFlag.EnableExtensions)) {
try {
await ExtensionsLoader.getInstance().initializeExtensions();
supersetCore.utils.logging.info(
'Extensions initialized successfully.',
);
} catch (error) {
supersetCore.utils.logging.error(
'Error setting up extensions:',
error,
);
}
await ExtensionsLoader.getInstance().initializeExtensions();
}
setInitialized(true);
};

View File

@@ -36,6 +36,7 @@ else:
from flask import Flask, Response
from werkzeug.exceptions import NotFound
from superset.extensions.cache_middleware import ExtensionCacheMiddleware
from superset.extensions.local_extensions_watcher import (
start_local_extensions_watcher_thread,
)
@@ -66,7 +67,6 @@ def create_app(
or app.config["APPLICATION_ROOT"],
)
if app_root != "/":
app.wsgi_app = AppRootMiddleware(app.wsgi_app, app_root)
# If not set, manually configure options that depend on the
# value of app_root so things work out of the box
if not app.config["STATIC_ASSETS_PREFIX"]:
@@ -77,6 +77,13 @@ def create_app(
app_initializer = app.config.get("APP_INITIALIZER", SupersetAppInitializer)(app)
app_initializer.init_app()
# Must be applied before AppRootMiddleware so the path prefix
# is stripped before the extension asset path regex runs.
app.wsgi_app = ExtensionCacheMiddleware(app.wsgi_app)
if app_root != "/":
app.wsgi_app = AppRootMiddleware(app.wsgi_app, app_root)
# Set up LOCAL_EXTENSIONS file watcher when in debug mode
if app.debug:
start_local_extensions_watcher_thread(app)

View File

@@ -27,9 +27,13 @@ from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.sql.visitors import VisitableType
from superset import db, security_manager
from superset.commands.dataset.exceptions import DatasetForbiddenDataURI
from superset.commands.dataset.exceptions import (
DatasetAccessDeniedError,
DatasetForbiddenDataURI,
)
from superset.commands.exceptions import ImportFailedError
from superset.connectors.sqla.models import SqlaTable
from superset.exceptions import SupersetSecurityException
from superset.models.core import Database
from superset.sql.parse import Table
from superset.utils import json
@@ -172,6 +176,12 @@ def import_dataset( # noqa: C901
if dataset.id is None:
db.session.flush()
if not ignore_permissions:
try:
security_manager.raise_for_access(datasource=dataset)
except SupersetSecurityException as ex:
raise DatasetAccessDeniedError() from ex
try:
table_exists = dataset.database.has_table(
Table(dataset.table_name, dataset.schema, dataset.catalog),

View File

@@ -58,7 +58,11 @@ class QueryDAO(BaseDAO[Query]):
@staticmethod
def stop_query(client_id: str) -> None:
query = db.session.query(Query).filter_by(client_id=client_id).one_or_none()
query = (
db.session.query(Query)
.filter(Query.client_id == client_id, Query.user_id == get_user_id())
.one_or_none()
)
if not query:
raise QueryNotFoundException(f"Query with client_id {client_id} not found")

View File

@@ -225,4 +225,9 @@ class ExtensionsRestApi(BaseApi):
if not mimetype:
mimetype = "application/octet-stream"
return send_file(BytesIO(chunk), mimetype=mimetype)
response = send_file(BytesIO(chunk), mimetype=mimetype)
# Chunk filenames include a content hash, so they are immutable.
response.cache_control.max_age = 31536000
response.cache_control.public = True
response.cache_control.immutable = True
return response

View File

@@ -0,0 +1,73 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import re
from types import TracebackType
from typing import Callable, Iterable, TYPE_CHECKING
if TYPE_CHECKING:
from _typeshed.wsgi import StartResponse, WSGIApplication, WSGIEnvironment
# Matches only the static asset endpoint: /api/v1/extensions/<publisher>/<name>/<file>
# Does not match the list (/), get (/<publisher>/<name>), or info (/_info) endpoints.
_ASSET_PATH_RE = re.compile(r"^/api/v1/extensions/[^/]+/[^/]+/[^/]+$")
class ExtensionCacheMiddleware:
"""Strip 'Cookie' from the Vary header on extension asset responses.
Flask's session interface appends Vary: Cookie unconditionally after every
after_request hook runs, so it cannot be removed at the view layer. This
middleware intercepts the WSGI response at the lowest level, after all
Flask processing is complete.
"""
def __init__(self, wsgi_app: WSGIApplication) -> None:
self.wsgi_app = wsgi_app
def __call__(
self, environ: WSGIEnvironment, start_response: StartResponse
) -> Iterable[bytes]:
path = environ.get("PATH_INFO", "")
if not _ASSET_PATH_RE.match(path):
return self.wsgi_app(environ, start_response)
def patched_start_response(
status: str,
response_headers: list[tuple[str, str]],
exc_info: (
tuple[type[BaseException], BaseException, TracebackType]
| tuple[None, None, None]
| None
) = None,
) -> Callable[[bytes], object]:
new_headers = []
for name, value in response_headers:
if name.lower() == "vary":
parts = [
v.strip()
for v in value.split(",")
if v.strip().lower() != "cookie"
]
if parts:
new_headers.append((name, ", ".join(parts)))
else:
new_headers.append((name, value))
return start_response(status, new_headers, exc_info)
return self.wsgi_app(environ, patched_start_response)

View File

@@ -49,9 +49,7 @@ from contextlib import AbstractContextManager
from typing import Any, Callable, TYPE_CHECKING, TypeVar
from flask import g, has_request_context
from flask_appbuilder.security.sqla.models import User
from superset.mcp_service.composite_token_verifier import API_KEY_PASSTHROUGH_CLAIM
from flask_appbuilder.security.sqla.models import Group, User
if TYPE_CHECKING:
from superset.connectors.sqla.models import SqlaTable
@@ -150,14 +148,23 @@ def check_tool_permission(func: Callable[..., Any]) -> bool:
def load_user_with_relationships(
username: str | None = None, email: str | None = None
) -> User | None:
"""Load a user with roles and group roles eagerly loaded.
"""
Load a user with all relationships needed for permission checks.
Delegates to :meth:`SupersetSecurityManager.find_user_with_relationships`,
which mirrors FAB's ``find_user`` (including ``auth_username_ci`` and
``MultipleResultsFound`` handling) while adding eager loading of
``User.roles`` and ``User.groups.roles`` to prevent detached-instance
errors when the SQLAlchemy session is closed or rolled back after the
lookup — as happens in MCP tool-execution contexts.
This function eagerly loads User.roles, User.groups, and Group.roles
to prevent detached instance errors when the session is closed/rolled back.
IMPORTANT: Always use this function instead of security_manager.find_user()
when loading users for MCP tool execution. The find_user() method doesn't
eagerly load Group.roles, causing "detached instance" errors when permission
checks access group.roles after the session is rolled back.
Args:
username: The username to look up (optional if email provided)
email: The email to look up (optional if username provided)
Returns:
User object with relationships loaded, or None if not found
Raises:
ValueError: If neither username nor email is provided
@@ -165,9 +172,21 @@ def load_user_with_relationships(
if not username and not email:
raise ValueError("Either username or email must be provided")
from superset import security_manager
from sqlalchemy.orm import joinedload
return security_manager.find_user_with_relationships(username=username, email=email)
from superset.extensions import db
query = db.session.query(User).options(
joinedload(User.roles),
joinedload(User.groups).joinedload(Group.roles),
)
if username:
query = query.filter(User.username == username)
else:
query = query.filter(User.email == email)
return query.first()
def _resolve_user_from_jwt_context(app: Any) -> User | None:
@@ -199,25 +218,6 @@ def _resolve_user_from_jwt_context(app: Any) -> User | None:
if access_token is None:
return None
# API key pass-through: CompositeTokenVerifier accepted this token
# at the transport layer but defers actual validation to
# _resolve_user_from_api_key() (priority 2 in get_user_from_request).
# Require client_id=="api_key" (set by CompositeTokenVerifier) in addition
# to the claim so that an external IdP JWT that happens to include the
# claim name is not misclassified as an API-key pass-through.
claims = getattr(access_token, "claims", None)
if isinstance(claims, dict) and claims.get(API_KEY_PASSTHROUGH_CLAIM):
if getattr(access_token, "client_id", None) == "api_key":
logger.debug(
"API key pass-through token detected, deferring to API key auth"
)
return None
logger.debug(
"Ignoring %s claim on non-API-key token (client_id=%r); processing as JWT",
API_KEY_PASSTHROUGH_CLAIM,
getattr(access_token, "client_id", None),
)
# Use configurable resolver or default
from superset.mcp_service.mcp_config import default_user_resolver
@@ -238,12 +238,9 @@ def _resolve_user_from_jwt_context(app: Any) -> User | None:
if not user:
# Fail closed: JWT says this user should exist but they don't.
# Do NOT fall through to MCP_DEV_USERNAME or stale g.user.
# Avoid echoing the JWT-extracted username in the exception message
# (CodeQL py/clear-text-logging-sensitive-data).
logger.debug("JWT-authenticated user not found in database (identity from JWT)")
raise ValueError(
"JWT authenticated user not found in Superset database. "
"Ensure the user exists before granting MCP access."
f"JWT authenticated user '{username}' not found in Superset database. "
f"Ensure the user exists before granting MCP access."
)
return user
@@ -251,57 +248,37 @@ def _resolve_user_from_jwt_context(app: Any) -> User | None:
def _resolve_user_from_api_key(app: Any) -> User | None:
"""
Resolve the current user from an API key passed via Bearer token.
Resolve the current user from an API key in the Authorization header.
Reads the token from FastMCP's per-request ``AccessToken`` (set by
``CompositeTokenVerifier`` when a Bearer token matches an API key
prefix). The streamable-http transport does not push a Flask request
context, so we cannot rely on ``flask.request`` headers — the verifier
already saw the token and stashed it on the ``AccessToken``.
Uses FAB SecurityManager's API key validation. Only attempts when
FAB_API_KEY_ENABLED is True and a request context is active.
Returns:
User object with relationships loaded, or None if no API key
pass-through token is present or API key auth is not enabled.
User object with relationships loaded, or None if no API key present
or API key auth is not enabled/available.
Raises:
PermissionError: If an API key pass-through token is present but
invalid/expired (fail closed — do NOT fall through to weaker
auth sources like ``MCP_DEV_USERNAME``), or if validation is
not available in this FAB version.
PermissionError: If an API key is present but invalid/expired,
or if validation is not available in this FAB version.
"""
if not app.config.get("FAB_API_KEY_ENABLED", False):
if not app.config.get("FAB_API_KEY_ENABLED", False) or not has_request_context():
return None
try:
from fastmcp.server.dependencies import get_access_token
except ImportError:
logger.debug("fastmcp.server.dependencies not available, skipping API key auth")
return None
access_token = get_access_token()
if access_token is None:
return None
# Only validate tokens that the CompositeTokenVerifier flagged as
# API key pass-throughs. Plain JWTs were already validated by the JWT
# verifier and resolved in _resolve_user_from_jwt_context.
claims = getattr(access_token, "claims", None)
if not (isinstance(claims, dict) and claims.get(API_KEY_PASSTHROUGH_CLAIM)):
return None
# Defense-in-depth: require client_id=="api_key" (set by CompositeTokenVerifier)
# to guard against rogue external IdP JWTs that include the passthrough claim.
if getattr(access_token, "client_id", None) != "api_key":
return None
api_key_string = getattr(access_token, "token", None)
if not api_key_string:
# Passthrough claim is set but the raw token is absent — fail closed
# rather than silently falling through to weaker auth sources.
raise PermissionError(
"API key pass-through token is missing the raw token value."
)
sm = app.appbuilder.sm
# extract_api_key_from_request is FAB's method for reading
# the Bearer token from the Authorization header and matching prefixes.
# Not all FAB versions include this method, so guard with hasattr.
if not hasattr(sm, "extract_api_key_from_request"):
logger.debug(
"FAB SecurityManager does not have extract_api_key_from_request; "
"API key authentication is not available in this FAB version"
)
return None
api_key_string = sm.extract_api_key_from_request()
if api_key_string is None:
return None
if not hasattr(sm, "validate_api_key"):
logger.warning(
"FAB SecurityManager does not have validate_api_key; "
@@ -468,6 +445,7 @@ def _setup_user_context() -> User | None:
# tool calls when no per-request middleware refreshes it.
# Only clear in app-context-only mode; preserve g.user when
# a request context is active (external middleware set it).
from flask import has_request_context
if not has_request_context():
g.pop("user", None)
@@ -511,7 +489,7 @@ def _setup_user_context() -> User | None:
logger.error("DB connection failed on retry during user setup: %s", e)
_cleanup_session_on_error()
raise
except (ValueError, PermissionError) as e:
except ValueError as e:
# User resolution failed — fail closed. Do not fall back to
# g.user from middleware, as that could allow a request to
# proceed as a different user in multi-tenant deployments.
@@ -598,7 +576,7 @@ def mcp_auth_hook(tool_func: F) -> F: # noqa: C901
import inspect
import types
from flask import current_app, has_app_context
from flask import current_app, has_app_context, has_request_context
def _get_app_context_manager() -> AbstractContextManager[None]:
"""Push a fresh app context unless a request context is active.

View File

@@ -1,117 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Composite token verifier for MCP authentication.
Routes Bearer tokens to the appropriate verifier based on prefix:
- Tokens matching FAB_API_KEY_PREFIXES (e.g. ``sst_``) are passed through
to the Flask layer where ``_resolve_user_from_api_key()`` handles
actual validation via FAB SecurityManager.
- All other tokens are delegated to the wrapped JWT verifier (when one is
configured); when no JWT verifier is configured, non-API-key tokens are
rejected at the transport layer.
"""
import logging
from fastmcp.server.auth import AccessToken
from fastmcp.server.auth.providers.jwt import TokenVerifier
logger = logging.getLogger(__name__)
# Namespaced claim that flags an AccessToken as an API-key pass-through.
# Namespacing avoids collision with custom claims an external IdP might
# happen to mint on a JWT — a plain ``_api_key_passthrough`` claim could
# be silently misidentified as a Superset API-key request.
API_KEY_PASSTHROUGH_CLAIM = "_superset_mcp_api_key_passthrough"
class CompositeTokenVerifier(TokenVerifier):
"""Routes Bearer tokens between API key pass-through and JWT verification.
API key tokens (identified by prefix) are accepted at the transport layer
with a marker claim so that ``_resolve_user_from_jwt_context()`` can
detect them and fall through to ``_resolve_user_from_api_key()`` for
actual validation.
Args:
jwt_verifier: The wrapped JWT verifier for non-API-key tokens.
When ``None``, only API-key tokens are accepted; all other
Bearer tokens are rejected at the transport layer (used when
``MCP_AUTH_ENABLED=False`` but ``FAB_API_KEY_ENABLED=True``).
api_key_prefixes: List of prefixes that identify API key tokens
(e.g. ``["sst_"]``).
"""
def __init__(
self,
jwt_verifier: TokenVerifier | None,
api_key_prefixes: list[str],
) -> None:
super().__init__(
base_url=getattr(jwt_verifier, "base_url", None),
required_scopes=getattr(jwt_verifier, "required_scopes", None) or [],
)
self._jwt_verifier = jwt_verifier
valid: list[str] = [
p for p in api_key_prefixes if isinstance(p, str) and p.strip()
]
invalid = [p for p in api_key_prefixes if p not in valid]
if invalid:
# Log count only — actual values may be config secrets
# (CodeQL py/clear-text-logging-sensitive-data).
logger.warning(
"FAB_API_KEY_PREFIXES has %d invalid entries (empty/non-string)"
" — ignored",
len(invalid),
)
self._api_key_prefixes = tuple(valid)
async def verify_token(self, token: str) -> AccessToken | None:
"""Verify a Bearer token.
If the token starts with an API key prefix, return a pass-through
AccessToken with the namespaced ``API_KEY_PASSTHROUGH_CLAIM``
(``_superset_mcp_api_key_passthrough``). The Flask-layer
``_resolve_user_from_api_key()`` performs the real validation.
Otherwise, delegate to the wrapped JWT verifier when one is
configured; if no JWT verifier is configured, reject the token.
"""
if any(token.startswith(prefix) for prefix in self._api_key_prefixes):
logger.debug("API key token detected (prefix match), passing through")
# Populate ``scopes`` from ``self.required_scopes`` so FastMCP's
# ``RequireAuthMiddleware`` (transport-layer scope check) is
# satisfied for API-key requests. Without this, MCP_REQUIRED_SCOPES
# being non-empty would 403 every API-key call before
# ``_resolve_user_from_api_key`` even runs.
return AccessToken(
token=token,
client_id="api_key",
scopes=list(self.required_scopes or []),
claims={API_KEY_PASSTHROUGH_CLAIM: True},
)
if self._jwt_verifier is None:
logger.debug(
"Bearer token does not match any API key prefix and no JWT "
"verifier is configured; rejecting"
)
return None
return await self._jwt_verifier.verify_token(token)

View File

@@ -20,15 +20,12 @@ import logging
import secrets
from typing import Any, Dict, Optional
from fastmcp.server.auth.providers.jwt import JWTVerifier
from flask import Flask
from superset.mcp_service.composite_token_verifier import CompositeTokenVerifier
from superset.mcp_service.constants import (
DEFAULT_TOKEN_LIMIT,
DEFAULT_WARN_THRESHOLD_PCT,
)
from superset.mcp_service.jwt_verifier import DetailedJWTVerifier
logger = logging.getLogger(__name__)
@@ -294,94 +291,56 @@ MCP_TOOL_SEARCH_CONFIG: Dict[str, Any] = {
def create_default_mcp_auth_factory(app: Flask) -> Optional[Any]:
"""Default MCP auth factory using app.config values.
Returns an auth provider when ``MCP_AUTH_ENABLED=True`` (JWT verifier,
optionally wrapped with ``CompositeTokenVerifier`` for API keys) or
when only ``FAB_API_KEY_ENABLED=True`` (API-key-only verifier that
rejects all non-API-key Bearer tokens at the transport).
"""
auth_enabled = app.config.get("MCP_AUTH_ENABLED", False)
api_key_enabled = app.config.get("FAB_API_KEY_ENABLED", False)
if not (auth_enabled or api_key_enabled):
"""Default MCP auth factory using app.config values."""
if not app.config.get("MCP_AUTH_ENABLED", False):
return None
jwt_verifier: Any | None = None
jwks_uri = app.config.get("MCP_JWKS_URI")
public_key = app.config.get("MCP_JWT_PUBLIC_KEY")
secret = app.config.get("MCP_JWT_SECRET")
if auth_enabled:
jwks_uri = app.config.get("MCP_JWKS_URI")
public_key = app.config.get("MCP_JWT_PUBLIC_KEY")
secret = app.config.get("MCP_JWT_SECRET")
if not (jwks_uri or public_key or secret):
logger.warning("MCP_AUTH_ENABLED is True but no JWT keys/secret configured")
return None
if not (jwks_uri or public_key or secret):
logger.warning("MCP_AUTH_ENABLED is True but no JWT keys/secret configured")
if not api_key_enabled:
return None
try:
debug_errors = app.config.get("MCP_JWT_DEBUG_ERRORS", False)
common_kwargs: dict[str, Any] = {
"issuer": app.config.get("MCP_JWT_ISSUER"),
"audience": app.config.get("MCP_JWT_AUDIENCE"),
"required_scopes": app.config.get("MCP_REQUIRED_SCOPES", []),
}
# For HS256 (symmetric), use the secret as the public_key parameter
if app.config.get("MCP_JWT_ALGORITHM") == "HS256" and secret:
common_kwargs["public_key"] = secret
common_kwargs["algorithm"] = "HS256"
else:
try:
jwt_verifier = _build_jwt_verifier(
app=app,
jwks_uri=jwks_uri,
public_key=public_key,
secret=secret,
)
except Exception: # noqa: BLE001 — JWT lib raises many types; broad catch intentional
# Do not log the exception — it may contain secrets (e.g., key material)
logger.error("Failed to create MCP JWT verifier")
if not api_key_enabled:
return None
# For RS256 (asymmetric), use public key or JWKS
common_kwargs["jwks_uri"] = jwks_uri
common_kwargs["public_key"] = public_key
common_kwargs["algorithm"] = app.config.get("MCP_JWT_ALGORITHM", "RS256")
if api_key_enabled:
raw_prefixes = app.config.get("FAB_API_KEY_PREFIXES", ["sst_"])
# Normalize: a plain string (e.g. "sst_") would iterate as characters;
# wrap it in a list so CompositeTokenVerifier receives a proper sequence.
if isinstance(raw_prefixes, str):
api_key_prefixes = [raw_prefixes]
if debug_errors:
# DetailedJWTVerifier: detailed server-side logging of JWT
# validation failures. HTTP responses are always generic per
# RFC 6750 Section 3.1.
from superset.mcp_service.jwt_verifier import DetailedJWTVerifier
auth_provider = DetailedJWTVerifier(**common_kwargs)
else:
api_key_prefixes = list(raw_prefixes)
logger.info("API key auth enabled for MCP")
return CompositeTokenVerifier(
jwt_verifier=jwt_verifier,
api_key_prefixes=api_key_prefixes,
)
# Default JWTVerifier: minimal logging, generic error responses.
from fastmcp.server.auth.providers.jwt import JWTVerifier
return jwt_verifier
auth_provider = JWTVerifier(**common_kwargs)
def _build_jwt_verifier(
app: Flask,
jwks_uri: Optional[str],
public_key: Optional[str],
secret: Optional[str],
) -> Any:
"""Construct the JWT verifier from configured keys/secret."""
debug_errors = app.config.get("MCP_JWT_DEBUG_ERRORS", False)
common_kwargs: Dict[str, Any] = {
"issuer": app.config.get("MCP_JWT_ISSUER"),
"audience": app.config.get("MCP_JWT_AUDIENCE"),
"required_scopes": app.config.get("MCP_REQUIRED_SCOPES", []),
}
# For HS256 (symmetric), use the secret as the public_key parameter
if app.config.get("MCP_JWT_ALGORITHM") == "HS256" and secret:
common_kwargs["public_key"] = secret
common_kwargs["algorithm"] = "HS256"
else:
# For RS256 (asymmetric), use public key or JWKS
common_kwargs["jwks_uri"] = jwks_uri
common_kwargs["public_key"] = public_key
common_kwargs["algorithm"] = app.config.get("MCP_JWT_ALGORITHM", "RS256")
if debug_errors:
# DetailedJWTVerifier: detailed server-side logging of JWT
# validation failures. HTTP responses are always generic per
# RFC 6750 Section 3.1.
return DetailedJWTVerifier(**common_kwargs)
# Default JWTVerifier: minimal logging, generic error responses.
return JWTVerifier(**common_kwargs)
return auth_provider
except Exception:
# Do not log the exception — it may contain the HS256 secret
# from common_kwargs["public_key"]
logger.error("Failed to create MCP auth provider")
return None
def default_user_resolver(app: Any, access_token: Any) -> str | None:

View File

@@ -429,13 +429,13 @@ def _tool_allowed_for_current_user(tool: Any) -> bool:
if not getattr(g, "user", None):
try:
g.user = get_user_from_request()
except (ValueError, PermissionError):
except ValueError:
return False
method_permission_name = getattr(tool_func, METHOD_PERMISSION_ATTR, "read")
permission_name = f"{PERMISSION_PREFIX}{method_permission_name}"
return security_manager.can_access(permission_name, class_permission_name)
except (AttributeError, RuntimeError, ValueError, PermissionError):
except (AttributeError, RuntimeError, ValueError):
logger.debug("Could not evaluate tool search permission", exc_info=True)
return False
@@ -673,9 +673,7 @@ def _create_auth_provider(flask_app: Any) -> Any | None:
"""Create an auth provider from Flask app config.
Tries MCP_AUTH_FACTORY first, then falls back to the default factory
when either ``MCP_AUTH_ENABLED`` (JWT auth) or ``FAB_API_KEY_ENABLED``
(API key auth) is True. The default factory builds a
``CompositeTokenVerifier`` that handles either or both auth modes.
when MCP_AUTH_ENABLED is True.
"""
auth_provider = None
if auth_factory := flask_app.config.get("MCP_AUTH_FACTORY"):
@@ -688,9 +686,7 @@ def _create_auth_provider(flask_app: Any) -> Any | None:
except Exception:
# Do not log the exception — it may contain secrets
logger.error("Failed to create auth provider from MCP_AUTH_FACTORY")
elif flask_app.config.get("MCP_AUTH_ENABLED", False) or flask_app.config.get(
"FAB_API_KEY_ENABLED", False
):
elif flask_app.config.get("MCP_AUTH_ENABLED", False):
from superset.mcp_service.mcp_config import (
create_default_mcp_auth_factory,
)

View File

@@ -29,8 +29,7 @@ BLOCKLIST = {
# sqlite creates a local DB, which allows mapping server's filesystem
re.compile(r"sqlite(?:\+[^\s]*)?$"),
# shillelagh allows opening local files (eg, 'SELECT * FROM "csv:///etc/passwd"')
re.compile(r"shillelagh$"),
re.compile(r"shillelagh\+apsw$"),
re.compile(r"shillelagh(?:\+[^\s]*)?$"),
}

View File

@@ -52,8 +52,7 @@ from flask_login import AnonymousUserMixin, LoginManager
from jwt.api_jwt import _jwt_global_obj
from sqlalchemy import and_, inspect, or_
from sqlalchemy.engine.base import Connection
from sqlalchemy.orm import eagerload, joinedload
from sqlalchemy.orm.exc import MultipleResultsFound
from sqlalchemy.orm import eagerload
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.query import Query as SqlaQuery
from sqlalchemy.sql import exists
@@ -463,10 +462,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
"PermissionViewMenu",
"ViewMenu",
"User",
# FAB ApiKeyApi blueprint (active when FAB_API_KEY_ENABLED=True).
# Listed unconditionally — harmless when the feature is off because
# no PVMs exist under this view menu.
"ApiKey",
} | USER_MODEL_VIEWS
ALPHA_ONLY_VIEW_MENUS = {
@@ -3169,60 +3164,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
.one_or_none()
)
def find_user_with_relationships(
self,
username: Optional[str] = None,
email: Optional[str] = None,
) -> Optional[User]:
"""Find a user with roles and group roles eagerly loaded.
Mirrors FAB's ``SecurityManager.find_user``
(including ``auth_username_ci`` case-insensitive handling and
``MultipleResultsFound`` guard) and additionally eager-loads
``User.roles`` and ``User.groups.roles`` to prevent detached-instance
errors when the SQLAlchemy session is closed or rolled back after the
lookup — as happens in MCP tool-execution contexts.
"""
eager = [
joinedload(self.user_model.roles),
joinedload(self.user_model.groups).joinedload("roles"),
]
if username:
try:
if self.auth_username_ci:
from sqlalchemy import func as sa_func
return (
self.session.query(self.user_model)
.options(*eager)
.filter(
sa_func.lower(self.user_model.username)
== sa_func.lower(username)
)
.one_or_none()
)
return (
self.session.query(self.user_model)
.options(*eager)
.filter(self.user_model.username == username)
.one_or_none()
)
except MultipleResultsFound:
logger.error("Multiple results found for user %s", username)
return None
if email:
try:
return (
self.session.query(self.user_model)
.options(*eager)
.filter_by(email=email)
.one_or_none()
)
except MultipleResultsFound:
logger.error("Multiple results found for user with email %s", email)
return None
return None
def get_anonymous_user(self) -> User:
return AnonymousUserMixin()

View File

@@ -72,11 +72,30 @@ from superset.security.analytics_db_safety import check_sqlalchemy_uri
True,
"shillelagh cannot be used as a data source for security reasons.",
),
("shillelagh+:///home/superset/bad.db", False, None),
(
"shillelagh+:///home/superset/bad.db",
True,
"shillelagh cannot be used as a data source for security reasons.",
),
(
"shillelagh+something:///home/superset/bad.db",
False,
None,
True,
"shillelagh cannot be used as a data source for security reasons.",
),
(
"shillelagh+csv:///etc/passwd",
True,
"shillelagh cannot be used as a data source for security reasons.",
),
(
"shillelagh+json:///etc/passwd",
True,
"shillelagh cannot be used as a data source for security reasons.",
),
(
"shillelagh+gsheets:///",
True,
"shillelagh cannot be used as a data source for security reasons.",
),
],
)

View File

@@ -279,3 +279,49 @@ def test_query_dao_stop_query(
QueryDAO.stop_query(query_obj.client_id)
query = db.session.query(Query).one()
assert query.status == QueryStatus.STOPPED
def test_query_dao_stop_query_wrong_user(
mocker: MockerFixture, app: Any, session: Session
) -> None:
"""A user cannot stop a query that belongs to a different user."""
from superset import db
from superset.common.db_query_status import QueryStatus
from superset.models.core import Database
from superset.models.sql_lab import Query
engine = db.session.get_bind()
Query.metadata.create_all(engine) # pylint: disable=no-member
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
query_obj = Query(
client_id="foo",
database=database,
tab_name="test_tab",
sql_editor_id="test_editor_id",
sql="select * from bar",
select_sql="select * from bar",
executed_sql="select * from bar",
limit=100,
select_as_cta=False,
rows=100,
error_message="none",
results_key="abc",
status=QueryStatus.RUNNING,
user_id=1,
)
db.session.add(database)
db.session.add(query_obj)
# Simulate a different user (user 2) attempting to stop user 1's query
mocker.patch("superset.daos.query.get_user_id", return_value=2)
from superset.daos.query import QueryDAO
with pytest.raises(QueryNotFoundException):
QueryDAO.stop_query(query_obj.client_id)
query = db.session.query(Query).one()
assert query.status == QueryStatus.RUNNING

View File

@@ -31,6 +31,7 @@ from sqlalchemy.orm.session import Session
from superset import db, security_manager
from superset.commands.dataset.exceptions import (
DatasetAccessDeniedError,
DatasetForbiddenDataURI,
)
from superset.commands.dataset.importers.v1.utils import (
@@ -744,6 +745,44 @@ def test_import_dataset_without_owner_permission(
mock_can_access.assert_called_with("can_write", "Dataset")
def test_import_dataset_access_check(
mocker: MockerFixture,
session: Session,
) -> None:
"""
Test that import_dataset raises DatasetAccessDeniedError when the user does not
have datasource-level access to the target dataset.
"""
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetSecurityException
mocker.patch.object(security_manager, "can_access", return_value=True)
mocker.patch.object(
security_manager,
"raise_for_access",
side_effect=SupersetSecurityException(
SupersetError(
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
message="User does not have access to this datasource",
level=ErrorLevel.ERROR,
)
),
)
engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
db.session.add(database)
db.session.flush()
config = copy.deepcopy(dataset_fixture)
config["database_id"] = database.id
with pytest.raises(DatasetAccessDeniedError):
import_dataset(config)
@pytest.mark.parametrize(
"allowed_urls, data_uri, expected, exception_class",
[

View File

@@ -0,0 +1,156 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Callable
from superset.extensions.cache_middleware import ExtensionCacheMiddleware
ResponseHeaders = list[tuple[str, str]]
def make_wsgi_app(
status: str = "200 OK",
headers: ResponseHeaders | None = None,
) -> Callable[..., Any]:
"""Returns a minimal WSGI app that calls start_response with the given headers."""
def app(environ, start_response): # noqa: ARG001
start_response(status, headers or [])
return [b"body"]
return app
def call_middleware(
path: str,
upstream_headers: ResponseHeaders,
) -> ResponseHeaders:
"""Run middleware for a given path, return headers passed to start_response."""
captured: list[ResponseHeaders] = []
def start_response(status, headers, exc_info=None): # noqa: ARG001
captured.append(headers)
wsgi_app = make_wsgi_app(headers=upstream_headers)
middleware = ExtensionCacheMiddleware(wsgi_app)
environ = {"PATH_INFO": path}
list(middleware(environ, start_response))
return captured[0]
# --- Path matching ---
def test_asset_path_is_intercepted() -> None:
headers = call_middleware(
"/api/v1/extensions/acme/my-ext/main.js",
[("Vary", "Accept-Encoding, Cookie")],
)
vary = dict(headers).get("Vary", "")
assert "Cookie" not in vary
def test_list_endpoint_is_not_intercepted() -> None:
upstream = [("Vary", "Accept-Encoding, Cookie")]
headers = call_middleware("/api/v1/extensions/", upstream)
assert headers == upstream
def test_get_endpoint_is_not_intercepted() -> None:
upstream = [("Vary", "Accept-Encoding, Cookie")]
headers = call_middleware("/api/v1/extensions/acme/my-ext", upstream)
assert headers == upstream
def test_info_endpoint_is_not_intercepted() -> None:
upstream = [("Vary", "Accept-Encoding, Cookie")]
headers = call_middleware("/api/v1/extensions/_info", upstream)
assert headers == upstream
def test_unrelated_path_is_not_intercepted() -> None:
upstream = [("Vary", "Accept-Encoding, Cookie")]
headers = call_middleware("/api/v1/dashboard/", upstream)
assert headers == upstream
# --- Vary stripping logic ---
def test_strips_cookie_from_vary() -> None:
headers = call_middleware(
"/api/v1/extensions/acme/my-ext/chunk.js",
[("Vary", "Accept-Encoding, Cookie")],
)
assert dict(headers)["Vary"] == "Accept-Encoding"
def test_strips_cookie_case_insensitive() -> None:
headers = call_middleware(
"/api/v1/extensions/acme/my-ext/chunk.js",
[("Vary", "Accept-Encoding, COOKIE")],
)
assert dict(headers)["Vary"] == "Accept-Encoding"
def test_removes_vary_header_when_cookie_is_only_value() -> None:
headers = call_middleware(
"/api/v1/extensions/acme/my-ext/chunk.js",
[("Vary", "Cookie")],
)
assert "Vary" not in dict(headers)
def test_multiple_vary_headers_all_stripped() -> None:
"""Some middleware stacks emit multiple separate Vary headers."""
headers = call_middleware(
"/api/v1/extensions/acme/my-ext/chunk.js",
[("Vary", "Cookie"), ("Vary", "Accept-Encoding, Cookie")],
)
vary_values = [v for k, v in headers if k == "Vary"]
assert all("Cookie" not in v for v in vary_values)
assert vary_values == ["Accept-Encoding"]
def test_non_vary_headers_are_preserved() -> None:
headers = call_middleware(
"/api/v1/extensions/acme/my-ext/chunk.wasm",
[
("Content-Type", "application/wasm"),
("Cache-Control", "public, max-age=31536000, immutable"),
("Vary", "Accept-Encoding, Cookie"),
],
)
d = dict(headers)
assert d["Content-Type"] == "application/wasm"
assert d["Cache-Control"] == "public, max-age=31536000, immutable"
def test_vary_without_cookie_is_unchanged() -> None:
headers = call_middleware(
"/api/v1/extensions/acme/my-ext/chunk.js",
[("Vary", "Accept-Encoding")],
)
assert dict(headers)["Vary"] == "Accept-Encoding"
def test_no_vary_header_produces_no_vary() -> None:
headers = call_middleware(
"/api/v1/extensions/acme/my-ext/chunk.js",
[("Content-Type", "application/javascript")],
)
assert "Vary" not in dict(headers)

View File

@@ -15,55 +15,25 @@
# specific language governing permissions and limitations
# under the License.
"""Tests for API key authentication in get_user_from_request().
"""Tests for API key authentication in get_user_from_request()."""
The streamable-http transport does not push a Flask request context, so
``_resolve_user_from_api_key`` reads the token from FastMCP's per-request
``AccessToken`` (populated by ``CompositeTokenVerifier``) rather than from
``flask.request``. These tests mock ``get_access_token`` accordingly.
"""
from collections.abc import Generator
from contextlib import contextmanager
from unittest.mock import MagicMock, patch
import pytest
from flask import g
from superset.app import SupersetApp
from superset.mcp_service.auth import (
_resolve_user_from_jwt_context,
get_user_from_request,
)
from superset.mcp_service.composite_token_verifier import API_KEY_PASSTHROUGH_CLAIM
from superset.mcp_service.auth import get_user_from_request
@pytest.fixture
def mock_user() -> MagicMock:
def mock_user():
user = MagicMock()
user.username = "api_key_user"
return user
def _passthrough_access_token(token: str) -> MagicMock:
"""Build an AccessToken matching what CompositeTokenVerifier emits."""
access_token = MagicMock()
access_token.token = token
access_token.client_id = "api_key"
access_token.claims = {API_KEY_PASSTHROUGH_CLAIM: True}
return access_token
def _patch_access_token(access_token: MagicMock | None):
"""Patch get_access_token where _resolve_user_from_api_key imports it."""
return patch(
"fastmcp.server.dependencies.get_access_token",
return_value=access_token,
)
@pytest.fixture
def _enable_api_keys(app: SupersetApp) -> Generator[None, None, None]:
def _enable_api_keys(app):
"""Enable FAB API key auth and clear MCP_DEV_USERNAME so the API key
path is exercised instead of falling through to the dev-user fallback."""
app.config["FAB_API_KEY_ENABLED"] = True
@@ -75,7 +45,7 @@ def _enable_api_keys(app: SupersetApp) -> Generator[None, None, None]:
@pytest.fixture
def _disable_api_keys(app: SupersetApp) -> Generator[None, None, None]:
def _disable_api_keys(app):
app.config["FAB_API_KEY_ENABLED"] = False
old_dev = app.config.pop("MCP_DEV_USERNAME", None)
yield
@@ -84,45 +54,24 @@ def _disable_api_keys(app: SupersetApp) -> Generator[None, None, None]:
app.config["MCP_DEV_USERNAME"] = old_dev
@contextmanager
def _mock_sm_ctx(app: SupersetApp, mock_sm: MagicMock):
"""Push an app context with g.user cleared and appbuilder.sm mocked."""
with app.app_context():
g.user = None
app.appbuilder = MagicMock()
app.appbuilder.sm = mock_sm
yield
def _patch_load_user_not_found():
"""Patch load_user_with_relationships to return None (user not found).
load_user_with_relationships delegates to the global security_manager
(not app.appbuilder.sm), so tests that need the JWT path to raise
ValueError("not found") must patch it directly at the module level.
"""
return patch(
"superset.mcp_service.auth.load_user_with_relationships",
return_value=None,
)
# -- Valid API key -> user loaded --
@pytest.mark.usefixtures("_enable_api_keys")
def test_valid_api_key_returns_user(app: SupersetApp, mock_user: MagicMock) -> None:
"""A valid API key pass-through token should authenticate and return the user."""
def test_valid_api_key_returns_user(app, mock_user) -> None:
"""A valid API key should authenticate and return the user."""
mock_sm = MagicMock()
mock_sm.extract_api_key_from_request.return_value = "sst_abc123"
mock_sm.validate_api_key.return_value = mock_user
with _mock_sm_ctx(app, mock_sm):
with (
_patch_access_token(_passthrough_access_token("sst_abc123")),
patch(
"superset.mcp_service.auth.load_user_with_relationships",
return_value=mock_user,
),
with app.test_request_context(headers={"Authorization": "Bearer sst_abc123"}):
g.user = None
app.appbuilder = MagicMock()
app.appbuilder.sm = mock_sm
with patch(
"superset.mcp_service.auth.load_user_with_relationships",
return_value=mock_user,
):
result = get_user_from_request()
@@ -130,70 +79,75 @@ def test_valid_api_key_returns_user(app: SupersetApp, mock_user: MagicMock) -> N
mock_sm.validate_api_key.assert_called_once_with("sst_abc123")
# -- Invalid API key -> PermissionError (does not silently fall back) --
# -- Invalid API key -> PermissionError --
@pytest.mark.usefixtures("_enable_api_keys")
def test_invalid_api_key_raises(app: SupersetApp) -> None:
"""An invalid API key pass-through token should raise PermissionError
(fail closed — do NOT fall through to MCP_DEV_USERNAME)."""
def test_invalid_api_key_raises(app) -> None:
"""An invalid API key should raise PermissionError."""
mock_sm = MagicMock()
mock_sm.extract_api_key_from_request.return_value = "sst_bad_key"
mock_sm.validate_api_key.return_value = None
# The dangerous fallthrough scenario: dev username IS set, but the
# request presented an invalid API key. The dev fallback must not
# mask the rejection.
app.config["MCP_DEV_USERNAME"] = "admin"
try:
with _mock_sm_ctx(app, mock_sm):
with _patch_access_token(_passthrough_access_token("sst_bad_key")):
with pytest.raises(PermissionError, match="Invalid or expired API key"):
get_user_from_request()
finally:
app.config.pop("MCP_DEV_USERNAME", None)
with app.test_request_context(headers={"Authorization": "Bearer sst_bad_key"}):
g.user = None
app.appbuilder = MagicMock()
app.appbuilder.sm = mock_sm
with pytest.raises(PermissionError, match="Invalid or expired API key"):
get_user_from_request()
# -- API key disabled -> falls through to next auth method --
@pytest.mark.usefixtures("_disable_api_keys")
def test_api_key_disabled_skips_auth(app: SupersetApp) -> None:
"""When FAB_API_KEY_ENABLED is False, API key auth is skipped entirely
even if an AccessToken is present."""
def test_api_key_disabled_skips_auth(app) -> None:
"""When FAB_API_KEY_ENABLED is False, API key auth is skipped entirely."""
mock_sm = MagicMock()
with _mock_sm_ctx(app, mock_sm):
with _patch_access_token(_passthrough_access_token("sst_abc123")):
with pytest.raises(ValueError, match="No authenticated user found"):
get_user_from_request()
with app.test_request_context(headers={"Authorization": "Bearer sst_abc123"}):
g.user = None
app.appbuilder = MagicMock()
app.appbuilder.sm = mock_sm
mock_sm.validate_api_key.assert_not_called()
# Without API key auth or MCP_DEV_USERNAME, should raise ValueError
# about no authenticated user (not about invalid API key)
with pytest.raises(ValueError, match="No authenticated user found"):
get_user_from_request()
# SecurityManager API key methods should never be called
mock_sm.extract_api_key_from_request.assert_not_called()
# -- No AccessToken -> API key auth skipped --
# -- No request context -> API key auth skipped --
@pytest.mark.usefixtures("_enable_api_keys")
def test_no_access_token_skips_api_key_auth(app: SupersetApp) -> None:
"""Without a FastMCP AccessToken (e.g., MCP_AUTH_ENABLED=False and no
auth provider installed), API key auth is skipped."""
def test_no_request_context_skips_api_key_auth(app) -> None:
"""Without a request context, API key auth should be skipped
(e.g., during MCP tool discovery with only an app context)."""
mock_sm = MagicMock()
with _mock_sm_ctx(app, mock_sm):
with _patch_access_token(None):
with app.app_context():
g.user = None
app.appbuilder = MagicMock()
app.appbuilder.sm = mock_sm
# Explicitly mock has_request_context to False because the test
# framework's app fixture may implicitly provide a request context.
with patch("superset.mcp_service.auth.has_request_context", return_value=False):
with pytest.raises(ValueError, match="No authenticated user found"):
get_user_from_request()
mock_sm.validate_api_key.assert_not_called()
mock_sm.extract_api_key_from_request.assert_not_called()
# -- g.user fallback when no higher-priority auth succeeds --
@pytest.mark.usefixtures("_disable_api_keys")
def test_g_user_fallback_when_no_jwt_or_api_key(
app: SupersetApp, mock_user: MagicMock
) -> None:
def test_g_user_fallback_when_no_jwt_or_api_key(app, mock_user) -> None:
"""When no JWT or API key auth succeeds and MCP_DEV_USERNAME is not set,
g.user (set by external middleware) is used as fallback."""
with app.test_request_context():
@@ -204,174 +158,89 @@ def test_g_user_fallback_when_no_jwt_or_api_key(
assert result.username == "api_key_user"
# -- FAB version without extract_api_key_from_request --
@pytest.mark.usefixtures("_enable_api_keys")
def test_fab_without_extract_method_skips_gracefully(app) -> None:
"""If FAB SecurityManager lacks extract_api_key_from_request,
API key auth should be skipped with a debug log, not crash."""
mock_sm = MagicMock(spec=[]) # empty spec = no attributes
with app.test_request_context():
g.user = None
app.appbuilder = MagicMock()
app.appbuilder.sm = mock_sm
with pytest.raises(ValueError, match="No authenticated user found"):
get_user_from_request()
# -- FAB version without validate_api_key --
@pytest.mark.usefixtures("_enable_api_keys")
def test_fab_without_validate_method_raises(app: SupersetApp) -> None:
"""If FAB SecurityManager lacks validate_api_key, should raise
PermissionError about unavailable validation."""
mock_sm = MagicMock(spec=[]) # empty spec = no attributes
def test_fab_without_validate_method_raises(app) -> None:
"""If FAB has extract_api_key_from_request but not validate_api_key,
should raise PermissionError about unavailable validation."""
mock_sm = MagicMock(spec=["extract_api_key_from_request"])
mock_sm.extract_api_key_from_request.return_value = "sst_abc123"
with _mock_sm_ctx(app, mock_sm):
with _patch_access_token(_passthrough_access_token("sst_abc123")):
with pytest.raises(
PermissionError, match="API key validation is not available"
):
get_user_from_request()
with app.test_request_context(headers={"Authorization": "Bearer sst_abc123"}):
g.user = None
app.appbuilder = MagicMock()
app.appbuilder.sm = mock_sm
with pytest.raises(
PermissionError, match="API key validation is not available"
):
get_user_from_request()
# -- Relationship reload fallback --
@pytest.mark.usefixtures("_enable_api_keys")
def test_relationship_reload_failure_returns_original_user(
app: SupersetApp, mock_user: MagicMock
) -> None:
def test_relationship_reload_failure_returns_original_user(app, mock_user) -> None:
"""If load_user_with_relationships fails, the original user from
validate_api_key should be returned as fallback."""
mock_sm = MagicMock()
mock_sm.extract_api_key_from_request.return_value = "sst_abc123"
mock_sm.validate_api_key.return_value = mock_user
with _mock_sm_ctx(app, mock_sm):
with (
_patch_access_token(_passthrough_access_token("sst_abc123")),
patch(
"superset.mcp_service.auth.load_user_with_relationships",
return_value=None,
),
with app.test_request_context(headers={"Authorization": "Bearer sst_abc123"}):
g.user = None
app.appbuilder = MagicMock()
app.appbuilder.sm = mock_sm
with patch(
"superset.mcp_service.auth.load_user_with_relationships",
return_value=None,
):
result = get_user_from_request()
assert result is mock_user
# -- AccessToken without passthrough claim (plain JWT) -> skip API key auth --
@pytest.mark.usefixtures("_enable_api_keys")
def test_jwt_access_token_skips_api_key_auth(app: SupersetApp) -> None:
"""When the AccessToken is a plain JWT (no API_KEY_PASSTHROUGH_CLAIM),
API key auth is skipped — the JWT was already validated by the JWT
verifier and resolved in _resolve_user_from_jwt_context."""
mock_sm = MagicMock()
jwt_access_token = MagicMock()
jwt_access_token.token = "eyJhbGciOiJIUzI1NiJ9.not-an-api-key" # noqa: S105
jwt_access_token.claims = {"sub": "alice"}
with _mock_sm_ctx(app, mock_sm):
with _patch_access_token(jwt_access_token), _patch_load_user_not_found():
# _resolve_user_from_jwt_context resolves "alice" from JWT claims
# and raises ValueError because the username is not a real user.
# We assert that _resolve_user_from_api_key did NOT short-circuit
# to the API key path.
with pytest.raises(ValueError, match="not found"):
get_user_from_request()
mock_sm.validate_api_key.assert_not_called()
# -- API key pass-through detection in JWT context resolver --
def test_jwt_context_with_api_key_passthrough_returns_none(app: SupersetApp) -> None:
"""When CompositeTokenVerifier passes through an API key token,
_resolve_user_from_jwt_context should detect the namespaced
pass-through claim AND client_id=="api_key" and return None so
get_user_from_request falls through to _resolve_user_from_api_key."""
mock_access_token = MagicMock()
mock_access_token.client_id = "api_key"
mock_access_token.claims = {API_KEY_PASSTHROUGH_CLAIM: True}
with patch(
"fastmcp.server.dependencies.get_access_token",
return_value=mock_access_token,
):
result = _resolve_user_from_jwt_context(app)
assert result is None
def test_namespaced_claim_without_api_key_client_id_is_ignored(
app: SupersetApp,
) -> None:
"""An external IdP JWT that includes the namespaced API_KEY_PASSTHROUGH_CLAIM
but does NOT have client_id=='api_key' must NOT divert into the API-key path.
The client_id guard prevents misclassification / DoS for affected JWT users."""
mock_sm = MagicMock()
rogue_token = MagicMock()
rogue_token.token = "eyJhbGciOiJSUzI1NiJ9.idp_jwt_with_rogue_claim" # noqa: S105
rogue_token.client_id = "some-idp-client"
rogue_token.claims = {API_KEY_PASSTHROUGH_CLAIM: True, "sub": "alice"}
with _mock_sm_ctx(app, mock_sm):
with _patch_access_token(rogue_token), _patch_load_user_not_found():
# JWT path resolves "alice" from claims and raises ValueError
# because no such user exists.
# validate_api_key must NOT be called — the rogue claim was ignored.
with pytest.raises(ValueError, match="not found"):
get_user_from_request()
mock_sm.validate_api_key.assert_not_called()
# -- Plain JWT with a colliding non-namespaced claim is NOT mistaken for API key --
@pytest.mark.usefixtures("_enable_api_keys")
def test_unnamespaced_passthrough_claim_does_not_trigger_api_key_path(
app: SupersetApp,
) -> None:
"""A JWT minted by an external IdP that happens to include a custom
``_api_key_passthrough`` claim (legacy unnamespaced name) must NOT be
treated as an API-key pass-through. Only the namespaced
``API_KEY_PASSTHROUGH_CLAIM`` triggers the API-key path."""
mock_sm = MagicMock()
rogue_token = MagicMock()
rogue_token.token = "eyJhbGciOiJSUzI1NiJ9.rogue_jwt" # noqa: S105
rogue_token.claims = {"_api_key_passthrough": True, "sub": "alice"}
with _mock_sm_ctx(app, mock_sm):
with _patch_access_token(rogue_token), _patch_load_user_not_found():
# JWT path resolves "alice" from claims and raises ValueError.
# validate_api_key must NOT be called — the rogue claim was ignored.
with pytest.raises(ValueError, match="not found"):
get_user_from_request()
mock_sm.validate_api_key.assert_not_called()
# -- SecurityManager method name regression test --
def test_security_manager_has_expected_api_key_methods(app: SupersetApp) -> None:
"""Regression test: verify the SecurityManager method name referenced in
auth._resolve_user_from_api_key() actually exists on the FAB
SecurityManager class. Catches future renames before they silently break
API key auth at runtime (see PR #39437)."""
with app.app_context():
from superset import security_manager
def test_security_manager_has_expected_api_key_methods() -> None:
"""Regression test: verify the SecurityManager method names referenced in
auth._resolve_user_from_api_key() actually exist on the FAB SecurityManager
class. This catches future renames before they silently break API key auth
at runtime (SC-99414: _extract_api_key_from_request vs
extract_api_key_from_request)."""
from superset import security_manager
sm = security_manager
assert hasattr(sm, "validate_api_key"), (
"FAB SecurityManager is missing 'validate_api_key'. "
"auth._resolve_user_from_api_key() references this method by name — "
"update auth.py if the FAB API changed."
)
def test_security_manager_has_find_user_with_relationships(app: SupersetApp) -> None:
"""Regression test: verify SupersetSecurityManager.find_user_with_relationships
exists. load_user_with_relationships() in auth.py delegates to it — a rename
or removal would silently break MCP user resolution at runtime."""
with app.app_context():
from superset import security_manager
assert hasattr(security_manager, "find_user_with_relationships"), (
"SupersetSecurityManager is missing 'find_user_with_relationships'. "
"auth.load_user_with_relationships() delegates to this method — "
"update auth.py if the method was renamed or removed."
)
sm = security_manager
assert hasattr(sm, "extract_api_key_from_request"), (
"FAB SecurityManager is missing 'extract_api_key_from_request'. "
"auth._resolve_user_from_api_key() references this method by name — "
"update auth.py if the FAB API changed."
)
assert hasattr(sm, "validate_api_key"), (
"FAB SecurityManager is missing 'validate_api_key'. "
"auth._resolve_user_from_api_key() references this method by name — "
"update auth.py if the FAB API changed."
)

View File

@@ -285,7 +285,7 @@ def test_mcp_auth_hook_clears_stale_g_user(app) -> None:
# framework's autouse app_context fixture may implicitly provide
# a request context in some CI environments.
with (
patch("superset.mcp_service.auth.has_request_context", return_value=False),
patch("flask.has_request_context", return_value=False),
patch(
"superset.mcp_service.auth.get_user_from_request",
side_effect=lambda: _assert_cleared_then_return(),
@@ -324,7 +324,7 @@ def test_mcp_auth_hook_clears_stale_g_user_async(app) -> None:
with app.app_context():
g.user = stale_user
with (
patch("superset.mcp_service.auth.has_request_context", return_value=False),
patch("flask.has_request_context", return_value=False),
patch(
"superset.mcp_service.auth.get_user_from_request",
side_effect=lambda: _assert_cleared_then_return(),

View File

@@ -1,218 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for CompositeTokenVerifier."""
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastmcp.server.auth import AccessToken
from superset.mcp_service.composite_token_verifier import (
API_KEY_PASSTHROUGH_CLAIM,
CompositeTokenVerifier,
)
@pytest.fixture
def mock_jwt_verifier() -> MagicMock:
verifier = MagicMock()
verifier.required_scopes = []
verifier.verify_token = AsyncMock()
return verifier
@pytest.fixture
def composite_verifier(mock_jwt_verifier: MagicMock) -> CompositeTokenVerifier:
return CompositeTokenVerifier(
jwt_verifier=mock_jwt_verifier,
api_key_prefixes=["sst_", "pat_"],
)
@pytest.mark.asyncio
async def test_api_key_token_returns_passthrough(
composite_verifier: CompositeTokenVerifier,
) -> None:
"""Tokens matching an API key prefix return a pass-through AccessToken."""
api_key = "sst_abc123secret" # noqa: S105
result = await composite_verifier.verify_token(api_key)
assert result is not None
assert result.token == api_key
assert result.client_id == "api_key"
assert result.claims.get(API_KEY_PASSTHROUGH_CLAIM) is True
@pytest.mark.asyncio
async def test_second_prefix_matches(
composite_verifier: CompositeTokenVerifier,
) -> None:
"""All configured prefixes are checked, not just the first."""
result = await composite_verifier.verify_token("pat_mytoken")
assert result is not None
assert result.claims.get(API_KEY_PASSTHROUGH_CLAIM) is True
@pytest.mark.asyncio
async def test_jwt_token_delegates_to_wrapped_verifier(
composite_verifier: CompositeTokenVerifier, mock_jwt_verifier: MagicMock
) -> None:
"""Non-API-key tokens are delegated to the wrapped JWT verifier."""
jwt_token = "eyJhbGciOiJSUzI1NiJ9.jwt_payload" # noqa: S105
jwt_result = AccessToken(
token=jwt_token,
client_id="oauth_client",
scopes=["read"],
claims={"sub": "user1"},
)
mock_jwt_verifier.verify_token.return_value = jwt_result
result = await composite_verifier.verify_token("eyJhbGciOiJSUzI1NiJ9.jwt_payload")
assert result is jwt_result
mock_jwt_verifier.verify_token.assert_awaited_once_with(
"eyJhbGciOiJSUzI1NiJ9.jwt_payload"
)
@pytest.mark.asyncio
async def test_invalid_jwt_returns_none(
composite_verifier: CompositeTokenVerifier, mock_jwt_verifier: MagicMock
) -> None:
"""When the JWT verifier rejects a token, None is returned."""
mock_jwt_verifier.verify_token.return_value = None
result = await composite_verifier.verify_token("not_a_valid_token")
assert result is None
mock_jwt_verifier.verify_token.assert_awaited_once()
@pytest.mark.asyncio
async def test_api_key_does_not_call_jwt_verifier(
composite_verifier: CompositeTokenVerifier, mock_jwt_verifier: MagicMock
) -> None:
"""API key tokens bypass the JWT verifier entirely."""
await composite_verifier.verify_token("sst_test_key")
mock_jwt_verifier.verify_token.assert_not_awaited()
# -- API-key-only mode (no JWT verifier configured) --
@pytest.mark.asyncio
async def test_api_key_only_mode_accepts_api_keys() -> None:
"""When jwt_verifier is None, API key tokens are still passed through."""
verifier = CompositeTokenVerifier(jwt_verifier=None, api_key_prefixes=["sst_"])
result = await verifier.verify_token("sst_abc123")
assert result is not None
assert result.claims.get(API_KEY_PASSTHROUGH_CLAIM) is True
@pytest.mark.asyncio
async def test_api_key_only_mode_rejects_non_api_key_tokens() -> None:
"""When jwt_verifier is None, non-API-key Bearer tokens are rejected at
the transport instead of being silently accepted."""
verifier = CompositeTokenVerifier(jwt_verifier=None, api_key_prefixes=["sst_"])
result = await verifier.verify_token("eyJhbGciOiJSUzI1NiJ9.jwt_payload")
assert result is None
@pytest.mark.asyncio
async def test_empty_string_prefix_is_filtered_out() -> None:
"""An empty-string prefix would match every Bearer token (DoS vector).
It must be silently dropped and never stored in _api_key_prefixes."""
verifier = CompositeTokenVerifier(jwt_verifier=None, api_key_prefixes=[""])
assert "" not in verifier._api_key_prefixes
# A plain JWT must NOT be misidentified as an API key.
result = await verifier.verify_token("eyJhbGciOiJSUzI1NiJ9.jwt_payload")
assert result is None
@pytest.mark.asyncio
async def test_whitespace_only_prefix_is_filtered_out() -> None:
"""A whitespace-only prefix is also invalid and must be dropped."""
verifier = CompositeTokenVerifier(jwt_verifier=None, api_key_prefixes=[" "])
assert " " not in verifier._api_key_prefixes
result = await verifier.verify_token(" starts_with_spaces")
assert result is None
@pytest.mark.asyncio
async def test_non_string_prefix_is_filtered_out() -> None:
"""Non-string entries (e.g. None, int) must not be stored and must not
cause a TypeError during verify_token."""
verifier = CompositeTokenVerifier(
jwt_verifier=None,
api_key_prefixes=[None, 42, "sst_"], # type: ignore[list-item]
)
assert None not in verifier._api_key_prefixes
assert 42 not in verifier._api_key_prefixes
assert verifier._api_key_prefixes == ("sst_",)
@pytest.mark.asyncio
async def test_invalid_prefixes_emit_warning(caplog: pytest.LogCaptureFixture) -> None:
"""Invalid prefix entries must trigger a logger.warning so operators can
detect misconfiguration in FAB_API_KEY_PREFIXES."""
import logging
logger_name = "superset.mcp_service.composite_token_verifier"
with caplog.at_level(logging.WARNING, logger=logger_name):
CompositeTokenVerifier(jwt_verifier=None, api_key_prefixes=["", "sst_"])
assert any("invalid" in record.message.lower() for record in caplog.records)
@pytest.mark.asyncio
async def test_all_invalid_prefixes_accepts_no_api_keys() -> None:
"""When all prefixes are invalid and filtered out, no token should match
the API key path."""
verifier = CompositeTokenVerifier(jwt_verifier=None, api_key_prefixes=["", " "])
assert verifier._api_key_prefixes == ()
result = await verifier.verify_token("sst_abc123")
assert result is None
@pytest.mark.asyncio
async def test_api_key_passthrough_propagates_required_scopes() -> None:
"""The pass-through AccessToken must carry the verifier's required_scopes
so FastMCP's transport-level ``RequireAuthMiddleware`` does not 403 the
request before ``_resolve_user_from_api_key`` runs."""
jwt_verifier = MagicMock()
jwt_verifier.required_scopes = ["read", "write"]
jwt_verifier.verify_token = AsyncMock()
verifier = CompositeTokenVerifier(
jwt_verifier=jwt_verifier, api_key_prefixes=["sst_"]
)
result = await verifier.verify_token("sst_abc123")
assert result is not None
assert result.scopes == ["read", "write"]

View File

@@ -91,17 +91,6 @@ def test_is_gamma_pvm_excludes_export_image(app_context: None) -> None:
assert sm._is_gamma_pvm(pvm) is False
def test_api_key_view_menu_is_admin_only() -> None:
"""Regression test: 'ApiKey' must be in ADMIN_ONLY_VIEW_MENUS.
FAB registers an ApiKeyApi blueprint when FAB_API_KEY_ENABLED=True.
Without this guard any Gamma user could reach the API key management
endpoints. A rename or removal of the entry would silently re-open
that access hole.
"""
assert "ApiKey" in SupersetSecurityManager.ADMIN_ONLY_VIEW_MENUS
def test_is_gamma_pvm_allows_copy_clipboard(app_context: None) -> None:
"""Verify _is_gamma_pvm returns True for can_copy_clipboard."""
from superset.extensions import appbuilder