fix: more DB OAuth2 fixes (#37398)

This commit is contained in:
Vitor Avila
2026-01-30 21:11:26 -03:00
committed by GitHub
parent 1ee14c5993
commit 6043e7e7e3
6 changed files with 543 additions and 76 deletions

View File

@@ -23,18 +23,27 @@ import json # noqa: TID251
import re
from textwrap import dedent
from typing import Any
from urllib.parse import parse_qs, urlparse
import pytest
from pytest_mock import MockerFixture
from sqlalchemy import types
from sqlalchemy import Boolean, Column, Integer, types
from sqlalchemy.dialects import sqlite
from sqlalchemy.engine.url import make_url, URL
from sqlalchemy.sql import sqltypes
from superset.db_engine_specs.base import BaseEngineSpec, convert_inspector_columns
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import OAuth2RedirectError
from superset.sql.parse import Table
from superset.superset_typing import ResultSetColumnType, SQLAColumnType
from superset.utils.core import GenericDataType
from superset.superset_typing import (
OAuth2ClientConfig,
OAuth2State,
ResultSetColumnType,
SQLAColumnType,
)
from superset.utils.core import FilterOperator, GenericDataType
from superset.utils.oauth2 import decode_oauth2_state
from tests.unit_tests.db_engine_specs.utils import assert_column_spec
@@ -68,9 +77,6 @@ def test_get_text_clause_with_colon() -> None:
"""
Make sure text clauses are correctly escaped
"""
from superset.db_engine_specs.base import BaseEngineSpec
text_clause = BaseEngineSpec.get_text_clause(
"SELECT foo FROM tbl WHERE foo = '123:456')"
)
@@ -90,8 +96,6 @@ def test_validate_db_uri(mocker: MockerFixture) -> None:
{"DB_SQLA_URI_VALIDATOR": mock_validate},
)
from superset.db_engine_specs.base import BaseEngineSpec
with pytest.raises(ValueError): # noqa: PT011
BaseEngineSpec.validate_database_uri(URL.create("sqlite"))
@@ -130,8 +134,6 @@ select 'USD' as cur
],
)
def test_cte_query_parsing(original: types.TypeEngine, expected: str) -> None:
from superset.db_engine_specs.base import BaseEngineSpec
actual = BaseEngineSpec.get_cte_query(original)
assert actual == expected
@@ -197,8 +199,6 @@ def test_get_column_spec(
def test_convert_inspector_columns(
cols: list[SQLAColumnType], expected_result: list[ResultSetColumnType]
):
from superset.db_engine_specs.base import convert_inspector_columns
assert convert_inspector_columns(cols) == expected_result
@@ -206,8 +206,6 @@ def test_select_star(mocker: MockerFixture) -> None:
"""
Test the ``select_star`` method.
"""
from superset.db_engine_specs.base import BaseEngineSpec
cols: list[ResultSetColumnType] = [
{
"column_name": "a",
@@ -249,7 +247,6 @@ def test_extra_table_metadata(mocker: MockerFixture) -> None:
"""
Test the deprecated `extra_table_metadata` method.
"""
from superset.db_engine_specs.base import BaseEngineSpec
from superset.models.core import Database
class ThirdPartyDBEngineSpec(BaseEngineSpec):
@@ -285,8 +282,6 @@ def test_get_default_catalog(mocker: MockerFixture) -> None:
"""
Test the `get_default_catalog` method.
"""
from superset.db_engine_specs.base import BaseEngineSpec
database = mocker.MagicMock()
assert BaseEngineSpec.get_default_catalog(database) is None
@@ -295,7 +290,6 @@ def test_quote_table() -> None:
"""
Test the `quote_table` function.
"""
from superset.db_engine_specs.base import BaseEngineSpec
dialect = sqlite.dialect()
@@ -318,8 +312,6 @@ def test_mask_encrypted_extra() -> None:
"""
Test that the private key is masked when the database is edited.
"""
from superset.db_engine_specs.base import BaseEngineSpec
config = json.dumps(
{
"foo": "bar",
@@ -342,8 +334,6 @@ def test_unmask_encrypted_extra() -> None:
"""
Test that the private key can be reused from the previous `encrypted_extra`.
"""
from superset.db_engine_specs.base import BaseEngineSpec
old = json.dumps(
{
"foo": "bar",
@@ -375,8 +365,6 @@ def test_impersonate_user_backwards_compatible(mocker: MockerFixture) -> None:
"""
Test that the `impersonate_user` method calls the original methods it replaced.
"""
from superset.db_engine_specs.base import BaseEngineSpec
database = mocker.MagicMock()
url = make_url("sqlite://foo.db")
new_url = make_url("sqlite://bar.db")
@@ -417,8 +405,6 @@ def test_impersonate_user_no_database(mocker: MockerFixture) -> None:
"""
Test `impersonate_user` when `update_impersonation_config` has an old signature.
"""
from superset.db_engine_specs.base import BaseEngineSpec
database = mocker.MagicMock()
url = make_url("sqlite://foo.db")
new_url = make_url("sqlite://bar.db")
@@ -457,10 +443,6 @@ def test_handle_boolean_filter_default_behavior() -> None:
"""
Test that BaseEngineSpec uses IS operators for boolean filters by default.
"""
from sqlalchemy import Boolean, Column
from superset.db_engine_specs.base import BaseEngineSpec
# Create a mock SQLAlchemy column
bool_col = Column("test_col", Boolean)
@@ -479,9 +461,6 @@ def test_handle_boolean_filter_with_equality() -> None:
"""
Test that BaseEngineSpec can use equality operators when configured.
"""
from sqlalchemy import Boolean, Column
from superset.db_engine_specs.base import BaseEngineSpec
# Create a test engine spec that uses equality
class TestEngineSpec(BaseEngineSpec):
@@ -502,15 +481,9 @@ def test_handle_null_filter() -> None:
"""
Test null/not null filter handling.
"""
from sqlalchemy import Boolean, Column
from superset.db_engine_specs.base import BaseEngineSpec
bool_col = Column("test_col", Boolean)
# Test IS_NULL - use actual FilterOperator values
from superset.utils.core import FilterOperator
result_null = BaseEngineSpec.handle_null_filter(bool_col, FilterOperator.IS_NULL)
assert hasattr(result_null, "left")
assert hasattr(result_null, "right")
@@ -531,15 +504,9 @@ def test_handle_comparison_filter() -> None:
"""
Test comparison filter handling for all operators.
"""
from sqlalchemy import Column, Integer
from superset.db_engine_specs.base import BaseEngineSpec
int_col = Column("test_col", Integer)
# Test all comparison operators - use actual FilterOperator values
from superset.utils.core import FilterOperator
operators_and_values = [
(FilterOperator.EQUALS, 5),
(FilterOperator.NOT_EQUALS, 5),
@@ -563,8 +530,6 @@ def test_use_equality_for_boolean_filters_property() -> None:
"""
Test that BaseEngineSpec has the correct default value for boolean filter property.
"""
from superset.db_engine_specs.base import BaseEngineSpec
# Default should be False (use IS operators)
assert BaseEngineSpec.use_equality_for_boolean_filters is False
@@ -573,9 +538,6 @@ def test_extract_errors(mocker: MockerFixture) -> None:
"""
Test that error is extracted correctly when no custom error message is provided.
"""
from superset.db_engine_specs.base import BaseEngineSpec
mocker.patch(
"flask.current_app.config",
{},
@@ -597,8 +559,6 @@ def test_extract_errors_from_config(mocker: MockerFixture) -> None:
using database_name.
"""
from superset.db_engine_specs.base import BaseEngineSpec
class TestEngineSpec(BaseEngineSpec):
engine_name = "ExampleEngine"
@@ -632,8 +592,6 @@ def test_extract_errors_only_to_specified_database(mocker: MockerFixture) -> Non
Test that custom error messages are only applied to the specified database_name.
"""
from superset.db_engine_specs.base import BaseEngineSpec
class TestEngineSpec(BaseEngineSpec):
engine_name = "ExampleEngine"
@@ -669,8 +627,6 @@ def test_extract_errors_from_config_with_regex(mocker: MockerFixture) -> None:
and show_issue_info are extracted correctly from config.
"""
from superset.db_engine_specs.base import BaseEngineSpec
class TestEngineSpec(BaseEngineSpec):
engine_name = "ExampleEngine"
@@ -740,7 +696,6 @@ def test_extract_errors_with_non_dict_custom_errors(mocker: MockerFixture):
Test that extract_errors doesn't fail when custom database errors
are in wrong format.
"""
from superset.db_engine_specs.base import BaseEngineSpec
class TestEngineSpec(BaseEngineSpec):
engine_name = "ExampleEngine"
@@ -765,7 +720,6 @@ def test_extract_errors_with_non_dict_engine_custom_errors(mocker: MockerFixture
Test that extract_errors doesn't fail when database-specific custom errors
are in wrong format.
"""
from superset.db_engine_specs.base import BaseEngineSpec
class TestEngineSpec(BaseEngineSpec):
engine_name = "ExampleEngine"
@@ -790,7 +744,6 @@ def test_extract_errors_with_empty_custom_error_message(mocker: MockerFixture):
Test that when the custom error message is empty,
the original error message is preserved.
"""
from superset.db_engine_specs.base import BaseEngineSpec
class TestEngineSpec(BaseEngineSpec):
engine_name = "ExampleEngine"
@@ -824,7 +777,6 @@ def test_extract_errors_matches_database_name_selection(mocker: MockerFixture) -
"""
Test that custom error messages are matched by database_name.
"""
from superset.db_engine_specs.base import BaseEngineSpec
class TestEngineSpec(BaseEngineSpec):
engine_name = "ExampleEngine"
@@ -866,7 +818,6 @@ def test_extract_errors_no_match_falls_back(mocker: MockerFixture) -> None:
"""
Test that when database_name has no match, the original error message is preserved.
"""
from superset.db_engine_specs.base import BaseEngineSpec
class TestEngineSpec(BaseEngineSpec):
engine_name = "ExampleEngine"
@@ -901,12 +852,6 @@ def test_get_oauth2_authorization_uri_standard_params(mocker: MockerFixture) ->
Test that BaseEngineSpec.get_oauth2_authorization_uri uses standard OAuth 2.0
parameters only and does not include provider-specific params like prompt=consent.
"""
from urllib.parse import parse_qs, urlparse
from superset.db_engine_specs.base import BaseEngineSpec
from superset.superset_typing import OAuth2ClientConfig, OAuth2State
from superset.utils.oauth2 import decode_oauth2_state
config: OAuth2ClientConfig = {
"id": "client-id",
"secret": "client-secret",
@@ -943,3 +888,81 @@ def test_get_oauth2_authorization_uri_standard_params(mocker: MockerFixture) ->
assert "prompt" not in query
assert "access_type" not in query
assert "include_granted_scopes" not in query
def test_start_oauth2_dance_uses_config_redirect_uri(mocker: MockerFixture) -> None:
"""
Test that start_oauth2_dance uses DATABASE_OAUTH2_REDIRECT_URI config if set.
"""
custom_redirect_uri = "https://proxy.example.com/oauth2/"
mocker.patch(
"flask.current_app.config",
{
"DATABASE_OAUTH2_REDIRECT_URI": custom_redirect_uri,
"SECRET_KEY": "test-secret-key",
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
},
)
g = mocker.patch("superset.db_engine_specs.base.g")
g.user.id = 1
database = mocker.MagicMock()
database.id = 1
database.get_oauth2_config.return_value = {
"id": "client-id",
"secret": "client-secret",
"scope": "read write",
"redirect_uri": "https://another-link.com",
"authorization_request_uri": "https://oauth.example.com/authorize",
"token_request_uri": "https://oauth.example.com/token",
}
with pytest.raises(OAuth2RedirectError) as exc_info:
BaseEngineSpec.start_oauth2_dance(database)
error = exc_info.value.error
assert error.extra["redirect_uri"] == custom_redirect_uri
def test_start_oauth2_dance_falls_back_to_url_for(mocker: MockerFixture) -> None:
"""
Test that start_oauth2_dance falls back to url_for when no config is set.
"""
fallback_uri = "http://localhost:8088/api/v1/database/oauth2/"
mocker.patch(
"flask.current_app.config",
{
"SECRET_KEY": "test-secret-key",
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
},
)
mocker.patch(
"superset.db_engine_specs.base.url_for",
return_value=fallback_uri,
)
g = mocker.patch("superset.db_engine_specs.base.g")
g.user.id = 1
database = mocker.MagicMock()
database.id = 1
database.get_oauth2_config.return_value = {
"id": "client-id",
"secret": "client-secret",
"scope": "read write",
"redirect_uri": "https://another-link.com",
"authorization_request_uri": "https://oauth.example.com/authorize",
"token_request_uri": "https://oauth.example.com/token",
"request_content_type": "json",
}
with pytest.raises(OAuth2RedirectError) as exc_info:
BaseEngineSpec.start_oauth2_dance(database)
error = exc_info.value.error
assert error.extra["redirect_uri"] == fallback_uri