feat: OAuth2StoreTokenCommand (#32546)

This commit is contained in:
Beto Dealmeida
2025-03-13 09:45:24 -04:00
committed by GitHub
parent 29b4c40e43
commit 8695239372
5 changed files with 290 additions and 61 deletions

View File

@@ -19,7 +19,7 @@
from __future__ import annotations
import logging
from datetime import datetime, timedelta
from datetime import datetime
from io import BytesIO
from typing import Any, cast
from zipfile import is_zipfile, ZipFile
@@ -46,6 +46,7 @@ from superset.commands.database.exceptions import (
)
from superset.commands.database.export import ExportDatabasesCommand
from superset.commands.database.importers.dispatcher import ImportDatabasesCommand
from superset.commands.database.oauth2 import OAuth2StoreTokenCommand
from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelDatabasePortError,
@@ -72,7 +73,7 @@ from superset.commands.importers.exceptions import (
)
from superset.commands.importers.v1.utils import get_contents_from_bundle
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
from superset.daos.database import DatabaseDAO, DatabaseUserOAuth2TokensDAO
from superset.daos.database import DatabaseDAO
from superset.databases.decorators import check_table_access
from superset.databases.filters import DatabaseFilter, DatabaseUploadEnabledFilter
from superset.databases.schemas import (
@@ -109,7 +110,6 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
DatabaseNotFoundException,
InvalidPayloadSchemaError,
OAuth2Error,
OAuth2RedirectError,
SupersetErrorsException,
SupersetException,
@@ -1433,51 +1433,15 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
$ref: '#/components/responses/500'
"""
parameters = OAuth2ProviderResponseSchema().load(request.args)
command = OAuth2StoreTokenCommand(parameters)
command.run()
if "error" in parameters:
raise OAuth2Error(parameters["error"])
# note that when decoding the state we will perform JWT validation, preventing a
# malicious payload that would insert a bogus database token, or delete an
# existing one.
state = decode_oauth2_state(parameters["state"])
tab_id = state["tab_id"]
# exchange code for access/refresh tokens
database = DatabaseDAO.find_by_id(state["database_id"], skip_base_filter=True)
if database is None:
return self.response_404()
oauth2_config = database.get_oauth2_config()
if oauth2_config is None:
raise OAuth2Error("No configuration found for OAuth2")
token_response = database.db_engine_spec.get_oauth2_token(
oauth2_config,
parameters["code"],
)
# delete old tokens
existing = DatabaseUserOAuth2TokensDAO.find_one_or_none(
user_id=state["user_id"],
database_id=state["database_id"],
)
if existing:
DatabaseUserOAuth2TokensDAO.delete([existing])
# store tokens
expiration = datetime.now() + timedelta(seconds=token_response["expires_in"])
DatabaseUserOAuth2TokensDAO.create(
attributes={
"user_id": state["user_id"],
"database_id": state["database_id"],
"access_token": token_response["access_token"],
"access_token_expiration": expiration,
"refresh_token": token_response.get("refresh_token"),
},
)
# return blank page that closes itself
return make_response(
render_template("superset/oauth2.html", tab_id=state["tab_id"]),
render_template("superset/oauth2.html", tab_id=tab_id),
200,
)