fix(db oauth2): Improve OAuth2 flow (#39499)

This commit is contained in:
Vitor Avila
2026-04-21 11:54:52 -03:00
committed by GitHub
parent a222dab781
commit 191337e08d
7 changed files with 239 additions and 182 deletions

View File

@@ -62,7 +62,11 @@ from superset import db
from superset.constants import QUERY_CANCEL_KEY, TimeGrain as TimeGrainConstants
from superset.databases.utils import get_table_metadata, make_url_safe
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.exceptions import (
OAuth2Error,
OAuth2RedirectError,
OAuth2TokenRefreshError,
)
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
from superset.sql.parse import (
BaseSQLStatement,
@@ -828,6 +832,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
if config["request_content_type"] == "data"
else requests.post(uri, json=req_body, timeout=timeout)
)
if response.status_code in (400, 401, 403):
raise OAuth2TokenRefreshError(response.text)
response.raise_for_status()
return response.json()

View File

@@ -30,9 +30,7 @@ from flask_babel import gettext as __
from marshmallow import fields, Schema
from marshmallow.exceptions import ValidationError
from requests import Session
from requests.exceptions import HTTPError
from shillelagh.adapters.api.gsheets.lib import SCOPES
from shillelagh.exceptions import UnauthenticatedError
from sqlalchemy.engine import create_engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
@@ -43,7 +41,6 @@ from superset.db_engine_specs.base import DatabaseCategory
from superset.db_engine_specs.shillelagh import ShillelaghEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetException
from superset.superset_typing import OAuth2TokenResponse
from superset.utils import json
from superset.utils.oauth2 import get_oauth2_access_token
@@ -154,7 +151,6 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
"https://accounts.google.com/o/oauth2/v2/auth"
)
oauth2_token_request_uri = "https://oauth2.googleapis.com/token" # noqa: S105
oauth2_exception = UnauthenticatedError
@classmethod
def get_oauth2_authorization_uri(
@@ -218,29 +214,6 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
)
)
@classmethod
def get_oauth2_fresh_token(
cls,
config: OAuth2ClientConfig,
refresh_token: str,
) -> OAuth2TokenResponse:
"""
Refresh an OAuth2 access token that has expired.
When trying to refresh an expired token that was revoked on Google side,
the request fails with 400 status code.
"""
try:
return super().get_oauth2_fresh_token(config, refresh_token)
except HTTPError as ex:
if ex.response is not None and ex.response.status_code == 400:
error_data = ex.response.json()
if error_data.get("error") == "invalid_grant":
raise UnauthenticatedError(
error_data.get("error_description", "Token has been revoked")
) from ex
raise
@classmethod
def impersonate_user(
cls,

View File

@@ -368,6 +368,27 @@ class OAuth2RedirectError(SupersetErrorException):
)
class OAuth2TokenRefreshError(OAuth2RedirectError):
"""
Raised when an OAuth2 refresh token request fails with a 400/401/403 error.
The stored token is no longer valid and the user must re-authenticate.
Subclasses OAuth2RedirectError so that existing oauth2_exception checks
match it automatically, triggering start_oauth2_dance() via check_for_oauth2.
"""
def __init__(self, response_text: str) -> None:
SupersetErrorException.__init__(
self,
SupersetError(
message="OAuth2 token refresh failed, re-authentication required.",
error_type=SupersetErrorType.OAUTH2_REDIRECT,
level=ErrorLevel.WARNING,
extra={"error": response_text},
),
)
class OAuth2Error(SupersetErrorException):
"""
Exception for when OAuth2 goes wrong.

View File

@@ -79,9 +79,9 @@ def generate_code_challenge(code_verifier: str) -> str:
@backoff.on_exception(
backoff.expo,
AcquireDistributedLockFailedException,
factor=10,
factor=0.1,
base=2,
max_tries=5,
max_tries=8,
raise_on_giveup=False,
giveup_log_level=logging.DEBUG,
)
@@ -143,14 +143,17 @@ def refresh_oauth2_token(
config,
token.refresh_token,
)
except db_engine_spec.oauth2_exception:
except db_engine_spec.oauth2_exception as ex:
# OAuth token is no longer valid, delete it and start OAuth2 dance
logger.warning(
"OAuth2 token refresh failed for user=%s db=%s, deleting invalid token",
"OAuth2 token refresh failed for user=%s db=%s, "
"deleting token. Error: %s",
user_id,
database_id,
ex,
)
db.session.delete(token)
db.session.flush()
raise
except Exception:
# non-OAuth related failure, log the exception

View File

@@ -21,6 +21,7 @@ from __future__ import annotations
import json # noqa: TID251
import re
from datetime import timedelta
from textwrap import dedent
from typing import Any
from urllib.parse import parse_qs, urlparse
@@ -44,6 +45,7 @@ from superset.superset_typing import (
)
from superset.utils.core import FilterOperator, GenericDataType
from superset.utils.oauth2 import decode_oauth2_state
from tests.conftest import with_config
from tests.unit_tests.db_engine_specs.utils import assert_column_spec
@@ -597,6 +599,19 @@ def test_extract_errors(mocker: MockerFixture) -> None:
assert result == [expected]
@with_config(
{
"CUSTOM_DATABASE_ERRORS": {
"examples": {
re.compile("This connector does not support roles"): (
"Custom error message",
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
{},
)
}
}
},
)
def test_extract_errors_from_config(mocker: MockerFixture) -> None:
"""
Test that custom error messages are extracted correctly from app config
@@ -606,21 +621,6 @@ def test_extract_errors_from_config(mocker: MockerFixture) -> None:
class TestEngineSpec(BaseEngineSpec):
engine_name = "ExampleEngine"
mocker.patch(
"flask.current_app.config",
{
"CUSTOM_DATABASE_ERRORS": {
"examples": {
re.compile("This connector does not support roles"): (
"Custom error message",
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
{},
)
}
}
},
)
msg = "This connector does not support roles"
result = TestEngineSpec.extract_errors(Exception(msg), database_name="examples")
@@ -631,6 +631,19 @@ def test_extract_errors_from_config(mocker: MockerFixture) -> None:
assert result == [expected]
@with_config(
{
"CUSTOM_DATABASE_ERRORS": {
"examples": {
re.compile("This connector does not support roles"): (
"Custom error message",
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
{},
)
}
}
},
)
def test_extract_errors_only_to_specified_database(mocker: MockerFixture) -> None:
"""
Test that custom error messages are only applied to the specified database_name.
@@ -639,21 +652,6 @@ def test_extract_errors_only_to_specified_database(mocker: MockerFixture) -> Non
class TestEngineSpec(BaseEngineSpec):
engine_name = "ExampleEngine"
mocker.patch(
"flask.current_app.config",
{
"CUSTOM_DATABASE_ERRORS": {
"examples": {
re.compile("This connector does not support roles"): (
"Custom error message",
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
{},
)
}
}
},
)
msg = "This connector does not support roles"
# database_name doesn't match configured one, so default message is used
result = TestEngineSpec.extract_errors(Exception(msg), database_name="examples_2")
@@ -665,6 +663,27 @@ def test_extract_errors_only_to_specified_database(mocker: MockerFixture) -> Non
assert result == [expected]
@with_config(
{
"CUSTOM_DATABASE_ERRORS": {
"examples": {
re.compile(r'message="(?P<message>[^"]*)"'): (
'Unexpected error: "%(message)s"',
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
{
"custom_doc_links": [
{
"url": "https://example.com/docs",
"label": "Check documentation",
},
],
"show_issue_info": False,
},
)
}
}
},
)
def test_extract_errors_from_config_with_regex(mocker: MockerFixture) -> None:
"""
Test that custom error messages with regex, custom_doc_links,
@@ -674,29 +693,6 @@ def test_extract_errors_from_config_with_regex(mocker: MockerFixture) -> None:
class TestEngineSpec(BaseEngineSpec):
engine_name = "ExampleEngine"
mocker.patch(
"flask.current_app.config",
{
"CUSTOM_DATABASE_ERRORS": {
"examples": {
re.compile(r'message="(?P<message>[^"]*)"'): (
'Unexpected error: "%(message)s"',
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
{
"custom_doc_links": [
{
"url": "https://example.com/docs",
"label": "Check documentation",
},
],
"show_issue_info": False,
},
)
}
}
},
)
msg = (
"db error: SomeUserError(type=USER_ERROR, name=TABLE_NOT_FOUND, "
'message="line 3:6: Table '
@@ -735,6 +731,7 @@ def test_extract_errors_from_config_with_regex(mocker: MockerFixture) -> None:
]
@with_config({"CUSTOM_DATABASE_ERRORS": {"examples": "not a dict"}})
def test_extract_errors_with_non_dict_custom_errors(mocker: MockerFixture):
"""
Test that extract_errors doesn't fail when custom database errors
@@ -744,11 +741,6 @@ def test_extract_errors_with_non_dict_custom_errors(mocker: MockerFixture):
class TestEngineSpec(BaseEngineSpec):
engine_name = "ExampleEngine"
mocker.patch(
"flask.current_app.config",
{"CUSTOM_DATABASE_ERRORS": "not a dict"},
)
msg = "This connector does not support roles"
result = TestEngineSpec.extract_errors(Exception(msg))
@@ -759,6 +751,7 @@ def test_extract_errors_with_non_dict_custom_errors(mocker: MockerFixture):
assert result == [expected]
@with_config({"CUSTOM_DATABASE_ERRORS": {"examples": "not a dict"}})
def test_extract_errors_with_non_dict_engine_custom_errors(mocker: MockerFixture):
"""
Test that extract_errors doesn't fail when database-specific custom errors
@@ -768,11 +761,6 @@ def test_extract_errors_with_non_dict_engine_custom_errors(mocker: MockerFixture
class TestEngineSpec(BaseEngineSpec):
engine_name = "ExampleEngine"
mocker.patch(
"flask.current_app.config",
{"CUSTOM_DATABASE_ERRORS": {"examples": "not a dict"}},
)
msg = "This connector does not support roles"
result = TestEngineSpec.extract_errors(Exception(msg), database_name="examples")
@@ -783,6 +771,19 @@ def test_extract_errors_with_non_dict_engine_custom_errors(mocker: MockerFixture
assert result == [expected]
@with_config(
{
"CUSTOM_DATABASE_ERRORS": {
"examples": {
re.compile("This connector does not support roles"): (
"",
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
{},
)
}
}
},
)
def test_extract_errors_with_empty_custom_error_message(mocker: MockerFixture):
"""
Test that when the custom error message is empty,
@@ -792,21 +793,6 @@ def test_extract_errors_with_empty_custom_error_message(mocker: MockerFixture):
class TestEngineSpec(BaseEngineSpec):
engine_name = "ExampleEngine"
mocker.patch(
"flask.current_app.config",
{
"CUSTOM_DATABASE_ERRORS": {
"examples": {
re.compile("This connector does not support roles"): (
"",
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
{},
)
}
}
},
)
msg = "This connector does not support roles"
result = TestEngineSpec.extract_errors(Exception(msg), database_name="examples")
@@ -817,6 +803,26 @@ def test_extract_errors_with_empty_custom_error_message(mocker: MockerFixture):
assert result == [expected]
@with_config(
{
"CUSTOM_DATABASE_ERRORS": {
"examples": {
re.compile("connection error"): (
"Examples DB error message",
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
{},
)
},
"examples_2": {
re.compile("connection error"): (
"Examples_2 DB error message",
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
{},
)
},
}
},
)
def test_extract_errors_matches_database_name_selection(mocker: MockerFixture) -> None:
"""
Test that custom error messages are matched by database_name.
@@ -825,28 +831,6 @@ def test_extract_errors_matches_database_name_selection(mocker: MockerFixture) -
class TestEngineSpec(BaseEngineSpec):
engine_name = "ExampleEngine"
mocker.patch(
"flask.current_app.config",
{
"CUSTOM_DATABASE_ERRORS": {
"examples": {
re.compile("connection error"): (
"Examples DB error message",
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
{},
)
},
"examples_2": {
re.compile("connection error"): (
"Examples_2 DB error message",
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
{},
)
},
}
},
)
msg = "connection error occurred"
# When database_name is examples_2 we should get that specific message
result = TestEngineSpec.extract_errors(Exception(msg), database_name="examples_2")
@@ -858,6 +842,19 @@ def test_extract_errors_matches_database_name_selection(mocker: MockerFixture) -
assert result == [expected]
@with_config(
{
"CUSTOM_DATABASE_ERRORS": {
"examples": {
re.compile("connection error"): (
"Examples DB error message",
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
{},
)
},
}
},
)
def test_extract_errors_no_match_falls_back(mocker: MockerFixture) -> None:
"""
Test that when database_name has no match, the original error message is preserved.
@@ -866,21 +863,6 @@ def test_extract_errors_no_match_falls_back(mocker: MockerFixture) -> None:
class TestEngineSpec(BaseEngineSpec):
engine_name = "ExampleEngine"
mocker.patch(
"flask.current_app.config",
{
"CUSTOM_DATABASE_ERRORS": {
"examples": {
re.compile("connection error"): (
"Examples DB error message",
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
{},
)
},
}
},
)
msg = "some other error"
result = TestEngineSpec.extract_errors(Exception(msg), database_name="examples_2")
@@ -980,16 +962,13 @@ def test_get_oauth2_authorization_uri_with_pkce(mocker: MockerFixture) -> None:
assert query["code_challenge"][0] == expected_challenge
@with_config({"DATABASE_OAUTH2_TIMEOUT": timedelta(seconds=30)})
def test_get_oauth2_token_without_pkce(mocker: MockerFixture) -> None:
"""
Test that BaseEngineSpec.get_oauth2_token works without PKCE code_verifier.
"""
from superset.db_engine_specs.base import BaseEngineSpec
mocker.patch(
"flask.current_app.config",
{"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda: 30)},
)
mock_post = mocker.patch("superset.db_engine_specs.base.requests.post")
mock_post.return_value.json.return_value = {
"access_token": "test-access-token", # noqa: S105
@@ -1015,6 +994,7 @@ def test_get_oauth2_token_without_pkce(mocker: MockerFixture) -> None:
assert "code_verifier" not in request_body
@with_config({"DATABASE_OAUTH2_TIMEOUT": timedelta(seconds=30)})
def test_get_oauth2_token_with_pkce(mocker: MockerFixture) -> None:
"""
Test BaseEngineSpec.get_oauth2_token includes code_verifier when provided.
@@ -1022,10 +1002,6 @@ def test_get_oauth2_token_with_pkce(mocker: MockerFixture) -> None:
from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils.oauth2 import generate_code_verifier
mocker.patch(
"flask.current_app.config",
{"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda: 30)},
)
mock_post = mocker.patch("superset.db_engine_specs.base.requests.post")
mock_post.return_value.json.return_value = {
"access_token": "test-access-token", # noqa: S105
@@ -1097,6 +1073,7 @@ def test_get_oauth2_authorization_uri_additional_params(
assert query["access_type"][0] == "offline"
@with_config({"DATABASE_OAUTH2_TIMEOUT": timedelta(seconds=30)})
def test_get_oauth2_token_additional_params(mocker: MockerFixture) -> None:
"""
Test that a subclass can inject additional params into the token request body
@@ -1109,10 +1086,6 @@ def test_get_oauth2_token_additional_params(mocker: MockerFixture) -> None:
"audience": "https://api.example.com",
}
mocker.patch(
"flask.current_app.config",
{"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda: 30)},
)
mock_post = mocker.patch("superset.db_engine_specs.base.requests.post")
mock_post.return_value.json.return_value = {
"access_token": "test-access-token", # noqa: S105
@@ -1143,6 +1116,94 @@ def test_get_oauth2_token_additional_params(mocker: MockerFixture) -> None:
assert request_body["audience"] == "https://api.example.com"
@with_config({"DATABASE_OAUTH2_TIMEOUT": timedelta(seconds=30)})
def test_get_oauth2_fresh_token_success(mocker: MockerFixture) -> None:
"""
Test that get_oauth2_fresh_token returns the token response on success.
"""
from superset.db_engine_specs.base import BaseEngineSpec
mock_post = mocker.patch("superset.db_engine_specs.base.requests.post")
mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = {
"access_token": "new-access-token",
"expires_in": 3600,
}
config: OAuth2ClientConfig = {
"id": "client-id",
"secret": "client-secret",
"scope": "read write",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "https://oauth.example.com/authorize",
"token_request_uri": "https://oauth.example.com/token",
"request_content_type": "json",
}
result = BaseEngineSpec.get_oauth2_fresh_token(config, "refresh-token")
assert result == {"access_token": "new-access-token", "expires_in": 3600}
@pytest.mark.parametrize("status_code", [400, 401, 403])
@with_config({"DATABASE_OAUTH2_TIMEOUT": timedelta(seconds=30)})
def test_get_oauth2_fresh_token_raises_on_auth_error(
mocker: MockerFixture,
status_code: int,
) -> None:
"""
Test that get_oauth2_fresh_token raises OAuth2TokenRefreshError on 400/401/403.
"""
from superset.db_engine_specs.base import BaseEngineSpec
from superset.exceptions import OAuth2TokenRefreshError
mock_post = mocker.patch("superset.db_engine_specs.base.requests.post")
mock_post.return_value.status_code = status_code
mock_post.return_value.text = '{"error": "invalid_grant"}'
config: OAuth2ClientConfig = {
"id": "client-id",
"secret": "client-secret",
"scope": "read write",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "https://oauth.example.com/authorize",
"token_request_uri": "https://oauth.example.com/token",
"request_content_type": "json",
}
with pytest.raises(OAuth2TokenRefreshError) as exc_info:
BaseEngineSpec.get_oauth2_fresh_token(config, "refresh-token")
assert exc_info.value.error.extra["error"] == '{"error": "invalid_grant"}'
@with_config({"DATABASE_OAUTH2_TIMEOUT": timedelta(seconds=30)})
def test_get_oauth2_fresh_token_raises_on_server_error(mocker: MockerFixture) -> None:
"""
Test that get_oauth2_fresh_token raises HTTPError (not OAuth2TokenRefreshError)
on 5xx.
"""
from requests.exceptions import HTTPError
from superset.db_engine_specs.base import BaseEngineSpec
mock_post = mocker.patch("superset.db_engine_specs.base.requests.post")
mock_post.return_value.status_code = 500
mock_post.return_value.raise_for_status.side_effect = HTTPError("500 Server Error")
config: OAuth2ClientConfig = {
"id": "client-id",
"secret": "client-secret",
"scope": "read write",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "https://oauth.example.com/authorize",
"token_request_uri": "https://oauth.example.com/token",
"request_content_type": "json",
}
with pytest.raises(HTTPError):
BaseEngineSpec.get_oauth2_fresh_token(config, "refresh-token")
def test_start_oauth2_dance_uses_config_redirect_uri(mocker: MockerFixture) -> None:
"""
Test that start_oauth2_dance uses DATABASE_OAUTH2_REDIRECT_URI config if set.
@@ -1182,19 +1243,18 @@ def test_start_oauth2_dance_uses_config_redirect_uri(mocker: MockerFixture) -> N
assert error.extra["redirect_uri"] == custom_redirect_uri
@with_config(
{
"SECRET_KEY": "test-secret-key",
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
}
)
def test_start_oauth2_dance_falls_back_to_url_for(mocker: MockerFixture) -> None:
"""
Test that start_oauth2_dance falls back to url_for when no config is set.
"""
fallback_uri = "http://localhost:8088/api/v1/database/oauth2/"
mocker.patch(
"flask.current_app.config",
{
"SECRET_KEY": "test-secret-key",
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
},
)
mocker.patch(
"superset.utils.oauth2.url_for",
return_value=fallback_uri,

View File

@@ -24,11 +24,10 @@ import pandas as pd
import pytest
from pytest_mock import MockerFixture
from requests.exceptions import HTTPError
from shillelagh.exceptions import UnauthenticatedError
from sqlalchemy.engine.url import make_url
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetException
from superset.exceptions import OAuth2TokenRefreshError, SupersetException
from superset.sql.parse import Table
from superset.superset_typing import OAuth2ClientConfig
from superset.utils import json
@@ -817,26 +816,20 @@ def test_get_oauth2_fresh_token_invalid_grant(
oauth2_config: OAuth2ClientConfig,
) -> None:
"""
Test that get_oauth2_fresh_token raises UnauthenticatedError for invalid_grant.
Test that get_oauth2_fresh_token raises OAuth2TokenRefreshError for a 400 response.
When a token is revoked on Google side, the refresh request returns 400
with error=invalid_grant.
When a token is revoked on Google side, the refresh request returns 400.
"""
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
mock_response = mocker.MagicMock()
mock_response.status_code = 400
mock_response.json.return_value = {
"error": "invalid_grant",
"error_description": "Token has been expired or revoked.",
}
http_error = HTTPError()
http_error.response = mock_response
requests = mocker.patch("superset.db_engine_specs.base.requests")
requests.post().raise_for_status.side_effect = http_error
requests.post().status_code = 400
requests.post().text = (
'{"error": "invalid_grant",'
' "error_description": "Token has been expired or revoked."}'
)
with pytest.raises(UnauthenticatedError):
with pytest.raises(OAuth2TokenRefreshError):
GSheetsEngineSpec.get_oauth2_fresh_token(oauth2_config, "refresh-token")

View File

@@ -137,6 +137,7 @@ def test_refresh_oauth2_token_deletes_token_on_oauth2_exception(
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
db.session.delete.assert_called_with(token)
db.session.flush.assert_called_once()
def test_refresh_oauth2_token_keeps_token_on_other_exception(