feat(SIP-85): OAuth2 for databases (#27631)

This commit is contained in:
Beto Dealmeida
2024-04-02 22:05:33 -04:00
committed by GitHub
parent fdc2dbe7db
commit 9022f5c519
46 changed files with 2080 additions and 44 deletions

View File

@@ -18,6 +18,7 @@
# pylint: disable=unused-argument, import-outside-toplevel, line-too-long
import json
from datetime import datetime
from io import BytesIO
from typing import Any
from unittest.mock import Mock
@@ -25,10 +26,12 @@ from uuid import UUID
import pytest
from flask import current_app
from freezegun import freeze_time
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session
from superset import db
from superset.db_engine_specs.sqlite import SqliteEngineSpec
def test_filter_by_uuid(
@@ -638,3 +641,170 @@ def test_apply_dynamic_database_filter(
# Ensure that the filter has been called once
assert base_filter_mock.call_count == 1
def test_oauth2_happy_path(
mocker: MockFixture,
session: Session,
client: Any,
full_api_access: None,
) -> None:
"""
Test the OAuth2 endpoint when everything goes well.
"""
from superset.databases.api import DatabaseRestApi
from superset.models.core import Database, DatabaseUserOAuth2Tokens
DatabaseRestApi.datamodel.session = session
# create table for databases
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
db.session.add(
Database(
database_name="my_db",
sqlalchemy_uri="sqlite://",
uuid=UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"),
)
)
db.session.commit()
get_oauth2_token = mocker.patch.object(SqliteEngineSpec, "get_oauth2_token")
get_oauth2_token.return_value = {
"access_token": "YYY",
"expires_in": 3600,
"refresh_token": "ZZZ",
}
state = {
"user_id": 1,
"database_id": 1,
"tab_id": 42,
}
decode_oauth2_state = mocker.patch("superset.databases.api.decode_oauth2_state")
decode_oauth2_state.return_value = state
mocker.patch("superset.databases.api.render_template", return_value="OK")
with freeze_time("2024-01-01T00:00:00Z"):
response = client.get(
"/api/v1/database/oauth2/",
query_string={
"state": "some%2Estate",
"code": "XXX",
},
)
assert response.status_code == 200
decode_oauth2_state.assert_called_with("some%2Estate")
get_oauth2_token.assert_called_with("XXX", state)
token = db.session.query(DatabaseUserOAuth2Tokens).one()
assert token.user_id == 1
assert token.database_id == 1
assert token.access_token == "YYY"
assert token.access_token_expiration == datetime(2024, 1, 1, 1, 0)
assert token.refresh_token == "ZZZ"
def test_oauth2_multiple_tokens(
mocker: MockFixture,
session: Session,
client: Any,
full_api_access: None,
) -> None:
"""
Test the OAuth2 endpoint when a second token is added.
"""
from superset.databases.api import DatabaseRestApi
from superset.models.core import Database, DatabaseUserOAuth2Tokens
DatabaseRestApi.datamodel.session = session
# create table for databases
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
db.session.add(
Database(
database_name="my_db",
sqlalchemy_uri="sqlite://",
uuid=UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"),
)
)
db.session.commit()
get_oauth2_token = mocker.patch.object(SqliteEngineSpec, "get_oauth2_token")
get_oauth2_token.side_effect = [
{
"access_token": "YYY",
"expires_in": 3600,
"refresh_token": "ZZZ",
},
{
"access_token": "YYY2",
"expires_in": 3600,
"refresh_token": "ZZZ2",
},
]
state = {
"user_id": 1,
"database_id": 1,
"tab_id": 42,
}
decode_oauth2_state = mocker.patch("superset.databases.api.decode_oauth2_state")
decode_oauth2_state.return_value = state
mocker.patch("superset.databases.api.render_template", return_value="OK")
with freeze_time("2024-01-01T00:00:00Z"):
response = client.get(
"/api/v1/database/oauth2/",
query_string={
"state": "some%2Estate",
"code": "XXX",
},
)
# second request should delete token from the first request
response = client.get(
"/api/v1/database/oauth2/",
query_string={
"state": "some%2Estate",
"code": "XXX",
},
)
assert response.status_code == 200
tokens = db.session.query(DatabaseUserOAuth2Tokens).all()
assert len(tokens) == 1
token = tokens[0]
assert token.access_token == "YYY2"
assert token.refresh_token == "ZZZ2"
def test_oauth2_error(
mocker: MockFixture,
session: Session,
client: Any,
full_api_access: None,
) -> None:
"""
Test the OAuth2 endpoint when OAuth2 errors.
"""
response = client.get(
"/api/v1/database/oauth2/",
query_string={
"error": "Something bad hapened",
},
)
assert response.status_code == 500
assert response.json == {
"errors": [
{
"message": "Something went wrong while doing OAuth2",
"error_type": "OAUTH2_REDIRECT_ERROR",
"level": "error",
"extra": {"error": "Something bad hapened"},
}
]
}