mirror of
https://github.com/apache/superset.git
synced 2026-04-27 12:05:24 +00:00
fix(db oauth2): Improve OAuth2 flow (#39499)
This commit is contained in:
@@ -21,6 +21,7 @@ from __future__ import annotations
|
||||
|
||||
import json # noqa: TID251
|
||||
import re
|
||||
from datetime import timedelta
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
@@ -44,6 +45,7 @@ from superset.superset_typing import (
|
||||
)
|
||||
from superset.utils.core import FilterOperator, GenericDataType
|
||||
from superset.utils.oauth2 import decode_oauth2_state
|
||||
from tests.conftest import with_config
|
||||
from tests.unit_tests.db_engine_specs.utils import assert_column_spec
|
||||
|
||||
|
||||
@@ -597,6 +599,19 @@ def test_extract_errors(mocker: MockerFixture) -> None:
|
||||
assert result == [expected]
|
||||
|
||||
|
||||
@with_config(
|
||||
{
|
||||
"CUSTOM_DATABASE_ERRORS": {
|
||||
"examples": {
|
||||
re.compile("This connector does not support roles"): (
|
||||
"Custom error message",
|
||||
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
|
||||
{},
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
def test_extract_errors_from_config(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that custom error messages are extracted correctly from app config
|
||||
@@ -606,21 +621,6 @@ def test_extract_errors_from_config(mocker: MockerFixture) -> None:
|
||||
class TestEngineSpec(BaseEngineSpec):
|
||||
engine_name = "ExampleEngine"
|
||||
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
{
|
||||
"CUSTOM_DATABASE_ERRORS": {
|
||||
"examples": {
|
||||
re.compile("This connector does not support roles"): (
|
||||
"Custom error message",
|
||||
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
|
||||
{},
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
msg = "This connector does not support roles"
|
||||
result = TestEngineSpec.extract_errors(Exception(msg), database_name="examples")
|
||||
|
||||
@@ -631,6 +631,19 @@ def test_extract_errors_from_config(mocker: MockerFixture) -> None:
|
||||
assert result == [expected]
|
||||
|
||||
|
||||
@with_config(
|
||||
{
|
||||
"CUSTOM_DATABASE_ERRORS": {
|
||||
"examples": {
|
||||
re.compile("This connector does not support roles"): (
|
||||
"Custom error message",
|
||||
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
|
||||
{},
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
def test_extract_errors_only_to_specified_database(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that custom error messages are only applied to the specified database_name.
|
||||
@@ -639,21 +652,6 @@ def test_extract_errors_only_to_specified_database(mocker: MockerFixture) -> Non
|
||||
class TestEngineSpec(BaseEngineSpec):
|
||||
engine_name = "ExampleEngine"
|
||||
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
{
|
||||
"CUSTOM_DATABASE_ERRORS": {
|
||||
"examples": {
|
||||
re.compile("This connector does not support roles"): (
|
||||
"Custom error message",
|
||||
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
|
||||
{},
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
msg = "This connector does not support roles"
|
||||
# database_name doesn't match configured one, so default message is used
|
||||
result = TestEngineSpec.extract_errors(Exception(msg), database_name="examples_2")
|
||||
@@ -665,6 +663,27 @@ def test_extract_errors_only_to_specified_database(mocker: MockerFixture) -> Non
|
||||
assert result == [expected]
|
||||
|
||||
|
||||
@with_config(
|
||||
{
|
||||
"CUSTOM_DATABASE_ERRORS": {
|
||||
"examples": {
|
||||
re.compile(r'message="(?P<message>[^"]*)"'): (
|
||||
'Unexpected error: "%(message)s"',
|
||||
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
|
||||
{
|
||||
"custom_doc_links": [
|
||||
{
|
||||
"url": "https://example.com/docs",
|
||||
"label": "Check documentation",
|
||||
},
|
||||
],
|
||||
"show_issue_info": False,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
def test_extract_errors_from_config_with_regex(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that custom error messages with regex, custom_doc_links,
|
||||
@@ -674,29 +693,6 @@ def test_extract_errors_from_config_with_regex(mocker: MockerFixture) -> None:
|
||||
class TestEngineSpec(BaseEngineSpec):
|
||||
engine_name = "ExampleEngine"
|
||||
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
{
|
||||
"CUSTOM_DATABASE_ERRORS": {
|
||||
"examples": {
|
||||
re.compile(r'message="(?P<message>[^"]*)"'): (
|
||||
'Unexpected error: "%(message)s"',
|
||||
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
|
||||
{
|
||||
"custom_doc_links": [
|
||||
{
|
||||
"url": "https://example.com/docs",
|
||||
"label": "Check documentation",
|
||||
},
|
||||
],
|
||||
"show_issue_info": False,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
msg = (
|
||||
"db error: SomeUserError(type=USER_ERROR, name=TABLE_NOT_FOUND, "
|
||||
'message="line 3:6: Table '
|
||||
@@ -735,6 +731,7 @@ def test_extract_errors_from_config_with_regex(mocker: MockerFixture) -> None:
|
||||
]
|
||||
|
||||
|
||||
@with_config({"CUSTOM_DATABASE_ERRORS": {"examples": "not a dict"}})
|
||||
def test_extract_errors_with_non_dict_custom_errors(mocker: MockerFixture):
|
||||
"""
|
||||
Test that extract_errors doesn't fail when custom database errors
|
||||
@@ -744,11 +741,6 @@ def test_extract_errors_with_non_dict_custom_errors(mocker: MockerFixture):
|
||||
class TestEngineSpec(BaseEngineSpec):
|
||||
engine_name = "ExampleEngine"
|
||||
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
{"CUSTOM_DATABASE_ERRORS": "not a dict"},
|
||||
)
|
||||
|
||||
msg = "This connector does not support roles"
|
||||
result = TestEngineSpec.extract_errors(Exception(msg))
|
||||
|
||||
@@ -759,6 +751,7 @@ def test_extract_errors_with_non_dict_custom_errors(mocker: MockerFixture):
|
||||
assert result == [expected]
|
||||
|
||||
|
||||
@with_config({"CUSTOM_DATABASE_ERRORS": {"examples": "not a dict"}})
|
||||
def test_extract_errors_with_non_dict_engine_custom_errors(mocker: MockerFixture):
|
||||
"""
|
||||
Test that extract_errors doesn't fail when database-specific custom errors
|
||||
@@ -768,11 +761,6 @@ def test_extract_errors_with_non_dict_engine_custom_errors(mocker: MockerFixture
|
||||
class TestEngineSpec(BaseEngineSpec):
|
||||
engine_name = "ExampleEngine"
|
||||
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
{"CUSTOM_DATABASE_ERRORS": {"examples": "not a dict"}},
|
||||
)
|
||||
|
||||
msg = "This connector does not support roles"
|
||||
result = TestEngineSpec.extract_errors(Exception(msg), database_name="examples")
|
||||
|
||||
@@ -783,6 +771,19 @@ def test_extract_errors_with_non_dict_engine_custom_errors(mocker: MockerFixture
|
||||
assert result == [expected]
|
||||
|
||||
|
||||
@with_config(
|
||||
{
|
||||
"CUSTOM_DATABASE_ERRORS": {
|
||||
"examples": {
|
||||
re.compile("This connector does not support roles"): (
|
||||
"",
|
||||
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
|
||||
{},
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
def test_extract_errors_with_empty_custom_error_message(mocker: MockerFixture):
|
||||
"""
|
||||
Test that when the custom error message is empty,
|
||||
@@ -792,21 +793,6 @@ def test_extract_errors_with_empty_custom_error_message(mocker: MockerFixture):
|
||||
class TestEngineSpec(BaseEngineSpec):
|
||||
engine_name = "ExampleEngine"
|
||||
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
{
|
||||
"CUSTOM_DATABASE_ERRORS": {
|
||||
"examples": {
|
||||
re.compile("This connector does not support roles"): (
|
||||
"",
|
||||
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
|
||||
{},
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
msg = "This connector does not support roles"
|
||||
result = TestEngineSpec.extract_errors(Exception(msg), database_name="examples")
|
||||
|
||||
@@ -817,6 +803,26 @@ def test_extract_errors_with_empty_custom_error_message(mocker: MockerFixture):
|
||||
assert result == [expected]
|
||||
|
||||
|
||||
@with_config(
|
||||
{
|
||||
"CUSTOM_DATABASE_ERRORS": {
|
||||
"examples": {
|
||||
re.compile("connection error"): (
|
||||
"Examples DB error message",
|
||||
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
|
||||
{},
|
||||
)
|
||||
},
|
||||
"examples_2": {
|
||||
re.compile("connection error"): (
|
||||
"Examples_2 DB error message",
|
||||
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
|
||||
{},
|
||||
)
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
def test_extract_errors_matches_database_name_selection(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that custom error messages are matched by database_name.
|
||||
@@ -825,28 +831,6 @@ def test_extract_errors_matches_database_name_selection(mocker: MockerFixture) -
|
||||
class TestEngineSpec(BaseEngineSpec):
|
||||
engine_name = "ExampleEngine"
|
||||
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
{
|
||||
"CUSTOM_DATABASE_ERRORS": {
|
||||
"examples": {
|
||||
re.compile("connection error"): (
|
||||
"Examples DB error message",
|
||||
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
|
||||
{},
|
||||
)
|
||||
},
|
||||
"examples_2": {
|
||||
re.compile("connection error"): (
|
||||
"Examples_2 DB error message",
|
||||
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
|
||||
{},
|
||||
)
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
msg = "connection error occurred"
|
||||
# When database_name is examples_2 we should get that specific message
|
||||
result = TestEngineSpec.extract_errors(Exception(msg), database_name="examples_2")
|
||||
@@ -858,6 +842,19 @@ def test_extract_errors_matches_database_name_selection(mocker: MockerFixture) -
|
||||
assert result == [expected]
|
||||
|
||||
|
||||
@with_config(
|
||||
{
|
||||
"CUSTOM_DATABASE_ERRORS": {
|
||||
"examples": {
|
||||
re.compile("connection error"): (
|
||||
"Examples DB error message",
|
||||
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
|
||||
{},
|
||||
)
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
def test_extract_errors_no_match_falls_back(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that when database_name has no match, the original error message is preserved.
|
||||
@@ -866,21 +863,6 @@ def test_extract_errors_no_match_falls_back(mocker: MockerFixture) -> None:
|
||||
class TestEngineSpec(BaseEngineSpec):
|
||||
engine_name = "ExampleEngine"
|
||||
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
{
|
||||
"CUSTOM_DATABASE_ERRORS": {
|
||||
"examples": {
|
||||
re.compile("connection error"): (
|
||||
"Examples DB error message",
|
||||
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
|
||||
{},
|
||||
)
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
msg = "some other error"
|
||||
result = TestEngineSpec.extract_errors(Exception(msg), database_name="examples_2")
|
||||
|
||||
@@ -980,16 +962,13 @@ def test_get_oauth2_authorization_uri_with_pkce(mocker: MockerFixture) -> None:
|
||||
assert query["code_challenge"][0] == expected_challenge
|
||||
|
||||
|
||||
@with_config({"DATABASE_OAUTH2_TIMEOUT": timedelta(seconds=30)})
|
||||
def test_get_oauth2_token_without_pkce(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that BaseEngineSpec.get_oauth2_token works without PKCE code_verifier.
|
||||
"""
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
{"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda: 30)},
|
||||
)
|
||||
mock_post = mocker.patch("superset.db_engine_specs.base.requests.post")
|
||||
mock_post.return_value.json.return_value = {
|
||||
"access_token": "test-access-token", # noqa: S105
|
||||
@@ -1015,6 +994,7 @@ def test_get_oauth2_token_without_pkce(mocker: MockerFixture) -> None:
|
||||
assert "code_verifier" not in request_body
|
||||
|
||||
|
||||
@with_config({"DATABASE_OAUTH2_TIMEOUT": timedelta(seconds=30)})
|
||||
def test_get_oauth2_token_with_pkce(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test BaseEngineSpec.get_oauth2_token includes code_verifier when provided.
|
||||
@@ -1022,10 +1002,6 @@ def test_get_oauth2_token_with_pkce(mocker: MockerFixture) -> None:
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.utils.oauth2 import generate_code_verifier
|
||||
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
{"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda: 30)},
|
||||
)
|
||||
mock_post = mocker.patch("superset.db_engine_specs.base.requests.post")
|
||||
mock_post.return_value.json.return_value = {
|
||||
"access_token": "test-access-token", # noqa: S105
|
||||
@@ -1097,6 +1073,7 @@ def test_get_oauth2_authorization_uri_additional_params(
|
||||
assert query["access_type"][0] == "offline"
|
||||
|
||||
|
||||
@with_config({"DATABASE_OAUTH2_TIMEOUT": timedelta(seconds=30)})
|
||||
def test_get_oauth2_token_additional_params(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that a subclass can inject additional params into the token request body
|
||||
@@ -1109,10 +1086,6 @@ def test_get_oauth2_token_additional_params(mocker: MockerFixture) -> None:
|
||||
"audience": "https://api.example.com",
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
{"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda: 30)},
|
||||
)
|
||||
mock_post = mocker.patch("superset.db_engine_specs.base.requests.post")
|
||||
mock_post.return_value.json.return_value = {
|
||||
"access_token": "test-access-token", # noqa: S105
|
||||
@@ -1143,6 +1116,94 @@ def test_get_oauth2_token_additional_params(mocker: MockerFixture) -> None:
|
||||
assert request_body["audience"] == "https://api.example.com"
|
||||
|
||||
|
||||
@with_config({"DATABASE_OAUTH2_TIMEOUT": timedelta(seconds=30)})
|
||||
def test_get_oauth2_fresh_token_success(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that get_oauth2_fresh_token returns the token response on success.
|
||||
"""
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
mock_post = mocker.patch("superset.db_engine_specs.base.requests.post")
|
||||
mock_post.return_value.status_code = 200
|
||||
mock_post.return_value.json.return_value = {
|
||||
"access_token": "new-access-token",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
config: OAuth2ClientConfig = {
|
||||
"id": "client-id",
|
||||
"secret": "client-secret",
|
||||
"scope": "read write",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"authorization_request_uri": "https://oauth.example.com/authorize",
|
||||
"token_request_uri": "https://oauth.example.com/token",
|
||||
"request_content_type": "json",
|
||||
}
|
||||
|
||||
result = BaseEngineSpec.get_oauth2_fresh_token(config, "refresh-token")
|
||||
assert result == {"access_token": "new-access-token", "expires_in": 3600}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("status_code", [400, 401, 403])
|
||||
@with_config({"DATABASE_OAUTH2_TIMEOUT": timedelta(seconds=30)})
|
||||
def test_get_oauth2_fresh_token_raises_on_auth_error(
|
||||
mocker: MockerFixture,
|
||||
status_code: int,
|
||||
) -> None:
|
||||
"""
|
||||
Test that get_oauth2_fresh_token raises OAuth2TokenRefreshError on 400/401/403.
|
||||
"""
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.exceptions import OAuth2TokenRefreshError
|
||||
|
||||
mock_post = mocker.patch("superset.db_engine_specs.base.requests.post")
|
||||
mock_post.return_value.status_code = status_code
|
||||
mock_post.return_value.text = '{"error": "invalid_grant"}'
|
||||
|
||||
config: OAuth2ClientConfig = {
|
||||
"id": "client-id",
|
||||
"secret": "client-secret",
|
||||
"scope": "read write",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"authorization_request_uri": "https://oauth.example.com/authorize",
|
||||
"token_request_uri": "https://oauth.example.com/token",
|
||||
"request_content_type": "json",
|
||||
}
|
||||
|
||||
with pytest.raises(OAuth2TokenRefreshError) as exc_info:
|
||||
BaseEngineSpec.get_oauth2_fresh_token(config, "refresh-token")
|
||||
|
||||
assert exc_info.value.error.extra["error"] == '{"error": "invalid_grant"}'
|
||||
|
||||
|
||||
@with_config({"DATABASE_OAUTH2_TIMEOUT": timedelta(seconds=30)})
|
||||
def test_get_oauth2_fresh_token_raises_on_server_error(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that get_oauth2_fresh_token raises HTTPError (not OAuth2TokenRefreshError)
|
||||
on 5xx.
|
||||
"""
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
mock_post = mocker.patch("superset.db_engine_specs.base.requests.post")
|
||||
mock_post.return_value.status_code = 500
|
||||
mock_post.return_value.raise_for_status.side_effect = HTTPError("500 Server Error")
|
||||
|
||||
config: OAuth2ClientConfig = {
|
||||
"id": "client-id",
|
||||
"secret": "client-secret",
|
||||
"scope": "read write",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"authorization_request_uri": "https://oauth.example.com/authorize",
|
||||
"token_request_uri": "https://oauth.example.com/token",
|
||||
"request_content_type": "json",
|
||||
}
|
||||
|
||||
with pytest.raises(HTTPError):
|
||||
BaseEngineSpec.get_oauth2_fresh_token(config, "refresh-token")
|
||||
|
||||
|
||||
def test_start_oauth2_dance_uses_config_redirect_uri(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that start_oauth2_dance uses DATABASE_OAUTH2_REDIRECT_URI config if set.
|
||||
@@ -1182,19 +1243,18 @@ def test_start_oauth2_dance_uses_config_redirect_uri(mocker: MockerFixture) -> N
|
||||
assert error.extra["redirect_uri"] == custom_redirect_uri
|
||||
|
||||
|
||||
@with_config(
|
||||
{
|
||||
"SECRET_KEY": "test-secret-key",
|
||||
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
|
||||
}
|
||||
)
|
||||
def test_start_oauth2_dance_falls_back_to_url_for(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that start_oauth2_dance falls back to url_for when no config is set.
|
||||
"""
|
||||
fallback_uri = "http://localhost:8088/api/v1/database/oauth2/"
|
||||
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
{
|
||||
"SECRET_KEY": "test-secret-key",
|
||||
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
|
||||
},
|
||||
)
|
||||
mocker.patch(
|
||||
"superset.utils.oauth2.url_for",
|
||||
return_value=fallback_uri,
|
||||
|
||||
Reference in New Issue
Block a user