mirror of
https://github.com/apache/superset.git
synced 2026-07-03 05:15:35 +00:00
Compare commits
1 Commits
chore/ci/s
...
oauth-duri
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c8af308afb |
@@ -622,6 +622,15 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
|||||||
const [editNewDb, setEditNewDb] = useState<boolean>(false);
|
const [editNewDb, setEditNewDb] = useState<boolean>(false);
|
||||||
const [isLoading, setLoading] = useState<boolean>(false);
|
const [isLoading, setLoading] = useState<boolean>(false);
|
||||||
const [testInProgress, setTestInProgress] = useState<boolean>(false);
|
const [testInProgress, setTestInProgress] = useState<boolean>(false);
|
||||||
|
// Stable id sent on every "Test Connection" call so that the backend can
|
||||||
|
// correlate the wizard's OAuth2 dance across requests. Generated once per
|
||||||
|
// modal session — keeping the same value lets the second test_connection
|
||||||
|
// call (after the user finishes the OAuth2 dance) find the cached token.
|
||||||
|
const oauth2TabIdRef = useRef<string>(
|
||||||
|
typeof crypto !== 'undefined' && crypto.randomUUID
|
||||||
|
? crypto.randomUUID()
|
||||||
|
: Math.random().toString(36).slice(2),
|
||||||
|
);
|
||||||
const [passwords, setPasswords] = useState<Record<string, string>>({});
|
const [passwords, setPasswords] = useState<Record<string, string>>({});
|
||||||
const [sshTunnelPasswords, setSSHTunnelPasswords] = useState<
|
const [sshTunnelPasswords, setSSHTunnelPasswords] = useState<
|
||||||
Record<string, string>
|
Record<string, string>
|
||||||
@@ -727,6 +736,11 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
|||||||
}, [setValidationErrors, setHasValidated, clearError]);
|
}, [setValidationErrors, setHasValidated, clearError]);
|
||||||
|
|
||||||
// Test Connection logic
|
// Test Connection logic
|
||||||
|
// keep the latest ``testConnection`` callable accessible to long-lived
|
||||||
|
// listeners (BroadcastChannel/storage) without making them depend on every
|
||||||
|
// render of the modal.
|
||||||
|
const testConnectionRef = useRef<() => void>(() => {});
|
||||||
|
|
||||||
const testConnection = () => {
|
const testConnection = () => {
|
||||||
handleClearValidationErrors();
|
handleClearValidationErrors();
|
||||||
if (!db?.sqlalchemy_uri) {
|
if (!db?.sqlalchemy_uri) {
|
||||||
@@ -748,6 +762,7 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
|||||||
server_port: Number(db.ssh_tunnel!.server_port),
|
server_port: Number(db.ssh_tunnel!.server_port),
|
||||||
}
|
}
|
||||||
: undefined,
|
: undefined,
|
||||||
|
oauth2_tab_id: oauth2TabIdRef.current,
|
||||||
};
|
};
|
||||||
setTestInProgress(true);
|
setTestInProgress(true);
|
||||||
testDatabaseConnection(
|
testDatabaseConnection(
|
||||||
@@ -764,6 +779,47 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
testConnectionRef.current = testConnection;
|
||||||
|
|
||||||
|
// Re-run "Test Connection" automatically when the OAuth2 dance completes
|
||||||
|
// in the popup tab. The dance posts a message on the ``oauth`` broadcast
|
||||||
|
// channel (and a localStorage event for cross-context delivery) carrying
|
||||||
|
// the wizard's own tab_id; we react only to our own.
|
||||||
|
useEffect(() => {
|
||||||
|
const tabId = oauth2TabIdRef.current;
|
||||||
|
|
||||||
|
const handleComplete = (incomingTabId?: string) => {
|
||||||
|
if (incomingTabId === tabId) {
|
||||||
|
testConnectionRef.current();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const channel =
|
||||||
|
typeof BroadcastChannel !== 'undefined'
|
||||||
|
? new BroadcastChannel('oauth')
|
||||||
|
: null;
|
||||||
|
if (channel) {
|
||||||
|
channel.onmessage = event => handleComplete(event.data?.tabId);
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleStorage = (event: StorageEvent) => {
|
||||||
|
if (event.key !== 'oauth2_auth_complete' || !event.newValue) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
const payload = JSON.parse(event.newValue) as { tabId?: string };
|
||||||
|
handleComplete(payload.tabId);
|
||||||
|
} catch {
|
||||||
|
/* ignore */
|
||||||
|
}
|
||||||
|
};
|
||||||
|
window.addEventListener('storage', handleStorage);
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
window.removeEventListener('storage', handleStorage);
|
||||||
|
channel?.close();
|
||||||
|
};
|
||||||
|
}, []);
|
||||||
|
|
||||||
const getPlaceholder = (field: string) => {
|
const getPlaceholder = (field: string) => {
|
||||||
if (field === 'database') {
|
if (field === 'database') {
|
||||||
|
|||||||
@@ -25,66 +25,157 @@ from superset.commands.database.exceptions import DatabaseNotFoundError
|
|||||||
from superset.daos.database import DatabaseUserOAuth2TokensDAO
|
from superset.daos.database import DatabaseUserOAuth2TokensDAO
|
||||||
from superset.daos.key_value import KeyValueDAO
|
from superset.daos.key_value import KeyValueDAO
|
||||||
from superset.databases.schemas import OAuth2ProviderResponseSchema
|
from superset.databases.schemas import OAuth2ProviderResponseSchema
|
||||||
|
from superset.db_engine_specs import get_engine_spec
|
||||||
|
from superset.db_engine_specs.base import BaseEngineSpec
|
||||||
from superset.exceptions import OAuth2Error
|
from superset.exceptions import OAuth2Error
|
||||||
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
|
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
|
||||||
from superset.models.core import Database, DatabaseUserOAuth2Tokens
|
from superset.models.core import Database, DatabaseUserOAuth2Tokens
|
||||||
from superset.superset_typing import OAuth2State
|
from superset.superset_typing import (
|
||||||
|
OAuth2ClientConfig,
|
||||||
|
OAuth2State,
|
||||||
|
OAuth2TokenResponse,
|
||||||
|
)
|
||||||
from superset.utils.decorators import on_error, transaction
|
from superset.utils.decorators import on_error, transaction
|
||||||
from superset.utils.oauth2 import decode_oauth2_state
|
from superset.utils.oauth2 import decode_oauth2_state
|
||||||
|
|
||||||
|
# how long the pre-create token cache lives in the KV store
|
||||||
|
PRE_CREATE_TOKEN_TTL = timedelta(minutes=5)
|
||||||
|
|
||||||
|
|
||||||
class OAuth2StoreTokenCommand(BaseCommand):
|
class OAuth2StoreTokenCommand(BaseCommand):
|
||||||
"""
|
"""
|
||||||
Command to store OAuth2 tokens in the database.
|
Command to store OAuth2 tokens.
|
||||||
|
|
||||||
|
Normal flow: the OAuth2 callback resolves the database via ``state.database_id``
|
||||||
|
and persists access/refresh tokens to ``database_user_oauth2_tokens``.
|
||||||
|
|
||||||
|
Pre-create flow: when ``state.database_id`` is ``None`` (the database hasn't
|
||||||
|
been saved yet — typically the "Create database" wizard), the command reads
|
||||||
|
the OAuth2 client config and engine name from the KV store entry that
|
||||||
|
:meth:`BaseEngineSpec.start_oauth2_dance` stashed there, exchanges the code,
|
||||||
|
and caches the resulting access token in the same KV entry for
|
||||||
|
:func:`get_oauth2_access_token` to pick up on the retry.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, parameters: OAuth2ProviderResponseSchema):
|
def __init__(self, parameters: OAuth2ProviderResponseSchema):
|
||||||
self._parameters = parameters
|
self._parameters = parameters
|
||||||
self._state: OAuth2State | None = None
|
self._state: OAuth2State | None = None
|
||||||
self._database: Database | None = None
|
self._database: Database | None = None
|
||||||
|
self._oauth2_config: OAuth2ClientConfig | None = None
|
||||||
|
self._engine_spec: type[BaseEngineSpec] | None = None
|
||||||
|
self._tab_uuid: UUID | None = None
|
||||||
|
|
||||||
@transaction(on_error=partial(on_error, reraise=OAuth2Error))
|
@transaction(on_error=partial(on_error, reraise=OAuth2Error))
|
||||||
def run(self) -> DatabaseUserOAuth2Tokens:
|
def run(self) -> DatabaseUserOAuth2Tokens | None:
|
||||||
self.validate()
|
self.validate()
|
||||||
self._database = cast(Database, self._database)
|
|
||||||
self._state = cast(OAuth2State, self._state)
|
self._state = cast(OAuth2State, self._state)
|
||||||
|
self._oauth2_config = cast(OAuth2ClientConfig, self._oauth2_config)
|
||||||
oauth2_config = self._database.get_oauth2_config()
|
self._engine_spec = cast(type[BaseEngineSpec], self._engine_spec)
|
||||||
if oauth2_config is None:
|
|
||||||
raise OAuth2Error("No configuration found for OAuth2")
|
|
||||||
|
|
||||||
# Look up PKCE code_verifier from KV store (RFC 7636)
|
# Look up PKCE code_verifier from KV store (RFC 7636)
|
||||||
code_verifier = None
|
code_verifier = self._pop_code_verifier()
|
||||||
tab_id = self._state["tab_id"]
|
|
||||||
try:
|
|
||||||
tab_uuid = UUID(tab_id)
|
|
||||||
except ValueError:
|
|
||||||
tab_uuid = None
|
|
||||||
|
|
||||||
if tab_uuid:
|
token_response = self._engine_spec.get_oauth2_token(
|
||||||
kv_value = KeyValueDAO.get_value(
|
self._oauth2_config,
|
||||||
resource=KeyValueResource.PKCE_CODE_VERIFIER,
|
|
||||||
key=tab_uuid,
|
|
||||||
codec=JsonKeyValueCodec(),
|
|
||||||
)
|
|
||||||
if kv_value:
|
|
||||||
code_verifier = kv_value.get("code_verifier")
|
|
||||||
KeyValueDAO.delete_entry(KeyValueResource.PKCE_CODE_VERIFIER, tab_uuid)
|
|
||||||
|
|
||||||
token_response = self._database.db_engine_spec.get_oauth2_token(
|
|
||||||
oauth2_config,
|
|
||||||
self._parameters["code"],
|
self._parameters["code"],
|
||||||
code_verifier=code_verifier,
|
code_verifier=code_verifier,
|
||||||
)
|
)
|
||||||
|
|
||||||
# delete old tokens
|
if self._database is None:
|
||||||
|
# Pre-create flow: cache the access token in the KV entry the
|
||||||
|
# initial dance created. The retry of "Test Connection" will read
|
||||||
|
# it via ``get_oauth2_access_token``.
|
||||||
|
self._cache_pre_create_token(token_response)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self._persist_token(token_response)
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
if error := self._parameters.get("error"):
|
||||||
|
raise OAuth2Error(error)
|
||||||
|
|
||||||
|
self._state = decode_oauth2_state(self._parameters["state"])
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._tab_uuid = UUID(self._state["tab_id"])
|
||||||
|
except (KeyError, ValueError):
|
||||||
|
# Legacy paths may use non-UUID tab ids; we still want to support
|
||||||
|
# them when ``database_id`` is set. The pre-create path below
|
||||||
|
# requires a valid UUID.
|
||||||
|
self._tab_uuid = None
|
||||||
|
|
||||||
|
if database_id := self._state.get("database_id"):
|
||||||
|
self._database = DatabaseUserOAuth2TokensDAO.get_database(database_id)
|
||||||
|
if self._database is None:
|
||||||
|
raise DatabaseNotFoundError("Database not found")
|
||||||
|
self._oauth2_config = self._database.get_oauth2_config()
|
||||||
|
self._engine_spec = self._database.db_engine_spec
|
||||||
|
else:
|
||||||
|
if self._tab_uuid is None:
|
||||||
|
raise OAuth2Error(
|
||||||
|
"Pre-create OAuth2 callback requires a UUID tab_id",
|
||||||
|
)
|
||||||
|
cached = KeyValueDAO.get_value(
|
||||||
|
resource=KeyValueResource.OAUTH2_PRE_CREATE_TOKEN,
|
||||||
|
key=self._tab_uuid,
|
||||||
|
codec=JsonKeyValueCodec(),
|
||||||
|
)
|
||||||
|
if not cached or not cached.get("config"):
|
||||||
|
raise OAuth2Error("Pre-create OAuth2 context not found or expired")
|
||||||
|
self._oauth2_config = cast(OAuth2ClientConfig, cached["config"])
|
||||||
|
engine = self._state.get("engine") or cached.get("engine")
|
||||||
|
if not engine:
|
||||||
|
raise OAuth2Error("Pre-create OAuth2 context missing engine name")
|
||||||
|
self._engine_spec = get_engine_spec(engine)
|
||||||
|
|
||||||
|
if self._oauth2_config is None:
|
||||||
|
raise OAuth2Error("No configuration found for OAuth2")
|
||||||
|
|
||||||
|
def _pop_code_verifier(self) -> str | None:
|
||||||
|
if self._tab_uuid is None:
|
||||||
|
return None
|
||||||
|
kv_value = KeyValueDAO.get_value(
|
||||||
|
resource=KeyValueResource.PKCE_CODE_VERIFIER,
|
||||||
|
key=self._tab_uuid,
|
||||||
|
codec=JsonKeyValueCodec(),
|
||||||
|
)
|
||||||
|
if not kv_value:
|
||||||
|
return None
|
||||||
|
KeyValueDAO.delete_entry(KeyValueResource.PKCE_CODE_VERIFIER, self._tab_uuid)
|
||||||
|
return kv_value.get("code_verifier")
|
||||||
|
|
||||||
|
def _cache_pre_create_token(self, token_response: OAuth2TokenResponse) -> None:
|
||||||
|
self._state = cast(OAuth2State, self._state)
|
||||||
|
self._tab_uuid = cast(UUID, self._tab_uuid)
|
||||||
|
self._oauth2_config = cast(OAuth2ClientConfig, self._oauth2_config)
|
||||||
|
self._engine_spec = cast(type[BaseEngineSpec], self._engine_spec)
|
||||||
|
|
||||||
|
expires_on = datetime.now() + PRE_CREATE_TOKEN_TTL
|
||||||
|
KeyValueDAO.upsert_entry(
|
||||||
|
resource=KeyValueResource.OAUTH2_PRE_CREATE_TOKEN,
|
||||||
|
key=self._tab_uuid,
|
||||||
|
value={
|
||||||
|
"engine": self._engine_spec.engine,
|
||||||
|
"config": self._oauth2_config,
|
||||||
|
"user_id": self._state["user_id"],
|
||||||
|
"access_token": token_response["access_token"],
|
||||||
|
},
|
||||||
|
codec=JsonKeyValueCodec(),
|
||||||
|
expires_on=expires_on,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _persist_token(
|
||||||
|
self,
|
||||||
|
token_response: OAuth2TokenResponse,
|
||||||
|
) -> DatabaseUserOAuth2Tokens:
|
||||||
|
self._state = cast(OAuth2State, self._state)
|
||||||
|
|
||||||
if existing := DatabaseUserOAuth2TokensDAO.find_one_or_none(
|
if existing := DatabaseUserOAuth2TokensDAO.find_one_or_none(
|
||||||
user_id=self._state["user_id"],
|
user_id=self._state["user_id"],
|
||||||
database_id=self._state["database_id"],
|
database_id=self._state["database_id"],
|
||||||
):
|
):
|
||||||
DatabaseUserOAuth2TokensDAO.delete([existing])
|
DatabaseUserOAuth2TokensDAO.delete([existing])
|
||||||
|
|
||||||
# store tokens
|
|
||||||
expiration = datetime.now() + timedelta(seconds=token_response["expires_in"])
|
expiration = datetime.now() + timedelta(seconds=token_response["expires_in"])
|
||||||
return DatabaseUserOAuth2TokensDAO.create(
|
return DatabaseUserOAuth2TokensDAO.create(
|
||||||
attributes={
|
attributes={
|
||||||
@@ -95,16 +186,3 @@ class OAuth2StoreTokenCommand(BaseCommand):
|
|||||||
"refresh_token": token_response.get("refresh_token"),
|
"refresh_token": token_response.get("refresh_token"),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate(self) -> None:
|
|
||||||
if error := self._parameters.get("error"):
|
|
||||||
raise OAuth2Error(error)
|
|
||||||
|
|
||||||
self._state = decode_oauth2_state(self._parameters["state"])
|
|
||||||
|
|
||||||
if database := DatabaseUserOAuth2TokensDAO.get_database(
|
|
||||||
self._state["database_id"]
|
|
||||||
):
|
|
||||||
self._database = database
|
|
||||||
else:
|
|
||||||
raise DatabaseNotFoundError("Database not found")
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from flask import g
|
||||||
from flask_babel import gettext as _
|
from flask_babel import gettext as _
|
||||||
from sqlalchemy.exc import DBAPIError, NoSuchModuleError
|
from sqlalchemy.exc import DBAPIError, NoSuchModuleError
|
||||||
|
|
||||||
@@ -94,6 +95,13 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
|||||||
self.validate()
|
self.validate()
|
||||||
ex_str = ""
|
ex_str = ""
|
||||||
|
|
||||||
|
# Surface the wizard's tab_id (sent by the frontend) so that
|
||||||
|
# ``get_oauth2_access_token`` can find the pre-create OAuth2 token
|
||||||
|
# cached in the KV store, and so that ``start_oauth2_dance`` reuses
|
||||||
|
# this id instead of generating a new one.
|
||||||
|
if oauth2_tab_id := self._properties.get("oauth2_tab_id"):
|
||||||
|
g.oauth2_tab_id = oauth2_tab_id
|
||||||
|
|
||||||
url = make_url_safe(self._uri)
|
url = make_url_safe(self._uri)
|
||||||
engine_name = url.get_backend_name()
|
engine_name = url.get_backend_name()
|
||||||
|
|
||||||
|
|||||||
@@ -623,6 +623,16 @@ class DatabaseTestConnectionSchema(DatabaseParametersSchemaMixin, Schema):
|
|||||||
)
|
)
|
||||||
|
|
||||||
ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
|
ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
|
||||||
|
oauth2_tab_id = fields.String(
|
||||||
|
metadata={
|
||||||
|
"description": (
|
||||||
|
"UUID identifying the wizard tab for the pre-create OAuth2 flow."
|
||||||
|
" Optional; when supplied, the engine will look up the pre-create"
|
||||||
|
" OAuth2 token cached under this key."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TableMetadataOptionsResponse(TypedDict):
|
class TableMetadataOptionsResponse(TypedDict):
|
||||||
|
|||||||
@@ -664,7 +664,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||||||
# Prevent circular import.
|
# Prevent circular import.
|
||||||
from superset.daos.key_value import KeyValueDAO
|
from superset.daos.key_value import KeyValueDAO
|
||||||
|
|
||||||
tab_id = str(uuid4())
|
# Reuse the wizard-supplied tab_id when present so the retry of
|
||||||
|
# "Test Connection" can find the same KV-cached token. Otherwise we
|
||||||
|
# generate a fresh one as before.
|
||||||
|
tab_id = getattr(g, "oauth2_tab_id", None) or str(uuid4())
|
||||||
default_redirect_uri = get_oauth2_redirect_uri()
|
default_redirect_uri = get_oauth2_redirect_uri()
|
||||||
|
|
||||||
# Generate PKCE code verifier (RFC 7636)
|
# Generate PKCE code verifier (RFC 7636)
|
||||||
@@ -690,6 +693,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||||||
# belongs to.
|
# belongs to.
|
||||||
state: OAuth2State = {
|
state: OAuth2State = {
|
||||||
# Database ID and user ID are the primary key associated with the token.
|
# Database ID and user ID are the primary key associated with the token.
|
||||||
|
# ``database_id`` is ``None`` during the "Create database" wizard — the
|
||||||
|
# callback caches the access token in the KV store instead of inserting
|
||||||
|
# a row in ``database_user_oauth2_tokens``.
|
||||||
"database_id": database.id,
|
"database_id": database.id,
|
||||||
"user_id": g.user.id,
|
"user_id": g.user.id,
|
||||||
# In multi-instance deployments there might be a single proxy handling
|
# In multi-instance deployments there might be a single proxy handling
|
||||||
@@ -710,6 +716,25 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||||||
if oauth2_config is None:
|
if oauth2_config is None:
|
||||||
raise OAuth2Error("No configuration found for OAuth2")
|
raise OAuth2Error("No configuration found for OAuth2")
|
||||||
|
|
||||||
|
# Pre-create flow: the database has no row yet, so we can't look it up by id
|
||||||
|
# in the callback. Stash the engine spec + oauth2 config in the KV store
|
||||||
|
# alongside the existing PKCE entry, keyed by the same ``tab_id``. The
|
||||||
|
# callback reads this to exchange the code without a persisted ``Database``.
|
||||||
|
if database.id is None:
|
||||||
|
state["engine"] = cls.engine
|
||||||
|
KeyValueDAO.delete_expired_entries(KeyValueResource.OAUTH2_PRE_CREATE_TOKEN)
|
||||||
|
KeyValueDAO.create_entry(
|
||||||
|
resource=KeyValueResource.OAUTH2_PRE_CREATE_TOKEN,
|
||||||
|
value={
|
||||||
|
"engine": cls.engine,
|
||||||
|
"config": oauth2_config,
|
||||||
|
},
|
||||||
|
codec=JsonKeyValueCodec(),
|
||||||
|
key=UUID(tab_id),
|
||||||
|
expires_on=datetime.now() + timedelta(minutes=5),
|
||||||
|
)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
oauth_url = cls.get_oauth2_authorization_uri(
|
oauth_url = cls.get_oauth2_authorization_uri(
|
||||||
oauth2_config,
|
oauth2_config,
|
||||||
state,
|
state,
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ class KeyValueResource(StrEnum):
|
|||||||
METASTORE_CACHE = "superset_metastore_cache"
|
METASTORE_CACHE = "superset_metastore_cache"
|
||||||
LOCK = "lock"
|
LOCK = "lock"
|
||||||
PKCE_CODE_VERIFIER = "pkce_code_verifier"
|
PKCE_CODE_VERIFIER = "pkce_code_verifier"
|
||||||
|
OAUTH2_PRE_CREATE_TOKEN = "oauth2_pre_create_token" # noqa: S105
|
||||||
SQLLAB_PERMALINK = "sqllab_permalink"
|
SQLLAB_PERMALINK = "sqllab_permalink"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -402,7 +402,15 @@ class OAuth2State(TypedDict, total=False):
|
|||||||
Type for the state passed during OAuth2.
|
Type for the state passed during OAuth2.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
database_id: int
|
# ``database_id`` is ``None`` during the "Create database" wizard, where the
|
||||||
|
# OAuth2 dance runs before the database has been persisted. In that case the
|
||||||
|
# access token is cached in the KV store keyed by ``tab_id`` until the user
|
||||||
|
# saves the database.
|
||||||
|
database_id: int | None
|
||||||
user_id: int
|
user_id: int
|
||||||
default_redirect_uri: str
|
default_redirect_uri: str
|
||||||
tab_id: str
|
tab_id: str
|
||||||
|
# Engine backend code (e.g. ``"semanticapi"``), present only for pre-create
|
||||||
|
# dances so the callback can resolve the engine spec without a persisted
|
||||||
|
# database row.
|
||||||
|
engine: str
|
||||||
|
|||||||
@@ -24,10 +24,11 @@ import secrets
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any, Iterator, TYPE_CHECKING
|
from typing import Any, Iterator, TYPE_CHECKING
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
import backoff
|
import backoff
|
||||||
import jwt
|
import jwt
|
||||||
from flask import current_app as app, url_for
|
from flask import current_app as app, g, url_for
|
||||||
from marshmallow import EXCLUDE, fields, post_load, Schema, validate
|
from marshmallow import EXCLUDE, fields, post_load, Schema, validate
|
||||||
from werkzeug.routing import BuildError
|
from werkzeug.routing import BuildError
|
||||||
|
|
||||||
@@ -87,7 +88,7 @@ def generate_code_challenge(code_verifier: str) -> str:
|
|||||||
)
|
)
|
||||||
def get_oauth2_access_token(
|
def get_oauth2_access_token(
|
||||||
config: OAuth2ClientConfig,
|
config: OAuth2ClientConfig,
|
||||||
database_id: int,
|
database_id: int | None,
|
||||||
user_id: int,
|
user_id: int,
|
||||||
db_engine_spec: type[BaseEngineSpec],
|
db_engine_spec: type[BaseEngineSpec],
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
@@ -100,10 +101,18 @@ def get_oauth2_access_token(
|
|||||||
simultaneous requests for refreshing a stale token; in that case only the first
|
simultaneous requests for refreshing a stale token; in that case only the first
|
||||||
process to acquire the lock will perform the refresh, and othe process should find a
|
process to acquire the lock will perform the refresh, and othe process should find a
|
||||||
a valid token when they retry.
|
a valid token when they retry.
|
||||||
|
|
||||||
|
When ``database_id`` is ``None`` (the "Create database" wizard, where the
|
||||||
|
database hasn't been persisted yet), look up the token in the KV store under
|
||||||
|
``g.oauth2_tab_id``. The token was cached there by
|
||||||
|
:class:`OAuth2StoreTokenCommand` after the wizard's OAuth2 dance.
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
# pylint: disable=import-outside-toplevel
|
# pylint: disable=import-outside-toplevel
|
||||||
from superset.models.core import DatabaseUserOAuth2Tokens
|
from superset.models.core import DatabaseUserOAuth2Tokens
|
||||||
|
|
||||||
|
if not database_id:
|
||||||
|
return _get_pre_create_access_token(user_id)
|
||||||
|
|
||||||
token = (
|
token = (
|
||||||
db.session.query(DatabaseUserOAuth2Tokens)
|
db.session.query(DatabaseUserOAuth2Tokens)
|
||||||
.filter_by(user_id=user_id, database_id=database_id)
|
.filter_by(user_id=user_id, database_id=database_id)
|
||||||
@@ -124,6 +133,39 @@ def get_oauth2_access_token(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_pre_create_access_token(user_id: int) -> str | None:
|
||||||
|
"""
|
||||||
|
Look up a pre-create OAuth2 token in the KV store.
|
||||||
|
|
||||||
|
The KV entry is keyed by the wizard's ``tab_id``; the frontend passes it on
|
||||||
|
every "Test Connection" request, and the test-connection command exposes it
|
||||||
|
via :data:`flask.g.oauth2_tab_id`.
|
||||||
|
"""
|
||||||
|
# pylint: disable=import-outside-toplevel
|
||||||
|
from superset.daos.key_value import KeyValueDAO
|
||||||
|
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
|
||||||
|
|
||||||
|
tab_id = getattr(g, "oauth2_tab_id", None)
|
||||||
|
if not tab_id:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
tab_uuid = UUID(tab_id)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
cached = KeyValueDAO.get_value(
|
||||||
|
resource=KeyValueResource.OAUTH2_PRE_CREATE_TOKEN,
|
||||||
|
key=tab_uuid,
|
||||||
|
codec=JsonKeyValueCodec(),
|
||||||
|
)
|
||||||
|
if not cached:
|
||||||
|
return None
|
||||||
|
if cached.get("user_id") != user_id:
|
||||||
|
# the tab_id is the user's own, but be defensive
|
||||||
|
return None
|
||||||
|
return cached.get("access_token")
|
||||||
|
|
||||||
|
|
||||||
def refresh_oauth2_token(
|
def refresh_oauth2_token(
|
||||||
config: OAuth2ClientConfig,
|
config: OAuth2ClientConfig,
|
||||||
database_id: int,
|
database_id: int,
|
||||||
@@ -203,14 +245,20 @@ def refresh_oauth2_token(
|
|||||||
def encode_oauth2_state(state: OAuth2State) -> str:
|
def encode_oauth2_state(state: OAuth2State) -> str:
|
||||||
"""
|
"""
|
||||||
Encode the OAuth2 state.
|
Encode the OAuth2 state.
|
||||||
|
|
||||||
|
``database_id`` is ``None`` for the pre-create-database OAuth2 dance,
|
||||||
|
which means the token won't be persisted to ``database_user_oauth2_tokens``
|
||||||
|
when the callback runs — see :class:`OAuth2StoreTokenCommand`.
|
||||||
"""
|
"""
|
||||||
payload: dict[str, Any] = {
|
payload: dict[str, Any] = {
|
||||||
"exp": datetime.now(tz=timezone.utc) + JWT_EXPIRATION,
|
"exp": datetime.now(tz=timezone.utc) + JWT_EXPIRATION,
|
||||||
"database_id": state["database_id"],
|
"database_id": state.get("database_id"),
|
||||||
"user_id": state["user_id"],
|
"user_id": state["user_id"],
|
||||||
"default_redirect_uri": state["default_redirect_uri"],
|
"default_redirect_uri": state["default_redirect_uri"],
|
||||||
"tab_id": state["tab_id"],
|
"tab_id": state["tab_id"],
|
||||||
}
|
}
|
||||||
|
if engine := state.get("engine"):
|
||||||
|
payload["engine"] = engine
|
||||||
|
|
||||||
encoded_state = jwt.encode(
|
encoded_state = jwt.encode(
|
||||||
payload=payload,
|
payload=payload,
|
||||||
@@ -225,10 +273,11 @@ def encode_oauth2_state(state: OAuth2State) -> str:
|
|||||||
|
|
||||||
|
|
||||||
class OAuth2StateSchema(Schema):
|
class OAuth2StateSchema(Schema):
|
||||||
database_id = fields.Int(required=True)
|
database_id = fields.Int(required=True, allow_none=True)
|
||||||
user_id = fields.Int(required=True)
|
user_id = fields.Int(required=True)
|
||||||
default_redirect_uri = fields.Str(required=True)
|
default_redirect_uri = fields.Str(required=True)
|
||||||
tab_id = fields.Str(required=True)
|
tab_id = fields.Str(required=True)
|
||||||
|
engine = fields.Str(required=False, load_default=None)
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
@post_load
|
@post_load
|
||||||
@@ -237,12 +286,15 @@ class OAuth2StateSchema(Schema):
|
|||||||
data: dict[str, Any],
|
data: dict[str, Any],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> OAuth2State:
|
) -> OAuth2State:
|
||||||
return {
|
state: OAuth2State = {
|
||||||
"database_id": data["database_id"],
|
"database_id": data["database_id"],
|
||||||
"user_id": data["user_id"],
|
"user_id": data["user_id"],
|
||||||
"default_redirect_uri": data["default_redirect_uri"],
|
"default_redirect_uri": data["default_redirect_uri"],
|
||||||
"tab_id": data["tab_id"],
|
"tab_id": data["tab_id"],
|
||||||
}
|
}
|
||||||
|
if data.get("engine"):
|
||||||
|
state["engine"] = data["engine"]
|
||||||
|
return state
|
||||||
|
|
||||||
class Meta: # pylint: disable=too-few-public-methods
|
class Meta: # pylint: disable=too-few-public-methods
|
||||||
# ignore `exp`
|
# ignore `exp`
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -27,6 +27,7 @@ from superset.daos.database import DatabaseUserOAuth2TokensDAO
|
|||||||
from superset.databases.schemas import OAuth2ProviderResponseSchema
|
from superset.databases.schemas import OAuth2ProviderResponseSchema
|
||||||
from superset.exceptions import OAuth2Error
|
from superset.exceptions import OAuth2Error
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
|
from superset.superset_typing import OAuth2State
|
||||||
from superset.utils.oauth2 import decode_oauth2_state, encode_oauth2_state
|
from superset.utils.oauth2 import decode_oauth2_state, encode_oauth2_state
|
||||||
|
|
||||||
|
|
||||||
@@ -135,6 +136,80 @@ def test_run_success(
|
|||||||
mock_create.assert_called_once()
|
mock_create.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_pre_create_caches_token_in_kv(mocker: MockerFixture) -> None:
|
||||||
|
"""
|
||||||
|
With ``state.database_id is None`` the command reads the engine + config
|
||||||
|
from the pre-create KV entry, exchanges the code, and writes the token
|
||||||
|
back to KV — *not* to ``database_user_oauth2_tokens``.
|
||||||
|
"""
|
||||||
|
state: OAuth2State = {
|
||||||
|
"user_id": 1,
|
||||||
|
"database_id": None,
|
||||||
|
"default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
|
||||||
|
"tab_id": "3a3a3a3a-3a3a-3a3a-3a3a-3a3a3a3a3a3a",
|
||||||
|
"engine": "semanticapi",
|
||||||
|
}
|
||||||
|
parameters: dict[str, Any] = {
|
||||||
|
"code": "the-code",
|
||||||
|
"state": encode_oauth2_state(state),
|
||||||
|
}
|
||||||
|
|
||||||
|
kv_dao = mocker.patch("superset.commands.database.oauth2.KeyValueDAO")
|
||||||
|
kv_dao.get_value.side_effect = [
|
||||||
|
{"engine": "semanticapi", "config": {"id": "x", "secret": "y"}}, # validate()
|
||||||
|
None, # code_verifier lookup
|
||||||
|
]
|
||||||
|
|
||||||
|
engine_spec = mocker.MagicMock()
|
||||||
|
engine_spec.engine = "semanticapi"
|
||||||
|
engine_spec.get_oauth2_token.return_value = {
|
||||||
|
"access_token": "fresh-token",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"refresh_token": "refresh",
|
||||||
|
}
|
||||||
|
mocker.patch(
|
||||||
|
"superset.commands.database.oauth2.get_engine_spec",
|
||||||
|
return_value=engine_spec,
|
||||||
|
)
|
||||||
|
|
||||||
|
dao_get_db = mocker.patch.object(DatabaseUserOAuth2TokensDAO, "get_database")
|
||||||
|
dao_create = mocker.patch.object(DatabaseUserOAuth2TokensDAO, "create")
|
||||||
|
|
||||||
|
result = OAuth2StoreTokenCommand(
|
||||||
|
cast(OAuth2ProviderResponseSchema, parameters),
|
||||||
|
).run()
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
dao_get_db.assert_not_called()
|
||||||
|
dao_create.assert_not_called()
|
||||||
|
kv_dao.upsert_entry.assert_called_once()
|
||||||
|
upsert_kwargs = kv_dao.upsert_entry.call_args.kwargs
|
||||||
|
assert upsert_kwargs["value"]["access_token"] == "fresh-token" # noqa: S105
|
||||||
|
assert upsert_kwargs["value"]["user_id"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_pre_create_missing_kv_entry(mocker: MockerFixture) -> None:
|
||||||
|
"""
|
||||||
|
Pre-create flow with no cached entry should fail with a clear OAuth2Error.
|
||||||
|
"""
|
||||||
|
state: OAuth2State = {
|
||||||
|
"user_id": 1,
|
||||||
|
"database_id": None,
|
||||||
|
"default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
|
||||||
|
"tab_id": "3a3a3a3a-3a3a-3a3a-3a3a-3a3a3a3a3a3a",
|
||||||
|
}
|
||||||
|
parameters: dict[str, Any] = {"code": "x", "state": encode_oauth2_state(state)}
|
||||||
|
|
||||||
|
kv_dao = mocker.patch("superset.commands.database.oauth2.KeyValueDAO")
|
||||||
|
kv_dao.get_value.return_value = None
|
||||||
|
|
||||||
|
with pytest.raises(OAuth2Error) as exc_info:
|
||||||
|
OAuth2StoreTokenCommand(
|
||||||
|
cast(OAuth2ProviderResponseSchema, parameters),
|
||||||
|
).validate()
|
||||||
|
assert "Pre-create OAuth2 context" in exc_info.value.error.extra["error"]
|
||||||
|
|
||||||
|
|
||||||
def test_run_existing_token(
|
def test_run_existing_token(
|
||||||
mocker: MockerFixture,
|
mocker: MockerFixture,
|
||||||
mock_database: MagicMock,
|
mock_database: MagicMock,
|
||||||
|
|||||||
@@ -458,6 +458,83 @@ def test_encode_decode_oauth2_state(
|
|||||||
assert decoded["user_id"] == 2
|
assert decoded["user_id"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_encode_decode_oauth2_state_pre_create(mocker: MockerFixture) -> None:
|
||||||
|
"""
|
||||||
|
The pre-create dance encodes ``database_id=None`` and an ``engine``
|
||||||
|
field; both must survive the round-trip so the callback can resolve the
|
||||||
|
engine spec without a persisted database.
|
||||||
|
"""
|
||||||
|
from superset.superset_typing import OAuth2State
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"flask.current_app.config",
|
||||||
|
{"SECRET_KEY": "test-secret-key", "DATABASE_OAUTH2_JWT_ALGORITHM": "HS256"},
|
||||||
|
)
|
||||||
|
|
||||||
|
state: OAuth2State = {
|
||||||
|
"database_id": None,
|
||||||
|
"user_id": 2,
|
||||||
|
"default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
|
||||||
|
"tab_id": "abc",
|
||||||
|
"engine": "semanticapi",
|
||||||
|
}
|
||||||
|
with freeze_time("2024-01-01"):
|
||||||
|
decoded = decode_oauth2_state(encode_oauth2_state(state))
|
||||||
|
|
||||||
|
assert decoded["database_id"] is None
|
||||||
|
assert decoded["engine"] == "semanticapi"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_oauth2_access_token_pre_create_hit(mocker: MockerFixture) -> None:
|
||||||
|
"""
|
||||||
|
With ``database_id=None`` and a matching KV entry, the cached pre-create
|
||||||
|
token is returned.
|
||||||
|
"""
|
||||||
|
mocker.patch(
|
||||||
|
"superset.utils.oauth2.g",
|
||||||
|
oauth2_tab_id="3a3a3a3a-3a3a-3a3a-3a3a-3a3a3a3a3a3a",
|
||||||
|
)
|
||||||
|
kv_dao = mocker.patch("superset.daos.key_value.KeyValueDAO")
|
||||||
|
kv_dao.get_value.return_value = {
|
||||||
|
"access_token": "cached-token",
|
||||||
|
"user_id": 7,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert (
|
||||||
|
get_oauth2_access_token(DUMMY_OAUTH2_CONFIG, None, 7, mocker.MagicMock())
|
||||||
|
== "cached-token"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_oauth2_access_token_pre_create_miss(mocker: MockerFixture) -> None:
|
||||||
|
"""
|
||||||
|
Without a tab id (or with no KV entry / wrong user) the lookup returns None.
|
||||||
|
"""
|
||||||
|
mocker.patch("superset.utils.oauth2.g", oauth2_tab_id=None)
|
||||||
|
assert (
|
||||||
|
get_oauth2_access_token(DUMMY_OAUTH2_CONFIG, None, 7, mocker.MagicMock())
|
||||||
|
is None
|
||||||
|
)
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"superset.utils.oauth2.g",
|
||||||
|
oauth2_tab_id="3a3a3a3a-3a3a-3a3a-3a3a-3a3a3a3a3a3a",
|
||||||
|
)
|
||||||
|
kv_dao = mocker.patch("superset.daos.key_value.KeyValueDAO")
|
||||||
|
kv_dao.get_value.return_value = None
|
||||||
|
assert (
|
||||||
|
get_oauth2_access_token(DUMMY_OAUTH2_CONFIG, None, 7, mocker.MagicMock())
|
||||||
|
is None
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_dao.get_value.return_value = {"access_token": "x", "user_id": 99}
|
||||||
|
# token belongs to another user
|
||||||
|
assert (
|
||||||
|
get_oauth2_access_token(DUMMY_OAUTH2_CONFIG, None, 7, mocker.MagicMock())
|
||||||
|
is None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_oauth2_access_token_lock_not_acquired_no_error_log(
|
def test_get_oauth2_access_token_lock_not_acquired_no_error_log(
|
||||||
mocker: MockerFixture,
|
mocker: MockerFixture,
|
||||||
caplog: pytest.LogCaptureFixture,
|
caplog: pytest.LogCaptureFixture,
|
||||||
|
|||||||
Reference in New Issue
Block a user