diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index f52b6250548..6c0cd77478c 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -551,17 +551,17 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods ) -> str: """ Return URI for initial OAuth2 request. + + Uses standard OAuth 2.0 parameters only. Subclasses can override + to add provider-specific parameters (e.g., Google's prompt=consent). """ uri = config["authorization_request_uri"] params = { "scope": config["scope"], - "access_type": "offline", - "include_granted_scopes": "false", "response_type": "code", "state": encode_oauth2_state(state), "redirect_uri": config["redirect_uri"], "client_id": config["id"], - "prompt": "consent", } return urljoin(uri, "?" + urlencode(params)) diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py index 2c70d07e947..30da234fd2c 100644 --- a/superset/db_engine_specs/gsheets.py +++ b/superset/db_engine_specs/gsheets.py @@ -45,6 +45,7 @@ from superset.utils import json if TYPE_CHECKING: from superset.models.core import Database from superset.sql.parse import Table + from superset.superset_typing import OAuth2ClientConfig, OAuth2State _logger = logging.getLogger() @@ -129,6 +130,38 @@ class GSheetsEngineSpec(ShillelaghEngineSpec): oauth2_token_request_uri = "https://oauth2.googleapis.com/token" # noqa: S105 oauth2_exception = UnauthenticatedError + @classmethod + def get_oauth2_authorization_uri( + cls, + config: "OAuth2ClientConfig", + state: "OAuth2State", + ) -> str: + """ + Return URI for initial OAuth2 request with Google-specific parameters. + + Google OAuth requires additional parameters for proper token refresh: + - access_type=offline: Request a refresh token + - include_granted_scopes=false: Don't include previously granted scopes + - prompt=consent: Force consent screen to ensure refresh token is returned + """ + from urllib.parse import urlencode, urljoin + + from superset.utils.oauth2 import encode_oauth2_state + + uri = config["authorization_request_uri"] + params = { + "scope": config["scope"], + "response_type": "code", + "state": encode_oauth2_state(state), + "redirect_uri": config["redirect_uri"], + "client_id": config["id"], + # Google-specific parameters + "access_type": "offline", + "include_granted_scopes": "false", + "prompt": "consent", + } + return urljoin(uri, "?" + urlencode(params)) + @classmethod def impersonate_user( cls, diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index 6db0edb813d..bff4c931171 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -894,3 +894,52 @@ def test_extract_errors_no_match_falls_back(mocker: MockerFixture) -> None: engine_name="ExampleEngine", ) assert result == [expected] + + +def test_get_oauth2_authorization_uri_standard_params(mocker: MockerFixture) -> None: + """ + Test that BaseEngineSpec.get_oauth2_authorization_uri uses standard OAuth 2.0 + parameters only and does not include provider-specific params like prompt=consent. + """ + from urllib.parse import parse_qs, urlparse + + from superset.db_engine_specs.base import BaseEngineSpec + from superset.superset_typing import OAuth2ClientConfig, OAuth2State + from superset.utils.oauth2 import decode_oauth2_state + + config: OAuth2ClientConfig = { + "id": "client-id", + "secret": "client-secret", + "scope": "read write", + "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/", + "authorization_request_uri": "https://oauth.example.com/authorize", + "token_request_uri": "https://oauth.example.com/token", + "request_content_type": "json", + } + + state: OAuth2State = { + "database_id": 1, + "user_id": 1, + "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/", + "tab_id": "1234", + } + + url = BaseEngineSpec.get_oauth2_authorization_uri(config, state) + parsed = urlparse(url) + assert parsed.netloc == "oauth.example.com" + assert parsed.path == "/authorize" + + query = parse_qs(parsed.query) + + # Verify standard OAuth 2.0 parameters are included + assert query["scope"][0] == "read write" + assert query["response_type"][0] == "code" + assert query["client_id"][0] == "client-id" + assert query["redirect_uri"][0] == "http://localhost:8088/api/v1/database/oauth2/" + encoded_state = query["state"][0].replace("%2E", ".") + assert decode_oauth2_state(encoded_state) == state + + # Verify Google-specific parameters are NOT included (standard OAuth 2.0) + assert "prompt" not in query + assert "access_type" not in query + assert "include_granted_scopes" not in query diff --git a/tests/unit_tests/db_engine_specs/test_gsheets.py b/tests/unit_tests/db_engine_specs/test_gsheets.py index 148064d8a6f..2ed796c32d2 100644 --- a/tests/unit_tests/db_engine_specs/test_gsheets.py +++ b/tests/unit_tests/db_engine_specs/test_gsheets.py @@ -638,6 +638,11 @@ def test_get_oauth2_authorization_uri( encoded_state = query["state"][0].replace("%2E", ".") assert decode_oauth2_state(encoded_state) == state + # Verify Google-specific OAuth parameters are included + assert query["access_type"][0] == "offline" + assert query["include_granted_scopes"][0] == "false" + assert query["prompt"][0] == "consent" + def test_get_oauth2_token( mocker: MockerFixture, diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py index 2f165df009f..6e221e1494b 100644 --- a/tests/unit_tests/sql_lab_test.py +++ b/tests/unit_tests/sql_lab_test.py @@ -231,7 +231,7 @@ def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None: "error_type": SupersetErrorType.OAUTH2_REDIRECT, "level": ErrorLevel.WARNING, "extra": { - "url": "https://abcd1234.snowflakecomputing.com/oauth/authorize?scope=refresh_token+session%3Arole%3AUSERADMIN&access_type=offline&include_granted_scopes=false&response_type=code&state=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9%252EeyJleHAiOjE2MTcyMzU1MDAsImRhdGFiYXNlX2lkIjoxLCJ1c2VyX2lkIjo0MiwiZGVmYXVsdF9yZWRpcmVjdF91cmkiOiJodHRwOi8vbG9jYWxob3N0L2FwaS92MS9kYXRhYmFzZS9vYXV0aDIvIiwidGFiX2lkIjoiZmIxMWY1MjgtNmViYS00YThhLTgzN2UtNmIwZDM5ZWU5MTg3In0%252E7nLkei6-V8sVk_Pgm8cFhk0tnKRKayRE1Vc7RxuM9mw&redirect_uri=http%3A%2F%2Flocalhost%2Fapi%2Fv1%2Fdatabase%2Foauth2%2F&client_id=my_client_id&prompt=consent", + "url": "https://abcd1234.snowflakecomputing.com/oauth/authorize?scope=refresh_token+session%3Arole%3AUSERADMIN&response_type=code&state=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9%252EeyJleHAiOjE2MTcyMzU1MDAsImRhdGFiYXNlX2lkIjoxLCJ1c2VyX2lkIjo0MiwiZGVmYXVsdF9yZWRpcmVjdF91cmkiOiJodHRwOi8vbG9jYWxob3N0L2FwaS92MS9kYXRhYmFzZS9vYXV0aDIvIiwidGFiX2lkIjoiZmIxMWY1MjgtNmViYS00YThhLTgzN2UtNmIwZDM5ZWU5MTg3In0%252E7nLkei6-V8sVk_Pgm8cFhk0tnKRKayRE1Vc7RxuM9mw&redirect_uri=http%3A%2F%2Flocalhost%2Fapi%2Fv1%2Fdatabase%2Foauth2%2F&client_id=my_client_id", "tab_id": "fb11f528-6eba-4a8a-837e-6b0d39ee9187", "redirect_uri": "http://localhost/api/v1/database/oauth2/", },