mirror of
https://github.com/apache/superset.git
synced 2026-04-19 16:14:52 +00:00
feat: purge OAuth2 tokens when DB changes (#31164)
This commit is contained in:
@@ -23,6 +23,7 @@ import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||
from superset.errors import SupersetErrorType
|
||||
@@ -603,3 +604,82 @@ def test_engine_context_manager(mocker: MockerFixture) -> None:
|
||||
source=None,
|
||||
sqlalchemy_uri="trino://",
|
||||
)
|
||||
|
||||
|
||||
def test_purge_oauth2_tokens(session: Session) -> None:
|
||||
"""
|
||||
Test the `purge_oauth2_tokens` method.
|
||||
"""
|
||||
from flask_appbuilder.security.sqla.models import Role, User # noqa: F401
|
||||
|
||||
from superset.models.core import Database, DatabaseUserOAuth2Tokens
|
||||
|
||||
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
|
||||
|
||||
user = User(
|
||||
first_name="Alice",
|
||||
last_name="Doe",
|
||||
email="adoe@example.org",
|
||||
username="adoe",
|
||||
)
|
||||
session.add(user)
|
||||
session.flush()
|
||||
|
||||
database1 = Database(database_name="my_oauth2_db", sqlalchemy_uri="sqlite://")
|
||||
database2 = Database(database_name="my_other_oauth2_db", sqlalchemy_uri="sqlite://")
|
||||
session.add_all([database1, database2])
|
||||
session.flush()
|
||||
|
||||
tokens = [
|
||||
DatabaseUserOAuth2Tokens(
|
||||
user_id=user.id,
|
||||
database_id=database1.id,
|
||||
access_token="my_access_token",
|
||||
access_token_expiration=datetime(2023, 1, 1),
|
||||
refresh_token="my_refresh_token",
|
||||
),
|
||||
DatabaseUserOAuth2Tokens(
|
||||
user_id=user.id,
|
||||
database_id=database2.id,
|
||||
access_token="my_other_access_token",
|
||||
access_token_expiration=datetime(2024, 1, 1),
|
||||
refresh_token="my_other_refresh_token",
|
||||
),
|
||||
]
|
||||
session.add_all(tokens)
|
||||
session.flush()
|
||||
|
||||
assert len(session.query(DatabaseUserOAuth2Tokens).all()) == 2
|
||||
|
||||
token = (
|
||||
session.query(DatabaseUserOAuth2Tokens)
|
||||
.filter_by(database_id=database1.id)
|
||||
.one()
|
||||
)
|
||||
assert token.user_id == user.id
|
||||
assert token.database_id == database1.id
|
||||
assert token.access_token == "my_access_token"
|
||||
assert token.access_token_expiration == datetime(2023, 1, 1)
|
||||
assert token.refresh_token == "my_refresh_token"
|
||||
|
||||
database1.purge_oauth2_tokens()
|
||||
|
||||
# confirm token was deleted
|
||||
token = (
|
||||
session.query(DatabaseUserOAuth2Tokens)
|
||||
.filter_by(database_id=database1.id)
|
||||
.one_or_none()
|
||||
)
|
||||
assert token is None
|
||||
|
||||
# make sure other DB tokens weren't deleted
|
||||
token = (
|
||||
session.query(DatabaseUserOAuth2Tokens)
|
||||
.filter_by(database_id=database2.id)
|
||||
.one()
|
||||
)
|
||||
assert token is not None
|
||||
|
||||
# make sure database was not deleted... just in case
|
||||
database = session.query(Database).filter_by(id=database1.id).one()
|
||||
assert database.name == "my_oauth2_db"
|
||||
|
||||
Reference in New Issue
Block a user