mirror of
https://github.com/apache/superset.git
synced 2026-06-28 10:55:36 +00:00
Compare commits
17 Commits
chore/unif
...
adopt/data
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
563edfdd9f | ||
|
|
98ba9da18c | ||
|
|
910e79dd8c | ||
|
|
6251280aa5 | ||
|
|
9d0a2209c6 | ||
|
|
cb7d5c8847 | ||
|
|
fe69d222bd | ||
|
|
8e0871bb2e | ||
|
|
ea1ba587f2 | ||
|
|
dc99373579 | ||
|
|
8737f010f3 | ||
|
|
278cfbb694 | ||
|
|
1de21ec5c6 | ||
|
|
2ed41ae8a6 | ||
|
|
a0fdb2aa31 | ||
|
|
25c9f3510a | ||
|
|
b8fd2e9725 |
4
.github/workflows/superset-websocket.yml
vendored
4
.github/workflows/superset-websocket.yml
vendored
@@ -28,6 +28,10 @@ jobs:
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0
|
||||
with:
|
||||
node-version-file: './superset-websocket/.nvmrc'
|
||||
- name: Install dependencies
|
||||
working-directory: ./superset-websocket
|
||||
run: npm ci
|
||||
|
||||
@@ -519,6 +519,80 @@ For a connection to a SQL endpoint you need to use the HTTP path from the endpoi
|
||||
{"connect_args": {"http_path": "/sql/1.0/endpoints/****", "driver_path": "/path/to/odbc/driver"}}
|
||||
```
|
||||
|
||||
##### OAuth2 Authentication
|
||||
|
||||
Superset supports OAuth2 authentication for Databricks, allowing users to authenticate with their personal Databricks accounts instead of using shared access tokens. This provides better security and audit capabilities.
|
||||
|
||||
###### Prerequisites
|
||||
|
||||
1. Create an OAuth2 application in your Databricks account:
|
||||
- Go to your Databricks account console
|
||||
- Navigate to **Settings** → **Developer** → **OAuth apps**
|
||||
- Create a new OAuth app with the redirect URI: `http://your-superset-host:port/api/v1/database/oauth2/`
|
||||
|
||||
2. Configure OAuth2 in your `superset_config.py`:
|
||||
|
||||
```python
|
||||
from datetime import timedelta
|
||||
|
||||
# OAuth2 configuration for Databricks
|
||||
# The authorization endpoint is derived from your Databricks workspace host; the
|
||||
# token endpoint must be set explicitly (see notes below).
|
||||
DATABASE_OAUTH2_CLIENTS = {
|
||||
"Databricks (legacy)": {
|
||||
"id": "your-databricks-client-id",
|
||||
"secret": "your-databricks-client-secret",
|
||||
"scope": "sql",
|
||||
"token_request_uri": "https://your-workspace-host/oidc/v1/token",
|
||||
},
|
||||
"Databricks": {
|
||||
"id": "your-databricks-client-id",
|
||||
"secret": "your-databricks-client-secret",
|
||||
"scope": "sql",
|
||||
"token_request_uri": "https://your-workspace-host/oidc/v1/token",
|
||||
},
|
||||
}
|
||||
|
||||
# OAuth2 redirect URI (adjust hostname/port for your setup)
|
||||
DATABASE_OAUTH2_REDIRECT_URI = "http://your-superset-host:port/api/v1/database/oauth2/"
|
||||
|
||||
# Optional: OAuth2 timeout
|
||||
DATABASE_OAUTH2_TIMEOUT = timedelta(seconds=30)
|
||||
```
|
||||
|
||||
Replace the following placeholders:
|
||||
- `your-databricks-client-id`: Your Databricks OAuth2 application client ID
|
||||
- `your-databricks-client-secret`: Your Databricks OAuth2 application client secret
|
||||
- `your-superset-host:port`: Your Superset instance hostname and port
|
||||
|
||||
**Multi-Cloud Provider Support**
|
||||
|
||||
Databricks fronts the user-to-machine (U2M) OAuth2 flow on every workspace at
|
||||
`https://<workspace-host>/oidc/v1/authorize` and
|
||||
`https://<workspace-host>/oidc/v1/token`, regardless of whether the workspace
|
||||
runs on AWS, Azure, or GCP. Superset derives the **authorization** endpoint
|
||||
directly from your connection's host, so no cloud provider or account/tenant
|
||||
identifier needs to be configured.
|
||||
|
||||
The **token** endpoint cannot be auto-derived (token exchange has no database
|
||||
context to read the host), so you must supply `token_request_uri` in
|
||||
`DATABASE_OAUTH2_CLIENTS`, set to `https://<workspace-host>/oidc/v1/token` for
|
||||
your workspace.
|
||||
|
||||
If you supply a fully-resolved `authorization_request_uri` (and/or
|
||||
`token_request_uri`), those values take precedence over the host-derived
|
||||
defaults.
|
||||
|
||||
###### Usage
|
||||
|
||||
Once configured, users can:
|
||||
|
||||
1. Connect to Databricks databases normally using access tokens
|
||||
2. When querying data, Superset will automatically redirect users to authenticate with Databricks if needed
|
||||
3. User-specific OAuth2 tokens will be used for database connections, providing better security and audit trails
|
||||
|
||||
This feature works with both "Databricks (legacy)" and "Databricks" engine types and automatically supports all major cloud providers (AWS, Azure, GCP).
|
||||
|
||||
#### Denodo
|
||||
|
||||
The recommended connector library for Denodo is
|
||||
|
||||
@@ -32,6 +32,7 @@ and therefore are not easily unit-testable. We have instead opted to test the sd
|
||||
This way, the tests can assert that the sdk actually mounts the iframe and communicates with it correctly.
|
||||
|
||||
At time of writing, these tests are not written yet, because we haven't yet put together the demo app that they will leverage.
|
||||
|
||||
### Things to e2e test once we have a demo app:
|
||||
|
||||
**happy path:**
|
||||
|
||||
@@ -41,12 +41,12 @@ npm install --save @superset-ui/embedded-sdk
|
||||
```
|
||||
|
||||
```js
|
||||
import { embedDashboard } from '@superset-ui/embedded-sdk';
|
||||
import { embedDashboard } from "@superset-ui/embedded-sdk";
|
||||
|
||||
embedDashboard({
|
||||
id: 'abc123', // given by the Superset embedding UI
|
||||
supersetDomain: 'https://superset.example.com',
|
||||
mountPoint: document.getElementById('my-superset-container'), // any html element that can contain an iframe
|
||||
id: "abc123", // given by the Superset embedding UI
|
||||
supersetDomain: "https://superset.example.com",
|
||||
mountPoint: document.getElementById("my-superset-container"), // any html element that can contain an iframe
|
||||
fetchGuestToken: () => fetchGuestTokenFromBackend(),
|
||||
dashboardUiConfig: {
|
||||
// dashboard UI config: hideTitle, hideTab, hideChartControls, filters.visible, filters.expanded (optional), urlParams (optional)
|
||||
@@ -55,21 +55,21 @@ embedDashboard({
|
||||
expanded: true,
|
||||
},
|
||||
urlParams: {
|
||||
foo: 'value1',
|
||||
bar: 'value2',
|
||||
foo: "value1",
|
||||
bar: "value2",
|
||||
// themeMode: 'dark', // set the initial theme: 'dark' | 'system' | 'default' (default: 'default')
|
||||
// ...
|
||||
},
|
||||
},
|
||||
// optional additional iframe sandbox attributes
|
||||
iframeSandboxExtras: [
|
||||
'allow-top-navigation',
|
||||
'allow-popups-to-escape-sandbox',
|
||||
"allow-top-navigation",
|
||||
"allow-popups-to-escape-sandbox",
|
||||
],
|
||||
// optional Permissions Policy features
|
||||
iframeAllowExtras: ['clipboard-write', 'fullscreen'],
|
||||
iframeAllowExtras: ["clipboard-write", "fullscreen"],
|
||||
// optional config to enforce a particular referrerPolicy
|
||||
referrerPolicy: 'same-origin',
|
||||
referrerPolicy: "same-origin",
|
||||
// optional callback to customize permalink URLs
|
||||
resolvePermalinkUrl: ({ key }) => `https://my-app.com/analytics/share/${key}`,
|
||||
});
|
||||
@@ -163,13 +163,13 @@ Use the `themeMode` URL parameter to control the embedded dashboard's initial co
|
||||
|
||||
```js
|
||||
embedDashboard({
|
||||
id: 'abc123',
|
||||
supersetDomain: 'https://superset.example.com',
|
||||
mountPoint: document.getElementById('my-superset-container'),
|
||||
id: "abc123",
|
||||
supersetDomain: "https://superset.example.com",
|
||||
mountPoint: document.getElementById("my-superset-container"),
|
||||
fetchGuestToken: () => fetchGuestTokenFromBackend(),
|
||||
dashboardUiConfig: {
|
||||
urlParams: {
|
||||
themeMode: 'dark', // 'dark' | 'system' | 'default' (default: 'default')
|
||||
themeMode: "dark", // 'dark' | 'system' | 'default' (default: 'default')
|
||||
},
|
||||
},
|
||||
});
|
||||
@@ -193,7 +193,7 @@ To pass additional sandbox attributes you can use `iframeSandboxExtras`:
|
||||
|
||||
```js
|
||||
// optional additional iframe sandbox attributes
|
||||
iframeSandboxExtras: ['allow-top-navigation', 'allow-popups-to-escape-sandbox'];
|
||||
iframeSandboxExtras: ["allow-top-navigation", "allow-popups-to-escape-sandbox"];
|
||||
```
|
||||
|
||||
### Permissions Policy
|
||||
@@ -202,7 +202,7 @@ To enable specific browser features within the embedded iframe, use `iframeAllow
|
||||
|
||||
```js
|
||||
// optional Permissions Policy features
|
||||
iframeAllowExtras: ['clipboard-write', 'fullscreen'];
|
||||
iframeAllowExtras: ["clipboard-write", "fullscreen"];
|
||||
```
|
||||
|
||||
Common permissions you might need:
|
||||
@@ -225,9 +225,9 @@ When users click share buttons inside an embedded dashboard, Superset generates
|
||||
|
||||
```js
|
||||
embedDashboard({
|
||||
id: 'abc123',
|
||||
supersetDomain: 'https://superset.example.com',
|
||||
mountPoint: document.getElementById('my-superset-container'),
|
||||
id: "abc123",
|
||||
supersetDomain: "https://superset.example.com",
|
||||
mountPoint: document.getElementById("my-superset-container"),
|
||||
fetchGuestToken: () => fetchGuestTokenFromBackend(),
|
||||
|
||||
// Customize permalink URLs
|
||||
@@ -245,9 +245,9 @@ To restore the dashboard state from a permalink in your app:
|
||||
const permalinkKey = routeParams.key;
|
||||
|
||||
embedDashboard({
|
||||
id: 'abc123',
|
||||
supersetDomain: 'https://superset.example.com',
|
||||
mountPoint: document.getElementById('my-superset-container'),
|
||||
id: "abc123",
|
||||
supersetDomain: "https://superset.example.com",
|
||||
mountPoint: document.getElementById("my-superset-container"),
|
||||
fetchGuestToken: () => fetchGuestTokenFromBackend(),
|
||||
resolvePermalinkUrl: ({ key }) => `https://my-app.com/analytics/share/${key}`,
|
||||
dashboardUiConfig: {
|
||||
|
||||
@@ -18,9 +18,6 @@
|
||||
*/
|
||||
|
||||
module.exports = {
|
||||
presets: [
|
||||
"@babel/preset-typescript",
|
||||
"@babel/preset-env"
|
||||
],
|
||||
presets: ["@babel/preset-typescript", "@babel/preset-env"],
|
||||
sourceMaps: true,
|
||||
};
|
||||
|
||||
10769
superset-embedded-sdk/package-lock.json
generated
10769
superset-embedded-sdk/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -24,7 +24,7 @@
|
||||
"scripts": {
|
||||
"build": "tsc && babel src --out-dir lib --extensions '.ts,.tsx' && webpack --mode production",
|
||||
"ci:release": "node ./release-if-necessary.js",
|
||||
"test": "jest"
|
||||
"test": "vitest --run --dir src"
|
||||
},
|
||||
"browserslist": [
|
||||
"last 3 chrome versions",
|
||||
@@ -41,12 +41,11 @@
|
||||
"@babel/core": "^7.25.2",
|
||||
"@babel/preset-env": "^7.25.4",
|
||||
"@babel/preset-typescript": "^7.24.7",
|
||||
"@types/jest": "^29.5.12",
|
||||
"@types/node": "^22.5.4",
|
||||
"@types/node": "^25.4.0",
|
||||
"babel-loader": "^9.1.3",
|
||||
"jest": "^29.7.0",
|
||||
"tscw-config": "^1.1.2",
|
||||
"typescript": "^5.6.2",
|
||||
"typescript": "^5.9.3",
|
||||
"vitest": "^4.0.18",
|
||||
"webpack": "^5.94.0",
|
||||
"webpack-cli": "^5.1.4"
|
||||
},
|
||||
|
||||
@@ -17,15 +17,15 @@
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
const { execSync } = require('child_process');
|
||||
const { name, version } = require('./package.json');
|
||||
const { execSync } = require("child_process");
|
||||
const { name, version } = require("./package.json");
|
||||
|
||||
function log(...args) {
|
||||
console.log('[embedded-sdk-release]', ...args);
|
||||
console.log("[embedded-sdk-release]", ...args);
|
||||
}
|
||||
|
||||
function logError(...args) {
|
||||
console.error('[embedded-sdk-release]', ...args);
|
||||
console.error("[embedded-sdk-release]", ...args);
|
||||
}
|
||||
|
||||
(async () => {
|
||||
@@ -38,13 +38,13 @@ function logError(...args) {
|
||||
const { status } = await fetch(packageUrl);
|
||||
|
||||
if (status === 200) {
|
||||
log('version already exists on npm, exiting');
|
||||
log("version already exists on npm, exiting");
|
||||
} else if (status === 404) {
|
||||
log('release required, building');
|
||||
log("release required, building");
|
||||
try {
|
||||
execSync('npm run build', { stdio: 'pipe' });
|
||||
log('build successful, publishing')
|
||||
execSync('npm publish --access public', { stdio: 'pipe' });
|
||||
execSync("npm run build", { stdio: "pipe" });
|
||||
log("build successful, publishing");
|
||||
execSync("npm publish --access public", { stdio: "pipe" });
|
||||
log(`published ${version} to npm`);
|
||||
} catch (err) {
|
||||
// npm writes failure details to stderr (auth/permission/registry
|
||||
@@ -52,7 +52,7 @@ function logError(...args) {
|
||||
// the real cause in CI logs.
|
||||
if (err.stdout) console.error(String(err.stdout));
|
||||
if (err.stderr) console.error(String(err.stderr));
|
||||
logError('Encountered an error, details should be above');
|
||||
logError("Encountered an error, details should be above");
|
||||
process.exitCode = 1;
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -18,7 +18,9 @@
|
||||
*/
|
||||
|
||||
export const IFRAME_COMMS_MESSAGE_TYPE = "__embedded_comms__";
|
||||
export const DASHBOARD_UI_FILTER_CONFIG_URL_PARAM_KEY: { [index: string]: any } = {
|
||||
export const DASHBOARD_UI_FILTER_CONFIG_URL_PARAM_KEY: {
|
||||
[index: string]: any;
|
||||
} = {
|
||||
visible: "show_filters",
|
||||
expanded: "expand_filters",
|
||||
}
|
||||
};
|
||||
|
||||
@@ -24,22 +24,23 @@ import {
|
||||
DEFAULT_TOKEN_EXP_MS,
|
||||
DEFAULT_TOKEN_REFRESH_RETRY_MS,
|
||||
} from "./guestTokenRefresh";
|
||||
import { afterAll, beforeAll, it, expect, describe, vi } from "vitest";
|
||||
|
||||
describe("guest token refresh", () => {
|
||||
beforeAll(() => {
|
||||
jest.useFakeTimers();
|
||||
jest.setSystemTime(new Date("2022-03-03 01:00"));
|
||||
jest.spyOn(global, "setTimeout");
|
||||
vi.useFakeTimers();
|
||||
vi.setSystemTime(new Date("2022-03-03 01:00"));
|
||||
vi.spyOn(globalThis, "setTimeout");
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
jest.useRealTimers();
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
function makeFakeJWT(claims: any) {
|
||||
// not a valid jwt, but close enough for this code
|
||||
const tokenifiedClaims = Buffer.from(JSON.stringify(claims)).toString(
|
||||
"base64"
|
||||
"base64",
|
||||
);
|
||||
return `abc.${tokenifiedClaims}.xyz`;
|
||||
}
|
||||
|
||||
@@ -18,17 +18,23 @@
|
||||
*/
|
||||
import { jwtDecode } from "jwt-decode";
|
||||
|
||||
export const REFRESH_TIMING_BUFFER_MS = 5000 // refresh guest token early to avoid failed superset requests
|
||||
export const MIN_REFRESH_WAIT_MS = 10000 // avoid blasting requests as fast as the cpu can handle
|
||||
export const DEFAULT_TOKEN_EXP_MS = 300000 // (5 min) used only when parsing guest token exp fails
|
||||
export const DEFAULT_TOKEN_REFRESH_RETRY_MS = 10000 // wait before retrying a failed/timed-out token refresh
|
||||
export const REFRESH_TIMING_BUFFER_MS = 5000; // refresh guest token early to avoid failed superset requests
|
||||
export const MIN_REFRESH_WAIT_MS = 10000; // avoid blasting requests as fast as the cpu can handle
|
||||
export const DEFAULT_TOKEN_EXP_MS = 300000; // (5 min) used only when parsing guest token exp fails
|
||||
export const DEFAULT_TOKEN_REFRESH_RETRY_MS = 10000; // wait before retrying a failed/timed-out token refresh
|
||||
|
||||
// when do we refresh the guest token?
|
||||
export function getGuestTokenRefreshTiming(currentGuestToken: string) {
|
||||
const parsedJwt = jwtDecode<Record<string, any>>(currentGuestToken);
|
||||
// if exp is int, it is in seconds, but Date() takes milliseconds
|
||||
const exp = new Date(/[^0-9\.]/g.test(parsedJwt.exp) ? parsedJwt.exp : parseFloat(parsedJwt.exp) * 1000);
|
||||
const isValidDate = exp.toString() !== 'Invalid Date';
|
||||
const ttl = isValidDate ? Math.max(MIN_REFRESH_WAIT_MS, exp.getTime() - Date.now()) : DEFAULT_TOKEN_EXP_MS;
|
||||
const exp = new Date(
|
||||
/[^0-9\.]/g.test(parsedJwt.exp)
|
||||
? parsedJwt.exp
|
||||
: parseFloat(parsedJwt.exp) * 1000,
|
||||
);
|
||||
const isValidDate = exp.toString() !== "Invalid Date";
|
||||
const ttl = isValidDate
|
||||
? Math.max(MIN_REFRESH_WAIT_MS, exp.getTime() - Date.now())
|
||||
: DEFAULT_TOKEN_EXP_MS;
|
||||
return ttl - REFRESH_TIMING_BUFFER_MS;
|
||||
}
|
||||
|
||||
@@ -20,15 +20,15 @@
|
||||
import {
|
||||
DASHBOARD_UI_FILTER_CONFIG_URL_PARAM_KEY,
|
||||
IFRAME_COMMS_MESSAGE_TYPE,
|
||||
} from './const';
|
||||
} from "./const";
|
||||
|
||||
// We can swap this out for the actual switchboard package once it gets published
|
||||
import { Switchboard } from '@superset-ui/switchboard';
|
||||
import { Switchboard } from "@superset-ui/switchboard";
|
||||
import {
|
||||
getGuestTokenRefreshTiming,
|
||||
DEFAULT_TOKEN_REFRESH_RETRY_MS,
|
||||
} from './guestTokenRefresh';
|
||||
import { withTimeout } from './withTimeout';
|
||||
} from "./guestTokenRefresh";
|
||||
import { withTimeout } from "./withTimeout";
|
||||
|
||||
/**
|
||||
* The function to fetch a guest token from your Host App's backend server.
|
||||
@@ -97,7 +97,7 @@ export type ObserveDataMaskCallbackFn = (
|
||||
nativeFiltersChanged: boolean;
|
||||
},
|
||||
) => void;
|
||||
export type ThemeMode = 'default' | 'dark' | 'system';
|
||||
export type ThemeMode = "default" | "dark" | "system";
|
||||
|
||||
/**
|
||||
* Callback to resolve permalink URLs.
|
||||
@@ -113,12 +113,12 @@ export type EmbeddedDashboard = {
|
||||
unmount: () => void;
|
||||
getDashboardPermalink: (anchor: string) => Promise<string>;
|
||||
getActiveTabs: () => Promise<string[]>;
|
||||
observeDataMask: (
|
||||
callbackFn: ObserveDataMaskCallbackFn,
|
||||
) => void;
|
||||
observeDataMask: (callbackFn: ObserveDataMaskCallbackFn) => void;
|
||||
getDataMask: () => Promise<Record<string, any>>;
|
||||
getChartStates: () => Promise<Record<string, any>>;
|
||||
getChartDataPayloads: (params?: { chartId?: number }) => Promise<Record<string, any>>;
|
||||
getChartDataPayloads: (params?: {
|
||||
chartId?: number;
|
||||
}) => Promise<Record<string, any>>;
|
||||
setThemeConfig: (themeConfig: Record<string, any>) => void;
|
||||
setThemeMode: (mode: ThemeMode) => void;
|
||||
};
|
||||
@@ -133,7 +133,7 @@ export async function embedDashboard({
|
||||
fetchGuestToken,
|
||||
dashboardUiConfig,
|
||||
debug = false,
|
||||
iframeTitle = 'Embedded Dashboard',
|
||||
iframeTitle = "Embedded Dashboard",
|
||||
iframeSandboxExtras = [],
|
||||
iframeAllowExtras = [],
|
||||
referrerPolicy,
|
||||
@@ -152,13 +152,13 @@ export async function embedDashboard({
|
||||
return withTimeout(
|
||||
fetchGuestToken(),
|
||||
guestTokenFetchTimeoutMs,
|
||||
'fetchGuestToken',
|
||||
"fetchGuestToken",
|
||||
);
|
||||
}
|
||||
|
||||
log('embedding');
|
||||
log("embedding");
|
||||
|
||||
if (supersetDomain.endsWith('/')) {
|
||||
if (supersetDomain.endsWith("/")) {
|
||||
supersetDomain = supersetDomain.slice(0, -1);
|
||||
}
|
||||
|
||||
@@ -185,15 +185,15 @@ export async function embedDashboard({
|
||||
}
|
||||
|
||||
async function mountIframe(): Promise<Switchboard> {
|
||||
return new Promise(resolve => {
|
||||
const iframe = document.createElement('iframe');
|
||||
return new Promise((resolve) => {
|
||||
const iframe = document.createElement("iframe");
|
||||
const dashboardConfigUrlParams = dashboardUiConfig
|
||||
? { uiConfig: `${calculateConfig()}` }
|
||||
: undefined;
|
||||
const filterConfig = dashboardUiConfig?.filters || {};
|
||||
const filterConfigKeys = Object.keys(filterConfig);
|
||||
const filterConfigUrlParams = Object.fromEntries(
|
||||
filterConfigKeys.map(key => [
|
||||
filterConfigKeys.map((key) => [
|
||||
DASHBOARD_UI_FILTER_CONFIG_URL_PARAM_KEY[key],
|
||||
filterConfig[key],
|
||||
]),
|
||||
@@ -206,16 +206,16 @@ export async function embedDashboard({
|
||||
...dashboardUiConfig?.urlParams,
|
||||
};
|
||||
const urlParamsString = Object.keys(urlParams).length
|
||||
? '?' + new URLSearchParams(urlParams).toString()
|
||||
: '';
|
||||
? "?" + new URLSearchParams(urlParams).toString()
|
||||
: "";
|
||||
|
||||
// set up the iframe's sandbox configuration
|
||||
iframe.sandbox.add('allow-same-origin'); // needed for postMessage to work
|
||||
iframe.sandbox.add('allow-scripts'); // obviously the iframe needs scripts
|
||||
iframe.sandbox.add('allow-presentation'); // for fullscreen charts
|
||||
iframe.sandbox.add('allow-downloads'); // for downloading charts as image
|
||||
iframe.sandbox.add('allow-forms'); // for forms to submit
|
||||
iframe.sandbox.add('allow-popups'); // for exporting charts as csv
|
||||
iframe.sandbox.add("allow-same-origin"); // needed for postMessage to work
|
||||
iframe.sandbox.add("allow-scripts"); // obviously the iframe needs scripts
|
||||
iframe.sandbox.add("allow-presentation"); // for fullscreen charts
|
||||
iframe.sandbox.add("allow-downloads"); // for downloading charts as image
|
||||
iframe.sandbox.add("allow-forms"); // for forms to submit
|
||||
iframe.sandbox.add("allow-popups"); // for exporting charts as csv
|
||||
// additional sandbox props
|
||||
iframeSandboxExtras.forEach((key: string) => {
|
||||
iframe.sandbox.add(key);
|
||||
@@ -226,7 +226,7 @@ export async function embedDashboard({
|
||||
}
|
||||
|
||||
// add the event listener before setting src, to be 100% sure that we capture the load event
|
||||
iframe.addEventListener('load', () => {
|
||||
iframe.addEventListener("load", () => {
|
||||
// MessageChannel allows us to send and receive messages smoothly between our window and the iframe
|
||||
// See https://developer.mozilla.org/en-US/docs/Web/API/Channel_Messaging_API
|
||||
const commsChannel = new MessageChannel();
|
||||
@@ -237,35 +237,35 @@ export async function embedDashboard({
|
||||
// See https://developer.mozilla.org/en-US/docs/Web/API/Window/postMessage
|
||||
// we know the content window isn't null because we are in the load event handler.
|
||||
iframe.contentWindow!.postMessage(
|
||||
{ type: IFRAME_COMMS_MESSAGE_TYPE, handshake: 'port transfer' },
|
||||
{ type: IFRAME_COMMS_MESSAGE_TYPE, handshake: "port transfer" },
|
||||
supersetDomain,
|
||||
[theirPort],
|
||||
);
|
||||
log('sent message channel to the iframe');
|
||||
log("sent message channel to the iframe");
|
||||
|
||||
// return our port from the promise
|
||||
resolve(
|
||||
new Switchboard({
|
||||
port: ourPort,
|
||||
name: 'superset-embedded-sdk',
|
||||
name: "superset-embedded-sdk",
|
||||
debug,
|
||||
}),
|
||||
);
|
||||
});
|
||||
iframe.src = `${supersetDomain}/embedded/${id}${urlParamsString}`;
|
||||
iframe.title = iframeTitle;
|
||||
iframe.style.background = 'transparent';
|
||||
iframe.style.background = "transparent";
|
||||
// Permissions Policy features the embedded dashboard relies on. Modern
|
||||
// browsers gate these APIs on the iframe's `allow` attribute regardless
|
||||
// of sandbox flags, so we include them by default. Host apps can extend
|
||||
// the list via `iframeAllowExtras`.
|
||||
const allowFeatures = Array.from(
|
||||
new Set(['fullscreen', 'clipboard-write', ...iframeAllowExtras]),
|
||||
new Set(["fullscreen", "clipboard-write", ...iframeAllowExtras]),
|
||||
);
|
||||
iframe.setAttribute('allow', allowFeatures.join('; '));
|
||||
iframe.setAttribute("allow", allowFeatures.join("; "));
|
||||
//@ts-ignore
|
||||
mountPoint.replaceChildren(iframe);
|
||||
log('placed the iframe');
|
||||
log("placed the iframe");
|
||||
});
|
||||
}
|
||||
|
||||
@@ -285,8 +285,8 @@ export async function embedDashboard({
|
||||
throw err;
|
||||
}
|
||||
|
||||
ourPort.emit('guestToken', { guestToken });
|
||||
log('sent guest token');
|
||||
ourPort.emit("guestToken", { guestToken });
|
||||
log("sent guest token");
|
||||
|
||||
// Track the pending refresh timer so it can be cancelled on unmount, and
|
||||
// stop the cycle once unmounted so it cannot leak across mount/unmount cycles.
|
||||
@@ -298,7 +298,7 @@ export async function embedDashboard({
|
||||
try {
|
||||
const newGuestToken = await fetchGuestTokenWithTimeout();
|
||||
if (unmounted) return;
|
||||
ourPort.emit('guestToken', { guestToken: newGuestToken });
|
||||
ourPort.emit("guestToken", { guestToken: newGuestToken });
|
||||
refreshTimer = setTimeout(
|
||||
refreshGuestToken,
|
||||
getGuestTokenRefreshTiming(newGuestToken),
|
||||
@@ -307,7 +307,7 @@ export async function embedDashboard({
|
||||
// A transient fetch failure or timeout must not permanently stop the
|
||||
// refresh cycle. Log it and retry so the session can recover once the
|
||||
// host callback succeeds again.
|
||||
log('failed to refresh guest token, will retry:', err);
|
||||
log("failed to refresh guest token, will retry:", err);
|
||||
if (unmounted) return;
|
||||
refreshTimer = setTimeout(
|
||||
refreshGuestToken,
|
||||
@@ -325,7 +325,7 @@ export async function embedDashboard({
|
||||
// Returns null if no callback provided or on error, allowing iframe to use default URL
|
||||
ourPort.start();
|
||||
ourPort.defineMethod(
|
||||
'resolvePermalinkUrl',
|
||||
"resolvePermalinkUrl",
|
||||
async ({ key }: { key: string }): Promise<string | null> => {
|
||||
if (!resolvePermalinkUrl) {
|
||||
return null;
|
||||
@@ -333,14 +333,14 @@ export async function embedDashboard({
|
||||
try {
|
||||
return await resolvePermalinkUrl({ key });
|
||||
} catch (error) {
|
||||
log('Error in resolvePermalinkUrl callback:', error);
|
||||
log("Error in resolvePermalinkUrl callback:", error);
|
||||
return null;
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
function unmount() {
|
||||
log('unmounting');
|
||||
log("unmounting");
|
||||
unmounted = true;
|
||||
if (refreshTimer !== undefined) {
|
||||
clearTimeout(refreshTimer);
|
||||
@@ -350,24 +350,25 @@ export async function embedDashboard({
|
||||
mountPoint.replaceChildren();
|
||||
}
|
||||
|
||||
const getScrollSize = () => ourPort.get<Size>('getScrollSize');
|
||||
const getScrollSize = () => ourPort.get<Size>("getScrollSize");
|
||||
const getDashboardPermalink = (anchor: string) =>
|
||||
ourPort.get<string>('getDashboardPermalink', { anchor });
|
||||
const getActiveTabs = () => ourPort.get<string[]>('getActiveTabs');
|
||||
const getDataMask = () => ourPort.get<Record<string, any>>('getDataMask');
|
||||
const getChartStates = () => ourPort.get<Record<string, any>>('getChartStates');
|
||||
ourPort.get<string>("getDashboardPermalink", { anchor });
|
||||
const getActiveTabs = () => ourPort.get<string[]>("getActiveTabs");
|
||||
const getDataMask = () => ourPort.get<Record<string, any>>("getDataMask");
|
||||
const getChartStates = () =>
|
||||
ourPort.get<Record<string, any>>("getChartStates");
|
||||
const getChartDataPayloads = (params?: { chartId?: number }) =>
|
||||
ourPort.get<Record<string, any>>('getChartDataPayloads', params);
|
||||
const observeDataMask = (
|
||||
callbackFn: ObserveDataMaskCallbackFn,
|
||||
) => {
|
||||
ourPort.defineMethod('observeDataMask', callbackFn);
|
||||
ourPort.get<Record<string, any>>("getChartDataPayloads", params);
|
||||
const observeDataMask = (callbackFn: ObserveDataMaskCallbackFn) => {
|
||||
ourPort.defineMethod("observeDataMask", callbackFn);
|
||||
};
|
||||
// TODO: Add proper types once theming branch is merged
|
||||
const setThemeConfig = async (themeConfig: Record<string, any>): Promise<void> => {
|
||||
const setThemeConfig = async (
|
||||
themeConfig: Record<string, any>,
|
||||
): Promise<void> => {
|
||||
try {
|
||||
ourPort.emit('setThemeConfig', { themeConfig });
|
||||
log('Theme config sent successfully (or at least message dispatched)');
|
||||
ourPort.emit("setThemeConfig", { themeConfig });
|
||||
log("Theme config sent successfully (or at least message dispatched)");
|
||||
} catch (error) {
|
||||
log(
|
||||
'Error sending theme config. Ensure the iframe side implements the "setThemeConfig" method.',
|
||||
@@ -378,7 +379,7 @@ export async function embedDashboard({
|
||||
|
||||
const setThemeMode = (mode: ThemeMode): void => {
|
||||
try {
|
||||
ourPort.emit('setThemeMode', { mode });
|
||||
ourPort.emit("setThemeMode", { mode });
|
||||
log(`Theme mode set to: ${mode}`);
|
||||
} catch (error) {
|
||||
log(
|
||||
|
||||
@@ -18,22 +18,23 @@
|
||||
*/
|
||||
|
||||
import { withTimeout } from "./withTimeout";
|
||||
import { test, expect } from "vitest";
|
||||
|
||||
test("resolves with the value when the promise settles in time", async () => {
|
||||
await expect(withTimeout(Promise.resolve("ok"), 1000, "fetch")).resolves.toBe(
|
||||
"ok"
|
||||
"ok",
|
||||
);
|
||||
});
|
||||
|
||||
test("rejects when the promise does not settle within the timeout", async () => {
|
||||
const never = new Promise<string>(() => {});
|
||||
await expect(withTimeout(never, 10, "fetch")).rejects.toThrow(
|
||||
/fetch did not resolve within 10ms/
|
||||
/fetch did not resolve within 10ms/,
|
||||
);
|
||||
});
|
||||
|
||||
test("passes the promise through unchanged when the timeout is disabled", async () => {
|
||||
await expect(withTimeout(Promise.resolve("ok"), 0, "fetch")).resolves.toBe(
|
||||
"ok"
|
||||
"ok",
|
||||
);
|
||||
});
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
// syntax rules
|
||||
"strict": true,
|
||||
|
||||
"moduleResolution": "node",
|
||||
"moduleResolution": "bundler",
|
||||
|
||||
// environment
|
||||
"target": "es6",
|
||||
@@ -13,7 +13,9 @@
|
||||
// output
|
||||
"outDir": "./dist",
|
||||
"emitDeclarationOnly": true,
|
||||
"declaration": true
|
||||
"declaration": true,
|
||||
|
||||
"types": ["node"]
|
||||
},
|
||||
|
||||
"include": [
|
||||
@@ -21,7 +23,6 @@
|
||||
],
|
||||
|
||||
"exclude": [
|
||||
"tests",
|
||||
"dist",
|
||||
"lib",
|
||||
"node_modules"
|
||||
|
||||
@@ -17,19 +17,19 @@
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
const path = require('path');
|
||||
const path = require("path");
|
||||
|
||||
module.exports = {
|
||||
entry: './src/index.ts',
|
||||
entry: "./src/index.ts",
|
||||
output: {
|
||||
filename: 'index.js',
|
||||
path: path.resolve(__dirname, 'bundle'),
|
||||
filename: "index.js",
|
||||
path: path.resolve(__dirname, "bundle"),
|
||||
|
||||
// this exposes the library's exports under a global variable
|
||||
library: {
|
||||
name: "supersetEmbeddedSdk",
|
||||
type: "umd"
|
||||
}
|
||||
type: "umd",
|
||||
},
|
||||
},
|
||||
devtool: "source-map",
|
||||
module: {
|
||||
@@ -38,12 +38,12 @@ module.exports = {
|
||||
test: /\.[tj]s$/,
|
||||
// babel-loader is faster than ts-loader because it ignores types.
|
||||
// We do type checking in a separate process, so that's fine.
|
||||
use: 'babel-loader',
|
||||
use: "babel-loader",
|
||||
exclude: /node_modules/,
|
||||
},
|
||||
],
|
||||
},
|
||||
resolve: {
|
||||
extensions: ['.ts', '.js'],
|
||||
extensions: [".ts", ".js"],
|
||||
},
|
||||
};
|
||||
|
||||
@@ -16,6 +16,7 @@ KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations
|
||||
under the License.
|
||||
-->
|
||||
|
||||
# Superset WebSocket Server
|
||||
|
||||
A Node.js WebSocket server for sending async event data to the Superset web application frontend.
|
||||
@@ -164,4 +165,4 @@ HEAD /health
|
||||
|
||||
## Containerization
|
||||
|
||||
*TODO: containerize websocket server*
|
||||
_TODO: containerize websocket server_
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
module.exports = {
|
||||
preset: 'ts-jest',
|
||||
testEnvironment: 'node',
|
||||
};
|
||||
10617
superset-websocket/package-lock.json
generated
10617
superset-websocket/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -5,7 +5,7 @@
|
||||
"main": "index.js",
|
||||
"scripts": {
|
||||
"start": "node dist/index.js start",
|
||||
"test": "NODE_ENV=test jest -i spec",
|
||||
"test": "npx vitest --run --dir spec",
|
||||
"type": "tsc --noEmit",
|
||||
"eslint": "eslint",
|
||||
"lint": "npm run eslint -- . && npm run type",
|
||||
@@ -28,24 +28,22 @@
|
||||
"devDependencies": {
|
||||
"@eslint/js": "^9.25.1",
|
||||
"@types/eslint__js": "^8.42.3",
|
||||
"@types/jest": "^29.5.14",
|
||||
"@types/jsonwebtoken": "^9.0.10",
|
||||
"@types/lodash": "^4.17.24",
|
||||
"@types/node": "^25.9.3",
|
||||
"@types/ws": "^8.18.1",
|
||||
"@typescript-eslint/eslint-plugin": "^8.61.1",
|
||||
"@typescript-eslint/parser": "^8.61.1",
|
||||
"@typescript-eslint/eslint-plugin": "^8.62.0",
|
||||
"@typescript-eslint/parser": "^8.62.0",
|
||||
"eslint": "^10.5.0",
|
||||
"eslint-config-prettier": "^10.1.8",
|
||||
"eslint-plugin-lodash": "^8.0.0",
|
||||
"globals": "^17.6.0",
|
||||
"jest": "^29.7.0",
|
||||
"prettier": "^3.8.4",
|
||||
"ts-jest": "^29.4.11",
|
||||
"ts-node": "^10.9.2",
|
||||
"tscw-config": "^1.1.2",
|
||||
"typescript": "^6.0.3",
|
||||
"typescript-eslint": "^8.61.1"
|
||||
"typescript-eslint": "^8.62.0",
|
||||
"vitest": "^4.1.5"
|
||||
},
|
||||
"engines": {
|
||||
"node": "^24.16.0",
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
* under the License.
|
||||
*/
|
||||
import { buildConfig } from '../src/config';
|
||||
import { expect, test } from '@jest/globals';
|
||||
import { expect, test } from 'vitest';
|
||||
|
||||
test('buildConfig() builds configuration and applies env var overrides', () => {
|
||||
let config = buildConfig();
|
||||
|
||||
@@ -25,29 +25,29 @@ import {
|
||||
test,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
jest,
|
||||
} from '@jest/globals';
|
||||
vi,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
import * as http from 'http';
|
||||
import * as net from 'net';
|
||||
import { WebSocket } from 'ws';
|
||||
import * as server from '../src/index';
|
||||
import { statsd } from '../src/index';
|
||||
|
||||
interface MockedRedisXrange {
|
||||
(): Promise<server.StreamResult[]>;
|
||||
}
|
||||
|
||||
// NOTE: these mock variables needs to start with "mock" due to
|
||||
// calls to `jest.mock` being hoisted to the top of the file.
|
||||
// https://jestjs.io/docs/es6-class-mocks#calling-jestmock-with-the-module-factory-parameter
|
||||
const mockRedisXrange = jest.fn() as jest.MockedFunction<MockedRedisXrange>;
|
||||
|
||||
jest.mock('ws');
|
||||
jest.mock('ioredis', () => {
|
||||
return jest.fn().mockImplementation(() => {
|
||||
return { xrange: mockRedisXrange, on: jest.fn() };
|
||||
});
|
||||
const { mockRedisXrange } = vi.hoisted(() => {
|
||||
return { mockRedisXrange: vi.fn() };
|
||||
});
|
||||
|
||||
const wsMock = WebSocket as jest.Mocked<typeof WebSocket>;
|
||||
vi.mock('ws');
|
||||
vi.mock('ioredis', () => {
|
||||
return {
|
||||
Redis: vi.fn().mockImplementation(function () {
|
||||
return { xrange: mockRedisXrange, on: vi.fn() };
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
const wsMock = WebSocket as unknown as Mock<typeof WebSocket>;
|
||||
const channelId = 'bc9e040c-7b4a-4817-99b9-292832d97ec7';
|
||||
const streamReturnValue: server.StreamResult[] = [
|
||||
[
|
||||
@@ -66,16 +66,13 @@ const streamReturnValue: server.StreamResult[] = [
|
||||
],
|
||||
];
|
||||
|
||||
import * as server from '../src/index';
|
||||
import { statsd } from '../src/index';
|
||||
|
||||
describe('server', () => {
|
||||
let statsdIncrementMock: jest.SpiedFunction<typeof statsd.increment>;
|
||||
let statsdIncrementMock: Mock<typeof statsd.increment>;
|
||||
|
||||
beforeEach(() => {
|
||||
mockRedisXrange.mockClear();
|
||||
server.resetState();
|
||||
statsdIncrementMock = jest.spyOn(statsd, 'increment').mockReturnValue();
|
||||
statsdIncrementMock = vi.spyOn(statsd, 'increment').mockReturnValue();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -84,8 +81,8 @@ describe('server', () => {
|
||||
|
||||
describe('HTTP requests', () => {
|
||||
test('services health checks', () => {
|
||||
const endMock = jest.fn();
|
||||
const writeHeadMock = jest.fn();
|
||||
const endMock = vi.fn();
|
||||
const writeHeadMock = vi.fn();
|
||||
|
||||
const request = {
|
||||
url: '/health',
|
||||
@@ -113,8 +110,8 @@ describe('server', () => {
|
||||
});
|
||||
|
||||
test('responds with a 404 when not found', () => {
|
||||
const endMock = jest.fn();
|
||||
const writeHeadMock = jest.fn();
|
||||
const endMock = vi.fn();
|
||||
const writeHeadMock = vi.fn();
|
||||
|
||||
const request = {
|
||||
url: '/unsupported',
|
||||
@@ -245,7 +242,7 @@ describe('server', () => {
|
||||
describe('processStreamResults', () => {
|
||||
test('sends data to channel', async () => {
|
||||
const ws = new wsMock('localhost');
|
||||
const sendMock = jest.spyOn(ws, 'send');
|
||||
const sendMock = vi.spyOn(ws, 'send');
|
||||
const socketInstance = { ws: ws, channel: channelId, pongTs: Date.now() };
|
||||
|
||||
expect(statsdIncrementMock).toHaveBeenCalledTimes(0);
|
||||
@@ -267,7 +264,7 @@ describe('server', () => {
|
||||
|
||||
test('channel not present', async () => {
|
||||
const ws = new wsMock('localhost');
|
||||
const sendMock = jest.spyOn(ws, 'send');
|
||||
const sendMock = vi.spyOn(ws, 'send');
|
||||
|
||||
expect(statsdIncrementMock).toHaveBeenCalledTimes(0);
|
||||
server.processStreamResults(streamReturnValue);
|
||||
@@ -278,10 +275,9 @@ describe('server', () => {
|
||||
|
||||
test('error sending data to client', async () => {
|
||||
const ws = new wsMock('localhost');
|
||||
const sendMock = jest.spyOn(ws, 'send').mockImplementation(() => {
|
||||
const sendMock = vi.spyOn(ws, 'send').mockImplementation(() => {
|
||||
throw new Error();
|
||||
});
|
||||
const cleanChannelMock = jest.spyOn(server, 'cleanChannel');
|
||||
const socketInstance = { ws: ws, channel: channelId, pongTs: Date.now() };
|
||||
|
||||
expect(statsdIncrementMock).toHaveBeenCalledTimes(0);
|
||||
@@ -300,9 +296,7 @@ describe('server', () => {
|
||||
);
|
||||
|
||||
expect(sendMock).toHaveBeenCalled();
|
||||
expect(cleanChannelMock).toHaveBeenCalledWith(channelId);
|
||||
|
||||
cleanChannelMock.mockRestore();
|
||||
expect(Object.keys(server.channels)).toHaveLength(0);
|
||||
});
|
||||
|
||||
const makeItem = (i: number): server.StreamResult =>
|
||||
@@ -330,8 +324,8 @@ describe('server', () => {
|
||||
channel: channelId,
|
||||
pongTs: Date.now(),
|
||||
});
|
||||
const sendMock = jest.spyOn(ws, 'send');
|
||||
const setImmediateSpy = jest.spyOn(global, 'setImmediate');
|
||||
const sendMock = vi.spyOn(ws, 'send');
|
||||
const setImmediateSpy = vi.spyOn(global, 'setImmediate');
|
||||
|
||||
const results = [0, 1, 2, 3, 4].map(makeItem);
|
||||
await server.processStreamResults(results);
|
||||
@@ -351,7 +345,7 @@ describe('server', () => {
|
||||
channel: channelId,
|
||||
pongTs: Date.now(),
|
||||
});
|
||||
const sendMock = jest.spyOn(ws, 'send');
|
||||
const sendMock = vi.spyOn(ws, 'send');
|
||||
|
||||
const results = [0, 1, 2, 3, 4].map(makeItem);
|
||||
await server.processStreamResults(results);
|
||||
@@ -372,16 +366,16 @@ describe('server', () => {
|
||||
server.opts.maxSocketBufferBytes = 0;
|
||||
// Restore any spies (e.g. on server.cleanChannel) so they don't leak
|
||||
// across tests and cause order-dependent failures.
|
||||
jest.restoreAllMocks();
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
test('does not terminate when cap disabled (0)', () => {
|
||||
server.opts.maxSocketBufferBytes = 0;
|
||||
const ws = new wsMock('localhost');
|
||||
// simulate a large outbound buffer
|
||||
(ws as unknown as { bufferedAmount: number }).bufferedAmount = 10_000_000;
|
||||
const terminateMock = jest.spyOn(ws, 'terminate');
|
||||
const sendMock = jest.spyOn(ws, 'send');
|
||||
vi.spyOn(ws, 'bufferedAmount', 'get').mockReturnValueOnce(10_000_000);
|
||||
const terminateMock = vi.spyOn(ws, 'terminate');
|
||||
const sendMock = vi.spyOn(ws, 'send');
|
||||
server.trackClient(channelId, {
|
||||
ws,
|
||||
channel: channelId,
|
||||
@@ -397,10 +391,9 @@ describe('server', () => {
|
||||
test('terminates a slow client whose buffer exceeds the cap', () => {
|
||||
server.opts.maxSocketBufferBytes = 1024;
|
||||
const ws = new wsMock('localhost');
|
||||
(ws as unknown as { bufferedAmount: number }).bufferedAmount = 2048;
|
||||
const terminateMock = jest.spyOn(ws, 'terminate');
|
||||
const sendMock = jest.spyOn(ws, 'send');
|
||||
const cleanChannelMock = jest.spyOn(server, 'cleanChannel');
|
||||
vi.spyOn(ws, 'bufferedAmount', 'get').mockReturnValueOnce(2048);
|
||||
const terminateMock = vi.spyOn(ws, 'terminate');
|
||||
const sendMock = vi.spyOn(ws, 'send');
|
||||
server.trackClient(channelId, {
|
||||
ws,
|
||||
channel: channelId,
|
||||
@@ -414,15 +407,15 @@ describe('server', () => {
|
||||
expect(statsdIncrementMock).toHaveBeenCalledWith(
|
||||
'ws_client_backpressure_disconnect',
|
||||
);
|
||||
expect(cleanChannelMock).toHaveBeenCalledWith(channelId);
|
||||
expect(Object.keys(server.channels)).toHaveLength(0);
|
||||
});
|
||||
|
||||
test('keeps sending to a client within the cap', () => {
|
||||
server.opts.maxSocketBufferBytes = 1024;
|
||||
const ws = new wsMock('localhost');
|
||||
(ws as unknown as { bufferedAmount: number }).bufferedAmount = 16;
|
||||
const terminateMock = jest.spyOn(ws, 'terminate');
|
||||
const sendMock = jest.spyOn(ws, 'send');
|
||||
vi.spyOn(ws, 'bufferedAmount', 'get').mockReturnValueOnce(16);
|
||||
const terminateMock = vi.spyOn(ws, 'terminate');
|
||||
const sendMock = vi.spyOn(ws, 'send');
|
||||
server.trackClient(channelId, {
|
||||
ws,
|
||||
channel: channelId,
|
||||
@@ -443,7 +436,7 @@ describe('server', () => {
|
||||
|
||||
test('success with results', async () => {
|
||||
mockRedisXrange.mockResolvedValueOnce(streamReturnValue);
|
||||
const cb = jest.fn() as jest.MockedFunction<
|
||||
const cb = vi.fn() as Mock<
|
||||
(results: server.StreamResult[]) => void | Promise<void>
|
||||
>;
|
||||
await server.fetchRangeFromStream({
|
||||
@@ -462,7 +455,7 @@ describe('server', () => {
|
||||
});
|
||||
|
||||
test('success no results', async () => {
|
||||
const cb = jest.fn() as jest.MockedFunction<
|
||||
const cb = vi.fn() as Mock<
|
||||
(results: server.StreamResult[]) => void | Promise<void>
|
||||
>;
|
||||
await server.fetchRangeFromStream({
|
||||
@@ -481,7 +474,7 @@ describe('server', () => {
|
||||
});
|
||||
|
||||
test('error', async () => {
|
||||
const cb = jest.fn() as jest.MockedFunction<
|
||||
const cb = vi.fn() as Mock<
|
||||
(results: server.StreamResult[]) => void | Promise<void>
|
||||
>;
|
||||
mockRedisXrange.mockRejectedValueOnce(new Error());
|
||||
@@ -503,12 +496,8 @@ describe('server', () => {
|
||||
|
||||
describe('wsConnection', () => {
|
||||
let ws: WebSocket;
|
||||
let wsEventMock: jest.SpiedFunction<typeof ws.on>;
|
||||
let trackClientSpy: jest.SpiedFunction<typeof server.trackClient>;
|
||||
let fetchRangeFromStreamSpy: jest.SpiedFunction<
|
||||
typeof server.fetchRangeFromStream
|
||||
>;
|
||||
let dateNowSpy: jest.SpiedFunction<typeof Date.now>;
|
||||
let wsEventMock: Mock<typeof ws.on>;
|
||||
let dateNowSpy: Mock<typeof Date.now>;
|
||||
let socketInstanceExpected: server.SocketInstance;
|
||||
|
||||
const getRequest = (token: string, url: string): http.IncomingMessage => {
|
||||
@@ -521,10 +510,8 @@ describe('server', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
ws = new wsMock('localhost');
|
||||
wsEventMock = jest.spyOn(ws, 'on');
|
||||
trackClientSpy = jest.spyOn(server, 'trackClient');
|
||||
fetchRangeFromStreamSpy = jest.spyOn(server, 'fetchRangeFromStream');
|
||||
dateNowSpy = jest
|
||||
wsEventMock = vi.spyOn(ws, 'on');
|
||||
dateNowSpy = vi
|
||||
.spyOn(global.Date, 'now')
|
||||
.mockImplementation(() =>
|
||||
new Date('2021-03-10T11:01:58.135Z').valueOf(),
|
||||
@@ -537,10 +524,8 @@ describe('server', () => {
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
wsEventMock.mockRestore();
|
||||
trackClientSpy.mockRestore();
|
||||
fetchRangeFromStreamSpy.mockRestore();
|
||||
dateNowSpy.mockRestore();
|
||||
wsEventMock?.mockRestore();
|
||||
dateNowSpy?.mockRestore();
|
||||
});
|
||||
|
||||
test('invalid JWT', async () => {
|
||||
@@ -558,11 +543,14 @@ describe('server', () => {
|
||||
|
||||
server.wsConnection(ws, request);
|
||||
|
||||
expect(trackClientSpy).toHaveBeenCalledWith(
|
||||
channelId,
|
||||
socketInstanceExpected,
|
||||
);
|
||||
expect(fetchRangeFromStreamSpy).not.toHaveBeenCalled();
|
||||
const channelSockets = server.channels[channelId];
|
||||
expect(channelSockets).toEqual({
|
||||
sockets: expect.any(Array<string>),
|
||||
});
|
||||
expect(channelSockets.sockets).toHaveLength(1);
|
||||
const socketId = channelSockets.sockets[0];
|
||||
expect(server.sockets[socketId]).toEqual(socketInstanceExpected);
|
||||
expect(mockRedisXrange).not.toHaveBeenCalled();
|
||||
expect(wsEventMock).toHaveBeenCalledWith('pong', expect.any(Function));
|
||||
});
|
||||
|
||||
@@ -575,12 +563,15 @@ describe('server', () => {
|
||||
|
||||
server.wsConnection(ws, request);
|
||||
|
||||
expect(trackClientSpy).toHaveBeenCalledWith(
|
||||
channelId,
|
||||
socketInstanceExpected,
|
||||
);
|
||||
const channelSockets = server.channels[channelId];
|
||||
expect(channelSockets).toEqual({
|
||||
sockets: expect.any(Array<string>),
|
||||
});
|
||||
expect(channelSockets.sockets).toHaveLength(1);
|
||||
|
||||
// Malformed last_id must not trigger a stream range fetch.
|
||||
expect(fetchRangeFromStreamSpy).not.toHaveBeenCalled();
|
||||
const socketId = channelSockets.sockets[0];
|
||||
expect(server.sockets[socketId]).toEqual(socketInstanceExpected);
|
||||
});
|
||||
|
||||
test('valid JWT, with lastId', async () => {
|
||||
@@ -593,16 +584,18 @@ describe('server', () => {
|
||||
|
||||
server.wsConnection(ws, request);
|
||||
|
||||
expect(trackClientSpy).toHaveBeenCalledWith(
|
||||
channelId,
|
||||
socketInstanceExpected,
|
||||
);
|
||||
expect(fetchRangeFromStreamSpy).toHaveBeenCalledWith({
|
||||
sessionId: channelId,
|
||||
startId: '1615426152415-1',
|
||||
endId: '+',
|
||||
listener: server.processStreamResults,
|
||||
const channelSockets = server.channels[channelId];
|
||||
expect(channelSockets).toEqual({
|
||||
sockets: expect.any(Array<string>),
|
||||
});
|
||||
expect(channelSockets.sockets).toHaveLength(1);
|
||||
const socketId = channelSockets.sockets[0];
|
||||
expect(server.sockets[socketId]).toEqual(socketInstanceExpected);
|
||||
expect(mockRedisXrange).toHaveBeenCalledWith(
|
||||
expect.stringContaining(channelId),
|
||||
'1615426152415-1',
|
||||
'+',
|
||||
);
|
||||
expect(wsEventMock).toHaveBeenCalledWith('pong', expect.any(Function));
|
||||
});
|
||||
|
||||
@@ -618,16 +611,18 @@ describe('server', () => {
|
||||
server.setLastFirehoseId(lastFirehoseId);
|
||||
server.wsConnection(ws, request);
|
||||
|
||||
expect(trackClientSpy).toHaveBeenCalledWith(
|
||||
channelId,
|
||||
socketInstanceExpected,
|
||||
);
|
||||
expect(fetchRangeFromStreamSpy).toHaveBeenCalledWith({
|
||||
sessionId: channelId,
|
||||
startId: '1615426152415-1',
|
||||
endId: lastFirehoseId,
|
||||
listener: server.processStreamResults,
|
||||
const channelSockets = server.channels[channelId];
|
||||
expect(channelSockets).toEqual({
|
||||
sockets: expect.any(Array<string>),
|
||||
});
|
||||
expect(channelSockets.sockets).toHaveLength(1);
|
||||
const socketId = channelSockets.sockets[0];
|
||||
expect(server.sockets[socketId]).toEqual(socketInstanceExpected);
|
||||
expect(mockRedisXrange).toHaveBeenCalledWith(
|
||||
expect.stringContaining(channelId),
|
||||
'1615426152415-1',
|
||||
lastFirehoseId,
|
||||
);
|
||||
expect(wsEventMock).toHaveBeenCalledWith('pong', expect.any(Function));
|
||||
});
|
||||
});
|
||||
@@ -662,7 +657,7 @@ describe('server', () => {
|
||||
test('total connection limit reached', () => {
|
||||
server.opts.maxTotalConnections = 1;
|
||||
const ws = new wsMock('localhost');
|
||||
setReadyState(ws, WebSocket.OPEN);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
|
||||
const socketInstance = {
|
||||
ws,
|
||||
channel: channelId,
|
||||
@@ -677,7 +672,7 @@ describe('server', () => {
|
||||
test('per-channel connection limit reached', () => {
|
||||
server.opts.maxConnectionsPerChannel = 1;
|
||||
const ws = new wsMock('localhost');
|
||||
setReadyState(ws, WebSocket.OPEN);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
|
||||
const socketInstance = {
|
||||
ws,
|
||||
channel: channelId,
|
||||
@@ -699,7 +694,7 @@ describe('server', () => {
|
||||
};
|
||||
server.trackClient(channelId, socketInstance);
|
||||
// simulate the socket having closed but not yet been GC'd
|
||||
setReadyState(ws, WebSocket.CLOSED);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.CLOSED);
|
||||
expect(server.connectionLimitReason('some-other-channel')).toBeNull();
|
||||
});
|
||||
|
||||
@@ -713,13 +708,13 @@ describe('server', () => {
|
||||
};
|
||||
server.trackClient(channelId, socketInstance);
|
||||
// simulate the socket having closed but not yet been GC'd
|
||||
setReadyState(ws, WebSocket.CLOSED);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.CLOSED);
|
||||
expect(server.connectionLimitReason(channelId)).toBeNull();
|
||||
});
|
||||
|
||||
test('isSocketActive reflects the socket readyState', () => {
|
||||
const ws = new wsMock('localhost');
|
||||
setReadyState(ws, WebSocket.OPEN);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
|
||||
const socketId = server.trackClient(channelId, {
|
||||
ws,
|
||||
channel: channelId,
|
||||
@@ -727,9 +722,9 @@ describe('server', () => {
|
||||
});
|
||||
expect(server.isSocketActive(socketId)).toBe(true);
|
||||
// CONNECTING is also considered active (see SOCKET_ACTIVE_STATES)
|
||||
setReadyState(ws, WebSocket.CONNECTING);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.CONNECTING);
|
||||
expect(server.isSocketActive(socketId)).toBe(true);
|
||||
setReadyState(ws, WebSocket.CLOSED);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.CLOSED);
|
||||
expect(server.isSocketActive(socketId)).toBe(false);
|
||||
// unknown socket ids are never active
|
||||
expect(server.isSocketActive('does-not-exist')).toBe(false);
|
||||
@@ -737,14 +732,14 @@ describe('server', () => {
|
||||
|
||||
test('activeSocketCount counts only active sockets', () => {
|
||||
const openWs = new wsMock('localhost');
|
||||
setReadyState(openWs, WebSocket.OPEN);
|
||||
vi.spyOn(openWs, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
|
||||
server.trackClient(channelId, {
|
||||
ws: openWs,
|
||||
channel: channelId,
|
||||
pongTs: Date.now(),
|
||||
});
|
||||
const closedWs = new wsMock('localhost');
|
||||
setReadyState(closedWs, WebSocket.CLOSED);
|
||||
vi.spyOn(closedWs, 'readyState', 'get').mockReturnValue(WebSocket.CLOSED);
|
||||
server.trackClient(channelId, {
|
||||
ws: closedWs,
|
||||
channel: channelId,
|
||||
@@ -755,14 +750,14 @@ describe('server', () => {
|
||||
|
||||
test('activeChannelSocketCount counts only active sockets on the channel', () => {
|
||||
const openWs = new wsMock('localhost');
|
||||
setReadyState(openWs, WebSocket.OPEN);
|
||||
vi.spyOn(openWs, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
|
||||
server.trackClient(channelId, {
|
||||
ws: openWs,
|
||||
channel: channelId,
|
||||
pongTs: Date.now(),
|
||||
});
|
||||
const closedWs = new wsMock('localhost');
|
||||
setReadyState(closedWs, WebSocket.CLOSED);
|
||||
vi.spyOn(closedWs, 'readyState', 'get').mockReturnValue(WebSocket.CLOSED);
|
||||
server.trackClient(channelId, {
|
||||
ws: closedWs,
|
||||
channel: channelId,
|
||||
@@ -776,7 +771,7 @@ describe('server', () => {
|
||||
test('wsConnection refuses over-limit connection without tracking', () => {
|
||||
server.opts.maxConnectionsPerChannel = 1;
|
||||
const existingWs = new wsMock('localhost');
|
||||
setReadyState(existingWs, WebSocket.OPEN);
|
||||
vi.spyOn(existingWs, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
|
||||
const existing = {
|
||||
ws: existingWs,
|
||||
channel: channelId,
|
||||
@@ -784,7 +779,7 @@ describe('server', () => {
|
||||
};
|
||||
server.trackClient(channelId, existing);
|
||||
|
||||
const trackClientSpy = jest.spyOn(server, 'trackClient');
|
||||
const trackClientSpy = vi.spyOn(server, 'trackClient');
|
||||
const ws = new wsMock('localhost');
|
||||
const validToken = jwt.sign({ channel: channelId }, config.jwtSecret);
|
||||
server.wsConnection(ws, getRequest(validToken, 'http://localhost'));
|
||||
@@ -800,8 +795,8 @@ describe('server', () => {
|
||||
|
||||
describe('httpUpgrade', () => {
|
||||
let socket: net.Socket;
|
||||
let socketDestroySpy: jest.SpiedFunction<typeof socket.destroy>;
|
||||
let wssUpgradeSpy: jest.SpiedFunction<typeof server.wss.handleUpgrade>;
|
||||
let socketDestroySpy: Mock<typeof socket.destroy>;
|
||||
let wssUpgradeSpy: Mock<typeof server.wss.handleUpgrade>;
|
||||
|
||||
const getRequest = (token: string, url: string): http.IncomingMessage => {
|
||||
const request = new http.IncomingMessage(new net.Socket());
|
||||
@@ -813,8 +808,8 @@ describe('server', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
socket = new net.Socket();
|
||||
socketDestroySpy = jest.spyOn(socket, 'destroy');
|
||||
wssUpgradeSpy = jest.spyOn(server.wss, 'handleUpgrade');
|
||||
socketDestroySpy = vi.spyOn(socket, 'destroy');
|
||||
wssUpgradeSpy = vi.spyOn(server.wss, 'handleUpgrade');
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -952,33 +947,21 @@ describe('server', () => {
|
||||
});
|
||||
});
|
||||
|
||||
const setReadyState = (ws: WebSocket, value: typeof ws.readyState) => {
|
||||
// workaround for not being able to do
|
||||
// spyOn(instance,'readyState','get').and.returnValue(value);
|
||||
// See for details: https://github.com/facebook/jest/issues/9675
|
||||
Object.defineProperty(ws, 'readyState', {
|
||||
configurable: true,
|
||||
get() {
|
||||
return value;
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
describe('checkSockets', () => {
|
||||
let ws: WebSocket;
|
||||
let pingSpy: jest.SpiedFunction<typeof ws.ping>;
|
||||
let terminateSpy: jest.SpiedFunction<typeof ws.terminate>;
|
||||
let pingSpy: Mock<typeof ws.ping>;
|
||||
let terminateSpy: Mock<typeof ws.terminate>;
|
||||
let socketInstance: server.SocketInstance;
|
||||
|
||||
beforeEach(() => {
|
||||
ws = new wsMock('localhost');
|
||||
pingSpy = jest.spyOn(ws, 'ping');
|
||||
terminateSpy = jest.spyOn(ws, 'terminate');
|
||||
pingSpy = vi.spyOn(ws, 'ping');
|
||||
terminateSpy = vi.spyOn(ws, 'terminate');
|
||||
socketInstance = { ws: ws, channel: channelId, pongTs: Date.now() };
|
||||
});
|
||||
|
||||
test('active sockets', () => {
|
||||
setReadyState(ws, WebSocket.OPEN);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
|
||||
server.trackClient(channelId, socketInstance);
|
||||
|
||||
server.checkSockets();
|
||||
@@ -989,7 +972,7 @@ describe('server', () => {
|
||||
});
|
||||
|
||||
test('stale sockets', () => {
|
||||
setReadyState(ws, WebSocket.OPEN);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
|
||||
socketInstance.pongTs = Date.now() - 60000;
|
||||
server.trackClient(channelId, socketInstance);
|
||||
|
||||
@@ -1001,7 +984,7 @@ describe('server', () => {
|
||||
});
|
||||
|
||||
test('closed sockets', () => {
|
||||
setReadyState(ws, WebSocket.CLOSED);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.CLOSED);
|
||||
server.trackClient(channelId, socketInstance);
|
||||
|
||||
server.checkSockets();
|
||||
@@ -1027,7 +1010,7 @@ describe('server', () => {
|
||||
});
|
||||
|
||||
test('active sockets', () => {
|
||||
setReadyState(ws, WebSocket.OPEN);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
|
||||
server.trackClient(channelId, socketInstance);
|
||||
|
||||
server.cleanChannel(channelId);
|
||||
@@ -1036,7 +1019,7 @@ describe('server', () => {
|
||||
});
|
||||
|
||||
test('closing sockets', () => {
|
||||
setReadyState(ws, WebSocket.CLOSING);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.CLOSING);
|
||||
server.trackClient(channelId, socketInstance);
|
||||
|
||||
server.cleanChannel(channelId);
|
||||
@@ -1045,11 +1028,12 @@ describe('server', () => {
|
||||
});
|
||||
|
||||
test('multiple sockets', () => {
|
||||
setReadyState(ws, WebSocket.OPEN);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
|
||||
server.trackClient(channelId, socketInstance);
|
||||
|
||||
const ws2 = new wsMock('localhost');
|
||||
setReadyState(ws2, WebSocket.OPEN);
|
||||
const readyStateSpy = vi.spyOn(ws2, 'readyState', 'get');
|
||||
readyStateSpy.mockReturnValue(WebSocket.OPEN);
|
||||
const socketInstance2 = {
|
||||
ws: ws2,
|
||||
channel: channelId,
|
||||
@@ -1061,7 +1045,7 @@ describe('server', () => {
|
||||
|
||||
expect(server.channels[channelId].sockets.length).toBe(2);
|
||||
|
||||
setReadyState(ws2, WebSocket.CLOSED);
|
||||
readyStateSpy.mockReturnValue(WebSocket.CLOSED);
|
||||
server.cleanChannel(channelId);
|
||||
|
||||
expect(server.channels[channelId].sockets.length).toBe(1);
|
||||
|
||||
@@ -19,11 +19,11 @@
|
||||
import * as http from 'http';
|
||||
import * as net from 'net';
|
||||
import { inspect } from 'util';
|
||||
import WebSocket, { WebSocketServer } from 'ws';
|
||||
import { WebSocket, WebSocketServer } from 'ws';
|
||||
import { randomUUID } from 'crypto';
|
||||
import jwt, { Algorithm } from 'jsonwebtoken';
|
||||
import { parse } from 'cookie';
|
||||
import Redis, { RedisOptions } from 'ioredis';
|
||||
import { Redis, RedisOptions } from 'ioredis';
|
||||
import StatsD from 'hot-shots';
|
||||
|
||||
import { createLogger } from './logger';
|
||||
|
||||
@@ -16,6 +16,7 @@ KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations
|
||||
under the License.
|
||||
-->
|
||||
|
||||
# Test & development utilities
|
||||
|
||||
The files provided here are for testing and development only, and are not required to run the WebSocket server application.
|
||||
|
||||
@@ -16,6 +16,7 @@ KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations
|
||||
under the License.
|
||||
-->
|
||||
|
||||
# Test client application
|
||||
|
||||
This Express web application is provided for testing the WebSocket server. It is not required for running the server application, and is provided here for testing and development purposes only.
|
||||
|
||||
@@ -17,10 +17,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, TYPE_CHECKING, TypedDict, Union
|
||||
from typing import Any, Callable, cast, TYPE_CHECKING, TypedDict, Union
|
||||
|
||||
from apispec import APISpec
|
||||
from apispec.ext.marshmallow import MarshmallowPlugin
|
||||
from flask import g
|
||||
from flask_babel import gettext as __
|
||||
from marshmallow import fields, Schema
|
||||
from marshmallow.validate import Range
|
||||
@@ -38,12 +39,18 @@ from superset.db_engine_specs.base import (
|
||||
)
|
||||
from superset.db_engine_specs.hive import HiveEngineSpec
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import OAuth2Error
|
||||
from superset.utils import json
|
||||
from superset.utils.core import get_user_agent, QuerySource
|
||||
from superset.utils.network import is_hostname_valid, is_port_open
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.core import Database
|
||||
from superset.superset_typing import (
|
||||
OAuth2ClientConfig,
|
||||
OAuth2State,
|
||||
OAuth2TokenResponse,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
@@ -277,6 +284,135 @@ class DatabricksDynamicBaseEngineSpec(BasicParametersMixin, DatabricksBaseEngine
|
||||
"port": "port",
|
||||
}
|
||||
|
||||
# The Databricks SQL driver has no dedicated authentication exception, so an
|
||||
# expired or missing token surfaces as a generic driver error. These case-
|
||||
# insensitive substrings flag the errors that should bootstrap a re-auth.
|
||||
oauth2_auth_failure_signals = (
|
||||
"http 401",
|
||||
"unauthorized",
|
||||
"unauthenticated",
|
||||
"invalid access token",
|
||||
"invalid token",
|
||||
"expired token",
|
||||
"token expired",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _workspace_oauth2_endpoint(cls, database: Database, path: str) -> str:
|
||||
"""
|
||||
Build a Databricks OAuth2 (U2M) endpoint from the workspace host.
|
||||
|
||||
Databricks fronts the user-to-machine OAuth2 flow on every workspace at
|
||||
``https://<workspace-host>/oidc/v1/{authorize,token}`` across AWS, Azure
|
||||
and GCP, so the endpoints derive directly from the connection host and
|
||||
need no account or tenant identifier.
|
||||
"""
|
||||
host = database.url_object.host
|
||||
if not host:
|
||||
raise OAuth2Error(
|
||||
"Databricks OAuth2 endpoint could not be resolved: the database "
|
||||
"connection has no host."
|
||||
)
|
||||
return f"https://{host}/oidc/v1/{path}"
|
||||
|
||||
@classmethod
|
||||
def needs_oauth2(cls, ex: Exception) -> bool:
|
||||
"""
|
||||
Identify driver errors that should trigger the OAuth2 dance.
|
||||
|
||||
Unlike Trino (``TrinoAuthError``) or GSheets (``UnauthenticatedError``),
|
||||
the Databricks driver raises no dedicated auth exception, so in addition
|
||||
to the base ``isinstance`` check we match the auth signals above on the
|
||||
error message (mirrors ``GSheetsEngineSpec.needs_oauth2``).
|
||||
"""
|
||||
if not (g and hasattr(g, "user")):
|
||||
return False
|
||||
if isinstance(ex, cls.oauth2_exception):
|
||||
return True
|
||||
message = str(ex).lower()
|
||||
return any(signal in message for signal in cls.oauth2_auth_failure_signals)
|
||||
|
||||
@classmethod
|
||||
def get_oauth2_authorization_uri(
|
||||
cls,
|
||||
config: "OAuth2ClientConfig",
|
||||
state: "OAuth2State",
|
||||
code_verifier: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Return the URI for the initial OAuth2 request.
|
||||
|
||||
A fully-resolved ``authorization_request_uri`` from
|
||||
``DATABASE_OAUTH2_CLIENTS`` is preserved; otherwise the endpoint is
|
||||
derived from the workspace host (``https://<host>/oidc/v1/authorize``),
|
||||
which is valid on AWS, Azure and GCP.
|
||||
"""
|
||||
if not config.get("authorization_request_uri"):
|
||||
from superset import db
|
||||
from superset.models.core import Database
|
||||
|
||||
database_id = state["database_id"]
|
||||
if database := db.session.get(Database, database_id):
|
||||
config = cast(
|
||||
"OAuth2ClientConfig",
|
||||
dict(config)
|
||||
| {
|
||||
"authorization_request_uri": cls._workspace_oauth2_endpoint(
|
||||
database, "authorize"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
return super().get_oauth2_authorization_uri(config, state, code_verifier)
|
||||
|
||||
@classmethod
|
||||
def get_oauth2_token(
|
||||
cls,
|
||||
config: "OAuth2ClientConfig",
|
||||
code: str,
|
||||
code_verifier: str | None = None,
|
||||
) -> "OAuth2TokenResponse":
|
||||
"""
|
||||
Exchange the authorization code for refresh/access tokens.
|
||||
|
||||
Token exchange runs in a separate request with no database context, so
|
||||
the workspace host is not available to derive the endpoint here. Require
|
||||
a configured ``token_request_uri``
|
||||
(``https://<workspace-host>/oidc/v1/token``) and fail fast rather than
|
||||
POST to an unresolved endpoint.
|
||||
"""
|
||||
if not config.get("token_request_uri"):
|
||||
raise OAuth2Error(
|
||||
"Databricks OAuth2 token endpoint is not configured: set "
|
||||
"`token_request_uri` to https://<workspace-host>/oidc/v1/token "
|
||||
"in DATABASE_OAUTH2_CLIENTS."
|
||||
)
|
||||
|
||||
return super().get_oauth2_token(config, code, code_verifier)
|
||||
|
||||
@classmethod
|
||||
def impersonate_user(
|
||||
cls,
|
||||
database: Database,
|
||||
username: str | None,
|
||||
user_token: str | None,
|
||||
url: URL,
|
||||
engine_kwargs: dict[str, Any],
|
||||
) -> tuple[URL, dict[str, Any]]:
|
||||
"""
|
||||
Update connection with OAuth2 access token for user impersonation.
|
||||
"""
|
||||
if user_token:
|
||||
# Replace the access token in the URL with the user's OAuth2 token
|
||||
url = url.set(password=user_token)
|
||||
|
||||
# Also update connect_args if they contain access token
|
||||
connect_args = engine_kwargs.setdefault("connect_args", {})
|
||||
if "access_token" in connect_args:
|
||||
connect_args["access_token"] = user_token
|
||||
|
||||
return url, engine_kwargs
|
||||
|
||||
@staticmethod
|
||||
def get_extra_params(
|
||||
database: Database, source: QuerySource | None = None
|
||||
@@ -474,6 +610,16 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
|
||||
supports_dynamic_catalog = True
|
||||
supports_cross_catalog_queries = True
|
||||
|
||||
# OAuth 2.0 support. The flow (endpoint resolution from the workspace host,
|
||||
# `needs_oauth2` detection) is shared via `DatabricksDynamicBaseEngineSpec`.
|
||||
supports_oauth2 = True
|
||||
oauth2_scope = "sql"
|
||||
|
||||
# Authorization endpoint is derived from the workspace host at runtime; the
|
||||
# token endpoint must be configured (no DB context at exchange time).
|
||||
oauth2_authorization_request_uri = ""
|
||||
oauth2_token_request_uri = ""
|
||||
|
||||
@classmethod
|
||||
def build_sqlalchemy_uri( # type: ignore
|
||||
cls, parameters: DatabricksNativeParametersType, *_
|
||||
@@ -685,6 +831,16 @@ class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
|
||||
|
||||
supports_dynamic_schema = supports_catalog = supports_dynamic_catalog = True
|
||||
|
||||
# OAuth 2.0 support. The flow (endpoint resolution from the workspace host,
|
||||
# `needs_oauth2` detection) is shared via `DatabricksDynamicBaseEngineSpec`.
|
||||
supports_oauth2 = True
|
||||
oauth2_scope = "sql"
|
||||
|
||||
# Authorization endpoint is derived from the workspace host at runtime; the
|
||||
# token endpoint must be configured (no DB context at exchange time).
|
||||
oauth2_authorization_request_uri = ""
|
||||
oauth2_token_request_uri = ""
|
||||
|
||||
@classmethod
|
||||
def build_sqlalchemy_uri( # type: ignore
|
||||
cls, parameters: DatabricksPythonConnectorParametersType, *_
|
||||
|
||||
@@ -17,14 +17,23 @@
|
||||
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
|
||||
from superset.db_engine_specs.base import OAuth2State
|
||||
from superset.db_engine_specs.databricks import (
|
||||
DatabricksNativeEngineSpec,
|
||||
DatabricksPythonConnectorEngineSpec,
|
||||
)
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError
|
||||
from superset.superset_typing import OAuth2ClientConfig
|
||||
from superset.utils import json
|
||||
from superset.utils.oauth2 import decode_oauth2_state
|
||||
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
|
||||
from tests.unit_tests.fixtures.common import dttm # noqa: F401
|
||||
|
||||
@@ -291,3 +300,595 @@ def test_get_prequeries(mocker: MockerFixture) -> None:
|
||||
"USE CATALOG `evil`` USE CATALOG bad`",
|
||||
"USE SCHEMA `evil`` USE SCHEMA bad`",
|
||||
]
|
||||
|
||||
|
||||
# OAuth2 Tests
|
||||
|
||||
|
||||
def test_oauth2_attributes() -> None:
|
||||
"""
|
||||
Test that OAuth2 attributes are properly set for both engine specs.
|
||||
"""
|
||||
# Test DatabricksNativeEngineSpec
|
||||
assert DatabricksNativeEngineSpec.supports_oauth2 is True
|
||||
assert DatabricksNativeEngineSpec.oauth2_scope == "sql"
|
||||
# The authorization endpoint is derived from the workspace host at runtime;
|
||||
# the token endpoint must be configured explicitly.
|
||||
assert DatabricksNativeEngineSpec.oauth2_authorization_request_uri == ""
|
||||
assert DatabricksNativeEngineSpec.oauth2_token_request_uri == ""
|
||||
|
||||
# Test DatabricksPythonConnectorEngineSpec
|
||||
assert DatabricksPythonConnectorEngineSpec.supports_oauth2 is True
|
||||
assert DatabricksPythonConnectorEngineSpec.oauth2_scope == "sql"
|
||||
assert DatabricksPythonConnectorEngineSpec.oauth2_authorization_request_uri == ""
|
||||
assert DatabricksPythonConnectorEngineSpec.oauth2_token_request_uri == ""
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec",
|
||||
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"message",
|
||||
[
|
||||
"Error during request to server: HTTP 401 Unauthorized",
|
||||
"Invalid access token",
|
||||
"The access token expired",
|
||||
"UNAUTHENTICATED: token is no longer valid",
|
||||
],
|
||||
)
|
||||
def test_needs_oauth2_detects_auth_failure_from_message(
|
||||
mocker: MockerFixture,
|
||||
spec: Any,
|
||||
message: str,
|
||||
) -> None:
|
||||
"""
|
||||
The Databricks driver has no dedicated auth exception, so `needs_oauth2`
|
||||
matches auth-failure signals in the error message to bootstrap a re-auth.
|
||||
"""
|
||||
g = mocker.patch("superset.db_engine_specs.databricks.g")
|
||||
g.user = mocker.MagicMock()
|
||||
|
||||
assert spec.needs_oauth2(Exception(message)) is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec",
|
||||
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"message",
|
||||
[
|
||||
"Table not found",
|
||||
# A bare 401 in an unrelated position must not look like an auth failure.
|
||||
"Query failed at line 401: syntax error",
|
||||
],
|
||||
)
|
||||
def test_needs_oauth2_ignores_unrelated_errors(
|
||||
mocker: MockerFixture,
|
||||
spec: Any,
|
||||
message: str,
|
||||
) -> None:
|
||||
"""
|
||||
A non-auth driver error must not trigger the OAuth2 dance.
|
||||
"""
|
||||
g = mocker.patch("superset.db_engine_specs.databricks.g")
|
||||
g.user = mocker.MagicMock()
|
||||
|
||||
assert spec.needs_oauth2(Exception(message)) is False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec",
|
||||
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
|
||||
)
|
||||
def test_needs_oauth2_matches_oauth2_redirect_error(
|
||||
mocker: MockerFixture,
|
||||
spec: Any,
|
||||
) -> None:
|
||||
"""
|
||||
The inherited `isinstance` check against `oauth2_exception` still holds.
|
||||
"""
|
||||
g = mocker.patch("superset.db_engine_specs.databricks.g")
|
||||
g.user = mocker.MagicMock()
|
||||
|
||||
ex = OAuth2RedirectError("https://example/authorize", "tab", "redirect")
|
||||
assert spec.needs_oauth2(ex) is True
|
||||
|
||||
|
||||
def test_impersonate_user_with_token(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test impersonate_user method with OAuth2 token for DatabricksNativeEngineSpec.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
original_url = make_url(
|
||||
"databricks+connector://token:original-token@host:443/database"
|
||||
)
|
||||
engine_kwargs = {"connect_args": {"access_token": "original-token"}}
|
||||
|
||||
# Test with user token
|
||||
url, kwargs = DatabricksNativeEngineSpec.impersonate_user(
|
||||
database=database,
|
||||
username="user1",
|
||||
user_token="user-oauth-token", # noqa: S106
|
||||
url=original_url,
|
||||
engine_kwargs=engine_kwargs,
|
||||
)
|
||||
|
||||
# Check that the password (token) was updated in the URL
|
||||
assert url.password == "user-oauth-token" # noqa: S105
|
||||
# Check that access_token was updated in connect_args
|
||||
assert kwargs["connect_args"]["access_token"] == "user-oauth-token" # noqa: S105
|
||||
|
||||
|
||||
def test_impersonate_user_without_token(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test impersonate_user method without OAuth2 token.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
original_url = make_url(
|
||||
"databricks+connector://token:original-token@host:443/database"
|
||||
)
|
||||
engine_kwargs = {"connect_args": {"access_token": "original-token"}}
|
||||
|
||||
# Test without user token
|
||||
url, kwargs = DatabricksNativeEngineSpec.impersonate_user(
|
||||
database=database,
|
||||
username="user1",
|
||||
user_token=None,
|
||||
url=original_url,
|
||||
engine_kwargs=engine_kwargs,
|
||||
)
|
||||
|
||||
# Check that nothing was changed
|
||||
assert url.password == "original-token" # noqa: S105
|
||||
assert kwargs["connect_args"]["access_token"] == "original-token" # noqa: S105
|
||||
|
||||
|
||||
def test_impersonate_user_python_connector(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test impersonate_user method for DatabricksPythonConnectorEngineSpec.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
original_url = make_url(
|
||||
"databricks://token:original-token@host:443?http_path=path&catalog=main&schema=default"
|
||||
)
|
||||
engine_kwargs = {"connect_args": {"access_token": "original-token"}}
|
||||
|
||||
# Test with user token
|
||||
url, kwargs = DatabricksPythonConnectorEngineSpec.impersonate_user(
|
||||
database=database,
|
||||
username="user1",
|
||||
user_token="user-oauth-token", # noqa: S106
|
||||
url=original_url,
|
||||
engine_kwargs=engine_kwargs,
|
||||
)
|
||||
|
||||
# Check that the password (token) was updated in the URL
|
||||
assert url.password == "user-oauth-token" # noqa: S105
|
||||
# Check that access_token was updated in connect_args
|
||||
assert kwargs["connect_args"]["access_token"] == "user-oauth-token" # noqa: S105
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_config_native() -> OAuth2ClientConfig:
|
||||
"""
|
||||
Config for Databricks Native OAuth2.
|
||||
"""
|
||||
return {
|
||||
"id": "databricks-client-id",
|
||||
"secret": "databricks-client-secret",
|
||||
"scope": "sql",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"authorization_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/authorize",
|
||||
"token_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
|
||||
"request_content_type": "json",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_config_python() -> OAuth2ClientConfig:
|
||||
"""
|
||||
Config for Databricks Python Connector OAuth2.
|
||||
"""
|
||||
return {
|
||||
"id": "databricks-client-id",
|
||||
"secret": "databricks-client-secret",
|
||||
"scope": "sql",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"authorization_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/authorize",
|
||||
"token_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
|
||||
"request_content_type": "json",
|
||||
}
|
||||
|
||||
|
||||
def test_is_oauth2_enabled_no_config_native(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test `is_oauth2_enabled` when OAuth2 is not configured for Native engine.
|
||||
"""
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
new={"DATABASE_OAUTH2_CLIENTS": {}},
|
||||
)
|
||||
|
||||
assert DatabricksNativeEngineSpec.is_oauth2_enabled() is False
|
||||
|
||||
|
||||
def test_is_oauth2_enabled_config_native(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test `is_oauth2_enabled` when OAuth2 is configured for Native engine.
|
||||
"""
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
new={
|
||||
"DATABASE_OAUTH2_CLIENTS": {
|
||||
"Databricks (legacy)": {
|
||||
"id": "client-id",
|
||||
"secret": "client-secret",
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert DatabricksNativeEngineSpec.is_oauth2_enabled() is True
|
||||
|
||||
|
||||
def test_is_oauth2_enabled_no_config_python(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test `is_oauth2_enabled` when OAuth2 is not configured for Python Connector engine.
|
||||
"""
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
new={"DATABASE_OAUTH2_CLIENTS": {}},
|
||||
)
|
||||
|
||||
assert DatabricksPythonConnectorEngineSpec.is_oauth2_enabled() is False
|
||||
|
||||
|
||||
def test_is_oauth2_enabled_config_python(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test `is_oauth2_enabled` when OAuth2 is configured for Python Connector engine.
|
||||
"""
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
new={
|
||||
"DATABASE_OAUTH2_CLIENTS": {
|
||||
"Databricks": {
|
||||
"id": "client-id",
|
||||
"secret": "client-secret",
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert DatabricksPythonConnectorEngineSpec.is_oauth2_enabled() is True
|
||||
|
||||
|
||||
def test_get_oauth2_authorization_uri_native(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config_native: OAuth2ClientConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Test `get_oauth2_authorization_uri` for Native engine.
|
||||
"""
|
||||
from superset.db_engine_specs.base import OAuth2State
|
||||
|
||||
state: OAuth2State = {
|
||||
"database_id": 1,
|
||||
"user_id": 1,
|
||||
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"tab_id": "1234",
|
||||
}
|
||||
|
||||
url = DatabricksNativeEngineSpec.get_oauth2_authorization_uri(
|
||||
oauth2_config_native, state
|
||||
)
|
||||
parsed = urlparse(url)
|
||||
assert parsed.netloc == "accounts.cloud.databricks.com"
|
||||
assert parsed.path == "/oidc/accounts/12345/v1/authorize"
|
||||
|
||||
query = parse_qs(parsed.query)
|
||||
assert query["scope"][0] == "sql"
|
||||
encoded_state = query["state"][0].replace("%2E", ".")
|
||||
assert decode_oauth2_state(encoded_state) == state
|
||||
|
||||
|
||||
def test_get_oauth2_authorization_uri_python(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config_python: OAuth2ClientConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Test `get_oauth2_authorization_uri` for Python Connector engine.
|
||||
"""
|
||||
from superset.db_engine_specs.base import OAuth2State
|
||||
|
||||
state: OAuth2State = {
|
||||
"database_id": 1,
|
||||
"user_id": 1,
|
||||
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"tab_id": "1234",
|
||||
}
|
||||
|
||||
url = DatabricksPythonConnectorEngineSpec.get_oauth2_authorization_uri(
|
||||
oauth2_config_python, state
|
||||
)
|
||||
parsed = urlparse(url)
|
||||
assert parsed.netloc == "accounts.cloud.databricks.com"
|
||||
assert parsed.path == "/oidc/accounts/12345/v1/authorize"
|
||||
|
||||
query = parse_qs(parsed.query)
|
||||
assert query["scope"][0] == "sql"
|
||||
encoded_state = query["state"][0].replace("%2E", ".")
|
||||
assert decode_oauth2_state(encoded_state) == state
|
||||
|
||||
|
||||
def test_get_oauth2_token_native(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config_native: OAuth2ClientConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Test `get_oauth2_token` for Native engine.
|
||||
"""
|
||||
requests = mocker.patch("superset.db_engine_specs.base.requests")
|
||||
requests.post().json.return_value = {
|
||||
"access_token": "access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "sql",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "refresh-token",
|
||||
}
|
||||
|
||||
assert DatabricksNativeEngineSpec.get_oauth2_token(
|
||||
oauth2_config_native, "authorization-code"
|
||||
) == {
|
||||
"access_token": "access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "sql",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "refresh-token",
|
||||
}
|
||||
requests.post.assert_called_with(
|
||||
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
|
||||
json={
|
||||
"code": "authorization-code",
|
||||
"client_id": "databricks-client-id",
|
||||
"client_secret": "databricks-client-secret",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
|
||||
def test_get_oauth2_token_python(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config_python: OAuth2ClientConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Test `get_oauth2_token` for Python Connector engine.
|
||||
"""
|
||||
requests = mocker.patch("superset.db_engine_specs.base.requests")
|
||||
requests.post().json.return_value = {
|
||||
"access_token": "access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "sql",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "refresh-token",
|
||||
}
|
||||
|
||||
assert DatabricksPythonConnectorEngineSpec.get_oauth2_token(
|
||||
oauth2_config_python, "authorization-code"
|
||||
) == {
|
||||
"access_token": "access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "sql",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "refresh-token",
|
||||
}
|
||||
requests.post.assert_called_with(
|
||||
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
|
||||
json={
|
||||
"code": "authorization-code",
|
||||
"client_id": "databricks-client-id",
|
||||
"client_secret": "databricks-client-secret",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
|
||||
def test_get_oauth2_fresh_token_native(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config_native: OAuth2ClientConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Test `get_oauth2_fresh_token` for Native engine.
|
||||
"""
|
||||
requests = mocker.patch("superset.db_engine_specs.base.requests")
|
||||
requests.post().json.return_value = {
|
||||
"access_token": "new-access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "sql",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "new-refresh-token",
|
||||
}
|
||||
|
||||
assert DatabricksNativeEngineSpec.get_oauth2_fresh_token(
|
||||
oauth2_config_native, "old-refresh-token"
|
||||
) == {
|
||||
"access_token": "new-access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "sql",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "new-refresh-token",
|
||||
}
|
||||
requests.post.assert_called_with(
|
||||
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
|
||||
json={
|
||||
"client_id": "databricks-client-id",
|
||||
"client_secret": "databricks-client-secret",
|
||||
"refresh_token": "old-refresh-token",
|
||||
"grant_type": "refresh_token",
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
|
||||
def _oauth2_state() -> OAuth2State:
|
||||
"""
|
||||
Build the default OAuth2 state shared by the OAuth2 tests.
|
||||
"""
|
||||
state: OAuth2State = {
|
||||
"database_id": 1,
|
||||
"user_id": 1,
|
||||
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"tab_id": "1234",
|
||||
}
|
||||
return state
|
||||
|
||||
|
||||
def _unresolved_oauth2_config() -> OAuth2ClientConfig:
|
||||
"""
|
||||
Config as built by `get_oauth2_config` when no endpoints are overridden:
|
||||
the URIs default to the spec's empty class attributes.
|
||||
"""
|
||||
return {
|
||||
"id": "databricks-client-id",
|
||||
"secret": "databricks-client-secret",
|
||||
"scope": "sql",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"authorization_request_uri": "",
|
||||
"token_request_uri": "",
|
||||
"request_content_type": "json",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec",
|
||||
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"host",
|
||||
[
|
||||
"dbc-abc.cloud.databricks.com",
|
||||
"adb-123456789.12.azuredatabricks.net",
|
||||
"123456789.gcp.databricks.com",
|
||||
],
|
||||
)
|
||||
def test_get_oauth2_authorization_uri_derives_from_workspace_host(
|
||||
mocker: MockerFixture,
|
||||
spec: Any,
|
||||
host: str,
|
||||
) -> None:
|
||||
"""
|
||||
With no configured `authorization_request_uri`, the endpoint is derived from
|
||||
the workspace host (`https://<host>/oidc/v1/authorize`) on every cloud, with
|
||||
no account/tenant identifier required.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
database.url_object.host = host
|
||||
mocker.patch("superset.db.session.get", return_value=database)
|
||||
|
||||
url = spec.get_oauth2_authorization_uri(
|
||||
_unresolved_oauth2_config(), _oauth2_state()
|
||||
)
|
||||
parsed = urlparse(url)
|
||||
assert parsed.netloc == host
|
||||
assert parsed.path == "/oidc/v1/authorize"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec",
|
||||
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
|
||||
)
|
||||
def test_get_oauth2_authorization_uri_preserves_configured(
|
||||
mocker: MockerFixture,
|
||||
spec: Any,
|
||||
) -> None:
|
||||
"""
|
||||
A fully-resolved `authorization_request_uri` is never overwritten by the
|
||||
host-derived endpoint, and no database lookup is needed.
|
||||
"""
|
||||
session_get = mocker.patch("superset.db.session.get")
|
||||
config = _unresolved_oauth2_config()
|
||||
config["authorization_request_uri"] = (
|
||||
"https://accounts.cloud.databricks.com/oidc/accounts/override/v1/authorize"
|
||||
)
|
||||
|
||||
url = spec.get_oauth2_authorization_uri(config, _oauth2_state())
|
||||
assert urlparse(url).path == "/oidc/accounts/override/v1/authorize"
|
||||
session_get.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec",
|
||||
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
|
||||
)
|
||||
def test_get_oauth2_authorization_uri_fails_without_host(
|
||||
mocker: MockerFixture,
|
||||
spec: Any,
|
||||
) -> None:
|
||||
"""
|
||||
When the endpoint must be derived but the connection has no host, fail fast
|
||||
instead of emitting an invalid `https:///oidc/v1/authorize` URL.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
database.url_object.host = None
|
||||
mocker.patch("superset.db.session.get", return_value=database)
|
||||
|
||||
with pytest.raises(OAuth2Error):
|
||||
spec.get_oauth2_authorization_uri(_unresolved_oauth2_config(), _oauth2_state())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec",
|
||||
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
|
||||
)
|
||||
def test_get_oauth2_token_fails_without_uri(
|
||||
mocker: MockerFixture,
|
||||
spec: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Token exchange has no database context to auto-detect the endpoint, so a
|
||||
missing `token_request_uri` fails fast rather than POSTing to `.../{}/...`.
|
||||
"""
|
||||
with pytest.raises(OAuth2Error):
|
||||
spec.get_oauth2_token(_unresolved_oauth2_config(), "authorization-code")
|
||||
|
||||
|
||||
def test_get_oauth2_fresh_token_python(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config_python: OAuth2ClientConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Test `get_oauth2_fresh_token` for Python Connector engine.
|
||||
"""
|
||||
requests = mocker.patch("superset.db_engine_specs.base.requests")
|
||||
requests.post().json.return_value = {
|
||||
"access_token": "new-access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "sql",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "new-refresh-token",
|
||||
}
|
||||
|
||||
assert DatabricksPythonConnectorEngineSpec.get_oauth2_fresh_token(
|
||||
oauth2_config_python, "old-refresh-token"
|
||||
) == {
|
||||
"access_token": "new-access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "sql",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "new-refresh-token",
|
||||
}
|
||||
requests.post.assert_called_with(
|
||||
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
|
||||
json={
|
||||
"client_id": "databricks-client-id",
|
||||
"client_secret": "databricks-client-secret",
|
||||
"refresh_token": "old-refresh-token",
|
||||
"grant_type": "refresh_token",
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
127
tests/unit_tests/db_engine_specs/test_databricks_multi_cloud.py
Normal file
127
tests/unit_tests/db_engine_specs/test_databricks_multi_cloud.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from superset.db_engine_specs.databricks import (
|
||||
DatabricksNativeEngineSpec,
|
||||
DatabricksPythonConnectorEngineSpec,
|
||||
)
|
||||
from superset.superset_typing import OAuth2ClientConfig
|
||||
from superset.utils.oauth2 import decode_oauth2_state
|
||||
|
||||
# Multi-Cloud Provider Tests
|
||||
#
|
||||
# Databricks fronts the user-to-machine OAuth2 flow on every workspace at
|
||||
# `https://<workspace-host>/oidc/v1/{authorize,token}`, regardless of cloud, so
|
||||
# the authorization endpoint derives from the connection host with no per-cloud
|
||||
# account/tenant identifier.
|
||||
|
||||
SPECS = [DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec]
|
||||
|
||||
# Representative workspace hosts for each cloud provider.
|
||||
CLOUD_HOSTS = [
|
||||
"my-cluster.cloud.databricks.com", # AWS
|
||||
"adb-123456789.12.azuredatabricks.net", # Azure
|
||||
"123456789.gcp.databricks.com", # GCP
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_config_no_uri() -> OAuth2ClientConfig:
|
||||
"""
|
||||
Config for Databricks OAuth2 without a pre-configured endpoint, so the
|
||||
authorization endpoint is derived from the workspace host.
|
||||
"""
|
||||
return {
|
||||
"id": "databricks-client-id",
|
||||
"secret": "databricks-client-secret",
|
||||
"scope": "sql",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"authorization_request_uri": "",
|
||||
"token_request_uri": "",
|
||||
"request_content_type": "json",
|
||||
}
|
||||
|
||||
|
||||
def _mock_database(mocker: MockerFixture, host: str) -> MagicMock:
|
||||
"""
|
||||
Build a mock database whose URL resolves to the given workspace host.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
database.url_object.host = host
|
||||
database.id = 1
|
||||
return database
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec", SPECS)
|
||||
@pytest.mark.parametrize("host", CLOUD_HOSTS)
|
||||
def test_get_oauth2_authorization_uri_uses_workspace_host(
|
||||
mocker: MockerFixture,
|
||||
spec: Any,
|
||||
host: str,
|
||||
oauth2_config_no_uri: OAuth2ClientConfig,
|
||||
) -> None:
|
||||
"""
|
||||
The authorization endpoint is the workspace host on AWS, Azure, and GCP.
|
||||
"""
|
||||
from superset.db_engine_specs.base import OAuth2State
|
||||
|
||||
mocker.patch(
|
||||
"superset.extensions.db.session.get",
|
||||
return_value=_mock_database(mocker, host),
|
||||
)
|
||||
|
||||
state: OAuth2State = {
|
||||
"database_id": 1,
|
||||
"user_id": 1,
|
||||
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"tab_id": "1234",
|
||||
}
|
||||
|
||||
url = spec.get_oauth2_authorization_uri(oauth2_config_no_uri, state)
|
||||
parsed = urlparse(url)
|
||||
assert parsed.netloc == host
|
||||
assert parsed.path == "/oidc/v1/authorize"
|
||||
|
||||
query = parse_qs(parsed.query)
|
||||
assert query["scope"][0] == "sql"
|
||||
encoded_state = query["state"][0].replace("%2E", ".")
|
||||
assert decode_oauth2_state(encoded_state) == state
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec", SPECS)
|
||||
@pytest.mark.parametrize("host", CLOUD_HOSTS)
|
||||
def test_workspace_oauth2_endpoint_builds_token_uri(
|
||||
mocker: MockerFixture,
|
||||
spec: Any,
|
||||
host: str,
|
||||
) -> None:
|
||||
"""
|
||||
The helper builds the matching token endpoint from the same workspace host.
|
||||
"""
|
||||
database = _mock_database(mocker, host)
|
||||
assert (
|
||||
spec._workspace_oauth2_endpoint(database, "token")
|
||||
== f"https://{host}/oidc/v1/token"
|
||||
)
|
||||
@@ -81,6 +81,7 @@ def _mock_dashboard(
|
||||
dashboard.slices = []
|
||||
dashboard.owners = []
|
||||
dashboard.tags = []
|
||||
dashboard.embedded = []
|
||||
return dashboard
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user