fix: skip DB filter when doing OAuth2 (#32486)

This commit is contained in:
Beto Dealmeida
2025-03-04 13:33:53 -05:00
committed by GitHub
parent c0e92b1639
commit 813e79fa9f
2 changed files with 71 additions and 1 deletions

View File

@@ -802,6 +802,76 @@ def test_oauth2_happy_path(
assert token.refresh_token == "ZZZ" # noqa: S105
def test_oauth2_permissions(
mocker: MockerFixture,
session: Session,
client: Any,
) -> None:
"""
Test the OAuth2 endpoint works for users without DB permissions.
Anyone should be able to authenticate with OAuth2, even if they don't have
permissions to read the database (which is needed to get the OAuth2 config).
"""
from superset.databases.api import DatabaseRestApi
from superset.models.core import Database, DatabaseUserOAuth2Tokens
DatabaseRestApi.datamodel.session = session
# create table for databases
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
db.session.add(
Database(
database_name="my_db",
sqlalchemy_uri="sqlite://",
uuid=UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"),
)
)
db.session.commit()
mocker.patch.object(
SqliteEngineSpec,
"get_oauth2_config",
return_value={"id": "one", "secret": "two"},
)
get_oauth2_token = mocker.patch.object(SqliteEngineSpec, "get_oauth2_token")
get_oauth2_token.return_value = {
"access_token": "YYY",
"expires_in": 3600,
"refresh_token": "ZZZ",
}
state = {
"user_id": 1,
"database_id": 1,
"tab_id": 42,
}
decode_oauth2_state = mocker.patch("superset.databases.api.decode_oauth2_state")
decode_oauth2_state.return_value = state
mocker.patch("superset.databases.api.render_template", return_value="OK")
with freeze_time("2024-01-01T00:00:00Z"):
response = client.get(
"/api/v1/database/oauth2/",
query_string={
"state": "some%2Estate",
"code": "XXX",
},
)
assert response.status_code == 200
decode_oauth2_state.assert_called_with("some%2Estate")
get_oauth2_token.assert_called_with({"id": "one", "secret": "two"}, "XXX")
token = db.session.query(DatabaseUserOAuth2Tokens).one()
assert token.user_id == 1
assert token.database_id == 1
assert token.access_token == "YYY" # noqa: S105
assert token.access_token_expiration == datetime(2024, 1, 1, 1, 0)
assert token.refresh_token == "ZZZ" # noqa: S105
def test_oauth2_multiple_tokens(
mocker: MockerFixture,
session: Session,