mirror of
https://github.com/apache/superset.git
synced 2026-06-28 10:55:36 +00:00
Compare commits
29 Commits
fix-dashbo
...
adopt/data
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
563edfdd9f | ||
|
|
98ba9da18c | ||
|
|
910e79dd8c | ||
|
|
6251280aa5 | ||
|
|
9d0a2209c6 | ||
|
|
cb7d5c8847 | ||
|
|
fe69d222bd | ||
|
|
8e0871bb2e | ||
|
|
ea1ba587f2 | ||
|
|
dc99373579 | ||
|
|
8737f010f3 | ||
|
|
278cfbb694 | ||
|
|
1de21ec5c6 | ||
|
|
2ed41ae8a6 | ||
|
|
a0fdb2aa31 | ||
|
|
25c9f3510a | ||
|
|
b8fd2e9725 | ||
|
|
78dd400ca4 | ||
|
|
7587d0778a | ||
|
|
97cb002f46 | ||
|
|
5ec0931840 | ||
|
|
3eb9185521 | ||
|
|
cd8ac41d16 | ||
|
|
21999bb772 | ||
|
|
0a18779280 | ||
|
|
a147079043 | ||
|
|
ebb32de625 | ||
|
|
1280eaee18 | ||
|
|
15626a047c |
2
.github/workflows/pre-commit.yml
vendored
2
.github/workflows/pre-commit.yml
vendored
@@ -63,7 +63,7 @@ jobs:
|
||||
yarn install --immutable
|
||||
|
||||
- name: Cache pre-commit environments
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: ~/.cache/pre-commit
|
||||
key: pre-commit-v2-${{ runner.os }}-py${{ matrix.python-version }}-${{ hashFiles('.pre-commit-config.yaml') }}
|
||||
|
||||
4
.github/workflows/release.yml
vendored
4
.github/workflows/release.yml
vendored
@@ -56,7 +56,7 @@ jobs:
|
||||
|
||||
- name: Cache npm
|
||||
if: env.HAS_TAGS
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: ~/.npm # npm cache files are stored in `~/.npm` on Linux/macOS
|
||||
key: ${{ runner.OS }}-node-${{ hashFiles('**/package-lock.json') }}
|
||||
@@ -70,7 +70,7 @@ jobs:
|
||||
run: echo "dir=$(npm config get cache)" >> $GITHUB_OUTPUT
|
||||
- name: Cache npm
|
||||
if: env.HAS_TAGS
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
id: npm-cache # use this to check for `cache-hit` (`steps.npm-cache.outputs.cache-hit != 'true'`)
|
||||
with:
|
||||
path: ${{ steps.npm-cache-dir-path.outputs.dir }}
|
||||
|
||||
4
.github/workflows/superset-websocket.yml
vendored
4
.github/workflows/superset-websocket.yml
vendored
@@ -28,6 +28,10 @@ jobs:
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0
|
||||
with:
|
||||
node-version-file: './superset-websocket/.nvmrc'
|
||||
- name: Install dependencies
|
||||
working-directory: ./superset-websocket
|
||||
run: npm ci
|
||||
|
||||
@@ -519,6 +519,80 @@ For a connection to a SQL endpoint you need to use the HTTP path from the endpoi
|
||||
{"connect_args": {"http_path": "/sql/1.0/endpoints/****", "driver_path": "/path/to/odbc/driver"}}
|
||||
```
|
||||
|
||||
##### OAuth2 Authentication
|
||||
|
||||
Superset supports OAuth2 authentication for Databricks, allowing users to authenticate with their personal Databricks accounts instead of using shared access tokens. This provides better security and audit capabilities.
|
||||
|
||||
###### Prerequisites
|
||||
|
||||
1. Create an OAuth2 application in your Databricks account:
|
||||
- Go to your Databricks account console
|
||||
- Navigate to **Settings** → **Developer** → **OAuth apps**
|
||||
- Create a new OAuth app with the redirect URI: `http://your-superset-host:port/api/v1/database/oauth2/`
|
||||
|
||||
2. Configure OAuth2 in your `superset_config.py`:
|
||||
|
||||
```python
|
||||
from datetime import timedelta
|
||||
|
||||
# OAuth2 configuration for Databricks
|
||||
# The authorization endpoint is derived from your Databricks workspace host; the
|
||||
# token endpoint must be set explicitly (see notes below).
|
||||
DATABASE_OAUTH2_CLIENTS = {
|
||||
"Databricks (legacy)": {
|
||||
"id": "your-databricks-client-id",
|
||||
"secret": "your-databricks-client-secret",
|
||||
"scope": "sql",
|
||||
"token_request_uri": "https://your-workspace-host/oidc/v1/token",
|
||||
},
|
||||
"Databricks": {
|
||||
"id": "your-databricks-client-id",
|
||||
"secret": "your-databricks-client-secret",
|
||||
"scope": "sql",
|
||||
"token_request_uri": "https://your-workspace-host/oidc/v1/token",
|
||||
},
|
||||
}
|
||||
|
||||
# OAuth2 redirect URI (adjust hostname/port for your setup)
|
||||
DATABASE_OAUTH2_REDIRECT_URI = "http://your-superset-host:port/api/v1/database/oauth2/"
|
||||
|
||||
# Optional: OAuth2 timeout
|
||||
DATABASE_OAUTH2_TIMEOUT = timedelta(seconds=30)
|
||||
```
|
||||
|
||||
Replace the following placeholders:
|
||||
- `your-databricks-client-id`: Your Databricks OAuth2 application client ID
|
||||
- `your-databricks-client-secret`: Your Databricks OAuth2 application client secret
|
||||
- `your-superset-host:port`: Your Superset instance hostname and port
|
||||
|
||||
**Multi-Cloud Provider Support**
|
||||
|
||||
Databricks fronts the user-to-machine (U2M) OAuth2 flow on every workspace at
|
||||
`https://<workspace-host>/oidc/v1/authorize` and
|
||||
`https://<workspace-host>/oidc/v1/token`, regardless of whether the workspace
|
||||
runs on AWS, Azure, or GCP. Superset derives the **authorization** endpoint
|
||||
directly from your connection's host, so no cloud provider or account/tenant
|
||||
identifier needs to be configured.
|
||||
|
||||
The **token** endpoint cannot be auto-derived (token exchange has no database
|
||||
context to read the host), so you must supply `token_request_uri` in
|
||||
`DATABASE_OAUTH2_CLIENTS`, set to `https://<workspace-host>/oidc/v1/token` for
|
||||
your workspace.
|
||||
|
||||
If you supply a fully-resolved `authorization_request_uri` (and/or
|
||||
`token_request_uri`), those values take precedence over the host-derived
|
||||
defaults.
|
||||
|
||||
###### Usage
|
||||
|
||||
Once configured, users can:
|
||||
|
||||
1. Connect to Databricks databases normally using access tokens
|
||||
2. When querying data, Superset will automatically redirect users to authenticate with Databricks if needed
|
||||
3. User-specific OAuth2 tokens will be used for database connections, providing better security and audit trails
|
||||
|
||||
This feature works with both "Databricks (legacy)" and "Databricks" engine types and automatically supports all major cloud providers (AWS, Azure, GCP).
|
||||
|
||||
#### Denodo
|
||||
|
||||
The recommended connector library for Denodo is
|
||||
|
||||
@@ -32,6 +32,7 @@ and therefore are not easily unit-testable. We have instead opted to test the sd
|
||||
This way, the tests can assert that the sdk actually mounts the iframe and communicates with it correctly.
|
||||
|
||||
At time of writing, these tests are not written yet, because we haven't yet put together the demo app that they will leverage.
|
||||
|
||||
### Things to e2e test once we have a demo app:
|
||||
|
||||
**happy path:**
|
||||
|
||||
@@ -41,12 +41,12 @@ npm install --save @superset-ui/embedded-sdk
|
||||
```
|
||||
|
||||
```js
|
||||
import { embedDashboard } from '@superset-ui/embedded-sdk';
|
||||
import { embedDashboard } from "@superset-ui/embedded-sdk";
|
||||
|
||||
embedDashboard({
|
||||
id: 'abc123', // given by the Superset embedding UI
|
||||
supersetDomain: 'https://superset.example.com',
|
||||
mountPoint: document.getElementById('my-superset-container'), // any html element that can contain an iframe
|
||||
id: "abc123", // given by the Superset embedding UI
|
||||
supersetDomain: "https://superset.example.com",
|
||||
mountPoint: document.getElementById("my-superset-container"), // any html element that can contain an iframe
|
||||
fetchGuestToken: () => fetchGuestTokenFromBackend(),
|
||||
dashboardUiConfig: {
|
||||
// dashboard UI config: hideTitle, hideTab, hideChartControls, filters.visible, filters.expanded (optional), urlParams (optional)
|
||||
@@ -55,21 +55,21 @@ embedDashboard({
|
||||
expanded: true,
|
||||
},
|
||||
urlParams: {
|
||||
foo: 'value1',
|
||||
bar: 'value2',
|
||||
foo: "value1",
|
||||
bar: "value2",
|
||||
// themeMode: 'dark', // set the initial theme: 'dark' | 'system' | 'default' (default: 'default')
|
||||
// ...
|
||||
},
|
||||
},
|
||||
// optional additional iframe sandbox attributes
|
||||
iframeSandboxExtras: [
|
||||
'allow-top-navigation',
|
||||
'allow-popups-to-escape-sandbox',
|
||||
"allow-top-navigation",
|
||||
"allow-popups-to-escape-sandbox",
|
||||
],
|
||||
// optional Permissions Policy features
|
||||
iframeAllowExtras: ['clipboard-write', 'fullscreen'],
|
||||
iframeAllowExtras: ["clipboard-write", "fullscreen"],
|
||||
// optional config to enforce a particular referrerPolicy
|
||||
referrerPolicy: 'same-origin',
|
||||
referrerPolicy: "same-origin",
|
||||
// optional callback to customize permalink URLs
|
||||
resolvePermalinkUrl: ({ key }) => `https://my-app.com/analytics/share/${key}`,
|
||||
});
|
||||
@@ -163,13 +163,13 @@ Use the `themeMode` URL parameter to control the embedded dashboard's initial co
|
||||
|
||||
```js
|
||||
embedDashboard({
|
||||
id: 'abc123',
|
||||
supersetDomain: 'https://superset.example.com',
|
||||
mountPoint: document.getElementById('my-superset-container'),
|
||||
id: "abc123",
|
||||
supersetDomain: "https://superset.example.com",
|
||||
mountPoint: document.getElementById("my-superset-container"),
|
||||
fetchGuestToken: () => fetchGuestTokenFromBackend(),
|
||||
dashboardUiConfig: {
|
||||
urlParams: {
|
||||
themeMode: 'dark', // 'dark' | 'system' | 'default' (default: 'default')
|
||||
themeMode: "dark", // 'dark' | 'system' | 'default' (default: 'default')
|
||||
},
|
||||
},
|
||||
});
|
||||
@@ -193,7 +193,7 @@ To pass additional sandbox attributes you can use `iframeSandboxExtras`:
|
||||
|
||||
```js
|
||||
// optional additional iframe sandbox attributes
|
||||
iframeSandboxExtras: ['allow-top-navigation', 'allow-popups-to-escape-sandbox'];
|
||||
iframeSandboxExtras: ["allow-top-navigation", "allow-popups-to-escape-sandbox"];
|
||||
```
|
||||
|
||||
### Permissions Policy
|
||||
@@ -202,7 +202,7 @@ To enable specific browser features within the embedded iframe, use `iframeAllow
|
||||
|
||||
```js
|
||||
// optional Permissions Policy features
|
||||
iframeAllowExtras: ['clipboard-write', 'fullscreen'];
|
||||
iframeAllowExtras: ["clipboard-write", "fullscreen"];
|
||||
```
|
||||
|
||||
Common permissions you might need:
|
||||
@@ -225,9 +225,9 @@ When users click share buttons inside an embedded dashboard, Superset generates
|
||||
|
||||
```js
|
||||
embedDashboard({
|
||||
id: 'abc123',
|
||||
supersetDomain: 'https://superset.example.com',
|
||||
mountPoint: document.getElementById('my-superset-container'),
|
||||
id: "abc123",
|
||||
supersetDomain: "https://superset.example.com",
|
||||
mountPoint: document.getElementById("my-superset-container"),
|
||||
fetchGuestToken: () => fetchGuestTokenFromBackend(),
|
||||
|
||||
// Customize permalink URLs
|
||||
@@ -245,9 +245,9 @@ To restore the dashboard state from a permalink in your app:
|
||||
const permalinkKey = routeParams.key;
|
||||
|
||||
embedDashboard({
|
||||
id: 'abc123',
|
||||
supersetDomain: 'https://superset.example.com',
|
||||
mountPoint: document.getElementById('my-superset-container'),
|
||||
id: "abc123",
|
||||
supersetDomain: "https://superset.example.com",
|
||||
mountPoint: document.getElementById("my-superset-container"),
|
||||
fetchGuestToken: () => fetchGuestTokenFromBackend(),
|
||||
resolvePermalinkUrl: ({ key }) => `https://my-app.com/analytics/share/${key}`,
|
||||
dashboardUiConfig: {
|
||||
|
||||
@@ -18,9 +18,6 @@
|
||||
*/
|
||||
|
||||
module.exports = {
|
||||
presets: [
|
||||
"@babel/preset-typescript",
|
||||
"@babel/preset-env"
|
||||
],
|
||||
presets: ["@babel/preset-typescript", "@babel/preset-env"],
|
||||
sourceMaps: true,
|
||||
};
|
||||
|
||||
10769
superset-embedded-sdk/package-lock.json
generated
10769
superset-embedded-sdk/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -24,7 +24,7 @@
|
||||
"scripts": {
|
||||
"build": "tsc && babel src --out-dir lib --extensions '.ts,.tsx' && webpack --mode production",
|
||||
"ci:release": "node ./release-if-necessary.js",
|
||||
"test": "jest"
|
||||
"test": "vitest --run --dir src"
|
||||
},
|
||||
"browserslist": [
|
||||
"last 3 chrome versions",
|
||||
@@ -41,12 +41,11 @@
|
||||
"@babel/core": "^7.25.2",
|
||||
"@babel/preset-env": "^7.25.4",
|
||||
"@babel/preset-typescript": "^7.24.7",
|
||||
"@types/jest": "^29.5.12",
|
||||
"@types/node": "^22.5.4",
|
||||
"@types/node": "^25.4.0",
|
||||
"babel-loader": "^9.1.3",
|
||||
"jest": "^29.7.0",
|
||||
"tscw-config": "^1.1.2",
|
||||
"typescript": "^5.6.2",
|
||||
"typescript": "^5.9.3",
|
||||
"vitest": "^4.0.18",
|
||||
"webpack": "^5.94.0",
|
||||
"webpack-cli": "^5.1.4"
|
||||
},
|
||||
|
||||
@@ -17,15 +17,15 @@
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
const { execSync } = require('child_process');
|
||||
const { name, version } = require('./package.json');
|
||||
const { execSync } = require("child_process");
|
||||
const { name, version } = require("./package.json");
|
||||
|
||||
function log(...args) {
|
||||
console.log('[embedded-sdk-release]', ...args);
|
||||
console.log("[embedded-sdk-release]", ...args);
|
||||
}
|
||||
|
||||
function logError(...args) {
|
||||
console.error('[embedded-sdk-release]', ...args);
|
||||
console.error("[embedded-sdk-release]", ...args);
|
||||
}
|
||||
|
||||
(async () => {
|
||||
@@ -38,13 +38,13 @@ function logError(...args) {
|
||||
const { status } = await fetch(packageUrl);
|
||||
|
||||
if (status === 200) {
|
||||
log('version already exists on npm, exiting');
|
||||
log("version already exists on npm, exiting");
|
||||
} else if (status === 404) {
|
||||
log('release required, building');
|
||||
log("release required, building");
|
||||
try {
|
||||
execSync('npm run build', { stdio: 'pipe' });
|
||||
log('build successful, publishing')
|
||||
execSync('npm publish --access public', { stdio: 'pipe' });
|
||||
execSync("npm run build", { stdio: "pipe" });
|
||||
log("build successful, publishing");
|
||||
execSync("npm publish --access public", { stdio: "pipe" });
|
||||
log(`published ${version} to npm`);
|
||||
} catch (err) {
|
||||
// npm writes failure details to stderr (auth/permission/registry
|
||||
@@ -52,7 +52,7 @@ function logError(...args) {
|
||||
// the real cause in CI logs.
|
||||
if (err.stdout) console.error(String(err.stdout));
|
||||
if (err.stderr) console.error(String(err.stderr));
|
||||
logError('Encountered an error, details should be above');
|
||||
logError("Encountered an error, details should be above");
|
||||
process.exitCode = 1;
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -18,7 +18,9 @@
|
||||
*/
|
||||
|
||||
export const IFRAME_COMMS_MESSAGE_TYPE = "__embedded_comms__";
|
||||
export const DASHBOARD_UI_FILTER_CONFIG_URL_PARAM_KEY: { [index: string]: any } = {
|
||||
export const DASHBOARD_UI_FILTER_CONFIG_URL_PARAM_KEY: {
|
||||
[index: string]: any;
|
||||
} = {
|
||||
visible: "show_filters",
|
||||
expanded: "expand_filters",
|
||||
}
|
||||
};
|
||||
|
||||
@@ -24,22 +24,23 @@ import {
|
||||
DEFAULT_TOKEN_EXP_MS,
|
||||
DEFAULT_TOKEN_REFRESH_RETRY_MS,
|
||||
} from "./guestTokenRefresh";
|
||||
import { afterAll, beforeAll, it, expect, describe, vi } from "vitest";
|
||||
|
||||
describe("guest token refresh", () => {
|
||||
beforeAll(() => {
|
||||
jest.useFakeTimers();
|
||||
jest.setSystemTime(new Date("2022-03-03 01:00"));
|
||||
jest.spyOn(global, "setTimeout");
|
||||
vi.useFakeTimers();
|
||||
vi.setSystemTime(new Date("2022-03-03 01:00"));
|
||||
vi.spyOn(globalThis, "setTimeout");
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
jest.useRealTimers();
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
function makeFakeJWT(claims: any) {
|
||||
// not a valid jwt, but close enough for this code
|
||||
const tokenifiedClaims = Buffer.from(JSON.stringify(claims)).toString(
|
||||
"base64"
|
||||
"base64",
|
||||
);
|
||||
return `abc.${tokenifiedClaims}.xyz`;
|
||||
}
|
||||
|
||||
@@ -18,17 +18,23 @@
|
||||
*/
|
||||
import { jwtDecode } from "jwt-decode";
|
||||
|
||||
export const REFRESH_TIMING_BUFFER_MS = 5000 // refresh guest token early to avoid failed superset requests
|
||||
export const MIN_REFRESH_WAIT_MS = 10000 // avoid blasting requests as fast as the cpu can handle
|
||||
export const DEFAULT_TOKEN_EXP_MS = 300000 // (5 min) used only when parsing guest token exp fails
|
||||
export const DEFAULT_TOKEN_REFRESH_RETRY_MS = 10000 // wait before retrying a failed/timed-out token refresh
|
||||
export const REFRESH_TIMING_BUFFER_MS = 5000; // refresh guest token early to avoid failed superset requests
|
||||
export const MIN_REFRESH_WAIT_MS = 10000; // avoid blasting requests as fast as the cpu can handle
|
||||
export const DEFAULT_TOKEN_EXP_MS = 300000; // (5 min) used only when parsing guest token exp fails
|
||||
export const DEFAULT_TOKEN_REFRESH_RETRY_MS = 10000; // wait before retrying a failed/timed-out token refresh
|
||||
|
||||
// when do we refresh the guest token?
|
||||
export function getGuestTokenRefreshTiming(currentGuestToken: string) {
|
||||
const parsedJwt = jwtDecode<Record<string, any>>(currentGuestToken);
|
||||
// if exp is int, it is in seconds, but Date() takes milliseconds
|
||||
const exp = new Date(/[^0-9\.]/g.test(parsedJwt.exp) ? parsedJwt.exp : parseFloat(parsedJwt.exp) * 1000);
|
||||
const isValidDate = exp.toString() !== 'Invalid Date';
|
||||
const ttl = isValidDate ? Math.max(MIN_REFRESH_WAIT_MS, exp.getTime() - Date.now()) : DEFAULT_TOKEN_EXP_MS;
|
||||
const exp = new Date(
|
||||
/[^0-9\.]/g.test(parsedJwt.exp)
|
||||
? parsedJwt.exp
|
||||
: parseFloat(parsedJwt.exp) * 1000,
|
||||
);
|
||||
const isValidDate = exp.toString() !== "Invalid Date";
|
||||
const ttl = isValidDate
|
||||
? Math.max(MIN_REFRESH_WAIT_MS, exp.getTime() - Date.now())
|
||||
: DEFAULT_TOKEN_EXP_MS;
|
||||
return ttl - REFRESH_TIMING_BUFFER_MS;
|
||||
}
|
||||
|
||||
@@ -20,15 +20,15 @@
|
||||
import {
|
||||
DASHBOARD_UI_FILTER_CONFIG_URL_PARAM_KEY,
|
||||
IFRAME_COMMS_MESSAGE_TYPE,
|
||||
} from './const';
|
||||
} from "./const";
|
||||
|
||||
// We can swap this out for the actual switchboard package once it gets published
|
||||
import { Switchboard } from '@superset-ui/switchboard';
|
||||
import { Switchboard } from "@superset-ui/switchboard";
|
||||
import {
|
||||
getGuestTokenRefreshTiming,
|
||||
DEFAULT_TOKEN_REFRESH_RETRY_MS,
|
||||
} from './guestTokenRefresh';
|
||||
import { withTimeout } from './withTimeout';
|
||||
} from "./guestTokenRefresh";
|
||||
import { withTimeout } from "./withTimeout";
|
||||
|
||||
/**
|
||||
* The function to fetch a guest token from your Host App's backend server.
|
||||
@@ -97,7 +97,7 @@ export type ObserveDataMaskCallbackFn = (
|
||||
nativeFiltersChanged: boolean;
|
||||
},
|
||||
) => void;
|
||||
export type ThemeMode = 'default' | 'dark' | 'system';
|
||||
export type ThemeMode = "default" | "dark" | "system";
|
||||
|
||||
/**
|
||||
* Callback to resolve permalink URLs.
|
||||
@@ -113,12 +113,12 @@ export type EmbeddedDashboard = {
|
||||
unmount: () => void;
|
||||
getDashboardPermalink: (anchor: string) => Promise<string>;
|
||||
getActiveTabs: () => Promise<string[]>;
|
||||
observeDataMask: (
|
||||
callbackFn: ObserveDataMaskCallbackFn,
|
||||
) => void;
|
||||
observeDataMask: (callbackFn: ObserveDataMaskCallbackFn) => void;
|
||||
getDataMask: () => Promise<Record<string, any>>;
|
||||
getChartStates: () => Promise<Record<string, any>>;
|
||||
getChartDataPayloads: (params?: { chartId?: number }) => Promise<Record<string, any>>;
|
||||
getChartDataPayloads: (params?: {
|
||||
chartId?: number;
|
||||
}) => Promise<Record<string, any>>;
|
||||
setThemeConfig: (themeConfig: Record<string, any>) => void;
|
||||
setThemeMode: (mode: ThemeMode) => void;
|
||||
};
|
||||
@@ -133,7 +133,7 @@ export async function embedDashboard({
|
||||
fetchGuestToken,
|
||||
dashboardUiConfig,
|
||||
debug = false,
|
||||
iframeTitle = 'Embedded Dashboard',
|
||||
iframeTitle = "Embedded Dashboard",
|
||||
iframeSandboxExtras = [],
|
||||
iframeAllowExtras = [],
|
||||
referrerPolicy,
|
||||
@@ -152,13 +152,13 @@ export async function embedDashboard({
|
||||
return withTimeout(
|
||||
fetchGuestToken(),
|
||||
guestTokenFetchTimeoutMs,
|
||||
'fetchGuestToken',
|
||||
"fetchGuestToken",
|
||||
);
|
||||
}
|
||||
|
||||
log('embedding');
|
||||
log("embedding");
|
||||
|
||||
if (supersetDomain.endsWith('/')) {
|
||||
if (supersetDomain.endsWith("/")) {
|
||||
supersetDomain = supersetDomain.slice(0, -1);
|
||||
}
|
||||
|
||||
@@ -185,15 +185,15 @@ export async function embedDashboard({
|
||||
}
|
||||
|
||||
async function mountIframe(): Promise<Switchboard> {
|
||||
return new Promise(resolve => {
|
||||
const iframe = document.createElement('iframe');
|
||||
return new Promise((resolve) => {
|
||||
const iframe = document.createElement("iframe");
|
||||
const dashboardConfigUrlParams = dashboardUiConfig
|
||||
? { uiConfig: `${calculateConfig()}` }
|
||||
: undefined;
|
||||
const filterConfig = dashboardUiConfig?.filters || {};
|
||||
const filterConfigKeys = Object.keys(filterConfig);
|
||||
const filterConfigUrlParams = Object.fromEntries(
|
||||
filterConfigKeys.map(key => [
|
||||
filterConfigKeys.map((key) => [
|
||||
DASHBOARD_UI_FILTER_CONFIG_URL_PARAM_KEY[key],
|
||||
filterConfig[key],
|
||||
]),
|
||||
@@ -206,16 +206,16 @@ export async function embedDashboard({
|
||||
...dashboardUiConfig?.urlParams,
|
||||
};
|
||||
const urlParamsString = Object.keys(urlParams).length
|
||||
? '?' + new URLSearchParams(urlParams).toString()
|
||||
: '';
|
||||
? "?" + new URLSearchParams(urlParams).toString()
|
||||
: "";
|
||||
|
||||
// set up the iframe's sandbox configuration
|
||||
iframe.sandbox.add('allow-same-origin'); // needed for postMessage to work
|
||||
iframe.sandbox.add('allow-scripts'); // obviously the iframe needs scripts
|
||||
iframe.sandbox.add('allow-presentation'); // for fullscreen charts
|
||||
iframe.sandbox.add('allow-downloads'); // for downloading charts as image
|
||||
iframe.sandbox.add('allow-forms'); // for forms to submit
|
||||
iframe.sandbox.add('allow-popups'); // for exporting charts as csv
|
||||
iframe.sandbox.add("allow-same-origin"); // needed for postMessage to work
|
||||
iframe.sandbox.add("allow-scripts"); // obviously the iframe needs scripts
|
||||
iframe.sandbox.add("allow-presentation"); // for fullscreen charts
|
||||
iframe.sandbox.add("allow-downloads"); // for downloading charts as image
|
||||
iframe.sandbox.add("allow-forms"); // for forms to submit
|
||||
iframe.sandbox.add("allow-popups"); // for exporting charts as csv
|
||||
// additional sandbox props
|
||||
iframeSandboxExtras.forEach((key: string) => {
|
||||
iframe.sandbox.add(key);
|
||||
@@ -226,7 +226,7 @@ export async function embedDashboard({
|
||||
}
|
||||
|
||||
// add the event listener before setting src, to be 100% sure that we capture the load event
|
||||
iframe.addEventListener('load', () => {
|
||||
iframe.addEventListener("load", () => {
|
||||
// MessageChannel allows us to send and receive messages smoothly between our window and the iframe
|
||||
// See https://developer.mozilla.org/en-US/docs/Web/API/Channel_Messaging_API
|
||||
const commsChannel = new MessageChannel();
|
||||
@@ -237,35 +237,35 @@ export async function embedDashboard({
|
||||
// See https://developer.mozilla.org/en-US/docs/Web/API/Window/postMessage
|
||||
// we know the content window isn't null because we are in the load event handler.
|
||||
iframe.contentWindow!.postMessage(
|
||||
{ type: IFRAME_COMMS_MESSAGE_TYPE, handshake: 'port transfer' },
|
||||
{ type: IFRAME_COMMS_MESSAGE_TYPE, handshake: "port transfer" },
|
||||
supersetDomain,
|
||||
[theirPort],
|
||||
);
|
||||
log('sent message channel to the iframe');
|
||||
log("sent message channel to the iframe");
|
||||
|
||||
// return our port from the promise
|
||||
resolve(
|
||||
new Switchboard({
|
||||
port: ourPort,
|
||||
name: 'superset-embedded-sdk',
|
||||
name: "superset-embedded-sdk",
|
||||
debug,
|
||||
}),
|
||||
);
|
||||
});
|
||||
iframe.src = `${supersetDomain}/embedded/${id}${urlParamsString}`;
|
||||
iframe.title = iframeTitle;
|
||||
iframe.style.background = 'transparent';
|
||||
iframe.style.background = "transparent";
|
||||
// Permissions Policy features the embedded dashboard relies on. Modern
|
||||
// browsers gate these APIs on the iframe's `allow` attribute regardless
|
||||
// of sandbox flags, so we include them by default. Host apps can extend
|
||||
// the list via `iframeAllowExtras`.
|
||||
const allowFeatures = Array.from(
|
||||
new Set(['fullscreen', 'clipboard-write', ...iframeAllowExtras]),
|
||||
new Set(["fullscreen", "clipboard-write", ...iframeAllowExtras]),
|
||||
);
|
||||
iframe.setAttribute('allow', allowFeatures.join('; '));
|
||||
iframe.setAttribute("allow", allowFeatures.join("; "));
|
||||
//@ts-ignore
|
||||
mountPoint.replaceChildren(iframe);
|
||||
log('placed the iframe');
|
||||
log("placed the iframe");
|
||||
});
|
||||
}
|
||||
|
||||
@@ -285,8 +285,8 @@ export async function embedDashboard({
|
||||
throw err;
|
||||
}
|
||||
|
||||
ourPort.emit('guestToken', { guestToken });
|
||||
log('sent guest token');
|
||||
ourPort.emit("guestToken", { guestToken });
|
||||
log("sent guest token");
|
||||
|
||||
// Track the pending refresh timer so it can be cancelled on unmount, and
|
||||
// stop the cycle once unmounted so it cannot leak across mount/unmount cycles.
|
||||
@@ -298,7 +298,7 @@ export async function embedDashboard({
|
||||
try {
|
||||
const newGuestToken = await fetchGuestTokenWithTimeout();
|
||||
if (unmounted) return;
|
||||
ourPort.emit('guestToken', { guestToken: newGuestToken });
|
||||
ourPort.emit("guestToken", { guestToken: newGuestToken });
|
||||
refreshTimer = setTimeout(
|
||||
refreshGuestToken,
|
||||
getGuestTokenRefreshTiming(newGuestToken),
|
||||
@@ -307,7 +307,7 @@ export async function embedDashboard({
|
||||
// A transient fetch failure or timeout must not permanently stop the
|
||||
// refresh cycle. Log it and retry so the session can recover once the
|
||||
// host callback succeeds again.
|
||||
log('failed to refresh guest token, will retry:', err);
|
||||
log("failed to refresh guest token, will retry:", err);
|
||||
if (unmounted) return;
|
||||
refreshTimer = setTimeout(
|
||||
refreshGuestToken,
|
||||
@@ -325,7 +325,7 @@ export async function embedDashboard({
|
||||
// Returns null if no callback provided or on error, allowing iframe to use default URL
|
||||
ourPort.start();
|
||||
ourPort.defineMethod(
|
||||
'resolvePermalinkUrl',
|
||||
"resolvePermalinkUrl",
|
||||
async ({ key }: { key: string }): Promise<string | null> => {
|
||||
if (!resolvePermalinkUrl) {
|
||||
return null;
|
||||
@@ -333,14 +333,14 @@ export async function embedDashboard({
|
||||
try {
|
||||
return await resolvePermalinkUrl({ key });
|
||||
} catch (error) {
|
||||
log('Error in resolvePermalinkUrl callback:', error);
|
||||
log("Error in resolvePermalinkUrl callback:", error);
|
||||
return null;
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
function unmount() {
|
||||
log('unmounting');
|
||||
log("unmounting");
|
||||
unmounted = true;
|
||||
if (refreshTimer !== undefined) {
|
||||
clearTimeout(refreshTimer);
|
||||
@@ -350,24 +350,25 @@ export async function embedDashboard({
|
||||
mountPoint.replaceChildren();
|
||||
}
|
||||
|
||||
const getScrollSize = () => ourPort.get<Size>('getScrollSize');
|
||||
const getScrollSize = () => ourPort.get<Size>("getScrollSize");
|
||||
const getDashboardPermalink = (anchor: string) =>
|
||||
ourPort.get<string>('getDashboardPermalink', { anchor });
|
||||
const getActiveTabs = () => ourPort.get<string[]>('getActiveTabs');
|
||||
const getDataMask = () => ourPort.get<Record<string, any>>('getDataMask');
|
||||
const getChartStates = () => ourPort.get<Record<string, any>>('getChartStates');
|
||||
ourPort.get<string>("getDashboardPermalink", { anchor });
|
||||
const getActiveTabs = () => ourPort.get<string[]>("getActiveTabs");
|
||||
const getDataMask = () => ourPort.get<Record<string, any>>("getDataMask");
|
||||
const getChartStates = () =>
|
||||
ourPort.get<Record<string, any>>("getChartStates");
|
||||
const getChartDataPayloads = (params?: { chartId?: number }) =>
|
||||
ourPort.get<Record<string, any>>('getChartDataPayloads', params);
|
||||
const observeDataMask = (
|
||||
callbackFn: ObserveDataMaskCallbackFn,
|
||||
) => {
|
||||
ourPort.defineMethod('observeDataMask', callbackFn);
|
||||
ourPort.get<Record<string, any>>("getChartDataPayloads", params);
|
||||
const observeDataMask = (callbackFn: ObserveDataMaskCallbackFn) => {
|
||||
ourPort.defineMethod("observeDataMask", callbackFn);
|
||||
};
|
||||
// TODO: Add proper types once theming branch is merged
|
||||
const setThemeConfig = async (themeConfig: Record<string, any>): Promise<void> => {
|
||||
const setThemeConfig = async (
|
||||
themeConfig: Record<string, any>,
|
||||
): Promise<void> => {
|
||||
try {
|
||||
ourPort.emit('setThemeConfig', { themeConfig });
|
||||
log('Theme config sent successfully (or at least message dispatched)');
|
||||
ourPort.emit("setThemeConfig", { themeConfig });
|
||||
log("Theme config sent successfully (or at least message dispatched)");
|
||||
} catch (error) {
|
||||
log(
|
||||
'Error sending theme config. Ensure the iframe side implements the "setThemeConfig" method.',
|
||||
@@ -378,7 +379,7 @@ export async function embedDashboard({
|
||||
|
||||
const setThemeMode = (mode: ThemeMode): void => {
|
||||
try {
|
||||
ourPort.emit('setThemeMode', { mode });
|
||||
ourPort.emit("setThemeMode", { mode });
|
||||
log(`Theme mode set to: ${mode}`);
|
||||
} catch (error) {
|
||||
log(
|
||||
|
||||
@@ -18,22 +18,23 @@
|
||||
*/
|
||||
|
||||
import { withTimeout } from "./withTimeout";
|
||||
import { test, expect } from "vitest";
|
||||
|
||||
test("resolves with the value when the promise settles in time", async () => {
|
||||
await expect(withTimeout(Promise.resolve("ok"), 1000, "fetch")).resolves.toBe(
|
||||
"ok"
|
||||
"ok",
|
||||
);
|
||||
});
|
||||
|
||||
test("rejects when the promise does not settle within the timeout", async () => {
|
||||
const never = new Promise<string>(() => {});
|
||||
await expect(withTimeout(never, 10, "fetch")).rejects.toThrow(
|
||||
/fetch did not resolve within 10ms/
|
||||
/fetch did not resolve within 10ms/,
|
||||
);
|
||||
});
|
||||
|
||||
test("passes the promise through unchanged when the timeout is disabled", async () => {
|
||||
await expect(withTimeout(Promise.resolve("ok"), 0, "fetch")).resolves.toBe(
|
||||
"ok"
|
||||
"ok",
|
||||
);
|
||||
});
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
// syntax rules
|
||||
"strict": true,
|
||||
|
||||
"moduleResolution": "node",
|
||||
"moduleResolution": "bundler",
|
||||
|
||||
// environment
|
||||
"target": "es6",
|
||||
@@ -13,7 +13,9 @@
|
||||
// output
|
||||
"outDir": "./dist",
|
||||
"emitDeclarationOnly": true,
|
||||
"declaration": true
|
||||
"declaration": true,
|
||||
|
||||
"types": ["node"]
|
||||
},
|
||||
|
||||
"include": [
|
||||
@@ -21,7 +23,6 @@
|
||||
],
|
||||
|
||||
"exclude": [
|
||||
"tests",
|
||||
"dist",
|
||||
"lib",
|
||||
"node_modules"
|
||||
|
||||
@@ -17,19 +17,19 @@
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
const path = require('path');
|
||||
const path = require("path");
|
||||
|
||||
module.exports = {
|
||||
entry: './src/index.ts',
|
||||
entry: "./src/index.ts",
|
||||
output: {
|
||||
filename: 'index.js',
|
||||
path: path.resolve(__dirname, 'bundle'),
|
||||
filename: "index.js",
|
||||
path: path.resolve(__dirname, "bundle"),
|
||||
|
||||
// this exposes the library's exports under a global variable
|
||||
library: {
|
||||
name: "supersetEmbeddedSdk",
|
||||
type: "umd"
|
||||
}
|
||||
type: "umd",
|
||||
},
|
||||
},
|
||||
devtool: "source-map",
|
||||
module: {
|
||||
@@ -38,12 +38,12 @@ module.exports = {
|
||||
test: /\.[tj]s$/,
|
||||
// babel-loader is faster than ts-loader because it ignores types.
|
||||
// We do type checking in a separate process, so that's fine.
|
||||
use: 'babel-loader',
|
||||
use: "babel-loader",
|
||||
exclude: /node_modules/,
|
||||
},
|
||||
],
|
||||
},
|
||||
resolve: {
|
||||
extensions: ['.ts', '.js'],
|
||||
extensions: [".ts", ".js"],
|
||||
},
|
||||
};
|
||||
|
||||
@@ -45,6 +45,7 @@ export const IconTooltip = forwardRef<HTMLElement, IconTooltipProps>(
|
||||
}}
|
||||
buttonStyle="link"
|
||||
className={`IconTooltip ${className}`}
|
||||
aria-label={tooltip ?? undefined}
|
||||
>
|
||||
{children}
|
||||
</Button>
|
||||
|
||||
@@ -344,6 +344,16 @@ export default function transformProps(
|
||||
data2,
|
||||
currencyCodeColumn,
|
||||
);
|
||||
const getAxisFormatterConfig = (axisIndex?: number) =>
|
||||
axisIndex === 1
|
||||
? {
|
||||
customFormatters: customFormattersSecondary,
|
||||
formatter: formatterSecondary,
|
||||
}
|
||||
: {
|
||||
customFormatters,
|
||||
formatter,
|
||||
};
|
||||
|
||||
const primarySeries = new Set<string>();
|
||||
const secondarySeries = new Set<string>();
|
||||
@@ -422,6 +432,8 @@ export default function transformProps(
|
||||
let [minSecondary, maxSecondary] = (yAxisBoundsSecondary || []).map(
|
||||
parseAxisBound,
|
||||
);
|
||||
const getAxisMax = (axisIndex?: number) =>
|
||||
axisIndex === 1 ? maxSecondary : yAxisMax;
|
||||
|
||||
const array = ensureIsArray(chartProps.rawFormData?.time_compare);
|
||||
const inverted = invert(verboseMap);
|
||||
@@ -445,10 +457,11 @@ export default function transformProps(
|
||||
// When no groupby, format as just the entry name with optional query identifier
|
||||
displayName = showQueryIdentifiers ? `${entryName} (Query A)` : entryName;
|
||||
}
|
||||
const axisFormatterConfig = getAxisFormatterConfig(yAxisIndex);
|
||||
|
||||
const seriesFormatter = getFormatter(
|
||||
customFormatters,
|
||||
formatter,
|
||||
axisFormatterConfig.customFormatters,
|
||||
axisFormatterConfig.formatter,
|
||||
metrics,
|
||||
labelMap?.[seriesName]?.[0],
|
||||
!!contributionMode,
|
||||
@@ -480,7 +493,7 @@ export default function transformProps(
|
||||
formatter:
|
||||
seriesType === EchartsTimeseriesSeriesType.Bar
|
||||
? getOverMaxHiddenFormatter({
|
||||
max: yAxisMax,
|
||||
max: getAxisMax(yAxisIndex),
|
||||
formatter: seriesFormatter,
|
||||
})
|
||||
: seriesFormatter,
|
||||
@@ -518,10 +531,11 @@ export default function transformProps(
|
||||
// When no groupby, format as just the entry name with optional query identifier
|
||||
displayName = showQueryIdentifiers ? `${entryName} (Query B)` : entryName;
|
||||
}
|
||||
const axisFormatterConfig = getAxisFormatterConfig(yAxisIndexB);
|
||||
|
||||
const seriesFormatter = getFormatter(
|
||||
customFormattersSecondary,
|
||||
formatterSecondary,
|
||||
axisFormatterConfig.customFormatters,
|
||||
axisFormatterConfig.formatter,
|
||||
metricsB,
|
||||
labelMapB?.[seriesName]?.[0],
|
||||
!!contributionMode,
|
||||
@@ -554,7 +568,7 @@ export default function transformProps(
|
||||
formatter:
|
||||
seriesTypeB === EchartsTimeseriesSeriesType.Bar
|
||||
? getOverMaxHiddenFormatter({
|
||||
max: maxSecondary,
|
||||
max: getAxisMax(yAxisIndexB),
|
||||
formatter: seriesFormatter,
|
||||
})
|
||||
: seriesFormatter,
|
||||
|
||||
@@ -35,13 +35,26 @@ import {
|
||||
} from '../../src';
|
||||
import transformProps from '../../src/MixedTimeseries/transformProps';
|
||||
import {
|
||||
DEFAULT_FORM_DATA,
|
||||
EchartsMixedTimeseriesFormData,
|
||||
EchartsMixedTimeseriesProps,
|
||||
} from '../../src/MixedTimeseries/types';
|
||||
import { DEFAULT_FORM_DATA } from '../../src/MixedTimeseries/types';
|
||||
import { createEchartsTimeseriesTestChartProps } from '../helpers';
|
||||
import type { SeriesOption } from 'echarts';
|
||||
|
||||
type LabelFormatterParams = {
|
||||
value: [number, number];
|
||||
dataIndex: number;
|
||||
seriesIndex: number;
|
||||
seriesName: string;
|
||||
};
|
||||
|
||||
type SeriesWithLabelFormatter = SeriesOption & {
|
||||
label?: {
|
||||
formatter?: (params: LabelFormatterParams) => string | number;
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a partial ChartDataResponseResult for testing.
|
||||
* Only includes the fields needed for tests, with sensible defaults for required fields.
|
||||
@@ -148,6 +161,30 @@ const queriesData: ChartDataResponseResult[] = [
|
||||
createTestQueryData(defaultQueryRows, { label_map: defaultLabelMap }),
|
||||
];
|
||||
|
||||
function getSeriesWithLabelFormatter(
|
||||
series: SeriesOption[],
|
||||
name: string,
|
||||
): SeriesWithLabelFormatter {
|
||||
const result = series.find(seriesOption => seriesOption.name === name);
|
||||
expect(result).toBeDefined();
|
||||
expect((result as SeriesWithLabelFormatter).label?.formatter).toBeDefined();
|
||||
return result as SeriesWithLabelFormatter;
|
||||
}
|
||||
|
||||
function formatSeriesLabel(
|
||||
series: SeriesWithLabelFormatter,
|
||||
value: [number, number],
|
||||
) {
|
||||
const formatter = series.label?.formatter;
|
||||
expect(formatter).toBeDefined();
|
||||
return formatter?.({
|
||||
dataIndex: 0,
|
||||
seriesIndex: 0,
|
||||
seriesName: String(series.name),
|
||||
value,
|
||||
});
|
||||
}
|
||||
|
||||
test('should transform chart props for viz with showQueryIdentifiers=false', () => {
|
||||
const chartProps = createEchartsTimeseriesTestChartProps<
|
||||
EchartsMixedTimeseriesFormData,
|
||||
@@ -232,6 +269,162 @@ test('should transform chart props for viz with showQueryIdentifiers=true', () =
|
||||
]);
|
||||
});
|
||||
|
||||
test('formats value labels with the formatter for the assigned y-axis', () => {
|
||||
const timestamp = 1704067200000;
|
||||
const queryAData = createTestQueryData(
|
||||
[{ __timestamp: timestamp, lineMetric: 0.25 }],
|
||||
{
|
||||
colnames: ['__timestamp', 'lineMetric'],
|
||||
coltypes: [GenericDataType.Temporal, GenericDataType.Numeric],
|
||||
label_map: { lineMetric: ['lineMetric'] },
|
||||
},
|
||||
);
|
||||
const queryBData = createTestQueryData(
|
||||
[{ __timestamp: timestamp, barMetric: 0.5 }],
|
||||
{
|
||||
colnames: ['__timestamp', 'barMetric'],
|
||||
coltypes: [GenericDataType.Temporal, GenericDataType.Numeric],
|
||||
label_map: { 'barMetric (1)': ['barMetric'] },
|
||||
},
|
||||
);
|
||||
const chartProps = createEchartsTimeseriesTestChartProps<
|
||||
EchartsMixedTimeseriesFormData,
|
||||
EchartsMixedTimeseriesProps
|
||||
>({
|
||||
...MIXED_TIMESERIES_CHART_PROPS_DEFAULTS,
|
||||
defaultQueriesData: [queryAData, queryBData],
|
||||
formData: {
|
||||
...formData,
|
||||
groupby: [],
|
||||
groupbyB: [],
|
||||
metrics: ['lineMetric'],
|
||||
metricsB: ['barMetric'],
|
||||
showValue: true,
|
||||
showValueB: true,
|
||||
stack: null,
|
||||
stackB: null,
|
||||
x_axis: '__timestamp',
|
||||
yAxisFormat: '.0%',
|
||||
yAxisFormatSecondary: ',.1f',
|
||||
yAxisIndex: 1,
|
||||
yAxisIndexB: 0,
|
||||
},
|
||||
queriesData: [queryAData, queryBData],
|
||||
});
|
||||
|
||||
const { echartOptions } = transformProps(chartProps);
|
||||
const series = echartOptions.series as SeriesOption[];
|
||||
const lineSeries = getSeriesWithLabelFormatter(series, 'lineMetric');
|
||||
const barSeries = getSeriesWithLabelFormatter(series, 'barMetric');
|
||||
|
||||
expect(formatSeriesLabel(lineSeries, [timestamp, 0.25])).toBe('0.3');
|
||||
expect(formatSeriesLabel(barSeries, [timestamp, 0.5])).toBe('50%');
|
||||
});
|
||||
|
||||
test('formats value labels correctly when y-axis assignments are reversed', () => {
|
||||
const timestamp = 1704067200000;
|
||||
const queryAData = createTestQueryData(
|
||||
[{ __timestamp: timestamp, lineMetric: 0.25 }],
|
||||
{
|
||||
colnames: ['__timestamp', 'lineMetric'],
|
||||
coltypes: [GenericDataType.Temporal, GenericDataType.Numeric],
|
||||
label_map: { lineMetric: ['lineMetric'] },
|
||||
},
|
||||
);
|
||||
const queryBData = createTestQueryData(
|
||||
[{ __timestamp: timestamp, barMetric: 0.5 }],
|
||||
{
|
||||
colnames: ['__timestamp', 'barMetric'],
|
||||
coltypes: [GenericDataType.Temporal, GenericDataType.Numeric],
|
||||
label_map: { 'barMetric (1)': ['barMetric'] },
|
||||
},
|
||||
);
|
||||
const chartProps = createEchartsTimeseriesTestChartProps<
|
||||
EchartsMixedTimeseriesFormData,
|
||||
EchartsMixedTimeseriesProps
|
||||
>({
|
||||
...MIXED_TIMESERIES_CHART_PROPS_DEFAULTS,
|
||||
defaultQueriesData: [queryAData, queryBData],
|
||||
formData: {
|
||||
...formData,
|
||||
groupby: [],
|
||||
groupbyB: [],
|
||||
metrics: ['lineMetric'],
|
||||
metricsB: ['barMetric'],
|
||||
showValue: true,
|
||||
showValueB: true,
|
||||
stack: null,
|
||||
stackB: null,
|
||||
x_axis: '__timestamp',
|
||||
yAxisFormat: '.0%',
|
||||
yAxisFormatSecondary: ',.1f',
|
||||
yAxisIndex: 0,
|
||||
yAxisIndexB: 1,
|
||||
},
|
||||
queriesData: [queryAData, queryBData],
|
||||
});
|
||||
|
||||
const { echartOptions } = transformProps(chartProps);
|
||||
const series = echartOptions.series as SeriesOption[];
|
||||
const lineSeries = getSeriesWithLabelFormatter(series, 'lineMetric');
|
||||
const barSeries = getSeriesWithLabelFormatter(series, 'barMetric');
|
||||
|
||||
expect(formatSeriesLabel(lineSeries, [timestamp, 0.25])).toBe('25%');
|
||||
expect(formatSeriesLabel(barSeries, [timestamp, 0.5])).toBe('0.5');
|
||||
});
|
||||
|
||||
test('keeps bar value label clipping aligned with the assigned y-axis', () => {
|
||||
const timestamp = 1704067200000;
|
||||
const queryAData = createTestQueryData(
|
||||
[{ __timestamp: timestamp, lineMetric: 0.25 }],
|
||||
{
|
||||
colnames: ['__timestamp', 'lineMetric'],
|
||||
coltypes: [GenericDataType.Temporal, GenericDataType.Numeric],
|
||||
label_map: { lineMetric: ['lineMetric'] },
|
||||
},
|
||||
);
|
||||
const queryBData = createTestQueryData(
|
||||
[{ __timestamp: timestamp, barMetric: 0.5 }],
|
||||
{
|
||||
colnames: ['__timestamp', 'barMetric'],
|
||||
coltypes: [GenericDataType.Temporal, GenericDataType.Numeric],
|
||||
label_map: { 'barMetric (1)': ['barMetric'] },
|
||||
},
|
||||
);
|
||||
const chartProps = createEchartsTimeseriesTestChartProps<
|
||||
EchartsMixedTimeseriesFormData,
|
||||
EchartsMixedTimeseriesProps
|
||||
>({
|
||||
...MIXED_TIMESERIES_CHART_PROPS_DEFAULTS,
|
||||
defaultQueriesData: [queryAData, queryBData],
|
||||
formData: {
|
||||
...formData,
|
||||
groupby: [],
|
||||
groupbyB: [],
|
||||
metrics: ['lineMetric'],
|
||||
metricsB: ['barMetric'],
|
||||
showValue: true,
|
||||
showValueB: true,
|
||||
stack: null,
|
||||
stackB: null,
|
||||
x_axis: '__timestamp',
|
||||
yAxisBounds: [undefined, 1],
|
||||
yAxisBoundsSecondary: [undefined, 0.1],
|
||||
yAxisFormat: '.0%',
|
||||
yAxisFormatSecondary: ',.1f',
|
||||
yAxisIndex: 0,
|
||||
yAxisIndexB: 1,
|
||||
},
|
||||
queriesData: [queryAData, queryBData],
|
||||
});
|
||||
|
||||
const { echartOptions } = transformProps(chartProps);
|
||||
const series = echartOptions.series as SeriesOption[];
|
||||
const barSeries = getSeriesWithLabelFormatter(series, 'barMetric');
|
||||
|
||||
expect(formatSeriesLabel(barSeries, [timestamp, 0.5])).toBe('');
|
||||
});
|
||||
|
||||
describe('legend sorting', () => {
|
||||
const getChartProps = (overrides = {}) =>
|
||||
createEchartsTimeseriesTestChartProps<
|
||||
|
||||
@@ -44,13 +44,13 @@ const fakeTableApiResult = {
|
||||
result: [
|
||||
{
|
||||
id: 1,
|
||||
value: 'fake api result1',
|
||||
value: 'fake_api_result1',
|
||||
label: 'fake api label1',
|
||||
type: 'table',
|
||||
},
|
||||
{
|
||||
id: 2,
|
||||
value: 'fake api result2',
|
||||
value: 'fake_api_result2',
|
||||
label: 'fake api label2',
|
||||
type: 'table',
|
||||
},
|
||||
@@ -152,6 +152,64 @@ test('returns keywords including fetched function_names data', async () => {
|
||||
});
|
||||
});
|
||||
|
||||
test('quotes table identifiers that require quoting in the inserted value', async () => {
|
||||
const dbFunctionNamesApiRoute = `glob:*/api/v1/database/${expectDbId}/function_names/`;
|
||||
fetchMock.get(dbFunctionNamesApiRoute, fakeFunctionNamesApiResult);
|
||||
|
||||
act(() => {
|
||||
store.dispatch(
|
||||
tableApiUtil.upsertQueryData(
|
||||
'tables',
|
||||
{ dbId: expectDbId, schema: expectSchema },
|
||||
{
|
||||
options: [
|
||||
{ value: 'COVID Vaccines', label: 'COVID Vaccines', type: 'table' },
|
||||
{ value: 'simple_table', label: 'simple_table', type: 'table' },
|
||||
],
|
||||
hasMore: false,
|
||||
},
|
||||
),
|
||||
);
|
||||
});
|
||||
|
||||
const { result } = renderHook(
|
||||
() =>
|
||||
useKeywords({
|
||||
queryEditorId: 'testqueryid',
|
||||
dbId: expectDbId,
|
||||
schema: expectSchema,
|
||||
}),
|
||||
{
|
||||
wrapper: createWrapper({
|
||||
useRedux: true,
|
||||
store,
|
||||
}),
|
||||
},
|
||||
);
|
||||
|
||||
await waitFor(() =>
|
||||
expect(fetchMock.callHistory.calls(dbFunctionNamesApiRoute).length).toBe(1),
|
||||
);
|
||||
|
||||
// A name that needs quoting is inserted as a double-quoted identifier,
|
||||
// while its display name stays human-readable.
|
||||
expect(result.current).toContainEqual(
|
||||
expect.objectContaining({
|
||||
name: 'COVID Vaccines',
|
||||
value: '"COVID Vaccines"',
|
||||
meta: 'table',
|
||||
}),
|
||||
);
|
||||
// A simple identifier is inserted as-is, without quotes.
|
||||
expect(result.current).toContainEqual(
|
||||
expect.objectContaining({
|
||||
name: 'simple_table',
|
||||
value: 'simple_table',
|
||||
meta: 'table',
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
test('skip fetching if autocomplete skipped', () => {
|
||||
const { result } = renderHook(
|
||||
() =>
|
||||
|
||||
@@ -53,6 +53,14 @@ const getHelperText = (value: string) =>
|
||||
detail: value,
|
||||
};
|
||||
|
||||
// Names that aren't simple identifiers (spaces, punctuation, leading digits)
|
||||
// must be double-quoted to be valid SQL, with embedded quotes doubled.
|
||||
const SIMPLE_IDENTIFIER_RE = /^[A-Za-z_][A-Za-z0-9_]*$/;
|
||||
const quoteIdentifier = (identifier: string) =>
|
||||
SIMPLE_IDENTIFIER_RE.test(identifier)
|
||||
? identifier
|
||||
: `"${identifier.replace(/"/g, '""')}"`;
|
||||
|
||||
const extensionsRegistry = getExtensionsRegistry();
|
||||
|
||||
export function useKeywords(
|
||||
@@ -197,7 +205,7 @@ export function useKeywords(
|
||||
() =>
|
||||
allCachedTables.map(({ value, label, schema: tableSchema }) => ({
|
||||
name: label,
|
||||
value,
|
||||
value: quoteIdentifier(value),
|
||||
schema: tableSchema,
|
||||
score: TABLE_AUTOCOMPLETE_SCORE,
|
||||
meta: 'table',
|
||||
|
||||
@@ -16,6 +16,7 @@ KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations
|
||||
under the License.
|
||||
-->
|
||||
|
||||
# Superset WebSocket Server
|
||||
|
||||
A Node.js WebSocket server for sending async event data to the Superset web application frontend.
|
||||
@@ -164,4 +165,4 @@ HEAD /health
|
||||
|
||||
## Containerization
|
||||
|
||||
*TODO: containerize websocket server*
|
||||
_TODO: containerize websocket server_
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
module.exports = {
|
||||
preset: 'ts-jest',
|
||||
testEnvironment: 'node',
|
||||
};
|
||||
10617
superset-websocket/package-lock.json
generated
10617
superset-websocket/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -5,7 +5,7 @@
|
||||
"main": "index.js",
|
||||
"scripts": {
|
||||
"start": "node dist/index.js start",
|
||||
"test": "NODE_ENV=test jest -i spec",
|
||||
"test": "npx vitest --run --dir spec",
|
||||
"type": "tsc --noEmit",
|
||||
"eslint": "eslint",
|
||||
"lint": "npm run eslint -- . && npm run type",
|
||||
@@ -28,24 +28,22 @@
|
||||
"devDependencies": {
|
||||
"@eslint/js": "^9.25.1",
|
||||
"@types/eslint__js": "^8.42.3",
|
||||
"@types/jest": "^29.5.14",
|
||||
"@types/jsonwebtoken": "^9.0.10",
|
||||
"@types/lodash": "^4.17.24",
|
||||
"@types/node": "^25.9.3",
|
||||
"@types/ws": "^8.18.1",
|
||||
"@typescript-eslint/eslint-plugin": "^8.61.1",
|
||||
"@typescript-eslint/parser": "^8.61.1",
|
||||
"@typescript-eslint/eslint-plugin": "^8.62.0",
|
||||
"@typescript-eslint/parser": "^8.62.0",
|
||||
"eslint": "^10.5.0",
|
||||
"eslint-config-prettier": "^10.1.8",
|
||||
"eslint-plugin-lodash": "^8.0.0",
|
||||
"globals": "^17.6.0",
|
||||
"jest": "^29.7.0",
|
||||
"prettier": "^3.8.4",
|
||||
"ts-jest": "^29.4.11",
|
||||
"ts-node": "^10.9.2",
|
||||
"tscw-config": "^1.1.2",
|
||||
"typescript": "^6.0.3",
|
||||
"typescript-eslint": "^8.61.1"
|
||||
"typescript-eslint": "^8.62.0",
|
||||
"vitest": "^4.1.5"
|
||||
},
|
||||
"engines": {
|
||||
"node": "^24.16.0",
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
* under the License.
|
||||
*/
|
||||
import { buildConfig } from '../src/config';
|
||||
import { expect, test } from '@jest/globals';
|
||||
import { expect, test } from 'vitest';
|
||||
|
||||
test('buildConfig() builds configuration and applies env var overrides', () => {
|
||||
let config = buildConfig();
|
||||
|
||||
@@ -25,29 +25,29 @@ import {
|
||||
test,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
jest,
|
||||
} from '@jest/globals';
|
||||
vi,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
import * as http from 'http';
|
||||
import * as net from 'net';
|
||||
import { WebSocket } from 'ws';
|
||||
import * as server from '../src/index';
|
||||
import { statsd } from '../src/index';
|
||||
|
||||
interface MockedRedisXrange {
|
||||
(): Promise<server.StreamResult[]>;
|
||||
}
|
||||
|
||||
// NOTE: these mock variables needs to start with "mock" due to
|
||||
// calls to `jest.mock` being hoisted to the top of the file.
|
||||
// https://jestjs.io/docs/es6-class-mocks#calling-jestmock-with-the-module-factory-parameter
|
||||
const mockRedisXrange = jest.fn() as jest.MockedFunction<MockedRedisXrange>;
|
||||
|
||||
jest.mock('ws');
|
||||
jest.mock('ioredis', () => {
|
||||
return jest.fn().mockImplementation(() => {
|
||||
return { xrange: mockRedisXrange, on: jest.fn() };
|
||||
});
|
||||
const { mockRedisXrange } = vi.hoisted(() => {
|
||||
return { mockRedisXrange: vi.fn() };
|
||||
});
|
||||
|
||||
const wsMock = WebSocket as jest.Mocked<typeof WebSocket>;
|
||||
vi.mock('ws');
|
||||
vi.mock('ioredis', () => {
|
||||
return {
|
||||
Redis: vi.fn().mockImplementation(function () {
|
||||
return { xrange: mockRedisXrange, on: vi.fn() };
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
const wsMock = WebSocket as unknown as Mock<typeof WebSocket>;
|
||||
const channelId = 'bc9e040c-7b4a-4817-99b9-292832d97ec7';
|
||||
const streamReturnValue: server.StreamResult[] = [
|
||||
[
|
||||
@@ -66,16 +66,13 @@ const streamReturnValue: server.StreamResult[] = [
|
||||
],
|
||||
];
|
||||
|
||||
import * as server from '../src/index';
|
||||
import { statsd } from '../src/index';
|
||||
|
||||
describe('server', () => {
|
||||
let statsdIncrementMock: jest.SpiedFunction<typeof statsd.increment>;
|
||||
let statsdIncrementMock: Mock<typeof statsd.increment>;
|
||||
|
||||
beforeEach(() => {
|
||||
mockRedisXrange.mockClear();
|
||||
server.resetState();
|
||||
statsdIncrementMock = jest.spyOn(statsd, 'increment').mockReturnValue();
|
||||
statsdIncrementMock = vi.spyOn(statsd, 'increment').mockReturnValue();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -84,8 +81,8 @@ describe('server', () => {
|
||||
|
||||
describe('HTTP requests', () => {
|
||||
test('services health checks', () => {
|
||||
const endMock = jest.fn();
|
||||
const writeHeadMock = jest.fn();
|
||||
const endMock = vi.fn();
|
||||
const writeHeadMock = vi.fn();
|
||||
|
||||
const request = {
|
||||
url: '/health',
|
||||
@@ -113,8 +110,8 @@ describe('server', () => {
|
||||
});
|
||||
|
||||
test('responds with a 404 when not found', () => {
|
||||
const endMock = jest.fn();
|
||||
const writeHeadMock = jest.fn();
|
||||
const endMock = vi.fn();
|
||||
const writeHeadMock = vi.fn();
|
||||
|
||||
const request = {
|
||||
url: '/unsupported',
|
||||
@@ -245,7 +242,7 @@ describe('server', () => {
|
||||
describe('processStreamResults', () => {
|
||||
test('sends data to channel', async () => {
|
||||
const ws = new wsMock('localhost');
|
||||
const sendMock = jest.spyOn(ws, 'send');
|
||||
const sendMock = vi.spyOn(ws, 'send');
|
||||
const socketInstance = { ws: ws, channel: channelId, pongTs: Date.now() };
|
||||
|
||||
expect(statsdIncrementMock).toHaveBeenCalledTimes(0);
|
||||
@@ -267,7 +264,7 @@ describe('server', () => {
|
||||
|
||||
test('channel not present', async () => {
|
||||
const ws = new wsMock('localhost');
|
||||
const sendMock = jest.spyOn(ws, 'send');
|
||||
const sendMock = vi.spyOn(ws, 'send');
|
||||
|
||||
expect(statsdIncrementMock).toHaveBeenCalledTimes(0);
|
||||
server.processStreamResults(streamReturnValue);
|
||||
@@ -278,10 +275,9 @@ describe('server', () => {
|
||||
|
||||
test('error sending data to client', async () => {
|
||||
const ws = new wsMock('localhost');
|
||||
const sendMock = jest.spyOn(ws, 'send').mockImplementation(() => {
|
||||
const sendMock = vi.spyOn(ws, 'send').mockImplementation(() => {
|
||||
throw new Error();
|
||||
});
|
||||
const cleanChannelMock = jest.spyOn(server, 'cleanChannel');
|
||||
const socketInstance = { ws: ws, channel: channelId, pongTs: Date.now() };
|
||||
|
||||
expect(statsdIncrementMock).toHaveBeenCalledTimes(0);
|
||||
@@ -300,9 +296,7 @@ describe('server', () => {
|
||||
);
|
||||
|
||||
expect(sendMock).toHaveBeenCalled();
|
||||
expect(cleanChannelMock).toHaveBeenCalledWith(channelId);
|
||||
|
||||
cleanChannelMock.mockRestore();
|
||||
expect(Object.keys(server.channels)).toHaveLength(0);
|
||||
});
|
||||
|
||||
const makeItem = (i: number): server.StreamResult =>
|
||||
@@ -330,8 +324,8 @@ describe('server', () => {
|
||||
channel: channelId,
|
||||
pongTs: Date.now(),
|
||||
});
|
||||
const sendMock = jest.spyOn(ws, 'send');
|
||||
const setImmediateSpy = jest.spyOn(global, 'setImmediate');
|
||||
const sendMock = vi.spyOn(ws, 'send');
|
||||
const setImmediateSpy = vi.spyOn(global, 'setImmediate');
|
||||
|
||||
const results = [0, 1, 2, 3, 4].map(makeItem);
|
||||
await server.processStreamResults(results);
|
||||
@@ -351,7 +345,7 @@ describe('server', () => {
|
||||
channel: channelId,
|
||||
pongTs: Date.now(),
|
||||
});
|
||||
const sendMock = jest.spyOn(ws, 'send');
|
||||
const sendMock = vi.spyOn(ws, 'send');
|
||||
|
||||
const results = [0, 1, 2, 3, 4].map(makeItem);
|
||||
await server.processStreamResults(results);
|
||||
@@ -372,16 +366,16 @@ describe('server', () => {
|
||||
server.opts.maxSocketBufferBytes = 0;
|
||||
// Restore any spies (e.g. on server.cleanChannel) so they don't leak
|
||||
// across tests and cause order-dependent failures.
|
||||
jest.restoreAllMocks();
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
test('does not terminate when cap disabled (0)', () => {
|
||||
server.opts.maxSocketBufferBytes = 0;
|
||||
const ws = new wsMock('localhost');
|
||||
// simulate a large outbound buffer
|
||||
(ws as unknown as { bufferedAmount: number }).bufferedAmount = 10_000_000;
|
||||
const terminateMock = jest.spyOn(ws, 'terminate');
|
||||
const sendMock = jest.spyOn(ws, 'send');
|
||||
vi.spyOn(ws, 'bufferedAmount', 'get').mockReturnValueOnce(10_000_000);
|
||||
const terminateMock = vi.spyOn(ws, 'terminate');
|
||||
const sendMock = vi.spyOn(ws, 'send');
|
||||
server.trackClient(channelId, {
|
||||
ws,
|
||||
channel: channelId,
|
||||
@@ -397,10 +391,9 @@ describe('server', () => {
|
||||
test('terminates a slow client whose buffer exceeds the cap', () => {
|
||||
server.opts.maxSocketBufferBytes = 1024;
|
||||
const ws = new wsMock('localhost');
|
||||
(ws as unknown as { bufferedAmount: number }).bufferedAmount = 2048;
|
||||
const terminateMock = jest.spyOn(ws, 'terminate');
|
||||
const sendMock = jest.spyOn(ws, 'send');
|
||||
const cleanChannelMock = jest.spyOn(server, 'cleanChannel');
|
||||
vi.spyOn(ws, 'bufferedAmount', 'get').mockReturnValueOnce(2048);
|
||||
const terminateMock = vi.spyOn(ws, 'terminate');
|
||||
const sendMock = vi.spyOn(ws, 'send');
|
||||
server.trackClient(channelId, {
|
||||
ws,
|
||||
channel: channelId,
|
||||
@@ -414,15 +407,15 @@ describe('server', () => {
|
||||
expect(statsdIncrementMock).toHaveBeenCalledWith(
|
||||
'ws_client_backpressure_disconnect',
|
||||
);
|
||||
expect(cleanChannelMock).toHaveBeenCalledWith(channelId);
|
||||
expect(Object.keys(server.channels)).toHaveLength(0);
|
||||
});
|
||||
|
||||
test('keeps sending to a client within the cap', () => {
|
||||
server.opts.maxSocketBufferBytes = 1024;
|
||||
const ws = new wsMock('localhost');
|
||||
(ws as unknown as { bufferedAmount: number }).bufferedAmount = 16;
|
||||
const terminateMock = jest.spyOn(ws, 'terminate');
|
||||
const sendMock = jest.spyOn(ws, 'send');
|
||||
vi.spyOn(ws, 'bufferedAmount', 'get').mockReturnValueOnce(16);
|
||||
const terminateMock = vi.spyOn(ws, 'terminate');
|
||||
const sendMock = vi.spyOn(ws, 'send');
|
||||
server.trackClient(channelId, {
|
||||
ws,
|
||||
channel: channelId,
|
||||
@@ -443,7 +436,7 @@ describe('server', () => {
|
||||
|
||||
test('success with results', async () => {
|
||||
mockRedisXrange.mockResolvedValueOnce(streamReturnValue);
|
||||
const cb = jest.fn() as jest.MockedFunction<
|
||||
const cb = vi.fn() as Mock<
|
||||
(results: server.StreamResult[]) => void | Promise<void>
|
||||
>;
|
||||
await server.fetchRangeFromStream({
|
||||
@@ -462,7 +455,7 @@ describe('server', () => {
|
||||
});
|
||||
|
||||
test('success no results', async () => {
|
||||
const cb = jest.fn() as jest.MockedFunction<
|
||||
const cb = vi.fn() as Mock<
|
||||
(results: server.StreamResult[]) => void | Promise<void>
|
||||
>;
|
||||
await server.fetchRangeFromStream({
|
||||
@@ -481,7 +474,7 @@ describe('server', () => {
|
||||
});
|
||||
|
||||
test('error', async () => {
|
||||
const cb = jest.fn() as jest.MockedFunction<
|
||||
const cb = vi.fn() as Mock<
|
||||
(results: server.StreamResult[]) => void | Promise<void>
|
||||
>;
|
||||
mockRedisXrange.mockRejectedValueOnce(new Error());
|
||||
@@ -503,12 +496,8 @@ describe('server', () => {
|
||||
|
||||
describe('wsConnection', () => {
|
||||
let ws: WebSocket;
|
||||
let wsEventMock: jest.SpiedFunction<typeof ws.on>;
|
||||
let trackClientSpy: jest.SpiedFunction<typeof server.trackClient>;
|
||||
let fetchRangeFromStreamSpy: jest.SpiedFunction<
|
||||
typeof server.fetchRangeFromStream
|
||||
>;
|
||||
let dateNowSpy: jest.SpiedFunction<typeof Date.now>;
|
||||
let wsEventMock: Mock<typeof ws.on>;
|
||||
let dateNowSpy: Mock<typeof Date.now>;
|
||||
let socketInstanceExpected: server.SocketInstance;
|
||||
|
||||
const getRequest = (token: string, url: string): http.IncomingMessage => {
|
||||
@@ -521,10 +510,8 @@ describe('server', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
ws = new wsMock('localhost');
|
||||
wsEventMock = jest.spyOn(ws, 'on');
|
||||
trackClientSpy = jest.spyOn(server, 'trackClient');
|
||||
fetchRangeFromStreamSpy = jest.spyOn(server, 'fetchRangeFromStream');
|
||||
dateNowSpy = jest
|
||||
wsEventMock = vi.spyOn(ws, 'on');
|
||||
dateNowSpy = vi
|
||||
.spyOn(global.Date, 'now')
|
||||
.mockImplementation(() =>
|
||||
new Date('2021-03-10T11:01:58.135Z').valueOf(),
|
||||
@@ -537,10 +524,8 @@ describe('server', () => {
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
wsEventMock.mockRestore();
|
||||
trackClientSpy.mockRestore();
|
||||
fetchRangeFromStreamSpy.mockRestore();
|
||||
dateNowSpy.mockRestore();
|
||||
wsEventMock?.mockRestore();
|
||||
dateNowSpy?.mockRestore();
|
||||
});
|
||||
|
||||
test('invalid JWT', async () => {
|
||||
@@ -558,11 +543,14 @@ describe('server', () => {
|
||||
|
||||
server.wsConnection(ws, request);
|
||||
|
||||
expect(trackClientSpy).toHaveBeenCalledWith(
|
||||
channelId,
|
||||
socketInstanceExpected,
|
||||
);
|
||||
expect(fetchRangeFromStreamSpy).not.toHaveBeenCalled();
|
||||
const channelSockets = server.channels[channelId];
|
||||
expect(channelSockets).toEqual({
|
||||
sockets: expect.any(Array<string>),
|
||||
});
|
||||
expect(channelSockets.sockets).toHaveLength(1);
|
||||
const socketId = channelSockets.sockets[0];
|
||||
expect(server.sockets[socketId]).toEqual(socketInstanceExpected);
|
||||
expect(mockRedisXrange).not.toHaveBeenCalled();
|
||||
expect(wsEventMock).toHaveBeenCalledWith('pong', expect.any(Function));
|
||||
});
|
||||
|
||||
@@ -575,12 +563,15 @@ describe('server', () => {
|
||||
|
||||
server.wsConnection(ws, request);
|
||||
|
||||
expect(trackClientSpy).toHaveBeenCalledWith(
|
||||
channelId,
|
||||
socketInstanceExpected,
|
||||
);
|
||||
const channelSockets = server.channels[channelId];
|
||||
expect(channelSockets).toEqual({
|
||||
sockets: expect.any(Array<string>),
|
||||
});
|
||||
expect(channelSockets.sockets).toHaveLength(1);
|
||||
|
||||
// Malformed last_id must not trigger a stream range fetch.
|
||||
expect(fetchRangeFromStreamSpy).not.toHaveBeenCalled();
|
||||
const socketId = channelSockets.sockets[0];
|
||||
expect(server.sockets[socketId]).toEqual(socketInstanceExpected);
|
||||
});
|
||||
|
||||
test('valid JWT, with lastId', async () => {
|
||||
@@ -593,16 +584,18 @@ describe('server', () => {
|
||||
|
||||
server.wsConnection(ws, request);
|
||||
|
||||
expect(trackClientSpy).toHaveBeenCalledWith(
|
||||
channelId,
|
||||
socketInstanceExpected,
|
||||
);
|
||||
expect(fetchRangeFromStreamSpy).toHaveBeenCalledWith({
|
||||
sessionId: channelId,
|
||||
startId: '1615426152415-1',
|
||||
endId: '+',
|
||||
listener: server.processStreamResults,
|
||||
const channelSockets = server.channels[channelId];
|
||||
expect(channelSockets).toEqual({
|
||||
sockets: expect.any(Array<string>),
|
||||
});
|
||||
expect(channelSockets.sockets).toHaveLength(1);
|
||||
const socketId = channelSockets.sockets[0];
|
||||
expect(server.sockets[socketId]).toEqual(socketInstanceExpected);
|
||||
expect(mockRedisXrange).toHaveBeenCalledWith(
|
||||
expect.stringContaining(channelId),
|
||||
'1615426152415-1',
|
||||
'+',
|
||||
);
|
||||
expect(wsEventMock).toHaveBeenCalledWith('pong', expect.any(Function));
|
||||
});
|
||||
|
||||
@@ -618,16 +611,18 @@ describe('server', () => {
|
||||
server.setLastFirehoseId(lastFirehoseId);
|
||||
server.wsConnection(ws, request);
|
||||
|
||||
expect(trackClientSpy).toHaveBeenCalledWith(
|
||||
channelId,
|
||||
socketInstanceExpected,
|
||||
);
|
||||
expect(fetchRangeFromStreamSpy).toHaveBeenCalledWith({
|
||||
sessionId: channelId,
|
||||
startId: '1615426152415-1',
|
||||
endId: lastFirehoseId,
|
||||
listener: server.processStreamResults,
|
||||
const channelSockets = server.channels[channelId];
|
||||
expect(channelSockets).toEqual({
|
||||
sockets: expect.any(Array<string>),
|
||||
});
|
||||
expect(channelSockets.sockets).toHaveLength(1);
|
||||
const socketId = channelSockets.sockets[0];
|
||||
expect(server.sockets[socketId]).toEqual(socketInstanceExpected);
|
||||
expect(mockRedisXrange).toHaveBeenCalledWith(
|
||||
expect.stringContaining(channelId),
|
||||
'1615426152415-1',
|
||||
lastFirehoseId,
|
||||
);
|
||||
expect(wsEventMock).toHaveBeenCalledWith('pong', expect.any(Function));
|
||||
});
|
||||
});
|
||||
@@ -662,7 +657,7 @@ describe('server', () => {
|
||||
test('total connection limit reached', () => {
|
||||
server.opts.maxTotalConnections = 1;
|
||||
const ws = new wsMock('localhost');
|
||||
setReadyState(ws, WebSocket.OPEN);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
|
||||
const socketInstance = {
|
||||
ws,
|
||||
channel: channelId,
|
||||
@@ -677,7 +672,7 @@ describe('server', () => {
|
||||
test('per-channel connection limit reached', () => {
|
||||
server.opts.maxConnectionsPerChannel = 1;
|
||||
const ws = new wsMock('localhost');
|
||||
setReadyState(ws, WebSocket.OPEN);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
|
||||
const socketInstance = {
|
||||
ws,
|
||||
channel: channelId,
|
||||
@@ -699,7 +694,7 @@ describe('server', () => {
|
||||
};
|
||||
server.trackClient(channelId, socketInstance);
|
||||
// simulate the socket having closed but not yet been GC'd
|
||||
setReadyState(ws, WebSocket.CLOSED);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.CLOSED);
|
||||
expect(server.connectionLimitReason('some-other-channel')).toBeNull();
|
||||
});
|
||||
|
||||
@@ -713,13 +708,13 @@ describe('server', () => {
|
||||
};
|
||||
server.trackClient(channelId, socketInstance);
|
||||
// simulate the socket having closed but not yet been GC'd
|
||||
setReadyState(ws, WebSocket.CLOSED);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.CLOSED);
|
||||
expect(server.connectionLimitReason(channelId)).toBeNull();
|
||||
});
|
||||
|
||||
test('isSocketActive reflects the socket readyState', () => {
|
||||
const ws = new wsMock('localhost');
|
||||
setReadyState(ws, WebSocket.OPEN);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
|
||||
const socketId = server.trackClient(channelId, {
|
||||
ws,
|
||||
channel: channelId,
|
||||
@@ -727,9 +722,9 @@ describe('server', () => {
|
||||
});
|
||||
expect(server.isSocketActive(socketId)).toBe(true);
|
||||
// CONNECTING is also considered active (see SOCKET_ACTIVE_STATES)
|
||||
setReadyState(ws, WebSocket.CONNECTING);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.CONNECTING);
|
||||
expect(server.isSocketActive(socketId)).toBe(true);
|
||||
setReadyState(ws, WebSocket.CLOSED);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.CLOSED);
|
||||
expect(server.isSocketActive(socketId)).toBe(false);
|
||||
// unknown socket ids are never active
|
||||
expect(server.isSocketActive('does-not-exist')).toBe(false);
|
||||
@@ -737,14 +732,14 @@ describe('server', () => {
|
||||
|
||||
test('activeSocketCount counts only active sockets', () => {
|
||||
const openWs = new wsMock('localhost');
|
||||
setReadyState(openWs, WebSocket.OPEN);
|
||||
vi.spyOn(openWs, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
|
||||
server.trackClient(channelId, {
|
||||
ws: openWs,
|
||||
channel: channelId,
|
||||
pongTs: Date.now(),
|
||||
});
|
||||
const closedWs = new wsMock('localhost');
|
||||
setReadyState(closedWs, WebSocket.CLOSED);
|
||||
vi.spyOn(closedWs, 'readyState', 'get').mockReturnValue(WebSocket.CLOSED);
|
||||
server.trackClient(channelId, {
|
||||
ws: closedWs,
|
||||
channel: channelId,
|
||||
@@ -755,14 +750,14 @@ describe('server', () => {
|
||||
|
||||
test('activeChannelSocketCount counts only active sockets on the channel', () => {
|
||||
const openWs = new wsMock('localhost');
|
||||
setReadyState(openWs, WebSocket.OPEN);
|
||||
vi.spyOn(openWs, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
|
||||
server.trackClient(channelId, {
|
||||
ws: openWs,
|
||||
channel: channelId,
|
||||
pongTs: Date.now(),
|
||||
});
|
||||
const closedWs = new wsMock('localhost');
|
||||
setReadyState(closedWs, WebSocket.CLOSED);
|
||||
vi.spyOn(closedWs, 'readyState', 'get').mockReturnValue(WebSocket.CLOSED);
|
||||
server.trackClient(channelId, {
|
||||
ws: closedWs,
|
||||
channel: channelId,
|
||||
@@ -776,7 +771,7 @@ describe('server', () => {
|
||||
test('wsConnection refuses over-limit connection without tracking', () => {
|
||||
server.opts.maxConnectionsPerChannel = 1;
|
||||
const existingWs = new wsMock('localhost');
|
||||
setReadyState(existingWs, WebSocket.OPEN);
|
||||
vi.spyOn(existingWs, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
|
||||
const existing = {
|
||||
ws: existingWs,
|
||||
channel: channelId,
|
||||
@@ -784,7 +779,7 @@ describe('server', () => {
|
||||
};
|
||||
server.trackClient(channelId, existing);
|
||||
|
||||
const trackClientSpy = jest.spyOn(server, 'trackClient');
|
||||
const trackClientSpy = vi.spyOn(server, 'trackClient');
|
||||
const ws = new wsMock('localhost');
|
||||
const validToken = jwt.sign({ channel: channelId }, config.jwtSecret);
|
||||
server.wsConnection(ws, getRequest(validToken, 'http://localhost'));
|
||||
@@ -800,8 +795,8 @@ describe('server', () => {
|
||||
|
||||
describe('httpUpgrade', () => {
|
||||
let socket: net.Socket;
|
||||
let socketDestroySpy: jest.SpiedFunction<typeof socket.destroy>;
|
||||
let wssUpgradeSpy: jest.SpiedFunction<typeof server.wss.handleUpgrade>;
|
||||
let socketDestroySpy: Mock<typeof socket.destroy>;
|
||||
let wssUpgradeSpy: Mock<typeof server.wss.handleUpgrade>;
|
||||
|
||||
const getRequest = (token: string, url: string): http.IncomingMessage => {
|
||||
const request = new http.IncomingMessage(new net.Socket());
|
||||
@@ -813,8 +808,8 @@ describe('server', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
socket = new net.Socket();
|
||||
socketDestroySpy = jest.spyOn(socket, 'destroy');
|
||||
wssUpgradeSpy = jest.spyOn(server.wss, 'handleUpgrade');
|
||||
socketDestroySpy = vi.spyOn(socket, 'destroy');
|
||||
wssUpgradeSpy = vi.spyOn(server.wss, 'handleUpgrade');
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -952,33 +947,21 @@ describe('server', () => {
|
||||
});
|
||||
});
|
||||
|
||||
const setReadyState = (ws: WebSocket, value: typeof ws.readyState) => {
|
||||
// workaround for not being able to do
|
||||
// spyOn(instance,'readyState','get').and.returnValue(value);
|
||||
// See for details: https://github.com/facebook/jest/issues/9675
|
||||
Object.defineProperty(ws, 'readyState', {
|
||||
configurable: true,
|
||||
get() {
|
||||
return value;
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
describe('checkSockets', () => {
|
||||
let ws: WebSocket;
|
||||
let pingSpy: jest.SpiedFunction<typeof ws.ping>;
|
||||
let terminateSpy: jest.SpiedFunction<typeof ws.terminate>;
|
||||
let pingSpy: Mock<typeof ws.ping>;
|
||||
let terminateSpy: Mock<typeof ws.terminate>;
|
||||
let socketInstance: server.SocketInstance;
|
||||
|
||||
beforeEach(() => {
|
||||
ws = new wsMock('localhost');
|
||||
pingSpy = jest.spyOn(ws, 'ping');
|
||||
terminateSpy = jest.spyOn(ws, 'terminate');
|
||||
pingSpy = vi.spyOn(ws, 'ping');
|
||||
terminateSpy = vi.spyOn(ws, 'terminate');
|
||||
socketInstance = { ws: ws, channel: channelId, pongTs: Date.now() };
|
||||
});
|
||||
|
||||
test('active sockets', () => {
|
||||
setReadyState(ws, WebSocket.OPEN);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
|
||||
server.trackClient(channelId, socketInstance);
|
||||
|
||||
server.checkSockets();
|
||||
@@ -989,7 +972,7 @@ describe('server', () => {
|
||||
});
|
||||
|
||||
test('stale sockets', () => {
|
||||
setReadyState(ws, WebSocket.OPEN);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
|
||||
socketInstance.pongTs = Date.now() - 60000;
|
||||
server.trackClient(channelId, socketInstance);
|
||||
|
||||
@@ -1001,7 +984,7 @@ describe('server', () => {
|
||||
});
|
||||
|
||||
test('closed sockets', () => {
|
||||
setReadyState(ws, WebSocket.CLOSED);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.CLOSED);
|
||||
server.trackClient(channelId, socketInstance);
|
||||
|
||||
server.checkSockets();
|
||||
@@ -1027,7 +1010,7 @@ describe('server', () => {
|
||||
});
|
||||
|
||||
test('active sockets', () => {
|
||||
setReadyState(ws, WebSocket.OPEN);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
|
||||
server.trackClient(channelId, socketInstance);
|
||||
|
||||
server.cleanChannel(channelId);
|
||||
@@ -1036,7 +1019,7 @@ describe('server', () => {
|
||||
});
|
||||
|
||||
test('closing sockets', () => {
|
||||
setReadyState(ws, WebSocket.CLOSING);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.CLOSING);
|
||||
server.trackClient(channelId, socketInstance);
|
||||
|
||||
server.cleanChannel(channelId);
|
||||
@@ -1045,11 +1028,12 @@ describe('server', () => {
|
||||
});
|
||||
|
||||
test('multiple sockets', () => {
|
||||
setReadyState(ws, WebSocket.OPEN);
|
||||
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
|
||||
server.trackClient(channelId, socketInstance);
|
||||
|
||||
const ws2 = new wsMock('localhost');
|
||||
setReadyState(ws2, WebSocket.OPEN);
|
||||
const readyStateSpy = vi.spyOn(ws2, 'readyState', 'get');
|
||||
readyStateSpy.mockReturnValue(WebSocket.OPEN);
|
||||
const socketInstance2 = {
|
||||
ws: ws2,
|
||||
channel: channelId,
|
||||
@@ -1061,7 +1045,7 @@ describe('server', () => {
|
||||
|
||||
expect(server.channels[channelId].sockets.length).toBe(2);
|
||||
|
||||
setReadyState(ws2, WebSocket.CLOSED);
|
||||
readyStateSpy.mockReturnValue(WebSocket.CLOSED);
|
||||
server.cleanChannel(channelId);
|
||||
|
||||
expect(server.channels[channelId].sockets.length).toBe(1);
|
||||
|
||||
@@ -19,11 +19,11 @@
|
||||
import * as http from 'http';
|
||||
import * as net from 'net';
|
||||
import { inspect } from 'util';
|
||||
import WebSocket, { WebSocketServer } from 'ws';
|
||||
import { WebSocket, WebSocketServer } from 'ws';
|
||||
import { randomUUID } from 'crypto';
|
||||
import jwt, { Algorithm } from 'jsonwebtoken';
|
||||
import { parse } from 'cookie';
|
||||
import Redis, { RedisOptions } from 'ioredis';
|
||||
import { Redis, RedisOptions } from 'ioredis';
|
||||
import StatsD from 'hot-shots';
|
||||
|
||||
import { createLogger } from './logger';
|
||||
|
||||
@@ -16,6 +16,7 @@ KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations
|
||||
under the License.
|
||||
-->
|
||||
|
||||
# Test & development utilities
|
||||
|
||||
The files provided here are for testing and development only, and are not required to run the WebSocket server application.
|
||||
|
||||
@@ -16,6 +16,7 @@ KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations
|
||||
under the License.
|
||||
-->
|
||||
|
||||
# Test client application
|
||||
|
||||
This Express web application is provided for testing the WebSocket server. It is not required for running the server application, and is provided here for testing and development purposes only.
|
||||
|
||||
@@ -100,7 +100,10 @@ class CacheRestApi(BaseSupersetModelRestApi):
|
||||
)
|
||||
cache_keys = [c.cache_key for c in cache_key_objs]
|
||||
if cache_key_objs:
|
||||
all_keys_deleted = cache_manager.cache.delete_many(*cache_keys)
|
||||
# Chart query results live in ``data_cache``, not the default
|
||||
# ``cache`` — using the wrong backend silently misses the Redis
|
||||
# keys when ``CACHE_KEY_PREFIX`` differs between the two configs.
|
||||
all_keys_deleted = cache_manager.data_cache.delete_many(*cache_keys)
|
||||
|
||||
if not all_keys_deleted:
|
||||
# expected behavior as keys may expire and cache is not a
|
||||
|
||||
@@ -23,7 +23,7 @@ import sys
|
||||
import urllib
|
||||
from datetime import datetime
|
||||
from re import Pattern
|
||||
from typing import Any, TYPE_CHECKING, TypedDict
|
||||
from typing import Any, Callable, TYPE_CHECKING, TypedDict
|
||||
|
||||
import pandas as pd
|
||||
from apispec import APISpec
|
||||
@@ -83,6 +83,97 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
# BigQuery string escape sequences keyed off documented escapes in
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#string_and_bytes_literals.
|
||||
# Backslash MUST be first so subsequent escapes don't double-escape their own
|
||||
# backslash. ``\?``, ``\"`` and ``\``` are valid BigQuery escapes but
|
||||
# intentionally omitted because those characters do not require escaping
|
||||
# inside a single-quoted literal. ``\0`` is NOT a valid BigQuery escape
|
||||
# (octal escapes require exactly three digits); the null byte instead falls
|
||||
# through to the ``\xhh`` fallback below.
|
||||
_BIGQUERY_STRING_ESCAPES = {
|
||||
"\\": "\\\\",
|
||||
"'": "\\'",
|
||||
"\n": "\\n",
|
||||
"\r": "\\r",
|
||||
"\t": "\\t",
|
||||
"\b": "\\b",
|
||||
"\f": "\\f",
|
||||
"\v": "\\v",
|
||||
"\a": "\\a",
|
||||
}
|
||||
|
||||
|
||||
def _process_string_literal(value: str) -> str:
|
||||
"""
|
||||
Escape a string value for use as a BigQuery SQL literal.
|
||||
|
||||
BigQuery requires backslash escaping for single quotes inside string
|
||||
literals (``'O\\'Brien'``). Doubled single quotes (``'O''Brien'``) are
|
||||
**not** valid — BigQuery parses them as two concatenated string literals
|
||||
without whitespace, causing a syntax error:
|
||||
``concatenated string literals must be separated by whitespace``.
|
||||
|
||||
BigQuery also forbids literal newlines, carriage returns, and other
|
||||
control characters inside a quoted string; those must be written using
|
||||
escape sequences (``\\n``, ``\\r``, ``\\t`` …). Control characters
|
||||
without a named escape are emitted as a ``\\xhh`` hex escape; printable
|
||||
Unicode passes through unchanged because BigQuery accepts UTF-8 inside
|
||||
string literals.
|
||||
|
||||
The upstream ``sqlalchemy-bigquery`` dialect relies on Python's ``repr()``
|
||||
to quote values, which switches to double-quote delimiters when the
|
||||
string contains an apostrophe (e.g. ``repr("O'Brien")`` → ``"O'Brien"``).
|
||||
Double-quoted tokens inside compiled SQL would be parsed as identifiers,
|
||||
so the query also fails. This helper always produces a single-quoted
|
||||
literal.
|
||||
"""
|
||||
parts = []
|
||||
for ch in value:
|
||||
escape = _BIGQUERY_STRING_ESCAPES.get(ch)
|
||||
if escape is not None:
|
||||
parts.append(escape)
|
||||
elif ord(ch) < 0x20 or ord(ch) == 0x7F:
|
||||
parts.append(f"\\x{ord(ch):02x}")
|
||||
else:
|
||||
parts.append(ch)
|
||||
return f"'{''.join(parts)}'"
|
||||
|
||||
|
||||
def _monkeypatch_bigquery_string_literal() -> None:
|
||||
"""
|
||||
Patch the sqlalchemy-bigquery dialect so that string literals containing
|
||||
apostrophes are rendered correctly when ``literal_binds=True``.
|
||||
|
||||
Without this patch, a filter value like ``O'Brien`` is compiled as the
|
||||
double-quoted identifier ``"O'Brien"`` instead of the single-quoted literal
|
||||
``'O\\'Brien'``, causing BigQuery to return a syntax error.
|
||||
|
||||
This follows the same pattern used for the Databricks dialect fix in
|
||||
``superset/db_engine_specs/databricks.py``.
|
||||
"""
|
||||
try:
|
||||
from sqlalchemy_bigquery import BigQueryDialect
|
||||
|
||||
class BigQuerySafeString(types.TypeDecorator):
|
||||
impl = types.String
|
||||
cache_ok = True
|
||||
|
||||
def literal_processor(self, dialect: Any) -> Callable[[str], str]:
|
||||
if dialect.name == "bigquery":
|
||||
return _process_string_literal
|
||||
return super().literal_processor(dialect)
|
||||
|
||||
BigQueryDialect.colspecs[types.String] = BigQuerySafeString
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
_monkeypatch_bigquery_string_literal()
|
||||
|
||||
|
||||
CONNECTION_DATABASE_PERMISSIONS_REGEX = re.compile(
|
||||
"Access Denied: Project (?P<project_name>.+?): User does not have "
|
||||
+ "bigquery.jobs.create permission in project (?P<project>.+?)"
|
||||
|
||||
@@ -17,10 +17,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, TYPE_CHECKING, TypedDict, Union
|
||||
from typing import Any, Callable, cast, TYPE_CHECKING, TypedDict, Union
|
||||
|
||||
from apispec import APISpec
|
||||
from apispec.ext.marshmallow import MarshmallowPlugin
|
||||
from flask import g
|
||||
from flask_babel import gettext as __
|
||||
from marshmallow import fields, Schema
|
||||
from marshmallow.validate import Range
|
||||
@@ -38,12 +39,18 @@ from superset.db_engine_specs.base import (
|
||||
)
|
||||
from superset.db_engine_specs.hive import HiveEngineSpec
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import OAuth2Error
|
||||
from superset.utils import json
|
||||
from superset.utils.core import get_user_agent, QuerySource
|
||||
from superset.utils.network import is_hostname_valid, is_port_open
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.core import Database
|
||||
from superset.superset_typing import (
|
||||
OAuth2ClientConfig,
|
||||
OAuth2State,
|
||||
OAuth2TokenResponse,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
@@ -277,6 +284,135 @@ class DatabricksDynamicBaseEngineSpec(BasicParametersMixin, DatabricksBaseEngine
|
||||
"port": "port",
|
||||
}
|
||||
|
||||
# The Databricks SQL driver has no dedicated authentication exception, so an
|
||||
# expired or missing token surfaces as a generic driver error. These case-
|
||||
# insensitive substrings flag the errors that should bootstrap a re-auth.
|
||||
oauth2_auth_failure_signals = (
|
||||
"http 401",
|
||||
"unauthorized",
|
||||
"unauthenticated",
|
||||
"invalid access token",
|
||||
"invalid token",
|
||||
"expired token",
|
||||
"token expired",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _workspace_oauth2_endpoint(cls, database: Database, path: str) -> str:
|
||||
"""
|
||||
Build a Databricks OAuth2 (U2M) endpoint from the workspace host.
|
||||
|
||||
Databricks fronts the user-to-machine OAuth2 flow on every workspace at
|
||||
``https://<workspace-host>/oidc/v1/{authorize,token}`` across AWS, Azure
|
||||
and GCP, so the endpoints derive directly from the connection host and
|
||||
need no account or tenant identifier.
|
||||
"""
|
||||
host = database.url_object.host
|
||||
if not host:
|
||||
raise OAuth2Error(
|
||||
"Databricks OAuth2 endpoint could not be resolved: the database "
|
||||
"connection has no host."
|
||||
)
|
||||
return f"https://{host}/oidc/v1/{path}"
|
||||
|
||||
@classmethod
|
||||
def needs_oauth2(cls, ex: Exception) -> bool:
|
||||
"""
|
||||
Identify driver errors that should trigger the OAuth2 dance.
|
||||
|
||||
Unlike Trino (``TrinoAuthError``) or GSheets (``UnauthenticatedError``),
|
||||
the Databricks driver raises no dedicated auth exception, so in addition
|
||||
to the base ``isinstance`` check we match the auth signals above on the
|
||||
error message (mirrors ``GSheetsEngineSpec.needs_oauth2``).
|
||||
"""
|
||||
if not (g and hasattr(g, "user")):
|
||||
return False
|
||||
if isinstance(ex, cls.oauth2_exception):
|
||||
return True
|
||||
message = str(ex).lower()
|
||||
return any(signal in message for signal in cls.oauth2_auth_failure_signals)
|
||||
|
||||
@classmethod
|
||||
def get_oauth2_authorization_uri(
|
||||
cls,
|
||||
config: "OAuth2ClientConfig",
|
||||
state: "OAuth2State",
|
||||
code_verifier: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Return the URI for the initial OAuth2 request.
|
||||
|
||||
A fully-resolved ``authorization_request_uri`` from
|
||||
``DATABASE_OAUTH2_CLIENTS`` is preserved; otherwise the endpoint is
|
||||
derived from the workspace host (``https://<host>/oidc/v1/authorize``),
|
||||
which is valid on AWS, Azure and GCP.
|
||||
"""
|
||||
if not config.get("authorization_request_uri"):
|
||||
from superset import db
|
||||
from superset.models.core import Database
|
||||
|
||||
database_id = state["database_id"]
|
||||
if database := db.session.get(Database, database_id):
|
||||
config = cast(
|
||||
"OAuth2ClientConfig",
|
||||
dict(config)
|
||||
| {
|
||||
"authorization_request_uri": cls._workspace_oauth2_endpoint(
|
||||
database, "authorize"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
return super().get_oauth2_authorization_uri(config, state, code_verifier)
|
||||
|
||||
@classmethod
|
||||
def get_oauth2_token(
|
||||
cls,
|
||||
config: "OAuth2ClientConfig",
|
||||
code: str,
|
||||
code_verifier: str | None = None,
|
||||
) -> "OAuth2TokenResponse":
|
||||
"""
|
||||
Exchange the authorization code for refresh/access tokens.
|
||||
|
||||
Token exchange runs in a separate request with no database context, so
|
||||
the workspace host is not available to derive the endpoint here. Require
|
||||
a configured ``token_request_uri``
|
||||
(``https://<workspace-host>/oidc/v1/token``) and fail fast rather than
|
||||
POST to an unresolved endpoint.
|
||||
"""
|
||||
if not config.get("token_request_uri"):
|
||||
raise OAuth2Error(
|
||||
"Databricks OAuth2 token endpoint is not configured: set "
|
||||
"`token_request_uri` to https://<workspace-host>/oidc/v1/token "
|
||||
"in DATABASE_OAUTH2_CLIENTS."
|
||||
)
|
||||
|
||||
return super().get_oauth2_token(config, code, code_verifier)
|
||||
|
||||
@classmethod
|
||||
def impersonate_user(
|
||||
cls,
|
||||
database: Database,
|
||||
username: str | None,
|
||||
user_token: str | None,
|
||||
url: URL,
|
||||
engine_kwargs: dict[str, Any],
|
||||
) -> tuple[URL, dict[str, Any]]:
|
||||
"""
|
||||
Update connection with OAuth2 access token for user impersonation.
|
||||
"""
|
||||
if user_token:
|
||||
# Replace the access token in the URL with the user's OAuth2 token
|
||||
url = url.set(password=user_token)
|
||||
|
||||
# Also update connect_args if they contain access token
|
||||
connect_args = engine_kwargs.setdefault("connect_args", {})
|
||||
if "access_token" in connect_args:
|
||||
connect_args["access_token"] = user_token
|
||||
|
||||
return url, engine_kwargs
|
||||
|
||||
@staticmethod
|
||||
def get_extra_params(
|
||||
database: Database, source: QuerySource | None = None
|
||||
@@ -474,6 +610,16 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
|
||||
supports_dynamic_catalog = True
|
||||
supports_cross_catalog_queries = True
|
||||
|
||||
# OAuth 2.0 support. The flow (endpoint resolution from the workspace host,
|
||||
# `needs_oauth2` detection) is shared via `DatabricksDynamicBaseEngineSpec`.
|
||||
supports_oauth2 = True
|
||||
oauth2_scope = "sql"
|
||||
|
||||
# Authorization endpoint is derived from the workspace host at runtime; the
|
||||
# token endpoint must be configured (no DB context at exchange time).
|
||||
oauth2_authorization_request_uri = ""
|
||||
oauth2_token_request_uri = ""
|
||||
|
||||
@classmethod
|
||||
def build_sqlalchemy_uri( # type: ignore
|
||||
cls, parameters: DatabricksNativeParametersType, *_
|
||||
@@ -685,6 +831,16 @@ class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
|
||||
|
||||
supports_dynamic_schema = supports_catalog = supports_dynamic_catalog = True
|
||||
|
||||
# OAuth 2.0 support. The flow (endpoint resolution from the workspace host,
|
||||
# `needs_oauth2` detection) is shared via `DatabricksDynamicBaseEngineSpec`.
|
||||
supports_oauth2 = True
|
||||
oauth2_scope = "sql"
|
||||
|
||||
# Authorization endpoint is derived from the workspace host at runtime; the
|
||||
# token endpoint must be configured (no DB context at exchange time).
|
||||
oauth2_authorization_request_uri = ""
|
||||
oauth2_token_request_uri = ""
|
||||
|
||||
@classmethod
|
||||
def build_sqlalchemy_uri( # type: ignore
|
||||
cls, parameters: DatabricksPythonConnectorParametersType, *_
|
||||
|
||||
@@ -303,6 +303,7 @@ DEFAULT_GET_DASHBOARD_INFO_COLUMNS: List[str] = [
|
||||
"created_on",
|
||||
"changed_on",
|
||||
"uuid",
|
||||
"embedded_uuid",
|
||||
"url",
|
||||
"created_on_humanized",
|
||||
"changed_on_humanized",
|
||||
@@ -427,6 +428,18 @@ class DashboardInfo(BaseModel):
|
||||
created_on: str | datetime | None = None
|
||||
changed_on: str | datetime | None = None
|
||||
uuid: str | None = None
|
||||
embedded_uuid: str | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Embedded UUID for this dashboard. This is the UUID required when "
|
||||
"generating guest tokens for embedded dashboards "
|
||||
"(resources[].id in the guest token payload). "
|
||||
"Only present when the dashboard has been configured for embedding "
|
||||
"via the Embed Dashboard UI. Distinct from `uuid` (the internal "
|
||||
"dashboard UUID) — using the wrong one causes 403 errors in guest "
|
||||
"token validation."
|
||||
),
|
||||
)
|
||||
url: str | None = None
|
||||
created_on_humanized: str | None = None
|
||||
changed_on_humanized: str | None = None
|
||||
@@ -1352,6 +1365,9 @@ def dashboard_serializer(dashboard: "Dashboard") -> DashboardInfo:
|
||||
created_on=dashboard.created_on,
|
||||
changed_on=dashboard.changed_on,
|
||||
uuid=str(dashboard.uuid) if dashboard.uuid else None,
|
||||
embedded_uuid=str(dashboard.embedded[0].uuid)
|
||||
if dashboard.embedded
|
||||
else None,
|
||||
url=absolute_url,
|
||||
created_on_humanized=dashboard.created_on_humanized,
|
||||
changed_on_humanized=dashboard.changed_on_humanized,
|
||||
|
||||
@@ -155,10 +155,11 @@ async def get_dashboard_info(
|
||||
from superset.models.dashboard import Dashboard
|
||||
from superset.models.slice import Slice
|
||||
|
||||
# Eager load slices and tags to avoid N+1 queries during serialization.
|
||||
# Eager load slices, tags, and embedded to avoid N+1 queries.
|
||||
eager_options = [
|
||||
subqueryload(Dashboard.slices).subqueryload(Slice.tags),
|
||||
subqueryload(Dashboard.tags),
|
||||
subqueryload(Dashboard.embedded),
|
||||
]
|
||||
|
||||
with event_logger.log_context(action="mcp.get_dashboard_info.lookup"):
|
||||
|
||||
@@ -2963,6 +2963,17 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||
tp = self.get_template_processor()
|
||||
processed_expression = self._process_expression_template(expression, tp)
|
||||
|
||||
# Apply the same parsing policy used for stored adhoc column and
|
||||
# metric expressions (single statement, no set operations, and no
|
||||
# sub-queries unless ALLOW_ADHOC_SUBQUERY is enabled), so expression
|
||||
# validation follows one policy across the query pipeline. Imported
|
||||
# locally to avoid a circular import with the connectors package.
|
||||
from superset.connectors.sqla.models import validate_stored_expression
|
||||
|
||||
validate_stored_expression(
|
||||
self.database, self.catalog, self.schema or "", processed_expression
|
||||
)
|
||||
|
||||
# Build validation query
|
||||
tbl, cte = self.get_from_clause(tp)
|
||||
validation_query = self._build_validation_query(
|
||||
|
||||
@@ -3164,9 +3164,6 @@ msgstr "Culoare după"
|
||||
msgid "Color for breakpoint"
|
||||
msgstr "Culoare pentru punct de întrerupere"
|
||||
|
||||
msgid "Color Metric"
|
||||
msgstr "Indicator culoare"
|
||||
|
||||
msgid "Color of the source location"
|
||||
msgstr "Culoarea locației sursă"
|
||||
|
||||
@@ -4644,7 +4641,7 @@ msgid "Deleted %(num)d theme"
|
||||
msgid_plural "Deleted %(num)d themes"
|
||||
msgstr[0] "Ștearsă %(num)d temă"
|
||||
msgstr[1] "Șterse %(num)d teme"
|
||||
msgstr[2] "Șterse %(num)d de teme""
|
||||
msgstr[2] "Șterse %(num)d de teme"
|
||||
|
||||
#, python-format
|
||||
msgid "Deleted %s"
|
||||
@@ -10218,7 +10215,7 @@ msgid "Saved expressions"
|
||||
msgstr "Expresii salvate"
|
||||
|
||||
msgid "Saved metric"
|
||||
msgstr "Indicator salvat""
|
||||
msgstr "Indicator salvat"
|
||||
|
||||
msgid "Saved queries"
|
||||
msgstr "Interogări salvate"
|
||||
@@ -13422,7 +13419,7 @@ msgid "This was triggered by:"
|
||||
msgid_plural "This may be triggered by:"
|
||||
msgstr[0] "Aceasta a fost declanșată de:"
|
||||
msgstr[1] "Acestea pot fi declanșate de:"
|
||||
msgstr[2] "Acestea pot fi declanșate de către:""
|
||||
msgstr[2] "Acestea pot fi declanșate de către:"
|
||||
|
||||
msgid ""
|
||||
"This will be applied to the whole table. Arrows (↑ and ↓) will be added "
|
||||
@@ -14447,7 +14444,7 @@ msgstr ""
|
||||
"Vizualizează un indicator corelat pentru perechi de grupuri. Hartile "
|
||||
"termice (Heatmaps) excelează în evidențierea corelației sau a "
|
||||
"intensității dintre două grupuri. Culoarea este utilizată pentru a "
|
||||
"sublinia intensitatea legăturii dintre fiecare pereche de grupuri.""
|
||||
"sublinia intensitatea legăturii dintre fiecare pereche de grupuri."
|
||||
|
||||
msgid ""
|
||||
"Visualize geospatial data like 3D buildings, landscapes, or objects in "
|
||||
|
||||
@@ -54,6 +54,13 @@ NUMPY_FUNCTIONS: dict[str, Callable[..., Any]] = {
|
||||
"var": np.var,
|
||||
}
|
||||
|
||||
# Operators that pandas GroupBy.agg accepts as string names. Passing the string
|
||||
# avoids a FutureWarning raised when pandas receives a numpy callable it internally
|
||||
# maps to its own method (e.g. np.mean → SeriesGroupBy.mean).
|
||||
_PANDAS_STRING_AGGREGATORS: frozenset[str] = frozenset(
|
||||
{"max", "mean", "median", "min", "prod", "std", "sum", "var"}
|
||||
)
|
||||
|
||||
DENYLIST_ROLLING_FUNCTIONS = (
|
||||
"count",
|
||||
"corr",
|
||||
@@ -166,7 +173,7 @@ def _get_aggregate_funcs(
|
||||
)
|
||||
operator = agg_obj["operator"]
|
||||
if callable(operator):
|
||||
aggfunc = operator
|
||||
aggfunc: str | Callable[..., Any] = operator
|
||||
else:
|
||||
func = NUMPY_FUNCTIONS.get(operator)
|
||||
if not func:
|
||||
@@ -177,7 +184,10 @@ def _get_aggregate_funcs(
|
||||
)
|
||||
)
|
||||
options = agg_obj.get("options", {})
|
||||
aggfunc = partial(func, **options)
|
||||
if not options and operator in _PANDAS_STRING_AGGREGATORS:
|
||||
aggfunc = operator
|
||||
else:
|
||||
aggfunc = partial(func, **options)
|
||||
agg_funcs[name] = NamedAgg(column=column, aggfunc=aggfunc)
|
||||
|
||||
return agg_funcs
|
||||
|
||||
@@ -440,9 +440,9 @@ class BaseViz: # pylint: disable=too-many-public-methods
|
||||
"metrics": metrics,
|
||||
"row_limit": row_limit,
|
||||
"filter": self.form_data.get("filters", []),
|
||||
"timeseries_limit": limit,
|
||||
"series_limit": limit,
|
||||
"extras": extras,
|
||||
"timeseries_limit_metric": timeseries_limit_metric,
|
||||
"series_limit_metric": timeseries_limit_metric,
|
||||
"order_desc": order_desc,
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
"""Unit tests for Superset"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -52,17 +53,43 @@ def test_invalidate_cache(invalidate):
|
||||
def test_invalidate_existing_cache(invalidate):
|
||||
db.session.add(CacheKey(cache_key="cache_key", datasource_uid="3__table"))
|
||||
db.session.commit()
|
||||
cache_manager.cache.set("cache_key", "value")
|
||||
cache_manager.data_cache.set("cache_key", "value")
|
||||
|
||||
rv = invalidate({"datasource_uids": ["3__table"]})
|
||||
|
||||
assert rv.status_code == 201
|
||||
assert cache_manager.cache.get("cache_key") is None # noqa: E711
|
||||
assert cache_manager.data_cache.get("cache_key") is None # noqa: E711
|
||||
assert (
|
||||
not db.session.query(CacheKey).filter(CacheKey.cache_key == "cache_key").first()
|
||||
)
|
||||
|
||||
|
||||
def test_invalidate_uses_data_cache_not_default_cache(invalidate):
|
||||
"""Regression test for #40489.
|
||||
|
||||
Chart query results are written through ``cache_manager.data_cache``
|
||||
(``DATA_CACHE_CONFIG``). When ``CACHE_CONFIG`` and ``DATA_CACHE_CONFIG``
|
||||
use distinct ``CACHE_KEY_PREFIX`` values, deleting via the default
|
||||
``cache_manager.cache`` silently misses the underlying Redis keys
|
||||
because flask-caching prepends the wrong prefix to the DEL call.
|
||||
"""
|
||||
db.session.add(CacheKey(cache_key="cache_key", datasource_uid="3__table"))
|
||||
db.session.commit()
|
||||
|
||||
with (
|
||||
patch.object(cache_manager.data_cache, "delete_many") as data_delete,
|
||||
patch.object(cache_manager.cache, "delete_many") as default_delete,
|
||||
):
|
||||
data_delete.return_value = True
|
||||
rv = invalidate({"datasource_uids": ["3__table"]})
|
||||
|
||||
assert rv.status_code == 201
|
||||
# Chart-data cache backend (the one that wrote the keys) must be hit.
|
||||
data_delete.assert_called_once_with("cache_key")
|
||||
# The default cache must NOT be touched — that's the #40489 regression.
|
||||
default_delete.assert_not_called()
|
||||
|
||||
|
||||
def test_invalidate_cache_empty_input(invalidate):
|
||||
rv = invalidate({"datasource_uids": []})
|
||||
assert rv.status_code == 201
|
||||
@@ -111,10 +138,10 @@ def test_invalidate_existing_caches(invalidate):
|
||||
db.session.add(CacheKey(cache_key="cache_keyX", datasource_uid="X__table"))
|
||||
db.session.commit()
|
||||
|
||||
cache_manager.cache.set("cache_key1", "value")
|
||||
cache_manager.cache.set("cache_key2", "value")
|
||||
cache_manager.cache.set("cache_key4", "value")
|
||||
cache_manager.cache.set("cache_keyX", "value")
|
||||
cache_manager.data_cache.set("cache_key1", "value")
|
||||
cache_manager.data_cache.set("cache_key2", "value")
|
||||
cache_manager.data_cache.set("cache_key4", "value")
|
||||
cache_manager.data_cache.set("cache_keyX", "value")
|
||||
|
||||
rv = invalidate(
|
||||
{
|
||||
@@ -155,10 +182,10 @@ def test_invalidate_existing_caches(invalidate):
|
||||
)
|
||||
|
||||
assert rv.status_code == 201
|
||||
assert cache_manager.cache.get("cache_key1") is None
|
||||
assert cache_manager.cache.get("cache_key2") is None
|
||||
assert cache_manager.cache.get("cache_key4") is None
|
||||
assert cache_manager.cache.get("cache_keyX") == "value"
|
||||
assert cache_manager.data_cache.get("cache_key1") is None
|
||||
assert cache_manager.data_cache.get("cache_key2") is None
|
||||
assert cache_manager.data_cache.get("cache_key4") is None
|
||||
assert cache_manager.data_cache.get("cache_keyX") == "value"
|
||||
assert (
|
||||
not db.session.query(CacheKey)
|
||||
.filter(CacheKey.cache_key.in_({"cache_key1", "cache_key2", "cache_key4"}))
|
||||
|
||||
@@ -767,3 +767,265 @@ def test_fetch_data_converts_bigquery_row_objects(mocker: MockerFixture) -> None
|
||||
|
||||
assert result == [(1, "a"), (2, "b")]
|
||||
assert flask_g.bq_memory_limited is False
|
||||
|
||||
|
||||
def test_string_literal_with_apostrophe() -> None:
|
||||
"""
|
||||
Test that string literals containing apostrophes are properly escaped
|
||||
for BigQuery using backslash escaping.
|
||||
|
||||
BigQuery requires backslash escaping for single quotes ('O\\'Brien').
|
||||
Doubled single quotes ('O''Brien') are NOT valid — BigQuery parses them
|
||||
as two concatenated string literals, causing a syntax error.
|
||||
The upstream sqlalchemy-bigquery dialect uses ``repr()`` which switches
|
||||
to double-quote delimiters when the value contains an apostrophe.
|
||||
Double-quoted tokens are identifiers in BigQuery, causing syntax errors.
|
||||
"""
|
||||
from sqlalchemy import column as sa_column
|
||||
|
||||
from superset.db_engine_specs.bigquery import BigQueryEngineSpec # noqa: F811
|
||||
|
||||
# Trigger module load to ensure the monkey-patch is applied
|
||||
assert BigQueryEngineSpec is not None
|
||||
|
||||
dialect = BigQueryDialect()
|
||||
|
||||
stmt = select(sa_column("name")).where(sa_column("name") == "Fernando's")
|
||||
compiled_sql = str(
|
||||
stmt.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
|
||||
)
|
||||
|
||||
# The compiled SQL must use single-quoted literal with backslash-escaped
|
||||
# apostrophes. Doubled single quotes are NOT valid in BigQuery.
|
||||
assert "= 'Fernando\\'s'" in compiled_sql
|
||||
# Must NOT contain doubled-quote escaping (BigQuery rejects this)
|
||||
assert "''" not in compiled_sql
|
||||
# Must NOT contain double-quoted identifiers
|
||||
assert '\\"' not in compiled_sql
|
||||
|
||||
|
||||
def test_string_literal_without_apostrophe() -> None:
|
||||
"""
|
||||
Test that normal string literals (without apostrophes) still compile
|
||||
correctly after the monkey-patch.
|
||||
"""
|
||||
from sqlalchemy import column as sa_column
|
||||
|
||||
from superset.db_engine_specs.bigquery import BigQueryEngineSpec # noqa: F811
|
||||
|
||||
assert BigQueryEngineSpec is not None
|
||||
|
||||
dialect = BigQueryDialect()
|
||||
|
||||
stmt = select(sa_column("name")).where(sa_column("name") == "Fernando")
|
||||
compiled_sql = str(
|
||||
stmt.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
|
||||
)
|
||||
|
||||
assert "= 'Fernando'" in compiled_sql
|
||||
|
||||
|
||||
def test_string_literal_in_filter_with_apostrophe() -> None:
|
||||
"""
|
||||
Test that IN filters with apostrophes in values compile correctly
|
||||
using backslash escaping.
|
||||
"""
|
||||
from sqlalchemy import column as sa_column
|
||||
|
||||
from superset.db_engine_specs.bigquery import BigQueryEngineSpec # noqa: F811
|
||||
|
||||
assert BigQueryEngineSpec is not None
|
||||
|
||||
dialect = BigQueryDialect()
|
||||
|
||||
stmt = select(sa_column("name")).where(
|
||||
sa_column("name").in_(["Fernando's", "O'Brien"])
|
||||
)
|
||||
compiled_sql = str(
|
||||
stmt.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
|
||||
)
|
||||
|
||||
assert "'Fernando\\'s'" in compiled_sql
|
||||
assert "'O\\'Brien'" in compiled_sql
|
||||
# Must NOT contain doubled-quote escaping
|
||||
assert "''" not in compiled_sql
|
||||
|
||||
|
||||
def test_process_string_literal_directly() -> None:
|
||||
"""
|
||||
Test _process_string_literal covers backslash escaping for apostrophes,
|
||||
control-character escaping (newline/CR/tab/etc.), the ``\\xhh`` fallback
|
||||
for control chars without a named escape, and pass-through for printable
|
||||
Unicode and other characters BigQuery accepts unescaped.
|
||||
"""
|
||||
from superset.db_engine_specs.bigquery import _process_string_literal
|
||||
|
||||
# Plain values
|
||||
assert _process_string_literal("hello") == "'hello'"
|
||||
assert _process_string_literal("") == "''"
|
||||
|
||||
# Apostrophes (the original fix)
|
||||
assert _process_string_literal("O'Brien") == "'O\\'Brien'"
|
||||
assert _process_string_literal("it's a test") == "'it\\'s a test'"
|
||||
|
||||
# Backslashes must be escaped before apostrophes
|
||||
assert _process_string_literal("C:\\path") == "'C:\\\\path'"
|
||||
assert _process_string_literal("it's C:\\path") == "'it\\'s C:\\\\path'"
|
||||
|
||||
# Literal backslash followed by 'n' (two characters, not a newline)
|
||||
# must produce the two-char sequence '\\n' (escaped backslash + n) so
|
||||
# BigQuery does not misread it as a newline escape.
|
||||
assert _process_string_literal("\\n") == "'\\\\n'"
|
||||
|
||||
# Control characters must be escaped using named escapes — BigQuery
|
||||
# rejects literal control characters inside quoted strings.
|
||||
assert _process_string_literal("foo\nbar") == "'foo\\nbar'"
|
||||
assert _process_string_literal("foo\rbar") == "'foo\\rbar'"
|
||||
assert _process_string_literal("foo\tbar") == "'foo\\tbar'"
|
||||
assert _process_string_literal("a\bb\fc\vd\ae") == "'a\\bb\\fc\\vd\\ae'"
|
||||
|
||||
# Control characters without a named escape fall through to ``\\xhh``.
|
||||
assert _process_string_literal("null\0byte") == "'null\\x00byte'"
|
||||
assert _process_string_literal("a\x01b") == "'a\\x01b'"
|
||||
assert _process_string_literal("a\x1bb") == "'a\\x1bb'"
|
||||
assert _process_string_literal("a\x7fb") == "'a\\x7fb'"
|
||||
|
||||
# Double quotes do NOT need escaping in single-quoted BigQuery literals.
|
||||
assert _process_string_literal('say "hello"') == "'say \"hello\"'"
|
||||
|
||||
# Printable Unicode and percent signs pass through unchanged.
|
||||
assert _process_string_literal("café") == "'café'"
|
||||
assert _process_string_literal("日本") == "'日本'"
|
||||
assert _process_string_literal("100%") == "'100%'"
|
||||
|
||||
# Combined: apostrophe + newline + backslash + unicode.
|
||||
assert _process_string_literal("it's\nC:\\café") == "'it\\'s\\nC:\\\\café'"
|
||||
|
||||
|
||||
def test_process_string_literal_no_literal_control_chars() -> None:
|
||||
"""
|
||||
Regression test for the issue raised in PR #38835 review: BigQuery
|
||||
rejects literal control characters inside quoted string literals, so the
|
||||
output must never contain them as literal characters.
|
||||
"""
|
||||
from superset.db_engine_specs.bigquery import _process_string_literal
|
||||
|
||||
for char in ["\n", "\r", "\t", "\b", "\f", "\v", "\a", "\0", "\x01", "\x7f"]:
|
||||
result = _process_string_literal(f"prefix{char}suffix")
|
||||
assert char not in result, (
|
||||
f"Literal {char!r} leaked into output {result!r}; "
|
||||
"BigQuery would reject this literal."
|
||||
)
|
||||
|
||||
|
||||
def test_string_literal_with_newline_in_filter() -> None:
|
||||
"""
|
||||
End-to-end regression test for @rusackas's review feedback on PR #38835:
|
||||
a filter value containing a newline must compile to valid BigQuery SQL
|
||||
using the ``\\n`` escape sequence, not a literal newline.
|
||||
"""
|
||||
from sqlalchemy import column as sa_column
|
||||
|
||||
from superset.db_engine_specs.bigquery import BigQueryEngineSpec # noqa: F811
|
||||
|
||||
assert BigQueryEngineSpec is not None
|
||||
|
||||
dialect = BigQueryDialect()
|
||||
stmt = select(sa_column("note")).where(sa_column("note") == "line1\nline2")
|
||||
compiled_sql = str(
|
||||
stmt.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
|
||||
)
|
||||
|
||||
# Must use the escape sequence form, not a literal newline.
|
||||
assert "'line1\\nline2'" in compiled_sql
|
||||
assert "\n" not in compiled_sql.split("note")[-1]
|
||||
|
||||
|
||||
def test_literal_processor_non_bigquery_dialect() -> None:
|
||||
"""
|
||||
Test that BigQuerySafeString.literal_processor falls back to the parent
|
||||
implementation when used with a non-BigQuery dialect.
|
||||
"""
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
from superset.db_engine_specs.bigquery import (
|
||||
_monkeypatch_bigquery_string_literal, # noqa: F811
|
||||
)
|
||||
|
||||
_monkeypatch_bigquery_string_literal()
|
||||
|
||||
safe_cls = BigQueryDialect.colspecs[sqltypes.String]
|
||||
instance = safe_cls()
|
||||
|
||||
# Use a non-BigQuery dialect (sqlite)
|
||||
sqlite_dialect = create_engine("sqlite://").dialect
|
||||
processor = instance.literal_processor(sqlite_dialect)
|
||||
|
||||
# The fallback processor should still produce a valid quoted string
|
||||
assert processor is not None
|
||||
|
||||
|
||||
def test_monkeypatch_is_applied() -> None:
|
||||
"""
|
||||
Test that _monkeypatch_bigquery_string_literal installs the custom
|
||||
type decorator into BigQueryDialect.colspecs.
|
||||
"""
|
||||
from sqlalchemy.sql import sqltypes as sa_sqltypes
|
||||
|
||||
from superset.db_engine_specs.bigquery import (
|
||||
BigQueryEngineSpec, # noqa: F811
|
||||
)
|
||||
|
||||
assert BigQueryEngineSpec is not None
|
||||
|
||||
colspecs = BigQueryDialect.colspecs
|
||||
assert sa_sqltypes.String in colspecs
|
||||
safe_cls = colspecs[sa_sqltypes.String]
|
||||
assert safe_cls.__name__ == "BigQuerySafeString"
|
||||
|
||||
|
||||
def test_literal_processor_returns_process_string_literal_for_bigquery() -> None:
|
||||
"""
|
||||
Test that BigQuerySafeString.literal_processor returns the
|
||||
_process_string_literal function when given a BigQuery dialect,
|
||||
and that calling it produces correctly escaped output.
|
||||
"""
|
||||
from superset.db_engine_specs.bigquery import (
|
||||
_monkeypatch_bigquery_string_literal,
|
||||
_process_string_literal,
|
||||
)
|
||||
|
||||
_monkeypatch_bigquery_string_literal()
|
||||
|
||||
safe_cls = BigQueryDialect.colspecs[sqltypes.String]
|
||||
instance = safe_cls()
|
||||
|
||||
dialect = BigQueryDialect()
|
||||
processor = instance.literal_processor(dialect)
|
||||
|
||||
assert processor is _process_string_literal
|
||||
assert processor("O'Brien") == "'O\\'Brien'"
|
||||
assert processor("plain") == "'plain'"
|
||||
|
||||
|
||||
def test_monkeypatch_handles_missing_bigquery_package() -> None:
|
||||
"""
|
||||
Test that _monkeypatch_bigquery_string_literal gracefully handles
|
||||
the case where sqlalchemy_bigquery is not installed.
|
||||
"""
|
||||
import builtins
|
||||
|
||||
from superset.db_engine_specs.bigquery import (
|
||||
_monkeypatch_bigquery_string_literal,
|
||||
)
|
||||
|
||||
original_import = builtins.__import__
|
||||
|
||||
def mock_import(name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
if name == "sqlalchemy_bigquery":
|
||||
raise ImportError("mocked missing package")
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with mock.patch("builtins.__import__", side_effect=mock_import):
|
||||
# Should not raise — the except ImportError branch handles it
|
||||
_monkeypatch_bigquery_string_literal()
|
||||
|
||||
@@ -17,14 +17,23 @@
|
||||
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
|
||||
from superset.db_engine_specs.base import OAuth2State
|
||||
from superset.db_engine_specs.databricks import (
|
||||
DatabricksNativeEngineSpec,
|
||||
DatabricksPythonConnectorEngineSpec,
|
||||
)
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError
|
||||
from superset.superset_typing import OAuth2ClientConfig
|
||||
from superset.utils import json
|
||||
from superset.utils.oauth2 import decode_oauth2_state
|
||||
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
|
||||
from tests.unit_tests.fixtures.common import dttm # noqa: F401
|
||||
|
||||
@@ -291,3 +300,595 @@ def test_get_prequeries(mocker: MockerFixture) -> None:
|
||||
"USE CATALOG `evil`` USE CATALOG bad`",
|
||||
"USE SCHEMA `evil`` USE SCHEMA bad`",
|
||||
]
|
||||
|
||||
|
||||
# OAuth2 Tests
|
||||
|
||||
|
||||
def test_oauth2_attributes() -> None:
|
||||
"""
|
||||
Test that OAuth2 attributes are properly set for both engine specs.
|
||||
"""
|
||||
# Test DatabricksNativeEngineSpec
|
||||
assert DatabricksNativeEngineSpec.supports_oauth2 is True
|
||||
assert DatabricksNativeEngineSpec.oauth2_scope == "sql"
|
||||
# The authorization endpoint is derived from the workspace host at runtime;
|
||||
# the token endpoint must be configured explicitly.
|
||||
assert DatabricksNativeEngineSpec.oauth2_authorization_request_uri == ""
|
||||
assert DatabricksNativeEngineSpec.oauth2_token_request_uri == ""
|
||||
|
||||
# Test DatabricksPythonConnectorEngineSpec
|
||||
assert DatabricksPythonConnectorEngineSpec.supports_oauth2 is True
|
||||
assert DatabricksPythonConnectorEngineSpec.oauth2_scope == "sql"
|
||||
assert DatabricksPythonConnectorEngineSpec.oauth2_authorization_request_uri == ""
|
||||
assert DatabricksPythonConnectorEngineSpec.oauth2_token_request_uri == ""
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec",
|
||||
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"message",
|
||||
[
|
||||
"Error during request to server: HTTP 401 Unauthorized",
|
||||
"Invalid access token",
|
||||
"The access token expired",
|
||||
"UNAUTHENTICATED: token is no longer valid",
|
||||
],
|
||||
)
|
||||
def test_needs_oauth2_detects_auth_failure_from_message(
|
||||
mocker: MockerFixture,
|
||||
spec: Any,
|
||||
message: str,
|
||||
) -> None:
|
||||
"""
|
||||
The Databricks driver has no dedicated auth exception, so `needs_oauth2`
|
||||
matches auth-failure signals in the error message to bootstrap a re-auth.
|
||||
"""
|
||||
g = mocker.patch("superset.db_engine_specs.databricks.g")
|
||||
g.user = mocker.MagicMock()
|
||||
|
||||
assert spec.needs_oauth2(Exception(message)) is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec",
|
||||
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"message",
|
||||
[
|
||||
"Table not found",
|
||||
# A bare 401 in an unrelated position must not look like an auth failure.
|
||||
"Query failed at line 401: syntax error",
|
||||
],
|
||||
)
|
||||
def test_needs_oauth2_ignores_unrelated_errors(
|
||||
mocker: MockerFixture,
|
||||
spec: Any,
|
||||
message: str,
|
||||
) -> None:
|
||||
"""
|
||||
A non-auth driver error must not trigger the OAuth2 dance.
|
||||
"""
|
||||
g = mocker.patch("superset.db_engine_specs.databricks.g")
|
||||
g.user = mocker.MagicMock()
|
||||
|
||||
assert spec.needs_oauth2(Exception(message)) is False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec",
|
||||
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
|
||||
)
|
||||
def test_needs_oauth2_matches_oauth2_redirect_error(
|
||||
mocker: MockerFixture,
|
||||
spec: Any,
|
||||
) -> None:
|
||||
"""
|
||||
The inherited `isinstance` check against `oauth2_exception` still holds.
|
||||
"""
|
||||
g = mocker.patch("superset.db_engine_specs.databricks.g")
|
||||
g.user = mocker.MagicMock()
|
||||
|
||||
ex = OAuth2RedirectError("https://example/authorize", "tab", "redirect")
|
||||
assert spec.needs_oauth2(ex) is True
|
||||
|
||||
|
||||
def test_impersonate_user_with_token(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test impersonate_user method with OAuth2 token for DatabricksNativeEngineSpec.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
original_url = make_url(
|
||||
"databricks+connector://token:original-token@host:443/database"
|
||||
)
|
||||
engine_kwargs = {"connect_args": {"access_token": "original-token"}}
|
||||
|
||||
# Test with user token
|
||||
url, kwargs = DatabricksNativeEngineSpec.impersonate_user(
|
||||
database=database,
|
||||
username="user1",
|
||||
user_token="user-oauth-token", # noqa: S106
|
||||
url=original_url,
|
||||
engine_kwargs=engine_kwargs,
|
||||
)
|
||||
|
||||
# Check that the password (token) was updated in the URL
|
||||
assert url.password == "user-oauth-token" # noqa: S105
|
||||
# Check that access_token was updated in connect_args
|
||||
assert kwargs["connect_args"]["access_token"] == "user-oauth-token" # noqa: S105
|
||||
|
||||
|
||||
def test_impersonate_user_without_token(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test impersonate_user method without OAuth2 token.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
original_url = make_url(
|
||||
"databricks+connector://token:original-token@host:443/database"
|
||||
)
|
||||
engine_kwargs = {"connect_args": {"access_token": "original-token"}}
|
||||
|
||||
# Test without user token
|
||||
url, kwargs = DatabricksNativeEngineSpec.impersonate_user(
|
||||
database=database,
|
||||
username="user1",
|
||||
user_token=None,
|
||||
url=original_url,
|
||||
engine_kwargs=engine_kwargs,
|
||||
)
|
||||
|
||||
# Check that nothing was changed
|
||||
assert url.password == "original-token" # noqa: S105
|
||||
assert kwargs["connect_args"]["access_token"] == "original-token" # noqa: S105
|
||||
|
||||
|
||||
def test_impersonate_user_python_connector(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test impersonate_user method for DatabricksPythonConnectorEngineSpec.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
original_url = make_url(
|
||||
"databricks://token:original-token@host:443?http_path=path&catalog=main&schema=default"
|
||||
)
|
||||
engine_kwargs = {"connect_args": {"access_token": "original-token"}}
|
||||
|
||||
# Test with user token
|
||||
url, kwargs = DatabricksPythonConnectorEngineSpec.impersonate_user(
|
||||
database=database,
|
||||
username="user1",
|
||||
user_token="user-oauth-token", # noqa: S106
|
||||
url=original_url,
|
||||
engine_kwargs=engine_kwargs,
|
||||
)
|
||||
|
||||
# Check that the password (token) was updated in the URL
|
||||
assert url.password == "user-oauth-token" # noqa: S105
|
||||
# Check that access_token was updated in connect_args
|
||||
assert kwargs["connect_args"]["access_token"] == "user-oauth-token" # noqa: S105
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_config_native() -> OAuth2ClientConfig:
|
||||
"""
|
||||
Config for Databricks Native OAuth2.
|
||||
"""
|
||||
return {
|
||||
"id": "databricks-client-id",
|
||||
"secret": "databricks-client-secret",
|
||||
"scope": "sql",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"authorization_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/authorize",
|
||||
"token_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
|
||||
"request_content_type": "json",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_config_python() -> OAuth2ClientConfig:
|
||||
"""
|
||||
Config for Databricks Python Connector OAuth2.
|
||||
"""
|
||||
return {
|
||||
"id": "databricks-client-id",
|
||||
"secret": "databricks-client-secret",
|
||||
"scope": "sql",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"authorization_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/authorize",
|
||||
"token_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
|
||||
"request_content_type": "json",
|
||||
}
|
||||
|
||||
|
||||
def test_is_oauth2_enabled_no_config_native(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test `is_oauth2_enabled` when OAuth2 is not configured for Native engine.
|
||||
"""
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
new={"DATABASE_OAUTH2_CLIENTS": {}},
|
||||
)
|
||||
|
||||
assert DatabricksNativeEngineSpec.is_oauth2_enabled() is False
|
||||
|
||||
|
||||
def test_is_oauth2_enabled_config_native(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test `is_oauth2_enabled` when OAuth2 is configured for Native engine.
|
||||
"""
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
new={
|
||||
"DATABASE_OAUTH2_CLIENTS": {
|
||||
"Databricks (legacy)": {
|
||||
"id": "client-id",
|
||||
"secret": "client-secret",
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert DatabricksNativeEngineSpec.is_oauth2_enabled() is True
|
||||
|
||||
|
||||
def test_is_oauth2_enabled_no_config_python(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test `is_oauth2_enabled` when OAuth2 is not configured for Python Connector engine.
|
||||
"""
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
new={"DATABASE_OAUTH2_CLIENTS": {}},
|
||||
)
|
||||
|
||||
assert DatabricksPythonConnectorEngineSpec.is_oauth2_enabled() is False
|
||||
|
||||
|
||||
def test_is_oauth2_enabled_config_python(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test `is_oauth2_enabled` when OAuth2 is configured for Python Connector engine.
|
||||
"""
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
new={
|
||||
"DATABASE_OAUTH2_CLIENTS": {
|
||||
"Databricks": {
|
||||
"id": "client-id",
|
||||
"secret": "client-secret",
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert DatabricksPythonConnectorEngineSpec.is_oauth2_enabled() is True
|
||||
|
||||
|
||||
def test_get_oauth2_authorization_uri_native(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config_native: OAuth2ClientConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Test `get_oauth2_authorization_uri` for Native engine.
|
||||
"""
|
||||
from superset.db_engine_specs.base import OAuth2State
|
||||
|
||||
state: OAuth2State = {
|
||||
"database_id": 1,
|
||||
"user_id": 1,
|
||||
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"tab_id": "1234",
|
||||
}
|
||||
|
||||
url = DatabricksNativeEngineSpec.get_oauth2_authorization_uri(
|
||||
oauth2_config_native, state
|
||||
)
|
||||
parsed = urlparse(url)
|
||||
assert parsed.netloc == "accounts.cloud.databricks.com"
|
||||
assert parsed.path == "/oidc/accounts/12345/v1/authorize"
|
||||
|
||||
query = parse_qs(parsed.query)
|
||||
assert query["scope"][0] == "sql"
|
||||
encoded_state = query["state"][0].replace("%2E", ".")
|
||||
assert decode_oauth2_state(encoded_state) == state
|
||||
|
||||
|
||||
def test_get_oauth2_authorization_uri_python(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config_python: OAuth2ClientConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Test `get_oauth2_authorization_uri` for Python Connector engine.
|
||||
"""
|
||||
from superset.db_engine_specs.base import OAuth2State
|
||||
|
||||
state: OAuth2State = {
|
||||
"database_id": 1,
|
||||
"user_id": 1,
|
||||
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"tab_id": "1234",
|
||||
}
|
||||
|
||||
url = DatabricksPythonConnectorEngineSpec.get_oauth2_authorization_uri(
|
||||
oauth2_config_python, state
|
||||
)
|
||||
parsed = urlparse(url)
|
||||
assert parsed.netloc == "accounts.cloud.databricks.com"
|
||||
assert parsed.path == "/oidc/accounts/12345/v1/authorize"
|
||||
|
||||
query = parse_qs(parsed.query)
|
||||
assert query["scope"][0] == "sql"
|
||||
encoded_state = query["state"][0].replace("%2E", ".")
|
||||
assert decode_oauth2_state(encoded_state) == state
|
||||
|
||||
|
||||
def test_get_oauth2_token_native(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config_native: OAuth2ClientConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Test `get_oauth2_token` for Native engine.
|
||||
"""
|
||||
requests = mocker.patch("superset.db_engine_specs.base.requests")
|
||||
requests.post().json.return_value = {
|
||||
"access_token": "access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "sql",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "refresh-token",
|
||||
}
|
||||
|
||||
assert DatabricksNativeEngineSpec.get_oauth2_token(
|
||||
oauth2_config_native, "authorization-code"
|
||||
) == {
|
||||
"access_token": "access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "sql",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "refresh-token",
|
||||
}
|
||||
requests.post.assert_called_with(
|
||||
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
|
||||
json={
|
||||
"code": "authorization-code",
|
||||
"client_id": "databricks-client-id",
|
||||
"client_secret": "databricks-client-secret",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
|
||||
def test_get_oauth2_token_python(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config_python: OAuth2ClientConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Test `get_oauth2_token` for Python Connector engine.
|
||||
"""
|
||||
requests = mocker.patch("superset.db_engine_specs.base.requests")
|
||||
requests.post().json.return_value = {
|
||||
"access_token": "access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "sql",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "refresh-token",
|
||||
}
|
||||
|
||||
assert DatabricksPythonConnectorEngineSpec.get_oauth2_token(
|
||||
oauth2_config_python, "authorization-code"
|
||||
) == {
|
||||
"access_token": "access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "sql",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "refresh-token",
|
||||
}
|
||||
requests.post.assert_called_with(
|
||||
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
|
||||
json={
|
||||
"code": "authorization-code",
|
||||
"client_id": "databricks-client-id",
|
||||
"client_secret": "databricks-client-secret",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
|
||||
def test_get_oauth2_fresh_token_native(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config_native: OAuth2ClientConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Test `get_oauth2_fresh_token` for Native engine.
|
||||
"""
|
||||
requests = mocker.patch("superset.db_engine_specs.base.requests")
|
||||
requests.post().json.return_value = {
|
||||
"access_token": "new-access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "sql",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "new-refresh-token",
|
||||
}
|
||||
|
||||
assert DatabricksNativeEngineSpec.get_oauth2_fresh_token(
|
||||
oauth2_config_native, "old-refresh-token"
|
||||
) == {
|
||||
"access_token": "new-access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "sql",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "new-refresh-token",
|
||||
}
|
||||
requests.post.assert_called_with(
|
||||
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
|
||||
json={
|
||||
"client_id": "databricks-client-id",
|
||||
"client_secret": "databricks-client-secret",
|
||||
"refresh_token": "old-refresh-token",
|
||||
"grant_type": "refresh_token",
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
|
||||
def _oauth2_state() -> OAuth2State:
|
||||
"""
|
||||
Build the default OAuth2 state shared by the OAuth2 tests.
|
||||
"""
|
||||
state: OAuth2State = {
|
||||
"database_id": 1,
|
||||
"user_id": 1,
|
||||
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"tab_id": "1234",
|
||||
}
|
||||
return state
|
||||
|
||||
|
||||
def _unresolved_oauth2_config() -> OAuth2ClientConfig:
|
||||
"""
|
||||
Config as built by `get_oauth2_config` when no endpoints are overridden:
|
||||
the URIs default to the spec's empty class attributes.
|
||||
"""
|
||||
return {
|
||||
"id": "databricks-client-id",
|
||||
"secret": "databricks-client-secret",
|
||||
"scope": "sql",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"authorization_request_uri": "",
|
||||
"token_request_uri": "",
|
||||
"request_content_type": "json",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec",
|
||||
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"host",
|
||||
[
|
||||
"dbc-abc.cloud.databricks.com",
|
||||
"adb-123456789.12.azuredatabricks.net",
|
||||
"123456789.gcp.databricks.com",
|
||||
],
|
||||
)
|
||||
def test_get_oauth2_authorization_uri_derives_from_workspace_host(
|
||||
mocker: MockerFixture,
|
||||
spec: Any,
|
||||
host: str,
|
||||
) -> None:
|
||||
"""
|
||||
With no configured `authorization_request_uri`, the endpoint is derived from
|
||||
the workspace host (`https://<host>/oidc/v1/authorize`) on every cloud, with
|
||||
no account/tenant identifier required.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
database.url_object.host = host
|
||||
mocker.patch("superset.db.session.get", return_value=database)
|
||||
|
||||
url = spec.get_oauth2_authorization_uri(
|
||||
_unresolved_oauth2_config(), _oauth2_state()
|
||||
)
|
||||
parsed = urlparse(url)
|
||||
assert parsed.netloc == host
|
||||
assert parsed.path == "/oidc/v1/authorize"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec",
|
||||
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
|
||||
)
|
||||
def test_get_oauth2_authorization_uri_preserves_configured(
|
||||
mocker: MockerFixture,
|
||||
spec: Any,
|
||||
) -> None:
|
||||
"""
|
||||
A fully-resolved `authorization_request_uri` is never overwritten by the
|
||||
host-derived endpoint, and no database lookup is needed.
|
||||
"""
|
||||
session_get = mocker.patch("superset.db.session.get")
|
||||
config = _unresolved_oauth2_config()
|
||||
config["authorization_request_uri"] = (
|
||||
"https://accounts.cloud.databricks.com/oidc/accounts/override/v1/authorize"
|
||||
)
|
||||
|
||||
url = spec.get_oauth2_authorization_uri(config, _oauth2_state())
|
||||
assert urlparse(url).path == "/oidc/accounts/override/v1/authorize"
|
||||
session_get.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec",
|
||||
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
|
||||
)
|
||||
def test_get_oauth2_authorization_uri_fails_without_host(
|
||||
mocker: MockerFixture,
|
||||
spec: Any,
|
||||
) -> None:
|
||||
"""
|
||||
When the endpoint must be derived but the connection has no host, fail fast
|
||||
instead of emitting an invalid `https:///oidc/v1/authorize` URL.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
database.url_object.host = None
|
||||
mocker.patch("superset.db.session.get", return_value=database)
|
||||
|
||||
with pytest.raises(OAuth2Error):
|
||||
spec.get_oauth2_authorization_uri(_unresolved_oauth2_config(), _oauth2_state())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec",
|
||||
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
|
||||
)
|
||||
def test_get_oauth2_token_fails_without_uri(
|
||||
mocker: MockerFixture,
|
||||
spec: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Token exchange has no database context to auto-detect the endpoint, so a
|
||||
missing `token_request_uri` fails fast rather than POSTing to `.../{}/...`.
|
||||
"""
|
||||
with pytest.raises(OAuth2Error):
|
||||
spec.get_oauth2_token(_unresolved_oauth2_config(), "authorization-code")
|
||||
|
||||
|
||||
def test_get_oauth2_fresh_token_python(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config_python: OAuth2ClientConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Test `get_oauth2_fresh_token` for Python Connector engine.
|
||||
"""
|
||||
requests = mocker.patch("superset.db_engine_specs.base.requests")
|
||||
requests.post().json.return_value = {
|
||||
"access_token": "new-access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "sql",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "new-refresh-token",
|
||||
}
|
||||
|
||||
assert DatabricksPythonConnectorEngineSpec.get_oauth2_fresh_token(
|
||||
oauth2_config_python, "old-refresh-token"
|
||||
) == {
|
||||
"access_token": "new-access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "sql",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "new-refresh-token",
|
||||
}
|
||||
requests.post.assert_called_with(
|
||||
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
|
||||
json={
|
||||
"client_id": "databricks-client-id",
|
||||
"client_secret": "databricks-client-secret",
|
||||
"refresh_token": "old-refresh-token",
|
||||
"grant_type": "refresh_token",
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
127
tests/unit_tests/db_engine_specs/test_databricks_multi_cloud.py
Normal file
127
tests/unit_tests/db_engine_specs/test_databricks_multi_cloud.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from superset.db_engine_specs.databricks import (
|
||||
DatabricksNativeEngineSpec,
|
||||
DatabricksPythonConnectorEngineSpec,
|
||||
)
|
||||
from superset.superset_typing import OAuth2ClientConfig
|
||||
from superset.utils.oauth2 import decode_oauth2_state
|
||||
|
||||
# Multi-Cloud Provider Tests
|
||||
#
|
||||
# Databricks fronts the user-to-machine OAuth2 flow on every workspace at
|
||||
# `https://<workspace-host>/oidc/v1/{authorize,token}`, regardless of cloud, so
|
||||
# the authorization endpoint derives from the connection host with no per-cloud
|
||||
# account/tenant identifier.
|
||||
|
||||
SPECS = [DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec]
|
||||
|
||||
# Representative workspace hosts for each cloud provider.
|
||||
CLOUD_HOSTS = [
|
||||
"my-cluster.cloud.databricks.com", # AWS
|
||||
"adb-123456789.12.azuredatabricks.net", # Azure
|
||||
"123456789.gcp.databricks.com", # GCP
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_config_no_uri() -> OAuth2ClientConfig:
|
||||
"""
|
||||
Config for Databricks OAuth2 without a pre-configured endpoint, so the
|
||||
authorization endpoint is derived from the workspace host.
|
||||
"""
|
||||
return {
|
||||
"id": "databricks-client-id",
|
||||
"secret": "databricks-client-secret",
|
||||
"scope": "sql",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"authorization_request_uri": "",
|
||||
"token_request_uri": "",
|
||||
"request_content_type": "json",
|
||||
}
|
||||
|
||||
|
||||
def _mock_database(mocker: MockerFixture, host: str) -> MagicMock:
|
||||
"""
|
||||
Build a mock database whose URL resolves to the given workspace host.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
database.url_object.host = host
|
||||
database.id = 1
|
||||
return database
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec", SPECS)
|
||||
@pytest.mark.parametrize("host", CLOUD_HOSTS)
|
||||
def test_get_oauth2_authorization_uri_uses_workspace_host(
|
||||
mocker: MockerFixture,
|
||||
spec: Any,
|
||||
host: str,
|
||||
oauth2_config_no_uri: OAuth2ClientConfig,
|
||||
) -> None:
|
||||
"""
|
||||
The authorization endpoint is the workspace host on AWS, Azure, and GCP.
|
||||
"""
|
||||
from superset.db_engine_specs.base import OAuth2State
|
||||
|
||||
mocker.patch(
|
||||
"superset.extensions.db.session.get",
|
||||
return_value=_mock_database(mocker, host),
|
||||
)
|
||||
|
||||
state: OAuth2State = {
|
||||
"database_id": 1,
|
||||
"user_id": 1,
|
||||
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"tab_id": "1234",
|
||||
}
|
||||
|
||||
url = spec.get_oauth2_authorization_uri(oauth2_config_no_uri, state)
|
||||
parsed = urlparse(url)
|
||||
assert parsed.netloc == host
|
||||
assert parsed.path == "/oidc/v1/authorize"
|
||||
|
||||
query = parse_qs(parsed.query)
|
||||
assert query["scope"][0] == "sql"
|
||||
encoded_state = query["state"][0].replace("%2E", ".")
|
||||
assert decode_oauth2_state(encoded_state) == state
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec", SPECS)
|
||||
@pytest.mark.parametrize("host", CLOUD_HOSTS)
|
||||
def test_workspace_oauth2_endpoint_builds_token_uri(
|
||||
mocker: MockerFixture,
|
||||
spec: Any,
|
||||
host: str,
|
||||
) -> None:
|
||||
"""
|
||||
The helper builds the matching token endpoint from the same workspace host.
|
||||
"""
|
||||
database = _mock_database(mocker, host)
|
||||
assert (
|
||||
spec._workspace_oauth2_endpoint(database, "token")
|
||||
== f"https://{host}/oidc/v1/token"
|
||||
)
|
||||
@@ -96,6 +96,7 @@ async def test_list_dashboards_basic(mock_list, mcp_server):
|
||||
dashboard.uuid = "test-dashboard-uuid-1"
|
||||
dashboard.thumbnail_url = None
|
||||
dashboard.roles = []
|
||||
dashboard.embedded = []
|
||||
dashboard.charts = []
|
||||
dashboard._mapping = {
|
||||
"id": dashboard.id,
|
||||
@@ -162,6 +163,7 @@ async def test_list_dashboards_with_filters(mock_list, mcp_server):
|
||||
dashboard.uuid = None
|
||||
dashboard.thumbnail_url = None
|
||||
dashboard.roles = []
|
||||
dashboard.embedded = []
|
||||
dashboard.charts = []
|
||||
dashboard._mapping = {
|
||||
"id": dashboard.id,
|
||||
@@ -257,6 +259,7 @@ async def test_list_dashboards_with_search(mock_list, mcp_server):
|
||||
dashboard.uuid = None
|
||||
dashboard.thumbnail_url = None
|
||||
dashboard.roles = []
|
||||
dashboard.embedded = []
|
||||
dashboard.charts = []
|
||||
dashboard._mapping = {
|
||||
"id": dashboard.id,
|
||||
@@ -351,6 +354,7 @@ async def test_get_dashboard_info_success(
|
||||
dashboard.owners = []
|
||||
dashboard.tags = []
|
||||
dashboard.roles = []
|
||||
dashboard.embedded = []
|
||||
dashboard.charts = []
|
||||
dashboard._mapping = {
|
||||
"id": dashboard.id,
|
||||
@@ -429,6 +433,7 @@ async def test_get_dashboard_info_permalink_does_not_double_sanitize(
|
||||
dashboard.owners = []
|
||||
dashboard.tags = []
|
||||
dashboard.roles = []
|
||||
dashboard.embedded = []
|
||||
dashboard.charts = []
|
||||
mock_info.return_value = dashboard
|
||||
permalink_value = {
|
||||
@@ -521,6 +526,7 @@ async def test_get_dashboard_info_permalink_key_includes_filter_state(
|
||||
dashboard.owners = []
|
||||
dashboard.tags = []
|
||||
dashboard.roles = []
|
||||
dashboard.embedded = []
|
||||
dashboard.charts = []
|
||||
mock_info.return_value = dashboard
|
||||
|
||||
@@ -767,6 +773,7 @@ async def test_get_dashboard_info_does_not_expose_access_list_or_roles(
|
||||
dashboard.owners = [owner]
|
||||
dashboard.tags = []
|
||||
dashboard.roles = [dashboard_role]
|
||||
dashboard.embedded = []
|
||||
|
||||
mock_info.return_value = dashboard
|
||||
|
||||
@@ -838,6 +845,7 @@ async def test_get_dashboard_info_restricted_user_redacts_data_model_metadata(
|
||||
dashboard.owners = []
|
||||
dashboard.tags = []
|
||||
dashboard.roles = []
|
||||
dashboard.embedded = []
|
||||
|
||||
mock_info.return_value = dashboard
|
||||
|
||||
@@ -890,6 +898,7 @@ async def test_get_dashboard_info_restricted_user_redacts_permalink_filter_state
|
||||
dashboard.owners = []
|
||||
dashboard.tags = []
|
||||
dashboard.roles = []
|
||||
dashboard.embedded = []
|
||||
|
||||
mock_info.return_value = dashboard
|
||||
|
||||
@@ -1012,6 +1021,88 @@ async def test_list_dashboards_omits_requested_user_directory_fields(
|
||||
assert field not in data["columns_available"]
|
||||
|
||||
|
||||
@patch("superset.mcp_service.mcp_core.ModelGetInfoCore._find_object")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_dashboard_info_includes_embedded_uuid(mock_find_object, mcp_server):
|
||||
"""Test that get_dashboard_info returns embedded_uuid when set."""
|
||||
from superset.models.embedded_dashboard import EmbeddedDashboard
|
||||
|
||||
dashboard = Mock()
|
||||
dashboard.id = 1
|
||||
dashboard.dashboard_title = "Embedded Dashboard"
|
||||
dashboard.slug = ""
|
||||
dashboard.description = None
|
||||
dashboard.css = None
|
||||
dashboard.certified_by = None
|
||||
dashboard.certification_details = None
|
||||
dashboard.json_metadata = "{}"
|
||||
dashboard.published = True
|
||||
dashboard.is_managed_externally = False
|
||||
dashboard.external_url = None
|
||||
dashboard.created_on = None
|
||||
dashboard.changed_on = None
|
||||
dashboard.created_on_humanized = None
|
||||
dashboard.changed_on_humanized = None
|
||||
dashboard.uuid = "94b826a5-dbd5-473d-ab58-1af676ee07e4"
|
||||
dashboard.url = "/dashboard/1"
|
||||
dashboard.slices = []
|
||||
dashboard.owners = []
|
||||
dashboard.tags = []
|
||||
dashboard.roles = []
|
||||
dashboard.embedded = []
|
||||
|
||||
embedded = Mock(spec=EmbeddedDashboard)
|
||||
embedded.uuid = "37c56048-d3f1-452d-b3ae-0879802dcb1f"
|
||||
dashboard.embedded = [embedded]
|
||||
|
||||
mock_find_object.return_value = dashboard
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_dashboard_info", {"request": {"identifier": 1}}
|
||||
)
|
||||
assert result.data["uuid"] == "94b826a5-dbd5-473d-ab58-1af676ee07e4"
|
||||
assert result.data["embedded_uuid"] == "37c56048-d3f1-452d-b3ae-0879802dcb1f"
|
||||
|
||||
|
||||
@patch("superset.mcp_service.mcp_core.ModelGetInfoCore._find_object")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_dashboard_info_embedded_uuid_none_when_not_embedded(
|
||||
mock_find_object, mcp_server
|
||||
):
|
||||
"""Test that embedded_uuid is None when the dashboard has not been configured
|
||||
for embedding."""
|
||||
dashboard = Mock()
|
||||
dashboard.id = 2
|
||||
dashboard.dashboard_title = "Non-Embedded Dashboard"
|
||||
dashboard.slug = ""
|
||||
dashboard.description = None
|
||||
dashboard.css = None
|
||||
dashboard.certified_by = None
|
||||
dashboard.certification_details = None
|
||||
dashboard.json_metadata = "{}"
|
||||
dashboard.published = True
|
||||
dashboard.is_managed_externally = False
|
||||
dashboard.external_url = None
|
||||
dashboard.created_on = None
|
||||
dashboard.changed_on = None
|
||||
dashboard.created_on_humanized = None
|
||||
dashboard.changed_on_humanized = None
|
||||
dashboard.uuid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
dashboard.url = "/dashboard/2"
|
||||
dashboard.slices = []
|
||||
dashboard.owners = []
|
||||
dashboard.tags = []
|
||||
dashboard.roles = []
|
||||
dashboard.embedded = []
|
||||
|
||||
mock_find_object.return_value = dashboard
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_dashboard_info", {"request": {"identifier": 2}}
|
||||
)
|
||||
assert result.data.get("embedded_uuid") is None
|
||||
|
||||
|
||||
# TODO (Phase 3+): Add tests for get_dashboard_available_filters tool
|
||||
|
||||
|
||||
@@ -1044,6 +1135,7 @@ async def test_get_dashboard_info_by_uuid(mock_find_object, mcp_server):
|
||||
dashboard.owners = []
|
||||
dashboard.tags = []
|
||||
dashboard.roles = []
|
||||
dashboard.embedded = []
|
||||
|
||||
mock_find_object.return_value = dashboard
|
||||
async with Client(mcp_server) as client:
|
||||
@@ -1083,6 +1175,7 @@ async def test_get_dashboard_info_by_slug(mock_find_object, mcp_server):
|
||||
dashboard.owners = []
|
||||
dashboard.tags = []
|
||||
dashboard.roles = []
|
||||
dashboard.embedded = []
|
||||
|
||||
mock_find_object.return_value = dashboard
|
||||
async with Client(mcp_server) as client:
|
||||
@@ -1122,6 +1215,7 @@ async def test_list_dashboards_custom_uuid_slug_columns(mock_list, mcp_server):
|
||||
dashboard.external_url = None
|
||||
dashboard.thumbnail_url = None
|
||||
dashboard.roles = []
|
||||
dashboard.embedded = []
|
||||
dashboard.charts = []
|
||||
dashboard._mapping = {
|
||||
"id": dashboard.id,
|
||||
@@ -1203,6 +1297,7 @@ async def test_list_dashboards_sanitizes_dashboard_descriptions_and_filter_text(
|
||||
dashboard.external_url = None
|
||||
dashboard.thumbnail_url = None
|
||||
dashboard.roles = []
|
||||
dashboard.embedded = []
|
||||
dashboard.charts = []
|
||||
dashboard._mapping = {
|
||||
"id": dashboard.id,
|
||||
@@ -1343,6 +1438,7 @@ class TestDashboardDefaultColumnFiltering:
|
||||
dashboard.external_url = None
|
||||
dashboard.thumbnail_url = None
|
||||
dashboard.roles = []
|
||||
dashboard.embedded = []
|
||||
dashboard.charts = []
|
||||
mock_list.return_value = ([dashboard], 1)
|
||||
|
||||
|
||||
@@ -81,6 +81,7 @@ def _mock_dashboard(
|
||||
dashboard.slices = []
|
||||
dashboard.owners = []
|
||||
dashboard.tags = []
|
||||
dashboard.embedded = []
|
||||
return dashboard
|
||||
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ class TestValidateExpression:
|
||||
self.table.schema = "test_schema"
|
||||
self.table.catalog = None
|
||||
self.table.database = MagicMock()
|
||||
self.table.database.backend = "sqlite"
|
||||
self.table.database.db_engine_spec = MagicMock()
|
||||
self.table.database.db_engine_spec.make_sqla_column_compatible = lambda x, _: x
|
||||
self.table.columns = []
|
||||
@@ -105,10 +106,8 @@ class TestValidateExpression:
|
||||
|
||||
@patch("superset.connectors.sqla.models.SqlaTable._execute_validation_query")
|
||||
def test_validate_invalid_expression(self, mock_execute):
|
||||
"""Test validation of invalid SQL expressions"""
|
||||
# Mock _execute_validation_query to raise an exception
|
||||
mock_execute.side_effect = Exception("Invalid SQL syntax")
|
||||
|
||||
"""Unparseable SQL is rejected by the shared expression parser before the
|
||||
validation query is built or executed."""
|
||||
result = self.table.validate_expression(
|
||||
expression="INVALID SQL HERE",
|
||||
expression_type=SqlExpressionType.COLUMN,
|
||||
@@ -116,7 +115,8 @@ class TestValidateExpression:
|
||||
|
||||
assert result["valid"] is False
|
||||
assert len(result["errors"]) == 1
|
||||
assert "Invalid SQL syntax" in result["errors"][0]["message"]
|
||||
assert result["errors"][0]["message"]
|
||||
mock_execute.assert_not_called()
|
||||
|
||||
@patch("superset.connectors.sqla.models.SqlaTable._execute_validation_query")
|
||||
def test_validate_having_with_non_aggregated_column(self, mock_execute):
|
||||
@@ -152,6 +152,38 @@ class TestValidateExpression:
|
||||
# The actual error message will come from the exception
|
||||
assert "empty" in result["errors"][0]["message"].lower()
|
||||
|
||||
@patch("superset.models.helpers.is_feature_enabled", return_value=False)
|
||||
@patch("superset.connectors.sqla.models.SqlaTable._execute_validation_query")
|
||||
def test_validate_expression_rejects_subquery(
|
||||
self, mock_execute: MagicMock, mock_ff: MagicMock
|
||||
) -> None:
|
||||
"""A sub-query expression is rejected by the same validate_adhoc_subquery
|
||||
gate used for stored adhoc expressions, before any validation query is
|
||||
built or run (with ALLOW_ADHOC_SUBQUERY off, the default). Locks in that
|
||||
expression validation never executes the sub-query."""
|
||||
result = self.table.validate_expression(
|
||||
expression="(SELECT 1) IS NOT NULL OR 1 = 1",
|
||||
expression_type=SqlExpressionType.WHERE,
|
||||
)
|
||||
|
||||
assert result["valid"] is False
|
||||
mock_execute.assert_not_called()
|
||||
|
||||
@patch("superset.models.helpers.is_feature_enabled", return_value=False)
|
||||
@patch("superset.connectors.sqla.models.SqlaTable._execute_validation_query")
|
||||
def test_validate_expression_rejects_set_operation(
|
||||
self, mock_execute: MagicMock, mock_ff: MagicMock
|
||||
) -> None:
|
||||
"""A set-operation expression is rejected before the validation query is
|
||||
built or run, matching the stored-adhoc-expression policy."""
|
||||
result = self.table.validate_expression(
|
||||
expression="1 UNION SELECT 1",
|
||||
expression_type=SqlExpressionType.WHERE,
|
||||
)
|
||||
|
||||
assert result["valid"] is False
|
||||
mock_execute.assert_not_called()
|
||||
|
||||
@patch("superset.connectors.sqla.models.SqlaTable._execute_validation_query")
|
||||
def test_validate_expression_with_rls(self, mock_execute):
|
||||
"""Test that RLS filters are applied during validation"""
|
||||
|
||||
@@ -38,3 +38,31 @@ def test_aggregate():
|
||||
assert series_to_list(df["asc sum"])[0] == 5050
|
||||
assert series_to_list(df["asc q2"])[0] == 75
|
||||
assert series_to_list(df["desc q1"])[0] == 25
|
||||
|
||||
|
||||
def test_aggregate_string_operators():
|
||||
"""mean, median, and other operators in _PANDAS_STRING_AGGREGATORS use the
|
||||
pandas string path; verify results match expected values on asc_idx [0..100]."""
|
||||
aggregates = {
|
||||
"asc mean": {"column": "asc_idx", "operator": "mean"},
|
||||
"asc median": {"column": "asc_idx", "operator": "median"},
|
||||
"asc max": {"column": "asc_idx", "operator": "max"},
|
||||
"asc min": {"column": "asc_idx", "operator": "min"},
|
||||
}
|
||||
df = aggregate(df=categories_df, groupby=["constant"], aggregates=aggregates)
|
||||
assert series_to_list(df["asc mean"])[0] == 50.0
|
||||
assert series_to_list(df["asc median"])[0] == 50.0
|
||||
assert series_to_list(df["asc max"])[0] == 100
|
||||
assert series_to_list(df["asc min"])[0] == 0
|
||||
|
||||
|
||||
def test_aggregate_count_includes_nulls():
|
||||
"""'count' operator uses np.ma.count, which counts all rows including NaN.
|
||||
It is intentionally excluded from _PANDAS_STRING_AGGREGATORS to preserve this
|
||||
behavior (pandas SeriesGroupBy.count excludes NaN)."""
|
||||
aggregates = {
|
||||
"null_count": {"column": "idx_nulls", "operator": "count"},
|
||||
}
|
||||
df = aggregate(df=categories_df, groupby=["constant"], aggregates=aggregates)
|
||||
# idx_nulls has 101 rows total; np.ma.count returns all 101 (NaN included)
|
||||
assert series_to_list(df["null_count"])[0] == 101
|
||||
|
||||
Reference in New Issue
Block a user