Compare commits

...

29 Commits

Author SHA1 Message Date
Evan
563edfdd9f fix(databricks): tighten 401 oauth2 signal to avoid false positives
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-27 21:33:16 -07:00
Evan
98ba9da18c test(databricks): add docstring to mock database helper
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-27 21:33:16 -07:00
Evan
910e79dd8c fix(databricks): correct OAuth2 trigger and derive endpoints from workspace host
Addresses review feedback on the Databricks OAuth2 flow:

- `oauth2_exception` was set to `OAuth2RedirectError` (Superset's own redirect
  signal, also the base default), so `needs_oauth2()` never matched a real
  Databricks token failure and the dance never auto-started. The driver has no
  dedicated auth exception, so detect auth failures from the error message
  instead (mirrors `GSheetsEngineSpec.needs_oauth2`).

- The per-cloud endpoint templates pointed Azure at Entra ID directly
  (`login.microsoftonline.com`) and required `account_id`/`tenant_id`
  substitution. Databricks fronts the U2M flow on every workspace at
  `https://<host>/oidc/v1/{authorize,token}` across AWS/Azure/GCP, so the
  authorization endpoint now derives from the workspace host with no account
  identifier. The token endpoint still requires explicit config (no DB context
  at exchange time); the error and docs now point at the workspace-host URL.

Shared OAuth logic is consolidated onto `DatabricksDynamicBaseEngineSpec`,
removing the duplicated overrides in both engine specs.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-27 21:33:16 -07:00
Evan
6251280aa5 fix(databricks): guard non-string cloud_provider before .lower()
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-27 21:33:16 -07:00
Evan Rusackas
9d0a2209c6 test(databricks): cover invalid cloud_provider fallback to hostname
The only uncovered branch in `_detect_cloud_provider`: an unrecognized explicit
`cloud_provider` should be ignored and detection should fall back to hostname
sniffing rather than returning the bad value.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-27 21:33:16 -07:00
Evan
cb7d5c8847 docs(databricks): clarify token endpoint is not auto-detected
The authorization endpoint auto-resolves from the hostname, but the token
exchange has no database context, so token_request_uri must be supplied for
the auto-detected flow. Docs implied both endpoints auto-detect.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-27 21:33:16 -07:00
Evan
fe69d222bd test(databricks): docstring the shared OAuth2 state helper
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-27 21:33:16 -07:00
Evan
8e0871bb2e test(databricks): exercise provider detection without pre-set OAuth2 URI
The multi-cloud OAuth2 URI tests passed a config with a fully-resolved
authorization_request_uri, which the engine spec now preserves. Drop the
URI for the Azure/GCP detection cases (and give those mock databases an
account_id/tenant_id) so the per-provider endpoint is actually resolved.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-27 21:33:16 -07:00
Evan Rusackas
ea1ba587f2 fix(databricks): resolve account_id in OAuth2 endpoints, preserve configured URIs
The per-cloud OAuth2 endpoint templates carry a `{}` placeholder for the
Databricks account id (or Azure tenant id) that was never substituted, so
auto-detected authorize/token URLs were emitted as `.../accounts/{}/v1/...`.
The authorization-URI methods also unconditionally overwrote a fully-resolved
`authorization_request_uri` supplied via DATABASE_OAUTH2_CLIENTS.

- Add `_resolve_oauth2_endpoint`: substitutes `account_id`/`tenant_id` from the
  database extra into the template, or raises OAuth2Error when absent instead of
  issuing a request to an unresolved endpoint.
- Preserve a configured `authorization_request_uri`; only auto-detect/resolve
  when none is set.
- `get_oauth2_token` has no database context to auto-detect, so fail fast on a
  missing `token_request_uri` rather than POST to `.../{}/v1/token`.
- Cover auto-detect/resolve, preserve-configured, and fail-fast paths for both
  the native and Python-connector specs; document `account_id`/`tenant_id`.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-27 21:33:16 -07:00
Evan
dc99373579 test(databricks): add return/param type annotations to multi-cloud OAuth fixtures
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-27 21:33:16 -07:00
Evan
8737f010f3 fix(databricks): preserve resolved OAuth2 token request URI
get_oauth2_token clobbered the config's already-resolved token_request_uri
with the AWS template that still contained an unsubstituted account-id
placeholder, so the token exchange POSTed to .../accounts/{}/v1/token. Only
fall back to the AWS endpoint when no token_request_uri is configured.

Co-authored-by: fabian_zse <fabian@zalando.de>
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-27 21:33:16 -07:00
fabian_zse
278cfbb694 cloud providers test 2026-06-27 21:33:16 -07:00
fabian_zse
1de21ec5c6 support all cloud providers 2026-06-27 21:33:15 -07:00
Fabian Halkivaha
2ed41ae8a6 fix docs slightly 2026-06-27 21:33:15 -07:00
fabian_zse
a0fdb2aa31 add databricks oauth support 2026-06-27 21:33:15 -07:00
ʈᵃᵢ
25c9f3510a test(mcp): set embedded on update_dashboard test mock (#41495) 2026-06-28 11:19:01 +07:00
Đỗ Trọng Hải
b8fd2e9725 feat(websocket,embedded-sdk): replace Jest with modern Vitest (#38308)
Signed-off-by: hainenber <dotronghai96@gmail.com>
Co-authored-by: codeant-ai-for-open-source[bot] <244253245+codeant-ai-for-open-source[bot]@users.noreply.github.com>
2026-06-28 11:12:37 +07:00
Evan Rusackas
78dd400ca4 chore(ci): correct actions/cache version comment to match pinned SHA (#41483)
Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-28 10:56:33 +07:00
Evan Rusackas
7587d0778a chore(ci): correct actions/cache version comment to v5.0.5 (#41484)
Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-28 10:56:08 +07:00
Elizabeth Thompson
97cb002f46 fix(a11y): propagate tooltip string as aria-label on IconTooltip button (#41493) 2026-06-27 15:01:46 -07:00
Elizabeth Thompson
5ec0931840 fix(pandas_postprocessing): pass string operator names to GroupBy.agg to avoid FutureWarning (#41025) 2026-06-27 15:01:43 -07:00
Elizabeth Thompson
3eb9185521 fix(viz): use series_limit/series_limit_metric in query_obj dict (#41002)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-27 15:01:40 -07:00
Shaitan
cd8ac41d16 fix(datasource): validate expressions through the shared adhoc-expression checks (#41427)
Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-27 19:47:59 +01:00
Evan Rusackas
21999bb772 fix(i18n): repair corrupted Romanian catalog so it parses again (#41467)
Co-authored-by: Claude Code <noreply@anthropic.com>
2026-06-27 09:13:39 -04:00
innovark
0a18779280 fix(echarts): format mixed timeseries value labels by assigned axis (#40420)
Co-authored-by: Evan Rusackas <evan@preset.io>
2026-06-27 01:27:41 -07:00
Krishna Chaitanya
a147079043 fix(bigquery): backslash-escape apostrophes in filter values (#38835)
BigQuery rejects filter values containing apostrophes (e.g. O'Brien): the
sqlalchemy-bigquery dialect renders string literals via repr(), which switches
to double-quote delimiters that BigQuery parses as identifiers, causing a
syntax error.

Monkey-patch the dialect's colspecs with a TypeDecorator whose literal_processor
emits single-quoted literals using backslash escaping ('O\'Brien'). Doubled
single quotes ('O''Brien') are NOT valid in BigQuery (parsed as concatenated
literals). Control characters are emitted as named escapes with a \xhh fallback,
since BigQuery forbids literal control chars in quoted strings. Follows the
existing Databricks dialect pattern.

Fixes #35857
2026-06-27 00:55:47 -07:00
Abdul Rehman
ebb32de625 fix(cachekey): use data_cache for chart query result invalidation (#40493) 2026-06-26 18:01:14 -07:00
Onur Taşhan
1280eaee18 fix(mcp): include embedded_uuid in get_dashboard_info response (#41195)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-26 18:00:10 -07:00
jesperct
15626a047c fix(sqllab): quote autocomplete table names that need it (#41199) 2026-06-26 17:58:05 -07:00
48 changed files with 4832 additions and 19005 deletions

View File

@@ -63,7 +63,7 @@ jobs:
yarn install --immutable
- name: Cache pre-commit environments
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: ~/.cache/pre-commit
key: pre-commit-v2-${{ runner.os }}-py${{ matrix.python-version }}-${{ hashFiles('.pre-commit-config.yaml') }}

View File

@@ -56,7 +56,7 @@ jobs:
- name: Cache npm
if: env.HAS_TAGS
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: ~/.npm # npm cache files are stored in `~/.npm` on Linux/macOS
key: ${{ runner.OS }}-node-${{ hashFiles('**/package-lock.json') }}
@@ -70,7 +70,7 @@ jobs:
run: echo "dir=$(npm config get cache)" >> $GITHUB_OUTPUT
- name: Cache npm
if: env.HAS_TAGS
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
id: npm-cache # use this to check for `cache-hit` (`steps.npm-cache.outputs.cache-hit != 'true'`)
with:
path: ${{ steps.npm-cache-dir-path.outputs.dir }}

View File

@@ -28,6 +28,10 @@ jobs:
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
persist-credentials: false
- name: Setup Node.js
uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0
with:
node-version-file: './superset-websocket/.nvmrc'
- name: Install dependencies
working-directory: ./superset-websocket
run: npm ci

View File

@@ -519,6 +519,80 @@ For a connection to a SQL endpoint you need to use the HTTP path from the endpoi
{"connect_args": {"http_path": "/sql/1.0/endpoints/****", "driver_path": "/path/to/odbc/driver"}}
```
##### OAuth2 Authentication
Superset supports OAuth2 authentication for Databricks, allowing users to authenticate with their personal Databricks accounts instead of using shared access tokens. This provides better security and audit capabilities.
###### Prerequisites
1. Create an OAuth2 application in your Databricks account:
- Go to your Databricks account console
- Navigate to **Settings** → **Developer** → **OAuth apps**
- Create a new OAuth app with the redirect URI: `http://your-superset-host:port/api/v1/database/oauth2/`
2. Configure OAuth2 in your `superset_config.py`:
```python
from datetime import timedelta
# OAuth2 configuration for Databricks
# The authorization endpoint is derived from your Databricks workspace host; the
# token endpoint must be set explicitly (see notes below).
DATABASE_OAUTH2_CLIENTS = {
"Databricks (legacy)": {
"id": "your-databricks-client-id",
"secret": "your-databricks-client-secret",
"scope": "sql",
"token_request_uri": "https://your-workspace-host/oidc/v1/token",
},
"Databricks": {
"id": "your-databricks-client-id",
"secret": "your-databricks-client-secret",
"scope": "sql",
"token_request_uri": "https://your-workspace-host/oidc/v1/token",
},
}
# OAuth2 redirect URI (adjust hostname/port for your setup)
DATABASE_OAUTH2_REDIRECT_URI = "http://your-superset-host:port/api/v1/database/oauth2/"
# Optional: OAuth2 timeout
DATABASE_OAUTH2_TIMEOUT = timedelta(seconds=30)
```
Replace the following placeholders:
- `your-databricks-client-id`: Your Databricks OAuth2 application client ID
- `your-databricks-client-secret`: Your Databricks OAuth2 application client secret
- `your-superset-host:port`: Your Superset instance hostname and port
**Multi-Cloud Provider Support**
Databricks fronts the user-to-machine (U2M) OAuth2 flow on every workspace at
`https://<workspace-host>/oidc/v1/authorize` and
`https://<workspace-host>/oidc/v1/token`, regardless of whether the workspace
runs on AWS, Azure, or GCP. Superset derives the **authorization** endpoint
directly from your connection's host, so no cloud provider or account/tenant
identifier needs to be configured.
The **token** endpoint cannot be auto-derived (token exchange has no database
context to read the host), so you must supply `token_request_uri` in
`DATABASE_OAUTH2_CLIENTS`, set to `https://<workspace-host>/oidc/v1/token` for
your workspace.
If you supply a fully-resolved `authorization_request_uri` (and/or
`token_request_uri`), those values take precedence over the host-derived
defaults.
###### Usage
Once configured, users can:
1. Connect to Databricks databases normally using access tokens
2. When querying data, Superset will automatically redirect users to authenticate with Databricks if needed
3. User-specific OAuth2 tokens will be used for database connections, providing better security and audit trails
This feature works with both "Databricks (legacy)" and "Databricks" engine types and automatically supports all major cloud providers (AWS, Azure, GCP).
#### Denodo
The recommended connector library for Denodo is

View File

@@ -32,6 +32,7 @@ and therefore are not easily unit-testable. We have instead opted to test the sd
This way, the tests can assert that the sdk actually mounts the iframe and communicates with it correctly.
At time of writing, these tests are not written yet, because we haven't yet put together the demo app that they will leverage.
### Things to e2e test once we have a demo app:
**happy path:**

View File

@@ -41,12 +41,12 @@ npm install --save @superset-ui/embedded-sdk
```
```js
import { embedDashboard } from '@superset-ui/embedded-sdk';
import { embedDashboard } from "@superset-ui/embedded-sdk";
embedDashboard({
id: 'abc123', // given by the Superset embedding UI
supersetDomain: 'https://superset.example.com',
mountPoint: document.getElementById('my-superset-container'), // any html element that can contain an iframe
id: "abc123", // given by the Superset embedding UI
supersetDomain: "https://superset.example.com",
mountPoint: document.getElementById("my-superset-container"), // any html element that can contain an iframe
fetchGuestToken: () => fetchGuestTokenFromBackend(),
dashboardUiConfig: {
// dashboard UI config: hideTitle, hideTab, hideChartControls, filters.visible, filters.expanded (optional), urlParams (optional)
@@ -55,21 +55,21 @@ embedDashboard({
expanded: true,
},
urlParams: {
foo: 'value1',
bar: 'value2',
foo: "value1",
bar: "value2",
// themeMode: 'dark', // set the initial theme: 'dark' | 'system' | 'default' (default: 'default')
// ...
},
},
// optional additional iframe sandbox attributes
iframeSandboxExtras: [
'allow-top-navigation',
'allow-popups-to-escape-sandbox',
"allow-top-navigation",
"allow-popups-to-escape-sandbox",
],
// optional Permissions Policy features
iframeAllowExtras: ['clipboard-write', 'fullscreen'],
iframeAllowExtras: ["clipboard-write", "fullscreen"],
// optional config to enforce a particular referrerPolicy
referrerPolicy: 'same-origin',
referrerPolicy: "same-origin",
// optional callback to customize permalink URLs
resolvePermalinkUrl: ({ key }) => `https://my-app.com/analytics/share/${key}`,
});
@@ -163,13 +163,13 @@ Use the `themeMode` URL parameter to control the embedded dashboard's initial co
```js
embedDashboard({
id: 'abc123',
supersetDomain: 'https://superset.example.com',
mountPoint: document.getElementById('my-superset-container'),
id: "abc123",
supersetDomain: "https://superset.example.com",
mountPoint: document.getElementById("my-superset-container"),
fetchGuestToken: () => fetchGuestTokenFromBackend(),
dashboardUiConfig: {
urlParams: {
themeMode: 'dark', // 'dark' | 'system' | 'default' (default: 'default')
themeMode: "dark", // 'dark' | 'system' | 'default' (default: 'default')
},
},
});
@@ -193,7 +193,7 @@ To pass additional sandbox attributes you can use `iframeSandboxExtras`:
```js
// optional additional iframe sandbox attributes
iframeSandboxExtras: ['allow-top-navigation', 'allow-popups-to-escape-sandbox'];
iframeSandboxExtras: ["allow-top-navigation", "allow-popups-to-escape-sandbox"];
```
### Permissions Policy
@@ -202,7 +202,7 @@ To enable specific browser features within the embedded iframe, use `iframeAllow
```js
// optional Permissions Policy features
iframeAllowExtras: ['clipboard-write', 'fullscreen'];
iframeAllowExtras: ["clipboard-write", "fullscreen"];
```
Common permissions you might need:
@@ -225,9 +225,9 @@ When users click share buttons inside an embedded dashboard, Superset generates
```js
embedDashboard({
id: 'abc123',
supersetDomain: 'https://superset.example.com',
mountPoint: document.getElementById('my-superset-container'),
id: "abc123",
supersetDomain: "https://superset.example.com",
mountPoint: document.getElementById("my-superset-container"),
fetchGuestToken: () => fetchGuestTokenFromBackend(),
// Customize permalink URLs
@@ -245,9 +245,9 @@ To restore the dashboard state from a permalink in your app:
const permalinkKey = routeParams.key;
embedDashboard({
id: 'abc123',
supersetDomain: 'https://superset.example.com',
mountPoint: document.getElementById('my-superset-container'),
id: "abc123",
supersetDomain: "https://superset.example.com",
mountPoint: document.getElementById("my-superset-container"),
fetchGuestToken: () => fetchGuestTokenFromBackend(),
resolvePermalinkUrl: ({ key }) => `https://my-app.com/analytics/share/${key}`,
dashboardUiConfig: {

View File

@@ -18,9 +18,6 @@
*/
module.exports = {
presets: [
"@babel/preset-typescript",
"@babel/preset-env"
],
presets: ["@babel/preset-typescript", "@babel/preset-env"],
sourceMaps: true,
};

File diff suppressed because it is too large Load Diff

View File

@@ -24,7 +24,7 @@
"scripts": {
"build": "tsc && babel src --out-dir lib --extensions '.ts,.tsx' && webpack --mode production",
"ci:release": "node ./release-if-necessary.js",
"test": "jest"
"test": "vitest --run --dir src"
},
"browserslist": [
"last 3 chrome versions",
@@ -41,12 +41,11 @@
"@babel/core": "^7.25.2",
"@babel/preset-env": "^7.25.4",
"@babel/preset-typescript": "^7.24.7",
"@types/jest": "^29.5.12",
"@types/node": "^22.5.4",
"@types/node": "^25.4.0",
"babel-loader": "^9.1.3",
"jest": "^29.7.0",
"tscw-config": "^1.1.2",
"typescript": "^5.6.2",
"typescript": "^5.9.3",
"vitest": "^4.0.18",
"webpack": "^5.94.0",
"webpack-cli": "^5.1.4"
},

View File

@@ -17,15 +17,15 @@
* under the License.
*/
const { execSync } = require('child_process');
const { name, version } = require('./package.json');
const { execSync } = require("child_process");
const { name, version } = require("./package.json");
function log(...args) {
console.log('[embedded-sdk-release]', ...args);
console.log("[embedded-sdk-release]", ...args);
}
function logError(...args) {
console.error('[embedded-sdk-release]', ...args);
console.error("[embedded-sdk-release]", ...args);
}
(async () => {
@@ -38,13 +38,13 @@ function logError(...args) {
const { status } = await fetch(packageUrl);
if (status === 200) {
log('version already exists on npm, exiting');
log("version already exists on npm, exiting");
} else if (status === 404) {
log('release required, building');
log("release required, building");
try {
execSync('npm run build', { stdio: 'pipe' });
log('build successful, publishing')
execSync('npm publish --access public', { stdio: 'pipe' });
execSync("npm run build", { stdio: "pipe" });
log("build successful, publishing");
execSync("npm publish --access public", { stdio: "pipe" });
log(`published ${version} to npm`);
} catch (err) {
// npm writes failure details to stderr (auth/permission/registry
@@ -52,7 +52,7 @@ function logError(...args) {
// the real cause in CI logs.
if (err.stdout) console.error(String(err.stdout));
if (err.stderr) console.error(String(err.stderr));
logError('Encountered an error, details should be above');
logError("Encountered an error, details should be above");
process.exitCode = 1;
}
} else {

View File

@@ -18,7 +18,9 @@
*/
export const IFRAME_COMMS_MESSAGE_TYPE = "__embedded_comms__";
export const DASHBOARD_UI_FILTER_CONFIG_URL_PARAM_KEY: { [index: string]: any } = {
export const DASHBOARD_UI_FILTER_CONFIG_URL_PARAM_KEY: {
[index: string]: any;
} = {
visible: "show_filters",
expanded: "expand_filters",
}
};

View File

@@ -24,22 +24,23 @@ import {
DEFAULT_TOKEN_EXP_MS,
DEFAULT_TOKEN_REFRESH_RETRY_MS,
} from "./guestTokenRefresh";
import { afterAll, beforeAll, it, expect, describe, vi } from "vitest";
describe("guest token refresh", () => {
beforeAll(() => {
jest.useFakeTimers();
jest.setSystemTime(new Date("2022-03-03 01:00"));
jest.spyOn(global, "setTimeout");
vi.useFakeTimers();
vi.setSystemTime(new Date("2022-03-03 01:00"));
vi.spyOn(globalThis, "setTimeout");
});
afterAll(() => {
jest.useRealTimers();
vi.useRealTimers();
});
function makeFakeJWT(claims: any) {
// not a valid jwt, but close enough for this code
const tokenifiedClaims = Buffer.from(JSON.stringify(claims)).toString(
"base64"
"base64",
);
return `abc.${tokenifiedClaims}.xyz`;
}

View File

@@ -18,17 +18,23 @@
*/
import { jwtDecode } from "jwt-decode";
export const REFRESH_TIMING_BUFFER_MS = 5000 // refresh guest token early to avoid failed superset requests
export const MIN_REFRESH_WAIT_MS = 10000 // avoid blasting requests as fast as the cpu can handle
export const DEFAULT_TOKEN_EXP_MS = 300000 // (5 min) used only when parsing guest token exp fails
export const DEFAULT_TOKEN_REFRESH_RETRY_MS = 10000 // wait before retrying a failed/timed-out token refresh
export const REFRESH_TIMING_BUFFER_MS = 5000; // refresh guest token early to avoid failed superset requests
export const MIN_REFRESH_WAIT_MS = 10000; // avoid blasting requests as fast as the cpu can handle
export const DEFAULT_TOKEN_EXP_MS = 300000; // (5 min) used only when parsing guest token exp fails
export const DEFAULT_TOKEN_REFRESH_RETRY_MS = 10000; // wait before retrying a failed/timed-out token refresh
// when do we refresh the guest token?
export function getGuestTokenRefreshTiming(currentGuestToken: string) {
const parsedJwt = jwtDecode<Record<string, any>>(currentGuestToken);
// if exp is int, it is in seconds, but Date() takes milliseconds
const exp = new Date(/[^0-9\.]/g.test(parsedJwt.exp) ? parsedJwt.exp : parseFloat(parsedJwt.exp) * 1000);
const isValidDate = exp.toString() !== 'Invalid Date';
const ttl = isValidDate ? Math.max(MIN_REFRESH_WAIT_MS, exp.getTime() - Date.now()) : DEFAULT_TOKEN_EXP_MS;
const exp = new Date(
/[^0-9\.]/g.test(parsedJwt.exp)
? parsedJwt.exp
: parseFloat(parsedJwt.exp) * 1000,
);
const isValidDate = exp.toString() !== "Invalid Date";
const ttl = isValidDate
? Math.max(MIN_REFRESH_WAIT_MS, exp.getTime() - Date.now())
: DEFAULT_TOKEN_EXP_MS;
return ttl - REFRESH_TIMING_BUFFER_MS;
}

View File

@@ -20,15 +20,15 @@
import {
DASHBOARD_UI_FILTER_CONFIG_URL_PARAM_KEY,
IFRAME_COMMS_MESSAGE_TYPE,
} from './const';
} from "./const";
// We can swap this out for the actual switchboard package once it gets published
import { Switchboard } from '@superset-ui/switchboard';
import { Switchboard } from "@superset-ui/switchboard";
import {
getGuestTokenRefreshTiming,
DEFAULT_TOKEN_REFRESH_RETRY_MS,
} from './guestTokenRefresh';
import { withTimeout } from './withTimeout';
} from "./guestTokenRefresh";
import { withTimeout } from "./withTimeout";
/**
* The function to fetch a guest token from your Host App's backend server.
@@ -97,7 +97,7 @@ export type ObserveDataMaskCallbackFn = (
nativeFiltersChanged: boolean;
},
) => void;
export type ThemeMode = 'default' | 'dark' | 'system';
export type ThemeMode = "default" | "dark" | "system";
/**
* Callback to resolve permalink URLs.
@@ -113,12 +113,12 @@ export type EmbeddedDashboard = {
unmount: () => void;
getDashboardPermalink: (anchor: string) => Promise<string>;
getActiveTabs: () => Promise<string[]>;
observeDataMask: (
callbackFn: ObserveDataMaskCallbackFn,
) => void;
observeDataMask: (callbackFn: ObserveDataMaskCallbackFn) => void;
getDataMask: () => Promise<Record<string, any>>;
getChartStates: () => Promise<Record<string, any>>;
getChartDataPayloads: (params?: { chartId?: number }) => Promise<Record<string, any>>;
getChartDataPayloads: (params?: {
chartId?: number;
}) => Promise<Record<string, any>>;
setThemeConfig: (themeConfig: Record<string, any>) => void;
setThemeMode: (mode: ThemeMode) => void;
};
@@ -133,7 +133,7 @@ export async function embedDashboard({
fetchGuestToken,
dashboardUiConfig,
debug = false,
iframeTitle = 'Embedded Dashboard',
iframeTitle = "Embedded Dashboard",
iframeSandboxExtras = [],
iframeAllowExtras = [],
referrerPolicy,
@@ -152,13 +152,13 @@ export async function embedDashboard({
return withTimeout(
fetchGuestToken(),
guestTokenFetchTimeoutMs,
'fetchGuestToken',
"fetchGuestToken",
);
}
log('embedding');
log("embedding");
if (supersetDomain.endsWith('/')) {
if (supersetDomain.endsWith("/")) {
supersetDomain = supersetDomain.slice(0, -1);
}
@@ -185,15 +185,15 @@ export async function embedDashboard({
}
async function mountIframe(): Promise<Switchboard> {
return new Promise(resolve => {
const iframe = document.createElement('iframe');
return new Promise((resolve) => {
const iframe = document.createElement("iframe");
const dashboardConfigUrlParams = dashboardUiConfig
? { uiConfig: `${calculateConfig()}` }
: undefined;
const filterConfig = dashboardUiConfig?.filters || {};
const filterConfigKeys = Object.keys(filterConfig);
const filterConfigUrlParams = Object.fromEntries(
filterConfigKeys.map(key => [
filterConfigKeys.map((key) => [
DASHBOARD_UI_FILTER_CONFIG_URL_PARAM_KEY[key],
filterConfig[key],
]),
@@ -206,16 +206,16 @@ export async function embedDashboard({
...dashboardUiConfig?.urlParams,
};
const urlParamsString = Object.keys(urlParams).length
? '?' + new URLSearchParams(urlParams).toString()
: '';
? "?" + new URLSearchParams(urlParams).toString()
: "";
// set up the iframe's sandbox configuration
iframe.sandbox.add('allow-same-origin'); // needed for postMessage to work
iframe.sandbox.add('allow-scripts'); // obviously the iframe needs scripts
iframe.sandbox.add('allow-presentation'); // for fullscreen charts
iframe.sandbox.add('allow-downloads'); // for downloading charts as image
iframe.sandbox.add('allow-forms'); // for forms to submit
iframe.sandbox.add('allow-popups'); // for exporting charts as csv
iframe.sandbox.add("allow-same-origin"); // needed for postMessage to work
iframe.sandbox.add("allow-scripts"); // obviously the iframe needs scripts
iframe.sandbox.add("allow-presentation"); // for fullscreen charts
iframe.sandbox.add("allow-downloads"); // for downloading charts as image
iframe.sandbox.add("allow-forms"); // for forms to submit
iframe.sandbox.add("allow-popups"); // for exporting charts as csv
// additional sandbox props
iframeSandboxExtras.forEach((key: string) => {
iframe.sandbox.add(key);
@@ -226,7 +226,7 @@ export async function embedDashboard({
}
// add the event listener before setting src, to be 100% sure that we capture the load event
iframe.addEventListener('load', () => {
iframe.addEventListener("load", () => {
// MessageChannel allows us to send and receive messages smoothly between our window and the iframe
// See https://developer.mozilla.org/en-US/docs/Web/API/Channel_Messaging_API
const commsChannel = new MessageChannel();
@@ -237,35 +237,35 @@ export async function embedDashboard({
// See https://developer.mozilla.org/en-US/docs/Web/API/Window/postMessage
// we know the content window isn't null because we are in the load event handler.
iframe.contentWindow!.postMessage(
{ type: IFRAME_COMMS_MESSAGE_TYPE, handshake: 'port transfer' },
{ type: IFRAME_COMMS_MESSAGE_TYPE, handshake: "port transfer" },
supersetDomain,
[theirPort],
);
log('sent message channel to the iframe');
log("sent message channel to the iframe");
// return our port from the promise
resolve(
new Switchboard({
port: ourPort,
name: 'superset-embedded-sdk',
name: "superset-embedded-sdk",
debug,
}),
);
});
iframe.src = `${supersetDomain}/embedded/${id}${urlParamsString}`;
iframe.title = iframeTitle;
iframe.style.background = 'transparent';
iframe.style.background = "transparent";
// Permissions Policy features the embedded dashboard relies on. Modern
// browsers gate these APIs on the iframe's `allow` attribute regardless
// of sandbox flags, so we include them by default. Host apps can extend
// the list via `iframeAllowExtras`.
const allowFeatures = Array.from(
new Set(['fullscreen', 'clipboard-write', ...iframeAllowExtras]),
new Set(["fullscreen", "clipboard-write", ...iframeAllowExtras]),
);
iframe.setAttribute('allow', allowFeatures.join('; '));
iframe.setAttribute("allow", allowFeatures.join("; "));
//@ts-ignore
mountPoint.replaceChildren(iframe);
log('placed the iframe');
log("placed the iframe");
});
}
@@ -285,8 +285,8 @@ export async function embedDashboard({
throw err;
}
ourPort.emit('guestToken', { guestToken });
log('sent guest token');
ourPort.emit("guestToken", { guestToken });
log("sent guest token");
// Track the pending refresh timer so it can be cancelled on unmount, and
// stop the cycle once unmounted so it cannot leak across mount/unmount cycles.
@@ -298,7 +298,7 @@ export async function embedDashboard({
try {
const newGuestToken = await fetchGuestTokenWithTimeout();
if (unmounted) return;
ourPort.emit('guestToken', { guestToken: newGuestToken });
ourPort.emit("guestToken", { guestToken: newGuestToken });
refreshTimer = setTimeout(
refreshGuestToken,
getGuestTokenRefreshTiming(newGuestToken),
@@ -307,7 +307,7 @@ export async function embedDashboard({
// A transient fetch failure or timeout must not permanently stop the
// refresh cycle. Log it and retry so the session can recover once the
// host callback succeeds again.
log('failed to refresh guest token, will retry:', err);
log("failed to refresh guest token, will retry:", err);
if (unmounted) return;
refreshTimer = setTimeout(
refreshGuestToken,
@@ -325,7 +325,7 @@ export async function embedDashboard({
// Returns null if no callback provided or on error, allowing iframe to use default URL
ourPort.start();
ourPort.defineMethod(
'resolvePermalinkUrl',
"resolvePermalinkUrl",
async ({ key }: { key: string }): Promise<string | null> => {
if (!resolvePermalinkUrl) {
return null;
@@ -333,14 +333,14 @@ export async function embedDashboard({
try {
return await resolvePermalinkUrl({ key });
} catch (error) {
log('Error in resolvePermalinkUrl callback:', error);
log("Error in resolvePermalinkUrl callback:", error);
return null;
}
},
);
function unmount() {
log('unmounting');
log("unmounting");
unmounted = true;
if (refreshTimer !== undefined) {
clearTimeout(refreshTimer);
@@ -350,24 +350,25 @@ export async function embedDashboard({
mountPoint.replaceChildren();
}
const getScrollSize = () => ourPort.get<Size>('getScrollSize');
const getScrollSize = () => ourPort.get<Size>("getScrollSize");
const getDashboardPermalink = (anchor: string) =>
ourPort.get<string>('getDashboardPermalink', { anchor });
const getActiveTabs = () => ourPort.get<string[]>('getActiveTabs');
const getDataMask = () => ourPort.get<Record<string, any>>('getDataMask');
const getChartStates = () => ourPort.get<Record<string, any>>('getChartStates');
ourPort.get<string>("getDashboardPermalink", { anchor });
const getActiveTabs = () => ourPort.get<string[]>("getActiveTabs");
const getDataMask = () => ourPort.get<Record<string, any>>("getDataMask");
const getChartStates = () =>
ourPort.get<Record<string, any>>("getChartStates");
const getChartDataPayloads = (params?: { chartId?: number }) =>
ourPort.get<Record<string, any>>('getChartDataPayloads', params);
const observeDataMask = (
callbackFn: ObserveDataMaskCallbackFn,
) => {
ourPort.defineMethod('observeDataMask', callbackFn);
ourPort.get<Record<string, any>>("getChartDataPayloads", params);
const observeDataMask = (callbackFn: ObserveDataMaskCallbackFn) => {
ourPort.defineMethod("observeDataMask", callbackFn);
};
// TODO: Add proper types once theming branch is merged
const setThemeConfig = async (themeConfig: Record<string, any>): Promise<void> => {
const setThemeConfig = async (
themeConfig: Record<string, any>,
): Promise<void> => {
try {
ourPort.emit('setThemeConfig', { themeConfig });
log('Theme config sent successfully (or at least message dispatched)');
ourPort.emit("setThemeConfig", { themeConfig });
log("Theme config sent successfully (or at least message dispatched)");
} catch (error) {
log(
'Error sending theme config. Ensure the iframe side implements the "setThemeConfig" method.',
@@ -378,7 +379,7 @@ export async function embedDashboard({
const setThemeMode = (mode: ThemeMode): void => {
try {
ourPort.emit('setThemeMode', { mode });
ourPort.emit("setThemeMode", { mode });
log(`Theme mode set to: ${mode}`);
} catch (error) {
log(

View File

@@ -18,22 +18,23 @@
*/
import { withTimeout } from "./withTimeout";
import { test, expect } from "vitest";
test("resolves with the value when the promise settles in time", async () => {
await expect(withTimeout(Promise.resolve("ok"), 1000, "fetch")).resolves.toBe(
"ok"
"ok",
);
});
test("rejects when the promise does not settle within the timeout", async () => {
const never = new Promise<string>(() => {});
await expect(withTimeout(never, 10, "fetch")).rejects.toThrow(
/fetch did not resolve within 10ms/
/fetch did not resolve within 10ms/,
);
});
test("passes the promise through unchanged when the timeout is disabled", async () => {
await expect(withTimeout(Promise.resolve("ok"), 0, "fetch")).resolves.toBe(
"ok"
"ok",
);
});

View File

@@ -3,7 +3,7 @@
// syntax rules
"strict": true,
"moduleResolution": "node",
"moduleResolution": "bundler",
// environment
"target": "es6",
@@ -13,7 +13,9 @@
// output
"outDir": "./dist",
"emitDeclarationOnly": true,
"declaration": true
"declaration": true,
"types": ["node"]
},
"include": [
@@ -21,7 +23,6 @@
],
"exclude": [
"tests",
"dist",
"lib",
"node_modules"

View File

@@ -17,19 +17,19 @@
* under the License.
*/
const path = require('path');
const path = require("path");
module.exports = {
entry: './src/index.ts',
entry: "./src/index.ts",
output: {
filename: 'index.js',
path: path.resolve(__dirname, 'bundle'),
filename: "index.js",
path: path.resolve(__dirname, "bundle"),
// this exposes the library's exports under a global variable
library: {
name: "supersetEmbeddedSdk",
type: "umd"
}
type: "umd",
},
},
devtool: "source-map",
module: {
@@ -38,12 +38,12 @@ module.exports = {
test: /\.[tj]s$/,
// babel-loader is faster than ts-loader because it ignores types.
// We do type checking in a separate process, so that's fine.
use: 'babel-loader',
use: "babel-loader",
exclude: /node_modules/,
},
],
},
resolve: {
extensions: ['.ts', '.js'],
extensions: [".ts", ".js"],
},
};

View File

@@ -45,6 +45,7 @@ export const IconTooltip = forwardRef<HTMLElement, IconTooltipProps>(
}}
buttonStyle="link"
className={`IconTooltip ${className}`}
aria-label={tooltip ?? undefined}
>
{children}
</Button>

View File

@@ -344,6 +344,16 @@ export default function transformProps(
data2,
currencyCodeColumn,
);
const getAxisFormatterConfig = (axisIndex?: number) =>
axisIndex === 1
? {
customFormatters: customFormattersSecondary,
formatter: formatterSecondary,
}
: {
customFormatters,
formatter,
};
const primarySeries = new Set<string>();
const secondarySeries = new Set<string>();
@@ -422,6 +432,8 @@ export default function transformProps(
let [minSecondary, maxSecondary] = (yAxisBoundsSecondary || []).map(
parseAxisBound,
);
const getAxisMax = (axisIndex?: number) =>
axisIndex === 1 ? maxSecondary : yAxisMax;
const array = ensureIsArray(chartProps.rawFormData?.time_compare);
const inverted = invert(verboseMap);
@@ -445,10 +457,11 @@ export default function transformProps(
// When no groupby, format as just the entry name with optional query identifier
displayName = showQueryIdentifiers ? `${entryName} (Query A)` : entryName;
}
const axisFormatterConfig = getAxisFormatterConfig(yAxisIndex);
const seriesFormatter = getFormatter(
customFormatters,
formatter,
axisFormatterConfig.customFormatters,
axisFormatterConfig.formatter,
metrics,
labelMap?.[seriesName]?.[0],
!!contributionMode,
@@ -480,7 +493,7 @@ export default function transformProps(
formatter:
seriesType === EchartsTimeseriesSeriesType.Bar
? getOverMaxHiddenFormatter({
max: yAxisMax,
max: getAxisMax(yAxisIndex),
formatter: seriesFormatter,
})
: seriesFormatter,
@@ -518,10 +531,11 @@ export default function transformProps(
// When no groupby, format as just the entry name with optional query identifier
displayName = showQueryIdentifiers ? `${entryName} (Query B)` : entryName;
}
const axisFormatterConfig = getAxisFormatterConfig(yAxisIndexB);
const seriesFormatter = getFormatter(
customFormattersSecondary,
formatterSecondary,
axisFormatterConfig.customFormatters,
axisFormatterConfig.formatter,
metricsB,
labelMapB?.[seriesName]?.[0],
!!contributionMode,
@@ -554,7 +568,7 @@ export default function transformProps(
formatter:
seriesTypeB === EchartsTimeseriesSeriesType.Bar
? getOverMaxHiddenFormatter({
max: maxSecondary,
max: getAxisMax(yAxisIndexB),
formatter: seriesFormatter,
})
: seriesFormatter,

View File

@@ -35,13 +35,26 @@ import {
} from '../../src';
import transformProps from '../../src/MixedTimeseries/transformProps';
import {
DEFAULT_FORM_DATA,
EchartsMixedTimeseriesFormData,
EchartsMixedTimeseriesProps,
} from '../../src/MixedTimeseries/types';
import { DEFAULT_FORM_DATA } from '../../src/MixedTimeseries/types';
import { createEchartsTimeseriesTestChartProps } from '../helpers';
import type { SeriesOption } from 'echarts';
type LabelFormatterParams = {
value: [number, number];
dataIndex: number;
seriesIndex: number;
seriesName: string;
};
type SeriesWithLabelFormatter = SeriesOption & {
label?: {
formatter?: (params: LabelFormatterParams) => string | number;
};
};
/**
* Creates a partial ChartDataResponseResult for testing.
* Only includes the fields needed for tests, with sensible defaults for required fields.
@@ -148,6 +161,30 @@ const queriesData: ChartDataResponseResult[] = [
createTestQueryData(defaultQueryRows, { label_map: defaultLabelMap }),
];
function getSeriesWithLabelFormatter(
series: SeriesOption[],
name: string,
): SeriesWithLabelFormatter {
const result = series.find(seriesOption => seriesOption.name === name);
expect(result).toBeDefined();
expect((result as SeriesWithLabelFormatter).label?.formatter).toBeDefined();
return result as SeriesWithLabelFormatter;
}
function formatSeriesLabel(
series: SeriesWithLabelFormatter,
value: [number, number],
) {
const formatter = series.label?.formatter;
expect(formatter).toBeDefined();
return formatter?.({
dataIndex: 0,
seriesIndex: 0,
seriesName: String(series.name),
value,
});
}
test('should transform chart props for viz with showQueryIdentifiers=false', () => {
const chartProps = createEchartsTimeseriesTestChartProps<
EchartsMixedTimeseriesFormData,
@@ -232,6 +269,162 @@ test('should transform chart props for viz with showQueryIdentifiers=true', () =
]);
});
test('formats value labels with the formatter for the assigned y-axis', () => {
const timestamp = 1704067200000;
const queryAData = createTestQueryData(
[{ __timestamp: timestamp, lineMetric: 0.25 }],
{
colnames: ['__timestamp', 'lineMetric'],
coltypes: [GenericDataType.Temporal, GenericDataType.Numeric],
label_map: { lineMetric: ['lineMetric'] },
},
);
const queryBData = createTestQueryData(
[{ __timestamp: timestamp, barMetric: 0.5 }],
{
colnames: ['__timestamp', 'barMetric'],
coltypes: [GenericDataType.Temporal, GenericDataType.Numeric],
label_map: { 'barMetric (1)': ['barMetric'] },
},
);
const chartProps = createEchartsTimeseriesTestChartProps<
EchartsMixedTimeseriesFormData,
EchartsMixedTimeseriesProps
>({
...MIXED_TIMESERIES_CHART_PROPS_DEFAULTS,
defaultQueriesData: [queryAData, queryBData],
formData: {
...formData,
groupby: [],
groupbyB: [],
metrics: ['lineMetric'],
metricsB: ['barMetric'],
showValue: true,
showValueB: true,
stack: null,
stackB: null,
x_axis: '__timestamp',
yAxisFormat: '.0%',
yAxisFormatSecondary: ',.1f',
yAxisIndex: 1,
yAxisIndexB: 0,
},
queriesData: [queryAData, queryBData],
});
const { echartOptions } = transformProps(chartProps);
const series = echartOptions.series as SeriesOption[];
const lineSeries = getSeriesWithLabelFormatter(series, 'lineMetric');
const barSeries = getSeriesWithLabelFormatter(series, 'barMetric');
expect(formatSeriesLabel(lineSeries, [timestamp, 0.25])).toBe('0.3');
expect(formatSeriesLabel(barSeries, [timestamp, 0.5])).toBe('50%');
});
test('formats value labels correctly when y-axis assignments are reversed', () => {
const timestamp = 1704067200000;
const queryAData = createTestQueryData(
[{ __timestamp: timestamp, lineMetric: 0.25 }],
{
colnames: ['__timestamp', 'lineMetric'],
coltypes: [GenericDataType.Temporal, GenericDataType.Numeric],
label_map: { lineMetric: ['lineMetric'] },
},
);
const queryBData = createTestQueryData(
[{ __timestamp: timestamp, barMetric: 0.5 }],
{
colnames: ['__timestamp', 'barMetric'],
coltypes: [GenericDataType.Temporal, GenericDataType.Numeric],
label_map: { 'barMetric (1)': ['barMetric'] },
},
);
const chartProps = createEchartsTimeseriesTestChartProps<
EchartsMixedTimeseriesFormData,
EchartsMixedTimeseriesProps
>({
...MIXED_TIMESERIES_CHART_PROPS_DEFAULTS,
defaultQueriesData: [queryAData, queryBData],
formData: {
...formData,
groupby: [],
groupbyB: [],
metrics: ['lineMetric'],
metricsB: ['barMetric'],
showValue: true,
showValueB: true,
stack: null,
stackB: null,
x_axis: '__timestamp',
yAxisFormat: '.0%',
yAxisFormatSecondary: ',.1f',
yAxisIndex: 0,
yAxisIndexB: 1,
},
queriesData: [queryAData, queryBData],
});
const { echartOptions } = transformProps(chartProps);
const series = echartOptions.series as SeriesOption[];
const lineSeries = getSeriesWithLabelFormatter(series, 'lineMetric');
const barSeries = getSeriesWithLabelFormatter(series, 'barMetric');
expect(formatSeriesLabel(lineSeries, [timestamp, 0.25])).toBe('25%');
expect(formatSeriesLabel(barSeries, [timestamp, 0.5])).toBe('0.5');
});
test('keeps bar value label clipping aligned with the assigned y-axis', () => {
const timestamp = 1704067200000;
const queryAData = createTestQueryData(
[{ __timestamp: timestamp, lineMetric: 0.25 }],
{
colnames: ['__timestamp', 'lineMetric'],
coltypes: [GenericDataType.Temporal, GenericDataType.Numeric],
label_map: { lineMetric: ['lineMetric'] },
},
);
const queryBData = createTestQueryData(
[{ __timestamp: timestamp, barMetric: 0.5 }],
{
colnames: ['__timestamp', 'barMetric'],
coltypes: [GenericDataType.Temporal, GenericDataType.Numeric],
label_map: { 'barMetric (1)': ['barMetric'] },
},
);
const chartProps = createEchartsTimeseriesTestChartProps<
EchartsMixedTimeseriesFormData,
EchartsMixedTimeseriesProps
>({
...MIXED_TIMESERIES_CHART_PROPS_DEFAULTS,
defaultQueriesData: [queryAData, queryBData],
formData: {
...formData,
groupby: [],
groupbyB: [],
metrics: ['lineMetric'],
metricsB: ['barMetric'],
showValue: true,
showValueB: true,
stack: null,
stackB: null,
x_axis: '__timestamp',
yAxisBounds: [undefined, 1],
yAxisBoundsSecondary: [undefined, 0.1],
yAxisFormat: '.0%',
yAxisFormatSecondary: ',.1f',
yAxisIndex: 0,
yAxisIndexB: 1,
},
queriesData: [queryAData, queryBData],
});
const { echartOptions } = transformProps(chartProps);
const series = echartOptions.series as SeriesOption[];
const barSeries = getSeriesWithLabelFormatter(series, 'barMetric');
expect(formatSeriesLabel(barSeries, [timestamp, 0.5])).toBe('');
});
describe('legend sorting', () => {
const getChartProps = (overrides = {}) =>
createEchartsTimeseriesTestChartProps<

View File

@@ -44,13 +44,13 @@ const fakeTableApiResult = {
result: [
{
id: 1,
value: 'fake api result1',
value: 'fake_api_result1',
label: 'fake api label1',
type: 'table',
},
{
id: 2,
value: 'fake api result2',
value: 'fake_api_result2',
label: 'fake api label2',
type: 'table',
},
@@ -152,6 +152,64 @@ test('returns keywords including fetched function_names data', async () => {
});
});
test('quotes table identifiers that require quoting in the inserted value', async () => {
const dbFunctionNamesApiRoute = `glob:*/api/v1/database/${expectDbId}/function_names/`;
fetchMock.get(dbFunctionNamesApiRoute, fakeFunctionNamesApiResult);
act(() => {
store.dispatch(
tableApiUtil.upsertQueryData(
'tables',
{ dbId: expectDbId, schema: expectSchema },
{
options: [
{ value: 'COVID Vaccines', label: 'COVID Vaccines', type: 'table' },
{ value: 'simple_table', label: 'simple_table', type: 'table' },
],
hasMore: false,
},
),
);
});
const { result } = renderHook(
() =>
useKeywords({
queryEditorId: 'testqueryid',
dbId: expectDbId,
schema: expectSchema,
}),
{
wrapper: createWrapper({
useRedux: true,
store,
}),
},
);
await waitFor(() =>
expect(fetchMock.callHistory.calls(dbFunctionNamesApiRoute).length).toBe(1),
);
// A name that needs quoting is inserted as a double-quoted identifier,
// while its display name stays human-readable.
expect(result.current).toContainEqual(
expect.objectContaining({
name: 'COVID Vaccines',
value: '"COVID Vaccines"',
meta: 'table',
}),
);
// A simple identifier is inserted as-is, without quotes.
expect(result.current).toContainEqual(
expect.objectContaining({
name: 'simple_table',
value: 'simple_table',
meta: 'table',
}),
);
});
test('skip fetching if autocomplete skipped', () => {
const { result } = renderHook(
() =>

View File

@@ -53,6 +53,14 @@ const getHelperText = (value: string) =>
detail: value,
};
// Names that aren't simple identifiers (spaces, punctuation, leading digits)
// must be double-quoted to be valid SQL, with embedded quotes doubled.
const SIMPLE_IDENTIFIER_RE = /^[A-Za-z_][A-Za-z0-9_]*$/;
const quoteIdentifier = (identifier: string) =>
SIMPLE_IDENTIFIER_RE.test(identifier)
? identifier
: `"${identifier.replace(/"/g, '""')}"`;
const extensionsRegistry = getExtensionsRegistry();
export function useKeywords(
@@ -197,7 +205,7 @@ export function useKeywords(
() =>
allCachedTables.map(({ value, label, schema: tableSchema }) => ({
name: label,
value,
value: quoteIdentifier(value),
schema: tableSchema,
score: TABLE_AUTOCOMPLETE_SCORE,
meta: 'table',

View File

@@ -16,6 +16,7 @@ KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
# Superset WebSocket Server
A Node.js WebSocket server for sending async event data to the Superset web application frontend.
@@ -164,4 +165,4 @@ HEAD /health
## Containerization
*TODO: containerize websocket server*
_TODO: containerize websocket server_

View File

@@ -1,22 +0,0 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module.exports = {
preset: 'ts-jest',
testEnvironment: 'node',
};

File diff suppressed because it is too large Load Diff

View File

@@ -5,7 +5,7 @@
"main": "index.js",
"scripts": {
"start": "node dist/index.js start",
"test": "NODE_ENV=test jest -i spec",
"test": "npx vitest --run --dir spec",
"type": "tsc --noEmit",
"eslint": "eslint",
"lint": "npm run eslint -- . && npm run type",
@@ -28,24 +28,22 @@
"devDependencies": {
"@eslint/js": "^9.25.1",
"@types/eslint__js": "^8.42.3",
"@types/jest": "^29.5.14",
"@types/jsonwebtoken": "^9.0.10",
"@types/lodash": "^4.17.24",
"@types/node": "^25.9.3",
"@types/ws": "^8.18.1",
"@typescript-eslint/eslint-plugin": "^8.61.1",
"@typescript-eslint/parser": "^8.61.1",
"@typescript-eslint/eslint-plugin": "^8.62.0",
"@typescript-eslint/parser": "^8.62.0",
"eslint": "^10.5.0",
"eslint-config-prettier": "^10.1.8",
"eslint-plugin-lodash": "^8.0.0",
"globals": "^17.6.0",
"jest": "^29.7.0",
"prettier": "^3.8.4",
"ts-jest": "^29.4.11",
"ts-node": "^10.9.2",
"tscw-config": "^1.1.2",
"typescript": "^6.0.3",
"typescript-eslint": "^8.61.1"
"typescript-eslint": "^8.62.0",
"vitest": "^4.1.5"
},
"engines": {
"node": "^24.16.0",

View File

@@ -17,7 +17,7 @@
* under the License.
*/
import { buildConfig } from '../src/config';
import { expect, test } from '@jest/globals';
import { expect, test } from 'vitest';
test('buildConfig() builds configuration and applies env var overrides', () => {
let config = buildConfig();

View File

@@ -25,29 +25,29 @@ import {
test,
beforeEach,
afterEach,
jest,
} from '@jest/globals';
vi,
type Mock,
} from 'vitest';
import * as http from 'http';
import * as net from 'net';
import { WebSocket } from 'ws';
import * as server from '../src/index';
import { statsd } from '../src/index';
interface MockedRedisXrange {
(): Promise<server.StreamResult[]>;
}
// NOTE: these mock variables needs to start with "mock" due to
// calls to `jest.mock` being hoisted to the top of the file.
// https://jestjs.io/docs/es6-class-mocks#calling-jestmock-with-the-module-factory-parameter
const mockRedisXrange = jest.fn() as jest.MockedFunction<MockedRedisXrange>;
jest.mock('ws');
jest.mock('ioredis', () => {
return jest.fn().mockImplementation(() => {
return { xrange: mockRedisXrange, on: jest.fn() };
});
const { mockRedisXrange } = vi.hoisted(() => {
return { mockRedisXrange: vi.fn() };
});
const wsMock = WebSocket as jest.Mocked<typeof WebSocket>;
vi.mock('ws');
vi.mock('ioredis', () => {
return {
Redis: vi.fn().mockImplementation(function () {
return { xrange: mockRedisXrange, on: vi.fn() };
}),
};
});
const wsMock = WebSocket as unknown as Mock<typeof WebSocket>;
const channelId = 'bc9e040c-7b4a-4817-99b9-292832d97ec7';
const streamReturnValue: server.StreamResult[] = [
[
@@ -66,16 +66,13 @@ const streamReturnValue: server.StreamResult[] = [
],
];
import * as server from '../src/index';
import { statsd } from '../src/index';
describe('server', () => {
let statsdIncrementMock: jest.SpiedFunction<typeof statsd.increment>;
let statsdIncrementMock: Mock<typeof statsd.increment>;
beforeEach(() => {
mockRedisXrange.mockClear();
server.resetState();
statsdIncrementMock = jest.spyOn(statsd, 'increment').mockReturnValue();
statsdIncrementMock = vi.spyOn(statsd, 'increment').mockReturnValue();
});
afterEach(() => {
@@ -84,8 +81,8 @@ describe('server', () => {
describe('HTTP requests', () => {
test('services health checks', () => {
const endMock = jest.fn();
const writeHeadMock = jest.fn();
const endMock = vi.fn();
const writeHeadMock = vi.fn();
const request = {
url: '/health',
@@ -113,8 +110,8 @@ describe('server', () => {
});
test('responds with a 404 when not found', () => {
const endMock = jest.fn();
const writeHeadMock = jest.fn();
const endMock = vi.fn();
const writeHeadMock = vi.fn();
const request = {
url: '/unsupported',
@@ -245,7 +242,7 @@ describe('server', () => {
describe('processStreamResults', () => {
test('sends data to channel', async () => {
const ws = new wsMock('localhost');
const sendMock = jest.spyOn(ws, 'send');
const sendMock = vi.spyOn(ws, 'send');
const socketInstance = { ws: ws, channel: channelId, pongTs: Date.now() };
expect(statsdIncrementMock).toHaveBeenCalledTimes(0);
@@ -267,7 +264,7 @@ describe('server', () => {
test('channel not present', async () => {
const ws = new wsMock('localhost');
const sendMock = jest.spyOn(ws, 'send');
const sendMock = vi.spyOn(ws, 'send');
expect(statsdIncrementMock).toHaveBeenCalledTimes(0);
server.processStreamResults(streamReturnValue);
@@ -278,10 +275,9 @@ describe('server', () => {
test('error sending data to client', async () => {
const ws = new wsMock('localhost');
const sendMock = jest.spyOn(ws, 'send').mockImplementation(() => {
const sendMock = vi.spyOn(ws, 'send').mockImplementation(() => {
throw new Error();
});
const cleanChannelMock = jest.spyOn(server, 'cleanChannel');
const socketInstance = { ws: ws, channel: channelId, pongTs: Date.now() };
expect(statsdIncrementMock).toHaveBeenCalledTimes(0);
@@ -300,9 +296,7 @@ describe('server', () => {
);
expect(sendMock).toHaveBeenCalled();
expect(cleanChannelMock).toHaveBeenCalledWith(channelId);
cleanChannelMock.mockRestore();
expect(Object.keys(server.channels)).toHaveLength(0);
});
const makeItem = (i: number): server.StreamResult =>
@@ -330,8 +324,8 @@ describe('server', () => {
channel: channelId,
pongTs: Date.now(),
});
const sendMock = jest.spyOn(ws, 'send');
const setImmediateSpy = jest.spyOn(global, 'setImmediate');
const sendMock = vi.spyOn(ws, 'send');
const setImmediateSpy = vi.spyOn(global, 'setImmediate');
const results = [0, 1, 2, 3, 4].map(makeItem);
await server.processStreamResults(results);
@@ -351,7 +345,7 @@ describe('server', () => {
channel: channelId,
pongTs: Date.now(),
});
const sendMock = jest.spyOn(ws, 'send');
const sendMock = vi.spyOn(ws, 'send');
const results = [0, 1, 2, 3, 4].map(makeItem);
await server.processStreamResults(results);
@@ -372,16 +366,16 @@ describe('server', () => {
server.opts.maxSocketBufferBytes = 0;
// Restore any spies (e.g. on server.cleanChannel) so they don't leak
// across tests and cause order-dependent failures.
jest.restoreAllMocks();
vi.restoreAllMocks();
});
test('does not terminate when cap disabled (0)', () => {
server.opts.maxSocketBufferBytes = 0;
const ws = new wsMock('localhost');
// simulate a large outbound buffer
(ws as unknown as { bufferedAmount: number }).bufferedAmount = 10_000_000;
const terminateMock = jest.spyOn(ws, 'terminate');
const sendMock = jest.spyOn(ws, 'send');
vi.spyOn(ws, 'bufferedAmount', 'get').mockReturnValueOnce(10_000_000);
const terminateMock = vi.spyOn(ws, 'terminate');
const sendMock = vi.spyOn(ws, 'send');
server.trackClient(channelId, {
ws,
channel: channelId,
@@ -397,10 +391,9 @@ describe('server', () => {
test('terminates a slow client whose buffer exceeds the cap', () => {
server.opts.maxSocketBufferBytes = 1024;
const ws = new wsMock('localhost');
(ws as unknown as { bufferedAmount: number }).bufferedAmount = 2048;
const terminateMock = jest.spyOn(ws, 'terminate');
const sendMock = jest.spyOn(ws, 'send');
const cleanChannelMock = jest.spyOn(server, 'cleanChannel');
vi.spyOn(ws, 'bufferedAmount', 'get').mockReturnValueOnce(2048);
const terminateMock = vi.spyOn(ws, 'terminate');
const sendMock = vi.spyOn(ws, 'send');
server.trackClient(channelId, {
ws,
channel: channelId,
@@ -414,15 +407,15 @@ describe('server', () => {
expect(statsdIncrementMock).toHaveBeenCalledWith(
'ws_client_backpressure_disconnect',
);
expect(cleanChannelMock).toHaveBeenCalledWith(channelId);
expect(Object.keys(server.channels)).toHaveLength(0);
});
test('keeps sending to a client within the cap', () => {
server.opts.maxSocketBufferBytes = 1024;
const ws = new wsMock('localhost');
(ws as unknown as { bufferedAmount: number }).bufferedAmount = 16;
const terminateMock = jest.spyOn(ws, 'terminate');
const sendMock = jest.spyOn(ws, 'send');
vi.spyOn(ws, 'bufferedAmount', 'get').mockReturnValueOnce(16);
const terminateMock = vi.spyOn(ws, 'terminate');
const sendMock = vi.spyOn(ws, 'send');
server.trackClient(channelId, {
ws,
channel: channelId,
@@ -443,7 +436,7 @@ describe('server', () => {
test('success with results', async () => {
mockRedisXrange.mockResolvedValueOnce(streamReturnValue);
const cb = jest.fn() as jest.MockedFunction<
const cb = vi.fn() as Mock<
(results: server.StreamResult[]) => void | Promise<void>
>;
await server.fetchRangeFromStream({
@@ -462,7 +455,7 @@ describe('server', () => {
});
test('success no results', async () => {
const cb = jest.fn() as jest.MockedFunction<
const cb = vi.fn() as Mock<
(results: server.StreamResult[]) => void | Promise<void>
>;
await server.fetchRangeFromStream({
@@ -481,7 +474,7 @@ describe('server', () => {
});
test('error', async () => {
const cb = jest.fn() as jest.MockedFunction<
const cb = vi.fn() as Mock<
(results: server.StreamResult[]) => void | Promise<void>
>;
mockRedisXrange.mockRejectedValueOnce(new Error());
@@ -503,12 +496,8 @@ describe('server', () => {
describe('wsConnection', () => {
let ws: WebSocket;
let wsEventMock: jest.SpiedFunction<typeof ws.on>;
let trackClientSpy: jest.SpiedFunction<typeof server.trackClient>;
let fetchRangeFromStreamSpy: jest.SpiedFunction<
typeof server.fetchRangeFromStream
>;
let dateNowSpy: jest.SpiedFunction<typeof Date.now>;
let wsEventMock: Mock<typeof ws.on>;
let dateNowSpy: Mock<typeof Date.now>;
let socketInstanceExpected: server.SocketInstance;
const getRequest = (token: string, url: string): http.IncomingMessage => {
@@ -521,10 +510,8 @@ describe('server', () => {
beforeEach(() => {
ws = new wsMock('localhost');
wsEventMock = jest.spyOn(ws, 'on');
trackClientSpy = jest.spyOn(server, 'trackClient');
fetchRangeFromStreamSpy = jest.spyOn(server, 'fetchRangeFromStream');
dateNowSpy = jest
wsEventMock = vi.spyOn(ws, 'on');
dateNowSpy = vi
.spyOn(global.Date, 'now')
.mockImplementation(() =>
new Date('2021-03-10T11:01:58.135Z').valueOf(),
@@ -537,10 +524,8 @@ describe('server', () => {
});
afterEach(() => {
wsEventMock.mockRestore();
trackClientSpy.mockRestore();
fetchRangeFromStreamSpy.mockRestore();
dateNowSpy.mockRestore();
wsEventMock?.mockRestore();
dateNowSpy?.mockRestore();
});
test('invalid JWT', async () => {
@@ -558,11 +543,14 @@ describe('server', () => {
server.wsConnection(ws, request);
expect(trackClientSpy).toHaveBeenCalledWith(
channelId,
socketInstanceExpected,
);
expect(fetchRangeFromStreamSpy).not.toHaveBeenCalled();
const channelSockets = server.channels[channelId];
expect(channelSockets).toEqual({
sockets: expect.any(Array<string>),
});
expect(channelSockets.sockets).toHaveLength(1);
const socketId = channelSockets.sockets[0];
expect(server.sockets[socketId]).toEqual(socketInstanceExpected);
expect(mockRedisXrange).not.toHaveBeenCalled();
expect(wsEventMock).toHaveBeenCalledWith('pong', expect.any(Function));
});
@@ -575,12 +563,15 @@ describe('server', () => {
server.wsConnection(ws, request);
expect(trackClientSpy).toHaveBeenCalledWith(
channelId,
socketInstanceExpected,
);
const channelSockets = server.channels[channelId];
expect(channelSockets).toEqual({
sockets: expect.any(Array<string>),
});
expect(channelSockets.sockets).toHaveLength(1);
// Malformed last_id must not trigger a stream range fetch.
expect(fetchRangeFromStreamSpy).not.toHaveBeenCalled();
const socketId = channelSockets.sockets[0];
expect(server.sockets[socketId]).toEqual(socketInstanceExpected);
});
test('valid JWT, with lastId', async () => {
@@ -593,16 +584,18 @@ describe('server', () => {
server.wsConnection(ws, request);
expect(trackClientSpy).toHaveBeenCalledWith(
channelId,
socketInstanceExpected,
);
expect(fetchRangeFromStreamSpy).toHaveBeenCalledWith({
sessionId: channelId,
startId: '1615426152415-1',
endId: '+',
listener: server.processStreamResults,
const channelSockets = server.channels[channelId];
expect(channelSockets).toEqual({
sockets: expect.any(Array<string>),
});
expect(channelSockets.sockets).toHaveLength(1);
const socketId = channelSockets.sockets[0];
expect(server.sockets[socketId]).toEqual(socketInstanceExpected);
expect(mockRedisXrange).toHaveBeenCalledWith(
expect.stringContaining(channelId),
'1615426152415-1',
'+',
);
expect(wsEventMock).toHaveBeenCalledWith('pong', expect.any(Function));
});
@@ -618,16 +611,18 @@ describe('server', () => {
server.setLastFirehoseId(lastFirehoseId);
server.wsConnection(ws, request);
expect(trackClientSpy).toHaveBeenCalledWith(
channelId,
socketInstanceExpected,
);
expect(fetchRangeFromStreamSpy).toHaveBeenCalledWith({
sessionId: channelId,
startId: '1615426152415-1',
endId: lastFirehoseId,
listener: server.processStreamResults,
const channelSockets = server.channels[channelId];
expect(channelSockets).toEqual({
sockets: expect.any(Array<string>),
});
expect(channelSockets.sockets).toHaveLength(1);
const socketId = channelSockets.sockets[0];
expect(server.sockets[socketId]).toEqual(socketInstanceExpected);
expect(mockRedisXrange).toHaveBeenCalledWith(
expect.stringContaining(channelId),
'1615426152415-1',
lastFirehoseId,
);
expect(wsEventMock).toHaveBeenCalledWith('pong', expect.any(Function));
});
});
@@ -662,7 +657,7 @@ describe('server', () => {
test('total connection limit reached', () => {
server.opts.maxTotalConnections = 1;
const ws = new wsMock('localhost');
setReadyState(ws, WebSocket.OPEN);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
const socketInstance = {
ws,
channel: channelId,
@@ -677,7 +672,7 @@ describe('server', () => {
test('per-channel connection limit reached', () => {
server.opts.maxConnectionsPerChannel = 1;
const ws = new wsMock('localhost');
setReadyState(ws, WebSocket.OPEN);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
const socketInstance = {
ws,
channel: channelId,
@@ -699,7 +694,7 @@ describe('server', () => {
};
server.trackClient(channelId, socketInstance);
// simulate the socket having closed but not yet been GC'd
setReadyState(ws, WebSocket.CLOSED);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.CLOSED);
expect(server.connectionLimitReason('some-other-channel')).toBeNull();
});
@@ -713,13 +708,13 @@ describe('server', () => {
};
server.trackClient(channelId, socketInstance);
// simulate the socket having closed but not yet been GC'd
setReadyState(ws, WebSocket.CLOSED);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.CLOSED);
expect(server.connectionLimitReason(channelId)).toBeNull();
});
test('isSocketActive reflects the socket readyState', () => {
const ws = new wsMock('localhost');
setReadyState(ws, WebSocket.OPEN);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
const socketId = server.trackClient(channelId, {
ws,
channel: channelId,
@@ -727,9 +722,9 @@ describe('server', () => {
});
expect(server.isSocketActive(socketId)).toBe(true);
// CONNECTING is also considered active (see SOCKET_ACTIVE_STATES)
setReadyState(ws, WebSocket.CONNECTING);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.CONNECTING);
expect(server.isSocketActive(socketId)).toBe(true);
setReadyState(ws, WebSocket.CLOSED);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.CLOSED);
expect(server.isSocketActive(socketId)).toBe(false);
// unknown socket ids are never active
expect(server.isSocketActive('does-not-exist')).toBe(false);
@@ -737,14 +732,14 @@ describe('server', () => {
test('activeSocketCount counts only active sockets', () => {
const openWs = new wsMock('localhost');
setReadyState(openWs, WebSocket.OPEN);
vi.spyOn(openWs, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
server.trackClient(channelId, {
ws: openWs,
channel: channelId,
pongTs: Date.now(),
});
const closedWs = new wsMock('localhost');
setReadyState(closedWs, WebSocket.CLOSED);
vi.spyOn(closedWs, 'readyState', 'get').mockReturnValue(WebSocket.CLOSED);
server.trackClient(channelId, {
ws: closedWs,
channel: channelId,
@@ -755,14 +750,14 @@ describe('server', () => {
test('activeChannelSocketCount counts only active sockets on the channel', () => {
const openWs = new wsMock('localhost');
setReadyState(openWs, WebSocket.OPEN);
vi.spyOn(openWs, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
server.trackClient(channelId, {
ws: openWs,
channel: channelId,
pongTs: Date.now(),
});
const closedWs = new wsMock('localhost');
setReadyState(closedWs, WebSocket.CLOSED);
vi.spyOn(closedWs, 'readyState', 'get').mockReturnValue(WebSocket.CLOSED);
server.trackClient(channelId, {
ws: closedWs,
channel: channelId,
@@ -776,7 +771,7 @@ describe('server', () => {
test('wsConnection refuses over-limit connection without tracking', () => {
server.opts.maxConnectionsPerChannel = 1;
const existingWs = new wsMock('localhost');
setReadyState(existingWs, WebSocket.OPEN);
vi.spyOn(existingWs, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
const existing = {
ws: existingWs,
channel: channelId,
@@ -784,7 +779,7 @@ describe('server', () => {
};
server.trackClient(channelId, existing);
const trackClientSpy = jest.spyOn(server, 'trackClient');
const trackClientSpy = vi.spyOn(server, 'trackClient');
const ws = new wsMock('localhost');
const validToken = jwt.sign({ channel: channelId }, config.jwtSecret);
server.wsConnection(ws, getRequest(validToken, 'http://localhost'));
@@ -800,8 +795,8 @@ describe('server', () => {
describe('httpUpgrade', () => {
let socket: net.Socket;
let socketDestroySpy: jest.SpiedFunction<typeof socket.destroy>;
let wssUpgradeSpy: jest.SpiedFunction<typeof server.wss.handleUpgrade>;
let socketDestroySpy: Mock<typeof socket.destroy>;
let wssUpgradeSpy: Mock<typeof server.wss.handleUpgrade>;
const getRequest = (token: string, url: string): http.IncomingMessage => {
const request = new http.IncomingMessage(new net.Socket());
@@ -813,8 +808,8 @@ describe('server', () => {
beforeEach(() => {
socket = new net.Socket();
socketDestroySpy = jest.spyOn(socket, 'destroy');
wssUpgradeSpy = jest.spyOn(server.wss, 'handleUpgrade');
socketDestroySpy = vi.spyOn(socket, 'destroy');
wssUpgradeSpy = vi.spyOn(server.wss, 'handleUpgrade');
});
afterEach(() => {
@@ -952,33 +947,21 @@ describe('server', () => {
});
});
const setReadyState = (ws: WebSocket, value: typeof ws.readyState) => {
// workaround for not being able to do
// spyOn(instance,'readyState','get').and.returnValue(value);
// See for details: https://github.com/facebook/jest/issues/9675
Object.defineProperty(ws, 'readyState', {
configurable: true,
get() {
return value;
},
});
};
describe('checkSockets', () => {
let ws: WebSocket;
let pingSpy: jest.SpiedFunction<typeof ws.ping>;
let terminateSpy: jest.SpiedFunction<typeof ws.terminate>;
let pingSpy: Mock<typeof ws.ping>;
let terminateSpy: Mock<typeof ws.terminate>;
let socketInstance: server.SocketInstance;
beforeEach(() => {
ws = new wsMock('localhost');
pingSpy = jest.spyOn(ws, 'ping');
terminateSpy = jest.spyOn(ws, 'terminate');
pingSpy = vi.spyOn(ws, 'ping');
terminateSpy = vi.spyOn(ws, 'terminate');
socketInstance = { ws: ws, channel: channelId, pongTs: Date.now() };
});
test('active sockets', () => {
setReadyState(ws, WebSocket.OPEN);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
server.trackClient(channelId, socketInstance);
server.checkSockets();
@@ -989,7 +972,7 @@ describe('server', () => {
});
test('stale sockets', () => {
setReadyState(ws, WebSocket.OPEN);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
socketInstance.pongTs = Date.now() - 60000;
server.trackClient(channelId, socketInstance);
@@ -1001,7 +984,7 @@ describe('server', () => {
});
test('closed sockets', () => {
setReadyState(ws, WebSocket.CLOSED);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.CLOSED);
server.trackClient(channelId, socketInstance);
server.checkSockets();
@@ -1027,7 +1010,7 @@ describe('server', () => {
});
test('active sockets', () => {
setReadyState(ws, WebSocket.OPEN);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
server.trackClient(channelId, socketInstance);
server.cleanChannel(channelId);
@@ -1036,7 +1019,7 @@ describe('server', () => {
});
test('closing sockets', () => {
setReadyState(ws, WebSocket.CLOSING);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.CLOSING);
server.trackClient(channelId, socketInstance);
server.cleanChannel(channelId);
@@ -1045,11 +1028,12 @@ describe('server', () => {
});
test('multiple sockets', () => {
setReadyState(ws, WebSocket.OPEN);
vi.spyOn(ws, 'readyState', 'get').mockReturnValue(WebSocket.OPEN);
server.trackClient(channelId, socketInstance);
const ws2 = new wsMock('localhost');
setReadyState(ws2, WebSocket.OPEN);
const readyStateSpy = vi.spyOn(ws2, 'readyState', 'get');
readyStateSpy.mockReturnValue(WebSocket.OPEN);
const socketInstance2 = {
ws: ws2,
channel: channelId,
@@ -1061,7 +1045,7 @@ describe('server', () => {
expect(server.channels[channelId].sockets.length).toBe(2);
setReadyState(ws2, WebSocket.CLOSED);
readyStateSpy.mockReturnValue(WebSocket.CLOSED);
server.cleanChannel(channelId);
expect(server.channels[channelId].sockets.length).toBe(1);

View File

@@ -19,11 +19,11 @@
import * as http from 'http';
import * as net from 'net';
import { inspect } from 'util';
import WebSocket, { WebSocketServer } from 'ws';
import { WebSocket, WebSocketServer } from 'ws';
import { randomUUID } from 'crypto';
import jwt, { Algorithm } from 'jsonwebtoken';
import { parse } from 'cookie';
import Redis, { RedisOptions } from 'ioredis';
import { Redis, RedisOptions } from 'ioredis';
import StatsD from 'hot-shots';
import { createLogger } from './logger';

View File

@@ -16,6 +16,7 @@ KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
# Test & development utilities
The files provided here are for testing and development only, and are not required to run the WebSocket server application.

View File

@@ -16,6 +16,7 @@ KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
# Test client application
This Express web application is provided for testing the WebSocket server. It is not required for running the server application, and is provided here for testing and development purposes only.

View File

@@ -100,7 +100,10 @@ class CacheRestApi(BaseSupersetModelRestApi):
)
cache_keys = [c.cache_key for c in cache_key_objs]
if cache_key_objs:
all_keys_deleted = cache_manager.cache.delete_many(*cache_keys)
# Chart query results live in ``data_cache``, not the default
# ``cache`` — using the wrong backend silently misses the Redis
# keys when ``CACHE_KEY_PREFIX`` differs between the two configs.
all_keys_deleted = cache_manager.data_cache.delete_many(*cache_keys)
if not all_keys_deleted:
# expected behavior as keys may expire and cache is not a

View File

@@ -23,7 +23,7 @@ import sys
import urllib
from datetime import datetime
from re import Pattern
from typing import Any, TYPE_CHECKING, TypedDict
from typing import Any, Callable, TYPE_CHECKING, TypedDict
import pandas as pd
from apispec import APISpec
@@ -83,6 +83,97 @@ if TYPE_CHECKING:
logger = logging.getLogger()
# BigQuery string escape sequences keyed off documented escapes in
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#string_and_bytes_literals.
# Backslash MUST be first so subsequent escapes don't double-escape their own
# backslash. ``\?``, ``\"`` and ``\``` are valid BigQuery escapes but
# intentionally omitted because those characters do not require escaping
# inside a single-quoted literal. ``\0`` is NOT a valid BigQuery escape
# (octal escapes require exactly three digits); the null byte instead falls
# through to the ``\xhh`` fallback below.
_BIGQUERY_STRING_ESCAPES = {
"\\": "\\\\",
"'": "\\'",
"\n": "\\n",
"\r": "\\r",
"\t": "\\t",
"\b": "\\b",
"\f": "\\f",
"\v": "\\v",
"\a": "\\a",
}
def _process_string_literal(value: str) -> str:
"""
Escape a string value for use as a BigQuery SQL literal.
BigQuery requires backslash escaping for single quotes inside string
literals (``'O\\'Brien'``). Doubled single quotes (``'O''Brien'``) are
**not** valid — BigQuery parses them as two concatenated string literals
without whitespace, causing a syntax error:
``concatenated string literals must be separated by whitespace``.
BigQuery also forbids literal newlines, carriage returns, and other
control characters inside a quoted string; those must be written using
escape sequences (``\\n``, ``\\r``, ``\\t`` …). Control characters
without a named escape are emitted as a ``\\xhh`` hex escape; printable
Unicode passes through unchanged because BigQuery accepts UTF-8 inside
string literals.
The upstream ``sqlalchemy-bigquery`` dialect relies on Python's ``repr()``
to quote values, which switches to double-quote delimiters when the
string contains an apostrophe (e.g. ``repr("O'Brien")`` → ``"O'Brien"``).
Double-quoted tokens inside compiled SQL would be parsed as identifiers,
so the query also fails. This helper always produces a single-quoted
literal.
"""
parts = []
for ch in value:
escape = _BIGQUERY_STRING_ESCAPES.get(ch)
if escape is not None:
parts.append(escape)
elif ord(ch) < 0x20 or ord(ch) == 0x7F:
parts.append(f"\\x{ord(ch):02x}")
else:
parts.append(ch)
return f"'{''.join(parts)}'"
def _monkeypatch_bigquery_string_literal() -> None:
"""
Patch the sqlalchemy-bigquery dialect so that string literals containing
apostrophes are rendered correctly when ``literal_binds=True``.
Without this patch, a filter value like ``O'Brien`` is compiled as the
double-quoted identifier ``"O'Brien"`` instead of the single-quoted literal
``'O\\'Brien'``, causing BigQuery to return a syntax error.
This follows the same pattern used for the Databricks dialect fix in
``superset/db_engine_specs/databricks.py``.
"""
try:
from sqlalchemy_bigquery import BigQueryDialect
class BigQuerySafeString(types.TypeDecorator):
impl = types.String
cache_ok = True
def literal_processor(self, dialect: Any) -> Callable[[str], str]:
if dialect.name == "bigquery":
return _process_string_literal
return super().literal_processor(dialect)
BigQueryDialect.colspecs[types.String] = BigQuerySafeString
except ImportError:
pass
_monkeypatch_bigquery_string_literal()
CONNECTION_DATABASE_PERMISSIONS_REGEX = re.compile(
"Access Denied: Project (?P<project_name>.+?): User does not have "
+ "bigquery.jobs.create permission in project (?P<project>.+?)"

View File

@@ -17,10 +17,11 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, Callable, TYPE_CHECKING, TypedDict, Union
from typing import Any, Callable, cast, TYPE_CHECKING, TypedDict, Union
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from flask import g
from flask_babel import gettext as __
from marshmallow import fields, Schema
from marshmallow.validate import Range
@@ -38,12 +39,18 @@ from superset.db_engine_specs.base import (
)
from superset.db_engine_specs.hive import HiveEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import OAuth2Error
from superset.utils import json
from superset.utils.core import get_user_agent, QuerySource
from superset.utils.network import is_hostname_valid, is_port_open
if TYPE_CHECKING:
from superset.models.core import Database
from superset.superset_typing import (
OAuth2ClientConfig,
OAuth2State,
OAuth2TokenResponse,
)
try:
@@ -277,6 +284,135 @@ class DatabricksDynamicBaseEngineSpec(BasicParametersMixin, DatabricksBaseEngine
"port": "port",
}
# The Databricks SQL driver has no dedicated authentication exception, so an
# expired or missing token surfaces as a generic driver error. These case-
# insensitive substrings flag the errors that should bootstrap a re-auth.
oauth2_auth_failure_signals = (
"http 401",
"unauthorized",
"unauthenticated",
"invalid access token",
"invalid token",
"expired token",
"token expired",
)
@classmethod
def _workspace_oauth2_endpoint(cls, database: Database, path: str) -> str:
"""
Build a Databricks OAuth2 (U2M) endpoint from the workspace host.
Databricks fronts the user-to-machine OAuth2 flow on every workspace at
``https://<workspace-host>/oidc/v1/{authorize,token}`` across AWS, Azure
and GCP, so the endpoints derive directly from the connection host and
need no account or tenant identifier.
"""
host = database.url_object.host
if not host:
raise OAuth2Error(
"Databricks OAuth2 endpoint could not be resolved: the database "
"connection has no host."
)
return f"https://{host}/oidc/v1/{path}"
@classmethod
def needs_oauth2(cls, ex: Exception) -> bool:
"""
Identify driver errors that should trigger the OAuth2 dance.
Unlike Trino (``TrinoAuthError``) or GSheets (``UnauthenticatedError``),
the Databricks driver raises no dedicated auth exception, so in addition
to the base ``isinstance`` check we match the auth signals above on the
error message (mirrors ``GSheetsEngineSpec.needs_oauth2``).
"""
if not (g and hasattr(g, "user")):
return False
if isinstance(ex, cls.oauth2_exception):
return True
message = str(ex).lower()
return any(signal in message for signal in cls.oauth2_auth_failure_signals)
@classmethod
def get_oauth2_authorization_uri(
cls,
config: "OAuth2ClientConfig",
state: "OAuth2State",
code_verifier: str | None = None,
) -> str:
"""
Return the URI for the initial OAuth2 request.
A fully-resolved ``authorization_request_uri`` from
``DATABASE_OAUTH2_CLIENTS`` is preserved; otherwise the endpoint is
derived from the workspace host (``https://<host>/oidc/v1/authorize``),
which is valid on AWS, Azure and GCP.
"""
if not config.get("authorization_request_uri"):
from superset import db
from superset.models.core import Database
database_id = state["database_id"]
if database := db.session.get(Database, database_id):
config = cast(
"OAuth2ClientConfig",
dict(config)
| {
"authorization_request_uri": cls._workspace_oauth2_endpoint(
database, "authorize"
)
},
)
return super().get_oauth2_authorization_uri(config, state, code_verifier)
@classmethod
def get_oauth2_token(
cls,
config: "OAuth2ClientConfig",
code: str,
code_verifier: str | None = None,
) -> "OAuth2TokenResponse":
"""
Exchange the authorization code for refresh/access tokens.
Token exchange runs in a separate request with no database context, so
the workspace host is not available to derive the endpoint here. Require
a configured ``token_request_uri``
(``https://<workspace-host>/oidc/v1/token``) and fail fast rather than
POST to an unresolved endpoint.
"""
if not config.get("token_request_uri"):
raise OAuth2Error(
"Databricks OAuth2 token endpoint is not configured: set "
"`token_request_uri` to https://<workspace-host>/oidc/v1/token "
"in DATABASE_OAUTH2_CLIENTS."
)
return super().get_oauth2_token(config, code, code_verifier)
@classmethod
def impersonate_user(
cls,
database: Database,
username: str | None,
user_token: str | None,
url: URL,
engine_kwargs: dict[str, Any],
) -> tuple[URL, dict[str, Any]]:
"""
Update connection with OAuth2 access token for user impersonation.
"""
if user_token:
# Replace the access token in the URL with the user's OAuth2 token
url = url.set(password=user_token)
# Also update connect_args if they contain access token
connect_args = engine_kwargs.setdefault("connect_args", {})
if "access_token" in connect_args:
connect_args["access_token"] = user_token
return url, engine_kwargs
@staticmethod
def get_extra_params(
database: Database, source: QuerySource | None = None
@@ -474,6 +610,16 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
supports_dynamic_catalog = True
supports_cross_catalog_queries = True
# OAuth 2.0 support. The flow (endpoint resolution from the workspace host,
# `needs_oauth2` detection) is shared via `DatabricksDynamicBaseEngineSpec`.
supports_oauth2 = True
oauth2_scope = "sql"
# Authorization endpoint is derived from the workspace host at runtime; the
# token endpoint must be configured (no DB context at exchange time).
oauth2_authorization_request_uri = ""
oauth2_token_request_uri = ""
@classmethod
def build_sqlalchemy_uri( # type: ignore
cls, parameters: DatabricksNativeParametersType, *_
@@ -685,6 +831,16 @@ class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
supports_dynamic_schema = supports_catalog = supports_dynamic_catalog = True
# OAuth 2.0 support. The flow (endpoint resolution from the workspace host,
# `needs_oauth2` detection) is shared via `DatabricksDynamicBaseEngineSpec`.
supports_oauth2 = True
oauth2_scope = "sql"
# Authorization endpoint is derived from the workspace host at runtime; the
# token endpoint must be configured (no DB context at exchange time).
oauth2_authorization_request_uri = ""
oauth2_token_request_uri = ""
@classmethod
def build_sqlalchemy_uri( # type: ignore
cls, parameters: DatabricksPythonConnectorParametersType, *_

View File

@@ -303,6 +303,7 @@ DEFAULT_GET_DASHBOARD_INFO_COLUMNS: List[str] = [
"created_on",
"changed_on",
"uuid",
"embedded_uuid",
"url",
"created_on_humanized",
"changed_on_humanized",
@@ -427,6 +428,18 @@ class DashboardInfo(BaseModel):
created_on: str | datetime | None = None
changed_on: str | datetime | None = None
uuid: str | None = None
embedded_uuid: str | None = Field(
None,
description=(
"Embedded UUID for this dashboard. This is the UUID required when "
"generating guest tokens for embedded dashboards "
"(resources[].id in the guest token payload). "
"Only present when the dashboard has been configured for embedding "
"via the Embed Dashboard UI. Distinct from `uuid` (the internal "
"dashboard UUID) — using the wrong one causes 403 errors in guest "
"token validation."
),
)
url: str | None = None
created_on_humanized: str | None = None
changed_on_humanized: str | None = None
@@ -1352,6 +1365,9 @@ def dashboard_serializer(dashboard: "Dashboard") -> DashboardInfo:
created_on=dashboard.created_on,
changed_on=dashboard.changed_on,
uuid=str(dashboard.uuid) if dashboard.uuid else None,
embedded_uuid=str(dashboard.embedded[0].uuid)
if dashboard.embedded
else None,
url=absolute_url,
created_on_humanized=dashboard.created_on_humanized,
changed_on_humanized=dashboard.changed_on_humanized,

View File

@@ -155,10 +155,11 @@ async def get_dashboard_info(
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
# Eager load slices and tags to avoid N+1 queries during serialization.
# Eager load slices, tags, and embedded to avoid N+1 queries.
eager_options = [
subqueryload(Dashboard.slices).subqueryload(Slice.tags),
subqueryload(Dashboard.tags),
subqueryload(Dashboard.embedded),
]
with event_logger.log_context(action="mcp.get_dashboard_info.lookup"):

View File

@@ -2963,6 +2963,17 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
tp = self.get_template_processor()
processed_expression = self._process_expression_template(expression, tp)
# Apply the same parsing policy used for stored adhoc column and
# metric expressions (single statement, no set operations, and no
# sub-queries unless ALLOW_ADHOC_SUBQUERY is enabled), so expression
# validation follows one policy across the query pipeline. Imported
# locally to avoid a circular import with the connectors package.
from superset.connectors.sqla.models import validate_stored_expression
validate_stored_expression(
self.database, self.catalog, self.schema or "", processed_expression
)
# Build validation query
tbl, cte = self.get_from_clause(tp)
validation_query = self._build_validation_query(

View File

@@ -3164,9 +3164,6 @@ msgstr "Culoare după"
msgid "Color for breakpoint"
msgstr "Culoare pentru punct de întrerupere"
msgid "Color Metric"
msgstr "Indicator culoare"
msgid "Color of the source location"
msgstr "Culoarea locației sursă"
@@ -4644,7 +4641,7 @@ msgid "Deleted %(num)d theme"
msgid_plural "Deleted %(num)d themes"
msgstr[0] "Ștearsă %(num)d temă"
msgstr[1] "Șterse %(num)d teme"
msgstr[2] "Șterse %(num)d de teme""
msgstr[2] "Șterse %(num)d de teme"
#, python-format
msgid "Deleted %s"
@@ -10218,7 +10215,7 @@ msgid "Saved expressions"
msgstr "Expresii salvate"
msgid "Saved metric"
msgstr "Indicator salvat""
msgstr "Indicator salvat"
msgid "Saved queries"
msgstr "Interogări salvate"
@@ -13422,7 +13419,7 @@ msgid "This was triggered by:"
msgid_plural "This may be triggered by:"
msgstr[0] "Aceasta a fost declanșată de:"
msgstr[1] "Acestea pot fi declanșate de:"
msgstr[2] "Acestea pot fi declanșate de către:""
msgstr[2] "Acestea pot fi declanșate de către:"
msgid ""
"This will be applied to the whole table. Arrows (↑ and ↓) will be added "
@@ -14447,7 +14444,7 @@ msgstr ""
"Vizualizează un indicator corelat pentru perechi de grupuri. Hartile "
"termice (Heatmaps) excelează în evidențierea corelației sau a "
"intensității dintre două grupuri. Culoarea este utilizată pentru a "
"sublinia intensitatea legăturii dintre fiecare pereche de grupuri.""
"sublinia intensitatea legăturii dintre fiecare pereche de grupuri."
msgid ""
"Visualize geospatial data like 3D buildings, landscapes, or objects in "

View File

@@ -54,6 +54,13 @@ NUMPY_FUNCTIONS: dict[str, Callable[..., Any]] = {
"var": np.var,
}
# Operators that pandas GroupBy.agg accepts as string names. Passing the string
# avoids a FutureWarning raised when pandas receives a numpy callable it internally
# maps to its own method (e.g. np.mean → SeriesGroupBy.mean).
_PANDAS_STRING_AGGREGATORS: frozenset[str] = frozenset(
{"max", "mean", "median", "min", "prod", "std", "sum", "var"}
)
DENYLIST_ROLLING_FUNCTIONS = (
"count",
"corr",
@@ -166,7 +173,7 @@ def _get_aggregate_funcs(
)
operator = agg_obj["operator"]
if callable(operator):
aggfunc = operator
aggfunc: str | Callable[..., Any] = operator
else:
func = NUMPY_FUNCTIONS.get(operator)
if not func:
@@ -177,7 +184,10 @@ def _get_aggregate_funcs(
)
)
options = agg_obj.get("options", {})
aggfunc = partial(func, **options)
if not options and operator in _PANDAS_STRING_AGGREGATORS:
aggfunc = operator
else:
aggfunc = partial(func, **options)
agg_funcs[name] = NamedAgg(column=column, aggfunc=aggfunc)
return agg_funcs

View File

@@ -440,9 +440,9 @@ class BaseViz: # pylint: disable=too-many-public-methods
"metrics": metrics,
"row_limit": row_limit,
"filter": self.form_data.get("filters", []),
"timeseries_limit": limit,
"series_limit": limit,
"extras": extras,
"timeseries_limit_metric": timeseries_limit_metric,
"series_limit_metric": timeseries_limit_metric,
"order_desc": order_desc,
}

View File

@@ -18,6 +18,7 @@
"""Unit tests for Superset"""
from typing import Any
from unittest.mock import patch
import pytest
@@ -52,17 +53,43 @@ def test_invalidate_cache(invalidate):
def test_invalidate_existing_cache(invalidate):
db.session.add(CacheKey(cache_key="cache_key", datasource_uid="3__table"))
db.session.commit()
cache_manager.cache.set("cache_key", "value")
cache_manager.data_cache.set("cache_key", "value")
rv = invalidate({"datasource_uids": ["3__table"]})
assert rv.status_code == 201
assert cache_manager.cache.get("cache_key") is None # noqa: E711
assert cache_manager.data_cache.get("cache_key") is None # noqa: E711
assert (
not db.session.query(CacheKey).filter(CacheKey.cache_key == "cache_key").first()
)
def test_invalidate_uses_data_cache_not_default_cache(invalidate):
"""Regression test for #40489.
Chart query results are written through ``cache_manager.data_cache``
(``DATA_CACHE_CONFIG``). When ``CACHE_CONFIG`` and ``DATA_CACHE_CONFIG``
use distinct ``CACHE_KEY_PREFIX`` values, deleting via the default
``cache_manager.cache`` silently misses the underlying Redis keys
because flask-caching prepends the wrong prefix to the DEL call.
"""
db.session.add(CacheKey(cache_key="cache_key", datasource_uid="3__table"))
db.session.commit()
with (
patch.object(cache_manager.data_cache, "delete_many") as data_delete,
patch.object(cache_manager.cache, "delete_many") as default_delete,
):
data_delete.return_value = True
rv = invalidate({"datasource_uids": ["3__table"]})
assert rv.status_code == 201
# Chart-data cache backend (the one that wrote the keys) must be hit.
data_delete.assert_called_once_with("cache_key")
# The default cache must NOT be touched — that's the #40489 regression.
default_delete.assert_not_called()
def test_invalidate_cache_empty_input(invalidate):
rv = invalidate({"datasource_uids": []})
assert rv.status_code == 201
@@ -111,10 +138,10 @@ def test_invalidate_existing_caches(invalidate):
db.session.add(CacheKey(cache_key="cache_keyX", datasource_uid="X__table"))
db.session.commit()
cache_manager.cache.set("cache_key1", "value")
cache_manager.cache.set("cache_key2", "value")
cache_manager.cache.set("cache_key4", "value")
cache_manager.cache.set("cache_keyX", "value")
cache_manager.data_cache.set("cache_key1", "value")
cache_manager.data_cache.set("cache_key2", "value")
cache_manager.data_cache.set("cache_key4", "value")
cache_manager.data_cache.set("cache_keyX", "value")
rv = invalidate(
{
@@ -155,10 +182,10 @@ def test_invalidate_existing_caches(invalidate):
)
assert rv.status_code == 201
assert cache_manager.cache.get("cache_key1") is None
assert cache_manager.cache.get("cache_key2") is None
assert cache_manager.cache.get("cache_key4") is None
assert cache_manager.cache.get("cache_keyX") == "value"
assert cache_manager.data_cache.get("cache_key1") is None
assert cache_manager.data_cache.get("cache_key2") is None
assert cache_manager.data_cache.get("cache_key4") is None
assert cache_manager.data_cache.get("cache_keyX") == "value"
assert (
not db.session.query(CacheKey)
.filter(CacheKey.cache_key.in_({"cache_key1", "cache_key2", "cache_key4"}))

View File

@@ -767,3 +767,265 @@ def test_fetch_data_converts_bigquery_row_objects(mocker: MockerFixture) -> None
assert result == [(1, "a"), (2, "b")]
assert flask_g.bq_memory_limited is False
def test_string_literal_with_apostrophe() -> None:
"""
Test that string literals containing apostrophes are properly escaped
for BigQuery using backslash escaping.
BigQuery requires backslash escaping for single quotes ('O\\'Brien').
Doubled single quotes ('O''Brien') are NOT valid — BigQuery parses them
as two concatenated string literals, causing a syntax error.
The upstream sqlalchemy-bigquery dialect uses ``repr()`` which switches
to double-quote delimiters when the value contains an apostrophe.
Double-quoted tokens are identifiers in BigQuery, causing syntax errors.
"""
from sqlalchemy import column as sa_column
from superset.db_engine_specs.bigquery import BigQueryEngineSpec # noqa: F811
# Trigger module load to ensure the monkey-patch is applied
assert BigQueryEngineSpec is not None
dialect = BigQueryDialect()
stmt = select(sa_column("name")).where(sa_column("name") == "Fernando's")
compiled_sql = str(
stmt.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
)
# The compiled SQL must use single-quoted literal with backslash-escaped
# apostrophes. Doubled single quotes are NOT valid in BigQuery.
assert "= 'Fernando\\'s'" in compiled_sql
# Must NOT contain doubled-quote escaping (BigQuery rejects this)
assert "''" not in compiled_sql
# Must NOT contain double-quoted identifiers
assert '\\"' not in compiled_sql
def test_string_literal_without_apostrophe() -> None:
"""
Test that normal string literals (without apostrophes) still compile
correctly after the monkey-patch.
"""
from sqlalchemy import column as sa_column
from superset.db_engine_specs.bigquery import BigQueryEngineSpec # noqa: F811
assert BigQueryEngineSpec is not None
dialect = BigQueryDialect()
stmt = select(sa_column("name")).where(sa_column("name") == "Fernando")
compiled_sql = str(
stmt.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
)
assert "= 'Fernando'" in compiled_sql
def test_string_literal_in_filter_with_apostrophe() -> None:
"""
Test that IN filters with apostrophes in values compile correctly
using backslash escaping.
"""
from sqlalchemy import column as sa_column
from superset.db_engine_specs.bigquery import BigQueryEngineSpec # noqa: F811
assert BigQueryEngineSpec is not None
dialect = BigQueryDialect()
stmt = select(sa_column("name")).where(
sa_column("name").in_(["Fernando's", "O'Brien"])
)
compiled_sql = str(
stmt.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
)
assert "'Fernando\\'s'" in compiled_sql
assert "'O\\'Brien'" in compiled_sql
# Must NOT contain doubled-quote escaping
assert "''" not in compiled_sql
def test_process_string_literal_directly() -> None:
"""
Test _process_string_literal covers backslash escaping for apostrophes,
control-character escaping (newline/CR/tab/etc.), the ``\\xhh`` fallback
for control chars without a named escape, and pass-through for printable
Unicode and other characters BigQuery accepts unescaped.
"""
from superset.db_engine_specs.bigquery import _process_string_literal
# Plain values
assert _process_string_literal("hello") == "'hello'"
assert _process_string_literal("") == "''"
# Apostrophes (the original fix)
assert _process_string_literal("O'Brien") == "'O\\'Brien'"
assert _process_string_literal("it's a test") == "'it\\'s a test'"
# Backslashes must be escaped before apostrophes
assert _process_string_literal("C:\\path") == "'C:\\\\path'"
assert _process_string_literal("it's C:\\path") == "'it\\'s C:\\\\path'"
# Literal backslash followed by 'n' (two characters, not a newline)
# must produce the two-char sequence '\\n' (escaped backslash + n) so
# BigQuery does not misread it as a newline escape.
assert _process_string_literal("\\n") == "'\\\\n'"
# Control characters must be escaped using named escapes — BigQuery
# rejects literal control characters inside quoted strings.
assert _process_string_literal("foo\nbar") == "'foo\\nbar'"
assert _process_string_literal("foo\rbar") == "'foo\\rbar'"
assert _process_string_literal("foo\tbar") == "'foo\\tbar'"
assert _process_string_literal("a\bb\fc\vd\ae") == "'a\\bb\\fc\\vd\\ae'"
# Control characters without a named escape fall through to ``\\xhh``.
assert _process_string_literal("null\0byte") == "'null\\x00byte'"
assert _process_string_literal("a\x01b") == "'a\\x01b'"
assert _process_string_literal("a\x1bb") == "'a\\x1bb'"
assert _process_string_literal("a\x7fb") == "'a\\x7fb'"
# Double quotes do NOT need escaping in single-quoted BigQuery literals.
assert _process_string_literal('say "hello"') == "'say \"hello\"'"
# Printable Unicode and percent signs pass through unchanged.
assert _process_string_literal("café") == "'café'"
assert _process_string_literal("日本") == "'日本'"
assert _process_string_literal("100%") == "'100%'"
# Combined: apostrophe + newline + backslash + unicode.
assert _process_string_literal("it's\nC:\\café") == "'it\\'s\\nC:\\\\café'"
def test_process_string_literal_no_literal_control_chars() -> None:
"""
Regression test for the issue raised in PR #38835 review: BigQuery
rejects literal control characters inside quoted string literals, so the
output must never contain them as literal characters.
"""
from superset.db_engine_specs.bigquery import _process_string_literal
for char in ["\n", "\r", "\t", "\b", "\f", "\v", "\a", "\0", "\x01", "\x7f"]:
result = _process_string_literal(f"prefix{char}suffix")
assert char not in result, (
f"Literal {char!r} leaked into output {result!r}; "
"BigQuery would reject this literal."
)
def test_string_literal_with_newline_in_filter() -> None:
"""
End-to-end regression test for @rusackas's review feedback on PR #38835:
a filter value containing a newline must compile to valid BigQuery SQL
using the ``\\n`` escape sequence, not a literal newline.
"""
from sqlalchemy import column as sa_column
from superset.db_engine_specs.bigquery import BigQueryEngineSpec # noqa: F811
assert BigQueryEngineSpec is not None
dialect = BigQueryDialect()
stmt = select(sa_column("note")).where(sa_column("note") == "line1\nline2")
compiled_sql = str(
stmt.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
)
# Must use the escape sequence form, not a literal newline.
assert "'line1\\nline2'" in compiled_sql
assert "\n" not in compiled_sql.split("note")[-1]
def test_literal_processor_non_bigquery_dialect() -> None:
"""
Test that BigQuerySafeString.literal_processor falls back to the parent
implementation when used with a non-BigQuery dialect.
"""
from sqlalchemy import create_engine
from superset.db_engine_specs.bigquery import (
_monkeypatch_bigquery_string_literal, # noqa: F811
)
_monkeypatch_bigquery_string_literal()
safe_cls = BigQueryDialect.colspecs[sqltypes.String]
instance = safe_cls()
# Use a non-BigQuery dialect (sqlite)
sqlite_dialect = create_engine("sqlite://").dialect
processor = instance.literal_processor(sqlite_dialect)
# The fallback processor should still produce a valid quoted string
assert processor is not None
def test_monkeypatch_is_applied() -> None:
"""
Test that _monkeypatch_bigquery_string_literal installs the custom
type decorator into BigQueryDialect.colspecs.
"""
from sqlalchemy.sql import sqltypes as sa_sqltypes
from superset.db_engine_specs.bigquery import (
BigQueryEngineSpec, # noqa: F811
)
assert BigQueryEngineSpec is not None
colspecs = BigQueryDialect.colspecs
assert sa_sqltypes.String in colspecs
safe_cls = colspecs[sa_sqltypes.String]
assert safe_cls.__name__ == "BigQuerySafeString"
def test_literal_processor_returns_process_string_literal_for_bigquery() -> None:
"""
Test that BigQuerySafeString.literal_processor returns the
_process_string_literal function when given a BigQuery dialect,
and that calling it produces correctly escaped output.
"""
from superset.db_engine_specs.bigquery import (
_monkeypatch_bigquery_string_literal,
_process_string_literal,
)
_monkeypatch_bigquery_string_literal()
safe_cls = BigQueryDialect.colspecs[sqltypes.String]
instance = safe_cls()
dialect = BigQueryDialect()
processor = instance.literal_processor(dialect)
assert processor is _process_string_literal
assert processor("O'Brien") == "'O\\'Brien'"
assert processor("plain") == "'plain'"
def test_monkeypatch_handles_missing_bigquery_package() -> None:
"""
Test that _monkeypatch_bigquery_string_literal gracefully handles
the case where sqlalchemy_bigquery is not installed.
"""
import builtins
from superset.db_engine_specs.bigquery import (
_monkeypatch_bigquery_string_literal,
)
original_import = builtins.__import__
def mock_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == "sqlalchemy_bigquery":
raise ImportError("mocked missing package")
return original_import(name, *args, **kwargs)
with mock.patch("builtins.__import__", side_effect=mock_import):
# Should not raise — the except ImportError branch handles it
_monkeypatch_bigquery_string_literal()

View File

@@ -17,14 +17,23 @@
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
from datetime import datetime
from typing import Optional
from typing import Any, Optional
from urllib.parse import parse_qs, urlparse
import pytest
from pytest_mock import MockerFixture
from sqlalchemy.engine.url import make_url
from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
from superset.db_engine_specs.base import OAuth2State
from superset.db_engine_specs.databricks import (
DatabricksNativeEngineSpec,
DatabricksPythonConnectorEngineSpec,
)
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.superset_typing import OAuth2ClientConfig
from superset.utils import json
from superset.utils.oauth2 import decode_oauth2_state
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm # noqa: F401
@@ -291,3 +300,595 @@ def test_get_prequeries(mocker: MockerFixture) -> None:
"USE CATALOG `evil`` USE CATALOG bad`",
"USE SCHEMA `evil`` USE SCHEMA bad`",
]
# OAuth2 Tests
def test_oauth2_attributes() -> None:
"""
Test that OAuth2 attributes are properly set for both engine specs.
"""
# Test DatabricksNativeEngineSpec
assert DatabricksNativeEngineSpec.supports_oauth2 is True
assert DatabricksNativeEngineSpec.oauth2_scope == "sql"
# The authorization endpoint is derived from the workspace host at runtime;
# the token endpoint must be configured explicitly.
assert DatabricksNativeEngineSpec.oauth2_authorization_request_uri == ""
assert DatabricksNativeEngineSpec.oauth2_token_request_uri == ""
# Test DatabricksPythonConnectorEngineSpec
assert DatabricksPythonConnectorEngineSpec.supports_oauth2 is True
assert DatabricksPythonConnectorEngineSpec.oauth2_scope == "sql"
assert DatabricksPythonConnectorEngineSpec.oauth2_authorization_request_uri == ""
assert DatabricksPythonConnectorEngineSpec.oauth2_token_request_uri == ""
@pytest.mark.parametrize(
"spec",
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
)
@pytest.mark.parametrize(
"message",
[
"Error during request to server: HTTP 401 Unauthorized",
"Invalid access token",
"The access token expired",
"UNAUTHENTICATED: token is no longer valid",
],
)
def test_needs_oauth2_detects_auth_failure_from_message(
mocker: MockerFixture,
spec: Any,
message: str,
) -> None:
"""
The Databricks driver has no dedicated auth exception, so `needs_oauth2`
matches auth-failure signals in the error message to bootstrap a re-auth.
"""
g = mocker.patch("superset.db_engine_specs.databricks.g")
g.user = mocker.MagicMock()
assert spec.needs_oauth2(Exception(message)) is True
@pytest.mark.parametrize(
"spec",
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
)
@pytest.mark.parametrize(
"message",
[
"Table not found",
# A bare 401 in an unrelated position must not look like an auth failure.
"Query failed at line 401: syntax error",
],
)
def test_needs_oauth2_ignores_unrelated_errors(
mocker: MockerFixture,
spec: Any,
message: str,
) -> None:
"""
A non-auth driver error must not trigger the OAuth2 dance.
"""
g = mocker.patch("superset.db_engine_specs.databricks.g")
g.user = mocker.MagicMock()
assert spec.needs_oauth2(Exception(message)) is False
@pytest.mark.parametrize(
"spec",
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
)
def test_needs_oauth2_matches_oauth2_redirect_error(
mocker: MockerFixture,
spec: Any,
) -> None:
"""
The inherited `isinstance` check against `oauth2_exception` still holds.
"""
g = mocker.patch("superset.db_engine_specs.databricks.g")
g.user = mocker.MagicMock()
ex = OAuth2RedirectError("https://example/authorize", "tab", "redirect")
assert spec.needs_oauth2(ex) is True
def test_impersonate_user_with_token(mocker: MockerFixture) -> None:
"""
Test impersonate_user method with OAuth2 token for DatabricksNativeEngineSpec.
"""
database = mocker.MagicMock()
original_url = make_url(
"databricks+connector://token:original-token@host:443/database"
)
engine_kwargs = {"connect_args": {"access_token": "original-token"}}
# Test with user token
url, kwargs = DatabricksNativeEngineSpec.impersonate_user(
database=database,
username="user1",
user_token="user-oauth-token", # noqa: S106
url=original_url,
engine_kwargs=engine_kwargs,
)
# Check that the password (token) was updated in the URL
assert url.password == "user-oauth-token" # noqa: S105
# Check that access_token was updated in connect_args
assert kwargs["connect_args"]["access_token"] == "user-oauth-token" # noqa: S105
def test_impersonate_user_without_token(mocker: MockerFixture) -> None:
"""
Test impersonate_user method without OAuth2 token.
"""
database = mocker.MagicMock()
original_url = make_url(
"databricks+connector://token:original-token@host:443/database"
)
engine_kwargs = {"connect_args": {"access_token": "original-token"}}
# Test without user token
url, kwargs = DatabricksNativeEngineSpec.impersonate_user(
database=database,
username="user1",
user_token=None,
url=original_url,
engine_kwargs=engine_kwargs,
)
# Check that nothing was changed
assert url.password == "original-token" # noqa: S105
assert kwargs["connect_args"]["access_token"] == "original-token" # noqa: S105
def test_impersonate_user_python_connector(mocker: MockerFixture) -> None:
"""
Test impersonate_user method for DatabricksPythonConnectorEngineSpec.
"""
database = mocker.MagicMock()
original_url = make_url(
"databricks://token:original-token@host:443?http_path=path&catalog=main&schema=default"
)
engine_kwargs = {"connect_args": {"access_token": "original-token"}}
# Test with user token
url, kwargs = DatabricksPythonConnectorEngineSpec.impersonate_user(
database=database,
username="user1",
user_token="user-oauth-token", # noqa: S106
url=original_url,
engine_kwargs=engine_kwargs,
)
# Check that the password (token) was updated in the URL
assert url.password == "user-oauth-token" # noqa: S105
# Check that access_token was updated in connect_args
assert kwargs["connect_args"]["access_token"] == "user-oauth-token" # noqa: S105
@pytest.fixture
def oauth2_config_native() -> OAuth2ClientConfig:
"""
Config for Databricks Native OAuth2.
"""
return {
"id": "databricks-client-id",
"secret": "databricks-client-secret",
"scope": "sql",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/authorize",
"token_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
"request_content_type": "json",
}
@pytest.fixture
def oauth2_config_python() -> OAuth2ClientConfig:
"""
Config for Databricks Python Connector OAuth2.
"""
return {
"id": "databricks-client-id",
"secret": "databricks-client-secret",
"scope": "sql",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/authorize",
"token_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
"request_content_type": "json",
}
def test_is_oauth2_enabled_no_config_native(mocker: MockerFixture) -> None:
"""
Test `is_oauth2_enabled` when OAuth2 is not configured for Native engine.
"""
mocker.patch(
"flask.current_app.config",
new={"DATABASE_OAUTH2_CLIENTS": {}},
)
assert DatabricksNativeEngineSpec.is_oauth2_enabled() is False
def test_is_oauth2_enabled_config_native(mocker: MockerFixture) -> None:
"""
Test `is_oauth2_enabled` when OAuth2 is configured for Native engine.
"""
mocker.patch(
"flask.current_app.config",
new={
"DATABASE_OAUTH2_CLIENTS": {
"Databricks (legacy)": {
"id": "client-id",
"secret": "client-secret",
},
}
},
)
assert DatabricksNativeEngineSpec.is_oauth2_enabled() is True
def test_is_oauth2_enabled_no_config_python(mocker: MockerFixture) -> None:
"""
Test `is_oauth2_enabled` when OAuth2 is not configured for Python Connector engine.
"""
mocker.patch(
"flask.current_app.config",
new={"DATABASE_OAUTH2_CLIENTS": {}},
)
assert DatabricksPythonConnectorEngineSpec.is_oauth2_enabled() is False
def test_is_oauth2_enabled_config_python(mocker: MockerFixture) -> None:
"""
Test `is_oauth2_enabled` when OAuth2 is configured for Python Connector engine.
"""
mocker.patch(
"flask.current_app.config",
new={
"DATABASE_OAUTH2_CLIENTS": {
"Databricks": {
"id": "client-id",
"secret": "client-secret",
},
}
},
)
assert DatabricksPythonConnectorEngineSpec.is_oauth2_enabled() is True
def test_get_oauth2_authorization_uri_native(
mocker: MockerFixture,
oauth2_config_native: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_authorization_uri` for Native engine.
"""
from superset.db_engine_specs.base import OAuth2State
state: OAuth2State = {
"database_id": 1,
"user_id": 1,
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"tab_id": "1234",
}
url = DatabricksNativeEngineSpec.get_oauth2_authorization_uri(
oauth2_config_native, state
)
parsed = urlparse(url)
assert parsed.netloc == "accounts.cloud.databricks.com"
assert parsed.path == "/oidc/accounts/12345/v1/authorize"
query = parse_qs(parsed.query)
assert query["scope"][0] == "sql"
encoded_state = query["state"][0].replace("%2E", ".")
assert decode_oauth2_state(encoded_state) == state
def test_get_oauth2_authorization_uri_python(
mocker: MockerFixture,
oauth2_config_python: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_authorization_uri` for Python Connector engine.
"""
from superset.db_engine_specs.base import OAuth2State
state: OAuth2State = {
"database_id": 1,
"user_id": 1,
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"tab_id": "1234",
}
url = DatabricksPythonConnectorEngineSpec.get_oauth2_authorization_uri(
oauth2_config_python, state
)
parsed = urlparse(url)
assert parsed.netloc == "accounts.cloud.databricks.com"
assert parsed.path == "/oidc/accounts/12345/v1/authorize"
query = parse_qs(parsed.query)
assert query["scope"][0] == "sql"
encoded_state = query["state"][0].replace("%2E", ".")
assert decode_oauth2_state(encoded_state) == state
def test_get_oauth2_token_native(
mocker: MockerFixture,
oauth2_config_native: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_token` for Native engine.
"""
requests = mocker.patch("superset.db_engine_specs.base.requests")
requests.post().json.return_value = {
"access_token": "access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
assert DatabricksNativeEngineSpec.get_oauth2_token(
oauth2_config_native, "authorization-code"
) == {
"access_token": "access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
requests.post.assert_called_with(
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
json={
"code": "authorization-code",
"client_id": "databricks-client-id",
"client_secret": "databricks-client-secret",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"grant_type": "authorization_code",
},
timeout=30.0,
)
def test_get_oauth2_token_python(
mocker: MockerFixture,
oauth2_config_python: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_token` for Python Connector engine.
"""
requests = mocker.patch("superset.db_engine_specs.base.requests")
requests.post().json.return_value = {
"access_token": "access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
assert DatabricksPythonConnectorEngineSpec.get_oauth2_token(
oauth2_config_python, "authorization-code"
) == {
"access_token": "access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
requests.post.assert_called_with(
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
json={
"code": "authorization-code",
"client_id": "databricks-client-id",
"client_secret": "databricks-client-secret",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"grant_type": "authorization_code",
},
timeout=30.0,
)
def test_get_oauth2_fresh_token_native(
mocker: MockerFixture,
oauth2_config_native: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_fresh_token` for Native engine.
"""
requests = mocker.patch("superset.db_engine_specs.base.requests")
requests.post().json.return_value = {
"access_token": "new-access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "new-refresh-token",
}
assert DatabricksNativeEngineSpec.get_oauth2_fresh_token(
oauth2_config_native, "old-refresh-token"
) == {
"access_token": "new-access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "new-refresh-token",
}
requests.post.assert_called_with(
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
json={
"client_id": "databricks-client-id",
"client_secret": "databricks-client-secret",
"refresh_token": "old-refresh-token",
"grant_type": "refresh_token",
},
timeout=30.0,
)
def _oauth2_state() -> OAuth2State:
"""
Build the default OAuth2 state shared by the OAuth2 tests.
"""
state: OAuth2State = {
"database_id": 1,
"user_id": 1,
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"tab_id": "1234",
}
return state
def _unresolved_oauth2_config() -> OAuth2ClientConfig:
"""
Config as built by `get_oauth2_config` when no endpoints are overridden:
the URIs default to the spec's empty class attributes.
"""
return {
"id": "databricks-client-id",
"secret": "databricks-client-secret",
"scope": "sql",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "",
"token_request_uri": "",
"request_content_type": "json",
}
@pytest.mark.parametrize(
"spec",
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
)
@pytest.mark.parametrize(
"host",
[
"dbc-abc.cloud.databricks.com",
"adb-123456789.12.azuredatabricks.net",
"123456789.gcp.databricks.com",
],
)
def test_get_oauth2_authorization_uri_derives_from_workspace_host(
mocker: MockerFixture,
spec: Any,
host: str,
) -> None:
"""
With no configured `authorization_request_uri`, the endpoint is derived from
the workspace host (`https://<host>/oidc/v1/authorize`) on every cloud, with
no account/tenant identifier required.
"""
database = mocker.MagicMock()
database.url_object.host = host
mocker.patch("superset.db.session.get", return_value=database)
url = spec.get_oauth2_authorization_uri(
_unresolved_oauth2_config(), _oauth2_state()
)
parsed = urlparse(url)
assert parsed.netloc == host
assert parsed.path == "/oidc/v1/authorize"
@pytest.mark.parametrize(
"spec",
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
)
def test_get_oauth2_authorization_uri_preserves_configured(
mocker: MockerFixture,
spec: Any,
) -> None:
"""
A fully-resolved `authorization_request_uri` is never overwritten by the
host-derived endpoint, and no database lookup is needed.
"""
session_get = mocker.patch("superset.db.session.get")
config = _unresolved_oauth2_config()
config["authorization_request_uri"] = (
"https://accounts.cloud.databricks.com/oidc/accounts/override/v1/authorize"
)
url = spec.get_oauth2_authorization_uri(config, _oauth2_state())
assert urlparse(url).path == "/oidc/accounts/override/v1/authorize"
session_get.assert_not_called()
@pytest.mark.parametrize(
"spec",
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
)
def test_get_oauth2_authorization_uri_fails_without_host(
mocker: MockerFixture,
spec: Any,
) -> None:
"""
When the endpoint must be derived but the connection has no host, fail fast
instead of emitting an invalid `https:///oidc/v1/authorize` URL.
"""
database = mocker.MagicMock()
database.url_object.host = None
mocker.patch("superset.db.session.get", return_value=database)
with pytest.raises(OAuth2Error):
spec.get_oauth2_authorization_uri(_unresolved_oauth2_config(), _oauth2_state())
@pytest.mark.parametrize(
"spec",
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
)
def test_get_oauth2_token_fails_without_uri(
mocker: MockerFixture,
spec: Any,
) -> None:
"""
Token exchange has no database context to auto-detect the endpoint, so a
missing `token_request_uri` fails fast rather than POSTing to `.../{}/...`.
"""
with pytest.raises(OAuth2Error):
spec.get_oauth2_token(_unresolved_oauth2_config(), "authorization-code")
def test_get_oauth2_fresh_token_python(
mocker: MockerFixture,
oauth2_config_python: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_fresh_token` for Python Connector engine.
"""
requests = mocker.patch("superset.db_engine_specs.base.requests")
requests.post().json.return_value = {
"access_token": "new-access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "new-refresh-token",
}
assert DatabricksPythonConnectorEngineSpec.get_oauth2_fresh_token(
oauth2_config_python, "old-refresh-token"
) == {
"access_token": "new-access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "new-refresh-token",
}
requests.post.assert_called_with(
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
json={
"client_id": "databricks-client-id",
"client_secret": "databricks-client-secret",
"refresh_token": "old-refresh-token",
"grant_type": "refresh_token",
},
timeout=30.0,
)

View File

@@ -0,0 +1,127 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
from typing import Any
from unittest.mock import MagicMock
from urllib.parse import parse_qs, urlparse
import pytest
from pytest_mock import MockerFixture
from superset.db_engine_specs.databricks import (
DatabricksNativeEngineSpec,
DatabricksPythonConnectorEngineSpec,
)
from superset.superset_typing import OAuth2ClientConfig
from superset.utils.oauth2 import decode_oauth2_state
# Multi-Cloud Provider Tests
#
# Databricks fronts the user-to-machine OAuth2 flow on every workspace at
# `https://<workspace-host>/oidc/v1/{authorize,token}`, regardless of cloud, so
# the authorization endpoint derives from the connection host with no per-cloud
# account/tenant identifier.
SPECS = [DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec]
# Representative workspace hosts for each cloud provider.
CLOUD_HOSTS = [
"my-cluster.cloud.databricks.com", # AWS
"adb-123456789.12.azuredatabricks.net", # Azure
"123456789.gcp.databricks.com", # GCP
]
@pytest.fixture
def oauth2_config_no_uri() -> OAuth2ClientConfig:
"""
Config for Databricks OAuth2 without a pre-configured endpoint, so the
authorization endpoint is derived from the workspace host.
"""
return {
"id": "databricks-client-id",
"secret": "databricks-client-secret",
"scope": "sql",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "",
"token_request_uri": "",
"request_content_type": "json",
}
def _mock_database(mocker: MockerFixture, host: str) -> MagicMock:
"""
Build a mock database whose URL resolves to the given workspace host.
"""
database = mocker.MagicMock()
database.url_object.host = host
database.id = 1
return database
@pytest.mark.parametrize("spec", SPECS)
@pytest.mark.parametrize("host", CLOUD_HOSTS)
def test_get_oauth2_authorization_uri_uses_workspace_host(
mocker: MockerFixture,
spec: Any,
host: str,
oauth2_config_no_uri: OAuth2ClientConfig,
) -> None:
"""
The authorization endpoint is the workspace host on AWS, Azure, and GCP.
"""
from superset.db_engine_specs.base import OAuth2State
mocker.patch(
"superset.extensions.db.session.get",
return_value=_mock_database(mocker, host),
)
state: OAuth2State = {
"database_id": 1,
"user_id": 1,
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"tab_id": "1234",
}
url = spec.get_oauth2_authorization_uri(oauth2_config_no_uri, state)
parsed = urlparse(url)
assert parsed.netloc == host
assert parsed.path == "/oidc/v1/authorize"
query = parse_qs(parsed.query)
assert query["scope"][0] == "sql"
encoded_state = query["state"][0].replace("%2E", ".")
assert decode_oauth2_state(encoded_state) == state
@pytest.mark.parametrize("spec", SPECS)
@pytest.mark.parametrize("host", CLOUD_HOSTS)
def test_workspace_oauth2_endpoint_builds_token_uri(
mocker: MockerFixture,
spec: Any,
host: str,
) -> None:
"""
The helper builds the matching token endpoint from the same workspace host.
"""
database = _mock_database(mocker, host)
assert (
spec._workspace_oauth2_endpoint(database, "token")
== f"https://{host}/oidc/v1/token"
)

View File

@@ -96,6 +96,7 @@ async def test_list_dashboards_basic(mock_list, mcp_server):
dashboard.uuid = "test-dashboard-uuid-1"
dashboard.thumbnail_url = None
dashboard.roles = []
dashboard.embedded = []
dashboard.charts = []
dashboard._mapping = {
"id": dashboard.id,
@@ -162,6 +163,7 @@ async def test_list_dashboards_with_filters(mock_list, mcp_server):
dashboard.uuid = None
dashboard.thumbnail_url = None
dashboard.roles = []
dashboard.embedded = []
dashboard.charts = []
dashboard._mapping = {
"id": dashboard.id,
@@ -257,6 +259,7 @@ async def test_list_dashboards_with_search(mock_list, mcp_server):
dashboard.uuid = None
dashboard.thumbnail_url = None
dashboard.roles = []
dashboard.embedded = []
dashboard.charts = []
dashboard._mapping = {
"id": dashboard.id,
@@ -351,6 +354,7 @@ async def test_get_dashboard_info_success(
dashboard.owners = []
dashboard.tags = []
dashboard.roles = []
dashboard.embedded = []
dashboard.charts = []
dashboard._mapping = {
"id": dashboard.id,
@@ -429,6 +433,7 @@ async def test_get_dashboard_info_permalink_does_not_double_sanitize(
dashboard.owners = []
dashboard.tags = []
dashboard.roles = []
dashboard.embedded = []
dashboard.charts = []
mock_info.return_value = dashboard
permalink_value = {
@@ -521,6 +526,7 @@ async def test_get_dashboard_info_permalink_key_includes_filter_state(
dashboard.owners = []
dashboard.tags = []
dashboard.roles = []
dashboard.embedded = []
dashboard.charts = []
mock_info.return_value = dashboard
@@ -767,6 +773,7 @@ async def test_get_dashboard_info_does_not_expose_access_list_or_roles(
dashboard.owners = [owner]
dashboard.tags = []
dashboard.roles = [dashboard_role]
dashboard.embedded = []
mock_info.return_value = dashboard
@@ -838,6 +845,7 @@ async def test_get_dashboard_info_restricted_user_redacts_data_model_metadata(
dashboard.owners = []
dashboard.tags = []
dashboard.roles = []
dashboard.embedded = []
mock_info.return_value = dashboard
@@ -890,6 +898,7 @@ async def test_get_dashboard_info_restricted_user_redacts_permalink_filter_state
dashboard.owners = []
dashboard.tags = []
dashboard.roles = []
dashboard.embedded = []
mock_info.return_value = dashboard
@@ -1012,6 +1021,88 @@ async def test_list_dashboards_omits_requested_user_directory_fields(
assert field not in data["columns_available"]
@patch("superset.mcp_service.mcp_core.ModelGetInfoCore._find_object")
@pytest.mark.asyncio
async def test_get_dashboard_info_includes_embedded_uuid(mock_find_object, mcp_server):
"""Test that get_dashboard_info returns embedded_uuid when set."""
from superset.models.embedded_dashboard import EmbeddedDashboard
dashboard = Mock()
dashboard.id = 1
dashboard.dashboard_title = "Embedded Dashboard"
dashboard.slug = ""
dashboard.description = None
dashboard.css = None
dashboard.certified_by = None
dashboard.certification_details = None
dashboard.json_metadata = "{}"
dashboard.published = True
dashboard.is_managed_externally = False
dashboard.external_url = None
dashboard.created_on = None
dashboard.changed_on = None
dashboard.created_on_humanized = None
dashboard.changed_on_humanized = None
dashboard.uuid = "94b826a5-dbd5-473d-ab58-1af676ee07e4"
dashboard.url = "/dashboard/1"
dashboard.slices = []
dashboard.owners = []
dashboard.tags = []
dashboard.roles = []
dashboard.embedded = []
embedded = Mock(spec=EmbeddedDashboard)
embedded.uuid = "37c56048-d3f1-452d-b3ae-0879802dcb1f"
dashboard.embedded = [embedded]
mock_find_object.return_value = dashboard
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_dashboard_info", {"request": {"identifier": 1}}
)
assert result.data["uuid"] == "94b826a5-dbd5-473d-ab58-1af676ee07e4"
assert result.data["embedded_uuid"] == "37c56048-d3f1-452d-b3ae-0879802dcb1f"
@patch("superset.mcp_service.mcp_core.ModelGetInfoCore._find_object")
@pytest.mark.asyncio
async def test_get_dashboard_info_embedded_uuid_none_when_not_embedded(
mock_find_object, mcp_server
):
"""Test that embedded_uuid is None when the dashboard has not been configured
for embedding."""
dashboard = Mock()
dashboard.id = 2
dashboard.dashboard_title = "Non-Embedded Dashboard"
dashboard.slug = ""
dashboard.description = None
dashboard.css = None
dashboard.certified_by = None
dashboard.certification_details = None
dashboard.json_metadata = "{}"
dashboard.published = True
dashboard.is_managed_externally = False
dashboard.external_url = None
dashboard.created_on = None
dashboard.changed_on = None
dashboard.created_on_humanized = None
dashboard.changed_on_humanized = None
dashboard.uuid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
dashboard.url = "/dashboard/2"
dashboard.slices = []
dashboard.owners = []
dashboard.tags = []
dashboard.roles = []
dashboard.embedded = []
mock_find_object.return_value = dashboard
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_dashboard_info", {"request": {"identifier": 2}}
)
assert result.data.get("embedded_uuid") is None
# TODO (Phase 3+): Add tests for get_dashboard_available_filters tool
@@ -1044,6 +1135,7 @@ async def test_get_dashboard_info_by_uuid(mock_find_object, mcp_server):
dashboard.owners = []
dashboard.tags = []
dashboard.roles = []
dashboard.embedded = []
mock_find_object.return_value = dashboard
async with Client(mcp_server) as client:
@@ -1083,6 +1175,7 @@ async def test_get_dashboard_info_by_slug(mock_find_object, mcp_server):
dashboard.owners = []
dashboard.tags = []
dashboard.roles = []
dashboard.embedded = []
mock_find_object.return_value = dashboard
async with Client(mcp_server) as client:
@@ -1122,6 +1215,7 @@ async def test_list_dashboards_custom_uuid_slug_columns(mock_list, mcp_server):
dashboard.external_url = None
dashboard.thumbnail_url = None
dashboard.roles = []
dashboard.embedded = []
dashboard.charts = []
dashboard._mapping = {
"id": dashboard.id,
@@ -1203,6 +1297,7 @@ async def test_list_dashboards_sanitizes_dashboard_descriptions_and_filter_text(
dashboard.external_url = None
dashboard.thumbnail_url = None
dashboard.roles = []
dashboard.embedded = []
dashboard.charts = []
dashboard._mapping = {
"id": dashboard.id,
@@ -1343,6 +1438,7 @@ class TestDashboardDefaultColumnFiltering:
dashboard.external_url = None
dashboard.thumbnail_url = None
dashboard.roles = []
dashboard.embedded = []
dashboard.charts = []
mock_list.return_value = ([dashboard], 1)

View File

@@ -81,6 +81,7 @@ def _mock_dashboard(
dashboard.slices = []
dashboard.owners = []
dashboard.tags = []
dashboard.embedded = []
return dashboard

View File

@@ -33,6 +33,7 @@ class TestValidateExpression:
self.table.schema = "test_schema"
self.table.catalog = None
self.table.database = MagicMock()
self.table.database.backend = "sqlite"
self.table.database.db_engine_spec = MagicMock()
self.table.database.db_engine_spec.make_sqla_column_compatible = lambda x, _: x
self.table.columns = []
@@ -105,10 +106,8 @@ class TestValidateExpression:
@patch("superset.connectors.sqla.models.SqlaTable._execute_validation_query")
def test_validate_invalid_expression(self, mock_execute):
"""Test validation of invalid SQL expressions"""
# Mock _execute_validation_query to raise an exception
mock_execute.side_effect = Exception("Invalid SQL syntax")
"""Unparseable SQL is rejected by the shared expression parser before the
validation query is built or executed."""
result = self.table.validate_expression(
expression="INVALID SQL HERE",
expression_type=SqlExpressionType.COLUMN,
@@ -116,7 +115,8 @@ class TestValidateExpression:
assert result["valid"] is False
assert len(result["errors"]) == 1
assert "Invalid SQL syntax" in result["errors"][0]["message"]
assert result["errors"][0]["message"]
mock_execute.assert_not_called()
@patch("superset.connectors.sqla.models.SqlaTable._execute_validation_query")
def test_validate_having_with_non_aggregated_column(self, mock_execute):
@@ -152,6 +152,38 @@ class TestValidateExpression:
# The actual error message will come from the exception
assert "empty" in result["errors"][0]["message"].lower()
@patch("superset.models.helpers.is_feature_enabled", return_value=False)
@patch("superset.connectors.sqla.models.SqlaTable._execute_validation_query")
def test_validate_expression_rejects_subquery(
self, mock_execute: MagicMock, mock_ff: MagicMock
) -> None:
"""A sub-query expression is rejected by the same validate_adhoc_subquery
gate used for stored adhoc expressions, before any validation query is
built or run (with ALLOW_ADHOC_SUBQUERY off, the default). Locks in that
expression validation never executes the sub-query."""
result = self.table.validate_expression(
expression="(SELECT 1) IS NOT NULL OR 1 = 1",
expression_type=SqlExpressionType.WHERE,
)
assert result["valid"] is False
mock_execute.assert_not_called()
@patch("superset.models.helpers.is_feature_enabled", return_value=False)
@patch("superset.connectors.sqla.models.SqlaTable._execute_validation_query")
def test_validate_expression_rejects_set_operation(
self, mock_execute: MagicMock, mock_ff: MagicMock
) -> None:
"""A set-operation expression is rejected before the validation query is
built or run, matching the stored-adhoc-expression policy."""
result = self.table.validate_expression(
expression="1 UNION SELECT 1",
expression_type=SqlExpressionType.WHERE,
)
assert result["valid"] is False
mock_execute.assert_not_called()
@patch("superset.connectors.sqla.models.SqlaTable._execute_validation_query")
def test_validate_expression_with_rls(self, mock_execute):
"""Test that RLS filters are applied during validation"""

View File

@@ -38,3 +38,31 @@ def test_aggregate():
assert series_to_list(df["asc sum"])[0] == 5050
assert series_to_list(df["asc q2"])[0] == 75
assert series_to_list(df["desc q1"])[0] == 25
def test_aggregate_string_operators():
"""mean, median, and other operators in _PANDAS_STRING_AGGREGATORS use the
pandas string path; verify results match expected values on asc_idx [0..100]."""
aggregates = {
"asc mean": {"column": "asc_idx", "operator": "mean"},
"asc median": {"column": "asc_idx", "operator": "median"},
"asc max": {"column": "asc_idx", "operator": "max"},
"asc min": {"column": "asc_idx", "operator": "min"},
}
df = aggregate(df=categories_df, groupby=["constant"], aggregates=aggregates)
assert series_to_list(df["asc mean"])[0] == 50.0
assert series_to_list(df["asc median"])[0] == 50.0
assert series_to_list(df["asc max"])[0] == 100
assert series_to_list(df["asc min"])[0] == 0
def test_aggregate_count_includes_nulls():
"""'count' operator uses np.ma.count, which counts all rows including NaN.
It is intentionally excluded from _PANDAS_STRING_AGGREGATORS to preserve this
behavior (pandas SeriesGroupBy.count excludes NaN)."""
aggregates = {
"null_count": {"column": "idx_nulls", "operator": "count"},
}
df = aggregate(df=categories_df, groupby=["constant"], aggregates=aggregates)
# idx_nulls has 101 rows total; np.ma.count returns all 101 (NaN included)
assert series_to_list(df["null_count"])[0] == 101