mirror of
https://github.com/apache/superset.git
synced 2026-06-10 10:09:14 +00:00
Compare commits
1 Commits
ci/cypress
...
oauth-duri
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c8af308afb |
@@ -622,6 +622,15 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
||||
const [editNewDb, setEditNewDb] = useState<boolean>(false);
|
||||
const [isLoading, setLoading] = useState<boolean>(false);
|
||||
const [testInProgress, setTestInProgress] = useState<boolean>(false);
|
||||
// Stable id sent on every "Test Connection" call so that the backend can
|
||||
// correlate the wizard's OAuth2 dance across requests. Generated once per
|
||||
// modal session — keeping the same value lets the second test_connection
|
||||
// call (after the user finishes the OAuth2 dance) find the cached token.
|
||||
const oauth2TabIdRef = useRef<string>(
|
||||
typeof crypto !== 'undefined' && crypto.randomUUID
|
||||
? crypto.randomUUID()
|
||||
: Math.random().toString(36).slice(2),
|
||||
);
|
||||
const [passwords, setPasswords] = useState<Record<string, string>>({});
|
||||
const [sshTunnelPasswords, setSSHTunnelPasswords] = useState<
|
||||
Record<string, string>
|
||||
@@ -727,6 +736,11 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
||||
}, [setValidationErrors, setHasValidated, clearError]);
|
||||
|
||||
// Test Connection logic
|
||||
// keep the latest ``testConnection`` callable accessible to long-lived
|
||||
// listeners (BroadcastChannel/storage) without making them depend on every
|
||||
// render of the modal.
|
||||
const testConnectionRef = useRef<() => void>(() => {});
|
||||
|
||||
const testConnection = () => {
|
||||
handleClearValidationErrors();
|
||||
if (!db?.sqlalchemy_uri) {
|
||||
@@ -748,6 +762,7 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
||||
server_port: Number(db.ssh_tunnel!.server_port),
|
||||
}
|
||||
: undefined,
|
||||
oauth2_tab_id: oauth2TabIdRef.current,
|
||||
};
|
||||
setTestInProgress(true);
|
||||
testDatabaseConnection(
|
||||
@@ -764,6 +779,47 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
||||
},
|
||||
);
|
||||
};
|
||||
testConnectionRef.current = testConnection;
|
||||
|
||||
// Re-run "Test Connection" automatically when the OAuth2 dance completes
|
||||
// in the popup tab. The dance posts a message on the ``oauth`` broadcast
|
||||
// channel (and a localStorage event for cross-context delivery) carrying
|
||||
// the wizard's own tab_id; we react only to our own.
|
||||
useEffect(() => {
|
||||
const tabId = oauth2TabIdRef.current;
|
||||
|
||||
const handleComplete = (incomingTabId?: string) => {
|
||||
if (incomingTabId === tabId) {
|
||||
testConnectionRef.current();
|
||||
}
|
||||
};
|
||||
|
||||
const channel =
|
||||
typeof BroadcastChannel !== 'undefined'
|
||||
? new BroadcastChannel('oauth')
|
||||
: null;
|
||||
if (channel) {
|
||||
channel.onmessage = event => handleComplete(event.data?.tabId);
|
||||
}
|
||||
|
||||
const handleStorage = (event: StorageEvent) => {
|
||||
if (event.key !== 'oauth2_auth_complete' || !event.newValue) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const payload = JSON.parse(event.newValue) as { tabId?: string };
|
||||
handleComplete(payload.tabId);
|
||||
} catch {
|
||||
/* ignore */
|
||||
}
|
||||
};
|
||||
window.addEventListener('storage', handleStorage);
|
||||
|
||||
return () => {
|
||||
window.removeEventListener('storage', handleStorage);
|
||||
channel?.close();
|
||||
};
|
||||
}, []);
|
||||
|
||||
const getPlaceholder = (field: string) => {
|
||||
if (field === 'database') {
|
||||
|
||||
@@ -25,66 +25,157 @@ from superset.commands.database.exceptions import DatabaseNotFoundError
|
||||
from superset.daos.database import DatabaseUserOAuth2TokensDAO
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
from superset.databases.schemas import OAuth2ProviderResponseSchema
|
||||
from superset.db_engine_specs import get_engine_spec
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.exceptions import OAuth2Error
|
||||
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
|
||||
from superset.models.core import Database, DatabaseUserOAuth2Tokens
|
||||
from superset.superset_typing import OAuth2State
|
||||
from superset.superset_typing import (
|
||||
OAuth2ClientConfig,
|
||||
OAuth2State,
|
||||
OAuth2TokenResponse,
|
||||
)
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
from superset.utils.oauth2 import decode_oauth2_state
|
||||
|
||||
# how long the pre-create token cache lives in the KV store
|
||||
PRE_CREATE_TOKEN_TTL = timedelta(minutes=5)
|
||||
|
||||
|
||||
class OAuth2StoreTokenCommand(BaseCommand):
|
||||
"""
|
||||
Command to store OAuth2 tokens in the database.
|
||||
Command to store OAuth2 tokens.
|
||||
|
||||
Normal flow: the OAuth2 callback resolves the database via ``state.database_id``
|
||||
and persists access/refresh tokens to ``database_user_oauth2_tokens``.
|
||||
|
||||
Pre-create flow: when ``state.database_id`` is ``None`` (the database hasn't
|
||||
been saved yet — typically the "Create database" wizard), the command reads
|
||||
the OAuth2 client config and engine name from the KV store entry that
|
||||
:meth:`BaseEngineSpec.start_oauth2_dance` stashed there, exchanges the code,
|
||||
and caches the resulting access token in the same KV entry for
|
||||
:func:`get_oauth2_access_token` to pick up on the retry.
|
||||
"""
|
||||
|
||||
def __init__(self, parameters: OAuth2ProviderResponseSchema):
|
||||
self._parameters = parameters
|
||||
self._state: OAuth2State | None = None
|
||||
self._database: Database | None = None
|
||||
self._oauth2_config: OAuth2ClientConfig | None = None
|
||||
self._engine_spec: type[BaseEngineSpec] | None = None
|
||||
self._tab_uuid: UUID | None = None
|
||||
|
||||
@transaction(on_error=partial(on_error, reraise=OAuth2Error))
|
||||
def run(self) -> DatabaseUserOAuth2Tokens:
|
||||
def run(self) -> DatabaseUserOAuth2Tokens | None:
|
||||
self.validate()
|
||||
self._database = cast(Database, self._database)
|
||||
self._state = cast(OAuth2State, self._state)
|
||||
|
||||
oauth2_config = self._database.get_oauth2_config()
|
||||
if oauth2_config is None:
|
||||
raise OAuth2Error("No configuration found for OAuth2")
|
||||
self._oauth2_config = cast(OAuth2ClientConfig, self._oauth2_config)
|
||||
self._engine_spec = cast(type[BaseEngineSpec], self._engine_spec)
|
||||
|
||||
# Look up PKCE code_verifier from KV store (RFC 7636)
|
||||
code_verifier = None
|
||||
tab_id = self._state["tab_id"]
|
||||
try:
|
||||
tab_uuid = UUID(tab_id)
|
||||
except ValueError:
|
||||
tab_uuid = None
|
||||
code_verifier = self._pop_code_verifier()
|
||||
|
||||
if tab_uuid:
|
||||
kv_value = KeyValueDAO.get_value(
|
||||
resource=KeyValueResource.PKCE_CODE_VERIFIER,
|
||||
key=tab_uuid,
|
||||
codec=JsonKeyValueCodec(),
|
||||
)
|
||||
if kv_value:
|
||||
code_verifier = kv_value.get("code_verifier")
|
||||
KeyValueDAO.delete_entry(KeyValueResource.PKCE_CODE_VERIFIER, tab_uuid)
|
||||
|
||||
token_response = self._database.db_engine_spec.get_oauth2_token(
|
||||
oauth2_config,
|
||||
token_response = self._engine_spec.get_oauth2_token(
|
||||
self._oauth2_config,
|
||||
self._parameters["code"],
|
||||
code_verifier=code_verifier,
|
||||
)
|
||||
|
||||
# delete old tokens
|
||||
if self._database is None:
|
||||
# Pre-create flow: cache the access token in the KV entry the
|
||||
# initial dance created. The retry of "Test Connection" will read
|
||||
# it via ``get_oauth2_access_token``.
|
||||
self._cache_pre_create_token(token_response)
|
||||
return None
|
||||
|
||||
return self._persist_token(token_response)
|
||||
|
||||
def validate(self) -> None:
|
||||
if error := self._parameters.get("error"):
|
||||
raise OAuth2Error(error)
|
||||
|
||||
self._state = decode_oauth2_state(self._parameters["state"])
|
||||
|
||||
try:
|
||||
self._tab_uuid = UUID(self._state["tab_id"])
|
||||
except (KeyError, ValueError):
|
||||
# Legacy paths may use non-UUID tab ids; we still want to support
|
||||
# them when ``database_id`` is set. The pre-create path below
|
||||
# requires a valid UUID.
|
||||
self._tab_uuid = None
|
||||
|
||||
if database_id := self._state.get("database_id"):
|
||||
self._database = DatabaseUserOAuth2TokensDAO.get_database(database_id)
|
||||
if self._database is None:
|
||||
raise DatabaseNotFoundError("Database not found")
|
||||
self._oauth2_config = self._database.get_oauth2_config()
|
||||
self._engine_spec = self._database.db_engine_spec
|
||||
else:
|
||||
if self._tab_uuid is None:
|
||||
raise OAuth2Error(
|
||||
"Pre-create OAuth2 callback requires a UUID tab_id",
|
||||
)
|
||||
cached = KeyValueDAO.get_value(
|
||||
resource=KeyValueResource.OAUTH2_PRE_CREATE_TOKEN,
|
||||
key=self._tab_uuid,
|
||||
codec=JsonKeyValueCodec(),
|
||||
)
|
||||
if not cached or not cached.get("config"):
|
||||
raise OAuth2Error("Pre-create OAuth2 context not found or expired")
|
||||
self._oauth2_config = cast(OAuth2ClientConfig, cached["config"])
|
||||
engine = self._state.get("engine") or cached.get("engine")
|
||||
if not engine:
|
||||
raise OAuth2Error("Pre-create OAuth2 context missing engine name")
|
||||
self._engine_spec = get_engine_spec(engine)
|
||||
|
||||
if self._oauth2_config is None:
|
||||
raise OAuth2Error("No configuration found for OAuth2")
|
||||
|
||||
def _pop_code_verifier(self) -> str | None:
|
||||
if self._tab_uuid is None:
|
||||
return None
|
||||
kv_value = KeyValueDAO.get_value(
|
||||
resource=KeyValueResource.PKCE_CODE_VERIFIER,
|
||||
key=self._tab_uuid,
|
||||
codec=JsonKeyValueCodec(),
|
||||
)
|
||||
if not kv_value:
|
||||
return None
|
||||
KeyValueDAO.delete_entry(KeyValueResource.PKCE_CODE_VERIFIER, self._tab_uuid)
|
||||
return kv_value.get("code_verifier")
|
||||
|
||||
def _cache_pre_create_token(self, token_response: OAuth2TokenResponse) -> None:
|
||||
self._state = cast(OAuth2State, self._state)
|
||||
self._tab_uuid = cast(UUID, self._tab_uuid)
|
||||
self._oauth2_config = cast(OAuth2ClientConfig, self._oauth2_config)
|
||||
self._engine_spec = cast(type[BaseEngineSpec], self._engine_spec)
|
||||
|
||||
expires_on = datetime.now() + PRE_CREATE_TOKEN_TTL
|
||||
KeyValueDAO.upsert_entry(
|
||||
resource=KeyValueResource.OAUTH2_PRE_CREATE_TOKEN,
|
||||
key=self._tab_uuid,
|
||||
value={
|
||||
"engine": self._engine_spec.engine,
|
||||
"config": self._oauth2_config,
|
||||
"user_id": self._state["user_id"],
|
||||
"access_token": token_response["access_token"],
|
||||
},
|
||||
codec=JsonKeyValueCodec(),
|
||||
expires_on=expires_on,
|
||||
)
|
||||
|
||||
def _persist_token(
|
||||
self,
|
||||
token_response: OAuth2TokenResponse,
|
||||
) -> DatabaseUserOAuth2Tokens:
|
||||
self._state = cast(OAuth2State, self._state)
|
||||
|
||||
if existing := DatabaseUserOAuth2TokensDAO.find_one_or_none(
|
||||
user_id=self._state["user_id"],
|
||||
database_id=self._state["database_id"],
|
||||
):
|
||||
DatabaseUserOAuth2TokensDAO.delete([existing])
|
||||
|
||||
# store tokens
|
||||
expiration = datetime.now() + timedelta(seconds=token_response["expires_in"])
|
||||
return DatabaseUserOAuth2TokensDAO.create(
|
||||
attributes={
|
||||
@@ -95,16 +186,3 @@ class OAuth2StoreTokenCommand(BaseCommand):
|
||||
"refresh_token": token_response.get("refresh_token"),
|
||||
},
|
||||
)
|
||||
|
||||
def validate(self) -> None:
|
||||
if error := self._parameters.get("error"):
|
||||
raise OAuth2Error(error)
|
||||
|
||||
self._state = decode_oauth2_state(self._parameters["state"])
|
||||
|
||||
if database := DatabaseUserOAuth2TokensDAO.get_database(
|
||||
self._state["database_id"]
|
||||
):
|
||||
self._database = database
|
||||
else:
|
||||
raise DatabaseNotFoundError("Database not found")
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import g
|
||||
from flask_babel import gettext as _
|
||||
from sqlalchemy.exc import DBAPIError, NoSuchModuleError
|
||||
|
||||
@@ -94,6 +95,13 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
||||
self.validate()
|
||||
ex_str = ""
|
||||
|
||||
# Surface the wizard's tab_id (sent by the frontend) so that
|
||||
# ``get_oauth2_access_token`` can find the pre-create OAuth2 token
|
||||
# cached in the KV store, and so that ``start_oauth2_dance`` reuses
|
||||
# this id instead of generating a new one.
|
||||
if oauth2_tab_id := self._properties.get("oauth2_tab_id"):
|
||||
g.oauth2_tab_id = oauth2_tab_id
|
||||
|
||||
url = make_url_safe(self._uri)
|
||||
engine_name = url.get_backend_name()
|
||||
|
||||
|
||||
@@ -623,6 +623,16 @@ class DatabaseTestConnectionSchema(DatabaseParametersSchemaMixin, Schema):
|
||||
)
|
||||
|
||||
ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
|
||||
oauth2_tab_id = fields.String(
|
||||
metadata={
|
||||
"description": (
|
||||
"UUID identifying the wizard tab for the pre-create OAuth2 flow."
|
||||
" Optional; when supplied, the engine will look up the pre-create"
|
||||
" OAuth2 token cached under this key."
|
||||
),
|
||||
},
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
|
||||
class TableMetadataOptionsResponse(TypedDict):
|
||||
|
||||
@@ -664,7 +664,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
# Prevent circular import.
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
|
||||
tab_id = str(uuid4())
|
||||
# Reuse the wizard-supplied tab_id when present so the retry of
|
||||
# "Test Connection" can find the same KV-cached token. Otherwise we
|
||||
# generate a fresh one as before.
|
||||
tab_id = getattr(g, "oauth2_tab_id", None) or str(uuid4())
|
||||
default_redirect_uri = get_oauth2_redirect_uri()
|
||||
|
||||
# Generate PKCE code verifier (RFC 7636)
|
||||
@@ -690,6 +693,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
# belongs to.
|
||||
state: OAuth2State = {
|
||||
# Database ID and user ID are the primary key associated with the token.
|
||||
# ``database_id`` is ``None`` during the "Create database" wizard — the
|
||||
# callback caches the access token in the KV store instead of inserting
|
||||
# a row in ``database_user_oauth2_tokens``.
|
||||
"database_id": database.id,
|
||||
"user_id": g.user.id,
|
||||
# In multi-instance deployments there might be a single proxy handling
|
||||
@@ -710,6 +716,25 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
if oauth2_config is None:
|
||||
raise OAuth2Error("No configuration found for OAuth2")
|
||||
|
||||
# Pre-create flow: the database has no row yet, so we can't look it up by id
|
||||
# in the callback. Stash the engine spec + oauth2 config in the KV store
|
||||
# alongside the existing PKCE entry, keyed by the same ``tab_id``. The
|
||||
# callback reads this to exchange the code without a persisted ``Database``.
|
||||
if database.id is None:
|
||||
state["engine"] = cls.engine
|
||||
KeyValueDAO.delete_expired_entries(KeyValueResource.OAUTH2_PRE_CREATE_TOKEN)
|
||||
KeyValueDAO.create_entry(
|
||||
resource=KeyValueResource.OAUTH2_PRE_CREATE_TOKEN,
|
||||
value={
|
||||
"engine": cls.engine,
|
||||
"config": oauth2_config,
|
||||
},
|
||||
codec=JsonKeyValueCodec(),
|
||||
key=UUID(tab_id),
|
||||
expires_on=datetime.now() + timedelta(minutes=5),
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
oauth_url = cls.get_oauth2_authorization_uri(
|
||||
oauth2_config,
|
||||
state,
|
||||
|
||||
@@ -46,6 +46,7 @@ class KeyValueResource(StrEnum):
|
||||
METASTORE_CACHE = "superset_metastore_cache"
|
||||
LOCK = "lock"
|
||||
PKCE_CODE_VERIFIER = "pkce_code_verifier"
|
||||
OAUTH2_PRE_CREATE_TOKEN = "oauth2_pre_create_token" # noqa: S105
|
||||
SQLLAB_PERMALINK = "sqllab_permalink"
|
||||
|
||||
|
||||
|
||||
@@ -402,7 +402,15 @@ class OAuth2State(TypedDict, total=False):
|
||||
Type for the state passed during OAuth2.
|
||||
"""
|
||||
|
||||
database_id: int
|
||||
# ``database_id`` is ``None`` during the "Create database" wizard, where the
|
||||
# OAuth2 dance runs before the database has been persisted. In that case the
|
||||
# access token is cached in the KV store keyed by ``tab_id`` until the user
|
||||
# saves the database.
|
||||
database_id: int | None
|
||||
user_id: int
|
||||
default_redirect_uri: str
|
||||
tab_id: str
|
||||
# Engine backend code (e.g. ``"semanticapi"``), present only for pre-create
|
||||
# dances so the callback can resolve the engine spec without a persisted
|
||||
# database row.
|
||||
engine: str
|
||||
|
||||
@@ -24,10 +24,11 @@ import secrets
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Iterator, TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
import backoff
|
||||
import jwt
|
||||
from flask import current_app as app, url_for
|
||||
from flask import current_app as app, g, url_for
|
||||
from marshmallow import EXCLUDE, fields, post_load, Schema, validate
|
||||
from werkzeug.routing import BuildError
|
||||
|
||||
@@ -87,7 +88,7 @@ def generate_code_challenge(code_verifier: str) -> str:
|
||||
)
|
||||
def get_oauth2_access_token(
|
||||
config: OAuth2ClientConfig,
|
||||
database_id: int,
|
||||
database_id: int | None,
|
||||
user_id: int,
|
||||
db_engine_spec: type[BaseEngineSpec],
|
||||
) -> str | None:
|
||||
@@ -100,10 +101,18 @@ def get_oauth2_access_token(
|
||||
simultaneous requests for refreshing a stale token; in that case only the first
|
||||
process to acquire the lock will perform the refresh, and othe process should find a
|
||||
a valid token when they retry.
|
||||
|
||||
When ``database_id`` is ``None`` (the "Create database" wizard, where the
|
||||
database hasn't been persisted yet), look up the token in the KV store under
|
||||
``g.oauth2_tab_id``. The token was cached there by
|
||||
:class:`OAuth2StoreTokenCommand` after the wizard's OAuth2 dance.
|
||||
""" # noqa: E501
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.models.core import DatabaseUserOAuth2Tokens
|
||||
|
||||
if not database_id:
|
||||
return _get_pre_create_access_token(user_id)
|
||||
|
||||
token = (
|
||||
db.session.query(DatabaseUserOAuth2Tokens)
|
||||
.filter_by(user_id=user_id, database_id=database_id)
|
||||
@@ -124,6 +133,39 @@ def get_oauth2_access_token(
|
||||
return None
|
||||
|
||||
|
||||
def _get_pre_create_access_token(user_id: int) -> str | None:
|
||||
"""
|
||||
Look up a pre-create OAuth2 token in the KV store.
|
||||
|
||||
The KV entry is keyed by the wizard's ``tab_id``; the frontend passes it on
|
||||
every "Test Connection" request, and the test-connection command exposes it
|
||||
via :data:`flask.g.oauth2_tab_id`.
|
||||
"""
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
|
||||
|
||||
tab_id = getattr(g, "oauth2_tab_id", None)
|
||||
if not tab_id:
|
||||
return None
|
||||
try:
|
||||
tab_uuid = UUID(tab_id)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
cached = KeyValueDAO.get_value(
|
||||
resource=KeyValueResource.OAUTH2_PRE_CREATE_TOKEN,
|
||||
key=tab_uuid,
|
||||
codec=JsonKeyValueCodec(),
|
||||
)
|
||||
if not cached:
|
||||
return None
|
||||
if cached.get("user_id") != user_id:
|
||||
# the tab_id is the user's own, but be defensive
|
||||
return None
|
||||
return cached.get("access_token")
|
||||
|
||||
|
||||
def refresh_oauth2_token(
|
||||
config: OAuth2ClientConfig,
|
||||
database_id: int,
|
||||
@@ -203,14 +245,20 @@ def refresh_oauth2_token(
|
||||
def encode_oauth2_state(state: OAuth2State) -> str:
|
||||
"""
|
||||
Encode the OAuth2 state.
|
||||
|
||||
``database_id`` is ``None`` for the pre-create-database OAuth2 dance,
|
||||
which means the token won't be persisted to ``database_user_oauth2_tokens``
|
||||
when the callback runs — see :class:`OAuth2StoreTokenCommand`.
|
||||
"""
|
||||
payload: dict[str, Any] = {
|
||||
"exp": datetime.now(tz=timezone.utc) + JWT_EXPIRATION,
|
||||
"database_id": state["database_id"],
|
||||
"database_id": state.get("database_id"),
|
||||
"user_id": state["user_id"],
|
||||
"default_redirect_uri": state["default_redirect_uri"],
|
||||
"tab_id": state["tab_id"],
|
||||
}
|
||||
if engine := state.get("engine"):
|
||||
payload["engine"] = engine
|
||||
|
||||
encoded_state = jwt.encode(
|
||||
payload=payload,
|
||||
@@ -225,10 +273,11 @@ def encode_oauth2_state(state: OAuth2State) -> str:
|
||||
|
||||
|
||||
class OAuth2StateSchema(Schema):
|
||||
database_id = fields.Int(required=True)
|
||||
database_id = fields.Int(required=True, allow_none=True)
|
||||
user_id = fields.Int(required=True)
|
||||
default_redirect_uri = fields.Str(required=True)
|
||||
tab_id = fields.Str(required=True)
|
||||
engine = fields.Str(required=False, load_default=None)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@post_load
|
||||
@@ -237,12 +286,15 @@ class OAuth2StateSchema(Schema):
|
||||
data: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> OAuth2State:
|
||||
return {
|
||||
state: OAuth2State = {
|
||||
"database_id": data["database_id"],
|
||||
"user_id": data["user_id"],
|
||||
"default_redirect_uri": data["default_redirect_uri"],
|
||||
"tab_id": data["tab_id"],
|
||||
}
|
||||
if data.get("engine"):
|
||||
state["engine"] = data["engine"]
|
||||
return state
|
||||
|
||||
class Meta: # pylint: disable=too-few-public-methods
|
||||
# ignore `exp`
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@@ -27,6 +27,7 @@ from superset.daos.database import DatabaseUserOAuth2TokensDAO
|
||||
from superset.databases.schemas import OAuth2ProviderResponseSchema
|
||||
from superset.exceptions import OAuth2Error
|
||||
from superset.models.core import Database
|
||||
from superset.superset_typing import OAuth2State
|
||||
from superset.utils.oauth2 import decode_oauth2_state, encode_oauth2_state
|
||||
|
||||
|
||||
@@ -135,6 +136,80 @@ def test_run_success(
|
||||
mock_create.assert_called_once()
|
||||
|
||||
|
||||
def test_run_pre_create_caches_token_in_kv(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
With ``state.database_id is None`` the command reads the engine + config
|
||||
from the pre-create KV entry, exchanges the code, and writes the token
|
||||
back to KV — *not* to ``database_user_oauth2_tokens``.
|
||||
"""
|
||||
state: OAuth2State = {
|
||||
"user_id": 1,
|
||||
"database_id": None,
|
||||
"default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
|
||||
"tab_id": "3a3a3a3a-3a3a-3a3a-3a3a-3a3a3a3a3a3a",
|
||||
"engine": "semanticapi",
|
||||
}
|
||||
parameters: dict[str, Any] = {
|
||||
"code": "the-code",
|
||||
"state": encode_oauth2_state(state),
|
||||
}
|
||||
|
||||
kv_dao = mocker.patch("superset.commands.database.oauth2.KeyValueDAO")
|
||||
kv_dao.get_value.side_effect = [
|
||||
{"engine": "semanticapi", "config": {"id": "x", "secret": "y"}}, # validate()
|
||||
None, # code_verifier lookup
|
||||
]
|
||||
|
||||
engine_spec = mocker.MagicMock()
|
||||
engine_spec.engine = "semanticapi"
|
||||
engine_spec.get_oauth2_token.return_value = {
|
||||
"access_token": "fresh-token",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "refresh",
|
||||
}
|
||||
mocker.patch(
|
||||
"superset.commands.database.oauth2.get_engine_spec",
|
||||
return_value=engine_spec,
|
||||
)
|
||||
|
||||
dao_get_db = mocker.patch.object(DatabaseUserOAuth2TokensDAO, "get_database")
|
||||
dao_create = mocker.patch.object(DatabaseUserOAuth2TokensDAO, "create")
|
||||
|
||||
result = OAuth2StoreTokenCommand(
|
||||
cast(OAuth2ProviderResponseSchema, parameters),
|
||||
).run()
|
||||
|
||||
assert result is None
|
||||
dao_get_db.assert_not_called()
|
||||
dao_create.assert_not_called()
|
||||
kv_dao.upsert_entry.assert_called_once()
|
||||
upsert_kwargs = kv_dao.upsert_entry.call_args.kwargs
|
||||
assert upsert_kwargs["value"]["access_token"] == "fresh-token" # noqa: S105
|
||||
assert upsert_kwargs["value"]["user_id"] == 1
|
||||
|
||||
|
||||
def test_run_pre_create_missing_kv_entry(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Pre-create flow with no cached entry should fail with a clear OAuth2Error.
|
||||
"""
|
||||
state: OAuth2State = {
|
||||
"user_id": 1,
|
||||
"database_id": None,
|
||||
"default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
|
||||
"tab_id": "3a3a3a3a-3a3a-3a3a-3a3a-3a3a3a3a3a3a",
|
||||
}
|
||||
parameters: dict[str, Any] = {"code": "x", "state": encode_oauth2_state(state)}
|
||||
|
||||
kv_dao = mocker.patch("superset.commands.database.oauth2.KeyValueDAO")
|
||||
kv_dao.get_value.return_value = None
|
||||
|
||||
with pytest.raises(OAuth2Error) as exc_info:
|
||||
OAuth2StoreTokenCommand(
|
||||
cast(OAuth2ProviderResponseSchema, parameters),
|
||||
).validate()
|
||||
assert "Pre-create OAuth2 context" in exc_info.value.error.extra["error"]
|
||||
|
||||
|
||||
def test_run_existing_token(
|
||||
mocker: MockerFixture,
|
||||
mock_database: MagicMock,
|
||||
|
||||
@@ -458,6 +458,83 @@ def test_encode_decode_oauth2_state(
|
||||
assert decoded["user_id"] == 2
|
||||
|
||||
|
||||
def test_encode_decode_oauth2_state_pre_create(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
The pre-create dance encodes ``database_id=None`` and an ``engine``
|
||||
field; both must survive the round-trip so the callback can resolve the
|
||||
engine spec without a persisted database.
|
||||
"""
|
||||
from superset.superset_typing import OAuth2State
|
||||
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
{"SECRET_KEY": "test-secret-key", "DATABASE_OAUTH2_JWT_ALGORITHM": "HS256"},
|
||||
)
|
||||
|
||||
state: OAuth2State = {
|
||||
"database_id": None,
|
||||
"user_id": 2,
|
||||
"default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
|
||||
"tab_id": "abc",
|
||||
"engine": "semanticapi",
|
||||
}
|
||||
with freeze_time("2024-01-01"):
|
||||
decoded = decode_oauth2_state(encode_oauth2_state(state))
|
||||
|
||||
assert decoded["database_id"] is None
|
||||
assert decoded["engine"] == "semanticapi"
|
||||
|
||||
|
||||
def test_get_oauth2_access_token_pre_create_hit(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
With ``database_id=None`` and a matching KV entry, the cached pre-create
|
||||
token is returned.
|
||||
"""
|
||||
mocker.patch(
|
||||
"superset.utils.oauth2.g",
|
||||
oauth2_tab_id="3a3a3a3a-3a3a-3a3a-3a3a-3a3a3a3a3a3a",
|
||||
)
|
||||
kv_dao = mocker.patch("superset.daos.key_value.KeyValueDAO")
|
||||
kv_dao.get_value.return_value = {
|
||||
"access_token": "cached-token",
|
||||
"user_id": 7,
|
||||
}
|
||||
|
||||
assert (
|
||||
get_oauth2_access_token(DUMMY_OAUTH2_CONFIG, None, 7, mocker.MagicMock())
|
||||
== "cached-token"
|
||||
)
|
||||
|
||||
|
||||
def test_get_oauth2_access_token_pre_create_miss(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Without a tab id (or with no KV entry / wrong user) the lookup returns None.
|
||||
"""
|
||||
mocker.patch("superset.utils.oauth2.g", oauth2_tab_id=None)
|
||||
assert (
|
||||
get_oauth2_access_token(DUMMY_OAUTH2_CONFIG, None, 7, mocker.MagicMock())
|
||||
is None
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"superset.utils.oauth2.g",
|
||||
oauth2_tab_id="3a3a3a3a-3a3a-3a3a-3a3a-3a3a3a3a3a3a",
|
||||
)
|
||||
kv_dao = mocker.patch("superset.daos.key_value.KeyValueDAO")
|
||||
kv_dao.get_value.return_value = None
|
||||
assert (
|
||||
get_oauth2_access_token(DUMMY_OAUTH2_CONFIG, None, 7, mocker.MagicMock())
|
||||
is None
|
||||
)
|
||||
|
||||
kv_dao.get_value.return_value = {"access_token": "x", "user_id": 99}
|
||||
# token belongs to another user
|
||||
assert (
|
||||
get_oauth2_access_token(DUMMY_OAUTH2_CONFIG, None, 7, mocker.MagicMock())
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
def test_get_oauth2_access_token_lock_not_acquired_no_error_log(
|
||||
mocker: MockerFixture,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
|
||||
Reference in New Issue
Block a user