Compare commits

...

17 Commits

Author SHA1 Message Date
Evan
563edfdd9f fix(databricks): tighten 401 oauth2 signal to avoid false positives
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-27 21:33:16 -07:00
Evan
98ba9da18c test(databricks): add docstring to mock database helper
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-27 21:33:16 -07:00
Evan
910e79dd8c fix(databricks): correct OAuth2 trigger and derive endpoints from workspace host
Addresses review feedback on the Databricks OAuth2 flow:

- `oauth2_exception` was set to `OAuth2RedirectError` (Superset's own redirect
  signal, also the base default), so `needs_oauth2()` never matched a real
  Databricks token failure and the dance never auto-started. The driver has no
  dedicated auth exception, so detect auth failures from the error message
  instead (mirrors `GSheetsEngineSpec.needs_oauth2`).

- The per-cloud endpoint templates pointed Azure at Entra ID directly
  (`login.microsoftonline.com`) and required `account_id`/`tenant_id`
  substitution. Databricks fronts the U2M flow on every workspace at
  `https://<host>/oidc/v1/{authorize,token}` across AWS/Azure/GCP, so the
  authorization endpoint now derives from the workspace host with no account
  identifier. The token endpoint still requires explicit config (no DB context
  at exchange time); the error and docs now point at the workspace-host URL.

Shared OAuth logic is consolidated onto `DatabricksDynamicBaseEngineSpec`,
removing the duplicated overrides in both engine specs.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-27 21:33:16 -07:00
Evan
6251280aa5 fix(databricks): guard non-string cloud_provider before .lower()
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-27 21:33:16 -07:00
Evan Rusackas
9d0a2209c6 test(databricks): cover invalid cloud_provider fallback to hostname
The only uncovered branch in `_detect_cloud_provider`: an unrecognized explicit
`cloud_provider` should be ignored and detection should fall back to hostname
sniffing rather than returning the bad value.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-27 21:33:16 -07:00
Evan
cb7d5c8847 docs(databricks): clarify token endpoint is not auto-detected
The authorization endpoint auto-resolves from the hostname, but the token
exchange has no database context, so token_request_uri must be supplied for
the auto-detected flow. Docs implied both endpoints auto-detect.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-27 21:33:16 -07:00
Evan
fe69d222bd test(databricks): docstring the shared OAuth2 state helper
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-27 21:33:16 -07:00
Evan
8e0871bb2e test(databricks): exercise provider detection without pre-set OAuth2 URI
The multi-cloud OAuth2 URI tests passed a config with a fully-resolved
authorization_request_uri, which the engine spec now preserves. Drop the
URI for the Azure/GCP detection cases (and give those mock databases an
account_id/tenant_id) so the per-provider endpoint is actually resolved.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-27 21:33:16 -07:00
Evan Rusackas
ea1ba587f2 fix(databricks): resolve account_id in OAuth2 endpoints, preserve configured URIs
The per-cloud OAuth2 endpoint templates carry a `{}` placeholder for the
Databricks account id (or Azure tenant id) that was never substituted, so
auto-detected authorize/token URLs were emitted as `.../accounts/{}/v1/...`.
The authorization-URI methods also unconditionally overwrote a fully-resolved
`authorization_request_uri` supplied via DATABASE_OAUTH2_CLIENTS.

- Add `_resolve_oauth2_endpoint`: substitutes `account_id`/`tenant_id` from the
  database extra into the template, or raises OAuth2Error when absent instead of
  issuing a request to an unresolved endpoint.
- Preserve a configured `authorization_request_uri`; only auto-detect/resolve
  when none is set.
- `get_oauth2_token` has no database context to auto-detect, so fail fast on a
  missing `token_request_uri` rather than POST to `.../{}/v1/token`.
- Cover auto-detect/resolve, preserve-configured, and fail-fast paths for both
  the native and Python-connector specs; document `account_id`/`tenant_id`.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-27 21:33:16 -07:00
Evan
dc99373579 test(databricks): add return/param type annotations to multi-cloud OAuth fixtures
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-27 21:33:16 -07:00
Evan
8737f010f3 fix(databricks): preserve resolved OAuth2 token request URI
get_oauth2_token clobbered the config's already-resolved token_request_uri
with the AWS template that still contained an unsubstituted account-id
placeholder, so the token exchange POSTed to .../accounts/{}/v1/token. Only
fall back to the AWS endpoint when no token_request_uri is configured.

Co-authored-by: fabian_zse <fabian@zalando.de>
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-27 21:33:16 -07:00
fabian_zse
278cfbb694 cloud providers test 2026-06-27 21:33:16 -07:00
fabian_zse
1de21ec5c6 support all cloud providers 2026-06-27 21:33:15 -07:00
Fabian Halkivaha
2ed41ae8a6 fix docs slightly 2026-06-27 21:33:15 -07:00
fabian_zse
a0fdb2aa31 add databricks oauth support 2026-06-27 21:33:15 -07:00
ʈᵃᵢ
25c9f3510a test(mcp): set embedded on update_dashboard test mock (#41495) 2026-06-28 11:19:01 +07:00
Đỗ Trọng Hải
b8fd2e9725 feat(websocket,embedded-sdk): replace Jest with modern Vitest (#38308)
Signed-off-by: hainenber <dotronghai96@gmail.com>
Co-authored-by: codeant-ai-for-open-source[bot] <244253245+codeant-ai-for-open-source[bot]@users.noreply.github.com>
2026-06-28 11:12:37 +07:00
28 changed files with 3942 additions and 18963 deletions

View File

@@ -28,6 +28,10 @@ jobs:
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
persist-credentials: false
- name: Setup Node.js
uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0
with:
node-version-file: './superset-websocket/.nvmrc'
- name: Install dependencies
working-directory: ./superset-websocket
run: npm ci

View File

@@ -519,6 +519,80 @@ For a connection to a SQL endpoint you need to use the HTTP path from the endpoi
{"connect_args": {"http_path": "/sql/1.0/endpoints/****", "driver_path": "/path/to/odbc/driver"}}
```
##### OAuth2 Authentication
Superset supports OAuth2 authentication for Databricks, allowing users to authenticate with their personal Databricks accounts instead of using shared access tokens. This provides better security and audit capabilities.
###### Prerequisites
1. Create an OAuth2 application in your Databricks account:
- Go to your Databricks account console
- Navigate to **Settings** → **Developer** → **OAuth apps**
- Create a new OAuth app with the redirect URI: `http://your-superset-host:port/api/v1/database/oauth2/`
2. Configure OAuth2 in your `superset_config.py`:
```python
from datetime import timedelta
# OAuth2 configuration for Databricks
# The authorization endpoint is derived from your Databricks workspace host; the
# token endpoint must be set explicitly (see notes below).
DATABASE_OAUTH2_CLIENTS = {
"Databricks (legacy)": {
"id": "your-databricks-client-id",
"secret": "your-databricks-client-secret",
"scope": "sql",
"token_request_uri": "https://your-workspace-host/oidc/v1/token",
},
"Databricks": {
"id": "your-databricks-client-id",
"secret": "your-databricks-client-secret",
"scope": "sql",
"token_request_uri": "https://your-workspace-host/oidc/v1/token",
},
}
# OAuth2 redirect URI (adjust hostname/port for your setup)
DATABASE_OAUTH2_REDIRECT_URI = "http://your-superset-host:port/api/v1/database/oauth2/"
# Optional: OAuth2 timeout
DATABASE_OAUTH2_TIMEOUT = timedelta(seconds=30)
```
Replace the following placeholders:
- `your-databricks-client-id`: Your Databricks OAuth2 application client ID
- `your-databricks-client-secret`: Your Databricks OAuth2 application client secret
- `your-superset-host:port`: Your Superset instance hostname and port
**Multi-Cloud Provider Support**
Databricks fronts the user-to-machine (U2M) OAuth2 flow on every workspace at
`https://<workspace-host>/oidc/v1/authorize` and
`https://<workspace-host>/oidc/v1/token`, regardless of whether the workspace
runs on AWS, Azure, or GCP. Superset derives the **authorization** endpoint
directly from your connection's host, so no cloud provider or account/tenant
identifier needs to be configured.
The **token** endpoint cannot be auto-derived (token exchange has no database
context to read the host), so you must supply `token_request_uri` in
`DATABASE_OAUTH2_CLIENTS`, set to `https://<workspace-host>/oidc/v1/token` for
your workspace.
If you supply a fully-resolved `authorization_request_uri` (and/or
`token_request_uri`), those values take precedence over the host-derived
defaults.
###### Usage
Once configured, users can:
1. Connect to Databricks databases normally using access tokens
2. When querying data, Superset will automatically redirect users to authenticate with Databricks if needed
3. User-specific OAuth2 tokens will be used for database connections, providing better security and audit trails
This feature works with both "Databricks (legacy)" and "Databricks" engine types and automatically supports all major cloud providers (AWS, Azure, GCP).
#### Denodo
The recommended connector library for Denodo is

View File

@@ -32,6 +32,7 @@ and therefore are not easily unit-testable. We have instead opted to test the sd
This way, the tests can assert that the sdk actually mounts the iframe and communicates with it correctly.
At time of writing, these tests are not written yet, because we haven't yet put together the demo app that they will leverage.
### Things to e2e test once we have a demo app:
**happy path:**

View File

@@ -41,12 +41,12 @@ npm install --save @superset-ui/embedded-sdk
```
```js
import { embedDashboard } from '@superset-ui/embedded-sdk';
import { embedDashboard } from "@superset-ui/embedded-sdk";
embedDashboard({
id: 'abc123', // given by the Superset embedding UI
supersetDomain: 'https://superset.example.com',
mountPoint: document.getElementById('my-superset-container'), // any html element that can contain an iframe
id: "abc123", // given by the Superset embedding UI
supersetDomain: "https://superset.example.com",
mountPoint: document.getElementById("my-superset-container"), // any html element that can contain an iframe
fetchGuestToken: () => fetchGuestTokenFromBackend(),
dashboardUiConfig: {
// dashboard UI config: hideTitle, hideTab, hideChartControls, filters.visible, filters.expanded (optional), urlParams (optional)
@@ -55,21 +55,21 @@ embedDashboard({
expanded: true,
},
urlParams: {
foo: 'value1',
bar: 'value2',
foo: "value1",
bar: "value2",
// themeMode: 'dark', // set the initial theme: 'dark' | 'system' | 'default' (default: 'default')
// ...
},
},
// optional additional iframe sandbox attributes
iframeSandboxExtras: [
'allow-top-navigation',
'allow-popups-to-escape-sandbox',
"allow-top-navigation",
"allow-popups-to-escape-sandbox",
],
// optional Permissions Policy features
iframeAllowExtras: ['clipboard-write', 'fullscreen'],
iframeAllowExtras: ["clipboard-write", "fullscreen"],
// optional config to enforce a particular referrerPolicy
referrerPolicy: 'same-origin',
referrerPolicy: "same-origin",
// optional callback to customize permalink URLs
resolvePermalinkUrl: ({ key }) => `https://my-app.com/analytics/share/${key}`,
});
@@ -163,13 +163,13 @@ Use the `themeMode` URL parameter to control the embedded dashboard's initial co
```js
embedDashboard({
id: 'abc123',
supersetDomain: 'https://superset.example.com',
mountPoint: document.getElementById('my-superset-container'),
id: "abc123",
supersetDomain: "https://superset.example.com",
mountPoint: document.getElementById("my-superset-container"),
fetchGuestToken: () => fetchGuestTokenFromBackend(),
dashboardUiConfig: {
urlParams: {
themeMode: 'dark', // 'dark' | 'system' | 'default' (default: 'default')
themeMode: "dark", // 'dark' | 'system' | 'default' (default: 'default')
},
},
});
@@ -193,7 +193,7 @@ To pass additional sandbox attributes you can use `iframeSandboxExtras`:
```js
// optional additional iframe sandbox attributes
iframeSandboxExtras: ['allow-top-navigation', 'allow-popups-to-escape-sandbox'];
iframeSandboxExtras: ["allow-top-navigation", "allow-popups-to-escape-sandbox"];
```
### Permissions Policy
@@ -202,7 +202,7 @@ To enable specific browser features within the embedded iframe, use `iframeAllow
```js
// optional Permissions Policy features
iframeAllowExtras: ['clipboard-write', 'fullscreen'];
iframeAllowExtras: ["clipboard-write", "fullscreen"];
```
Common permissions you might need:
@@ -225,9 +225,9 @@ When users click share buttons inside an embedded dashboard, Superset generates
```js
embedDashboard({
id: 'abc123',
supersetDomain: 'https://superset.example.com',
mountPoint: document.getElementById('my-superset-container'),
id: "abc123",
supersetDomain: "https://superset.example.com",
mountPoint: document.getElementById("my-superset-container"),
fetchGuestToken: () => fetchGuestTokenFromBackend(),
// Customize permalink URLs
@@ -245,9 +245,9 @@ To restore the dashboard state from a permalink in your app:
const permalinkKey = routeParams.key;
embedDashboard({
id: 'abc123',
supersetDomain: 'https://superset.example.com',
mountPoint: document.getElementById('my-superset-container'),
id: "abc123",
supersetDomain: "https://superset.example.com",
mountPoint: document.getElementById("my-superset-container"),
fetchGuestToken: () => fetchGuestTokenFromBackend(),
resolvePermalinkUrl: ({ key }) => `https://my-app.com/analytics/share/${key}`,
dashboardUiConfig: {

View File

@@ -18,9 +18,6 @@
*/
module.exports = {
presets: [
"@babel/preset-typescript",
"@babel/preset-env"
],
presets: ["@babel/preset-typescript", "@babel/preset-env"],
sourceMaps: true,
};

File diff suppressed because it is too large Load Diff

View File

@@ -24,7 +24,7 @@
"scripts": {
"build": "tsc && babel src --out-dir lib --extensions '.ts,.tsx' && webpack --mode production",
"ci:release": "node ./release-if-necessary.js",
"test": "jest"
"test": "vitest --run --dir src"
},
"browserslist": [
"last 3 chrome versions",
@@ -41,12 +41,11 @@
"@babel/core": "^7.25.2",
"@babel/preset-env": "^7.25.4",
"@babel/preset-typescript": "^7.24.7",
"@types/jest": "^29.5.12",
"@types/node": "^22.5.4",
"@types/node": "^25.4.0",
"babel-loader": "^9.1.3",
"jest": "^29.7.0",
"tscw-config": "^1.1.2",
"typescript": "^5.6.2",
"typescript": "^5.9.3",
"vitest": "^4.0.18",
"webpack": "^5.94.0",
"webpack-cli": "^5.1.4"
},

View File

@@ -17,15 +17,15 @@
* under the License.
*/
const { execSync } = require('child_process');
const { name, version } = require('./package.json');
const { execSync } = require("child_process");
const { name, version } = require("./package.json");
function log(...args) {
console.log('[embedded-sdk-release]', ...args);
console.log("[embedded-sdk-release]", ...args);
}
function logError(...args) {
console.error('[embedded-sdk-release]', ...args);
console.error("[embedded-sdk-release]", ...args);
}
(async () => {
@@ -38,13 +38,13 @@ function logError(...args) {
const { status } = await fetch(packageUrl);
if (status === 200) {
log('version already exists on npm, exiting');
log("version already exists on npm, exiting");
} else if (status === 404) {
log('release required, building');
log("release required, building");
try {
execSync('npm run build', { stdio: 'pipe' });
log('build successful, publishing')
execSync('npm publish --access public', { stdio: 'pipe' });
execSync("npm run build", { stdio: "pipe" });
log("build successful, publishing");
execSync("npm publish --access public", { stdio: "pipe" });
log(`published ${version} to npm`);
} catch (err) {
// npm writes failure details to stderr (auth/permission/registry
@@ -52,7 +52,7 @@ function logError(...args) {
// the real cause in CI logs.
if (err.stdout) console.error(String(err.stdout));
if (err.stderr) console.error(String(err.stderr));
logError('Encountered an error, details should be above');
logError("Encountered an error, details should be above");
process.exitCode = 1;
}
} else {

View File

@@ -18,7 +18,9 @@
*/
export const IFRAME_COMMS_MESSAGE_TYPE = "__embedded_comms__";
export const DASHBOARD_UI_FILTER_CONFIG_URL_PARAM_KEY: { [index: string]: any } = {
export const DASHBOARD_UI_FILTER_CONFIG_URL_PARAM_KEY: {
[index: string]: any;
} = {
visible: "show_filters",
expanded: "expand_filters",
}
};

View File

@@ -24,22 +24,23 @@ import {
DEFAULT_TOKEN_EXP_MS,
DEFAULT_TOKEN_REFRESH_RETRY_MS,
} from "./guestTokenRefresh";
import { afterAll, beforeAll, it, expect, describe, vi } from "vitest";
describe("guest token refresh", () => {
beforeAll(() => {
jest.useFakeTimers();
jest.setSystemTime(new Date("2022-03-03 01:00"));
jest.spyOn(global, "setTimeout");
vi.useFakeTimers();
vi.setSystemTime(new Date("2022-03-03 01:00"));
vi.spyOn(globalThis, "setTimeout");
});
afterAll(() => {
jest.useRealTimers();
vi.useRealTimers();
});
function makeFakeJWT(claims: any) {
// not a valid jwt, but close enough for this code
const tokenifiedClaims = Buffer.from(JSON.stringify(claims)).toString(
"base64"
"base64",
);
return `abc.${tokenifiedClaims}.xyz`;
}

View File

@@ -18,17 +18,23 @@
*/
import { jwtDecode } from "jwt-decode";
export const REFRESH_TIMING_BUFFER_MS = 5000 // refresh guest token early to avoid failed superset requests
export const MIN_REFRESH_WAIT_MS = 10000 // avoid blasting requests as fast as the cpu can handle
export const DEFAULT_TOKEN_EXP_MS = 300000 // (5 min) used only when parsing guest token exp fails
export const DEFAULT_TOKEN_REFRESH_RETRY_MS = 10000 // wait before retrying a failed/timed-out token refresh
export const REFRESH_TIMING_BUFFER_MS = 5000; // refresh guest token early to avoid failed superset requests
export const MIN_REFRESH_WAIT_MS = 10000; // avoid blasting requests as fast as the cpu can handle
export const DEFAULT_TOKEN_EXP_MS = 300000; // (5 min) used only when parsing guest token exp fails
export const DEFAULT_TOKEN_REFRESH_RETRY_MS = 10000; // wait before retrying a failed/timed-out token refresh
// when do we refresh the guest token?
export function getGuestTokenRefreshTiming(currentGuestToken: string) {
const parsedJwt = jwtDecode<Record<string, any>>(currentGuestToken);
// if exp is int, it is in seconds, but Date() takes milliseconds
const exp = new Date(/[^0-9\.]/g.test(parsedJwt.exp) ? parsedJwt.exp : parseFloat(parsedJwt.exp) * 1000);
const isValidDate = exp.toString() !== 'Invalid Date';
const ttl = isValidDate ? Math.max(MIN_REFRESH_WAIT_MS, exp.getTime() - Date.now()) : DEFAULT_TOKEN_EXP_MS;
const exp = new Date(
/[^0-9\.]/g.test(parsedJwt.exp)
? parsedJwt.exp
: parseFloat(parsedJwt.exp) * 1000,
);
const isValidDate = exp.toString() !== "Invalid Date";
const ttl = isValidDate
? Math.max(MIN_REFRESH_WAIT_MS, exp.getTime() - Date.now())
: DEFAULT_TOKEN_EXP_MS;
return ttl - REFRESH_TIMING_BUFFER_MS;
}

View File

@@ -20,15 +20,15 @@
import {
DASHBOARD_UI_FILTER_CONFIG_URL_PARAM_KEY,
IFRAME_COMMS_MESSAGE_TYPE,
} from './const';
} from "./const";
// We can swap this out for the actual switchboard package once it gets published
import { Switchboard } from '@superset-ui/switchboard';
import { Switchboard } from "@superset-ui/switchboard";
import {
getGuestTokenRefreshTiming,
DEFAULT_TOKEN_REFRESH_RETRY_MS,
} from './guestTokenRefresh';
import { withTimeout } from './withTimeout';
} from "./guestTokenRefresh";
import { withTimeout } from "./withTimeout";
/**
* The function to fetch a guest token from your Host App's backend server.
@@ -97,7 +97,7 @@ export type ObserveDataMaskCallbackFn = (
nativeFiltersChanged: boolean;
},
) => void;
export type ThemeMode = 'default' | 'dark' | 'system';
export type ThemeMode = "default" | "dark" | "system";
/**
* Callback to resolve permalink URLs.
@@ -113,12 +113,12 @@ export type EmbeddedDashboard = {
unmount: () => void;
getDashboardPermalink: (anchor: string) => Promise<string>;
getActiveTabs: () => Promise<string[]>;
observeDataMask: (
callbackFn: ObserveDataMaskCallbackFn,
) => void;
observeDataMask: (callbackFn: ObserveDataMaskCallbackFn) => void;
getDataMask: () => Promise<Record<string, any>>;
getChartStates: () => Promise<Record<string, any>>;
getChartDataPayloads: (params?: { chartId?: number }) => Promise<Record<string, any>>;
getChartDataPayloads: (params?: {
chartId?: number;
}) => Promise<Record<string, any>>;
setThemeConfig: (themeConfig: Record<string, any>) => void;
setThemeMode: (mode: ThemeMode) => void;
};
@@ -133,7 +133,7 @@ export async function embedDashboard({
fetchGuestToken,
dashboardUiConfig,
debug = false,
iframeTitle = 'Embedded Dashboard',
iframeTitle = "Embedded Dashboard",
iframeSandboxExtras = [],
iframeAllowExtras = [],
referrerPolicy,
@@ -152,13 +152,13 @@ export async function embedDashboard({
return withTimeout(
fetchGuestToken(),
guestTokenFetchTimeoutMs,
'fetchGuestToken',
"fetchGuestToken",
);
}
log('embedding');
log("embedding");
if (supersetDomain.endsWith('/')) {
if (supersetDomain.endsWith("/")) {
supersetDomain = supersetDomain.slice(0, -1);
}
@@ -185,15 +185,15 @@ export async function embedDashboard({
}
async function mountIframe(): Promise<Switchboard> {
return new Promise(resolve => {
const iframe = document.createElement('iframe');
return new Promise((resolve) => {
const iframe = document.createElement("iframe");
const dashboardConfigUrlParams = dashboardUiConfig
? { uiConfig: `${calculateConfig()}` }
: undefined;
const filterConfig = dashboardUiConfig?.filters || {};
const filterConfigKeys = Object.keys(filterConfig);
const filterConfigUrlParams = Object.fromEntries(
filterConfigKeys.map(key => [
filterConfigKeys.map((key) => [
DASHBOARD_UI_FILTER_CONFIG_URL_PARAM_KEY[key],
filterConfig[key],
]),
@@ -206,16 +206,16 @@ export async function embedDashboard({
...dashboardUiConfig?.urlParams,
};
const urlParamsString = Object.keys(urlParams).length
? '?' + new URLSearchParams(urlParams).toString()
: '';
? "?" + new URLSearchParams(urlParams).toString()
: "";
// set up the iframe's sandbox configuration
iframe.sandbox.add('allow-same-origin'); // needed for postMessage to work
iframe.sandbox.add('allow-scripts'); // obviously the iframe needs scripts
iframe.sandbox.add('allow-presentation'); // for fullscreen charts
iframe.sandbox.add('allow-downloads'); // for downloading charts as image
iframe.sandbox.add('allow-forms'); // for forms to submit
iframe.sandbox.add('allow-popups'); // for exporting charts as csv
iframe.sandbox.add("allow-same-origin"); // needed for postMessage to work
iframe.sandbox.add("allow-scripts"); // obviously the iframe needs scripts
iframe.sandbox.add("allow-presentation"); // for fullscreen charts
iframe.sandbox.add("allow-downloads"); // for downloading charts as image
iframe.sandbox.add("allow-forms"); // for forms to submit
iframe.sandbox.add("allow-popups"); // for exporting charts as csv
// additional sandbox props
iframeSandboxExtras.forEach((key: string) => {
iframe.sandbox.add(key);
@@ -226,7 +226,7 @@ export async function embedDashboard({
}
// add the event listener before setting src, to be 100% sure that we capture the load event
iframe.addEventListener('load', () => {
iframe.addEventListener("load", () => {
// MessageChannel allows us to send and receive messages smoothly between our window and the iframe
// See https://developer.mozilla.org/en-US/docs/Web/API/Channel_Messaging_API
const commsChannel = new MessageChannel();
@@ -237,35 +237,35 @@ export async function embedDashboard({
// See https://developer.mozilla.org/en-US/docs/Web/API/Window/postMessage
// we know the content window isn't null because we are in the load event handler.
iframe.contentWindow!.postMessage(
{ type: IFRAME_COMMS_MESSAGE_TYPE, handshake: 'port transfer' },
{ type: IFRAME_COMMS_MESSAGE_TYPE, handshake: "port transfer" },
supersetDomain,
[theirPort],
);
log('sent message channel to the iframe');
log("sent message channel to the iframe");
// return our port from the promise
resolve(
new Switchboard({
port: ourPort,
name: 'superset-embedded-sdk',
name: "superset-embedded-sdk",
debug,
}),
);
});
iframe.src = `${supersetDomain}/embedded/${id}${urlParamsString}`;
iframe.title = iframeTitle;
iframe.style.background = 'transparent';
iframe.style.background = "transparent";
// Permissions Policy features the embedded dashboard relies on. Modern
// browsers gate these APIs on the iframe's `allow` attribute regardless
// of sandbox flags, so we include them by default. Host apps can extend
// the list via `iframeAllowExtras`.
const allowFeatures = Array.from(
new Set(['fullscreen', 'clipboard-write', ...iframeAllowExtras]),
new Set(["fullscreen", "clipboard-write", ...iframeAllowExtras]),
);
iframe.setAttribute('allow', allowFeatures.join('; '));
iframe.setAttribute("allow", allowFeatures.join("; "));
//@ts-ignore
mountPoint.replaceChildren(iframe);
log('placed the iframe');
log("placed the iframe");
});
}
@@ -285,8 +285,8 @@ export async function embedDashboard({
throw err;
}
ourPort.emit('guestToken', { guestToken });
log('sent guest token');
ourPort.emit("guestToken", { guestToken });
log("sent guest token");
// Track the pending refresh timer so it can be cancelled on unmount, and
// stop the cycle once unmounted so it cannot leak across mount/unmount cycles.
@@ -298,7 +298,7 @@ export async function embedDashboard({
try {
const newGuestToken = await fetchGuestTokenWithTimeout();
if (unmounted) return;
ourPort.emit('guestToken', { guestToken: newGuestToken });
ourPort.emit("guestToken", { guestToken: newGuestToken });
refreshTimer = setTimeout(
refreshGuestToken,
getGuestTokenRefreshTiming(newGuestToken),
@@ -307,7 +307,7 @@ export async function embedDashboard({
// A transient fetch failure or timeout must not permanently stop the
// refresh cycle. Log it and retry so the session can recover once the
// host callback succeeds again.
log('failed to refresh guest token, will retry:', err);
log("failed to refresh guest token, will retry:", err);
if (unmounted) return;
refreshTimer = setTimeout(
refreshGuestToken,
@@ -325,7 +325,7 @@ export async function embedDashboard({
// Returns null if no callback provided or on error, allowing iframe to use default URL
ourPort.start();
ourPort.defineMethod(
'resolvePermalinkUrl',
"resolvePermalinkUrl",
async ({ key }: { key: string }): Promise<string | null> => {
if (!resolvePermalinkUrl) {
return null;
@@ -333,14 +333,14 @@ export async function embedDashboard({
try {
return await resolvePermalinkUrl({ key });
} catch (error) {
log('Error in resolvePermalinkUrl callback:', error);
log("Error in resolvePermalinkUrl callback:", error);
return null;
}
},
);
function unmount() {
log('unmounting');
log("unmounting");
unmounted = true;
if (refreshTimer !== undefined) {
clearTimeout(refreshTimer);
@@ -350,24 +350,25 @@ export async function embedDashboard({
mountPoint.replaceChildren();
}
const getScrollSize = () => ourPort.get<Size>('getScrollSize');
const getScrollSize = () => ourPort.get<Size>("getScrollSize");
const getDashboardPermalink = (anchor: string) =>
ourPort.get<string>('getDashboardPermalink', { anchor });
const getActiveTabs = () => ourPort.get<string[]>('getActiveTabs');
const getDataMask = () => ourPort.get<Record<string, any>>('getDataMask');
const getChartStates = () => ourPort.get<Record<string, any>>('getChartStates');
ourPort.get<string>("getDashboardPermalink", { anchor });
const getActiveTabs = () => ourPort.get<string[]>("getActiveTabs");
const getDataMask = () => ourPort.get<Record<string, any>>("getDataMask");
const getChartStates = () =>
ourPort.get<Record<string, any>>("getChartStates");
const getChartDataPayloads = (params?: { chartId?: number }) =>
ourPort.get<Record<string, any>>('getChartDataPayloads', params);
const observeDataMask = (
callbackFn: ObserveDataMaskCallbackFn,
) => {
ourPort.defineMethod('observeDataMask', callbackFn);
ourPort.get<Record<string, any>>("getChartDataPayloads", params);
const observeDataMask = (callbackFn: ObserveDataMaskCallbackFn) => {
ourPort.defineMethod("observeDataMask", callbackFn);
};
// TODO: Add proper types once theming branch is merged
const setThemeConfig = async (themeConfig: Record<string, any>): Promise<void> => {
const setThemeConfig = async (
themeConfig: Record<string, any>,
): Promise<void> => {
try {
ourPort.emit('setThemeConfig', { themeConfig });
log('Theme config sent successfully (or at least message dispatched)');
ourPort.emit("setThemeConfig", { themeConfig });
log("Theme config sent successfully (or at least message dispatched)");
} catch (error) {
log(
'Error sending theme config. Ensure the iframe side implements the "setThemeConfig" method.',
@@ -378,7 +379,7 @@ export async function embedDashboard({
const setThemeMode = (mode: ThemeMode): void => {
try {
ourPort.emit('setThemeMode', { mode });
ourPort.emit("setThemeMode", { mode });
log(`Theme mode set to: ${mode}`);
} catch (error) {
log(

View File

@@ -18,22 +18,23 @@
*/
import { withTimeout } from "./withTimeout";
import { test, expect } from "vitest";
test("resolves with the value when the promise settles in time", async () => {
await expect(withTimeout(Promise.resolve("ok"), 1000, "fetch")).resolves.toBe(
"ok"
"ok",
);
});
test("rejects when the promise does not settle within the timeout", async () => {
const never = new Promise<string>(() => {});
await expect(withTimeout(never, 10, "fetch")).rejects.toThrow(
/fetch did not resolve within 10ms/
/fetch did not resolve within 10ms/,
);
});
test("passes the promise through unchanged when the timeout is disabled", async () => {
await expect(withTimeout(Promise.resolve("ok"), 0, "fetch")).resolves.toBe(
"ok"
"ok",
);
});

View File

@@ -3,7 +3,7 @@
// syntax rules
"strict": true,
"moduleResolution": "node",
"moduleResolution": "bundler",
// environment
"target": "es6",
@@ -13,7 +13,9 @@
// output
"outDir": "./dist",
"emitDeclarationOnly": true,
"declaration": true
"declaration": true,
"types": ["node"]
},
"include": [
@@ -21,7 +23,6 @@
],
"exclude": [
"tests",
"dist",
"lib",
"node_modules"

View File

@@ -17,19 +17,19 @@
* under the License.
*/
const path = require('path');
const path = require("path");
module.exports = {
entry: './src/index.ts',
entry: "./src/index.ts",
output: {
filename: 'index.js',
path: path.resolve(__dirname, 'bundle'),
filename: "index.js",
path: path.resolve(__dirname, "bundle"),
// this exposes the library's exports under a global variable
library: {
name: "supersetEmbeddedSdk",
type: "umd"
}
type: "umd",
},
},
devtool: "source-map",
module: {
@@ -38,12 +38,12 @@ module.exports = {
test: /\.[tj]s$/,
// babel-loader is faster than ts-loader because it ignores types.
// We do type checking in a separate process, so that's fine.
use: 'babel-loader',
use: "babel-loader",
exclude: /node_modules/,
},
],
},
resolve: {
extensions: ['.ts', '.js'],
extensions: [".ts", ".js"],
},
};

View File

@@ -16,6 +16,7 @@ KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
# Superset WebSocket Server
A Node.js WebSocket server for sending async event data to the Superset web application frontend.
@@ -164,4 +165,4 @@ HEAD /health
## Containerization
*TODO: containerize websocket server*
_TODO: containerize websocket server_

View File

@@ -1,22 +0,0 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module.exports = {
preset: 'ts-jest',
testEnvironment: 'node',
};

File diff suppressed because it is too large Load Diff

View File

@@ -5,7 +5,7 @@
"main": "index.js",
"scripts": {
"start": "node dist/index.js start",
"test": "NODE_ENV=test jest -i spec",
"test": "npx vitest --run --dir spec",
"type": "tsc --noEmit",
"eslint": "eslint",
"lint": "npm run eslint -- . && npm run type",
@@ -28,24 +28,22 @@
"devDependencies": {
"@eslint/js": "^9.25.1",
"@types/eslint__js": "^8.42.3",
"@types/jest": "^29.5.14",
"@types/jsonwebtoken": "^9.0.10",
"@types/lodash": "^4.17.24",
"@types/node": "^25.9.3",
"@types/ws": "^8.18.1",
"@typescript-eslint/eslint-plugin": "^8.61.1",
"@typescript-eslint/parser": "^8.61.1",
"@typescript-eslint/eslint-plugin": "^8.62.0",
"@typescript-eslint/parser": "^8.62.0",
"eslint": "^10.5.0",
"eslint-config-prettier": "^10.1.8",
"eslint-plugin-lodash": "^8.0.0",
"globals": "^17.6.0",
"jest": "^29.7.0",
"prettier": "^3.8.4",
"ts-jest": "^29.4.11",
"ts-node": "^10.9.2",
"tscw-config": "^1.1.2",
"typescript": "^6.0.3",
"typescript-eslint": "^8.61.1"
"typescript-eslint": "^8.62.0",
"vitest": "^4.1.5"
},
"engines": {
"node": "^24.16.0",

View File

@@ -17,7 +17,7 @@
* under the License.
*/
import { buildConfig } from '../src/config';
import { expect, test } from '@jest/globals';
import { expect, test } from 'vitest';
test('buildConfig() builds configuration and applies env var overrides', () => {
let config = buildConfig();

View File

@@ -25,29 +25,29 @@ import {
test,
beforeEach,
afterEach,
jest,
} from '@jest/globals';
vi,
type Mock,
} from 'vitest';
import * as http from 'http';
import * as net from 'net';
import { WebSocket } from 'ws';
import * as server from '../src/index';
import { statsd } from '../src/index';
interface MockedRedisXrange {
(): Promise<server.StreamResult[]>;
}
// NOTE: these mock variables needs to start with "mock" due to
// calls to `jest.mock` being hoisted to the top of the file.
// https://jestjs.io/docs/es6-class-mocks#calling-jestmock-with-the-module-factory-parameter
const mockRedisXrange = jest.fn() as jest.MockedFunction<MockedRedisXrange>;
jest.mock('ws');
jest.mock('ioredis', () => {
return jest.fn().mockImplementation(() => {
return { xrange: mockRedisXrange, on: jest.fn() };
});
const { mockRedisXrange } = vi.hoisted(() => {
return { mockRedisXrange: vi.fn() };
});
const wsMock = WebSocket as jest.Mocked<typeof WebSocket>;
vi.mock('ws');
vi.mock('ioredis', () => {
return {
Redis: vi.fn().mockImplementation(function () {
return { xrange: mockRedisXrange, on: vi.fn() };
}),
};
});
const wsMock = WebSocket as unknown as Mock<typeof WebSocket>;
const channelId = 'bc9e040c-7b4a-4817-99b9-292832d97ec7';
const streamReturnValue: server.StreamResult[] = [
[
@@ -66,16 +66,13 @@ const streamReturnValue: server.StreamResult[] = [
],
];
import * as server from '../src/index';
import { statsd } from '../src/index';
describe('server', () => {
let statsdIncrementMock: jest.SpiedFunction<typeof statsd.increment>;
let statsdIncrementMock: Mock<typeof statsd.increment>;
beforeEach(() => {
mockRedisXrange.mockClear();
server.resetState();
statsdIncrementMock = jest.spyOn(statsd, 'increment').mockReturnValue();
statsdIncrementMock = vi.spyOn(statsd, 'increment').mockReturnValue();
});
afterEach(() => {
@@ -84,8 +81,8 @@ describe('server', () => {
describe('HTTP requests', () => {
test('services health checks', () => {
const endMock = jest.fn();
const writeHeadMock = jest.fn();
const endMock = vi.fn();
const writeHeadMock = vi.fn();
const request = {
url: '/health',
@@ -113,8 +110,8 @@ describe('server', () => {
});
test('responds with a 404 when not found', () => {
const endMock = jest.fn();
const writeHeadMock = jest.fn();
const endMock = vi.fn();
const writeHeadMock = vi.fn();
const request = {
url: '/unsupported',
@@ -245,7 +242,7 @@ describe('server', () => {
describe('processStreamResults', () => {
test('sends data to channel', async () => {
const ws = new wsMock('localhost');
const sendMock = jest.spyOn(ws, 'send');
const sendMock = vi.spyOn(ws, 'send');
const socketInstance = { ws: ws, channel: channelId, pongTs: Date.now() };
expect(statsdIncrementMock).toHaveBeenCalledTimes(0);
@@ -267,7 +264,7 @@ describe('server', () => {
test('channel not present', async () => {
const ws = new wsMock('localhost');
const sendMock = jest.spyOn(ws, 'send');
const sendMock = vi.spyOn(ws, 'send');
expect(statsdIncrementMock).toHaveBeenCalledTimes(0);
server.processStreamResults(streamReturnValue);
@@ -278,10 +275,9 @@ describe('server', () => {
test('error sending data to client', async () => {
const ws = new wsMock('localhost');
const sendMock = jest.spyOn(ws, 'send').mockImplementation(() => {
const sendMock = vi.spyOn(ws, 'send').mockImplementation(() => {
throw new Error();
});
const cleanChannelMock = jest.spyOn(server, 'cleanChannel');
const socketInstance = { ws: ws, channel: channelId, pongTs: Date.now() };
expect(statsdIncrementMock).toHaveBeenCalledTimes(0);
@@ -300,9 +296,7 @@ describe('server', () => {
);
expect(sendMock).toHaveBeenCalled();
expect(cleanChannelMock).toHaveBeenCalledWith(channelId);
cleanChannelMock.mockRestore();
expect(Object.keys(server.channels)).toHaveLength(0);
});
const makeItem = (i: number): server.StreamResult =>
@@ -330,8 +324,8 @@ describe('server', () => {
channel: channelId,
pongTs: Date.now(),
});
const sendMock = jest.spyOn(ws, 'send');
const setImmediateSpy = jest.spyOn(global, 'setImmediate');
const sendMock = vi.spyOn(ws, 'send');
const setImmediateSpy = vi.spyOn(global, 'setImmediate');
const results = [0, 1, 2, 3, 4].map(makeItem);
await server.processStreamResults(results);
@@ -351,7 +345,7 @@ describe('server', () => {
channel: channelId,
pongTs: Date.now(),
});
const sendMock = jest.spyOn(ws, 'send');
const sendMock = vi.spyOn(ws, 'send');
const results = [0, 1, 2, 3, 4].map(makeItem);
await server.processStreamResults(results);
@@ -372,16 +366,16 @@ describe('server', () => {
server.opts.maxSocketBufferBytes = 0;
// Restore any spies (e.g. on server.cleanChannel) so they don't leak
// across tests and cause order-dependent failures.
jest.restoreAllMocks();
vi.restoreAllMocks();
});
test('does not terminate when cap disabled (0)', () => {
server.opts.maxSocketBufferBytes = 0;
const ws = new wsMock('localhost');
// simulate a large outbound buffer
(ws as unknown as { bufferedAmount: number }).bufferedAmount = 10_000_000;
const terminateMock = jest.spyOn(ws, 'terminate');
const sendMock = jest.spyOn(ws, 'send');
vi.spyOn(ws, 'bufferedAmount', 'get').mockReturnValueOnce(10_000_000);
const terminateMock = vi.spyOn(ws, 'terminate');
const sendMock = vi.spyOn(ws, 'send');
server.trackClient(channelId, {
ws,
channel: channelId,
@@ -397,10 +391,9 @@ describe('server', () => {
test('terminates a slow client whose buffer exceeds the cap', () => {
server.opts.maxSocketBufferBytes = 1024;
const ws = new wsMock('localhost');
(ws as unknown as { bufferedAmount: number }).bufferedAmount = 2048;
const terminateMock = jest.spyOn(ws, 'terminate');
const sendMock = jest.spyOn(ws, 'send');
const cleanChannelMock = jest.spyOn(server, 'cleanChannel');
vi.spyOn(ws, 'bufferedAmount', 'get').mockReturnValueOnce(2048);
const terminateMock = vi.spyOn(ws, 'terminate');
const sendMock = vi.spyOn(ws, 'send');
server.trackClient(channelId, {
ws,
channel: channelId,
@@ -414,15 +407,15 @@ describe('server', () => {
expect(statsdIncrementMock).toHaveBeenCalledWith(
'ws_client_backpressure_disconnect',
);
expect(cleanChannelMock).toHaveBeenCalledWith(channelId);
expect(Object.keys(server.channels)).toHaveLength(0);
});
test('keeps sending to a client within the cap', () => {
server.opts.maxSocketBufferBytes = 1024;
const ws = new wsMock('localhost');
(ws as unknown as { bufferedAmount: number }).bufferedAmount = 16;
const terminateMock = jest.spyOn(ws, 'terminate');
const sendMock = jest.spyOn(ws, 'send');
vi.spyOn(ws, 'bufferedAmount', 'get').mockReturnValueOnce(16);
const terminateMock = vi.spyOn(ws, 'terminate');
const sendMock = vi.spyOn(ws, 'send');
server.trackClient(channelId, {
ws,
channel: channelId,
@@ -443,7 +436,7 @@ describe('server', () => {
test('success with results', async () => {
mockRedisXrange.mockResolvedValueOnce(streamReturnValue);
const cb = jest.fn() as jest.MockedFunction<
const cb = vi.fn() as Mock<
(results: server.StreamResult[]) => void | Promise<void>
>;
await server.fetchRangeFromStream({
@@ -462,7 +455,7 @@ describe('server', () => {
});
test('success no results', async () => {
const cb = jest.fn() as jest.MockedFunction<
const cb = vi.fn() as Mock<
(results: server.StreamResult[]) => void | Promise<void>
>;
await server.fetchRangeFromStream({
@@ -481,7 +474,7 @@ describe('server', () => {
});
test('error', async () => {
const cb = jest.fn() as jest.MockedFunction<
const cb = vi.fn() as Mock<
(results: server.StreamResult[]) => void | Promise<void>
>;
mockRedisXrange.mockRejectedValueOnce(new Error());
@@ -503,12 +496,8 @@ describe('server', () => {
describe('wsConnection', () => {
let ws: WebSocket;
let wsEventMock: jest.SpiedFunction<typeof ws.on>;
let trackClientSpy: jest.SpiedFunction<typeof server.trackClient>;
let fetchRangeFromStreamSpy: jest.SpiedFunction<
typeof server.fetchRangeFromStream
>;
let dateNowSpy: jest.SpiedFunction<typeof Date.now>;
let wsEventMock: Mock<typeof ws.on>;
let dateNowSpy: Mock<typeof Date.now>;
let socketInstanceExpected: server.SocketInstance;
const getRequest = (token: string, url: string): http.IncomingMessage => {
@@ -521,10 +510,8 @@ describe('server', () => {
beforeEach(() => {
ws = new wsMock('localhost');
wsEventMock = jest.spyOn(ws, 'on');
trackClientSpy = jest.spyOn(server, 'trackClient');
fetchRangeFromStreamSpy = jest.spyOn(server, 'fetchRangeFromStream');
dateNowSpy = jest
wsEventMock = vi.spyOn(ws, 'on');
dateNowSpy = vi
.spyOn(global.Date, 'now')
.mockImplementation(() =>
new Date('2021-03-10T11:01:58.135Z').valueOf(),
@@ -537,10 +524,8 @@ describe('server', () => {
});
afterEach(() => {
wsEventMock.mockRestore();
trackClientSpy.mockRestore();
fetchRangeFromStreamSpy.mockRestore();
dateNowSpy.mockRestore();
wsEventMock?.mockRestore();
dateNowSpy?.mockRestore();
});
test('invalid JWT', async () => {
@@ -558,11 +543,14 @@ describe('server', () => {
server.wsConnection(ws, request);
expect(trackClientSpy).toHaveBeenCalledWith(
channelId,
socketInstanceExpected,
);
expect(fetchRangeFromStreamSpy).not.toHaveBeenCalled();
const channelSockets = server.channels[channelId];
expect(channelSockets).toEqual({
sockets: expect.any(Array<string>),
});
expect(channelSockets.sockets).toHaveLength(1);
const socketId = channelSockets.sockets[0];
expect(server.sockets[socketId]).toEqual(socketInstanceExpected);
expect(mockRedisXrange).not.toHaveBeenCalled();
expect(wsEventMock).toHaveBeenCalledWith('pong', expect.any(Function));
});
@@ -575,12 +563,15 @@ describe('server', () => {
server.wsConnection(ws, request);
expect(trackClientSpy).toHaveBeenCalledWith(
channelId,
socketInstanceExpected,
);
const channelSockets = server.channels[channelId];
expect(channelSockets).toEqual({
sockets: expect.any(Array<string>),
});
expect(channelSockets.sockets).toHaveLength(1);
// Malformed last_id must not trigger a stream range fetch.
expect(fetchRangeFromStreamSpy).not.toHaveBeenCalled();
const socketId = channelSockets.sockets[0];
expect(server.sockets[socketId]).toEqual(socketInstanceExpected);
});
test('valid JWT, with lastId', async () => {
@@ -593,16 +584,18 @@ describe('server', () => {
server.wsConnection(ws, request);
expect(trackClientSpy).toHaveBeenCalledWith(
channelId,
socketInstanceExpected,
);
expect(fetchRangeFromStreamSpy).toHaveBeenCalledWith({
sessionId: channelId,
startId: '1615426152415-1',
endId: '+',
listener: server.processStreamResults,
const channelSockets = server.channels[channelId];
expect(channelSockets).toEqual({
sockets: expect.any(Array<string>),
});
expect(channelSockets.sockets).toHaveLength(1);
const socketId = channelSockets.sockets[0];
expect(server.sockets[socketId]).toEqual(socketInstanceExpected);
expect(mockRedisXrange).toHaveBeenCalledWith(
expect.stringContaining(channelId),
'1615426152415-1',
'+',
);
expect(wsEventMock).toHaveBeenCalledWith('pong', expect.any(Function));
});
@@ -618,16 +611,18 @@ describe('server', () => {
server.setLastFirehoseId(lastFirehoseId);
server.wsConnection(ws, request);
expect(trackClientSpy).toHaveBeenCalledWith(
channelId,
socketInstanceExpected,
);
expect(fetchRangeFromStreamSpy).toHaveBeenCalledWith({
sessionId: channelId,
startId: '1615426152415-1',
endId: lastFirehoseId,
listener: server.processStreamResults,
const channelSockets = server.channels[channelId];
expect(channelSockets).toEqual({
sockets: expect.any(Array<string>),
});
expect(channelSockets.sockets).toHaveLength(1);
const socketId = channelSockets.sockets[0];
expect(server.sockets[socketId]).toEqual(socketInstanceExpected);
expect(mockRedisXrange).toHaveBeenCalledWith(
expect.stringContaining(channelId),
'1615426152415-1',
lastFirehoseId,
);
expect(wsEventMock).toHaveBeenCalledWith('pong', expect.any(Function));
});
});
@@ -662,7 +657,7 @@ describe('server', () => {
test('total connection limit reached', () => {
server.opts.maxTotalConnections = 1;
const ws = new wsMock('localhost');
setReadyState(ws, WebSocket.OPEN);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
const socketInstance = {
ws,
channel: channelId,
@@ -677,7 +672,7 @@ describe('server', () => {
test('per-channel connection limit reached', () => {
server.opts.maxConnectionsPerChannel = 1;
const ws = new wsMock('localhost');
setReadyState(ws, WebSocket.OPEN);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
const socketInstance = {
ws,
channel: channelId,
@@ -699,7 +694,7 @@ describe('server', () => {
};
server.trackClient(channelId, socketInstance);
// simulate the socket having closed but not yet been GC'd
setReadyState(ws, WebSocket.CLOSED);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.CLOSED);
expect(server.connectionLimitReason('some-other-channel')).toBeNull();
});
@@ -713,13 +708,13 @@ describe('server', () => {
};
server.trackClient(channelId, socketInstance);
// simulate the socket having closed but not yet been GC'd
setReadyState(ws, WebSocket.CLOSED);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.CLOSED);
expect(server.connectionLimitReason(channelId)).toBeNull();
});
test('isSocketActive reflects the socket readyState', () => {
const ws = new wsMock('localhost');
setReadyState(ws, WebSocket.OPEN);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
const socketId = server.trackClient(channelId, {
ws,
channel: channelId,
@@ -727,9 +722,9 @@ describe('server', () => {
});
expect(server.isSocketActive(socketId)).toBe(true);
// CONNECTING is also considered active (see SOCKET_ACTIVE_STATES)
setReadyState(ws, WebSocket.CONNECTING);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.CONNECTING);
expect(server.isSocketActive(socketId)).toBe(true);
setReadyState(ws, WebSocket.CLOSED);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.CLOSED);
expect(server.isSocketActive(socketId)).toBe(false);
// unknown socket ids are never active
expect(server.isSocketActive('does-not-exist')).toBe(false);
@@ -737,14 +732,14 @@ describe('server', () => {
test('activeSocketCount counts only active sockets', () => {
const openWs = new wsMock('localhost');
setReadyState(openWs, WebSocket.OPEN);
vi.spyOn(openWs, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
server.trackClient(channelId, {
ws: openWs,
channel: channelId,
pongTs: Date.now(),
});
const closedWs = new wsMock('localhost');
setReadyState(closedWs, WebSocket.CLOSED);
vi.spyOn(closedWs, 'readyState', 'get').mockReturnValue(WebSocket.CLOSED);
server.trackClient(channelId, {
ws: closedWs,
channel: channelId,
@@ -755,14 +750,14 @@ describe('server', () => {
test('activeChannelSocketCount counts only active sockets on the channel', () => {
const openWs = new wsMock('localhost');
setReadyState(openWs, WebSocket.OPEN);
vi.spyOn(openWs, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
server.trackClient(channelId, {
ws: openWs,
channel: channelId,
pongTs: Date.now(),
});
const closedWs = new wsMock('localhost');
setReadyState(closedWs, WebSocket.CLOSED);
vi.spyOn(closedWs, 'readyState', 'get').mockReturnValue(WebSocket.CLOSED);
server.trackClient(channelId, {
ws: closedWs,
channel: channelId,
@@ -776,7 +771,7 @@ describe('server', () => {
test('wsConnection refuses over-limit connection without tracking', () => {
server.opts.maxConnectionsPerChannel = 1;
const existingWs = new wsMock('localhost');
setReadyState(existingWs, WebSocket.OPEN);
vi.spyOn(existingWs, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
const existing = {
ws: existingWs,
channel: channelId,
@@ -784,7 +779,7 @@ describe('server', () => {
};
server.trackClient(channelId, existing);
const trackClientSpy = jest.spyOn(server, 'trackClient');
const trackClientSpy = vi.spyOn(server, 'trackClient');
const ws = new wsMock('localhost');
const validToken = jwt.sign({ channel: channelId }, config.jwtSecret);
server.wsConnection(ws, getRequest(validToken, 'http://localhost'));
@@ -800,8 +795,8 @@ describe('server', () => {
describe('httpUpgrade', () => {
let socket: net.Socket;
let socketDestroySpy: jest.SpiedFunction<typeof socket.destroy>;
let wssUpgradeSpy: jest.SpiedFunction<typeof server.wss.handleUpgrade>;
let socketDestroySpy: Mock<typeof socket.destroy>;
let wssUpgradeSpy: Mock<typeof server.wss.handleUpgrade>;
const getRequest = (token: string, url: string): http.IncomingMessage => {
const request = new http.IncomingMessage(new net.Socket());
@@ -813,8 +808,8 @@ describe('server', () => {
beforeEach(() => {
socket = new net.Socket();
socketDestroySpy = jest.spyOn(socket, 'destroy');
wssUpgradeSpy = jest.spyOn(server.wss, 'handleUpgrade');
socketDestroySpy = vi.spyOn(socket, 'destroy');
wssUpgradeSpy = vi.spyOn(server.wss, 'handleUpgrade');
});
afterEach(() => {
@@ -952,33 +947,21 @@ describe('server', () => {
});
});
const setReadyState = (ws: WebSocket, value: typeof ws.readyState) => {
// workaround for not being able to do
// spyOn(instance,'readyState','get').and.returnValue(value);
// See for details: https://github.com/facebook/jest/issues/9675
Object.defineProperty(ws, 'readyState', {
configurable: true,
get() {
return value;
},
});
};
describe('checkSockets', () => {
let ws: WebSocket;
let pingSpy: jest.SpiedFunction<typeof ws.ping>;
let terminateSpy: jest.SpiedFunction<typeof ws.terminate>;
let pingSpy: Mock<typeof ws.ping>;
let terminateSpy: Mock<typeof ws.terminate>;
let socketInstance: server.SocketInstance;
beforeEach(() => {
ws = new wsMock('localhost');
pingSpy = jest.spyOn(ws, 'ping');
terminateSpy = jest.spyOn(ws, 'terminate');
pingSpy = vi.spyOn(ws, 'ping');
terminateSpy = vi.spyOn(ws, 'terminate');
socketInstance = { ws: ws, channel: channelId, pongTs: Date.now() };
});
test('active sockets', () => {
setReadyState(ws, WebSocket.OPEN);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
server.trackClient(channelId, socketInstance);
server.checkSockets();
@@ -989,7 +972,7 @@ describe('server', () => {
});
test('stale sockets', () => {
setReadyState(ws, WebSocket.OPEN);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
socketInstance.pongTs = Date.now() - 60000;
server.trackClient(channelId, socketInstance);
@@ -1001,7 +984,7 @@ describe('server', () => {
});
test('closed sockets', () => {
setReadyState(ws, WebSocket.CLOSED);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.CLOSED);
server.trackClient(channelId, socketInstance);
server.checkSockets();
@@ -1027,7 +1010,7 @@ describe('server', () => {
});
test('active sockets', () => {
setReadyState(ws, WebSocket.OPEN);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
server.trackClient(channelId, socketInstance);
server.cleanChannel(channelId);
@@ -1036,7 +1019,7 @@ describe('server', () => {
});
test('closing sockets', () => {
setReadyState(ws, WebSocket.CLOSING);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.CLOSING);
server.trackClient(channelId, socketInstance);
server.cleanChannel(channelId);
@@ -1045,11 +1028,12 @@ describe('server', () => {
});
test('multiple sockets', () => {
setReadyState(ws, WebSocket.OPEN);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
server.trackClient(channelId, socketInstance);
const ws2 = new wsMock('localhost');
setReadyState(ws2, WebSocket.OPEN);
const readyStateSpy = vi.spyOn(ws2, 'readyState', 'get');
readyStateSpy.mockReturnValue(WebSocket.OPEN);
const socketInstance2 = {
ws: ws2,
channel: channelId,
@@ -1061,7 +1045,7 @@ describe('server', () => {
expect(server.channels[channelId].sockets.length).toBe(2);
setReadyState(ws2, WebSocket.CLOSED);
readyStateSpy.mockReturnValue(WebSocket.CLOSED);
server.cleanChannel(channelId);
expect(server.channels[channelId].sockets.length).toBe(1);

View File

@@ -19,11 +19,11 @@
import * as http from 'http';
import * as net from 'net';
import { inspect } from 'util';
import WebSocket, { WebSocketServer } from 'ws';
import { WebSocket, WebSocketServer } from 'ws';
import { randomUUID } from 'crypto';
import jwt, { Algorithm } from 'jsonwebtoken';
import { parse } from 'cookie';
import Redis, { RedisOptions } from 'ioredis';
import { Redis, RedisOptions } from 'ioredis';
import StatsD from 'hot-shots';
import { createLogger } from './logger';

View File

@@ -16,6 +16,7 @@ KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
# Test & development utilities
The files provided here are for testing and development only, and are not required to run the WebSocket server application.

View File

@@ -16,6 +16,7 @@ KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
# Test client application
This Express web application is provided for testing the WebSocket server. It is not required for running the server application, and is provided here for testing and development purposes only.

View File

@@ -17,10 +17,11 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, Callable, TYPE_CHECKING, TypedDict, Union
from typing import Any, Callable, cast, TYPE_CHECKING, TypedDict, Union
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from flask import g
from flask_babel import gettext as __
from marshmallow import fields, Schema
from marshmallow.validate import Range
@@ -38,12 +39,18 @@ from superset.db_engine_specs.base import (
)
from superset.db_engine_specs.hive import HiveEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import OAuth2Error
from superset.utils import json
from superset.utils.core import get_user_agent, QuerySource
from superset.utils.network import is_hostname_valid, is_port_open
if TYPE_CHECKING:
from superset.models.core import Database
from superset.superset_typing import (
OAuth2ClientConfig,
OAuth2State,
OAuth2TokenResponse,
)
try:
@@ -277,6 +284,135 @@ class DatabricksDynamicBaseEngineSpec(BasicParametersMixin, DatabricksBaseEngine
"port": "port",
}
# The Databricks SQL driver has no dedicated authentication exception, so an
# expired or missing token surfaces as a generic driver error. These case-
# insensitive substrings flag the errors that should bootstrap a re-auth.
oauth2_auth_failure_signals = (
"http 401",
"unauthorized",
"unauthenticated",
"invalid access token",
"invalid token",
"expired token",
"token expired",
)
@classmethod
def _workspace_oauth2_endpoint(cls, database: Database, path: str) -> str:
"""
Build a Databricks OAuth2 (U2M) endpoint from the workspace host.
Databricks fronts the user-to-machine OAuth2 flow on every workspace at
``https://<workspace-host>/oidc/v1/{authorize,token}`` across AWS, Azure
and GCP, so the endpoints derive directly from the connection host and
need no account or tenant identifier.
"""
host = database.url_object.host
if not host:
raise OAuth2Error(
"Databricks OAuth2 endpoint could not be resolved: the database "
"connection has no host."
)
return f"https://{host}/oidc/v1/{path}"
@classmethod
def needs_oauth2(cls, ex: Exception) -> bool:
"""
Identify driver errors that should trigger the OAuth2 dance.
Unlike Trino (``TrinoAuthError``) or GSheets (``UnauthenticatedError``),
the Databricks driver raises no dedicated auth exception, so in addition
to the base ``isinstance`` check we match the auth signals above on the
error message (mirrors ``GSheetsEngineSpec.needs_oauth2``).
"""
if not (g and hasattr(g, "user")):
return False
if isinstance(ex, cls.oauth2_exception):
return True
message = str(ex).lower()
return any(signal in message for signal in cls.oauth2_auth_failure_signals)
@classmethod
def get_oauth2_authorization_uri(
cls,
config: "OAuth2ClientConfig",
state: "OAuth2State",
code_verifier: str | None = None,
) -> str:
"""
Return the URI for the initial OAuth2 request.
A fully-resolved ``authorization_request_uri`` from
``DATABASE_OAUTH2_CLIENTS`` is preserved; otherwise the endpoint is
derived from the workspace host (``https://<host>/oidc/v1/authorize``),
which is valid on AWS, Azure and GCP.
"""
if not config.get("authorization_request_uri"):
from superset import db
from superset.models.core import Database
database_id = state["database_id"]
if database := db.session.get(Database, database_id):
config = cast(
"OAuth2ClientConfig",
dict(config)
| {
"authorization_request_uri": cls._workspace_oauth2_endpoint(
database, "authorize"
)
},
)
return super().get_oauth2_authorization_uri(config, state, code_verifier)
@classmethod
def get_oauth2_token(
cls,
config: "OAuth2ClientConfig",
code: str,
code_verifier: str | None = None,
) -> "OAuth2TokenResponse":
"""
Exchange the authorization code for refresh/access tokens.
Token exchange runs in a separate request with no database context, so
the workspace host is not available to derive the endpoint here. Require
a configured ``token_request_uri``
(``https://<workspace-host>/oidc/v1/token``) and fail fast rather than
POST to an unresolved endpoint.
"""
if not config.get("token_request_uri"):
raise OAuth2Error(
"Databricks OAuth2 token endpoint is not configured: set "
"`token_request_uri` to https://<workspace-host>/oidc/v1/token "
"in DATABASE_OAUTH2_CLIENTS."
)
return super().get_oauth2_token(config, code, code_verifier)
@classmethod
def impersonate_user(
cls,
database: Database,
username: str | None,
user_token: str | None,
url: URL,
engine_kwargs: dict[str, Any],
) -> tuple[URL, dict[str, Any]]:
"""
Update connection with OAuth2 access token for user impersonation.
"""
if user_token:
# Replace the access token in the URL with the user's OAuth2 token
url = url.set(password=user_token)
# Also update connect_args if they contain access token
connect_args = engine_kwargs.setdefault("connect_args", {})
if "access_token" in connect_args:
connect_args["access_token"] = user_token
return url, engine_kwargs
@staticmethod
def get_extra_params(
database: Database, source: QuerySource | None = None
@@ -474,6 +610,16 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
supports_dynamic_catalog = True
supports_cross_catalog_queries = True
# OAuth 2.0 support. The flow (endpoint resolution from the workspace host,
# `needs_oauth2` detection) is shared via `DatabricksDynamicBaseEngineSpec`.
supports_oauth2 = True
oauth2_scope = "sql"
# Authorization endpoint is derived from the workspace host at runtime; the
# token endpoint must be configured (no DB context at exchange time).
oauth2_authorization_request_uri = ""
oauth2_token_request_uri = ""
@classmethod
def build_sqlalchemy_uri( # type: ignore
cls, parameters: DatabricksNativeParametersType, *_
@@ -685,6 +831,16 @@ class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
supports_dynamic_schema = supports_catalog = supports_dynamic_catalog = True
# OAuth 2.0 support. The flow (endpoint resolution from the workspace host,
# `needs_oauth2` detection) is shared via `DatabricksDynamicBaseEngineSpec`.
supports_oauth2 = True
oauth2_scope = "sql"
# Authorization endpoint is derived from the workspace host at runtime; the
# token endpoint must be configured (no DB context at exchange time).
oauth2_authorization_request_uri = ""
oauth2_token_request_uri = ""
@classmethod
def build_sqlalchemy_uri( # type: ignore
cls, parameters: DatabricksPythonConnectorParametersType, *_

View File

@@ -17,14 +17,23 @@
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
from datetime import datetime
from typing import Optional
from typing import Any, Optional
from urllib.parse import parse_qs, urlparse
import pytest
from pytest_mock import MockerFixture
from sqlalchemy.engine.url import make_url
from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
from superset.db_engine_specs.base import OAuth2State
from superset.db_engine_specs.databricks import (
DatabricksNativeEngineSpec,
DatabricksPythonConnectorEngineSpec,
)
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.superset_typing import OAuth2ClientConfig
from superset.utils import json
from superset.utils.oauth2 import decode_oauth2_state
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm # noqa: F401
@@ -291,3 +300,595 @@ def test_get_prequeries(mocker: MockerFixture) -> None:
"USE CATALOG `evil`` USE CATALOG bad`",
"USE SCHEMA `evil`` USE SCHEMA bad`",
]
# OAuth2 Tests
def test_oauth2_attributes() -> None:
"""
Test that OAuth2 attributes are properly set for both engine specs.
"""
# Test DatabricksNativeEngineSpec
assert DatabricksNativeEngineSpec.supports_oauth2 is True
assert DatabricksNativeEngineSpec.oauth2_scope == "sql"
# The authorization endpoint is derived from the workspace host at runtime;
# the token endpoint must be configured explicitly.
assert DatabricksNativeEngineSpec.oauth2_authorization_request_uri == ""
assert DatabricksNativeEngineSpec.oauth2_token_request_uri == ""
# Test DatabricksPythonConnectorEngineSpec
assert DatabricksPythonConnectorEngineSpec.supports_oauth2 is True
assert DatabricksPythonConnectorEngineSpec.oauth2_scope == "sql"
assert DatabricksPythonConnectorEngineSpec.oauth2_authorization_request_uri == ""
assert DatabricksPythonConnectorEngineSpec.oauth2_token_request_uri == ""
@pytest.mark.parametrize(
"spec",
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
)
@pytest.mark.parametrize(
"message",
[
"Error during request to server: HTTP 401 Unauthorized",
"Invalid access token",
"The access token expired",
"UNAUTHENTICATED: token is no longer valid",
],
)
def test_needs_oauth2_detects_auth_failure_from_message(
mocker: MockerFixture,
spec: Any,
message: str,
) -> None:
"""
The Databricks driver has no dedicated auth exception, so `needs_oauth2`
matches auth-failure signals in the error message to bootstrap a re-auth.
"""
g = mocker.patch("superset.db_engine_specs.databricks.g")
g.user = mocker.MagicMock()
assert spec.needs_oauth2(Exception(message)) is True
@pytest.mark.parametrize(
"spec",
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
)
@pytest.mark.parametrize(
"message",
[
"Table not found",
# A bare 401 in an unrelated position must not look like an auth failure.
"Query failed at line 401: syntax error",
],
)
def test_needs_oauth2_ignores_unrelated_errors(
mocker: MockerFixture,
spec: Any,
message: str,
) -> None:
"""
A non-auth driver error must not trigger the OAuth2 dance.
"""
g = mocker.patch("superset.db_engine_specs.databricks.g")
g.user = mocker.MagicMock()
assert spec.needs_oauth2(Exception(message)) is False
@pytest.mark.parametrize(
"spec",
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
)
def test_needs_oauth2_matches_oauth2_redirect_error(
mocker: MockerFixture,
spec: Any,
) -> None:
"""
The inherited `isinstance` check against `oauth2_exception` still holds.
"""
g = mocker.patch("superset.db_engine_specs.databricks.g")
g.user = mocker.MagicMock()
ex = OAuth2RedirectError("https://example/authorize", "tab", "redirect")
assert spec.needs_oauth2(ex) is True
def test_impersonate_user_with_token(mocker: MockerFixture) -> None:
"""
Test impersonate_user method with OAuth2 token for DatabricksNativeEngineSpec.
"""
database = mocker.MagicMock()
original_url = make_url(
"databricks+connector://token:original-token@host:443/database"
)
engine_kwargs = {"connect_args": {"access_token": "original-token"}}
# Test with user token
url, kwargs = DatabricksNativeEngineSpec.impersonate_user(
database=database,
username="user1",
user_token="user-oauth-token", # noqa: S106
url=original_url,
engine_kwargs=engine_kwargs,
)
# Check that the password (token) was updated in the URL
assert url.password == "user-oauth-token" # noqa: S105
# Check that access_token was updated in connect_args
assert kwargs["connect_args"]["access_token"] == "user-oauth-token" # noqa: S105
def test_impersonate_user_without_token(mocker: MockerFixture) -> None:
"""
Test impersonate_user method without OAuth2 token.
"""
database = mocker.MagicMock()
original_url = make_url(
"databricks+connector://token:original-token@host:443/database"
)
engine_kwargs = {"connect_args": {"access_token": "original-token"}}
# Test without user token
url, kwargs = DatabricksNativeEngineSpec.impersonate_user(
database=database,
username="user1",
user_token=None,
url=original_url,
engine_kwargs=engine_kwargs,
)
# Check that nothing was changed
assert url.password == "original-token" # noqa: S105
assert kwargs["connect_args"]["access_token"] == "original-token" # noqa: S105
def test_impersonate_user_python_connector(mocker: MockerFixture) -> None:
"""
Test impersonate_user method for DatabricksPythonConnectorEngineSpec.
"""
database = mocker.MagicMock()
original_url = make_url(
"databricks://token:original-token@host:443?http_path=path&catalog=main&schema=default"
)
engine_kwargs = {"connect_args": {"access_token": "original-token"}}
# Test with user token
url, kwargs = DatabricksPythonConnectorEngineSpec.impersonate_user(
database=database,
username="user1",
user_token="user-oauth-token", # noqa: S106
url=original_url,
engine_kwargs=engine_kwargs,
)
# Check that the password (token) was updated in the URL
assert url.password == "user-oauth-token" # noqa: S105
# Check that access_token was updated in connect_args
assert kwargs["connect_args"]["access_token"] == "user-oauth-token" # noqa: S105
@pytest.fixture
def oauth2_config_native() -> OAuth2ClientConfig:
"""
Config for Databricks Native OAuth2.
"""
return {
"id": "databricks-client-id",
"secret": "databricks-client-secret",
"scope": "sql",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/authorize",
"token_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
"request_content_type": "json",
}
@pytest.fixture
def oauth2_config_python() -> OAuth2ClientConfig:
"""
Config for Databricks Python Connector OAuth2.
"""
return {
"id": "databricks-client-id",
"secret": "databricks-client-secret",
"scope": "sql",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/authorize",
"token_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
"request_content_type": "json",
}
def test_is_oauth2_enabled_no_config_native(mocker: MockerFixture) -> None:
"""
Test `is_oauth2_enabled` when OAuth2 is not configured for Native engine.
"""
mocker.patch(
"flask.current_app.config",
new={"DATABASE_OAUTH2_CLIENTS": {}},
)
assert DatabricksNativeEngineSpec.is_oauth2_enabled() is False
def test_is_oauth2_enabled_config_native(mocker: MockerFixture) -> None:
"""
Test `is_oauth2_enabled` when OAuth2 is configured for Native engine.
"""
mocker.patch(
"flask.current_app.config",
new={
"DATABASE_OAUTH2_CLIENTS": {
"Databricks (legacy)": {
"id": "client-id",
"secret": "client-secret",
},
}
},
)
assert DatabricksNativeEngineSpec.is_oauth2_enabled() is True
def test_is_oauth2_enabled_no_config_python(mocker: MockerFixture) -> None:
"""
Test `is_oauth2_enabled` when OAuth2 is not configured for Python Connector engine.
"""
mocker.patch(
"flask.current_app.config",
new={"DATABASE_OAUTH2_CLIENTS": {}},
)
assert DatabricksPythonConnectorEngineSpec.is_oauth2_enabled() is False
def test_is_oauth2_enabled_config_python(mocker: MockerFixture) -> None:
"""
Test `is_oauth2_enabled` when OAuth2 is configured for Python Connector engine.
"""
mocker.patch(
"flask.current_app.config",
new={
"DATABASE_OAUTH2_CLIENTS": {
"Databricks": {
"id": "client-id",
"secret": "client-secret",
},
}
},
)
assert DatabricksPythonConnectorEngineSpec.is_oauth2_enabled() is True
def test_get_oauth2_authorization_uri_native(
mocker: MockerFixture,
oauth2_config_native: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_authorization_uri` for Native engine.
"""
from superset.db_engine_specs.base import OAuth2State
state: OAuth2State = {
"database_id": 1,
"user_id": 1,
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"tab_id": "1234",
}
url = DatabricksNativeEngineSpec.get_oauth2_authorization_uri(
oauth2_config_native, state
)
parsed = urlparse(url)
assert parsed.netloc == "accounts.cloud.databricks.com"
assert parsed.path == "/oidc/accounts/12345/v1/authorize"
query = parse_qs(parsed.query)
assert query["scope"][0] == "sql"
encoded_state = query["state"][0].replace("%2E", ".")
assert decode_oauth2_state(encoded_state) == state
def test_get_oauth2_authorization_uri_python(
mocker: MockerFixture,
oauth2_config_python: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_authorization_uri` for Python Connector engine.
"""
from superset.db_engine_specs.base import OAuth2State
state: OAuth2State = {
"database_id": 1,
"user_id": 1,
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"tab_id": "1234",
}
url = DatabricksPythonConnectorEngineSpec.get_oauth2_authorization_uri(
oauth2_config_python, state
)
parsed = urlparse(url)
assert parsed.netloc == "accounts.cloud.databricks.com"
assert parsed.path == "/oidc/accounts/12345/v1/authorize"
query = parse_qs(parsed.query)
assert query["scope"][0] == "sql"
encoded_state = query["state"][0].replace("%2E", ".")
assert decode_oauth2_state(encoded_state) == state
def test_get_oauth2_token_native(
mocker: MockerFixture,
oauth2_config_native: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_token` for Native engine.
"""
requests = mocker.patch("superset.db_engine_specs.base.requests")
requests.post().json.return_value = {
"access_token": "access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
assert DatabricksNativeEngineSpec.get_oauth2_token(
oauth2_config_native, "authorization-code"
) == {
"access_token": "access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
requests.post.assert_called_with(
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
json={
"code": "authorization-code",
"client_id": "databricks-client-id",
"client_secret": "databricks-client-secret",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"grant_type": "authorization_code",
},
timeout=30.0,
)
def test_get_oauth2_token_python(
mocker: MockerFixture,
oauth2_config_python: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_token` for Python Connector engine.
"""
requests = mocker.patch("superset.db_engine_specs.base.requests")
requests.post().json.return_value = {
"access_token": "access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
assert DatabricksPythonConnectorEngineSpec.get_oauth2_token(
oauth2_config_python, "authorization-code"
) == {
"access_token": "access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
requests.post.assert_called_with(
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
json={
"code": "authorization-code",
"client_id": "databricks-client-id",
"client_secret": "databricks-client-secret",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"grant_type": "authorization_code",
},
timeout=30.0,
)
def test_get_oauth2_fresh_token_native(
mocker: MockerFixture,
oauth2_config_native: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_fresh_token` for Native engine.
"""
requests = mocker.patch("superset.db_engine_specs.base.requests")
requests.post().json.return_value = {
"access_token": "new-access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "new-refresh-token",
}
assert DatabricksNativeEngineSpec.get_oauth2_fresh_token(
oauth2_config_native, "old-refresh-token"
) == {
"access_token": "new-access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "new-refresh-token",
}
requests.post.assert_called_with(
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
json={
"client_id": "databricks-client-id",
"client_secret": "databricks-client-secret",
"refresh_token": "old-refresh-token",
"grant_type": "refresh_token",
},
timeout=30.0,
)
def _oauth2_state() -> OAuth2State:
"""
Build the default OAuth2 state shared by the OAuth2 tests.
"""
state: OAuth2State = {
"database_id": 1,
"user_id": 1,
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"tab_id": "1234",
}
return state
def _unresolved_oauth2_config() -> OAuth2ClientConfig:
"""
Config as built by `get_oauth2_config` when no endpoints are overridden:
the URIs default to the spec's empty class attributes.
"""
return {
"id": "databricks-client-id",
"secret": "databricks-client-secret",
"scope": "sql",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "",
"token_request_uri": "",
"request_content_type": "json",
}
@pytest.mark.parametrize(
"spec",
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
)
@pytest.mark.parametrize(
"host",
[
"dbc-abc.cloud.databricks.com",
"adb-123456789.12.azuredatabricks.net",
"123456789.gcp.databricks.com",
],
)
def test_get_oauth2_authorization_uri_derives_from_workspace_host(
mocker: MockerFixture,
spec: Any,
host: str,
) -> None:
"""
With no configured `authorization_request_uri`, the endpoint is derived from
the workspace host (`https://<host>/oidc/v1/authorize`) on every cloud, with
no account/tenant identifier required.
"""
database = mocker.MagicMock()
database.url_object.host = host
mocker.patch("superset.db.session.get", return_value=database)
url = spec.get_oauth2_authorization_uri(
_unresolved_oauth2_config(), _oauth2_state()
)
parsed = urlparse(url)
assert parsed.netloc == host
assert parsed.path == "/oidc/v1/authorize"
@pytest.mark.parametrize(
"spec",
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
)
def test_get_oauth2_authorization_uri_preserves_configured(
mocker: MockerFixture,
spec: Any,
) -> None:
"""
A fully-resolved `authorization_request_uri` is never overwritten by the
host-derived endpoint, and no database lookup is needed.
"""
session_get = mocker.patch("superset.db.session.get")
config = _unresolved_oauth2_config()
config["authorization_request_uri"] = (
"https://accounts.cloud.databricks.com/oidc/accounts/override/v1/authorize"
)
url = spec.get_oauth2_authorization_uri(config, _oauth2_state())
assert urlparse(url).path == "/oidc/accounts/override/v1/authorize"
session_get.assert_not_called()
@pytest.mark.parametrize(
"spec",
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
)
def test_get_oauth2_authorization_uri_fails_without_host(
mocker: MockerFixture,
spec: Any,
) -> None:
"""
When the endpoint must be derived but the connection has no host, fail fast
instead of emitting an invalid `https:///oidc/v1/authorize` URL.
"""
database = mocker.MagicMock()
database.url_object.host = None
mocker.patch("superset.db.session.get", return_value=database)
with pytest.raises(OAuth2Error):
spec.get_oauth2_authorization_uri(_unresolved_oauth2_config(), _oauth2_state())
@pytest.mark.parametrize(
"spec",
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
)
def test_get_oauth2_token_fails_without_uri(
mocker: MockerFixture,
spec: Any,
) -> None:
"""
Token exchange has no database context to auto-detect the endpoint, so a
missing `token_request_uri` fails fast rather than POSTing to `.../{}/...`.
"""
with pytest.raises(OAuth2Error):
spec.get_oauth2_token(_unresolved_oauth2_config(), "authorization-code")
def test_get_oauth2_fresh_token_python(
mocker: MockerFixture,
oauth2_config_python: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_fresh_token` for Python Connector engine.
"""
requests = mocker.patch("superset.db_engine_specs.base.requests")
requests.post().json.return_value = {
"access_token": "new-access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "new-refresh-token",
}
assert DatabricksPythonConnectorEngineSpec.get_oauth2_fresh_token(
oauth2_config_python, "old-refresh-token"
) == {
"access_token": "new-access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "new-refresh-token",
}
requests.post.assert_called_with(
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
json={
"client_id": "databricks-client-id",
"client_secret": "databricks-client-secret",
"refresh_token": "old-refresh-token",
"grant_type": "refresh_token",
},
timeout=30.0,
)

View File

@@ -0,0 +1,127 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
from typing import Any
from unittest.mock import MagicMock
from urllib.parse import parse_qs, urlparse
import pytest
from pytest_mock import MockerFixture
from superset.db_engine_specs.databricks import (
DatabricksNativeEngineSpec,
DatabricksPythonConnectorEngineSpec,
)
from superset.superset_typing import OAuth2ClientConfig
from superset.utils.oauth2 import decode_oauth2_state
# Multi-Cloud Provider Tests
#
# Databricks fronts the user-to-machine OAuth2 flow on every workspace at
# `https://<workspace-host>/oidc/v1/{authorize,token}`, regardless of cloud, so
# the authorization endpoint derives from the connection host with no per-cloud
# account/tenant identifier.
SPECS = [DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec]
# Representative workspace hosts for each cloud provider.
CLOUD_HOSTS = [
"my-cluster.cloud.databricks.com", # AWS
"adb-123456789.12.azuredatabricks.net", # Azure
"123456789.gcp.databricks.com", # GCP
]
@pytest.fixture
def oauth2_config_no_uri() -> OAuth2ClientConfig:
"""
Config for Databricks OAuth2 without a pre-configured endpoint, so the
authorization endpoint is derived from the workspace host.
"""
return {
"id": "databricks-client-id",
"secret": "databricks-client-secret",
"scope": "sql",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "",
"token_request_uri": "",
"request_content_type": "json",
}
def _mock_database(mocker: MockerFixture, host: str) -> MagicMock:
"""
Build a mock database whose URL resolves to the given workspace host.
"""
database = mocker.MagicMock()
database.url_object.host = host
database.id = 1
return database
@pytest.mark.parametrize("spec", SPECS)
@pytest.mark.parametrize("host", CLOUD_HOSTS)
def test_get_oauth2_authorization_uri_uses_workspace_host(
mocker: MockerFixture,
spec: Any,
host: str,
oauth2_config_no_uri: OAuth2ClientConfig,
) -> None:
"""
The authorization endpoint is the workspace host on AWS, Azure, and GCP.
"""
from superset.db_engine_specs.base import OAuth2State
mocker.patch(
"superset.extensions.db.session.get",
return_value=_mock_database(mocker, host),
)
state: OAuth2State = {
"database_id": 1,
"user_id": 1,
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"tab_id": "1234",
}
url = spec.get_oauth2_authorization_uri(oauth2_config_no_uri, state)
parsed = urlparse(url)
assert parsed.netloc == host
assert parsed.path == "/oidc/v1/authorize"
query = parse_qs(parsed.query)
assert query["scope"][0] == "sql"
encoded_state = query["state"][0].replace("%2E", ".")
assert decode_oauth2_state(encoded_state) == state
@pytest.mark.parametrize("spec", SPECS)
@pytest.mark.parametrize("host", CLOUD_HOSTS)
def test_workspace_oauth2_endpoint_builds_token_uri(
mocker: MockerFixture,
spec: Any,
host: str,
) -> None:
"""
The helper builds the matching token endpoint from the same workspace host.
"""
database = _mock_database(mocker, host)
assert (
spec._workspace_oauth2_endpoint(database, "token")
== f"https://{host}/oidc/v1/token"
)

View File

@@ -81,6 +81,7 @@ def _mock_dashboard(
dashboard.slices = []
dashboard.owners = []
dashboard.tags = []
dashboard.embedded = []
return dashboard