mirror of
https://github.com/apache/superset.git
synced 2026-06-03 22:59:21 +00:00
fix(mcp): handle OAuth-authenticated databases in execute_sql (#39166)
This commit is contained in:
@@ -43,7 +43,7 @@ import requests
|
||||
from apispec import APISpec
|
||||
from apispec.ext.marshmallow import MarshmallowPlugin
|
||||
from deprecation import deprecated
|
||||
from flask import current_app as app, g, url_for
|
||||
from flask import current_app as app, g
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
from flask_babel import gettext as __, lazy_gettext as _
|
||||
from marshmallow import fields, Schema
|
||||
@@ -88,6 +88,7 @@ from superset.utils.oauth2 import (
|
||||
encode_oauth2_state,
|
||||
generate_code_challenge,
|
||||
generate_code_verifier,
|
||||
get_oauth2_redirect_uri,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -654,10 +655,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
|
||||
tab_id = str(uuid4())
|
||||
default_redirect_uri = app.config.get(
|
||||
"DATABASE_OAUTH2_REDIRECT_URI",
|
||||
url_for("DatabaseRestApi.oauth2", _external=True),
|
||||
)
|
||||
default_redirect_uri = get_oauth2_redirect_uri()
|
||||
|
||||
# Generate PKCE code verifier (RFC 7636)
|
||||
code_verifier = generate_code_verifier()
|
||||
@@ -720,10 +718,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
return None
|
||||
|
||||
db_engine_spec_config = oauth2_config[cls.engine_name]
|
||||
redirect_uri = app.config.get(
|
||||
"DATABASE_OAUTH2_REDIRECT_URI",
|
||||
url_for("DatabaseRestApi.oauth2", _external=True),
|
||||
)
|
||||
redirect_uri = get_oauth2_redirect_uri()
|
||||
|
||||
config: OAuth2ClientConfig = {
|
||||
"id": db_engine_spec_config["id"],
|
||||
|
||||
@@ -29,6 +29,7 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.commands.exceptions import CommandException
|
||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.auth import has_dataset_access
|
||||
from superset.mcp_service.chart.chart_utils import (
|
||||
@@ -46,6 +47,10 @@ from superset.mcp_service.chart.schemas import (
|
||||
parse_chart_config,
|
||||
PerformanceMetadata,
|
||||
)
|
||||
from superset.mcp_service.utils.oauth2_utils import (
|
||||
build_oauth2_redirect_message,
|
||||
OAUTH2_CONFIG_ERROR_MESSAGE,
|
||||
)
|
||||
from superset.mcp_service.utils.url_utils import get_superset_base_url
|
||||
from superset.utils import json
|
||||
|
||||
@@ -826,6 +831,37 @@ async def generate_chart( # noqa: C901
|
||||
)
|
||||
return GenerateChartResponse.model_validate(result)
|
||||
|
||||
except OAuth2RedirectError as ex:
|
||||
await ctx.error(
|
||||
"Chart generation requires OAuth authentication: dataset_id=%s"
|
||||
% request.dataset_id
|
||||
)
|
||||
return GenerateChartResponse.model_validate(
|
||||
{
|
||||
"chart": None,
|
||||
"success": False,
|
||||
"error": {
|
||||
"error_type": "OAUTH2_REDIRECT",
|
||||
"message": build_oauth2_redirect_message(ex),
|
||||
"details": "OAuth2 authentication required",
|
||||
},
|
||||
}
|
||||
)
|
||||
except OAuth2Error:
|
||||
await ctx.error(
|
||||
"OAuth2 configuration error: dataset_id=%s" % request.dataset_id
|
||||
)
|
||||
return GenerateChartResponse.model_validate(
|
||||
{
|
||||
"chart": None,
|
||||
"success": False,
|
||||
"error": {
|
||||
"error_type": "OAUTH2_REDIRECT_ERROR",
|
||||
"message": OAUTH2_CONFIG_ERROR_MESSAGE,
|
||||
"details": "OAuth2 configuration or provider error",
|
||||
},
|
||||
}
|
||||
)
|
||||
except (CommandException, SQLAlchemyError, KeyError, ValueError) as e:
|
||||
from superset import db
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from superset.commands.exceptions import CommandException
|
||||
from superset.commands.explore.form_data.parameters import CommandParameters
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.chart.chart_utils import validate_chart_dataset
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
@@ -43,6 +43,10 @@ from superset.mcp_service.chart.schemas import (
|
||||
PerformanceMetadata,
|
||||
)
|
||||
from superset.mcp_service.utils.cache_utils import get_cache_status_from_result
|
||||
from superset.mcp_service.utils.oauth2_utils import (
|
||||
build_oauth2_redirect_message,
|
||||
OAUTH2_CONFIG_ERROR_MESSAGE,
|
||||
)
|
||||
from superset.utils.core import merge_extra_filters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -797,6 +801,23 @@ async def get_chart_data( # noqa: C901
|
||||
error_type="DataError",
|
||||
)
|
||||
|
||||
except OAuth2RedirectError as ex:
|
||||
await ctx.error(
|
||||
"Chart data requires OAuth authentication: identifier=%s"
|
||||
% request.identifier
|
||||
)
|
||||
return ChartError(
|
||||
error=build_oauth2_redirect_message(ex),
|
||||
error_type="OAUTH2_REDIRECT",
|
||||
)
|
||||
except OAuth2Error:
|
||||
await ctx.error(
|
||||
"OAuth2 configuration error: identifier=%s" % request.identifier
|
||||
)
|
||||
return ChartError(
|
||||
error=OAUTH2_CONFIG_ERROR_MESSAGE,
|
||||
error_type="OAUTH2_REDIRECT_ERROR",
|
||||
)
|
||||
except Exception as e:
|
||||
await ctx.error(
|
||||
"Chart data retrieval failed: identifier=%s, error=%s, error_type=%s"
|
||||
|
||||
@@ -26,7 +26,7 @@ from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.commands.exceptions import CommandException
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.chart.chart_utils import validate_chart_dataset
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
@@ -41,6 +41,10 @@ from superset.mcp_service.chart.schemas import (
|
||||
URLPreview,
|
||||
VegaLitePreview,
|
||||
)
|
||||
from superset.mcp_service.utils.oauth2_utils import (
|
||||
build_oauth2_redirect_message,
|
||||
OAUTH2_CONFIG_ERROR_MESSAGE,
|
||||
)
|
||||
from superset.mcp_service.utils.url_utils import get_superset_base_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -2247,6 +2251,23 @@ async def get_chart_preview(
|
||||
)
|
||||
|
||||
return result
|
||||
except OAuth2RedirectError as ex:
|
||||
await ctx.error(
|
||||
"Chart preview requires OAuth authentication: identifier=%s"
|
||||
% request.identifier
|
||||
)
|
||||
return ChartError(
|
||||
error=build_oauth2_redirect_message(ex),
|
||||
error_type="OAUTH2_REDIRECT",
|
||||
)
|
||||
except OAuth2Error:
|
||||
await ctx.error(
|
||||
"OAuth2 configuration error: identifier=%s" % request.identifier
|
||||
)
|
||||
return ChartError(
|
||||
error=OAUTH2_CONFIG_ERROR_MESSAGE,
|
||||
error_type="OAUTH2_REDIRECT_ERROR",
|
||||
)
|
||||
except Exception as e:
|
||||
await ctx.error(
|
||||
"Chart preview generation failed: identifier=%s, error=%s, error_type=%s"
|
||||
|
||||
@@ -28,6 +28,7 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.commands.exceptions import CommandException
|
||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.chart.chart_utils import (
|
||||
analyze_chart_capabilities,
|
||||
@@ -42,6 +43,10 @@ from superset.mcp_service.chart.schemas import (
|
||||
PerformanceMetadata,
|
||||
UpdateChartRequest,
|
||||
)
|
||||
from superset.mcp_service.utils.oauth2_utils import (
|
||||
build_oauth2_redirect_message,
|
||||
OAUTH2_CONFIG_ERROR_MESSAGE,
|
||||
)
|
||||
from superset.mcp_service.utils.url_utils import get_superset_base_url
|
||||
from superset.utils import json
|
||||
|
||||
@@ -118,7 +123,7 @@ def _build_update_payload(
|
||||
destructiveHint=True,
|
||||
),
|
||||
)
|
||||
async def update_chart(
|
||||
async def update_chart( # noqa: C901
|
||||
request: UpdateChartRequest, ctx: Context
|
||||
) -> GenerateChartResponse:
|
||||
"""Update existing chart with new configuration.
|
||||
@@ -313,6 +318,35 @@ async def update_chart(
|
||||
}
|
||||
return GenerateChartResponse.model_validate(result)
|
||||
|
||||
except OAuth2RedirectError as ex:
|
||||
await ctx.error(
|
||||
"Chart update requires OAuth authentication: identifier=%s"
|
||||
% request.identifier
|
||||
)
|
||||
return GenerateChartResponse.model_validate(
|
||||
{
|
||||
"chart": None,
|
||||
"success": False,
|
||||
"error": {
|
||||
"error_type": "OAUTH2_REDIRECT",
|
||||
"message": build_oauth2_redirect_message(ex),
|
||||
"details": "OAuth2 authentication required",
|
||||
},
|
||||
}
|
||||
)
|
||||
except OAuth2Error:
|
||||
await ctx.error("OAuth2 configuration error: chart_id=%s" % request.identifier)
|
||||
return GenerateChartResponse.model_validate(
|
||||
{
|
||||
"chart": None,
|
||||
"success": False,
|
||||
"error": {
|
||||
"error_type": "OAUTH2_REDIRECT_ERROR",
|
||||
"message": OAUTH2_CONFIG_ERROR_MESSAGE,
|
||||
"details": "OAuth2 configuration or provider error",
|
||||
},
|
||||
}
|
||||
)
|
||||
except (
|
||||
CommandException,
|
||||
SQLAlchemyError,
|
||||
|
||||
@@ -26,6 +26,7 @@ from typing import Any, Dict
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.chart.chart_utils import (
|
||||
analyze_chart_capabilities,
|
||||
@@ -40,6 +41,10 @@ from superset.mcp_service.chart.schemas import (
|
||||
PerformanceMetadata,
|
||||
UpdateChartPreviewRequest,
|
||||
)
|
||||
from superset.mcp_service.utils.oauth2_utils import (
|
||||
build_oauth2_redirect_message,
|
||||
OAUTH2_CONFIG_ERROR_MESSAGE,
|
||||
)
|
||||
from superset.utils import json as utils_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -175,6 +180,25 @@ def update_chart_preview(
|
||||
}
|
||||
return result
|
||||
|
||||
except OAuth2RedirectError as ex:
|
||||
logger.warning(
|
||||
"Chart preview update requires OAuth authentication: form_data_key=%s",
|
||||
request.form_data_key,
|
||||
)
|
||||
return {
|
||||
"chart": None,
|
||||
"error": build_oauth2_redirect_message(ex),
|
||||
"success": False,
|
||||
}
|
||||
except OAuth2Error:
|
||||
logger.warning(
|
||||
"OAuth2 configuration error: form_data_key=%s", request.form_data_key
|
||||
)
|
||||
return {
|
||||
"chart": None,
|
||||
"error": OAUTH2_CONFIG_ERROR_MESSAGE,
|
||||
"success": False,
|
||||
}
|
||||
except Exception as e:
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
return {
|
||||
|
||||
@@ -37,6 +37,7 @@ from superset_core.queries.types import (
|
||||
)
|
||||
|
||||
from superset.errors import SupersetErrorType
|
||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.sql_lab.schemas import (
|
||||
ColumnInfo,
|
||||
@@ -45,6 +46,10 @@ from superset.mcp_service.sql_lab.schemas import (
|
||||
StatementData,
|
||||
StatementInfo,
|
||||
)
|
||||
from superset.mcp_service.utils.oauth2_utils import (
|
||||
build_oauth2_redirect_message,
|
||||
OAUTH2_CONFIG_ERROR_MESSAGE,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -147,6 +152,25 @@ async def execute_sql(request: ExecuteSqlRequest, ctx: Context) -> ExecuteSqlRes
|
||||
|
||||
return response
|
||||
|
||||
except OAuth2RedirectError as ex:
|
||||
await ctx.error(
|
||||
"Database requires OAuth authentication: database_id=%s"
|
||||
% request.database_id
|
||||
)
|
||||
return ExecuteSqlResponse(
|
||||
success=False,
|
||||
error=build_oauth2_redirect_message(ex),
|
||||
error_type=SupersetErrorType.OAUTH2_REDIRECT.value,
|
||||
)
|
||||
except OAuth2Error:
|
||||
await ctx.error(
|
||||
"OAuth2 configuration/flow error: database_id=%s" % request.database_id
|
||||
)
|
||||
return ExecuteSqlResponse(
|
||||
success=False,
|
||||
error=OAUTH2_CONFIG_ERROR_MESSAGE,
|
||||
error_type=SupersetErrorType.OAUTH2_REDIRECT_ERROR.value,
|
||||
)
|
||||
except Exception as e:
|
||||
await ctx.error(
|
||||
"SQL execution failed: error=%s, database_id=%s"
|
||||
|
||||
47
superset/mcp_service/utils/oauth2_utils.py
Normal file
47
superset/mcp_service/utils/oauth2_utils.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Utilities for handling OAuth2 errors in MCP tools.
|
||||
"""
|
||||
|
||||
from superset.exceptions import OAuth2RedirectError
|
||||
|
||||
|
||||
def build_oauth2_redirect_message(ex: OAuth2RedirectError) -> str:
|
||||
"""
|
||||
Build a user-facing message for OAuth2RedirectError.
|
||||
|
||||
Extracts the authorization URL from the exception and includes it
|
||||
so the MCP client can present it to the user for authentication.
|
||||
"""
|
||||
# extra is always set by OAuth2RedirectError.__init__
|
||||
assert ex.error.extra is not None # noqa: S101
|
||||
oauth_url = ex.error.extra["url"]
|
||||
return (
|
||||
"This database uses OAuth for authentication. "
|
||||
"Please open the following URL in your browser to "
|
||||
"authorize access, then retry this request:\n\n"
|
||||
f"{oauth_url}"
|
||||
)
|
||||
|
||||
|
||||
OAUTH2_CONFIG_ERROR_MESSAGE = (
|
||||
"OAuth authentication failed due to a configuration "
|
||||
"or provider error. "
|
||||
"Please contact your Superset administrator."
|
||||
)
|
||||
@@ -69,6 +69,8 @@ from flask import current_app as app, g, has_app_context
|
||||
from superset import db
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import (
|
||||
OAuth2Error,
|
||||
OAuth2RedirectError,
|
||||
SupersetSecurityException,
|
||||
SupersetTimeoutException,
|
||||
)
|
||||
@@ -318,6 +320,10 @@ class SQLExecutor:
|
||||
)
|
||||
except SupersetSecurityException as ex:
|
||||
return self._create_error_result(QueryStatus.FAILED, str(ex), start_time)
|
||||
except (OAuth2RedirectError, OAuth2Error):
|
||||
# Let OAuth2 exceptions propagate so callers (MCP, API) can
|
||||
# handle them with context-appropriate responses.
|
||||
raise
|
||||
except Exception as ex:
|
||||
error_msg = self.database.db_engine_spec.extract_error_message(ex)
|
||||
return self._create_error_result(QueryStatus.FAILED, error_msg, start_time)
|
||||
|
||||
@@ -29,10 +29,11 @@ import backoff
|
||||
import jwt
|
||||
from flask import current_app as app, url_for
|
||||
from marshmallow import EXCLUDE, fields, post_load, Schema, validate
|
||||
from werkzeug.routing import BuildError
|
||||
|
||||
from superset import db
|
||||
from superset.distributed_lock import DistributedLock
|
||||
from superset.exceptions import AcquireDistributedLockFailedException
|
||||
from superset.exceptions import AcquireDistributedLockFailedException, OAuth2Error
|
||||
from superset.superset_typing import OAuth2ClientConfig, OAuth2State
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -245,16 +246,34 @@ def decode_oauth2_state(encoded_state: str) -> OAuth2State:
|
||||
return state
|
||||
|
||||
|
||||
def get_oauth2_redirect_uri() -> str:
|
||||
"""
|
||||
Return the OAuth2 redirect URI.
|
||||
|
||||
Tries the explicit config first, then falls back to url_for().
|
||||
If url_for() fails (e.g. in headless/MCP contexts where the
|
||||
DatabaseRestApi blueprint may not be registered), raises
|
||||
OAuth2Error so callers don't silently proceed with an invalid URI.
|
||||
"""
|
||||
if configured := app.config.get("DATABASE_OAUTH2_REDIRECT_URI"):
|
||||
return configured
|
||||
|
||||
try:
|
||||
return url_for("DatabaseRestApi.oauth2", _external=True)
|
||||
except (BuildError, RuntimeError):
|
||||
raise OAuth2Error(
|
||||
"Unable to determine the OAuth2 redirect URI. "
|
||||
"Set DATABASE_OAUTH2_REDIRECT_URI in the configuration."
|
||||
) from None
|
||||
|
||||
|
||||
class OAuth2ClientConfigSchema(Schema):
|
||||
id = fields.String(required=True)
|
||||
secret = fields.String(required=True)
|
||||
scope = fields.String(required=True)
|
||||
redirect_uri = fields.String(
|
||||
required=False,
|
||||
load_default=lambda: app.config.get(
|
||||
"DATABASE_OAUTH2_REDIRECT_URI",
|
||||
url_for("DatabaseRestApi.oauth2", _external=True),
|
||||
),
|
||||
load_default=get_oauth2_redirect_uri,
|
||||
)
|
||||
authorization_request_uri = fields.String(required=True)
|
||||
token_request_uri = fields.String(required=True)
|
||||
|
||||
@@ -1196,7 +1196,7 @@ def test_start_oauth2_dance_falls_back_to_url_for(mocker: MockerFixture) -> None
|
||||
},
|
||||
)
|
||||
mocker.patch(
|
||||
"superset.db_engine_specs.base.url_for",
|
||||
"superset.utils.oauth2.url_for",
|
||||
return_value=fallback_uri,
|
||||
)
|
||||
mocker.patch("superset.daos.key_value.KeyValueDAO")
|
||||
|
||||
@@ -1093,3 +1093,74 @@ class TestSanitizeRowValues:
|
||||
assert rows[0]["name"] == "test"
|
||||
assert rows[0]["price"] == 9.99
|
||||
assert rows[0]["blob"] == "000102ff"
|
||||
|
||||
|
||||
class TestExecuteSqlOAuth2:
|
||||
"""Tests for OAuth2 error handling in execute_sql."""
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_oauth2_redirect_error(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test that OAuth2RedirectError is caught and returns a clear message."""
|
||||
from superset.exceptions import OAuth2RedirectError
|
||||
|
||||
mock_database = _mock_database()
|
||||
mock_database.execute.side_effect = OAuth2RedirectError(
|
||||
url="https://oauth.example.com/authorize",
|
||||
tab_id="test-tab-id",
|
||||
redirect_uri="https://superset.example.com/callback",
|
||||
)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT 1",
|
||||
"limit": 100,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
data = result.structured_content
|
||||
assert data["success"] is False
|
||||
assert "OAuth" in data["error"]
|
||||
assert "https://oauth.example.com/authorize" in data["error"]
|
||||
assert data["error_type"] == "OAUTH2_REDIRECT"
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_oauth2_error(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test that OAuth2Error is caught and returns a clear message."""
|
||||
from superset.exceptions import OAuth2Error
|
||||
|
||||
mock_database = _mock_database()
|
||||
mock_database.execute.side_effect = OAuth2Error(
|
||||
"Unable to determine the OAuth2 redirect URI."
|
||||
)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT 1",
|
||||
"limit": 100,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
data = result.structured_content
|
||||
assert data["success"] is False
|
||||
assert "configuration" in data["error"]
|
||||
assert data["error_type"] == "OAUTH2_REDIRECT_ERROR"
|
||||
|
||||
@@ -644,6 +644,78 @@ def test_execute_error(
|
||||
assert "Database error" in result.error_message
|
||||
|
||||
|
||||
def test_execute_oauth2_redirect_error_propagates(
|
||||
mocker: MockerFixture, database: Database, app_context: None
|
||||
) -> None:
|
||||
"""Test that OAuth2RedirectError propagates instead of being swallowed."""
|
||||
from superset.exceptions import OAuth2RedirectError
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_conn.__enter__ = MagicMock(return_value=mock_conn)
|
||||
mock_conn.__exit__ = MagicMock(return_value=False)
|
||||
mocker.patch.object(database, "get_raw_connection", return_value=mock_conn)
|
||||
mocker.patch.object(
|
||||
database, "mutate_sql_based_on_config", side_effect=lambda sql, **kw: sql
|
||||
)
|
||||
mocker.patch.object(
|
||||
database.db_engine_spec,
|
||||
"execute",
|
||||
side_effect=OAuth2RedirectError(
|
||||
url="https://oauth.example.com/authorize",
|
||||
tab_id="test-tab",
|
||||
redirect_uri="https://superset.example.com/callback",
|
||||
),
|
||||
)
|
||||
mocker.patch.dict(
|
||||
current_app.config,
|
||||
{
|
||||
"SQL_QUERY_MUTATOR": None,
|
||||
"SQLLAB_TIMEOUT": 30,
|
||||
"SQL_MAX_ROW": None,
|
||||
"QUERY_LOGGER": None,
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(OAuth2RedirectError):
|
||||
database.execute("SELECT 1")
|
||||
|
||||
|
||||
def test_execute_oauth2_error_propagates(
|
||||
mocker: MockerFixture, database: Database, app_context: None
|
||||
) -> None:
|
||||
"""Test that OAuth2Error propagates instead of being swallowed."""
|
||||
from superset.exceptions import OAuth2Error
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_conn.__enter__ = MagicMock(return_value=mock_conn)
|
||||
mock_conn.__exit__ = MagicMock(return_value=False)
|
||||
mocker.patch.object(database, "get_raw_connection", return_value=mock_conn)
|
||||
mocker.patch.object(
|
||||
database, "mutate_sql_based_on_config", side_effect=lambda sql, **kw: sql
|
||||
)
|
||||
mocker.patch.object(
|
||||
database.db_engine_spec,
|
||||
"execute",
|
||||
side_effect=OAuth2Error("No configuration found for OAuth2"),
|
||||
)
|
||||
mocker.patch.dict(
|
||||
current_app.config,
|
||||
{
|
||||
"SQL_QUERY_MUTATOR": None,
|
||||
"SQLLAB_TIMEOUT": 30,
|
||||
"SQL_MAX_ROW": None,
|
||||
"QUERY_LOGGER": None,
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(OAuth2Error):
|
||||
database.execute("SELECT 1")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Async Execution Tests
|
||||
# =============================================================================
|
||||
|
||||
@@ -33,6 +33,7 @@ from superset.utils.oauth2 import (
|
||||
generate_code_challenge,
|
||||
generate_code_verifier,
|
||||
get_oauth2_access_token,
|
||||
get_oauth2_redirect_uri,
|
||||
refresh_oauth2_token,
|
||||
)
|
||||
|
||||
@@ -335,3 +336,66 @@ def test_encode_decode_oauth2_state(
|
||||
assert "code_verifier" not in decoded
|
||||
assert decoded["database_id"] == 1
|
||||
assert decoded["user_id"] == 2
|
||||
|
||||
|
||||
def test_get_oauth2_redirect_uri_from_config(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that get_oauth2_redirect_uri returns the configured value when set.
|
||||
"""
|
||||
custom_uri = "https://proxy.example.com/oauth2/"
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
{"DATABASE_OAUTH2_REDIRECT_URI": custom_uri},
|
||||
)
|
||||
assert get_oauth2_redirect_uri() == custom_uri
|
||||
|
||||
|
||||
def test_get_oauth2_redirect_uri_falls_back_to_url_for(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that get_oauth2_redirect_uri falls back to url_for when config is not set.
|
||||
"""
|
||||
fallback_uri = "http://localhost:8088/api/v1/database/oauth2/"
|
||||
mocker.patch("flask.current_app.config", {})
|
||||
mocker.patch(
|
||||
"superset.utils.oauth2.url_for",
|
||||
return_value=fallback_uri,
|
||||
)
|
||||
assert get_oauth2_redirect_uri() == fallback_uri
|
||||
|
||||
|
||||
def test_get_oauth2_redirect_uri_raises_on_build_error(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that get_oauth2_redirect_uri raises OAuth2Error when url_for raises
|
||||
BuildError (e.g. in headless/MCP contexts).
|
||||
"""
|
||||
from werkzeug.routing import BuildError
|
||||
|
||||
from superset.exceptions import OAuth2Error
|
||||
|
||||
mocker.patch("flask.current_app.config", {})
|
||||
mocker.patch(
|
||||
"superset.utils.oauth2.url_for",
|
||||
side_effect=BuildError("DatabaseRestApi.oauth2", {}, ("GET",)),
|
||||
)
|
||||
with pytest.raises(OAuth2Error):
|
||||
get_oauth2_redirect_uri()
|
||||
|
||||
|
||||
def test_get_oauth2_redirect_uri_raises_on_runtime_error(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that get_oauth2_redirect_uri raises OAuth2Error when url_for raises
|
||||
RuntimeError (e.g. no request context and no SERVER_NAME).
|
||||
"""
|
||||
from superset.exceptions import OAuth2Error
|
||||
|
||||
mocker.patch("flask.current_app.config", {})
|
||||
mocker.patch(
|
||||
"superset.utils.oauth2.url_for",
|
||||
side_effect=RuntimeError("Unable to build URL outside of request context"),
|
||||
)
|
||||
with pytest.raises(OAuth2Error):
|
||||
get_oauth2_redirect_uri()
|
||||
|
||||
Reference in New Issue
Block a user