feat(SIP-85): OAuth2 for databases (#27631)

This commit is contained in:
Beto Dealmeida
2024-04-02 22:05:33 -04:00
committed by GitHub
parent fdc2dbe7db
commit 9022f5c519
46 changed files with 2080 additions and 44 deletions

View File

@@ -65,8 +65,9 @@ def test_execute_connection_error() -> None:
cursor.execute.side_effect = NewConnectionError(
HTTPConnection("localhost"), "Exception with sensitive data"
)
with pytest.raises(SupersetDBAPIDatabaseError) as ex:
ClickHouseEngineSpec.execute(cursor, "SELECT col1 from table1")
with pytest.raises(SupersetDBAPIDatabaseError) as excinfo:
ClickHouseEngineSpec.execute(cursor, "SELECT col1 from table1", 1)
assert str(excinfo.value) == "Connection failed"
@pytest.mark.parametrize(

View File

@@ -66,8 +66,9 @@ def test_execute_connection_error() -> None:
cursor.execute.side_effect = NewConnectionError(
HTTPConnection("Dummypool"), "Exception with sensitive data"
)
with pytest.raises(SupersetDBAPIDatabaseError) as ex:
DatabendEngineSpec.execute(cursor, "SELECT col1 from table1")
with pytest.raises(SupersetDBAPIDatabaseError) as excinfo:
DatabendEngineSpec.execute(cursor, "SELECT col1 from table1", 1)
assert str(excinfo.value) == "Connection failed"
@pytest.mark.parametrize(

View File

@@ -38,7 +38,7 @@ def test_odbc_impersonation() -> None:
url = URL.create("drill+odbc")
username = "DoAsUser"
url = DrillEngineSpec.get_url_for_impersonation(url, True, username)
url = DrillEngineSpec.get_url_for_impersonation(url, True, username, None)
assert url.query["DelegationUID"] == username
@@ -54,7 +54,7 @@ def test_jdbc_impersonation() -> None:
url = URL.create("drill+jdbc")
username = "DoAsUser"
url = DrillEngineSpec.get_url_for_impersonation(url, True, username)
url = DrillEngineSpec.get_url_for_impersonation(url, True, username, None)
assert url.query["impersonation_target"] == username
@@ -70,7 +70,7 @@ def test_sadrill_impersonation() -> None:
url = URL.create("drill+sadrill")
username = "DoAsUser"
url = DrillEngineSpec.get_url_for_impersonation(url, True, username)
url = DrillEngineSpec.get_url_for_impersonation(url, True, username, None)
assert url.query["impersonation_target"] == username
@@ -90,7 +90,7 @@ def test_invalid_impersonation() -> None:
username = "DoAsUser"
with pytest.raises(SupersetDBAPIProgrammingError):
DrillEngineSpec.get_url_for_impersonation(url, True, username)
DrillEngineSpec.get_url_for_impersonation(url, True, username, None)
@pytest.mark.parametrize(

View File

@@ -101,6 +101,8 @@ def test_opendistro_strip_comments() -> None:
mock_cursor.execute.return_value = []
OpenDistroEngineSpec.execute(
mock_cursor, "-- some comment \nSELECT 1\n --other comment"
mock_cursor,
"-- some comment \nSELECT 1\n --other comment",
1,
)
mock_cursor.execute.assert_called_once_with("SELECT 1\n")

View File

@@ -18,14 +18,21 @@
# pylint: disable=import-outside-toplevel, invalid-name, line-too-long
import json
from typing import TYPE_CHECKING
from urllib.parse import parse_qs, urlparse
import pandas as pd
import pytest
from pytest_mock import MockFixture
from sqlalchemy.engine.url import make_url
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetException
from superset.sql_parse import Table
from superset.utils.oauth2 import decode_oauth2_state
if TYPE_CHECKING:
from superset.db_engine_specs.base import OAuth2State
class ProgrammingError(Exception):
@@ -399,3 +406,223 @@ def test_upload_existing(mocker: MockFixture) -> None:
mocker.call().json(),
]
)
def test_get_url_for_impersonation_username(mocker: MockFixture) -> None:
"""
Test passing a username to `get_url_for_impersonation`.
"""
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
user = mocker.MagicMock()
user.email = "alice@example.org"
mocker.patch(
"superset.db_engine_specs.gsheets.security_manager.find_user",
return_value=user,
)
assert GSheetsEngineSpec.get_url_for_impersonation(
url=make_url("gsheets://"),
impersonate_user=True,
username="alice",
access_token=None,
) == make_url("gsheets://?subject=alice%40example.org")
def test_get_url_for_impersonation_access_token() -> None:
"""
Test passing an access token to `get_url_for_impersonation`.
"""
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
assert GSheetsEngineSpec.get_url_for_impersonation(
url=make_url("gsheets://"),
impersonate_user=True,
username=None,
access_token="access-token",
) == make_url("gsheets://?access_token=access-token")
def test_is_oauth2_enabled_no_config(mocker: MockFixture) -> None:
"""
Test `is_oauth2_enabled` when OAuth2 is not configured.
"""
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
mocker.patch(
"superset.db_engine_specs.gsheets.current_app.config",
new={"DATABASE_OAUTH2_CREDENTIALS": {}},
)
assert GSheetsEngineSpec.is_oauth2_enabled() is False
def test_is_oauth2_enabled_config(mocker: MockFixture) -> None:
"""
Test `is_oauth2_enabled` when OAuth2 is configured.
"""
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
mocker.patch(
"superset.db_engine_specs.gsheets.current_app.config",
new={
"DATABASE_OAUTH2_CREDENTIALS": {
"Google Sheets": {
"CLIENT_ID": "XXX.apps.googleusercontent.com",
"CLIENT_SECRET": "GOCSPX-YYY",
},
}
},
)
assert GSheetsEngineSpec.is_oauth2_enabled() is True
def test_get_oauth2_authorization_uri(mocker: MockFixture) -> None:
"""
Test `get_oauth2_authorization_uri`.
"""
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
mocker.patch(
"superset.db_engine_specs.gsheets.current_app.config",
new={
"DATABASE_OAUTH2_CREDENTIALS": {
"Google Sheets": {
"CLIENT_ID": "XXX.apps.googleusercontent.com",
"CLIENT_SECRET": "GOCSPX-YYY",
},
},
"SECRET_KEY": "not-a-secret",
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
},
)
state: OAuth2State = {
"database_id": 1,
"user_id": 1,
"default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
"tab_id": "1234",
}
url = GSheetsEngineSpec.get_oauth2_authorization_uri(state)
parsed = urlparse(url)
assert parsed.netloc == "accounts.google.com"
assert parsed.path == "/o/oauth2/v2/auth"
query = parse_qs(parsed.query)
assert query["scope"][0] == (
"https://www.googleapis.com/auth/drive.readonly "
"https://www.googleapis.com/auth/spreadsheets "
"https://spreadsheets.google.com/feeds"
)
encoded_state = query["state"][0].replace("%2E", ".")
assert decode_oauth2_state(encoded_state) == state
def test_get_oauth2_token(mocker: MockFixture) -> None:
"""
Test `get_oauth2_token`.
"""
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
http = mocker.patch("superset.db_engine_specs.gsheets.http")
http.request().data.decode.return_value = json.dumps(
{
"access_token": "access-token",
"expires_in": 3600,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
)
mocker.patch(
"superset.db_engine_specs.gsheets.current_app.config",
new={
"DATABASE_OAUTH2_CREDENTIALS": {
"Google Sheets": {
"CLIENT_ID": "XXX.apps.googleusercontent.com",
"CLIENT_SECRET": "GOCSPX-YYY",
},
},
"SECRET_KEY": "not-a-secret",
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
},
)
state: OAuth2State = {
"database_id": 1,
"user_id": 1,
"default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
"tab_id": "1234",
}
assert GSheetsEngineSpec.get_oauth2_token("code", state) == {
"access_token": "access-token",
"expires_in": 3600,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
http.request.assert_called_with(
"POST",
"https://oauth2.googleapis.com/token",
fields={
"code": "code",
"client_id": "XXX.apps.googleusercontent.com",
"client_secret": "GOCSPX-YYY",
"redirect_uri": "http://localhost:8088/api/v1/oauth2/",
"grant_type": "authorization_code",
},
)
def test_get_oauth2_fresh_token(mocker: MockFixture) -> None:
"""
Test `get_oauth2_token`.
"""
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
http = mocker.patch("superset.db_engine_specs.gsheets.http")
http.request().data.decode.return_value = json.dumps(
{
"access_token": "access-token",
"expires_in": 3600,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
)
mocker.patch(
"superset.db_engine_specs.gsheets.current_app.config",
new={
"DATABASE_OAUTH2_CREDENTIALS": {
"Google Sheets": {
"CLIENT_ID": "XXX.apps.googleusercontent.com",
"CLIENT_SECRET": "GOCSPX-YYY",
},
},
"SECRET_KEY": "not-a-secret",
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
},
)
assert GSheetsEngineSpec.get_oauth2_fresh_token("refresh-token") == {
"access_token": "access-token",
"expires_in": 3600,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
http.request.assert_called_with(
"POST",
"https://oauth2.googleapis.com/token",
fields={
"client_id": "XXX.apps.googleusercontent.com",
"client_secret": "GOCSPX-YYY",
"refresh_token": "refresh-token",
"grant_type": "refresh_token",
},
)