diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 82d26c2bd2b..bd1e2f9c361 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -43,7 +43,7 @@ import requests from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin from deprecation import deprecated -from flask import current_app as app, g, url_for +from flask import current_app as app, g from flask_appbuilder.security.sqla.models import User from flask_babel import gettext as __, lazy_gettext as _ from marshmallow import fields, Schema @@ -88,6 +88,7 @@ from superset.utils.oauth2 import ( encode_oauth2_state, generate_code_challenge, generate_code_verifier, + get_oauth2_redirect_uri, ) if TYPE_CHECKING: @@ -654,10 +655,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods from superset.daos.key_value import KeyValueDAO tab_id = str(uuid4()) - default_redirect_uri = app.config.get( - "DATABASE_OAUTH2_REDIRECT_URI", - url_for("DatabaseRestApi.oauth2", _external=True), - ) + default_redirect_uri = get_oauth2_redirect_uri() # Generate PKCE code verifier (RFC 7636) code_verifier = generate_code_verifier() @@ -720,10 +718,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return None db_engine_spec_config = oauth2_config[cls.engine_name] - redirect_uri = app.config.get( - "DATABASE_OAUTH2_REDIRECT_URI", - url_for("DatabaseRestApi.oauth2", _external=True), - ) + redirect_uri = get_oauth2_redirect_uri() config: OAuth2ClientConfig = { "id": db_engine_spec_config["id"], diff --git a/superset/mcp_service/chart/tool/generate_chart.py b/superset/mcp_service/chart/tool/generate_chart.py index be6d2390cc9..494af9b5614 100644 --- a/superset/mcp_service/chart/tool/generate_chart.py +++ b/superset/mcp_service/chart/tool/generate_chart.py @@ -29,6 +29,7 @@ from sqlalchemy.exc import SQLAlchemyError from superset_core.mcp.decorators import tool, ToolAnnotations from superset.commands.exceptions import CommandException +from superset.exceptions import OAuth2Error, OAuth2RedirectError from superset.extensions import event_logger from superset.mcp_service.auth import has_dataset_access from superset.mcp_service.chart.chart_utils import ( @@ -46,6 +47,10 @@ from superset.mcp_service.chart.schemas import ( parse_chart_config, PerformanceMetadata, ) +from superset.mcp_service.utils.oauth2_utils import ( + build_oauth2_redirect_message, + OAUTH2_CONFIG_ERROR_MESSAGE, +) from superset.mcp_service.utils.url_utils import get_superset_base_url from superset.utils import json @@ -826,6 +831,37 @@ async def generate_chart( # noqa: C901 ) return GenerateChartResponse.model_validate(result) + except OAuth2RedirectError as ex: + await ctx.error( + "Chart generation requires OAuth authentication: dataset_id=%s" + % request.dataset_id + ) + return GenerateChartResponse.model_validate( + { + "chart": None, + "success": False, + "error": { + "error_type": "OAUTH2_REDIRECT", + "message": build_oauth2_redirect_message(ex), + "details": "OAuth2 authentication required", + }, + } + ) + except OAuth2Error: + await ctx.error( + "OAuth2 configuration error: dataset_id=%s" % request.dataset_id + ) + return GenerateChartResponse.model_validate( + { + "chart": None, + "success": False, + "error": { + "error_type": "OAUTH2_REDIRECT_ERROR", + "message": OAUTH2_CONFIG_ERROR_MESSAGE, + "details": "OAuth2 configuration or provider error", + }, + } + ) except (CommandException, SQLAlchemyError, KeyError, ValueError) as e: from superset import db diff --git a/superset/mcp_service/chart/tool/get_chart_data.py b/superset/mcp_service/chart/tool/get_chart_data.py index c725ab49a81..7ee7eefee02 100644 --- a/superset/mcp_service/chart/tool/get_chart_data.py +++ b/superset/mcp_service/chart/tool/get_chart_data.py @@ -32,7 +32,7 @@ if TYPE_CHECKING: from superset.commands.exceptions import CommandException from superset.commands.explore.form_data.parameters import CommandParameters -from superset.exceptions import SupersetException +from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException from superset.extensions import event_logger from superset.mcp_service.chart.chart_utils import validate_chart_dataset from superset.mcp_service.chart.schemas import ( @@ -43,6 +43,10 @@ from superset.mcp_service.chart.schemas import ( PerformanceMetadata, ) from superset.mcp_service.utils.cache_utils import get_cache_status_from_result +from superset.mcp_service.utils.oauth2_utils import ( + build_oauth2_redirect_message, + OAUTH2_CONFIG_ERROR_MESSAGE, +) from superset.utils.core import merge_extra_filters logger = logging.getLogger(__name__) @@ -797,6 +801,23 @@ async def get_chart_data( # noqa: C901 error_type="DataError", ) + except OAuth2RedirectError as ex: + await ctx.error( + "Chart data requires OAuth authentication: identifier=%s" + % request.identifier + ) + return ChartError( + error=build_oauth2_redirect_message(ex), + error_type="OAUTH2_REDIRECT", + ) + except OAuth2Error: + await ctx.error( + "OAuth2 configuration error: identifier=%s" % request.identifier + ) + return ChartError( + error=OAUTH2_CONFIG_ERROR_MESSAGE, + error_type="OAUTH2_REDIRECT_ERROR", + ) except Exception as e: await ctx.error( "Chart data retrieval failed: identifier=%s, error=%s, error_type=%s" diff --git a/superset/mcp_service/chart/tool/get_chart_preview.py b/superset/mcp_service/chart/tool/get_chart_preview.py index df044dec595..77072a5943c 100644 --- a/superset/mcp_service/chart/tool/get_chart_preview.py +++ b/superset/mcp_service/chart/tool/get_chart_preview.py @@ -26,7 +26,7 @@ from fastmcp import Context from superset_core.mcp.decorators import tool, ToolAnnotations from superset.commands.exceptions import CommandException -from superset.exceptions import SupersetException +from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException from superset.extensions import event_logger from superset.mcp_service.chart.chart_utils import validate_chart_dataset from superset.mcp_service.chart.schemas import ( @@ -41,6 +41,10 @@ from superset.mcp_service.chart.schemas import ( URLPreview, VegaLitePreview, ) +from superset.mcp_service.utils.oauth2_utils import ( + build_oauth2_redirect_message, + OAUTH2_CONFIG_ERROR_MESSAGE, +) from superset.mcp_service.utils.url_utils import get_superset_base_url logger = logging.getLogger(__name__) @@ -2247,6 +2251,23 @@ async def get_chart_preview( ) return result + except OAuth2RedirectError as ex: + await ctx.error( + "Chart preview requires OAuth authentication: identifier=%s" + % request.identifier + ) + return ChartError( + error=build_oauth2_redirect_message(ex), + error_type="OAUTH2_REDIRECT", + ) + except OAuth2Error: + await ctx.error( + "OAuth2 configuration error: identifier=%s" % request.identifier + ) + return ChartError( + error=OAUTH2_CONFIG_ERROR_MESSAGE, + error_type="OAUTH2_REDIRECT_ERROR", + ) except Exception as e: await ctx.error( "Chart preview generation failed: identifier=%s, error=%s, error_type=%s" diff --git a/superset/mcp_service/chart/tool/update_chart.py b/superset/mcp_service/chart/tool/update_chart.py index 3987afd75b8..e5be2d03324 100644 --- a/superset/mcp_service/chart/tool/update_chart.py +++ b/superset/mcp_service/chart/tool/update_chart.py @@ -28,6 +28,7 @@ from sqlalchemy.exc import SQLAlchemyError from superset_core.mcp.decorators import tool, ToolAnnotations from superset.commands.exceptions import CommandException +from superset.exceptions import OAuth2Error, OAuth2RedirectError from superset.extensions import event_logger from superset.mcp_service.chart.chart_utils import ( analyze_chart_capabilities, @@ -42,6 +43,10 @@ from superset.mcp_service.chart.schemas import ( PerformanceMetadata, UpdateChartRequest, ) +from superset.mcp_service.utils.oauth2_utils import ( + build_oauth2_redirect_message, + OAUTH2_CONFIG_ERROR_MESSAGE, +) from superset.mcp_service.utils.url_utils import get_superset_base_url from superset.utils import json @@ -118,7 +123,7 @@ def _build_update_payload( destructiveHint=True, ), ) -async def update_chart( +async def update_chart( # noqa: C901 request: UpdateChartRequest, ctx: Context ) -> GenerateChartResponse: """Update existing chart with new configuration. @@ -313,6 +318,35 @@ async def update_chart( } return GenerateChartResponse.model_validate(result) + except OAuth2RedirectError as ex: + await ctx.error( + "Chart update requires OAuth authentication: identifier=%s" + % request.identifier + ) + return GenerateChartResponse.model_validate( + { + "chart": None, + "success": False, + "error": { + "error_type": "OAUTH2_REDIRECT", + "message": build_oauth2_redirect_message(ex), + "details": "OAuth2 authentication required", + }, + } + ) + except OAuth2Error: + await ctx.error("OAuth2 configuration error: chart_id=%s" % request.identifier) + return GenerateChartResponse.model_validate( + { + "chart": None, + "success": False, + "error": { + "error_type": "OAUTH2_REDIRECT_ERROR", + "message": OAUTH2_CONFIG_ERROR_MESSAGE, + "details": "OAuth2 configuration or provider error", + }, + } + ) except ( CommandException, SQLAlchemyError, diff --git a/superset/mcp_service/chart/tool/update_chart_preview.py b/superset/mcp_service/chart/tool/update_chart_preview.py index ccb6e53abe9..ae86d8df1c0 100644 --- a/superset/mcp_service/chart/tool/update_chart_preview.py +++ b/superset/mcp_service/chart/tool/update_chart_preview.py @@ -26,6 +26,7 @@ from typing import Any, Dict from fastmcp import Context from superset_core.mcp.decorators import tool, ToolAnnotations +from superset.exceptions import OAuth2Error, OAuth2RedirectError from superset.extensions import event_logger from superset.mcp_service.chart.chart_utils import ( analyze_chart_capabilities, @@ -40,6 +41,10 @@ from superset.mcp_service.chart.schemas import ( PerformanceMetadata, UpdateChartPreviewRequest, ) +from superset.mcp_service.utils.oauth2_utils import ( + build_oauth2_redirect_message, + OAUTH2_CONFIG_ERROR_MESSAGE, +) from superset.utils import json as utils_json logger = logging.getLogger(__name__) @@ -175,6 +180,25 @@ def update_chart_preview( } return result + except OAuth2RedirectError as ex: + logger.warning( + "Chart preview update requires OAuth authentication: form_data_key=%s", + request.form_data_key, + ) + return { + "chart": None, + "error": build_oauth2_redirect_message(ex), + "success": False, + } + except OAuth2Error: + logger.warning( + "OAuth2 configuration error: form_data_key=%s", request.form_data_key + ) + return { + "chart": None, + "error": OAUTH2_CONFIG_ERROR_MESSAGE, + "success": False, + } except Exception as e: execution_time = int((time.time() - start_time) * 1000) return { diff --git a/superset/mcp_service/sql_lab/tool/execute_sql.py b/superset/mcp_service/sql_lab/tool/execute_sql.py index edfbc9c8a50..60842383dbc 100644 --- a/superset/mcp_service/sql_lab/tool/execute_sql.py +++ b/superset/mcp_service/sql_lab/tool/execute_sql.py @@ -37,6 +37,7 @@ from superset_core.queries.types import ( ) from superset.errors import SupersetErrorType +from superset.exceptions import OAuth2Error, OAuth2RedirectError from superset.extensions import event_logger from superset.mcp_service.sql_lab.schemas import ( ColumnInfo, @@ -45,6 +46,10 @@ from superset.mcp_service.sql_lab.schemas import ( StatementData, StatementInfo, ) +from superset.mcp_service.utils.oauth2_utils import ( + build_oauth2_redirect_message, + OAUTH2_CONFIG_ERROR_MESSAGE, +) logger = logging.getLogger(__name__) @@ -147,6 +152,25 @@ async def execute_sql(request: ExecuteSqlRequest, ctx: Context) -> ExecuteSqlRes return response + except OAuth2RedirectError as ex: + await ctx.error( + "Database requires OAuth authentication: database_id=%s" + % request.database_id + ) + return ExecuteSqlResponse( + success=False, + error=build_oauth2_redirect_message(ex), + error_type=SupersetErrorType.OAUTH2_REDIRECT.value, + ) + except OAuth2Error: + await ctx.error( + "OAuth2 configuration/flow error: database_id=%s" % request.database_id + ) + return ExecuteSqlResponse( + success=False, + error=OAUTH2_CONFIG_ERROR_MESSAGE, + error_type=SupersetErrorType.OAUTH2_REDIRECT_ERROR.value, + ) except Exception as e: await ctx.error( "SQL execution failed: error=%s, database_id=%s" diff --git a/superset/mcp_service/utils/oauth2_utils.py b/superset/mcp_service/utils/oauth2_utils.py new file mode 100644 index 00000000000..f6f6f87068a --- /dev/null +++ b/superset/mcp_service/utils/oauth2_utils.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Utilities for handling OAuth2 errors in MCP tools. +""" + +from superset.exceptions import OAuth2RedirectError + + +def build_oauth2_redirect_message(ex: OAuth2RedirectError) -> str: + """ + Build a user-facing message for OAuth2RedirectError. + + Extracts the authorization URL from the exception and includes it + so the MCP client can present it to the user for authentication. + """ + # extra is always set by OAuth2RedirectError.__init__ + assert ex.error.extra is not None # noqa: S101 + oauth_url = ex.error.extra["url"] + return ( + "This database uses OAuth for authentication. " + "Please open the following URL in your browser to " + "authorize access, then retry this request:\n\n" + f"{oauth_url}" + ) + + +OAUTH2_CONFIG_ERROR_MESSAGE = ( + "OAuth authentication failed due to a configuration " + "or provider error. " + "Please contact your Superset administrator." +) diff --git a/superset/sql/execution/executor.py b/superset/sql/execution/executor.py index 90be63b5d2a..13e12fc1a47 100644 --- a/superset/sql/execution/executor.py +++ b/superset/sql/execution/executor.py @@ -69,6 +69,8 @@ from flask import current_app as app, g, has_app_context from superset import db from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( + OAuth2Error, + OAuth2RedirectError, SupersetSecurityException, SupersetTimeoutException, ) @@ -318,6 +320,10 @@ class SQLExecutor: ) except SupersetSecurityException as ex: return self._create_error_result(QueryStatus.FAILED, str(ex), start_time) + except (OAuth2RedirectError, OAuth2Error): + # Let OAuth2 exceptions propagate so callers (MCP, API) can + # handle them with context-appropriate responses. + raise except Exception as ex: error_msg = self.database.db_engine_spec.extract_error_message(ex) return self._create_error_result(QueryStatus.FAILED, error_msg, start_time) diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py index 4978c0af5c5..76b671b4109 100644 --- a/superset/utils/oauth2.py +++ b/superset/utils/oauth2.py @@ -29,10 +29,11 @@ import backoff import jwt from flask import current_app as app, url_for from marshmallow import EXCLUDE, fields, post_load, Schema, validate +from werkzeug.routing import BuildError from superset import db from superset.distributed_lock import DistributedLock -from superset.exceptions import AcquireDistributedLockFailedException +from superset.exceptions import AcquireDistributedLockFailedException, OAuth2Error from superset.superset_typing import OAuth2ClientConfig, OAuth2State if TYPE_CHECKING: @@ -245,16 +246,34 @@ def decode_oauth2_state(encoded_state: str) -> OAuth2State: return state +def get_oauth2_redirect_uri() -> str: + """ + Return the OAuth2 redirect URI. + + Tries the explicit config first, then falls back to url_for(). + If url_for() fails (e.g. in headless/MCP contexts where the + DatabaseRestApi blueprint may not be registered), raises + OAuth2Error so callers don't silently proceed with an invalid URI. + """ + if configured := app.config.get("DATABASE_OAUTH2_REDIRECT_URI"): + return configured + + try: + return url_for("DatabaseRestApi.oauth2", _external=True) + except (BuildError, RuntimeError): + raise OAuth2Error( + "Unable to determine the OAuth2 redirect URI. " + "Set DATABASE_OAUTH2_REDIRECT_URI in the configuration." + ) from None + + class OAuth2ClientConfigSchema(Schema): id = fields.String(required=True) secret = fields.String(required=True) scope = fields.String(required=True) redirect_uri = fields.String( required=False, - load_default=lambda: app.config.get( - "DATABASE_OAUTH2_REDIRECT_URI", - url_for("DatabaseRestApi.oauth2", _external=True), - ), + load_default=get_oauth2_redirect_uri, ) authorization_request_uri = fields.String(required=True) token_request_uri = fields.String(required=True) diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index 14fa82b55af..22ec7d0aa18 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -1196,7 +1196,7 @@ def test_start_oauth2_dance_falls_back_to_url_for(mocker: MockerFixture) -> None }, ) mocker.patch( - "superset.db_engine_specs.base.url_for", + "superset.utils.oauth2.url_for", return_value=fallback_uri, ) mocker.patch("superset.daos.key_value.KeyValueDAO") diff --git a/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py b/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py index bcd66cbe20e..697bec39d59 100644 --- a/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py +++ b/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py @@ -1093,3 +1093,74 @@ class TestSanitizeRowValues: assert rows[0]["name"] == "test" assert rows[0]["price"] == 9.99 assert rows[0]["blob"] == "000102ff" + + +class TestExecuteSqlOAuth2: + """Tests for OAuth2 error handling in execute_sql.""" + + @patch("superset.security_manager") + @patch("superset.db") + @pytest.mark.asyncio + async def test_execute_sql_oauth2_redirect_error( + self, mock_db, mock_security_manager, mcp_server + ): + """Test that OAuth2RedirectError is caught and returns a clear message.""" + from superset.exceptions import OAuth2RedirectError + + mock_database = _mock_database() + mock_database.execute.side_effect = OAuth2RedirectError( + url="https://oauth.example.com/authorize", + tab_id="test-tab-id", + redirect_uri="https://superset.example.com/callback", + ) + mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( + mock_database + ) + mock_security_manager.can_access_database.return_value = True + + request = { + "database_id": 1, + "sql": "SELECT 1", + "limit": 100, + } + + async with Client(mcp_server) as client: + result = await client.call_tool("execute_sql", {"request": request}) + + data = result.structured_content + assert data["success"] is False + assert "OAuth" in data["error"] + assert "https://oauth.example.com/authorize" in data["error"] + assert data["error_type"] == "OAUTH2_REDIRECT" + + @patch("superset.security_manager") + @patch("superset.db") + @pytest.mark.asyncio + async def test_execute_sql_oauth2_error( + self, mock_db, mock_security_manager, mcp_server + ): + """Test that OAuth2Error is caught and returns a clear message.""" + from superset.exceptions import OAuth2Error + + mock_database = _mock_database() + mock_database.execute.side_effect = OAuth2Error( + "Unable to determine the OAuth2 redirect URI." + ) + mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( + mock_database + ) + mock_security_manager.can_access_database.return_value = True + + request = { + "database_id": 1, + "sql": "SELECT 1", + "limit": 100, + } + + async with Client(mcp_server) as client: + result = await client.call_tool("execute_sql", {"request": request}) + + data = result.structured_content + assert data["success"] is False + assert "configuration" in data["error"] + assert data["error_type"] == "OAUTH2_REDIRECT_ERROR" diff --git a/tests/unit_tests/sql/execution/test_executor.py b/tests/unit_tests/sql/execution/test_executor.py index 18b814a54d9..fdaf52f8a79 100644 --- a/tests/unit_tests/sql/execution/test_executor.py +++ b/tests/unit_tests/sql/execution/test_executor.py @@ -644,6 +644,78 @@ def test_execute_error( assert "Database error" in result.error_message +def test_execute_oauth2_redirect_error_propagates( + mocker: MockerFixture, database: Database, app_context: None +) -> None: + """Test that OAuth2RedirectError propagates instead of being swallowed.""" + from superset.exceptions import OAuth2RedirectError + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_conn.__enter__ = MagicMock(return_value=mock_conn) + mock_conn.__exit__ = MagicMock(return_value=False) + mocker.patch.object(database, "get_raw_connection", return_value=mock_conn) + mocker.patch.object( + database, "mutate_sql_based_on_config", side_effect=lambda sql, **kw: sql + ) + mocker.patch.object( + database.db_engine_spec, + "execute", + side_effect=OAuth2RedirectError( + url="https://oauth.example.com/authorize", + tab_id="test-tab", + redirect_uri="https://superset.example.com/callback", + ), + ) + mocker.patch.dict( + current_app.config, + { + "SQL_QUERY_MUTATOR": None, + "SQLLAB_TIMEOUT": 30, + "SQL_MAX_ROW": None, + "QUERY_LOGGER": None, + }, + ) + + with pytest.raises(OAuth2RedirectError): + database.execute("SELECT 1") + + +def test_execute_oauth2_error_propagates( + mocker: MockerFixture, database: Database, app_context: None +) -> None: + """Test that OAuth2Error propagates instead of being swallowed.""" + from superset.exceptions import OAuth2Error + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_conn.__enter__ = MagicMock(return_value=mock_conn) + mock_conn.__exit__ = MagicMock(return_value=False) + mocker.patch.object(database, "get_raw_connection", return_value=mock_conn) + mocker.patch.object( + database, "mutate_sql_based_on_config", side_effect=lambda sql, **kw: sql + ) + mocker.patch.object( + database.db_engine_spec, + "execute", + side_effect=OAuth2Error("No configuration found for OAuth2"), + ) + mocker.patch.dict( + current_app.config, + { + "SQL_QUERY_MUTATOR": None, + "SQLLAB_TIMEOUT": 30, + "SQL_MAX_ROW": None, + "QUERY_LOGGER": None, + }, + ) + + with pytest.raises(OAuth2Error): + database.execute("SELECT 1") + + # ============================================================================= # Async Execution Tests # ============================================================================= diff --git a/tests/unit_tests/utils/oauth2_tests.py b/tests/unit_tests/utils/oauth2_tests.py index f04ae26e7c2..a320a8a23e8 100644 --- a/tests/unit_tests/utils/oauth2_tests.py +++ b/tests/unit_tests/utils/oauth2_tests.py @@ -33,6 +33,7 @@ from superset.utils.oauth2 import ( generate_code_challenge, generate_code_verifier, get_oauth2_access_token, + get_oauth2_redirect_uri, refresh_oauth2_token, ) @@ -335,3 +336,66 @@ def test_encode_decode_oauth2_state( assert "code_verifier" not in decoded assert decoded["database_id"] == 1 assert decoded["user_id"] == 2 + + +def test_get_oauth2_redirect_uri_from_config(mocker: MockerFixture) -> None: + """ + Test that get_oauth2_redirect_uri returns the configured value when set. + """ + custom_uri = "https://proxy.example.com/oauth2/" + mocker.patch( + "flask.current_app.config", + {"DATABASE_OAUTH2_REDIRECT_URI": custom_uri}, + ) + assert get_oauth2_redirect_uri() == custom_uri + + +def test_get_oauth2_redirect_uri_falls_back_to_url_for(mocker: MockerFixture) -> None: + """ + Test that get_oauth2_redirect_uri falls back to url_for when config is not set. + """ + fallback_uri = "http://localhost:8088/api/v1/database/oauth2/" + mocker.patch("flask.current_app.config", {}) + mocker.patch( + "superset.utils.oauth2.url_for", + return_value=fallback_uri, + ) + assert get_oauth2_redirect_uri() == fallback_uri + + +def test_get_oauth2_redirect_uri_raises_on_build_error( + mocker: MockerFixture, +) -> None: + """ + Test that get_oauth2_redirect_uri raises OAuth2Error when url_for raises + BuildError (e.g. in headless/MCP contexts). + """ + from werkzeug.routing import BuildError + + from superset.exceptions import OAuth2Error + + mocker.patch("flask.current_app.config", {}) + mocker.patch( + "superset.utils.oauth2.url_for", + side_effect=BuildError("DatabaseRestApi.oauth2", {}, ("GET",)), + ) + with pytest.raises(OAuth2Error): + get_oauth2_redirect_uri() + + +def test_get_oauth2_redirect_uri_raises_on_runtime_error( + mocker: MockerFixture, +) -> None: + """ + Test that get_oauth2_redirect_uri raises OAuth2Error when url_for raises + RuntimeError (e.g. no request context and no SERVER_NAME). + """ + from superset.exceptions import OAuth2Error + + mocker.patch("flask.current_app.config", {}) + mocker.patch( + "superset.utils.oauth2.url_for", + side_effect=RuntimeError("Unable to build URL outside of request context"), + ) + with pytest.raises(OAuth2Error): + get_oauth2_redirect_uri()