fix(mcp): handle OAuth-authenticated databases in execute_sql (#39166)

This commit is contained in:
Amin Ghadersohi
2026-04-09 15:47:00 -04:00
committed by GitHub
parent 5815665cc6
commit 68067d7f44
14 changed files with 452 additions and 18 deletions

View File

@@ -43,7 +43,7 @@ import requests
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from deprecation import deprecated
from flask import current_app as app, g, url_for
from flask import current_app as app, g
from flask_appbuilder.security.sqla.models import User
from flask_babel import gettext as __, lazy_gettext as _
from marshmallow import fields, Schema
@@ -88,6 +88,7 @@ from superset.utils.oauth2 import (
encode_oauth2_state,
generate_code_challenge,
generate_code_verifier,
get_oauth2_redirect_uri,
)
if TYPE_CHECKING:
@@ -654,10 +655,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
from superset.daos.key_value import KeyValueDAO
tab_id = str(uuid4())
default_redirect_uri = app.config.get(
"DATABASE_OAUTH2_REDIRECT_URI",
url_for("DatabaseRestApi.oauth2", _external=True),
)
default_redirect_uri = get_oauth2_redirect_uri()
# Generate PKCE code verifier (RFC 7636)
code_verifier = generate_code_verifier()
@@ -720,10 +718,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return None
db_engine_spec_config = oauth2_config[cls.engine_name]
redirect_uri = app.config.get(
"DATABASE_OAUTH2_REDIRECT_URI",
url_for("DatabaseRestApi.oauth2", _external=True),
)
redirect_uri = get_oauth2_redirect_uri()
config: OAuth2ClientConfig = {
"id": db_engine_spec_config["id"],

View File

@@ -29,6 +29,7 @@ from sqlalchemy.exc import SQLAlchemyError
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.commands.exceptions import CommandException
from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.extensions import event_logger
from superset.mcp_service.auth import has_dataset_access
from superset.mcp_service.chart.chart_utils import (
@@ -46,6 +47,10 @@ from superset.mcp_service.chart.schemas import (
parse_chart_config,
PerformanceMetadata,
)
from superset.mcp_service.utils.oauth2_utils import (
build_oauth2_redirect_message,
OAUTH2_CONFIG_ERROR_MESSAGE,
)
from superset.mcp_service.utils.url_utils import get_superset_base_url
from superset.utils import json
@@ -826,6 +831,37 @@ async def generate_chart( # noqa: C901
)
return GenerateChartResponse.model_validate(result)
except OAuth2RedirectError as ex:
await ctx.error(
"Chart generation requires OAuth authentication: dataset_id=%s"
% request.dataset_id
)
return GenerateChartResponse.model_validate(
{
"chart": None,
"success": False,
"error": {
"error_type": "OAUTH2_REDIRECT",
"message": build_oauth2_redirect_message(ex),
"details": "OAuth2 authentication required",
},
}
)
except OAuth2Error:
await ctx.error(
"OAuth2 configuration error: dataset_id=%s" % request.dataset_id
)
return GenerateChartResponse.model_validate(
{
"chart": None,
"success": False,
"error": {
"error_type": "OAUTH2_REDIRECT_ERROR",
"message": OAUTH2_CONFIG_ERROR_MESSAGE,
"details": "OAuth2 configuration or provider error",
},
}
)
except (CommandException, SQLAlchemyError, KeyError, ValueError) as e:
from superset import db

View File

@@ -32,7 +32,7 @@ if TYPE_CHECKING:
from superset.commands.exceptions import CommandException
from superset.commands.explore.form_data.parameters import CommandParameters
from superset.exceptions import SupersetException
from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException
from superset.extensions import event_logger
from superset.mcp_service.chart.chart_utils import validate_chart_dataset
from superset.mcp_service.chart.schemas import (
@@ -43,6 +43,10 @@ from superset.mcp_service.chart.schemas import (
PerformanceMetadata,
)
from superset.mcp_service.utils.cache_utils import get_cache_status_from_result
from superset.mcp_service.utils.oauth2_utils import (
build_oauth2_redirect_message,
OAUTH2_CONFIG_ERROR_MESSAGE,
)
from superset.utils.core import merge_extra_filters
logger = logging.getLogger(__name__)
@@ -797,6 +801,23 @@ async def get_chart_data( # noqa: C901
error_type="DataError",
)
except OAuth2RedirectError as ex:
await ctx.error(
"Chart data requires OAuth authentication: identifier=%s"
% request.identifier
)
return ChartError(
error=build_oauth2_redirect_message(ex),
error_type="OAUTH2_REDIRECT",
)
except OAuth2Error:
await ctx.error(
"OAuth2 configuration error: identifier=%s" % request.identifier
)
return ChartError(
error=OAUTH2_CONFIG_ERROR_MESSAGE,
error_type="OAUTH2_REDIRECT_ERROR",
)
except Exception as e:
await ctx.error(
"Chart data retrieval failed: identifier=%s, error=%s, error_type=%s"

View File

@@ -26,7 +26,7 @@ from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.commands.exceptions import CommandException
from superset.exceptions import SupersetException
from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException
from superset.extensions import event_logger
from superset.mcp_service.chart.chart_utils import validate_chart_dataset
from superset.mcp_service.chart.schemas import (
@@ -41,6 +41,10 @@ from superset.mcp_service.chart.schemas import (
URLPreview,
VegaLitePreview,
)
from superset.mcp_service.utils.oauth2_utils import (
build_oauth2_redirect_message,
OAUTH2_CONFIG_ERROR_MESSAGE,
)
from superset.mcp_service.utils.url_utils import get_superset_base_url
logger = logging.getLogger(__name__)
@@ -2247,6 +2251,23 @@ async def get_chart_preview(
)
return result
except OAuth2RedirectError as ex:
await ctx.error(
"Chart preview requires OAuth authentication: identifier=%s"
% request.identifier
)
return ChartError(
error=build_oauth2_redirect_message(ex),
error_type="OAUTH2_REDIRECT",
)
except OAuth2Error:
await ctx.error(
"OAuth2 configuration error: identifier=%s" % request.identifier
)
return ChartError(
error=OAUTH2_CONFIG_ERROR_MESSAGE,
error_type="OAUTH2_REDIRECT_ERROR",
)
except Exception as e:
await ctx.error(
"Chart preview generation failed: identifier=%s, error=%s, error_type=%s"

View File

@@ -28,6 +28,7 @@ from sqlalchemy.exc import SQLAlchemyError
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.commands.exceptions import CommandException
from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.extensions import event_logger
from superset.mcp_service.chart.chart_utils import (
analyze_chart_capabilities,
@@ -42,6 +43,10 @@ from superset.mcp_service.chart.schemas import (
PerformanceMetadata,
UpdateChartRequest,
)
from superset.mcp_service.utils.oauth2_utils import (
build_oauth2_redirect_message,
OAUTH2_CONFIG_ERROR_MESSAGE,
)
from superset.mcp_service.utils.url_utils import get_superset_base_url
from superset.utils import json
@@ -118,7 +123,7 @@ def _build_update_payload(
destructiveHint=True,
),
)
async def update_chart(
async def update_chart( # noqa: C901
request: UpdateChartRequest, ctx: Context
) -> GenerateChartResponse:
"""Update existing chart with new configuration.
@@ -313,6 +318,35 @@ async def update_chart(
}
return GenerateChartResponse.model_validate(result)
except OAuth2RedirectError as ex:
await ctx.error(
"Chart update requires OAuth authentication: identifier=%s"
% request.identifier
)
return GenerateChartResponse.model_validate(
{
"chart": None,
"success": False,
"error": {
"error_type": "OAUTH2_REDIRECT",
"message": build_oauth2_redirect_message(ex),
"details": "OAuth2 authentication required",
},
}
)
except OAuth2Error:
await ctx.error("OAuth2 configuration error: chart_id=%s" % request.identifier)
return GenerateChartResponse.model_validate(
{
"chart": None,
"success": False,
"error": {
"error_type": "OAUTH2_REDIRECT_ERROR",
"message": OAUTH2_CONFIG_ERROR_MESSAGE,
"details": "OAuth2 configuration or provider error",
},
}
)
except (
CommandException,
SQLAlchemyError,

View File

@@ -26,6 +26,7 @@ from typing import Any, Dict
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.extensions import event_logger
from superset.mcp_service.chart.chart_utils import (
analyze_chart_capabilities,
@@ -40,6 +41,10 @@ from superset.mcp_service.chart.schemas import (
PerformanceMetadata,
UpdateChartPreviewRequest,
)
from superset.mcp_service.utils.oauth2_utils import (
build_oauth2_redirect_message,
OAUTH2_CONFIG_ERROR_MESSAGE,
)
from superset.utils import json as utils_json
logger = logging.getLogger(__name__)
@@ -175,6 +180,25 @@ def update_chart_preview(
}
return result
except OAuth2RedirectError as ex:
logger.warning(
"Chart preview update requires OAuth authentication: form_data_key=%s",
request.form_data_key,
)
return {
"chart": None,
"error": build_oauth2_redirect_message(ex),
"success": False,
}
except OAuth2Error:
logger.warning(
"OAuth2 configuration error: form_data_key=%s", request.form_data_key
)
return {
"chart": None,
"error": OAUTH2_CONFIG_ERROR_MESSAGE,
"success": False,
}
except Exception as e:
execution_time = int((time.time() - start_time) * 1000)
return {

View File

@@ -37,6 +37,7 @@ from superset_core.queries.types import (
)
from superset.errors import SupersetErrorType
from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.extensions import event_logger
from superset.mcp_service.sql_lab.schemas import (
ColumnInfo,
@@ -45,6 +46,10 @@ from superset.mcp_service.sql_lab.schemas import (
StatementData,
StatementInfo,
)
from superset.mcp_service.utils.oauth2_utils import (
build_oauth2_redirect_message,
OAUTH2_CONFIG_ERROR_MESSAGE,
)
logger = logging.getLogger(__name__)
@@ -147,6 +152,25 @@ async def execute_sql(request: ExecuteSqlRequest, ctx: Context) -> ExecuteSqlRes
return response
except OAuth2RedirectError as ex:
await ctx.error(
"Database requires OAuth authentication: database_id=%s"
% request.database_id
)
return ExecuteSqlResponse(
success=False,
error=build_oauth2_redirect_message(ex),
error_type=SupersetErrorType.OAUTH2_REDIRECT.value,
)
except OAuth2Error:
await ctx.error(
"OAuth2 configuration/flow error: database_id=%s" % request.database_id
)
return ExecuteSqlResponse(
success=False,
error=OAUTH2_CONFIG_ERROR_MESSAGE,
error_type=SupersetErrorType.OAUTH2_REDIRECT_ERROR.value,
)
except Exception as e:
await ctx.error(
"SQL execution failed: error=%s, database_id=%s"

View File

@@ -0,0 +1,47 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Utilities for handling OAuth2 errors in MCP tools.
"""
from superset.exceptions import OAuth2RedirectError
def build_oauth2_redirect_message(ex: OAuth2RedirectError) -> str:
"""
Build a user-facing message for OAuth2RedirectError.
Extracts the authorization URL from the exception and includes it
so the MCP client can present it to the user for authentication.
"""
# extra is always set by OAuth2RedirectError.__init__
assert ex.error.extra is not None # noqa: S101
oauth_url = ex.error.extra["url"]
return (
"This database uses OAuth for authentication. "
"Please open the following URL in your browser to "
"authorize access, then retry this request:\n\n"
f"{oauth_url}"
)
OAUTH2_CONFIG_ERROR_MESSAGE = (
"OAuth authentication failed due to a configuration "
"or provider error. "
"Please contact your Superset administrator."
)

View File

@@ -69,6 +69,8 @@ from flask import current_app as app, g, has_app_context
from superset import db
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
OAuth2Error,
OAuth2RedirectError,
SupersetSecurityException,
SupersetTimeoutException,
)
@@ -318,6 +320,10 @@ class SQLExecutor:
)
except SupersetSecurityException as ex:
return self._create_error_result(QueryStatus.FAILED, str(ex), start_time)
except (OAuth2RedirectError, OAuth2Error):
# Let OAuth2 exceptions propagate so callers (MCP, API) can
# handle them with context-appropriate responses.
raise
except Exception as ex:
error_msg = self.database.db_engine_spec.extract_error_message(ex)
return self._create_error_result(QueryStatus.FAILED, error_msg, start_time)

View File

@@ -29,10 +29,11 @@ import backoff
import jwt
from flask import current_app as app, url_for
from marshmallow import EXCLUDE, fields, post_load, Schema, validate
from werkzeug.routing import BuildError
from superset import db
from superset.distributed_lock import DistributedLock
from superset.exceptions import AcquireDistributedLockFailedException
from superset.exceptions import AcquireDistributedLockFailedException, OAuth2Error
from superset.superset_typing import OAuth2ClientConfig, OAuth2State
if TYPE_CHECKING:
@@ -245,16 +246,34 @@ def decode_oauth2_state(encoded_state: str) -> OAuth2State:
return state
def get_oauth2_redirect_uri() -> str:
"""
Return the OAuth2 redirect URI.
Tries the explicit config first, then falls back to url_for().
If url_for() fails (e.g. in headless/MCP contexts where the
DatabaseRestApi blueprint may not be registered), raises
OAuth2Error so callers don't silently proceed with an invalid URI.
"""
if configured := app.config.get("DATABASE_OAUTH2_REDIRECT_URI"):
return configured
try:
return url_for("DatabaseRestApi.oauth2", _external=True)
except (BuildError, RuntimeError):
raise OAuth2Error(
"Unable to determine the OAuth2 redirect URI. "
"Set DATABASE_OAUTH2_REDIRECT_URI in the configuration."
) from None
class OAuth2ClientConfigSchema(Schema):
id = fields.String(required=True)
secret = fields.String(required=True)
scope = fields.String(required=True)
redirect_uri = fields.String(
required=False,
load_default=lambda: app.config.get(
"DATABASE_OAUTH2_REDIRECT_URI",
url_for("DatabaseRestApi.oauth2", _external=True),
),
load_default=get_oauth2_redirect_uri,
)
authorization_request_uri = fields.String(required=True)
token_request_uri = fields.String(required=True)

View File

@@ -1196,7 +1196,7 @@ def test_start_oauth2_dance_falls_back_to_url_for(mocker: MockerFixture) -> None
},
)
mocker.patch(
"superset.db_engine_specs.base.url_for",
"superset.utils.oauth2.url_for",
return_value=fallback_uri,
)
mocker.patch("superset.daos.key_value.KeyValueDAO")

View File

@@ -1093,3 +1093,74 @@ class TestSanitizeRowValues:
assert rows[0]["name"] == "test"
assert rows[0]["price"] == 9.99
assert rows[0]["blob"] == "000102ff"
class TestExecuteSqlOAuth2:
"""Tests for OAuth2 error handling in execute_sql."""
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_oauth2_redirect_error(
self, mock_db, mock_security_manager, mcp_server
):
"""Test that OAuth2RedirectError is caught and returns a clear message."""
from superset.exceptions import OAuth2RedirectError
mock_database = _mock_database()
mock_database.execute.side_effect = OAuth2RedirectError(
url="https://oauth.example.com/authorize",
tab_id="test-tab-id",
redirect_uri="https://superset.example.com/callback",
)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "SELECT 1",
"limit": 100,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
data = result.structured_content
assert data["success"] is False
assert "OAuth" in data["error"]
assert "https://oauth.example.com/authorize" in data["error"]
assert data["error_type"] == "OAUTH2_REDIRECT"
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_oauth2_error(
self, mock_db, mock_security_manager, mcp_server
):
"""Test that OAuth2Error is caught and returns a clear message."""
from superset.exceptions import OAuth2Error
mock_database = _mock_database()
mock_database.execute.side_effect = OAuth2Error(
"Unable to determine the OAuth2 redirect URI."
)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "SELECT 1",
"limit": 100,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
data = result.structured_content
assert data["success"] is False
assert "configuration" in data["error"]
assert data["error_type"] == "OAUTH2_REDIRECT_ERROR"

View File

@@ -644,6 +644,78 @@ def test_execute_error(
assert "Database error" in result.error_message
def test_execute_oauth2_redirect_error_propagates(
mocker: MockerFixture, database: Database, app_context: None
) -> None:
"""Test that OAuth2RedirectError propagates instead of being swallowed."""
from superset.exceptions import OAuth2RedirectError
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value = mock_cursor
mock_conn.__enter__ = MagicMock(return_value=mock_conn)
mock_conn.__exit__ = MagicMock(return_value=False)
mocker.patch.object(database, "get_raw_connection", return_value=mock_conn)
mocker.patch.object(
database, "mutate_sql_based_on_config", side_effect=lambda sql, **kw: sql
)
mocker.patch.object(
database.db_engine_spec,
"execute",
side_effect=OAuth2RedirectError(
url="https://oauth.example.com/authorize",
tab_id="test-tab",
redirect_uri="https://superset.example.com/callback",
),
)
mocker.patch.dict(
current_app.config,
{
"SQL_QUERY_MUTATOR": None,
"SQLLAB_TIMEOUT": 30,
"SQL_MAX_ROW": None,
"QUERY_LOGGER": None,
},
)
with pytest.raises(OAuth2RedirectError):
database.execute("SELECT 1")
def test_execute_oauth2_error_propagates(
mocker: MockerFixture, database: Database, app_context: None
) -> None:
"""Test that OAuth2Error propagates instead of being swallowed."""
from superset.exceptions import OAuth2Error
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value = mock_cursor
mock_conn.__enter__ = MagicMock(return_value=mock_conn)
mock_conn.__exit__ = MagicMock(return_value=False)
mocker.patch.object(database, "get_raw_connection", return_value=mock_conn)
mocker.patch.object(
database, "mutate_sql_based_on_config", side_effect=lambda sql, **kw: sql
)
mocker.patch.object(
database.db_engine_spec,
"execute",
side_effect=OAuth2Error("No configuration found for OAuth2"),
)
mocker.patch.dict(
current_app.config,
{
"SQL_QUERY_MUTATOR": None,
"SQLLAB_TIMEOUT": 30,
"SQL_MAX_ROW": None,
"QUERY_LOGGER": None,
},
)
with pytest.raises(OAuth2Error):
database.execute("SELECT 1")
# =============================================================================
# Async Execution Tests
# =============================================================================

View File

@@ -33,6 +33,7 @@ from superset.utils.oauth2 import (
generate_code_challenge,
generate_code_verifier,
get_oauth2_access_token,
get_oauth2_redirect_uri,
refresh_oauth2_token,
)
@@ -335,3 +336,66 @@ def test_encode_decode_oauth2_state(
assert "code_verifier" not in decoded
assert decoded["database_id"] == 1
assert decoded["user_id"] == 2
def test_get_oauth2_redirect_uri_from_config(mocker: MockerFixture) -> None:
"""
Test that get_oauth2_redirect_uri returns the configured value when set.
"""
custom_uri = "https://proxy.example.com/oauth2/"
mocker.patch(
"flask.current_app.config",
{"DATABASE_OAUTH2_REDIRECT_URI": custom_uri},
)
assert get_oauth2_redirect_uri() == custom_uri
def test_get_oauth2_redirect_uri_falls_back_to_url_for(mocker: MockerFixture) -> None:
"""
Test that get_oauth2_redirect_uri falls back to url_for when config is not set.
"""
fallback_uri = "http://localhost:8088/api/v1/database/oauth2/"
mocker.patch("flask.current_app.config", {})
mocker.patch(
"superset.utils.oauth2.url_for",
return_value=fallback_uri,
)
assert get_oauth2_redirect_uri() == fallback_uri
def test_get_oauth2_redirect_uri_raises_on_build_error(
mocker: MockerFixture,
) -> None:
"""
Test that get_oauth2_redirect_uri raises OAuth2Error when url_for raises
BuildError (e.g. in headless/MCP contexts).
"""
from werkzeug.routing import BuildError
from superset.exceptions import OAuth2Error
mocker.patch("flask.current_app.config", {})
mocker.patch(
"superset.utils.oauth2.url_for",
side_effect=BuildError("DatabaseRestApi.oauth2", {}, ("GET",)),
)
with pytest.raises(OAuth2Error):
get_oauth2_redirect_uri()
def test_get_oauth2_redirect_uri_raises_on_runtime_error(
mocker: MockerFixture,
) -> None:
"""
Test that get_oauth2_redirect_uri raises OAuth2Error when url_for raises
RuntimeError (e.g. no request context and no SERVER_NAME).
"""
from superset.exceptions import OAuth2Error
mocker.patch("flask.current_app.config", {})
mocker.patch(
"superset.utils.oauth2.url_for",
side_effect=RuntimeError("Unable to build URL outside of request context"),
)
with pytest.raises(OAuth2Error):
get_oauth2_redirect_uri()