feat: OAuth2 client initial work (#29109)

This commit is contained in:
Beto Dealmeida
2024-06-09 22:11:18 -04:00
committed by GitHub
parent fc9bc175e6
commit 5660f8e554
6 changed files with 210 additions and 27 deletions

View File

@@ -16,6 +16,7 @@
# under the License.
# pylint: disable=import-outside-toplevel
from datetime import datetime
import pytest
@@ -24,11 +25,23 @@ from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import make_url
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.models.core import Database
from superset.sql_parse import Table
from superset.utils import json
from tests.unit_tests.conftest import with_feature_flags
# sample config for OAuth2 tests
oauth2_client_info = {
"oauth2_client_info": {
"id": "my_client_id",
"secret": "my_client_secret",
"authorization_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/authorize",
"token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request",
"scope": "refresh_token session:role:SYSADMIN",
}
}
def test_get_metrics(mocker: MockFixture) -> None:
"""
@@ -378,3 +391,73 @@ def test_get_sqla_engine_user_impersonation_email(mocker: MockFixture) -> None:
make_url("trino:///"),
connect_args={"user": "alice.doe", "source": "Apache Superset"},
)
def test_is_oauth2_enabled() -> None:
"""
Test the `is_oauth2_enabled` method.
"""
database = Database(
database_name="db",
sqlalchemy_uri="postgresql://user:password@host:5432/examples",
)
assert not database.is_oauth2_enabled()
database.encrypted_extra = json.dumps(oauth2_client_info)
assert database.is_oauth2_enabled()
def test_get_oauth2_config(app_context: None) -> None:
"""
Test the `get_oauth2_config` method.
"""
database = Database(
database_name="db",
sqlalchemy_uri="postgresql://user:password@host:5432/examples",
)
assert database.get_oauth2_config() is None
database.encrypted_extra = json.dumps(oauth2_client_info)
assert database.get_oauth2_config() == {
"id": "my_client_id",
"secret": "my_client_secret",
"authorization_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/authorize",
"token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request",
"scope": "refresh_token session:role:SYSADMIN",
"redirect_uri": "http://example.com/api/v1/database/oauth2/",
}
def test_raw_connection_oauth(mocker: MockFixture) -> None:
"""
Test that we can start OAuth2 from `raw_connection()` errors.
Some databases that use OAuth2 need to trigger the flow when the connection is
created, rather than when the query runs. This happens when the SQLAlchemy engine
URI cannot be built without the user personal token.
This test verifies that the exception is captured and raised correctly so that the
frontend can trigger the OAuth2 dance.
"""
g = mocker.patch("superset.db_engine_specs.base.g")
g.user = mocker.MagicMock()
g.user.id = 42
database = Database(
id=1,
database_name="my_db",
sqlalchemy_uri="sqlite://",
encrypted_extra=json.dumps(oauth2_client_info),
)
database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
get_sqla_engine = mocker.patch.object(database, "get_sqla_engine")
get_sqla_engine().__enter__().raw_connection.side_effect = OAuth2Error(
"OAuth2 required"
)
with pytest.raises(OAuth2RedirectError) as excinfo:
with database.get_raw_connection() as conn:
conn.cursor()
assert str(excinfo.value) == "You don't have permission to access the data."