feat: purge OAuth2 tokens when DB changes (#31164)

This commit is contained in:
Beto Dealmeida
2024-11-26 15:57:01 -05:00
committed by GitHub
parent f077323e6f
commit 68499a1199
5 changed files with 186 additions and 4 deletions

View File

@@ -23,6 +23,7 @@ import pytest
from pytest_mock import MockerFixture
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import make_url
from sqlalchemy.orm.session import Session
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.errors import SupersetErrorType
@@ -603,3 +604,82 @@ def test_engine_context_manager(mocker: MockerFixture) -> None:
source=None,
sqlalchemy_uri="trino://",
)
def test_purge_oauth2_tokens(session: Session) -> None:
"""
Test the `purge_oauth2_tokens` method.
"""
from flask_appbuilder.security.sqla.models import Role, User # noqa: F401
from superset.models.core import Database, DatabaseUserOAuth2Tokens
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
user = User(
first_name="Alice",
last_name="Doe",
email="adoe@example.org",
username="adoe",
)
session.add(user)
session.flush()
database1 = Database(database_name="my_oauth2_db", sqlalchemy_uri="sqlite://")
database2 = Database(database_name="my_other_oauth2_db", sqlalchemy_uri="sqlite://")
session.add_all([database1, database2])
session.flush()
tokens = [
DatabaseUserOAuth2Tokens(
user_id=user.id,
database_id=database1.id,
access_token="my_access_token",
access_token_expiration=datetime(2023, 1, 1),
refresh_token="my_refresh_token",
),
DatabaseUserOAuth2Tokens(
user_id=user.id,
database_id=database2.id,
access_token="my_other_access_token",
access_token_expiration=datetime(2024, 1, 1),
refresh_token="my_other_refresh_token",
),
]
session.add_all(tokens)
session.flush()
assert len(session.query(DatabaseUserOAuth2Tokens).all()) == 2
token = (
session.query(DatabaseUserOAuth2Tokens)
.filter_by(database_id=database1.id)
.one()
)
assert token.user_id == user.id
assert token.database_id == database1.id
assert token.access_token == "my_access_token"
assert token.access_token_expiration == datetime(2023, 1, 1)
assert token.refresh_token == "my_refresh_token"
database1.purge_oauth2_tokens()
# confirm token was deleted
token = (
session.query(DatabaseUserOAuth2Tokens)
.filter_by(database_id=database1.id)
.one_or_none()
)
assert token is None
# make sure other DB tokens weren't deleted
token = (
session.query(DatabaseUserOAuth2Tokens)
.filter_by(database_id=database2.id)
.one()
)
assert token is not None
# make sure database was not deleted... just in case
database = session.query(Database).filter_by(id=database1.id).one()
assert database.name == "my_oauth2_db"