mirror of
https://github.com/apache/superset.git
synced 2026-04-19 16:14:52 +00:00
fix: more DB OAuth2 fixes (#37398)
This commit is contained in:
@@ -23,18 +23,27 @@ import json # noqa: TID251
|
||||
import re
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy import Boolean, Column, Integer, types
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.engine.url import make_url, URL
|
||||
from sqlalchemy.sql import sqltypes
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, convert_inspector_columns
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import OAuth2RedirectError
|
||||
from superset.sql.parse import Table
|
||||
from superset.superset_typing import ResultSetColumnType, SQLAColumnType
|
||||
from superset.utils.core import GenericDataType
|
||||
from superset.superset_typing import (
|
||||
OAuth2ClientConfig,
|
||||
OAuth2State,
|
||||
ResultSetColumnType,
|
||||
SQLAColumnType,
|
||||
)
|
||||
from superset.utils.core import FilterOperator, GenericDataType
|
||||
from superset.utils.oauth2 import decode_oauth2_state
|
||||
from tests.unit_tests.db_engine_specs.utils import assert_column_spec
|
||||
|
||||
|
||||
@@ -68,9 +77,6 @@ def test_get_text_clause_with_colon() -> None:
|
||||
"""
|
||||
Make sure text clauses are correctly escaped
|
||||
"""
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
text_clause = BaseEngineSpec.get_text_clause(
|
||||
"SELECT foo FROM tbl WHERE foo = '123:456')"
|
||||
)
|
||||
@@ -90,8 +96,6 @@ def test_validate_db_uri(mocker: MockerFixture) -> None:
|
||||
{"DB_SQLA_URI_VALIDATOR": mock_validate},
|
||||
)
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
with pytest.raises(ValueError): # noqa: PT011
|
||||
BaseEngineSpec.validate_database_uri(URL.create("sqlite"))
|
||||
|
||||
@@ -130,8 +134,6 @@ select 'USD' as cur
|
||||
],
|
||||
)
|
||||
def test_cte_query_parsing(original: types.TypeEngine, expected: str) -> None:
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
actual = BaseEngineSpec.get_cte_query(original)
|
||||
assert actual == expected
|
||||
|
||||
@@ -197,8 +199,6 @@ def test_get_column_spec(
|
||||
def test_convert_inspector_columns(
|
||||
cols: list[SQLAColumnType], expected_result: list[ResultSetColumnType]
|
||||
):
|
||||
from superset.db_engine_specs.base import convert_inspector_columns
|
||||
|
||||
assert convert_inspector_columns(cols) == expected_result
|
||||
|
||||
|
||||
@@ -206,8 +206,6 @@ def test_select_star(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test the ``select_star`` method.
|
||||
"""
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
cols: list[ResultSetColumnType] = [
|
||||
{
|
||||
"column_name": "a",
|
||||
@@ -249,7 +247,6 @@ def test_extra_table_metadata(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test the deprecated `extra_table_metadata` method.
|
||||
"""
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.models.core import Database
|
||||
|
||||
class ThirdPartyDBEngineSpec(BaseEngineSpec):
|
||||
@@ -285,8 +282,6 @@ def test_get_default_catalog(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test the `get_default_catalog` method.
|
||||
"""
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
database = mocker.MagicMock()
|
||||
assert BaseEngineSpec.get_default_catalog(database) is None
|
||||
|
||||
@@ -295,7 +290,6 @@ def test_quote_table() -> None:
|
||||
"""
|
||||
Test the `quote_table` function.
|
||||
"""
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
dialect = sqlite.dialect()
|
||||
|
||||
@@ -318,8 +312,6 @@ def test_mask_encrypted_extra() -> None:
|
||||
"""
|
||||
Test that the private key is masked when the database is edited.
|
||||
"""
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
config = json.dumps(
|
||||
{
|
||||
"foo": "bar",
|
||||
@@ -342,8 +334,6 @@ def test_unmask_encrypted_extra() -> None:
|
||||
"""
|
||||
Test that the private key can be reused from the previous `encrypted_extra`.
|
||||
"""
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
old = json.dumps(
|
||||
{
|
||||
"foo": "bar",
|
||||
@@ -375,8 +365,6 @@ def test_impersonate_user_backwards_compatible(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that the `impersonate_user` method calls the original methods it replaced.
|
||||
"""
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
database = mocker.MagicMock()
|
||||
url = make_url("sqlite://foo.db")
|
||||
new_url = make_url("sqlite://bar.db")
|
||||
@@ -417,8 +405,6 @@ def test_impersonate_user_no_database(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test `impersonate_user` when `update_impersonation_config` has an old signature.
|
||||
"""
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
database = mocker.MagicMock()
|
||||
url = make_url("sqlite://foo.db")
|
||||
new_url = make_url("sqlite://bar.db")
|
||||
@@ -457,10 +443,6 @@ def test_handle_boolean_filter_default_behavior() -> None:
|
||||
"""
|
||||
Test that BaseEngineSpec uses IS operators for boolean filters by default.
|
||||
"""
|
||||
from sqlalchemy import Boolean, Column
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
# Create a mock SQLAlchemy column
|
||||
bool_col = Column("test_col", Boolean)
|
||||
|
||||
@@ -479,9 +461,6 @@ def test_handle_boolean_filter_with_equality() -> None:
|
||||
"""
|
||||
Test that BaseEngineSpec can use equality operators when configured.
|
||||
"""
|
||||
from sqlalchemy import Boolean, Column
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
# Create a test engine spec that uses equality
|
||||
class TestEngineSpec(BaseEngineSpec):
|
||||
@@ -502,15 +481,9 @@ def test_handle_null_filter() -> None:
|
||||
"""
|
||||
Test null/not null filter handling.
|
||||
"""
|
||||
from sqlalchemy import Boolean, Column
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
bool_col = Column("test_col", Boolean)
|
||||
|
||||
# Test IS_NULL - use actual FilterOperator values
|
||||
from superset.utils.core import FilterOperator
|
||||
|
||||
result_null = BaseEngineSpec.handle_null_filter(bool_col, FilterOperator.IS_NULL)
|
||||
assert hasattr(result_null, "left")
|
||||
assert hasattr(result_null, "right")
|
||||
@@ -531,15 +504,9 @@ def test_handle_comparison_filter() -> None:
|
||||
"""
|
||||
Test comparison filter handling for all operators.
|
||||
"""
|
||||
from sqlalchemy import Column, Integer
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
int_col = Column("test_col", Integer)
|
||||
|
||||
# Test all comparison operators - use actual FilterOperator values
|
||||
from superset.utils.core import FilterOperator
|
||||
|
||||
operators_and_values = [
|
||||
(FilterOperator.EQUALS, 5),
|
||||
(FilterOperator.NOT_EQUALS, 5),
|
||||
@@ -563,8 +530,6 @@ def test_use_equality_for_boolean_filters_property() -> None:
|
||||
"""
|
||||
Test that BaseEngineSpec has the correct default value for boolean filter property.
|
||||
"""
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
# Default should be False (use IS operators)
|
||||
assert BaseEngineSpec.use_equality_for_boolean_filters is False
|
||||
|
||||
@@ -573,9 +538,6 @@ def test_extract_errors(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that error is extracted correctly when no custom error message is provided.
|
||||
"""
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
{},
|
||||
@@ -597,8 +559,6 @@ def test_extract_errors_from_config(mocker: MockerFixture) -> None:
|
||||
using database_name.
|
||||
"""
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
class TestEngineSpec(BaseEngineSpec):
|
||||
engine_name = "ExampleEngine"
|
||||
|
||||
@@ -632,8 +592,6 @@ def test_extract_errors_only_to_specified_database(mocker: MockerFixture) -> Non
|
||||
Test that custom error messages are only applied to the specified database_name.
|
||||
"""
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
class TestEngineSpec(BaseEngineSpec):
|
||||
engine_name = "ExampleEngine"
|
||||
|
||||
@@ -669,8 +627,6 @@ def test_extract_errors_from_config_with_regex(mocker: MockerFixture) -> None:
|
||||
and show_issue_info are extracted correctly from config.
|
||||
"""
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
class TestEngineSpec(BaseEngineSpec):
|
||||
engine_name = "ExampleEngine"
|
||||
|
||||
@@ -740,7 +696,6 @@ def test_extract_errors_with_non_dict_custom_errors(mocker: MockerFixture):
|
||||
Test that extract_errors doesn't fail when custom database errors
|
||||
are in wrong format.
|
||||
"""
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
class TestEngineSpec(BaseEngineSpec):
|
||||
engine_name = "ExampleEngine"
|
||||
@@ -765,7 +720,6 @@ def test_extract_errors_with_non_dict_engine_custom_errors(mocker: MockerFixture
|
||||
Test that extract_errors doesn't fail when database-specific custom errors
|
||||
are in wrong format.
|
||||
"""
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
class TestEngineSpec(BaseEngineSpec):
|
||||
engine_name = "ExampleEngine"
|
||||
@@ -790,7 +744,6 @@ def test_extract_errors_with_empty_custom_error_message(mocker: MockerFixture):
|
||||
Test that when the custom error message is empty,
|
||||
the original error message is preserved.
|
||||
"""
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
class TestEngineSpec(BaseEngineSpec):
|
||||
engine_name = "ExampleEngine"
|
||||
@@ -824,7 +777,6 @@ def test_extract_errors_matches_database_name_selection(mocker: MockerFixture) -
|
||||
"""
|
||||
Test that custom error messages are matched by database_name.
|
||||
"""
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
class TestEngineSpec(BaseEngineSpec):
|
||||
engine_name = "ExampleEngine"
|
||||
@@ -866,7 +818,6 @@ def test_extract_errors_no_match_falls_back(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that when database_name has no match, the original error message is preserved.
|
||||
"""
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
class TestEngineSpec(BaseEngineSpec):
|
||||
engine_name = "ExampleEngine"
|
||||
@@ -901,12 +852,6 @@ def test_get_oauth2_authorization_uri_standard_params(mocker: MockerFixture) ->
|
||||
Test that BaseEngineSpec.get_oauth2_authorization_uri uses standard OAuth 2.0
|
||||
parameters only and does not include provider-specific params like prompt=consent.
|
||||
"""
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.superset_typing import OAuth2ClientConfig, OAuth2State
|
||||
from superset.utils.oauth2 import decode_oauth2_state
|
||||
|
||||
config: OAuth2ClientConfig = {
|
||||
"id": "client-id",
|
||||
"secret": "client-secret",
|
||||
@@ -943,3 +888,81 @@ def test_get_oauth2_authorization_uri_standard_params(mocker: MockerFixture) ->
|
||||
assert "prompt" not in query
|
||||
assert "access_type" not in query
|
||||
assert "include_granted_scopes" not in query
|
||||
|
||||
|
||||
def test_start_oauth2_dance_uses_config_redirect_uri(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that start_oauth2_dance uses DATABASE_OAUTH2_REDIRECT_URI config if set.
|
||||
"""
|
||||
custom_redirect_uri = "https://proxy.example.com/oauth2/"
|
||||
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
{
|
||||
"DATABASE_OAUTH2_REDIRECT_URI": custom_redirect_uri,
|
||||
"SECRET_KEY": "test-secret-key",
|
||||
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
|
||||
},
|
||||
)
|
||||
|
||||
g = mocker.patch("superset.db_engine_specs.base.g")
|
||||
g.user.id = 1
|
||||
|
||||
database = mocker.MagicMock()
|
||||
database.id = 1
|
||||
database.get_oauth2_config.return_value = {
|
||||
"id": "client-id",
|
||||
"secret": "client-secret",
|
||||
"scope": "read write",
|
||||
"redirect_uri": "https://another-link.com",
|
||||
"authorization_request_uri": "https://oauth.example.com/authorize",
|
||||
"token_request_uri": "https://oauth.example.com/token",
|
||||
}
|
||||
|
||||
with pytest.raises(OAuth2RedirectError) as exc_info:
|
||||
BaseEngineSpec.start_oauth2_dance(database)
|
||||
|
||||
error = exc_info.value.error
|
||||
|
||||
assert error.extra["redirect_uri"] == custom_redirect_uri
|
||||
|
||||
|
||||
def test_start_oauth2_dance_falls_back_to_url_for(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that start_oauth2_dance falls back to url_for when no config is set.
|
||||
"""
|
||||
fallback_uri = "http://localhost:8088/api/v1/database/oauth2/"
|
||||
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
{
|
||||
"SECRET_KEY": "test-secret-key",
|
||||
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
|
||||
},
|
||||
)
|
||||
mocker.patch(
|
||||
"superset.db_engine_specs.base.url_for",
|
||||
return_value=fallback_uri,
|
||||
)
|
||||
|
||||
g = mocker.patch("superset.db_engine_specs.base.g")
|
||||
g.user.id = 1
|
||||
|
||||
database = mocker.MagicMock()
|
||||
database.id = 1
|
||||
database.get_oauth2_config.return_value = {
|
||||
"id": "client-id",
|
||||
"secret": "client-secret",
|
||||
"scope": "read write",
|
||||
"redirect_uri": "https://another-link.com",
|
||||
"authorization_request_uri": "https://oauth.example.com/authorize",
|
||||
"token_request_uri": "https://oauth.example.com/token",
|
||||
"request_content_type": "json",
|
||||
}
|
||||
|
||||
with pytest.raises(OAuth2RedirectError) as exc_info:
|
||||
BaseEngineSpec.start_oauth2_dance(database)
|
||||
|
||||
error = exc_info.value.error
|
||||
|
||||
assert error.extra["redirect_uri"] == fallback_uri
|
||||
|
||||
@@ -23,6 +23,8 @@ from urllib.parse import parse_qs, urlparse
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from requests.exceptions import HTTPError
|
||||
from shillelagh.exceptions import UnauthenticatedError
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
@@ -737,3 +739,253 @@ def test_update_params_from_encrypted_extra(mocker: MockerFixture) -> None:
|
||||
|
||||
GSheetsEngineSpec.update_params_from_encrypted_extra(database, params)
|
||||
assert params == {"foo": "bar"}
|
||||
|
||||
|
||||
def test_needs_oauth2_with_credentials_error(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that needs_oauth2 returns True for google-auth credentials error.
|
||||
|
||||
When a token is manually revoked on Google side, google-auth tries to
|
||||
refresh credentials but fails with this message.
|
||||
"""
|
||||
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
|
||||
|
||||
g = mocker.patch("superset.db_engine_specs.gsheets.g")
|
||||
g.user = mocker.MagicMock()
|
||||
|
||||
ex = Exception("credentials do not contain the necessary fields")
|
||||
assert GSheetsEngineSpec.needs_oauth2(ex) is True
|
||||
|
||||
|
||||
def test_needs_oauth2_with_other_error(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that needs_oauth2 returns False for other errors.
|
||||
"""
|
||||
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
|
||||
|
||||
g = mocker.patch("superset.db_engine_specs.gsheets.g")
|
||||
g.user = mocker.MagicMock()
|
||||
|
||||
ex = Exception("Some other error")
|
||||
assert GSheetsEngineSpec.needs_oauth2(ex) is False
|
||||
|
||||
|
||||
def test_get_oauth2_fresh_token_success(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config: OAuth2ClientConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Test that get_oauth2_fresh_token returns token on success.
|
||||
"""
|
||||
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
|
||||
|
||||
requests = mocker.patch("superset.db_engine_specs.base.requests")
|
||||
requests.post().json.return_value = {
|
||||
"access_token": "new-access-token",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
result = GSheetsEngineSpec.get_oauth2_fresh_token(oauth2_config, "refresh-token")
|
||||
assert result == {
|
||||
"access_token": "new-access-token",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
|
||||
def test_get_oauth2_fresh_token_invalid_grant(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config: OAuth2ClientConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Test that get_oauth2_fresh_token raises UnauthenticatedError for invalid_grant.
|
||||
|
||||
When a token is revoked on Google side, the refresh request returns 400
|
||||
with error=invalid_grant.
|
||||
"""
|
||||
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
|
||||
|
||||
mock_response = mocker.MagicMock()
|
||||
mock_response.status_code = 400
|
||||
mock_response.json.return_value = {
|
||||
"error": "invalid_grant",
|
||||
"error_description": "Token has been expired or revoked.",
|
||||
}
|
||||
http_error = HTTPError()
|
||||
http_error.response = mock_response
|
||||
|
||||
requests = mocker.patch("superset.db_engine_specs.base.requests")
|
||||
requests.post().raise_for_status.side_effect = http_error
|
||||
|
||||
with pytest.raises(UnauthenticatedError):
|
||||
GSheetsEngineSpec.get_oauth2_fresh_token(oauth2_config, "refresh-token")
|
||||
|
||||
|
||||
def test_get_oauth2_fresh_token_other_http_error(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config: OAuth2ClientConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Test that get_oauth2_fresh_token re-raises non-invalid_grant HTTP errors.
|
||||
"""
|
||||
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
|
||||
|
||||
mock_response = mocker.MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.json.return_value = {"error": "server_error"}
|
||||
|
||||
http_error = HTTPError()
|
||||
http_error.response = mock_response
|
||||
|
||||
requests = mocker.patch("superset.db_engine_specs.base.requests")
|
||||
requests.post().raise_for_status.side_effect = http_error
|
||||
|
||||
with pytest.raises(HTTPError):
|
||||
GSheetsEngineSpec.get_oauth2_fresh_token(oauth2_config, "refresh-token")
|
||||
|
||||
|
||||
def test_get_table_names_triggers_oauth2_dance(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that get_table_names triggers OAuth2 dance when no token exists.
|
||||
"""
|
||||
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
|
||||
|
||||
g = mocker.patch("superset.db_engine_specs.gsheets.g")
|
||||
g.user.id = 1
|
||||
|
||||
get_oauth2_access_token = mocker.patch(
|
||||
"superset.db_engine_specs.gsheets.get_oauth2_access_token",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
database = mocker.MagicMock()
|
||||
database.id = 1
|
||||
database.is_oauth2_enabled.return_value = True
|
||||
database.get_oauth2_config.return_value = {"id": "client-id"}
|
||||
database.db_engine_spec = GSheetsEngineSpec
|
||||
|
||||
inspector = mocker.MagicMock()
|
||||
|
||||
GSheetsEngineSpec.get_table_names(database, inspector, None)
|
||||
|
||||
database.start_oauth2_dance.assert_called_once()
|
||||
get_oauth2_access_token.assert_called_once()
|
||||
|
||||
|
||||
def test_get_table_names_does_not_trigger_oauth2_when_token_exists(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that get_table_names does not trigger OAuth2 dance when token exists.
|
||||
"""
|
||||
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
|
||||
|
||||
g = mocker.patch("superset.db_engine_specs.gsheets.g")
|
||||
g.user.id = 1
|
||||
|
||||
get_oauth2_access_token = mocker.patch(
|
||||
"superset.db_engine_specs.gsheets.get_oauth2_access_token",
|
||||
return_value="valid-token",
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"superset.db_engine_specs.shillelagh.ShillelaghEngineSpec.get_table_names",
|
||||
return_value={"sheet1", "sheet2"},
|
||||
)
|
||||
|
||||
database = mocker.MagicMock()
|
||||
database.id = 1
|
||||
database.is_oauth2_enabled.return_value = True
|
||||
database.get_oauth2_config.return_value = {"id": "client-id"}
|
||||
database.db_engine_spec = GSheetsEngineSpec
|
||||
|
||||
inspector = mocker.MagicMock()
|
||||
|
||||
result = GSheetsEngineSpec.get_table_names(database, inspector, None)
|
||||
|
||||
database.start_oauth2_dance.assert_not_called()
|
||||
get_oauth2_access_token.assert_called_once()
|
||||
assert result == {"sheet1", "sheet2"}
|
||||
|
||||
|
||||
def test_validate_parameters_skips_oauth2_connections_with_parameters(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that validate_parameters skips validation for OAuth2 connections.
|
||||
|
||||
When oauth2_client_info is present in parameters, the validation should
|
||||
skip URL checks since the user will authenticate via OAuth2.
|
||||
"""
|
||||
from superset.db_engine_specs.gsheets import (
|
||||
GSheetsEngineSpec,
|
||||
GSheetsPropertiesType,
|
||||
)
|
||||
|
||||
g = mocker.patch("superset.db_engine_specs.gsheets.g")
|
||||
g.user.email = "admin@example.org"
|
||||
|
||||
create_engine = mocker.patch("superset.db_engine_specs.gsheets.create_engine")
|
||||
conn = create_engine.return_value.connect.return_value
|
||||
results = conn.execute.return_value
|
||||
results.fetchall.side_effect = ProgrammingError(
|
||||
"The caller does not have permission"
|
||||
)
|
||||
|
||||
properties: GSheetsPropertiesType = {
|
||||
"parameters": {
|
||||
"service_account_info": "",
|
||||
"catalog": {},
|
||||
"oauth2_client_info": {"id": "client-id", "secret": "client-secret"},
|
||||
},
|
||||
"catalog": {
|
||||
"sheet1": "https://docs.google.com/spreadsheets/d/1/edit",
|
||||
},
|
||||
}
|
||||
errors = GSheetsEngineSpec.validate_parameters(properties)
|
||||
|
||||
assert errors == []
|
||||
conn.execute.assert_not_called()
|
||||
|
||||
|
||||
def test_validate_parameters_skips_oauth2_connections_with_masked_encrypted_extra(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test validate_parameters skips validation for OAuth2 via masked_encrypted_extra.
|
||||
|
||||
When oauth2_client_info is present in masked_encrypted_extra (used during
|
||||
create/update), the validation should skip URL checks.
|
||||
"""
|
||||
from superset.db_engine_specs.gsheets import (
|
||||
GSheetsEngineSpec,
|
||||
GSheetsPropertiesType,
|
||||
)
|
||||
|
||||
g = mocker.patch("superset.db_engine_specs.gsheets.g")
|
||||
g.user.email = "admin@example.org"
|
||||
|
||||
create_engine = mocker.patch("superset.db_engine_specs.gsheets.create_engine")
|
||||
conn = create_engine.return_value.connect.return_value
|
||||
results = conn.execute.return_value
|
||||
results.fetchall.side_effect = ProgrammingError(
|
||||
"The caller does not have permission"
|
||||
)
|
||||
|
||||
properties: GSheetsPropertiesType = {
|
||||
"parameters": {
|
||||
"service_account_info": "",
|
||||
"catalog": {},
|
||||
},
|
||||
"catalog": {
|
||||
"sheet1": "https://docs.google.com/spreadsheets/d/1/edit",
|
||||
},
|
||||
"masked_encrypted_extra": json.dumps(
|
||||
{
|
||||
"oauth2_client_info": {"id": "client-id", "secret": "XXXXXXXXXX"},
|
||||
}
|
||||
),
|
||||
}
|
||||
errors = GSheetsEngineSpec.validate_parameters(properties)
|
||||
|
||||
assert errors == []
|
||||
conn.execute.assert_not_called()
|
||||
|
||||
@@ -18,11 +18,16 @@
|
||||
# pylint: disable=invalid-name, disallowed-name
|
||||
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from superset.utils.oauth2 import get_oauth2_access_token
|
||||
from superset.superset_typing import OAuth2ClientConfig
|
||||
from superset.utils.oauth2 import get_oauth2_access_token, refresh_oauth2_token
|
||||
|
||||
DUMMY_OAUTH2_CONFIG = cast(OAuth2ClientConfig, {})
|
||||
|
||||
|
||||
def test_get_oauth2_access_token_base_no_token(mocker: MockerFixture) -> None:
|
||||
@@ -93,3 +98,82 @@ def test_get_oauth2_access_token_base_no_refresh(mocker: MockerFixture) -> None:
|
||||
|
||||
# check that token was deleted
|
||||
db.session.delete.assert_called_with(token)
|
||||
|
||||
|
||||
def test_refresh_oauth2_token_deletes_token_on_oauth2_exception(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that refresh_oauth2_token deletes the token on OAuth2-specific exception.
|
||||
|
||||
When the token refresh fails with an OAuth2-specific exception (e.g., token
|
||||
was revoked), the invalid token should be deleted and the exception re-raised.
|
||||
"""
|
||||
db = mocker.patch("superset.utils.oauth2.db")
|
||||
mocker.patch("superset.utils.oauth2.KeyValueDistributedLock")
|
||||
|
||||
class OAuth2ExceptionError(Exception):
|
||||
pass
|
||||
|
||||
db_engine_spec = mocker.MagicMock()
|
||||
db_engine_spec.oauth2_exception = OAuth2ExceptionError
|
||||
db_engine_spec.get_oauth2_fresh_token.side_effect = OAuth2ExceptionError(
|
||||
"Token revoked"
|
||||
)
|
||||
token = mocker.MagicMock()
|
||||
token.refresh_token = "refresh-token" # noqa: S105
|
||||
|
||||
with pytest.raises(OAuth2ExceptionError):
|
||||
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
|
||||
|
||||
db.session.delete.assert_called_with(token)
|
||||
|
||||
|
||||
def test_refresh_oauth2_token_keeps_token_on_other_exception(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that refresh_oauth2_token keeps the token on non-OAuth2 exceptions.
|
||||
|
||||
When the token refresh fails with a transient error (e.g., network issue),
|
||||
the token should be kept (refresh token may still be valid) and the
|
||||
exception re-raised.
|
||||
"""
|
||||
db = mocker.patch("superset.utils.oauth2.db")
|
||||
mocker.patch("superset.utils.oauth2.KeyValueDistributedLock")
|
||||
|
||||
class OAuth2ExceptionError(Exception):
|
||||
pass
|
||||
|
||||
db_engine_spec = mocker.MagicMock()
|
||||
db_engine_spec.oauth2_exception = OAuth2ExceptionError
|
||||
db_engine_spec.get_oauth2_fresh_token.side_effect = Exception("Network error")
|
||||
token = mocker.MagicMock()
|
||||
token.refresh_token = "refresh-token" # noqa: S105
|
||||
|
||||
with pytest.raises(Exception, match="Network error"):
|
||||
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
|
||||
|
||||
db.session.delete.assert_not_called()
|
||||
|
||||
|
||||
def test_refresh_oauth2_token_no_access_token_in_response(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that refresh_oauth2_token returns None when no access_token in response.
|
||||
|
||||
This can happen when the refresh token was revoked.
|
||||
"""
|
||||
mocker.patch("superset.utils.oauth2.db")
|
||||
mocker.patch("superset.utils.oauth2.KeyValueDistributedLock")
|
||||
db_engine_spec = mocker.MagicMock()
|
||||
db_engine_spec.get_oauth2_fresh_token.return_value = {
|
||||
"error": "invalid_grant",
|
||||
}
|
||||
token = mocker.MagicMock()
|
||||
token.refresh_token = "refresh-token" # noqa: S105
|
||||
|
||||
result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
|
||||
|
||||
assert result is None
|
||||
|
||||
Reference in New Issue
Block a user