mirror of
https://github.com/apache/superset.git
synced 2026-04-18 23:55:00 +00:00
feat(SIP-85): OAuth2 for databases (#27631)
This commit is contained in:
@@ -18,6 +18,7 @@
|
||||
# pylint: disable=unused-argument, import-outside-toplevel, line-too-long
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
@@ -25,10 +26,12 @@ from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from flask import current_app
|
||||
from freezegun import freeze_time
|
||||
from pytest_mock import MockFixture
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset import db
|
||||
from superset.db_engine_specs.sqlite import SqliteEngineSpec
|
||||
|
||||
|
||||
def test_filter_by_uuid(
|
||||
@@ -638,3 +641,170 @@ def test_apply_dynamic_database_filter(
|
||||
|
||||
# Ensure that the filter has been called once
|
||||
assert base_filter_mock.call_count == 1
|
||||
|
||||
|
||||
def test_oauth2_happy_path(
|
||||
mocker: MockFixture,
|
||||
session: Session,
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
) -> None:
|
||||
"""
|
||||
Test the OAuth2 endpoint when everything goes well.
|
||||
"""
|
||||
from superset.databases.api import DatabaseRestApi
|
||||
from superset.models.core import Database, DatabaseUserOAuth2Tokens
|
||||
|
||||
DatabaseRestApi.datamodel.session = session
|
||||
|
||||
# create table for databases
|
||||
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
|
||||
db.session.add(
|
||||
Database(
|
||||
database_name="my_db",
|
||||
sqlalchemy_uri="sqlite://",
|
||||
uuid=UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"),
|
||||
)
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
get_oauth2_token = mocker.patch.object(SqliteEngineSpec, "get_oauth2_token")
|
||||
get_oauth2_token.return_value = {
|
||||
"access_token": "YYY",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "ZZZ",
|
||||
}
|
||||
|
||||
state = {
|
||||
"user_id": 1,
|
||||
"database_id": 1,
|
||||
"tab_id": 42,
|
||||
}
|
||||
decode_oauth2_state = mocker.patch("superset.databases.api.decode_oauth2_state")
|
||||
decode_oauth2_state.return_value = state
|
||||
|
||||
mocker.patch("superset.databases.api.render_template", return_value="OK")
|
||||
|
||||
with freeze_time("2024-01-01T00:00:00Z"):
|
||||
response = client.get(
|
||||
"/api/v1/database/oauth2/",
|
||||
query_string={
|
||||
"state": "some%2Estate",
|
||||
"code": "XXX",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
decode_oauth2_state.assert_called_with("some%2Estate")
|
||||
get_oauth2_token.assert_called_with("XXX", state)
|
||||
|
||||
token = db.session.query(DatabaseUserOAuth2Tokens).one()
|
||||
assert token.user_id == 1
|
||||
assert token.database_id == 1
|
||||
assert token.access_token == "YYY"
|
||||
assert token.access_token_expiration == datetime(2024, 1, 1, 1, 0)
|
||||
assert token.refresh_token == "ZZZ"
|
||||
|
||||
|
||||
def test_oauth2_multiple_tokens(
|
||||
mocker: MockFixture,
|
||||
session: Session,
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
) -> None:
|
||||
"""
|
||||
Test the OAuth2 endpoint when a second token is added.
|
||||
"""
|
||||
from superset.databases.api import DatabaseRestApi
|
||||
from superset.models.core import Database, DatabaseUserOAuth2Tokens
|
||||
|
||||
DatabaseRestApi.datamodel.session = session
|
||||
|
||||
# create table for databases
|
||||
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
|
||||
db.session.add(
|
||||
Database(
|
||||
database_name="my_db",
|
||||
sqlalchemy_uri="sqlite://",
|
||||
uuid=UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"),
|
||||
)
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
get_oauth2_token = mocker.patch.object(SqliteEngineSpec, "get_oauth2_token")
|
||||
get_oauth2_token.side_effect = [
|
||||
{
|
||||
"access_token": "YYY",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "ZZZ",
|
||||
},
|
||||
{
|
||||
"access_token": "YYY2",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "ZZZ2",
|
||||
},
|
||||
]
|
||||
|
||||
state = {
|
||||
"user_id": 1,
|
||||
"database_id": 1,
|
||||
"tab_id": 42,
|
||||
}
|
||||
decode_oauth2_state = mocker.patch("superset.databases.api.decode_oauth2_state")
|
||||
decode_oauth2_state.return_value = state
|
||||
|
||||
mocker.patch("superset.databases.api.render_template", return_value="OK")
|
||||
|
||||
with freeze_time("2024-01-01T00:00:00Z"):
|
||||
response = client.get(
|
||||
"/api/v1/database/oauth2/",
|
||||
query_string={
|
||||
"state": "some%2Estate",
|
||||
"code": "XXX",
|
||||
},
|
||||
)
|
||||
|
||||
# second request should delete token from the first request
|
||||
response = client.get(
|
||||
"/api/v1/database/oauth2/",
|
||||
query_string={
|
||||
"state": "some%2Estate",
|
||||
"code": "XXX",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
tokens = db.session.query(DatabaseUserOAuth2Tokens).all()
|
||||
assert len(tokens) == 1
|
||||
token = tokens[0]
|
||||
assert token.access_token == "YYY2"
|
||||
assert token.refresh_token == "ZZZ2"
|
||||
|
||||
|
||||
def test_oauth2_error(
|
||||
mocker: MockFixture,
|
||||
session: Session,
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
) -> None:
|
||||
"""
|
||||
Test the OAuth2 endpoint when OAuth2 errors.
|
||||
"""
|
||||
response = client.get(
|
||||
"/api/v1/database/oauth2/",
|
||||
query_string={
|
||||
"error": "Something bad hapened",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
assert response.json == {
|
||||
"errors": [
|
||||
{
|
||||
"message": "Something went wrong while doing OAuth2",
|
||||
"error_type": "OAUTH2_REDIRECT_ERROR",
|
||||
"level": "error",
|
||||
"extra": {"error": "Something bad hapened"},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user