mirror of
https://github.com/apache/superset.git
synced 2026-05-16 05:15:16 +00:00
Compare commits
5 Commits
work-pr-39
...
chat-proto
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c85661f4fd | ||
|
|
a06e6ea19b | ||
|
|
ee9eec25f9 | ||
|
|
ffa32414ef | ||
|
|
407321e394 |
0
extensions/chat/PUT_FILES_HERE.txt
Normal file
0
extensions/chat/PUT_FILES_HERE.txt
Normal file
@@ -16,6 +16,7 @@
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import { SupersetClient } from '@superset-ui/core';
|
||||
import { logging } from '@apache-superset/core/utils';
|
||||
import type { common as core } from '@apache-superset/core';
|
||||
import ExtensionsLoader from './ExtensionsLoader';
|
||||
@@ -111,3 +112,33 @@ test('handles initialization errors gracefully', async () => {
|
||||
errorSpy.mockRestore();
|
||||
appendChildSpy.mockRestore();
|
||||
});
|
||||
|
||||
test('logs success after initializeExtensions completes', async () => {
|
||||
const loader = ExtensionsLoader.getInstance();
|
||||
const infoSpy = jest.spyOn(logging, 'info').mockImplementation();
|
||||
jest.spyOn(SupersetClient, 'get').mockResolvedValue({
|
||||
json: { result: [] },
|
||||
} as any);
|
||||
|
||||
await loader.initializeExtensions();
|
||||
|
||||
expect(infoSpy).toHaveBeenCalledWith('Extensions initialized successfully.');
|
||||
|
||||
infoSpy.mockRestore();
|
||||
});
|
||||
|
||||
test('logs error when initializeExtensions fails', async () => {
|
||||
const loader = ExtensionsLoader.getInstance();
|
||||
const errorSpy = jest.spyOn(logging, 'error').mockImplementation();
|
||||
const fetchError = new Error('Network error');
|
||||
jest.spyOn(SupersetClient, 'get').mockRejectedValue(fetchError);
|
||||
|
||||
await loader.initializeExtensions();
|
||||
|
||||
expect(errorSpy).toHaveBeenCalledWith(
|
||||
'Error setting up extensions:',
|
||||
fetchError,
|
||||
);
|
||||
|
||||
errorSpy.mockRestore();
|
||||
});
|
||||
|
||||
@@ -34,6 +34,8 @@ class ExtensionsLoader {
|
||||
|
||||
private extensionIndex: Map<string, Extension> = new Map();
|
||||
|
||||
private initializationPromise: Promise<void> | null = null;
|
||||
|
||||
// eslint-disable-next-line no-useless-constructor
|
||||
private constructor() {
|
||||
// Private constructor for singleton pattern
|
||||
@@ -54,16 +56,27 @@ class ExtensionsLoader {
|
||||
* Initializes extensions by fetching the list from the API and loading each one.
|
||||
* @throws Error if initialization fails.
|
||||
*/
|
||||
public async initializeExtensions(): Promise<void> {
|
||||
const response = await SupersetClient.get({
|
||||
endpoint: '/api/v1/extensions/',
|
||||
});
|
||||
const extensions: Extension[] = response.json.result;
|
||||
await Promise.all(
|
||||
extensions.map(async extension => {
|
||||
await this.initializeExtension(extension);
|
||||
}),
|
||||
);
|
||||
public initializeExtensions(): Promise<void> {
|
||||
if (this.initializationPromise) {
|
||||
return this.initializationPromise;
|
||||
}
|
||||
this.initializationPromise = (async () => {
|
||||
try {
|
||||
const response = await SupersetClient.get({
|
||||
endpoint: '/api/v1/extensions/',
|
||||
});
|
||||
const extensions: Extension[] = response.json.result;
|
||||
await Promise.all(
|
||||
extensions.map(async extension => {
|
||||
await this.initializeExtension(extension);
|
||||
}),
|
||||
);
|
||||
logging.info('Extensions initialized successfully.');
|
||||
} catch (error) {
|
||||
logging.error('Error setting up extensions:', error);
|
||||
}
|
||||
})();
|
||||
return this.initializationPromise;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -18,7 +18,6 @@
|
||||
*/
|
||||
import { render, waitFor } from 'spec/helpers/testing-library';
|
||||
import { FeatureFlag, isFeatureEnabled } from '@superset-ui/core';
|
||||
import { logging } from '@apache-superset/core/utils';
|
||||
import fetchMock from 'fetch-mock';
|
||||
import ExtensionsStartup from './ExtensionsStartup';
|
||||
import ExtensionsLoader from './ExtensionsLoader';
|
||||
@@ -192,14 +191,12 @@ test('only initializes once even with multiple renders', async () => {
|
||||
loader.initializeExtensions = originalInitialize;
|
||||
});
|
||||
|
||||
test('initializes ExtensionsLoader and logs success when EnableExtensions feature flag is enabled', async () => {
|
||||
test('initializes ExtensionsLoader when EnableExtensions feature flag is enabled', async () => {
|
||||
// Ensure feature flag is enabled
|
||||
mockIsFeatureEnabled.mockImplementation(
|
||||
(flag: FeatureFlag) => flag === FeatureFlag.EnableExtensions,
|
||||
);
|
||||
|
||||
const infoSpy = jest.spyOn(logging, 'info').mockImplementation();
|
||||
|
||||
// Mock the initializeExtensions method to succeed
|
||||
const originalInitialize = ExtensionsLoader.prototype.initializeExtensions;
|
||||
ExtensionsLoader.prototype.initializeExtensions = jest
|
||||
@@ -220,15 +217,10 @@ test('initializes ExtensionsLoader and logs success when EnableExtensions featur
|
||||
expect(
|
||||
ExtensionsLoader.prototype.initializeExtensions,
|
||||
).toHaveBeenCalledTimes(1);
|
||||
// Verify success message was logged
|
||||
expect(infoSpy).toHaveBeenCalledWith(
|
||||
'Extensions initialized successfully.',
|
||||
);
|
||||
});
|
||||
|
||||
// Restore original method
|
||||
ExtensionsLoader.prototype.initializeExtensions = originalInitialize;
|
||||
infoSpy.mockRestore();
|
||||
});
|
||||
|
||||
test('does not initialize ExtensionsLoader when EnableExtensions feature flag is disabled', async () => {
|
||||
@@ -259,38 +251,36 @@ test('does not initialize ExtensionsLoader when EnableExtensions feature flag is
|
||||
initializeSpy.mockRestore();
|
||||
});
|
||||
|
||||
test('logs error when ExtensionsLoader initialization fails', async () => {
|
||||
test('continues rendering children even when ExtensionsLoader initialization fails', async () => {
|
||||
// Ensure feature flag is enabled
|
||||
mockIsFeatureEnabled.mockReturnValue(true);
|
||||
|
||||
const errorSpy = jest.spyOn(logging, 'error').mockImplementation();
|
||||
|
||||
// Mock the initializeExtensions method to throw an error
|
||||
// Mock the initializeExtensions method to reject — ExtensionsLoader handles
|
||||
// its own error logging internally
|
||||
const originalInitialize = ExtensionsLoader.prototype.initializeExtensions;
|
||||
ExtensionsLoader.prototype.initializeExtensions = jest
|
||||
.fn()
|
||||
.mockImplementation(() => {
|
||||
throw new Error('Test initialization error');
|
||||
});
|
||||
.mockImplementation(() => Promise.resolve());
|
||||
|
||||
render(<ExtensionsStartup />, {
|
||||
useRedux: true,
|
||||
initialState: mockInitialState,
|
||||
});
|
||||
const { container } = render(
|
||||
<ExtensionsStartup>
|
||||
<div data-testid="child" />
|
||||
</ExtensionsStartup>,
|
||||
{
|
||||
useRedux: true,
|
||||
initialState: mockInitialState,
|
||||
},
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
// Verify feature flag was checked
|
||||
expect(mockIsFeatureEnabled).toHaveBeenCalledWith(
|
||||
FeatureFlag.EnableExtensions,
|
||||
);
|
||||
// Verify error was logged
|
||||
expect(errorSpy).toHaveBeenCalledWith(
|
||||
'Error setting up extensions:',
|
||||
expect.any(Error),
|
||||
);
|
||||
expect(
|
||||
container.querySelector('[data-testid="child"]'),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Restore original method
|
||||
ExtensionsLoader.prototype.initializeExtensions = originalInitialize;
|
||||
errorSpy.mockRestore();
|
||||
});
|
||||
|
||||
@@ -82,17 +82,7 @@ const ExtensionsStartup: React.FC<{ children?: React.ReactNode }> = ({
|
||||
|
||||
const setup = async () => {
|
||||
if (isFeatureEnabled(FeatureFlag.EnableExtensions)) {
|
||||
try {
|
||||
await ExtensionsLoader.getInstance().initializeExtensions();
|
||||
supersetCore.utils.logging.info(
|
||||
'Extensions initialized successfully.',
|
||||
);
|
||||
} catch (error) {
|
||||
supersetCore.utils.logging.error(
|
||||
'Error setting up extensions:',
|
||||
error,
|
||||
);
|
||||
}
|
||||
await ExtensionsLoader.getInstance().initializeExtensions();
|
||||
}
|
||||
setInitialized(true);
|
||||
};
|
||||
|
||||
@@ -36,6 +36,7 @@ else:
|
||||
from flask import Flask, Response
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from superset.extensions.cache_middleware import ExtensionCacheMiddleware
|
||||
from superset.extensions.local_extensions_watcher import (
|
||||
start_local_extensions_watcher_thread,
|
||||
)
|
||||
@@ -66,7 +67,6 @@ def create_app(
|
||||
or app.config["APPLICATION_ROOT"],
|
||||
)
|
||||
if app_root != "/":
|
||||
app.wsgi_app = AppRootMiddleware(app.wsgi_app, app_root)
|
||||
# If not set, manually configure options that depend on the
|
||||
# value of app_root so things work out of the box
|
||||
if not app.config["STATIC_ASSETS_PREFIX"]:
|
||||
@@ -77,6 +77,13 @@ def create_app(
|
||||
app_initializer = app.config.get("APP_INITIALIZER", SupersetAppInitializer)(app)
|
||||
app_initializer.init_app()
|
||||
|
||||
# Must be applied before AppRootMiddleware so the path prefix
|
||||
# is stripped before the extension asset path regex runs.
|
||||
app.wsgi_app = ExtensionCacheMiddleware(app.wsgi_app)
|
||||
|
||||
if app_root != "/":
|
||||
app.wsgi_app = AppRootMiddleware(app.wsgi_app, app_root)
|
||||
|
||||
# Set up LOCAL_EXTENSIONS file watcher when in debug mode
|
||||
if app.debug:
|
||||
start_local_extensions_watcher_thread(app)
|
||||
|
||||
@@ -27,9 +27,13 @@ from sqlalchemy.exc import MultipleResultsFound
|
||||
from sqlalchemy.sql.visitors import VisitableType
|
||||
|
||||
from superset import db, security_manager
|
||||
from superset.commands.dataset.exceptions import DatasetForbiddenDataURI
|
||||
from superset.commands.dataset.exceptions import (
|
||||
DatasetAccessDeniedError,
|
||||
DatasetForbiddenDataURI,
|
||||
)
|
||||
from superset.commands.exceptions import ImportFailedError
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
from superset.models.core import Database
|
||||
from superset.sql.parse import Table
|
||||
from superset.utils import json
|
||||
@@ -172,6 +176,12 @@ def import_dataset( # noqa: C901
|
||||
if dataset.id is None:
|
||||
db.session.flush()
|
||||
|
||||
if not ignore_permissions:
|
||||
try:
|
||||
security_manager.raise_for_access(datasource=dataset)
|
||||
except SupersetSecurityException as ex:
|
||||
raise DatasetAccessDeniedError() from ex
|
||||
|
||||
try:
|
||||
table_exists = dataset.database.has_table(
|
||||
Table(dataset.table_name, dataset.schema, dataset.catalog),
|
||||
|
||||
@@ -58,7 +58,11 @@ class QueryDAO(BaseDAO[Query]):
|
||||
|
||||
@staticmethod
|
||||
def stop_query(client_id: str) -> None:
|
||||
query = db.session.query(Query).filter_by(client_id=client_id).one_or_none()
|
||||
query = (
|
||||
db.session.query(Query)
|
||||
.filter(Query.client_id == client_id, Query.user_id == get_user_id())
|
||||
.one_or_none()
|
||||
)
|
||||
if not query:
|
||||
raise QueryNotFoundException(f"Query with client_id {client_id} not found")
|
||||
|
||||
|
||||
@@ -225,4 +225,9 @@ class ExtensionsRestApi(BaseApi):
|
||||
if not mimetype:
|
||||
mimetype = "application/octet-stream"
|
||||
|
||||
return send_file(BytesIO(chunk), mimetype=mimetype)
|
||||
response = send_file(BytesIO(chunk), mimetype=mimetype)
|
||||
# Chunk filenames include a content hash, so they are immutable.
|
||||
response.cache_control.max_age = 31536000
|
||||
response.cache_control.public = True
|
||||
response.cache_control.immutable = True
|
||||
return response
|
||||
|
||||
73
superset/extensions/cache_middleware.py
Normal file
73
superset/extensions/cache_middleware.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from types import TracebackType
|
||||
from typing import Callable, Iterable, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed.wsgi import StartResponse, WSGIApplication, WSGIEnvironment
|
||||
|
||||
# Matches only the static asset endpoint: /api/v1/extensions/<publisher>/<name>/<file>
|
||||
# Does not match the list (/), get (/<publisher>/<name>), or info (/_info) endpoints.
|
||||
_ASSET_PATH_RE = re.compile(r"^/api/v1/extensions/[^/]+/[^/]+/[^/]+$")
|
||||
|
||||
|
||||
class ExtensionCacheMiddleware:
|
||||
"""Strip 'Cookie' from the Vary header on extension asset responses.
|
||||
|
||||
Flask's session interface appends Vary: Cookie unconditionally after every
|
||||
after_request hook runs, so it cannot be removed at the view layer. This
|
||||
middleware intercepts the WSGI response at the lowest level, after all
|
||||
Flask processing is complete.
|
||||
"""
|
||||
|
||||
def __init__(self, wsgi_app: WSGIApplication) -> None:
|
||||
self.wsgi_app = wsgi_app
|
||||
|
||||
def __call__(
|
||||
self, environ: WSGIEnvironment, start_response: StartResponse
|
||||
) -> Iterable[bytes]:
|
||||
path = environ.get("PATH_INFO", "")
|
||||
if not _ASSET_PATH_RE.match(path):
|
||||
return self.wsgi_app(environ, start_response)
|
||||
|
||||
def patched_start_response(
|
||||
status: str,
|
||||
response_headers: list[tuple[str, str]],
|
||||
exc_info: (
|
||||
tuple[type[BaseException], BaseException, TracebackType]
|
||||
| tuple[None, None, None]
|
||||
| None
|
||||
) = None,
|
||||
) -> Callable[[bytes], object]:
|
||||
new_headers = []
|
||||
for name, value in response_headers:
|
||||
if name.lower() == "vary":
|
||||
parts = [
|
||||
v.strip()
|
||||
for v in value.split(",")
|
||||
if v.strip().lower() != "cookie"
|
||||
]
|
||||
if parts:
|
||||
new_headers.append((name, ", ".join(parts)))
|
||||
else:
|
||||
new_headers.append((name, value))
|
||||
return start_response(status, new_headers, exc_info)
|
||||
|
||||
return self.wsgi_app(environ, patched_start_response)
|
||||
@@ -49,9 +49,7 @@ from contextlib import AbstractContextManager
|
||||
from typing import Any, Callable, TYPE_CHECKING, TypeVar
|
||||
|
||||
from flask import g, has_request_context
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
||||
from superset.mcp_service.composite_token_verifier import API_KEY_PASSTHROUGH_CLAIM
|
||||
from flask_appbuilder.security.sqla.models import Group, User
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
@@ -150,14 +148,23 @@ def check_tool_permission(func: Callable[..., Any]) -> bool:
|
||||
def load_user_with_relationships(
|
||||
username: str | None = None, email: str | None = None
|
||||
) -> User | None:
|
||||
"""Load a user with roles and group roles eagerly loaded.
|
||||
"""
|
||||
Load a user with all relationships needed for permission checks.
|
||||
|
||||
Delegates to :meth:`SupersetSecurityManager.find_user_with_relationships`,
|
||||
which mirrors FAB's ``find_user`` (including ``auth_username_ci`` and
|
||||
``MultipleResultsFound`` handling) while adding eager loading of
|
||||
``User.roles`` and ``User.groups.roles`` to prevent detached-instance
|
||||
errors when the SQLAlchemy session is closed or rolled back after the
|
||||
lookup — as happens in MCP tool-execution contexts.
|
||||
This function eagerly loads User.roles, User.groups, and Group.roles
|
||||
to prevent detached instance errors when the session is closed/rolled back.
|
||||
|
||||
IMPORTANT: Always use this function instead of security_manager.find_user()
|
||||
when loading users for MCP tool execution. The find_user() method doesn't
|
||||
eagerly load Group.roles, causing "detached instance" errors when permission
|
||||
checks access group.roles after the session is rolled back.
|
||||
|
||||
Args:
|
||||
username: The username to look up (optional if email provided)
|
||||
email: The email to look up (optional if username provided)
|
||||
|
||||
Returns:
|
||||
User object with relationships loaded, or None if not found
|
||||
|
||||
Raises:
|
||||
ValueError: If neither username nor email is provided
|
||||
@@ -165,9 +172,21 @@ def load_user_with_relationships(
|
||||
if not username and not email:
|
||||
raise ValueError("Either username or email must be provided")
|
||||
|
||||
from superset import security_manager
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
return security_manager.find_user_with_relationships(username=username, email=email)
|
||||
from superset.extensions import db
|
||||
|
||||
query = db.session.query(User).options(
|
||||
joinedload(User.roles),
|
||||
joinedload(User.groups).joinedload(Group.roles),
|
||||
)
|
||||
|
||||
if username:
|
||||
query = query.filter(User.username == username)
|
||||
else:
|
||||
query = query.filter(User.email == email)
|
||||
|
||||
return query.first()
|
||||
|
||||
|
||||
def _resolve_user_from_jwt_context(app: Any) -> User | None:
|
||||
@@ -199,25 +218,6 @@ def _resolve_user_from_jwt_context(app: Any) -> User | None:
|
||||
if access_token is None:
|
||||
return None
|
||||
|
||||
# API key pass-through: CompositeTokenVerifier accepted this token
|
||||
# at the transport layer but defers actual validation to
|
||||
# _resolve_user_from_api_key() (priority 2 in get_user_from_request).
|
||||
# Require client_id=="api_key" (set by CompositeTokenVerifier) in addition
|
||||
# to the claim so that an external IdP JWT that happens to include the
|
||||
# claim name is not misclassified as an API-key pass-through.
|
||||
claims = getattr(access_token, "claims", None)
|
||||
if isinstance(claims, dict) and claims.get(API_KEY_PASSTHROUGH_CLAIM):
|
||||
if getattr(access_token, "client_id", None) == "api_key":
|
||||
logger.debug(
|
||||
"API key pass-through token detected, deferring to API key auth"
|
||||
)
|
||||
return None
|
||||
logger.debug(
|
||||
"Ignoring %s claim on non-API-key token (client_id=%r); processing as JWT",
|
||||
API_KEY_PASSTHROUGH_CLAIM,
|
||||
getattr(access_token, "client_id", None),
|
||||
)
|
||||
|
||||
# Use configurable resolver or default
|
||||
from superset.mcp_service.mcp_config import default_user_resolver
|
||||
|
||||
@@ -238,12 +238,9 @@ def _resolve_user_from_jwt_context(app: Any) -> User | None:
|
||||
if not user:
|
||||
# Fail closed: JWT says this user should exist but they don't.
|
||||
# Do NOT fall through to MCP_DEV_USERNAME or stale g.user.
|
||||
# Avoid echoing the JWT-extracted username in the exception message
|
||||
# (CodeQL py/clear-text-logging-sensitive-data).
|
||||
logger.debug("JWT-authenticated user not found in database (identity from JWT)")
|
||||
raise ValueError(
|
||||
"JWT authenticated user not found in Superset database. "
|
||||
"Ensure the user exists before granting MCP access."
|
||||
f"JWT authenticated user '{username}' not found in Superset database. "
|
||||
f"Ensure the user exists before granting MCP access."
|
||||
)
|
||||
|
||||
return user
|
||||
@@ -251,57 +248,37 @@ def _resolve_user_from_jwt_context(app: Any) -> User | None:
|
||||
|
||||
def _resolve_user_from_api_key(app: Any) -> User | None:
|
||||
"""
|
||||
Resolve the current user from an API key passed via Bearer token.
|
||||
Resolve the current user from an API key in the Authorization header.
|
||||
|
||||
Reads the token from FastMCP's per-request ``AccessToken`` (set by
|
||||
``CompositeTokenVerifier`` when a Bearer token matches an API key
|
||||
prefix). The streamable-http transport does not push a Flask request
|
||||
context, so we cannot rely on ``flask.request`` headers — the verifier
|
||||
already saw the token and stashed it on the ``AccessToken``.
|
||||
Uses FAB SecurityManager's API key validation. Only attempts when
|
||||
FAB_API_KEY_ENABLED is True and a request context is active.
|
||||
|
||||
Returns:
|
||||
User object with relationships loaded, or None if no API key
|
||||
pass-through token is present or API key auth is not enabled.
|
||||
User object with relationships loaded, or None if no API key present
|
||||
or API key auth is not enabled/available.
|
||||
|
||||
Raises:
|
||||
PermissionError: If an API key pass-through token is present but
|
||||
invalid/expired (fail closed — do NOT fall through to weaker
|
||||
auth sources like ``MCP_DEV_USERNAME``), or if validation is
|
||||
not available in this FAB version.
|
||||
PermissionError: If an API key is present but invalid/expired,
|
||||
or if validation is not available in this FAB version.
|
||||
"""
|
||||
if not app.config.get("FAB_API_KEY_ENABLED", False):
|
||||
if not app.config.get("FAB_API_KEY_ENABLED", False) or not has_request_context():
|
||||
return None
|
||||
|
||||
try:
|
||||
from fastmcp.server.dependencies import get_access_token
|
||||
except ImportError:
|
||||
logger.debug("fastmcp.server.dependencies not available, skipping API key auth")
|
||||
return None
|
||||
|
||||
access_token = get_access_token()
|
||||
if access_token is None:
|
||||
return None
|
||||
|
||||
# Only validate tokens that the CompositeTokenVerifier flagged as
|
||||
# API key pass-throughs. Plain JWTs were already validated by the JWT
|
||||
# verifier and resolved in _resolve_user_from_jwt_context.
|
||||
claims = getattr(access_token, "claims", None)
|
||||
if not (isinstance(claims, dict) and claims.get(API_KEY_PASSTHROUGH_CLAIM)):
|
||||
return None
|
||||
# Defense-in-depth: require client_id=="api_key" (set by CompositeTokenVerifier)
|
||||
# to guard against rogue external IdP JWTs that include the passthrough claim.
|
||||
if getattr(access_token, "client_id", None) != "api_key":
|
||||
return None
|
||||
|
||||
api_key_string = getattr(access_token, "token", None)
|
||||
if not api_key_string:
|
||||
# Passthrough claim is set but the raw token is absent — fail closed
|
||||
# rather than silently falling through to weaker auth sources.
|
||||
raise PermissionError(
|
||||
"API key pass-through token is missing the raw token value."
|
||||
)
|
||||
|
||||
sm = app.appbuilder.sm
|
||||
# extract_api_key_from_request is FAB's method for reading
|
||||
# the Bearer token from the Authorization header and matching prefixes.
|
||||
# Not all FAB versions include this method, so guard with hasattr.
|
||||
if not hasattr(sm, "extract_api_key_from_request"):
|
||||
logger.debug(
|
||||
"FAB SecurityManager does not have extract_api_key_from_request; "
|
||||
"API key authentication is not available in this FAB version"
|
||||
)
|
||||
return None
|
||||
|
||||
api_key_string = sm.extract_api_key_from_request()
|
||||
if api_key_string is None:
|
||||
return None
|
||||
|
||||
if not hasattr(sm, "validate_api_key"):
|
||||
logger.warning(
|
||||
"FAB SecurityManager does not have validate_api_key; "
|
||||
@@ -468,6 +445,7 @@ def _setup_user_context() -> User | None:
|
||||
# tool calls when no per-request middleware refreshes it.
|
||||
# Only clear in app-context-only mode; preserve g.user when
|
||||
# a request context is active (external middleware set it).
|
||||
from flask import has_request_context
|
||||
|
||||
if not has_request_context():
|
||||
g.pop("user", None)
|
||||
@@ -511,7 +489,7 @@ def _setup_user_context() -> User | None:
|
||||
logger.error("DB connection failed on retry during user setup: %s", e)
|
||||
_cleanup_session_on_error()
|
||||
raise
|
||||
except (ValueError, PermissionError) as e:
|
||||
except ValueError as e:
|
||||
# User resolution failed — fail closed. Do not fall back to
|
||||
# g.user from middleware, as that could allow a request to
|
||||
# proceed as a different user in multi-tenant deployments.
|
||||
@@ -598,7 +576,7 @@ def mcp_auth_hook(tool_func: F) -> F: # noqa: C901
|
||||
import inspect
|
||||
import types
|
||||
|
||||
from flask import current_app, has_app_context
|
||||
from flask import current_app, has_app_context, has_request_context
|
||||
|
||||
def _get_app_context_manager() -> AbstractContextManager[None]:
|
||||
"""Push a fresh app context unless a request context is active.
|
||||
|
||||
@@ -1,117 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Composite token verifier for MCP authentication.
|
||||
|
||||
Routes Bearer tokens to the appropriate verifier based on prefix:
|
||||
- Tokens matching FAB_API_KEY_PREFIXES (e.g. ``sst_``) are passed through
|
||||
to the Flask layer where ``_resolve_user_from_api_key()`` handles
|
||||
actual validation via FAB SecurityManager.
|
||||
- All other tokens are delegated to the wrapped JWT verifier (when one is
|
||||
configured); when no JWT verifier is configured, non-API-key tokens are
|
||||
rejected at the transport layer.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastmcp.server.auth import AccessToken
|
||||
from fastmcp.server.auth.providers.jwt import TokenVerifier
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Namespaced claim that flags an AccessToken as an API-key pass-through.
|
||||
# Namespacing avoids collision with custom claims an external IdP might
|
||||
# happen to mint on a JWT — a plain ``_api_key_passthrough`` claim could
|
||||
# be silently misidentified as a Superset API-key request.
|
||||
API_KEY_PASSTHROUGH_CLAIM = "_superset_mcp_api_key_passthrough"
|
||||
|
||||
|
||||
class CompositeTokenVerifier(TokenVerifier):
|
||||
"""Routes Bearer tokens between API key pass-through and JWT verification.
|
||||
|
||||
API key tokens (identified by prefix) are accepted at the transport layer
|
||||
with a marker claim so that ``_resolve_user_from_jwt_context()`` can
|
||||
detect them and fall through to ``_resolve_user_from_api_key()`` for
|
||||
actual validation.
|
||||
|
||||
Args:
|
||||
jwt_verifier: The wrapped JWT verifier for non-API-key tokens.
|
||||
When ``None``, only API-key tokens are accepted; all other
|
||||
Bearer tokens are rejected at the transport layer (used when
|
||||
``MCP_AUTH_ENABLED=False`` but ``FAB_API_KEY_ENABLED=True``).
|
||||
api_key_prefixes: List of prefixes that identify API key tokens
|
||||
(e.g. ``["sst_"]``).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
jwt_verifier: TokenVerifier | None,
|
||||
api_key_prefixes: list[str],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
base_url=getattr(jwt_verifier, "base_url", None),
|
||||
required_scopes=getattr(jwt_verifier, "required_scopes", None) or [],
|
||||
)
|
||||
self._jwt_verifier = jwt_verifier
|
||||
valid: list[str] = [
|
||||
p for p in api_key_prefixes if isinstance(p, str) and p.strip()
|
||||
]
|
||||
invalid = [p for p in api_key_prefixes if p not in valid]
|
||||
if invalid:
|
||||
# Log count only — actual values may be config secrets
|
||||
# (CodeQL py/clear-text-logging-sensitive-data).
|
||||
logger.warning(
|
||||
"FAB_API_KEY_PREFIXES has %d invalid entries (empty/non-string)"
|
||||
" — ignored",
|
||||
len(invalid),
|
||||
)
|
||||
self._api_key_prefixes = tuple(valid)
|
||||
|
||||
async def verify_token(self, token: str) -> AccessToken | None:
|
||||
"""Verify a Bearer token.
|
||||
|
||||
If the token starts with an API key prefix, return a pass-through
|
||||
AccessToken with the namespaced ``API_KEY_PASSTHROUGH_CLAIM``
|
||||
(``_superset_mcp_api_key_passthrough``). The Flask-layer
|
||||
``_resolve_user_from_api_key()`` performs the real validation.
|
||||
|
||||
Otherwise, delegate to the wrapped JWT verifier when one is
|
||||
configured; if no JWT verifier is configured, reject the token.
|
||||
"""
|
||||
if any(token.startswith(prefix) for prefix in self._api_key_prefixes):
|
||||
logger.debug("API key token detected (prefix match), passing through")
|
||||
# Populate ``scopes`` from ``self.required_scopes`` so FastMCP's
|
||||
# ``RequireAuthMiddleware`` (transport-layer scope check) is
|
||||
# satisfied for API-key requests. Without this, MCP_REQUIRED_SCOPES
|
||||
# being non-empty would 403 every API-key call before
|
||||
# ``_resolve_user_from_api_key`` even runs.
|
||||
return AccessToken(
|
||||
token=token,
|
||||
client_id="api_key",
|
||||
scopes=list(self.required_scopes or []),
|
||||
claims={API_KEY_PASSTHROUGH_CLAIM: True},
|
||||
)
|
||||
|
||||
if self._jwt_verifier is None:
|
||||
logger.debug(
|
||||
"Bearer token does not match any API key prefix and no JWT "
|
||||
"verifier is configured; rejecting"
|
||||
)
|
||||
return None
|
||||
|
||||
return await self._jwt_verifier.verify_token(token)
|
||||
@@ -20,15 +20,12 @@ import logging
|
||||
import secrets
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastmcp.server.auth.providers.jwt import JWTVerifier
|
||||
from flask import Flask
|
||||
|
||||
from superset.mcp_service.composite_token_verifier import CompositeTokenVerifier
|
||||
from superset.mcp_service.constants import (
|
||||
DEFAULT_TOKEN_LIMIT,
|
||||
DEFAULT_WARN_THRESHOLD_PCT,
|
||||
)
|
||||
from superset.mcp_service.jwt_verifier import DetailedJWTVerifier
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -294,94 +291,56 @@ MCP_TOOL_SEARCH_CONFIG: Dict[str, Any] = {
|
||||
|
||||
|
||||
def create_default_mcp_auth_factory(app: Flask) -> Optional[Any]:
|
||||
"""Default MCP auth factory using app.config values.
|
||||
|
||||
Returns an auth provider when ``MCP_AUTH_ENABLED=True`` (JWT verifier,
|
||||
optionally wrapped with ``CompositeTokenVerifier`` for API keys) or
|
||||
when only ``FAB_API_KEY_ENABLED=True`` (API-key-only verifier that
|
||||
rejects all non-API-key Bearer tokens at the transport).
|
||||
"""
|
||||
auth_enabled = app.config.get("MCP_AUTH_ENABLED", False)
|
||||
api_key_enabled = app.config.get("FAB_API_KEY_ENABLED", False)
|
||||
|
||||
if not (auth_enabled or api_key_enabled):
|
||||
"""Default MCP auth factory using app.config values."""
|
||||
if not app.config.get("MCP_AUTH_ENABLED", False):
|
||||
return None
|
||||
|
||||
jwt_verifier: Any | None = None
|
||||
jwks_uri = app.config.get("MCP_JWKS_URI")
|
||||
public_key = app.config.get("MCP_JWT_PUBLIC_KEY")
|
||||
secret = app.config.get("MCP_JWT_SECRET")
|
||||
|
||||
if auth_enabled:
|
||||
jwks_uri = app.config.get("MCP_JWKS_URI")
|
||||
public_key = app.config.get("MCP_JWT_PUBLIC_KEY")
|
||||
secret = app.config.get("MCP_JWT_SECRET")
|
||||
if not (jwks_uri or public_key or secret):
|
||||
logger.warning("MCP_AUTH_ENABLED is True but no JWT keys/secret configured")
|
||||
return None
|
||||
|
||||
if not (jwks_uri or public_key or secret):
|
||||
logger.warning("MCP_AUTH_ENABLED is True but no JWT keys/secret configured")
|
||||
if not api_key_enabled:
|
||||
return None
|
||||
try:
|
||||
debug_errors = app.config.get("MCP_JWT_DEBUG_ERRORS", False)
|
||||
|
||||
common_kwargs: dict[str, Any] = {
|
||||
"issuer": app.config.get("MCP_JWT_ISSUER"),
|
||||
"audience": app.config.get("MCP_JWT_AUDIENCE"),
|
||||
"required_scopes": app.config.get("MCP_REQUIRED_SCOPES", []),
|
||||
}
|
||||
|
||||
# For HS256 (symmetric), use the secret as the public_key parameter
|
||||
if app.config.get("MCP_JWT_ALGORITHM") == "HS256" and secret:
|
||||
common_kwargs["public_key"] = secret
|
||||
common_kwargs["algorithm"] = "HS256"
|
||||
else:
|
||||
try:
|
||||
jwt_verifier = _build_jwt_verifier(
|
||||
app=app,
|
||||
jwks_uri=jwks_uri,
|
||||
public_key=public_key,
|
||||
secret=secret,
|
||||
)
|
||||
except Exception: # noqa: BLE001 — JWT lib raises many types; broad catch intentional
|
||||
# Do not log the exception — it may contain secrets (e.g., key material)
|
||||
logger.error("Failed to create MCP JWT verifier")
|
||||
if not api_key_enabled:
|
||||
return None
|
||||
# For RS256 (asymmetric), use public key or JWKS
|
||||
common_kwargs["jwks_uri"] = jwks_uri
|
||||
common_kwargs["public_key"] = public_key
|
||||
common_kwargs["algorithm"] = app.config.get("MCP_JWT_ALGORITHM", "RS256")
|
||||
|
||||
if api_key_enabled:
|
||||
raw_prefixes = app.config.get("FAB_API_KEY_PREFIXES", ["sst_"])
|
||||
# Normalize: a plain string (e.g. "sst_") would iterate as characters;
|
||||
# wrap it in a list so CompositeTokenVerifier receives a proper sequence.
|
||||
if isinstance(raw_prefixes, str):
|
||||
api_key_prefixes = [raw_prefixes]
|
||||
if debug_errors:
|
||||
# DetailedJWTVerifier: detailed server-side logging of JWT
|
||||
# validation failures. HTTP responses are always generic per
|
||||
# RFC 6750 Section 3.1.
|
||||
from superset.mcp_service.jwt_verifier import DetailedJWTVerifier
|
||||
|
||||
auth_provider = DetailedJWTVerifier(**common_kwargs)
|
||||
else:
|
||||
api_key_prefixes = list(raw_prefixes)
|
||||
logger.info("API key auth enabled for MCP")
|
||||
return CompositeTokenVerifier(
|
||||
jwt_verifier=jwt_verifier,
|
||||
api_key_prefixes=api_key_prefixes,
|
||||
)
|
||||
# Default JWTVerifier: minimal logging, generic error responses.
|
||||
from fastmcp.server.auth.providers.jwt import JWTVerifier
|
||||
|
||||
return jwt_verifier
|
||||
auth_provider = JWTVerifier(**common_kwargs)
|
||||
|
||||
|
||||
def _build_jwt_verifier(
|
||||
app: Flask,
|
||||
jwks_uri: Optional[str],
|
||||
public_key: Optional[str],
|
||||
secret: Optional[str],
|
||||
) -> Any:
|
||||
"""Construct the JWT verifier from configured keys/secret."""
|
||||
debug_errors = app.config.get("MCP_JWT_DEBUG_ERRORS", False)
|
||||
|
||||
common_kwargs: Dict[str, Any] = {
|
||||
"issuer": app.config.get("MCP_JWT_ISSUER"),
|
||||
"audience": app.config.get("MCP_JWT_AUDIENCE"),
|
||||
"required_scopes": app.config.get("MCP_REQUIRED_SCOPES", []),
|
||||
}
|
||||
|
||||
# For HS256 (symmetric), use the secret as the public_key parameter
|
||||
if app.config.get("MCP_JWT_ALGORITHM") == "HS256" and secret:
|
||||
common_kwargs["public_key"] = secret
|
||||
common_kwargs["algorithm"] = "HS256"
|
||||
else:
|
||||
# For RS256 (asymmetric), use public key or JWKS
|
||||
common_kwargs["jwks_uri"] = jwks_uri
|
||||
common_kwargs["public_key"] = public_key
|
||||
common_kwargs["algorithm"] = app.config.get("MCP_JWT_ALGORITHM", "RS256")
|
||||
|
||||
if debug_errors:
|
||||
# DetailedJWTVerifier: detailed server-side logging of JWT
|
||||
# validation failures. HTTP responses are always generic per
|
||||
# RFC 6750 Section 3.1.
|
||||
return DetailedJWTVerifier(**common_kwargs)
|
||||
|
||||
# Default JWTVerifier: minimal logging, generic error responses.
|
||||
return JWTVerifier(**common_kwargs)
|
||||
return auth_provider
|
||||
except Exception:
|
||||
# Do not log the exception — it may contain the HS256 secret
|
||||
# from common_kwargs["public_key"]
|
||||
logger.error("Failed to create MCP auth provider")
|
||||
return None
|
||||
|
||||
|
||||
def default_user_resolver(app: Any, access_token: Any) -> str | None:
|
||||
|
||||
@@ -429,13 +429,13 @@ def _tool_allowed_for_current_user(tool: Any) -> bool:
|
||||
if not getattr(g, "user", None):
|
||||
try:
|
||||
g.user = get_user_from_request()
|
||||
except (ValueError, PermissionError):
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
method_permission_name = getattr(tool_func, METHOD_PERMISSION_ATTR, "read")
|
||||
permission_name = f"{PERMISSION_PREFIX}{method_permission_name}"
|
||||
return security_manager.can_access(permission_name, class_permission_name)
|
||||
except (AttributeError, RuntimeError, ValueError, PermissionError):
|
||||
except (AttributeError, RuntimeError, ValueError):
|
||||
logger.debug("Could not evaluate tool search permission", exc_info=True)
|
||||
return False
|
||||
|
||||
@@ -673,9 +673,7 @@ def _create_auth_provider(flask_app: Any) -> Any | None:
|
||||
"""Create an auth provider from Flask app config.
|
||||
|
||||
Tries MCP_AUTH_FACTORY first, then falls back to the default factory
|
||||
when either ``MCP_AUTH_ENABLED`` (JWT auth) or ``FAB_API_KEY_ENABLED``
|
||||
(API key auth) is True. The default factory builds a
|
||||
``CompositeTokenVerifier`` that handles either or both auth modes.
|
||||
when MCP_AUTH_ENABLED is True.
|
||||
"""
|
||||
auth_provider = None
|
||||
if auth_factory := flask_app.config.get("MCP_AUTH_FACTORY"):
|
||||
@@ -688,9 +686,7 @@ def _create_auth_provider(flask_app: Any) -> Any | None:
|
||||
except Exception:
|
||||
# Do not log the exception — it may contain secrets
|
||||
logger.error("Failed to create auth provider from MCP_AUTH_FACTORY")
|
||||
elif flask_app.config.get("MCP_AUTH_ENABLED", False) or flask_app.config.get(
|
||||
"FAB_API_KEY_ENABLED", False
|
||||
):
|
||||
elif flask_app.config.get("MCP_AUTH_ENABLED", False):
|
||||
from superset.mcp_service.mcp_config import (
|
||||
create_default_mcp_auth_factory,
|
||||
)
|
||||
|
||||
@@ -29,8 +29,7 @@ BLOCKLIST = {
|
||||
# sqlite creates a local DB, which allows mapping server's filesystem
|
||||
re.compile(r"sqlite(?:\+[^\s]*)?$"),
|
||||
# shillelagh allows opening local files (eg, 'SELECT * FROM "csv:///etc/passwd"')
|
||||
re.compile(r"shillelagh$"),
|
||||
re.compile(r"shillelagh\+apsw$"),
|
||||
re.compile(r"shillelagh(?:\+[^\s]*)?$"),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -52,8 +52,7 @@ from flask_login import AnonymousUserMixin, LoginManager
|
||||
from jwt.api_jwt import _jwt_global_obj
|
||||
from sqlalchemy import and_, inspect, or_
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from sqlalchemy.orm import eagerload, joinedload
|
||||
from sqlalchemy.orm.exc import MultipleResultsFound
|
||||
from sqlalchemy.orm import eagerload
|
||||
from sqlalchemy.orm.mapper import Mapper
|
||||
from sqlalchemy.orm.query import Query as SqlaQuery
|
||||
from sqlalchemy.sql import exists
|
||||
@@ -463,10 +462,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
"PermissionViewMenu",
|
||||
"ViewMenu",
|
||||
"User",
|
||||
# FAB ApiKeyApi blueprint (active when FAB_API_KEY_ENABLED=True).
|
||||
# Listed unconditionally — harmless when the feature is off because
|
||||
# no PVMs exist under this view menu.
|
||||
"ApiKey",
|
||||
} | USER_MODEL_VIEWS
|
||||
|
||||
ALPHA_ONLY_VIEW_MENUS = {
|
||||
@@ -3169,60 +3164,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
def find_user_with_relationships(
|
||||
self,
|
||||
username: Optional[str] = None,
|
||||
email: Optional[str] = None,
|
||||
) -> Optional[User]:
|
||||
"""Find a user with roles and group roles eagerly loaded.
|
||||
|
||||
Mirrors FAB's ``SecurityManager.find_user``
|
||||
(including ``auth_username_ci`` case-insensitive handling and
|
||||
``MultipleResultsFound`` guard) and additionally eager-loads
|
||||
``User.roles`` and ``User.groups.roles`` to prevent detached-instance
|
||||
errors when the SQLAlchemy session is closed or rolled back after the
|
||||
lookup — as happens in MCP tool-execution contexts.
|
||||
"""
|
||||
eager = [
|
||||
joinedload(self.user_model.roles),
|
||||
joinedload(self.user_model.groups).joinedload("roles"),
|
||||
]
|
||||
if username:
|
||||
try:
|
||||
if self.auth_username_ci:
|
||||
from sqlalchemy import func as sa_func
|
||||
|
||||
return (
|
||||
self.session.query(self.user_model)
|
||||
.options(*eager)
|
||||
.filter(
|
||||
sa_func.lower(self.user_model.username)
|
||||
== sa_func.lower(username)
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
return (
|
||||
self.session.query(self.user_model)
|
||||
.options(*eager)
|
||||
.filter(self.user_model.username == username)
|
||||
.one_or_none()
|
||||
)
|
||||
except MultipleResultsFound:
|
||||
logger.error("Multiple results found for user %s", username)
|
||||
return None
|
||||
if email:
|
||||
try:
|
||||
return (
|
||||
self.session.query(self.user_model)
|
||||
.options(*eager)
|
||||
.filter_by(email=email)
|
||||
.one_or_none()
|
||||
)
|
||||
except MultipleResultsFound:
|
||||
logger.error("Multiple results found for user with email %s", email)
|
||||
return None
|
||||
return None
|
||||
|
||||
def get_anonymous_user(self) -> User:
|
||||
return AnonymousUserMixin()
|
||||
|
||||
|
||||
@@ -72,11 +72,30 @@ from superset.security.analytics_db_safety import check_sqlalchemy_uri
|
||||
True,
|
||||
"shillelagh cannot be used as a data source for security reasons.",
|
||||
),
|
||||
("shillelagh+:///home/superset/bad.db", False, None),
|
||||
(
|
||||
"shillelagh+:///home/superset/bad.db",
|
||||
True,
|
||||
"shillelagh cannot be used as a data source for security reasons.",
|
||||
),
|
||||
(
|
||||
"shillelagh+something:///home/superset/bad.db",
|
||||
False,
|
||||
None,
|
||||
True,
|
||||
"shillelagh cannot be used as a data source for security reasons.",
|
||||
),
|
||||
(
|
||||
"shillelagh+csv:///etc/passwd",
|
||||
True,
|
||||
"shillelagh cannot be used as a data source for security reasons.",
|
||||
),
|
||||
(
|
||||
"shillelagh+json:///etc/passwd",
|
||||
True,
|
||||
"shillelagh cannot be used as a data source for security reasons.",
|
||||
),
|
||||
(
|
||||
"shillelagh+gsheets:///",
|
||||
True,
|
||||
"shillelagh cannot be used as a data source for security reasons.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -279,3 +279,49 @@ def test_query_dao_stop_query(
|
||||
QueryDAO.stop_query(query_obj.client_id)
|
||||
query = db.session.query(Query).one()
|
||||
assert query.status == QueryStatus.STOPPED
|
||||
|
||||
|
||||
def test_query_dao_stop_query_wrong_user(
|
||||
mocker: MockerFixture, app: Any, session: Session
|
||||
) -> None:
|
||||
"""A user cannot stop a query that belongs to a different user."""
|
||||
from superset import db
|
||||
from superset.common.db_query_status import QueryStatus
|
||||
from superset.models.core import Database
|
||||
from superset.models.sql_lab import Query
|
||||
|
||||
engine = db.session.get_bind()
|
||||
Query.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
|
||||
query_obj = Query(
|
||||
client_id="foo",
|
||||
database=database,
|
||||
tab_name="test_tab",
|
||||
sql_editor_id="test_editor_id",
|
||||
sql="select * from bar",
|
||||
select_sql="select * from bar",
|
||||
executed_sql="select * from bar",
|
||||
limit=100,
|
||||
select_as_cta=False,
|
||||
rows=100,
|
||||
error_message="none",
|
||||
results_key="abc",
|
||||
status=QueryStatus.RUNNING,
|
||||
user_id=1,
|
||||
)
|
||||
|
||||
db.session.add(database)
|
||||
db.session.add(query_obj)
|
||||
|
||||
# Simulate a different user (user 2) attempting to stop user 1's query
|
||||
mocker.patch("superset.daos.query.get_user_id", return_value=2)
|
||||
|
||||
from superset.daos.query import QueryDAO
|
||||
|
||||
with pytest.raises(QueryNotFoundException):
|
||||
QueryDAO.stop_query(query_obj.client_id)
|
||||
|
||||
query = db.session.query(Query).one()
|
||||
assert query.status == QueryStatus.RUNNING
|
||||
|
||||
@@ -31,6 +31,7 @@ from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset import db, security_manager
|
||||
from superset.commands.dataset.exceptions import (
|
||||
DatasetAccessDeniedError,
|
||||
DatasetForbiddenDataURI,
|
||||
)
|
||||
from superset.commands.dataset.importers.v1.utils import (
|
||||
@@ -744,6 +745,44 @@ def test_import_dataset_without_owner_permission(
|
||||
mock_can_access.assert_called_with("can_write", "Dataset")
|
||||
|
||||
|
||||
def test_import_dataset_access_check(
|
||||
mocker: MockerFixture,
|
||||
session: Session,
|
||||
) -> None:
|
||||
"""
|
||||
Test that import_dataset raises DatasetAccessDeniedError when the user does not
|
||||
have datasource-level access to the target dataset.
|
||||
"""
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
|
||||
mocker.patch.object(security_manager, "can_access", return_value=True)
|
||||
mocker.patch.object(
|
||||
security_manager,
|
||||
"raise_for_access",
|
||||
side_effect=SupersetSecurityException(
|
||||
SupersetError(
|
||||
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
|
||||
message="User does not have access to this datasource",
|
||||
level=ErrorLevel.ERROR,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
engine = db.session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
db.session.add(database)
|
||||
db.session.flush()
|
||||
|
||||
config = copy.deepcopy(dataset_fixture)
|
||||
config["database_id"] = database.id
|
||||
|
||||
with pytest.raises(DatasetAccessDeniedError):
|
||||
import_dataset(config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"allowed_urls, data_uri, expected, exception_class",
|
||||
[
|
||||
|
||||
156
tests/unit_tests/extensions/test_cache_middleware.py
Normal file
156
tests/unit_tests/extensions/test_cache_middleware.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any, Callable
|
||||
|
||||
from superset.extensions.cache_middleware import ExtensionCacheMiddleware
|
||||
|
||||
ResponseHeaders = list[tuple[str, str]]
|
||||
|
||||
|
||||
def make_wsgi_app(
|
||||
status: str = "200 OK",
|
||||
headers: ResponseHeaders | None = None,
|
||||
) -> Callable[..., Any]:
|
||||
"""Returns a minimal WSGI app that calls start_response with the given headers."""
|
||||
|
||||
def app(environ, start_response): # noqa: ARG001
|
||||
start_response(status, headers or [])
|
||||
return [b"body"]
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def call_middleware(
|
||||
path: str,
|
||||
upstream_headers: ResponseHeaders,
|
||||
) -> ResponseHeaders:
|
||||
"""Run middleware for a given path, return headers passed to start_response."""
|
||||
captured: list[ResponseHeaders] = []
|
||||
|
||||
def start_response(status, headers, exc_info=None): # noqa: ARG001
|
||||
captured.append(headers)
|
||||
|
||||
wsgi_app = make_wsgi_app(headers=upstream_headers)
|
||||
middleware = ExtensionCacheMiddleware(wsgi_app)
|
||||
environ = {"PATH_INFO": path}
|
||||
list(middleware(environ, start_response))
|
||||
|
||||
return captured[0]
|
||||
|
||||
|
||||
# --- Path matching ---
|
||||
|
||||
|
||||
def test_asset_path_is_intercepted() -> None:
|
||||
headers = call_middleware(
|
||||
"/api/v1/extensions/acme/my-ext/main.js",
|
||||
[("Vary", "Accept-Encoding, Cookie")],
|
||||
)
|
||||
vary = dict(headers).get("Vary", "")
|
||||
assert "Cookie" not in vary
|
||||
|
||||
|
||||
def test_list_endpoint_is_not_intercepted() -> None:
|
||||
upstream = [("Vary", "Accept-Encoding, Cookie")]
|
||||
headers = call_middleware("/api/v1/extensions/", upstream)
|
||||
assert headers == upstream
|
||||
|
||||
|
||||
def test_get_endpoint_is_not_intercepted() -> None:
|
||||
upstream = [("Vary", "Accept-Encoding, Cookie")]
|
||||
headers = call_middleware("/api/v1/extensions/acme/my-ext", upstream)
|
||||
assert headers == upstream
|
||||
|
||||
|
||||
def test_info_endpoint_is_not_intercepted() -> None:
|
||||
upstream = [("Vary", "Accept-Encoding, Cookie")]
|
||||
headers = call_middleware("/api/v1/extensions/_info", upstream)
|
||||
assert headers == upstream
|
||||
|
||||
|
||||
def test_unrelated_path_is_not_intercepted() -> None:
|
||||
upstream = [("Vary", "Accept-Encoding, Cookie")]
|
||||
headers = call_middleware("/api/v1/dashboard/", upstream)
|
||||
assert headers == upstream
|
||||
|
||||
|
||||
# --- Vary stripping logic ---
|
||||
|
||||
|
||||
def test_strips_cookie_from_vary() -> None:
|
||||
headers = call_middleware(
|
||||
"/api/v1/extensions/acme/my-ext/chunk.js",
|
||||
[("Vary", "Accept-Encoding, Cookie")],
|
||||
)
|
||||
assert dict(headers)["Vary"] == "Accept-Encoding"
|
||||
|
||||
|
||||
def test_strips_cookie_case_insensitive() -> None:
|
||||
headers = call_middleware(
|
||||
"/api/v1/extensions/acme/my-ext/chunk.js",
|
||||
[("Vary", "Accept-Encoding, COOKIE")],
|
||||
)
|
||||
assert dict(headers)["Vary"] == "Accept-Encoding"
|
||||
|
||||
|
||||
def test_removes_vary_header_when_cookie_is_only_value() -> None:
|
||||
headers = call_middleware(
|
||||
"/api/v1/extensions/acme/my-ext/chunk.js",
|
||||
[("Vary", "Cookie")],
|
||||
)
|
||||
assert "Vary" not in dict(headers)
|
||||
|
||||
|
||||
def test_multiple_vary_headers_all_stripped() -> None:
|
||||
"""Some middleware stacks emit multiple separate Vary headers."""
|
||||
headers = call_middleware(
|
||||
"/api/v1/extensions/acme/my-ext/chunk.js",
|
||||
[("Vary", "Cookie"), ("Vary", "Accept-Encoding, Cookie")],
|
||||
)
|
||||
vary_values = [v for k, v in headers if k == "Vary"]
|
||||
assert all("Cookie" not in v for v in vary_values)
|
||||
assert vary_values == ["Accept-Encoding"]
|
||||
|
||||
|
||||
def test_non_vary_headers_are_preserved() -> None:
|
||||
headers = call_middleware(
|
||||
"/api/v1/extensions/acme/my-ext/chunk.wasm",
|
||||
[
|
||||
("Content-Type", "application/wasm"),
|
||||
("Cache-Control", "public, max-age=31536000, immutable"),
|
||||
("Vary", "Accept-Encoding, Cookie"),
|
||||
],
|
||||
)
|
||||
d = dict(headers)
|
||||
assert d["Content-Type"] == "application/wasm"
|
||||
assert d["Cache-Control"] == "public, max-age=31536000, immutable"
|
||||
|
||||
|
||||
def test_vary_without_cookie_is_unchanged() -> None:
|
||||
headers = call_middleware(
|
||||
"/api/v1/extensions/acme/my-ext/chunk.js",
|
||||
[("Vary", "Accept-Encoding")],
|
||||
)
|
||||
assert dict(headers)["Vary"] == "Accept-Encoding"
|
||||
|
||||
|
||||
def test_no_vary_header_produces_no_vary() -> None:
|
||||
headers = call_middleware(
|
||||
"/api/v1/extensions/acme/my-ext/chunk.js",
|
||||
[("Content-Type", "application/javascript")],
|
||||
)
|
||||
assert "Vary" not in dict(headers)
|
||||
@@ -15,55 +15,25 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Tests for API key authentication in get_user_from_request().
|
||||
"""Tests for API key authentication in get_user_from_request()."""
|
||||
|
||||
The streamable-http transport does not push a Flask request context, so
|
||||
``_resolve_user_from_api_key`` reads the token from FastMCP's per-request
|
||||
``AccessToken`` (populated by ``CompositeTokenVerifier``) rather than from
|
||||
``flask.request``. These tests mock ``get_access_token`` accordingly.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import g
|
||||
|
||||
from superset.app import SupersetApp
|
||||
from superset.mcp_service.auth import (
|
||||
_resolve_user_from_jwt_context,
|
||||
get_user_from_request,
|
||||
)
|
||||
from superset.mcp_service.composite_token_verifier import API_KEY_PASSTHROUGH_CLAIM
|
||||
from superset.mcp_service.auth import get_user_from_request
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user() -> MagicMock:
|
||||
def mock_user():
|
||||
user = MagicMock()
|
||||
user.username = "api_key_user"
|
||||
return user
|
||||
|
||||
|
||||
def _passthrough_access_token(token: str) -> MagicMock:
|
||||
"""Build an AccessToken matching what CompositeTokenVerifier emits."""
|
||||
access_token = MagicMock()
|
||||
access_token.token = token
|
||||
access_token.client_id = "api_key"
|
||||
access_token.claims = {API_KEY_PASSTHROUGH_CLAIM: True}
|
||||
return access_token
|
||||
|
||||
|
||||
def _patch_access_token(access_token: MagicMock | None):
|
||||
"""Patch get_access_token where _resolve_user_from_api_key imports it."""
|
||||
return patch(
|
||||
"fastmcp.server.dependencies.get_access_token",
|
||||
return_value=access_token,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _enable_api_keys(app: SupersetApp) -> Generator[None, None, None]:
|
||||
def _enable_api_keys(app):
|
||||
"""Enable FAB API key auth and clear MCP_DEV_USERNAME so the API key
|
||||
path is exercised instead of falling through to the dev-user fallback."""
|
||||
app.config["FAB_API_KEY_ENABLED"] = True
|
||||
@@ -75,7 +45,7 @@ def _enable_api_keys(app: SupersetApp) -> Generator[None, None, None]:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _disable_api_keys(app: SupersetApp) -> Generator[None, None, None]:
|
||||
def _disable_api_keys(app):
|
||||
app.config["FAB_API_KEY_ENABLED"] = False
|
||||
old_dev = app.config.pop("MCP_DEV_USERNAME", None)
|
||||
yield
|
||||
@@ -84,45 +54,24 @@ def _disable_api_keys(app: SupersetApp) -> Generator[None, None, None]:
|
||||
app.config["MCP_DEV_USERNAME"] = old_dev
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _mock_sm_ctx(app: SupersetApp, mock_sm: MagicMock):
|
||||
"""Push an app context with g.user cleared and appbuilder.sm mocked."""
|
||||
with app.app_context():
|
||||
g.user = None
|
||||
app.appbuilder = MagicMock()
|
||||
app.appbuilder.sm = mock_sm
|
||||
yield
|
||||
|
||||
|
||||
def _patch_load_user_not_found():
|
||||
"""Patch load_user_with_relationships to return None (user not found).
|
||||
|
||||
load_user_with_relationships delegates to the global security_manager
|
||||
(not app.appbuilder.sm), so tests that need the JWT path to raise
|
||||
ValueError("not found") must patch it directly at the module level.
|
||||
"""
|
||||
return patch(
|
||||
"superset.mcp_service.auth.load_user_with_relationships",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
|
||||
# -- Valid API key -> user loaded --
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_enable_api_keys")
|
||||
def test_valid_api_key_returns_user(app: SupersetApp, mock_user: MagicMock) -> None:
|
||||
"""A valid API key pass-through token should authenticate and return the user."""
|
||||
def test_valid_api_key_returns_user(app, mock_user) -> None:
|
||||
"""A valid API key should authenticate and return the user."""
|
||||
mock_sm = MagicMock()
|
||||
mock_sm.extract_api_key_from_request.return_value = "sst_abc123"
|
||||
mock_sm.validate_api_key.return_value = mock_user
|
||||
|
||||
with _mock_sm_ctx(app, mock_sm):
|
||||
with (
|
||||
_patch_access_token(_passthrough_access_token("sst_abc123")),
|
||||
patch(
|
||||
"superset.mcp_service.auth.load_user_with_relationships",
|
||||
return_value=mock_user,
|
||||
),
|
||||
with app.test_request_context(headers={"Authorization": "Bearer sst_abc123"}):
|
||||
g.user = None
|
||||
app.appbuilder = MagicMock()
|
||||
app.appbuilder.sm = mock_sm
|
||||
|
||||
with patch(
|
||||
"superset.mcp_service.auth.load_user_with_relationships",
|
||||
return_value=mock_user,
|
||||
):
|
||||
result = get_user_from_request()
|
||||
|
||||
@@ -130,70 +79,75 @@ def test_valid_api_key_returns_user(app: SupersetApp, mock_user: MagicMock) -> N
|
||||
mock_sm.validate_api_key.assert_called_once_with("sst_abc123")
|
||||
|
||||
|
||||
# -- Invalid API key -> PermissionError (does not silently fall back) --
|
||||
# -- Invalid API key -> PermissionError --
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_enable_api_keys")
|
||||
def test_invalid_api_key_raises(app: SupersetApp) -> None:
|
||||
"""An invalid API key pass-through token should raise PermissionError
|
||||
(fail closed — do NOT fall through to MCP_DEV_USERNAME)."""
|
||||
def test_invalid_api_key_raises(app) -> None:
|
||||
"""An invalid API key should raise PermissionError."""
|
||||
mock_sm = MagicMock()
|
||||
mock_sm.extract_api_key_from_request.return_value = "sst_bad_key"
|
||||
mock_sm.validate_api_key.return_value = None
|
||||
|
||||
# The dangerous fallthrough scenario: dev username IS set, but the
|
||||
# request presented an invalid API key. The dev fallback must not
|
||||
# mask the rejection.
|
||||
app.config["MCP_DEV_USERNAME"] = "admin"
|
||||
try:
|
||||
with _mock_sm_ctx(app, mock_sm):
|
||||
with _patch_access_token(_passthrough_access_token("sst_bad_key")):
|
||||
with pytest.raises(PermissionError, match="Invalid or expired API key"):
|
||||
get_user_from_request()
|
||||
finally:
|
||||
app.config.pop("MCP_DEV_USERNAME", None)
|
||||
with app.test_request_context(headers={"Authorization": "Bearer sst_bad_key"}):
|
||||
g.user = None
|
||||
app.appbuilder = MagicMock()
|
||||
app.appbuilder.sm = mock_sm
|
||||
|
||||
with pytest.raises(PermissionError, match="Invalid or expired API key"):
|
||||
get_user_from_request()
|
||||
|
||||
|
||||
# -- API key disabled -> falls through to next auth method --
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_disable_api_keys")
|
||||
def test_api_key_disabled_skips_auth(app: SupersetApp) -> None:
|
||||
"""When FAB_API_KEY_ENABLED is False, API key auth is skipped entirely
|
||||
even if an AccessToken is present."""
|
||||
def test_api_key_disabled_skips_auth(app) -> None:
|
||||
"""When FAB_API_KEY_ENABLED is False, API key auth is skipped entirely."""
|
||||
mock_sm = MagicMock()
|
||||
|
||||
with _mock_sm_ctx(app, mock_sm):
|
||||
with _patch_access_token(_passthrough_access_token("sst_abc123")):
|
||||
with pytest.raises(ValueError, match="No authenticated user found"):
|
||||
get_user_from_request()
|
||||
with app.test_request_context(headers={"Authorization": "Bearer sst_abc123"}):
|
||||
g.user = None
|
||||
app.appbuilder = MagicMock()
|
||||
app.appbuilder.sm = mock_sm
|
||||
|
||||
mock_sm.validate_api_key.assert_not_called()
|
||||
# Without API key auth or MCP_DEV_USERNAME, should raise ValueError
|
||||
# about no authenticated user (not about invalid API key)
|
||||
with pytest.raises(ValueError, match="No authenticated user found"):
|
||||
get_user_from_request()
|
||||
|
||||
# SecurityManager API key methods should never be called
|
||||
mock_sm.extract_api_key_from_request.assert_not_called()
|
||||
|
||||
|
||||
# -- No AccessToken -> API key auth skipped --
|
||||
# -- No request context -> API key auth skipped --
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_enable_api_keys")
|
||||
def test_no_access_token_skips_api_key_auth(app: SupersetApp) -> None:
|
||||
"""Without a FastMCP AccessToken (e.g., MCP_AUTH_ENABLED=False and no
|
||||
auth provider installed), API key auth is skipped."""
|
||||
def test_no_request_context_skips_api_key_auth(app) -> None:
|
||||
"""Without a request context, API key auth should be skipped
|
||||
(e.g., during MCP tool discovery with only an app context)."""
|
||||
mock_sm = MagicMock()
|
||||
|
||||
with _mock_sm_ctx(app, mock_sm):
|
||||
with _patch_access_token(None):
|
||||
with app.app_context():
|
||||
g.user = None
|
||||
app.appbuilder = MagicMock()
|
||||
app.appbuilder.sm = mock_sm
|
||||
|
||||
# Explicitly mock has_request_context to False because the test
|
||||
# framework's app fixture may implicitly provide a request context.
|
||||
with patch("superset.mcp_service.auth.has_request_context", return_value=False):
|
||||
with pytest.raises(ValueError, match="No authenticated user found"):
|
||||
get_user_from_request()
|
||||
|
||||
mock_sm.validate_api_key.assert_not_called()
|
||||
mock_sm.extract_api_key_from_request.assert_not_called()
|
||||
|
||||
|
||||
# -- g.user fallback when no higher-priority auth succeeds --
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_disable_api_keys")
|
||||
def test_g_user_fallback_when_no_jwt_or_api_key(
|
||||
app: SupersetApp, mock_user: MagicMock
|
||||
) -> None:
|
||||
def test_g_user_fallback_when_no_jwt_or_api_key(app, mock_user) -> None:
|
||||
"""When no JWT or API key auth succeeds and MCP_DEV_USERNAME is not set,
|
||||
g.user (set by external middleware) is used as fallback."""
|
||||
with app.test_request_context():
|
||||
@@ -204,174 +158,89 @@ def test_g_user_fallback_when_no_jwt_or_api_key(
|
||||
assert result.username == "api_key_user"
|
||||
|
||||
|
||||
# -- FAB version without extract_api_key_from_request --
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_enable_api_keys")
|
||||
def test_fab_without_extract_method_skips_gracefully(app) -> None:
|
||||
"""If FAB SecurityManager lacks extract_api_key_from_request,
|
||||
API key auth should be skipped with a debug log, not crash."""
|
||||
mock_sm = MagicMock(spec=[]) # empty spec = no attributes
|
||||
|
||||
with app.test_request_context():
|
||||
g.user = None
|
||||
app.appbuilder = MagicMock()
|
||||
app.appbuilder.sm = mock_sm
|
||||
|
||||
with pytest.raises(ValueError, match="No authenticated user found"):
|
||||
get_user_from_request()
|
||||
|
||||
|
||||
# -- FAB version without validate_api_key --
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_enable_api_keys")
|
||||
def test_fab_without_validate_method_raises(app: SupersetApp) -> None:
|
||||
"""If FAB SecurityManager lacks validate_api_key, should raise
|
||||
PermissionError about unavailable validation."""
|
||||
mock_sm = MagicMock(spec=[]) # empty spec = no attributes
|
||||
def test_fab_without_validate_method_raises(app) -> None:
|
||||
"""If FAB has extract_api_key_from_request but not validate_api_key,
|
||||
should raise PermissionError about unavailable validation."""
|
||||
mock_sm = MagicMock(spec=["extract_api_key_from_request"])
|
||||
mock_sm.extract_api_key_from_request.return_value = "sst_abc123"
|
||||
|
||||
with _mock_sm_ctx(app, mock_sm):
|
||||
with _patch_access_token(_passthrough_access_token("sst_abc123")):
|
||||
with pytest.raises(
|
||||
PermissionError, match="API key validation is not available"
|
||||
):
|
||||
get_user_from_request()
|
||||
with app.test_request_context(headers={"Authorization": "Bearer sst_abc123"}):
|
||||
g.user = None
|
||||
app.appbuilder = MagicMock()
|
||||
app.appbuilder.sm = mock_sm
|
||||
|
||||
with pytest.raises(
|
||||
PermissionError, match="API key validation is not available"
|
||||
):
|
||||
get_user_from_request()
|
||||
|
||||
|
||||
# -- Relationship reload fallback --
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_enable_api_keys")
|
||||
def test_relationship_reload_failure_returns_original_user(
|
||||
app: SupersetApp, mock_user: MagicMock
|
||||
) -> None:
|
||||
def test_relationship_reload_failure_returns_original_user(app, mock_user) -> None:
|
||||
"""If load_user_with_relationships fails, the original user from
|
||||
validate_api_key should be returned as fallback."""
|
||||
mock_sm = MagicMock()
|
||||
mock_sm.extract_api_key_from_request.return_value = "sst_abc123"
|
||||
mock_sm.validate_api_key.return_value = mock_user
|
||||
|
||||
with _mock_sm_ctx(app, mock_sm):
|
||||
with (
|
||||
_patch_access_token(_passthrough_access_token("sst_abc123")),
|
||||
patch(
|
||||
"superset.mcp_service.auth.load_user_with_relationships",
|
||||
return_value=None,
|
||||
),
|
||||
with app.test_request_context(headers={"Authorization": "Bearer sst_abc123"}):
|
||||
g.user = None
|
||||
app.appbuilder = MagicMock()
|
||||
app.appbuilder.sm = mock_sm
|
||||
|
||||
with patch(
|
||||
"superset.mcp_service.auth.load_user_with_relationships",
|
||||
return_value=None,
|
||||
):
|
||||
result = get_user_from_request()
|
||||
|
||||
assert result is mock_user
|
||||
|
||||
|
||||
# -- AccessToken without passthrough claim (plain JWT) -> skip API key auth --
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_enable_api_keys")
|
||||
def test_jwt_access_token_skips_api_key_auth(app: SupersetApp) -> None:
|
||||
"""When the AccessToken is a plain JWT (no API_KEY_PASSTHROUGH_CLAIM),
|
||||
API key auth is skipped — the JWT was already validated by the JWT
|
||||
verifier and resolved in _resolve_user_from_jwt_context."""
|
||||
mock_sm = MagicMock()
|
||||
|
||||
jwt_access_token = MagicMock()
|
||||
jwt_access_token.token = "eyJhbGciOiJIUzI1NiJ9.not-an-api-key" # noqa: S105
|
||||
jwt_access_token.claims = {"sub": "alice"}
|
||||
|
||||
with _mock_sm_ctx(app, mock_sm):
|
||||
with _patch_access_token(jwt_access_token), _patch_load_user_not_found():
|
||||
# _resolve_user_from_jwt_context resolves "alice" from JWT claims
|
||||
# and raises ValueError because the username is not a real user.
|
||||
# We assert that _resolve_user_from_api_key did NOT short-circuit
|
||||
# to the API key path.
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
get_user_from_request()
|
||||
|
||||
mock_sm.validate_api_key.assert_not_called()
|
||||
|
||||
|
||||
# -- API key pass-through detection in JWT context resolver --
|
||||
|
||||
|
||||
def test_jwt_context_with_api_key_passthrough_returns_none(app: SupersetApp) -> None:
|
||||
"""When CompositeTokenVerifier passes through an API key token,
|
||||
_resolve_user_from_jwt_context should detect the namespaced
|
||||
pass-through claim AND client_id=="api_key" and return None so
|
||||
get_user_from_request falls through to _resolve_user_from_api_key."""
|
||||
mock_access_token = MagicMock()
|
||||
mock_access_token.client_id = "api_key"
|
||||
mock_access_token.claims = {API_KEY_PASSTHROUGH_CLAIM: True}
|
||||
|
||||
with patch(
|
||||
"fastmcp.server.dependencies.get_access_token",
|
||||
return_value=mock_access_token,
|
||||
):
|
||||
result = _resolve_user_from_jwt_context(app)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_namespaced_claim_without_api_key_client_id_is_ignored(
|
||||
app: SupersetApp,
|
||||
) -> None:
|
||||
"""An external IdP JWT that includes the namespaced API_KEY_PASSTHROUGH_CLAIM
|
||||
but does NOT have client_id=='api_key' must NOT divert into the API-key path.
|
||||
The client_id guard prevents misclassification / DoS for affected JWT users."""
|
||||
mock_sm = MagicMock()
|
||||
|
||||
rogue_token = MagicMock()
|
||||
rogue_token.token = "eyJhbGciOiJSUzI1NiJ9.idp_jwt_with_rogue_claim" # noqa: S105
|
||||
rogue_token.client_id = "some-idp-client"
|
||||
rogue_token.claims = {API_KEY_PASSTHROUGH_CLAIM: True, "sub": "alice"}
|
||||
|
||||
with _mock_sm_ctx(app, mock_sm):
|
||||
with _patch_access_token(rogue_token), _patch_load_user_not_found():
|
||||
# JWT path resolves "alice" from claims and raises ValueError
|
||||
# because no such user exists.
|
||||
# validate_api_key must NOT be called — the rogue claim was ignored.
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
get_user_from_request()
|
||||
|
||||
mock_sm.validate_api_key.assert_not_called()
|
||||
|
||||
|
||||
# -- Plain JWT with a colliding non-namespaced claim is NOT mistaken for API key --
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_enable_api_keys")
|
||||
def test_unnamespaced_passthrough_claim_does_not_trigger_api_key_path(
|
||||
app: SupersetApp,
|
||||
) -> None:
|
||||
"""A JWT minted by an external IdP that happens to include a custom
|
||||
``_api_key_passthrough`` claim (legacy unnamespaced name) must NOT be
|
||||
treated as an API-key pass-through. Only the namespaced
|
||||
``API_KEY_PASSTHROUGH_CLAIM`` triggers the API-key path."""
|
||||
mock_sm = MagicMock()
|
||||
|
||||
rogue_token = MagicMock()
|
||||
rogue_token.token = "eyJhbGciOiJSUzI1NiJ9.rogue_jwt" # noqa: S105
|
||||
rogue_token.claims = {"_api_key_passthrough": True, "sub": "alice"}
|
||||
|
||||
with _mock_sm_ctx(app, mock_sm):
|
||||
with _patch_access_token(rogue_token), _patch_load_user_not_found():
|
||||
# JWT path resolves "alice" from claims and raises ValueError.
|
||||
# validate_api_key must NOT be called — the rogue claim was ignored.
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
get_user_from_request()
|
||||
|
||||
mock_sm.validate_api_key.assert_not_called()
|
||||
|
||||
|
||||
# -- SecurityManager method name regression test --
|
||||
|
||||
|
||||
def test_security_manager_has_expected_api_key_methods(app: SupersetApp) -> None:
|
||||
"""Regression test: verify the SecurityManager method name referenced in
|
||||
auth._resolve_user_from_api_key() actually exists on the FAB
|
||||
SecurityManager class. Catches future renames before they silently break
|
||||
API key auth at runtime (see PR #39437)."""
|
||||
with app.app_context():
|
||||
from superset import security_manager
|
||||
def test_security_manager_has_expected_api_key_methods() -> None:
|
||||
"""Regression test: verify the SecurityManager method names referenced in
|
||||
auth._resolve_user_from_api_key() actually exist on the FAB SecurityManager
|
||||
class. This catches future renames before they silently break API key auth
|
||||
at runtime (SC-99414: _extract_api_key_from_request vs
|
||||
extract_api_key_from_request)."""
|
||||
from superset import security_manager
|
||||
|
||||
sm = security_manager
|
||||
assert hasattr(sm, "validate_api_key"), (
|
||||
"FAB SecurityManager is missing 'validate_api_key'. "
|
||||
"auth._resolve_user_from_api_key() references this method by name — "
|
||||
"update auth.py if the FAB API changed."
|
||||
)
|
||||
|
||||
|
||||
def test_security_manager_has_find_user_with_relationships(app: SupersetApp) -> None:
|
||||
"""Regression test: verify SupersetSecurityManager.find_user_with_relationships
|
||||
exists. load_user_with_relationships() in auth.py delegates to it — a rename
|
||||
or removal would silently break MCP user resolution at runtime."""
|
||||
with app.app_context():
|
||||
from superset import security_manager
|
||||
|
||||
assert hasattr(security_manager, "find_user_with_relationships"), (
|
||||
"SupersetSecurityManager is missing 'find_user_with_relationships'. "
|
||||
"auth.load_user_with_relationships() delegates to this method — "
|
||||
"update auth.py if the method was renamed or removed."
|
||||
)
|
||||
sm = security_manager
|
||||
assert hasattr(sm, "extract_api_key_from_request"), (
|
||||
"FAB SecurityManager is missing 'extract_api_key_from_request'. "
|
||||
"auth._resolve_user_from_api_key() references this method by name — "
|
||||
"update auth.py if the FAB API changed."
|
||||
)
|
||||
assert hasattr(sm, "validate_api_key"), (
|
||||
"FAB SecurityManager is missing 'validate_api_key'. "
|
||||
"auth._resolve_user_from_api_key() references this method by name — "
|
||||
"update auth.py if the FAB API changed."
|
||||
)
|
||||
|
||||
@@ -285,7 +285,7 @@ def test_mcp_auth_hook_clears_stale_g_user(app) -> None:
|
||||
# framework's autouse app_context fixture may implicitly provide
|
||||
# a request context in some CI environments.
|
||||
with (
|
||||
patch("superset.mcp_service.auth.has_request_context", return_value=False),
|
||||
patch("flask.has_request_context", return_value=False),
|
||||
patch(
|
||||
"superset.mcp_service.auth.get_user_from_request",
|
||||
side_effect=lambda: _assert_cleared_then_return(),
|
||||
@@ -324,7 +324,7 @@ def test_mcp_auth_hook_clears_stale_g_user_async(app) -> None:
|
||||
with app.app_context():
|
||||
g.user = stale_user
|
||||
with (
|
||||
patch("superset.mcp_service.auth.has_request_context", return_value=False),
|
||||
patch("flask.has_request_context", return_value=False),
|
||||
patch(
|
||||
"superset.mcp_service.auth.get_user_from_request",
|
||||
side_effect=lambda: _assert_cleared_then_return(),
|
||||
|
||||
@@ -1,218 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Tests for CompositeTokenVerifier."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastmcp.server.auth import AccessToken
|
||||
|
||||
from superset.mcp_service.composite_token_verifier import (
|
||||
API_KEY_PASSTHROUGH_CLAIM,
|
||||
CompositeTokenVerifier,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jwt_verifier() -> MagicMock:
|
||||
verifier = MagicMock()
|
||||
verifier.required_scopes = []
|
||||
verifier.verify_token = AsyncMock()
|
||||
return verifier
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def composite_verifier(mock_jwt_verifier: MagicMock) -> CompositeTokenVerifier:
|
||||
return CompositeTokenVerifier(
|
||||
jwt_verifier=mock_jwt_verifier,
|
||||
api_key_prefixes=["sst_", "pat_"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_token_returns_passthrough(
|
||||
composite_verifier: CompositeTokenVerifier,
|
||||
) -> None:
|
||||
"""Tokens matching an API key prefix return a pass-through AccessToken."""
|
||||
api_key = "sst_abc123secret" # noqa: S105
|
||||
result = await composite_verifier.verify_token(api_key)
|
||||
|
||||
assert result is not None
|
||||
assert result.token == api_key
|
||||
assert result.client_id == "api_key"
|
||||
assert result.claims.get(API_KEY_PASSTHROUGH_CLAIM) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_second_prefix_matches(
|
||||
composite_verifier: CompositeTokenVerifier,
|
||||
) -> None:
|
||||
"""All configured prefixes are checked, not just the first."""
|
||||
result = await composite_verifier.verify_token("pat_mytoken")
|
||||
|
||||
assert result is not None
|
||||
assert result.claims.get(API_KEY_PASSTHROUGH_CLAIM) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jwt_token_delegates_to_wrapped_verifier(
|
||||
composite_verifier: CompositeTokenVerifier, mock_jwt_verifier: MagicMock
|
||||
) -> None:
|
||||
"""Non-API-key tokens are delegated to the wrapped JWT verifier."""
|
||||
jwt_token = "eyJhbGciOiJSUzI1NiJ9.jwt_payload" # noqa: S105
|
||||
jwt_result = AccessToken(
|
||||
token=jwt_token,
|
||||
client_id="oauth_client",
|
||||
scopes=["read"],
|
||||
claims={"sub": "user1"},
|
||||
)
|
||||
mock_jwt_verifier.verify_token.return_value = jwt_result
|
||||
|
||||
result = await composite_verifier.verify_token("eyJhbGciOiJSUzI1NiJ9.jwt_payload")
|
||||
|
||||
assert result is jwt_result
|
||||
mock_jwt_verifier.verify_token.assert_awaited_once_with(
|
||||
"eyJhbGciOiJSUzI1NiJ9.jwt_payload"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_jwt_returns_none(
|
||||
composite_verifier: CompositeTokenVerifier, mock_jwt_verifier: MagicMock
|
||||
) -> None:
|
||||
"""When the JWT verifier rejects a token, None is returned."""
|
||||
mock_jwt_verifier.verify_token.return_value = None
|
||||
|
||||
result = await composite_verifier.verify_token("not_a_valid_token")
|
||||
|
||||
assert result is None
|
||||
mock_jwt_verifier.verify_token.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_does_not_call_jwt_verifier(
|
||||
composite_verifier: CompositeTokenVerifier, mock_jwt_verifier: MagicMock
|
||||
) -> None:
|
||||
"""API key tokens bypass the JWT verifier entirely."""
|
||||
await composite_verifier.verify_token("sst_test_key")
|
||||
|
||||
mock_jwt_verifier.verify_token.assert_not_awaited()
|
||||
|
||||
|
||||
# -- API-key-only mode (no JWT verifier configured) --
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_only_mode_accepts_api_keys() -> None:
|
||||
"""When jwt_verifier is None, API key tokens are still passed through."""
|
||||
verifier = CompositeTokenVerifier(jwt_verifier=None, api_key_prefixes=["sst_"])
|
||||
|
||||
result = await verifier.verify_token("sst_abc123")
|
||||
|
||||
assert result is not None
|
||||
assert result.claims.get(API_KEY_PASSTHROUGH_CLAIM) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_only_mode_rejects_non_api_key_tokens() -> None:
|
||||
"""When jwt_verifier is None, non-API-key Bearer tokens are rejected at
|
||||
the transport instead of being silently accepted."""
|
||||
verifier = CompositeTokenVerifier(jwt_verifier=None, api_key_prefixes=["sst_"])
|
||||
|
||||
result = await verifier.verify_token("eyJhbGciOiJSUzI1NiJ9.jwt_payload")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_string_prefix_is_filtered_out() -> None:
|
||||
"""An empty-string prefix would match every Bearer token (DoS vector).
|
||||
It must be silently dropped and never stored in _api_key_prefixes."""
|
||||
verifier = CompositeTokenVerifier(jwt_verifier=None, api_key_prefixes=[""])
|
||||
|
||||
assert "" not in verifier._api_key_prefixes
|
||||
# A plain JWT must NOT be misidentified as an API key.
|
||||
result = await verifier.verify_token("eyJhbGciOiJSUzI1NiJ9.jwt_payload")
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whitespace_only_prefix_is_filtered_out() -> None:
|
||||
"""A whitespace-only prefix is also invalid and must be dropped."""
|
||||
verifier = CompositeTokenVerifier(jwt_verifier=None, api_key_prefixes=[" "])
|
||||
|
||||
assert " " not in verifier._api_key_prefixes
|
||||
result = await verifier.verify_token(" starts_with_spaces")
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_string_prefix_is_filtered_out() -> None:
|
||||
"""Non-string entries (e.g. None, int) must not be stored and must not
|
||||
cause a TypeError during verify_token."""
|
||||
verifier = CompositeTokenVerifier(
|
||||
jwt_verifier=None,
|
||||
api_key_prefixes=[None, 42, "sst_"], # type: ignore[list-item]
|
||||
)
|
||||
|
||||
assert None not in verifier._api_key_prefixes
|
||||
assert 42 not in verifier._api_key_prefixes
|
||||
assert verifier._api_key_prefixes == ("sst_",)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_prefixes_emit_warning(caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""Invalid prefix entries must trigger a logger.warning so operators can
|
||||
detect misconfiguration in FAB_API_KEY_PREFIXES."""
|
||||
import logging
|
||||
|
||||
logger_name = "superset.mcp_service.composite_token_verifier"
|
||||
with caplog.at_level(logging.WARNING, logger=logger_name):
|
||||
CompositeTokenVerifier(jwt_verifier=None, api_key_prefixes=["", "sst_"])
|
||||
|
||||
assert any("invalid" in record.message.lower() for record in caplog.records)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_invalid_prefixes_accepts_no_api_keys() -> None:
|
||||
"""When all prefixes are invalid and filtered out, no token should match
|
||||
the API key path."""
|
||||
verifier = CompositeTokenVerifier(jwt_verifier=None, api_key_prefixes=["", " "])
|
||||
|
||||
assert verifier._api_key_prefixes == ()
|
||||
result = await verifier.verify_token("sst_abc123")
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_passthrough_propagates_required_scopes() -> None:
|
||||
"""The pass-through AccessToken must carry the verifier's required_scopes
|
||||
so FastMCP's transport-level ``RequireAuthMiddleware`` does not 403 the
|
||||
request before ``_resolve_user_from_api_key`` runs."""
|
||||
jwt_verifier = MagicMock()
|
||||
jwt_verifier.required_scopes = ["read", "write"]
|
||||
jwt_verifier.verify_token = AsyncMock()
|
||||
|
||||
verifier = CompositeTokenVerifier(
|
||||
jwt_verifier=jwt_verifier, api_key_prefixes=["sst_"]
|
||||
)
|
||||
|
||||
result = await verifier.verify_token("sst_abc123")
|
||||
|
||||
assert result is not None
|
||||
assert result.scopes == ["read", "write"]
|
||||
@@ -91,17 +91,6 @@ def test_is_gamma_pvm_excludes_export_image(app_context: None) -> None:
|
||||
assert sm._is_gamma_pvm(pvm) is False
|
||||
|
||||
|
||||
def test_api_key_view_menu_is_admin_only() -> None:
|
||||
"""Regression test: 'ApiKey' must be in ADMIN_ONLY_VIEW_MENUS.
|
||||
|
||||
FAB registers an ApiKeyApi blueprint when FAB_API_KEY_ENABLED=True.
|
||||
Without this guard any Gamma user could reach the API key management
|
||||
endpoints. A rename or removal of the entry would silently re-open
|
||||
that access hole.
|
||||
"""
|
||||
assert "ApiKey" in SupersetSecurityManager.ADMIN_ONLY_VIEW_MENUS
|
||||
|
||||
|
||||
def test_is_gamma_pvm_allows_copy_clipboard(app_context: None) -> None:
|
||||
"""Verify _is_gamma_pvm returns True for can_copy_clipboard."""
|
||||
from superset.extensions import appbuilder
|
||||
|
||||
Reference in New Issue
Block a user