mirror of
https://github.com/apache/superset.git
synced 2026-05-07 00:44:26 +00:00
Compare commits
5 Commits
embedded-e
...
hughhhh/za
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3cb8b17337 | ||
|
|
16e83455e9 | ||
|
|
957b298ae1 | ||
|
|
f29d82b3b1 | ||
|
|
3f550f166f |
9
.github/workflows/bashlib.sh
vendored
9
.github/workflows/bashlib.sh
vendored
@@ -59,15 +59,6 @@ build-assets() {
|
||||
say "::endgroup::"
|
||||
}
|
||||
|
||||
build-embedded-sdk() {
|
||||
cd "$GITHUB_WORKSPACE/superset-embedded-sdk"
|
||||
|
||||
say "::group::Build embedded SDK bundle for E2E tests"
|
||||
npm ci
|
||||
npm run build
|
||||
say "::endgroup::"
|
||||
}
|
||||
|
||||
build-instrumented-assets() {
|
||||
cd "$GITHUB_WORKSPACE/superset-frontend"
|
||||
|
||||
|
||||
6
.github/workflows/superset-e2e.yml
vendored
6
.github/workflows/superset-e2e.yml
vendored
@@ -169,7 +169,6 @@ jobs:
|
||||
PYTHONPATH: ${{ github.workspace }}
|
||||
REDIS_PORT: 16379
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
SUPERSET_FEATURE_EMBEDDED_SUPERSET: "true"
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:17-alpine
|
||||
@@ -240,11 +239,6 @@ jobs:
|
||||
uses: ./.github/actions/cached-dependencies
|
||||
with:
|
||||
run: build-instrumented-assets
|
||||
- name: Build embedded SDK
|
||||
if: steps.check.outputs.python || steps.check.outputs.frontend
|
||||
uses: ./.github/actions/cached-dependencies
|
||||
with:
|
||||
run: build-embedded-sdk
|
||||
- name: Install Playwright
|
||||
if: steps.check.outputs.python || steps.check.outputs.frontend
|
||||
uses: ./.github/actions/cached-dependencies
|
||||
|
||||
6
.github/workflows/superset-playwright.yml
vendored
6
.github/workflows/superset-playwright.yml
vendored
@@ -43,7 +43,6 @@ jobs:
|
||||
PYTHONPATH: ${{ github.workspace }}
|
||||
REDIS_PORT: 16379
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
SUPERSET_FEATURE_EMBEDDED_SUPERSET: "true"
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:17-alpine
|
||||
@@ -114,11 +113,6 @@ jobs:
|
||||
uses: ./.github/actions/cached-dependencies
|
||||
with:
|
||||
run: build-instrumented-assets
|
||||
- name: Build embedded SDK
|
||||
if: steps.check.outputs.python || steps.check.outputs.frontend
|
||||
uses: ./.github/actions/cached-dependencies
|
||||
with:
|
||||
run: build-embedded-sdk
|
||||
- name: Install Playwright
|
||||
if: steps.check.outputs.python || steps.check.outputs.frontend
|
||||
uses: ./.github/actions/cached-dependencies
|
||||
|
||||
@@ -95,7 +95,6 @@ export default defineConfig({
|
||||
testIgnore: [
|
||||
'**/tests/auth/**/*.spec.ts',
|
||||
'**/tests/sqllab/**/*.spec.ts',
|
||||
'**/tests/embedded/**/*.spec.ts',
|
||||
...(process.env.INCLUDE_EXPERIMENTAL ? [] : ['**/experimental/**']),
|
||||
],
|
||||
use: {
|
||||
@@ -133,18 +132,6 @@ export default defineConfig({
|
||||
// No storageState = clean browser with no cached cookies
|
||||
},
|
||||
},
|
||||
{
|
||||
// Embedded dashboard tests - validates the full embedding flow:
|
||||
// external app -> SDK -> iframe -> guest token -> dashboard render
|
||||
name: 'chromium-embedded',
|
||||
testMatch: '**/tests/embedded/**/*.spec.ts',
|
||||
use: {
|
||||
browserName: 'chromium',
|
||||
testIdAttribute: 'data-test',
|
||||
// Uses admin auth for API calls to configure embedding and get guest tokens
|
||||
storageState: 'playwright/.auth/user.json',
|
||||
},
|
||||
},
|
||||
],
|
||||
|
||||
// Web server setup - disabled in CI (Flask started separately in workflow)
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
<!--
|
||||
Licensed to the Apache Software Foundation (ASF) under one
|
||||
or more contributor license agreements. See the NOTICE file
|
||||
distributed with this work for additional information
|
||||
regarding copyright ownership. The ASF licenses this file
|
||||
to you under the Apache License, Version 2.0 (the
|
||||
"License"); you may not use this file except in compliance
|
||||
with the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing,
|
||||
software distributed under the License is distributed on an
|
||||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations
|
||||
under the License.
|
||||
-->
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Embedded Dashboard Test App</title>
|
||||
<style>
|
||||
html, body { margin: 0; padding: 0; height: 100%; }
|
||||
#superset-container { width: 100%; height: 100vh; }
|
||||
#superset-container iframe { width: 100%; height: 100%; border: none; }
|
||||
#error { color: red; padding: 20px; display: none; }
|
||||
#status { padding: 10px; font-family: monospace; font-size: 12px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="status">Initializing embedded dashboard...</div>
|
||||
<div id="error"></div>
|
||||
<div id="superset-container" data-test="embedded-container"></div>
|
||||
|
||||
<script src="/sdk/index.js"></script>
|
||||
<script>
|
||||
(async function () {
|
||||
const params = new URLSearchParams(window.location.search);
|
||||
const uuid = params.get('uuid');
|
||||
const supersetDomain = params.get('supersetDomain');
|
||||
|
||||
if (!uuid || !supersetDomain) {
|
||||
document.getElementById('error').style.display = 'block';
|
||||
document.getElementById('error').textContent =
|
||||
'Missing required query params: uuid, supersetDomain';
|
||||
return;
|
||||
}
|
||||
|
||||
const statusEl = document.getElementById('status');
|
||||
|
||||
// fetchGuestToken is injected by Playwright via page.exposeFunction()
|
||||
// Falls back to window.__guestToken for simple/static token injection
|
||||
async function fetchGuestToken() {
|
||||
if (typeof window.__fetchGuestToken === 'function') {
|
||||
statusEl.textContent = 'Fetching guest token...';
|
||||
const token = await window.__fetchGuestToken();
|
||||
statusEl.textContent = 'Guest token received, loading dashboard...';
|
||||
return token;
|
||||
}
|
||||
if (window.__guestToken) {
|
||||
return window.__guestToken;
|
||||
}
|
||||
throw new Error('No guest token source available');
|
||||
}
|
||||
|
||||
try {
|
||||
// Parse optional UI config from query params
|
||||
const uiConfig = {};
|
||||
if (params.get('hideTitle') === 'true') uiConfig.hideTitle = true;
|
||||
if (params.get('hideTab') === 'true') uiConfig.hideTab = true;
|
||||
if (params.get('hideChartControls') === 'true') uiConfig.hideChartControls = true;
|
||||
|
||||
const dashboard = await supersetEmbeddedSdk.embedDashboard({
|
||||
id: uuid,
|
||||
supersetDomain: supersetDomain,
|
||||
mountPoint: document.getElementById('superset-container'),
|
||||
fetchGuestToken: fetchGuestToken,
|
||||
dashboardUiConfig: Object.keys(uiConfig).length > 0 ? uiConfig : undefined,
|
||||
debug: params.get('debug') === 'true',
|
||||
});
|
||||
|
||||
statusEl.textContent = 'Dashboard embedded successfully';
|
||||
// Expose dashboard API on window for Playwright assertions
|
||||
window.__embeddedDashboard = dashboard;
|
||||
} catch (err) {
|
||||
document.getElementById('error').style.display = 'block';
|
||||
document.getElementById('error').textContent = 'Embed failed: ' + err.message;
|
||||
statusEl.textContent = 'Error';
|
||||
}
|
||||
})();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -132,14 +132,26 @@ export interface DashboardResult {
|
||||
published?: boolean;
|
||||
}
|
||||
|
||||
async function getDashboardByFilter(
|
||||
/**
|
||||
* Get a dashboard by its title
|
||||
* @param page - Playwright page instance (provides authentication context)
|
||||
* @param title - The dashboard_title to search for
|
||||
* @returns Dashboard object if found, null if not found
|
||||
*/
|
||||
export async function getDashboardByName(
|
||||
page: Page,
|
||||
col: 'dashboard_title' | 'slug',
|
||||
value: string,
|
||||
title: string,
|
||||
): Promise<DashboardResult | null> {
|
||||
const queryParam = rison.encode({
|
||||
filters: [{ col, opr: 'eq', value }],
|
||||
});
|
||||
const filter = {
|
||||
filters: [
|
||||
{
|
||||
col: 'dashboard_title',
|
||||
opr: 'eq',
|
||||
value: title,
|
||||
},
|
||||
],
|
||||
};
|
||||
const queryParam = rison.encode(filter);
|
||||
const response = await apiGet(
|
||||
page,
|
||||
`${ENDPOINTS.DASHBOARD}?q=${queryParam}`,
|
||||
@@ -157,29 +169,3 @@ async function getDashboardByFilter(
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a dashboard by its title
|
||||
* @param page - Playwright page instance (provides authentication context)
|
||||
* @param title - The dashboard_title to search for
|
||||
* @returns Dashboard object if found, null if not found
|
||||
*/
|
||||
export async function getDashboardByName(
|
||||
page: Page,
|
||||
title: string,
|
||||
): Promise<DashboardResult | null> {
|
||||
return getDashboardByFilter(page, 'dashboard_title', title);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a dashboard by its slug
|
||||
* @param page - Playwright page instance (provides authentication context)
|
||||
* @param slug - The slug to search for
|
||||
* @returns Dashboard object if found, null if not found
|
||||
*/
|
||||
export async function getDashboardBySlug(
|
||||
page: Page,
|
||||
slug: string,
|
||||
): Promise<DashboardResult | null> {
|
||||
return getDashboardByFilter(page, 'slug', slug);
|
||||
}
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
import { Page } from '@playwright/test';
|
||||
import { apiPost, apiPut } from './requests';
|
||||
import { ENDPOINTS as DASHBOARD_ENDPOINTS } from './dashboard';
|
||||
|
||||
export const ENDPOINTS = {
|
||||
SECURITY_LOGIN: 'api/v1/security/login',
|
||||
GUEST_TOKEN: 'api/v1/security/guest_token/',
|
||||
} as const;
|
||||
|
||||
export interface EmbeddedConfig {
|
||||
uuid: string;
|
||||
allowed_domains: string[];
|
||||
dashboard_id: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable embedding on a dashboard and return the embedded UUID.
|
||||
* Uses PUT (upsert) to preserve UUID across repeated calls.
|
||||
* @param page - Playwright page instance (provides authentication context)
|
||||
* @param dashboardIdOrSlug - Numeric dashboard id or slug
|
||||
* @param allowedDomains - Domains allowed to embed; empty array allows all
|
||||
* @returns Embedded config with UUID, allowed_domains, and dashboard_id
|
||||
*/
|
||||
export async function apiEnableEmbedding(
|
||||
page: Page,
|
||||
dashboardIdOrSlug: number | string,
|
||||
allowedDomains: string[] = [],
|
||||
): Promise<EmbeddedConfig> {
|
||||
const response = await apiPut(
|
||||
page,
|
||||
`${DASHBOARD_ENDPOINTS.DASHBOARD}${dashboardIdOrSlug}/embedded`,
|
||||
{ allowed_domains: allowedDomains },
|
||||
);
|
||||
const body = await response.json();
|
||||
return body.result as EmbeddedConfig;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a guest token for an embedded dashboard.
|
||||
* Uses the admin login flow (login → access_token → guest_token).
|
||||
* @param page - Playwright page instance (used for request context)
|
||||
* @param dashboardId - Dashboard id to grant access to
|
||||
* @param options - Optional login credentials and RLS rules
|
||||
* @returns Signed guest token string
|
||||
*/
|
||||
export async function getGuestToken(
|
||||
page: Page,
|
||||
dashboardId: number | string,
|
||||
options?: {
|
||||
username?: string;
|
||||
password?: string;
|
||||
rls?: Array<{ dataset: number; clause: string }>;
|
||||
},
|
||||
): Promise<string> {
|
||||
const username = options?.username ?? 'admin';
|
||||
const password = options?.password ?? 'general';
|
||||
const rls = options?.rls ?? [];
|
||||
|
||||
// Step 1: Login to get access token
|
||||
const loginResponse = await apiPost(
|
||||
page,
|
||||
ENDPOINTS.SECURITY_LOGIN,
|
||||
{
|
||||
username,
|
||||
password,
|
||||
provider: 'db',
|
||||
refresh: true,
|
||||
},
|
||||
{ allowMissingCsrf: true },
|
||||
);
|
||||
const loginBody = await loginResponse.json();
|
||||
const accessToken = loginBody.access_token;
|
||||
|
||||
// Step 2: Fetch guest token using the access token.
|
||||
// Uses raw page.request.post() (not apiPost) because the guest token endpoint
|
||||
// requires a JWT Bearer token rather than session+CSRF auth.
|
||||
const guestResponse = await page.request.post(ENDPOINTS.GUEST_TOKEN, {
|
||||
data: {
|
||||
user: {
|
||||
username: 'embedded_test_user',
|
||||
first_name: 'Embedded',
|
||||
last_name: 'TestUser',
|
||||
},
|
||||
resources: [{ type: 'dashboard', id: String(dashboardId) }],
|
||||
rls,
|
||||
},
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
});
|
||||
const guestBody = await guestResponse.json();
|
||||
return guestBody.token;
|
||||
}
|
||||
@@ -1,140 +0,0 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
import { Page, FrameLocator } from '@playwright/test';
|
||||
import { EMBEDDED } from '../utils/constants';
|
||||
|
||||
/**
|
||||
* Page object for the embedded dashboard test app.
|
||||
*
|
||||
* The test app runs on a separate origin (localhost:9000) and uses the
|
||||
* @superset-ui/embedded-sdk to render a Superset dashboard in an iframe.
|
||||
* Playwright's page.exposeFunction() bridges the guest token from Node.js
|
||||
* into the browser page.
|
||||
*/
|
||||
export class EmbeddedPage {
|
||||
private readonly page: Page;
|
||||
|
||||
private static readonly SELECTORS = {
|
||||
CONTAINER: '[data-test="embedded-container"]',
|
||||
IFRAME: 'iframe[title="Embedded Dashboard"]',
|
||||
STATUS: '#status',
|
||||
ERROR: '#error',
|
||||
} as const;
|
||||
|
||||
constructor(page: Page) {
|
||||
this.page = page;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set up the guest token bridge before navigating.
|
||||
* Must be called BEFORE goto() since embedDashboard() calls fetchGuestToken
|
||||
* immediately on page load.
|
||||
*/
|
||||
async exposeTokenFetcher(tokenFn: () => Promise<string>): Promise<void> {
|
||||
await this.page.exposeFunction('__fetchGuestToken', tokenFn);
|
||||
}
|
||||
|
||||
/**
|
||||
* Navigate to the embedded test app with the given parameters.
|
||||
*/
|
||||
async goto(params: {
|
||||
uuid: string;
|
||||
supersetDomain: string;
|
||||
hideTitle?: boolean;
|
||||
hideTab?: boolean;
|
||||
hideChartControls?: boolean;
|
||||
debug?: boolean;
|
||||
}): Promise<void> {
|
||||
const searchParams = new URLSearchParams({
|
||||
uuid: params.uuid,
|
||||
supersetDomain: params.supersetDomain,
|
||||
});
|
||||
if (params.hideTitle) searchParams.set('hideTitle', 'true');
|
||||
if (params.hideTab) searchParams.set('hideTab', 'true');
|
||||
if (params.hideChartControls) searchParams.set('hideChartControls', 'true');
|
||||
if (params.debug) searchParams.set('debug', 'true');
|
||||
|
||||
await this.page.goto(`${EMBEDDED.APP_URL}/?${searchParams.toString()}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* FrameLocator for the embedded dashboard iframe.
|
||||
*/
|
||||
get iframe(): FrameLocator {
|
||||
return this.page.frameLocator(EmbeddedPage.SELECTORS.IFRAME);
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait for the iframe to appear in the DOM.
|
||||
*/
|
||||
async waitForIframe(options?: { timeout?: number }): Promise<void> {
|
||||
await this.page.locator(EmbeddedPage.SELECTORS.IFRAME).waitFor({
|
||||
state: 'attached',
|
||||
timeout: options?.timeout ?? EMBEDDED.IFRAME_LOAD,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait for dashboard content to render inside the iframe.
|
||||
* Looks for the grid-container which indicates charts are loading/loaded.
|
||||
*/
|
||||
async waitForDashboardContent(options?: { timeout?: number }): Promise<void> {
|
||||
const frame = this.iframe;
|
||||
await frame
|
||||
.locator('.grid-container, [data-test="grid-container"]')
|
||||
.first()
|
||||
.waitFor({
|
||||
state: 'visible',
|
||||
timeout: options?.timeout ?? EMBEDDED.DASHBOARD_RENDER,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the status text from the test app.
|
||||
*/
|
||||
async getStatus(): Promise<string> {
|
||||
return (
|
||||
(await this.page.locator(EmbeddedPage.SELECTORS.STATUS).textContent()) ??
|
||||
''
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the error text, if any.
|
||||
*/
|
||||
async getError(): Promise<string> {
|
||||
const errorEl = this.page.locator(EmbeddedPage.SELECTORS.ERROR);
|
||||
const display = await errorEl.evaluate(el => getComputedStyle(el).display);
|
||||
if (display === 'none') return '';
|
||||
return (await errorEl.textContent()) ?? '';
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the dashboard title is visible inside the iframe.
|
||||
*/
|
||||
async isTitleVisible(): Promise<boolean> {
|
||||
const frame = this.iframe;
|
||||
return frame
|
||||
.locator(
|
||||
'[data-test="dashboard-header-container"] [data-test="editable-title-input"]',
|
||||
)
|
||||
.isVisible();
|
||||
}
|
||||
}
|
||||
@@ -1,288 +0,0 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
import { test, expect, Browser, BrowserContext, Page } from '@playwright/test';
|
||||
import { createServer, IncomingMessage, ServerResponse, Server } from 'http';
|
||||
import { readFileSync, existsSync } from 'fs';
|
||||
import { join } from 'path';
|
||||
import { apiEnableEmbedding, getGuestToken } from '../../helpers/api/embedded';
|
||||
import { getDashboardBySlug } from '../../helpers/api/dashboard';
|
||||
import { EmbeddedPage } from '../../pages/EmbeddedPage';
|
||||
import { EMBEDDED } from '../../utils/constants';
|
||||
|
||||
/**
|
||||
* Superset domain (Flask server) — set by CI or defaults to local dev
|
||||
*/
|
||||
const SUPERSET_DOMAIN = (() => {
|
||||
const url = process.env.PLAYWRIGHT_BASE_URL || 'http://localhost:8088';
|
||||
return url.replace(/\/+$/, '');
|
||||
})();
|
||||
|
||||
const SUPERSET_BASE_URL = SUPERSET_DOMAIN.endsWith('/')
|
||||
? SUPERSET_DOMAIN
|
||||
: `${SUPERSET_DOMAIN}/`;
|
||||
|
||||
/**
|
||||
* Path to the SDK bundle built from superset-embedded-sdk/
|
||||
*/
|
||||
const SDK_BUNDLE_PATH = join(
|
||||
__dirname,
|
||||
'../../../../superset-embedded-sdk/bundle/index.js',
|
||||
);
|
||||
|
||||
/**
|
||||
* Path to the embedded test app static files
|
||||
*/
|
||||
const EMBED_APP_DIR = join(__dirname, '../../embedded-app');
|
||||
|
||||
/**
|
||||
* Create a minimal static file server for the embedded test app.
|
||||
* Serves only a fixed allowlist of routes — the test app references just
|
||||
* its index.html and the SDK bundle, so anything else is 404.
|
||||
*/
|
||||
const INDEX_HTML_PATH = join(EMBED_APP_DIR, 'index.html');
|
||||
|
||||
function createEmbedAppServer(): Server {
|
||||
return createServer((req: IncomingMessage, res: ServerResponse) => {
|
||||
const urlPath = req.url?.split('?')[0] || '/';
|
||||
|
||||
if (urlPath === '/sdk/index.js') {
|
||||
if (!existsSync(SDK_BUNDLE_PATH)) {
|
||||
res.writeHead(404);
|
||||
res.end(
|
||||
'SDK bundle not found. Run: cd superset-embedded-sdk && npm ci && npm run build',
|
||||
);
|
||||
return;
|
||||
}
|
||||
res.writeHead(200, { 'Content-Type': 'text/javascript' });
|
||||
res.end(readFileSync(SDK_BUNDLE_PATH));
|
||||
return;
|
||||
}
|
||||
|
||||
if (urlPath === '/' || urlPath === '/index.html') {
|
||||
res.writeHead(200, { 'Content-Type': 'text/html' });
|
||||
res.end(readFileSync(INDEX_HTML_PATH));
|
||||
return;
|
||||
}
|
||||
|
||||
res.writeHead(404);
|
||||
res.end('Not found');
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a browser context authenticated as admin for API-only work
|
||||
* (enabling embedding, restoring config). Caller is responsible for closing.
|
||||
*/
|
||||
function createAdminContext(browser: Browser): Promise<BrowserContext> {
|
||||
return browser.newContext({
|
||||
storageState: 'playwright/.auth/user.json',
|
||||
baseURL: SUPERSET_BASE_URL,
|
||||
});
|
||||
}
|
||||
|
||||
// ─── Test Suite ────────────────────────────────────────────────────────────
|
||||
|
||||
// Describe wrapper is needed for shared server state and serial execution:
|
||||
// all tests share a static file server on a fixed port and must not run in parallel.
|
||||
test.describe('Embedded Dashboard E2E', () => {
|
||||
test.describe.configure({ mode: 'serial' });
|
||||
|
||||
let server: Server;
|
||||
let embedUuid: string;
|
||||
let dashboardId: number;
|
||||
|
||||
/**
|
||||
* Set up a page to render the default embedded dashboard.
|
||||
* Tests that need a different UUID or UI config should not use this helper.
|
||||
*/
|
||||
async function setupEmbeddedPage(page: Page): Promise<EmbeddedPage> {
|
||||
const embeddedPage = new EmbeddedPage(page);
|
||||
await embeddedPage.exposeTokenFetcher(async () =>
|
||||
getGuestToken(page, dashboardId),
|
||||
);
|
||||
await embeddedPage.goto({
|
||||
uuid: embedUuid,
|
||||
supersetDomain: SUPERSET_DOMAIN,
|
||||
});
|
||||
await embeddedPage.waitForIframe();
|
||||
await embeddedPage.waitForDashboardContent();
|
||||
return embeddedPage;
|
||||
}
|
||||
|
||||
test.beforeAll(async ({ browser }) => {
|
||||
// Skip all tests if the SDK bundle hasn't been built
|
||||
test.skip(
|
||||
!existsSync(SDK_BUNDLE_PATH),
|
||||
'Embedded SDK bundle not found. Build it with: cd superset-embedded-sdk && npm ci && npm run build',
|
||||
);
|
||||
|
||||
// Start the embedded test app server
|
||||
server = createEmbedAppServer();
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
server.on('error', reject);
|
||||
server.listen(EMBEDDED.APP_PORT, () => resolve());
|
||||
});
|
||||
|
||||
// Use a fresh context with auth to set up test data via API
|
||||
const context = await createAdminContext(browser);
|
||||
const setupPage = await context.newPage();
|
||||
|
||||
try {
|
||||
// Find a well-known example dashboard
|
||||
const dashboard = await getDashboardBySlug(setupPage, 'world_health');
|
||||
if (!dashboard) {
|
||||
throw new Error(
|
||||
'Dashboard "world_health" not found. Ensure load_examples ran in CI setup.',
|
||||
);
|
||||
}
|
||||
dashboardId = dashboard.id;
|
||||
|
||||
// Enable embedding on the dashboard (empty allowed_domains = allow all)
|
||||
const embedded = await apiEnableEmbedding(setupPage, dashboardId);
|
||||
embedUuid = embedded.uuid;
|
||||
} finally {
|
||||
await context.close();
|
||||
}
|
||||
});
|
||||
|
||||
test.afterAll(async () => {
|
||||
if (server) {
|
||||
await new Promise<void>(resolve => server.close(() => resolve()));
|
||||
}
|
||||
});
|
||||
|
||||
test('dashboard renders in embedded iframe', async ({ page }) => {
|
||||
const embeddedPage = await setupEmbeddedPage(page);
|
||||
|
||||
// Verify the iframe src points to Superset's /embedded/ endpoint
|
||||
const iframeSrc = await page
|
||||
.locator('iframe[title="Embedded Dashboard"]')
|
||||
.getAttribute('src');
|
||||
expect(iframeSrc).toContain(`/embedded/${embedUuid}`);
|
||||
|
||||
// Verify no errors in the test app
|
||||
const error = await embeddedPage.getError();
|
||||
expect(error).toBe('');
|
||||
|
||||
// Baseline: title should be visible when hideTitle is not set
|
||||
const titleVisible = await embeddedPage.isTitleVisible();
|
||||
expect(titleVisible).toBe(true);
|
||||
});
|
||||
|
||||
test('UI config hideTitle hides dashboard title', async ({ page }) => {
|
||||
const embeddedPage = new EmbeddedPage(page);
|
||||
await embeddedPage.exposeTokenFetcher(async () =>
|
||||
getGuestToken(page, dashboardId),
|
||||
);
|
||||
await embeddedPage.goto({
|
||||
uuid: embedUuid,
|
||||
supersetDomain: SUPERSET_DOMAIN,
|
||||
hideTitle: true,
|
||||
});
|
||||
await embeddedPage.waitForIframe();
|
||||
await embeddedPage.waitForDashboardContent();
|
||||
|
||||
// The iframe URL should include uiConfig parameter
|
||||
const iframeSrc = await page
|
||||
.locator('iframe[title="Embedded Dashboard"]')
|
||||
.getAttribute('src');
|
||||
expect(iframeSrc).toContain('uiConfig=');
|
||||
|
||||
// Verify the title is actually hidden inside the iframe
|
||||
const titleVisible = await embeddedPage.isTitleVisible();
|
||||
expect(titleVisible).toBe(false);
|
||||
});
|
||||
|
||||
test('charts render inside embedded iframe', async ({ page }) => {
|
||||
const embeddedPage = await setupEmbeddedPage(page);
|
||||
|
||||
// Verify chart containers are present and visible in the iframe
|
||||
const charts = embeddedPage.iframe.locator(
|
||||
'.chart-container, [data-test="chart-container"]',
|
||||
);
|
||||
await expect(charts.first()).toBeVisible({
|
||||
timeout: EMBEDDED.DASHBOARD_RENDER,
|
||||
});
|
||||
});
|
||||
|
||||
test('allowed_domains blocks unauthorized referrer', async ({
|
||||
page,
|
||||
browser,
|
||||
}) => {
|
||||
const context = await createAdminContext(browser);
|
||||
const setupPage = await context.newPage();
|
||||
|
||||
try {
|
||||
// Restrict to a domain that is NOT localhost:9000
|
||||
const restrictedEmbed = await apiEnableEmbedding(setupPage, dashboardId, [
|
||||
'https://allowed.example.com',
|
||||
]);
|
||||
|
||||
const embeddedPage = new EmbeddedPage(page);
|
||||
await embeddedPage.exposeTokenFetcher(async () =>
|
||||
getGuestToken(page, dashboardId),
|
||||
);
|
||||
await embeddedPage.goto({
|
||||
uuid: restrictedEmbed.uuid,
|
||||
supersetDomain: SUPERSET_DOMAIN,
|
||||
});
|
||||
|
||||
// The iframe should load but get a 403 from Superset's referrer check
|
||||
await embeddedPage.waitForIframe();
|
||||
|
||||
// The dashboard content should NOT render (403 blocks the embedded page)
|
||||
const content = embeddedPage.iframe.locator(
|
||||
'.grid-container, [data-test="grid-container"]',
|
||||
);
|
||||
await expect(content).not.toBeVisible({ timeout: 5000 });
|
||||
} finally {
|
||||
// Restore the open embedding config for other tests
|
||||
await apiEnableEmbedding(setupPage, dashboardId, []);
|
||||
await context.close();
|
||||
}
|
||||
});
|
||||
|
||||
test('guest token enables dashboard data access', async ({ page }) => {
|
||||
const embeddedPage = new EmbeddedPage(page);
|
||||
|
||||
let tokenCallCount = 0;
|
||||
await embeddedPage.exposeTokenFetcher(async () => {
|
||||
tokenCallCount += 1;
|
||||
return getGuestToken(page, dashboardId);
|
||||
});
|
||||
|
||||
await embeddedPage.goto({
|
||||
uuid: embedUuid,
|
||||
supersetDomain: SUPERSET_DOMAIN,
|
||||
});
|
||||
await embeddedPage.waitForIframe();
|
||||
await embeddedPage.waitForDashboardContent();
|
||||
|
||||
// The SDK should have called fetchGuestToken at least once
|
||||
expect(tokenCallCount).toBeGreaterThanOrEqual(1);
|
||||
|
||||
// Verify charts are actually rendering data (not just loading spinners)
|
||||
const charts = embeddedPage.iframe.locator(
|
||||
'.chart-container, [data-test="chart-container"]',
|
||||
);
|
||||
const chartCount = await charts.count();
|
||||
expect(chartCount).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
@@ -75,18 +75,3 @@ export const TIMEOUT = {
|
||||
*/
|
||||
SLOW_TEST: 60000, // 60s for tests that chain multiple slow operations
|
||||
} as const;
|
||||
|
||||
/**
|
||||
* Embedded dashboard test app configuration.
|
||||
* The test app is served by a Node.js http server started in the test fixture.
|
||||
*/
|
||||
export const EMBEDDED = {
|
||||
/** Port for the embedded test app static server */
|
||||
APP_PORT: 9000,
|
||||
/** Full URL for the embedded test app */
|
||||
APP_URL: 'http://localhost:9000',
|
||||
/** Timeout for iframe to appear in the DOM */
|
||||
IFRAME_LOAD: 15000, // 15s
|
||||
/** Timeout for dashboard content to render inside the iframe */
|
||||
DASHBOARD_RENDER: 30000, // 30s
|
||||
} as const;
|
||||
|
||||
@@ -17,20 +17,31 @@
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
import { FC, memo, useMemo } from 'react';
|
||||
import { FC, memo, useCallback, useMemo } from 'react';
|
||||
import { t } from '@apache-superset/core/translation';
|
||||
import { DataMaskStateWithId } from '@superset-ui/core';
|
||||
import { styled } from '@apache-superset/core/theme';
|
||||
import { Loading } from '@superset-ui/core/components';
|
||||
import { RootState } from 'src/dashboard/types';
|
||||
import { Icons } from '@superset-ui/core/components/Icons';
|
||||
import { FilterBarOrientation, RootState } from 'src/dashboard/types';
|
||||
import { useChartLayoutItems } from 'src/dashboard/util/useChartLayoutItems';
|
||||
import { useChartIds } from 'src/dashboard/util/charts/useChartIds';
|
||||
import { useSelector } from 'react-redux';
|
||||
import {
|
||||
getRisonFilterParam,
|
||||
parseRisonFilters,
|
||||
updateUrlWithUnmatchedFilters,
|
||||
} from 'src/dashboard/util/risonFilters';
|
||||
import FilterControls from './FilterControls/FilterControls';
|
||||
import { useChartsVerboseMaps, getFilterBarTestId } from './utils';
|
||||
import { HorizontalBarProps } from './types';
|
||||
import FilterBarSettings from './FilterBarSettings';
|
||||
import crossFiltersSelector from './CrossFilters/selectors';
|
||||
import {
|
||||
getUrlFilterIndicators,
|
||||
UrlFilterIndicator,
|
||||
} from './UrlFilters/selectors';
|
||||
import UrlFilterTag from './UrlFilters/UrlFilterTag';
|
||||
|
||||
const HorizontalBar = styled.div`
|
||||
${({ theme }) => `
|
||||
@@ -65,6 +76,28 @@ const FilterBarEmptyStateContainer = styled.div`
|
||||
`}
|
||||
`;
|
||||
|
||||
const UrlFiltersContainer = styled.div`
|
||||
${({ theme }) => `
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
align-items: center;
|
||||
gap: ${theme.sizeUnit * 2}px;
|
||||
padding: 0 ${theme.sizeUnit * 2}px;
|
||||
margin-right: ${theme.sizeUnit * 2}px;
|
||||
border-right: 1px solid ${theme.colorBorder};
|
||||
`}
|
||||
`;
|
||||
|
||||
const UrlFilterTitle = styled.div`
|
||||
${({ theme }) => `
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: ${theme.sizeUnit}px;
|
||||
font-weight: ${theme.fontWeightStrong};
|
||||
font-size: ${theme.fontSizeSM}px;
|
||||
`}
|
||||
`;
|
||||
|
||||
const HorizontalFilterBar: FC<HorizontalBarProps> = ({
|
||||
actions,
|
||||
dataMaskSelected,
|
||||
@@ -94,9 +127,47 @@ const HorizontalFilterBar: FC<HorizontalBarProps> = ({
|
||||
[chartIds, chartLayoutItems, dataMask, verboseMaps],
|
||||
);
|
||||
|
||||
const activeUrlFilters = useMemo(() => getUrlFilterIndicators(), []);
|
||||
|
||||
const handleRemoveUrlFilter = useCallback(
|
||||
(filterToRemove: UrlFilterIndicator) => {
|
||||
const risonParam = getRisonFilterParam();
|
||||
if (!risonParam) return;
|
||||
|
||||
const currentFilters = parseRisonFilters(risonParam);
|
||||
const remaining = currentFilters.filter(
|
||||
f => f.subject !== filterToRemove.filter.subject,
|
||||
);
|
||||
updateUrlWithUnmatchedFilters(remaining);
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
const urlFiltersComponent = useMemo(() => {
|
||||
if (activeUrlFilters.length === 0) return null;
|
||||
|
||||
return (
|
||||
<UrlFiltersContainer>
|
||||
<UrlFilterTitle>
|
||||
<Icons.LinkOutlined iconSize="s" />
|
||||
{t('URL Filters')}
|
||||
</UrlFilterTitle>
|
||||
{activeUrlFilters.map(filter => (
|
||||
<UrlFilterTag
|
||||
key={filter.subject}
|
||||
filter={filter}
|
||||
orientation={FilterBarOrientation.Horizontal}
|
||||
onRemove={handleRemoveUrlFilter}
|
||||
/>
|
||||
))}
|
||||
</UrlFiltersContainer>
|
||||
);
|
||||
}, [activeUrlFilters, handleRemoveUrlFilter]);
|
||||
|
||||
const hasFilters =
|
||||
filterValues.length > 0 ||
|
||||
selectedCrossFilters.length > 0 ||
|
||||
activeUrlFilters.length > 0 ||
|
||||
chartCustomizationValues.length > 0;
|
||||
|
||||
return (
|
||||
@@ -113,16 +184,19 @@ const HorizontalFilterBar: FC<HorizontalBarProps> = ({
|
||||
</FilterBarEmptyStateContainer>
|
||||
)}
|
||||
{hasFilters && (
|
||||
<FilterControls
|
||||
dataMaskSelected={dataMaskSelected}
|
||||
onFilterSelectionChange={onSelectionChange}
|
||||
onPendingCustomizationDataMaskChange={
|
||||
onPendingCustomizationDataMaskChange
|
||||
}
|
||||
chartCustomizationValues={chartCustomizationValues}
|
||||
clearAllTriggers={clearAllTriggers}
|
||||
onClearAllComplete={onClearAllComplete}
|
||||
/>
|
||||
<>
|
||||
{urlFiltersComponent}
|
||||
<FilterControls
|
||||
dataMaskSelected={dataMaskSelected}
|
||||
onFilterSelectionChange={onSelectionChange}
|
||||
onPendingCustomizationDataMaskChange={
|
||||
onPendingCustomizationDataMaskChange
|
||||
}
|
||||
chartCustomizationValues={chartCustomizationValues}
|
||||
clearAllTriggers={clearAllTriggers}
|
||||
onClearAllComplete={onClearAllComplete}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
{actions}
|
||||
</>
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
import { useCSSTextTruncation } from '@superset-ui/core';
|
||||
import { styled, css, useTheme } from '@apache-superset/core/theme';
|
||||
import { Tag } from 'src/components/Tag';
|
||||
import { Tooltip } from '@superset-ui/core/components';
|
||||
import { FilterBarOrientation } from 'src/dashboard/types';
|
||||
import { ellipsisCss } from '../CrossFilters/styles';
|
||||
import { UrlFilterIndicator } from './selectors';
|
||||
|
||||
const StyledValue = styled.b`
|
||||
${({ theme }) => `
|
||||
max-width: ${theme.sizeUnit * 25}px;
|
||||
`}
|
||||
${ellipsisCss}
|
||||
`;
|
||||
|
||||
const StyledColumn = styled('span')`
|
||||
${({ theme }) => `
|
||||
max-width: ${theme.sizeUnit * 25}px;
|
||||
padding-right: ${theme.sizeUnit}px;
|
||||
`}
|
||||
${ellipsisCss}
|
||||
`;
|
||||
|
||||
const StyledTag = styled(Tag)`
|
||||
${({ theme }) => `
|
||||
border: 1px solid ${theme.colorBorder};
|
||||
border-radius: 2px;
|
||||
.anticon-close {
|
||||
vertical-align: middle;
|
||||
}
|
||||
`}
|
||||
`;
|
||||
|
||||
const UrlFilterTag = (props: {
|
||||
filter: UrlFilterIndicator;
|
||||
orientation: FilterBarOrientation;
|
||||
onRemove: (filter: UrlFilterIndicator) => void;
|
||||
}) => {
|
||||
const { filter, orientation, onRemove } = props;
|
||||
const theme = useTheme();
|
||||
const [columnRef, columnIsTruncated] =
|
||||
useCSSTextTruncation<HTMLSpanElement>();
|
||||
const [valueRef, valueIsTruncated] = useCSSTextTruncation<HTMLSpanElement>();
|
||||
|
||||
return (
|
||||
<StyledTag
|
||||
css={css`
|
||||
${orientation === FilterBarOrientation.Vertical
|
||||
? `margin-top: ${theme.sizeUnit * 2}px;`
|
||||
: `margin-left: ${theme.sizeUnit * 2}px;`}
|
||||
`}
|
||||
closable
|
||||
onClose={() => onRemove(filter)}
|
||||
editable
|
||||
>
|
||||
<Tooltip title={columnIsTruncated ? filter.subject : null}>
|
||||
<StyledColumn ref={columnRef}>{filter.subject}</StyledColumn>
|
||||
</Tooltip>
|
||||
<Tooltip title={valueIsTruncated ? filter.value : null}>
|
||||
<StyledValue ref={valueRef}>{filter.value}</StyledValue>
|
||||
</Tooltip>
|
||||
</StyledTag>
|
||||
);
|
||||
};
|
||||
|
||||
export default UrlFilterTag;
|
||||
@@ -0,0 +1,34 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
import { useMemo } from 'react';
|
||||
import { getUrlFilterIndicators } from './selectors';
|
||||
import UrlFiltersVerticalCollapse from './VerticalCollapse';
|
||||
|
||||
const UrlFiltersVertical = () => {
|
||||
const urlFilters = useMemo(() => getUrlFilterIndicators(), []);
|
||||
|
||||
if (!urlFilters.length) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return <UrlFiltersVerticalCollapse urlFilters={urlFilters} />;
|
||||
};
|
||||
|
||||
export default UrlFiltersVertical;
|
||||
@@ -0,0 +1,173 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
import { useMemo, useState, useCallback } from 'react';
|
||||
import { t } from '@apache-superset/core/translation';
|
||||
import { css, useTheme, SupersetTheme } from '@apache-superset/core/theme';
|
||||
import { Icons } from '@superset-ui/core/components/Icons';
|
||||
import { FilterBarOrientation } from 'src/dashboard/types';
|
||||
import {
|
||||
updateUrlWithUnmatchedFilters,
|
||||
getRisonFilterParam,
|
||||
parseRisonFilters,
|
||||
} from 'src/dashboard/util/risonFilters';
|
||||
import UrlFilterTag from './UrlFilterTag';
|
||||
import { UrlFilterIndicator } from './selectors';
|
||||
|
||||
const UrlFiltersVerticalCollapse = (props: {
|
||||
urlFilters: UrlFilterIndicator[];
|
||||
}) => {
|
||||
const { urlFilters: initialFilters } = props;
|
||||
const theme = useTheme();
|
||||
const [isOpen, setIsOpen] = useState(true);
|
||||
const [urlFilters, setUrlFilters] =
|
||||
useState<UrlFilterIndicator[]>(initialFilters);
|
||||
|
||||
const toggleSection = useCallback(() => {
|
||||
setIsOpen(prev => !prev);
|
||||
}, []);
|
||||
|
||||
const handleRemoveFilter = useCallback(
|
||||
(filterToRemove: UrlFilterIndicator) => {
|
||||
const risonParam = getRisonFilterParam();
|
||||
if (!risonParam) return;
|
||||
|
||||
const currentFilters = parseRisonFilters(risonParam);
|
||||
const remaining = currentFilters.filter(
|
||||
f => f.subject !== filterToRemove.filter.subject,
|
||||
);
|
||||
|
||||
updateUrlWithUnmatchedFilters(remaining);
|
||||
setUrlFilters(prev =>
|
||||
prev.filter(f => f.subject !== filterToRemove.subject),
|
||||
);
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
const sectionContainerStyle = useCallback(
|
||||
(theme: SupersetTheme) => css`
|
||||
margin-bottom: ${theme.sizeUnit * 3}px;
|
||||
padding: 0 ${theme.sizeUnit * 4}px;
|
||||
`,
|
||||
[],
|
||||
);
|
||||
|
||||
const sectionHeaderStyle = useCallback(
|
||||
(theme: SupersetTheme) => css`
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
padding: ${theme.sizeUnit * 2}px 0;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
|
||||
&:hover {
|
||||
background: ${theme.colorBgTextHover};
|
||||
margin: 0 -${theme.sizeUnit * 2}px;
|
||||
padding: ${theme.sizeUnit * 2}px;
|
||||
border-radius: ${theme.borderRadius}px;
|
||||
}
|
||||
`,
|
||||
[],
|
||||
);
|
||||
|
||||
const sectionTitleStyle = useCallback(
|
||||
(theme: SupersetTheme) => css`
|
||||
margin: 0;
|
||||
font-size: ${theme.fontSize}px;
|
||||
font-weight: ${theme.fontWeightStrong};
|
||||
color: ${theme.colorText};
|
||||
line-height: 1.3;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: ${theme.sizeUnit}px;
|
||||
`,
|
||||
[],
|
||||
);
|
||||
|
||||
const sectionContentStyle = useCallback(
|
||||
(theme: SupersetTheme) => css`
|
||||
padding: ${theme.sizeUnit * 2}px 0;
|
||||
`,
|
||||
[],
|
||||
);
|
||||
|
||||
const dividerStyle = useCallback(
|
||||
(theme: SupersetTheme) => css`
|
||||
height: 1px;
|
||||
background: ${theme.colorSplit};
|
||||
margin: ${theme.sizeUnit * 2}px 0;
|
||||
`,
|
||||
[],
|
||||
);
|
||||
|
||||
const iconStyle = useCallback(
|
||||
(open: boolean, theme: SupersetTheme) => css`
|
||||
transform: ${open ? 'rotate(0deg)' : 'rotate(180deg)'};
|
||||
transition: transform 0.2s ease;
|
||||
color: ${theme.colorTextSecondary};
|
||||
`,
|
||||
[],
|
||||
);
|
||||
|
||||
const filterIndicators = useMemo(
|
||||
() =>
|
||||
urlFilters.map(filter => (
|
||||
<UrlFilterTag
|
||||
key={filter.subject}
|
||||
filter={filter}
|
||||
orientation={FilterBarOrientation.Vertical}
|
||||
onRemove={handleRemoveFilter}
|
||||
/>
|
||||
)),
|
||||
[urlFilters, handleRemoveFilter],
|
||||
);
|
||||
|
||||
if (!urlFilters.length) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div css={sectionContainerStyle}>
|
||||
<div
|
||||
css={sectionHeaderStyle}
|
||||
onClick={toggleSection}
|
||||
onKeyDown={e => {
|
||||
if (e.key === 'Enter' || e.key === ' ') {
|
||||
e.preventDefault();
|
||||
toggleSection();
|
||||
}
|
||||
}}
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
>
|
||||
<h4 css={sectionTitleStyle}>
|
||||
<Icons.LinkOutlined iconSize="s" />
|
||||
{t('URL Filters')}
|
||||
</h4>
|
||||
<Icons.UpOutlined iconSize="m" css={iconStyle(isOpen, theme)} />
|
||||
</div>
|
||||
{isOpen && <div css={sectionContentStyle}>{filterIndicators}</div>}
|
||||
{isOpen && <div css={dividerStyle} data-test="url-filters-divider" />}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default UrlFiltersVerticalCollapse;
|
||||
@@ -0,0 +1,60 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
import {
|
||||
getRisonFilterParam,
|
||||
parseRisonFilters,
|
||||
RisonFilter,
|
||||
} from 'src/dashboard/util/risonFilters';
|
||||
|
||||
export interface UrlFilterIndicator {
|
||||
subject: string;
|
||||
operator: string;
|
||||
value: string;
|
||||
filter: RisonFilter;
|
||||
}
|
||||
|
||||
function formatFilterValue(filter: RisonFilter): string {
|
||||
const { comparator, operator } = filter;
|
||||
|
||||
if (operator === 'BETWEEN' && Array.isArray(comparator)) {
|
||||
return `${comparator[0]} – ${comparator[1]}`;
|
||||
}
|
||||
|
||||
if (Array.isArray(comparator)) {
|
||||
return comparator.join(', ');
|
||||
}
|
||||
|
||||
return String(comparator);
|
||||
}
|
||||
|
||||
export function getUrlFilterIndicators(): UrlFilterIndicator[] {
|
||||
const risonParam = getRisonFilterParam();
|
||||
if (!risonParam) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const filters = parseRisonFilters(risonParam);
|
||||
return filters.map(filter => ({
|
||||
subject: filter.subject,
|
||||
operator: filter.operator,
|
||||
value: formatFilterValue(filter),
|
||||
filter,
|
||||
}));
|
||||
}
|
||||
@@ -45,6 +45,7 @@ import Header from './Header';
|
||||
import FilterControls from './FilterControls/FilterControls';
|
||||
import CrossFiltersVertical from './CrossFilters/Vertical';
|
||||
import crossFiltersSelector from './CrossFilters/selectors';
|
||||
import UrlFiltersVertical from './UrlFilters/Vertical';
|
||||
|
||||
enum SectionType {
|
||||
Filters = 'filters',
|
||||
@@ -301,6 +302,7 @@ const VerticalFilterBar: FC<VerticalBarProps> = ({
|
||||
) : (
|
||||
<div css={tabPaneStyle} onScroll={onScroll}>
|
||||
<>
|
||||
<UrlFiltersVertical />
|
||||
<CrossFiltersVertical hideHeader={hasOnlyOneSectionType} />
|
||||
{filterControls}
|
||||
</>
|
||||
|
||||
@@ -107,9 +107,16 @@ const publishDataMask = debounce(
|
||||
const previousParams = new URLSearchParams(search);
|
||||
const newParams = new URLSearchParams();
|
||||
let dataMaskKey: string | null;
|
||||
let risonFilterValue: string | null = null;
|
||||
|
||||
previousParams.forEach((value, key) => {
|
||||
if (!EXCLUDED_URL_PARAMS.includes(key)) {
|
||||
newParams.append(key, value);
|
||||
if (key === 'f') {
|
||||
// Preserve the original Rison filter value to avoid encoding
|
||||
risonFilterValue = value;
|
||||
} else {
|
||||
newParams.append(key, value);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -148,9 +155,16 @@ const publishDataMask = debounce(
|
||||
if (appRoot !== '/' && replacementPathname.startsWith(appRoot)) {
|
||||
replacementPathname = replacementPathname.substring(appRoot.length);
|
||||
}
|
||||
// Manually reconstruct the search string to preserve Rison filter encoding
|
||||
let searchString = newParams.toString();
|
||||
if (risonFilterValue) {
|
||||
const separator = searchString ? '&' : '';
|
||||
searchString = `${searchString}${separator}f=${risonFilterValue}`;
|
||||
}
|
||||
|
||||
history.replace({
|
||||
pathname: replacementPathname,
|
||||
search: newParams.toString(),
|
||||
search: searchString,
|
||||
});
|
||||
}
|
||||
},
|
||||
|
||||
@@ -64,6 +64,15 @@ import SyncDashboardState, {
|
||||
getDashboardContextLocalStorage,
|
||||
} from '../components/SyncDashboardState';
|
||||
import { AutoRefreshProvider } from '../contexts/AutoRefreshContext';
|
||||
import { PartialFilters } from '@superset-ui/core';
|
||||
import {
|
||||
parseRisonFilters,
|
||||
risonToAdhocFilters,
|
||||
getRisonFilterParam,
|
||||
prettifyRisonFilterUrl,
|
||||
injectRisonFiltersIntelligently,
|
||||
updateUrlWithUnmatchedFilters,
|
||||
} from '../util/risonFilters';
|
||||
|
||||
export const DashboardPageIdContext = createContext('');
|
||||
|
||||
@@ -195,6 +204,61 @@ export const DashboardPage: FC<PageProps> = ({ idOrSlug }: PageProps) => {
|
||||
dataMask = isOldRison;
|
||||
}
|
||||
|
||||
// Parse Rison URL filters with intelligent native filter injection
|
||||
const risonFilterParam = getRisonFilterParam();
|
||||
if (risonFilterParam) {
|
||||
const risonFilters = parseRisonFilters(risonFilterParam);
|
||||
if (risonFilters.length > 0) {
|
||||
// Convert native filter config array to keyed object for lookup
|
||||
const filterConfigArray =
|
||||
(dashboard?.metadata
|
||||
?.native_filter_configuration as Array<Record<string, unknown> & { id: string }>) ||
|
||||
[];
|
||||
const nativeFilters: PartialFilters = {};
|
||||
filterConfigArray.forEach(filter => {
|
||||
nativeFilters[filter.id] = filter as PartialFilters[string];
|
||||
});
|
||||
const injectionResult = injectRisonFiltersIntelligently(
|
||||
risonFilters,
|
||||
nativeFilters,
|
||||
dataMask,
|
||||
);
|
||||
|
||||
dataMask = injectionResult.updatedDataMask;
|
||||
|
||||
// For unmatched filters, fall back to adhoc filter approach
|
||||
if (injectionResult.unmatchedFilters.length > 0) {
|
||||
const unmatchedAdhocFilters = risonToAdhocFilters(
|
||||
injectionResult.unmatchedFilters,
|
||||
);
|
||||
|
||||
const risonDataMask = {
|
||||
__rison_filters__: {
|
||||
filterState: { value: unmatchedAdhocFilters },
|
||||
ownState: {},
|
||||
},
|
||||
};
|
||||
|
||||
dataMask = { ...dataMask, ...risonDataMask };
|
||||
}
|
||||
|
||||
// Clean up URL: remove matched filters, keep only unmatched ones
|
||||
const matchedCount =
|
||||
risonFilters.length - injectionResult.unmatchedFilters.length;
|
||||
if (matchedCount > 0) {
|
||||
setTimeout(
|
||||
() =>
|
||||
updateUrlWithUnmatchedFilters(injectionResult.unmatchedFilters),
|
||||
100,
|
||||
);
|
||||
}
|
||||
|
||||
if (injectionResult.unmatchedFilters.length > 0) {
|
||||
setTimeout(() => prettifyRisonFilterUrl(), 150);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (readyToRender) {
|
||||
if (!isDashboardHydrated.current) {
|
||||
isDashboardHydrated.current = true;
|
||||
|
||||
325
superset-frontend/src/dashboard/util/risonFilters.test.ts
Normal file
325
superset-frontend/src/dashboard/util/risonFilters.test.ts
Normal file
@@ -0,0 +1,325 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import { PartialFilters, DataMaskStateWithId } from '@superset-ui/core';
|
||||
import {
|
||||
injectRisonFiltersIntelligently,
|
||||
RisonFilter,
|
||||
parseRisonFilters,
|
||||
risonFiltersToString,
|
||||
risonToAdhocFilters,
|
||||
} from './risonFilters';
|
||||
|
||||
const mockNativeFilters: PartialFilters = {
|
||||
filter_1: {
|
||||
id: 'filter_1',
|
||||
targets: [
|
||||
{
|
||||
column: { name: 'country' },
|
||||
datasetId: 1,
|
||||
},
|
||||
],
|
||||
filterType: 'filter_select',
|
||||
},
|
||||
filter_2: {
|
||||
id: 'filter_2',
|
||||
targets: [
|
||||
{
|
||||
column: { name: 'year' },
|
||||
datasetId: 1,
|
||||
},
|
||||
],
|
||||
filterType: 'filter_range',
|
||||
},
|
||||
filter_3: {
|
||||
id: 'filter_3',
|
||||
targets: [
|
||||
{
|
||||
column: { name: 'Country Code' },
|
||||
datasetId: 1,
|
||||
},
|
||||
],
|
||||
filterType: 'filter_select',
|
||||
},
|
||||
};
|
||||
|
||||
const mockDataMask: DataMaskStateWithId = {
|
||||
filter_1: {
|
||||
id: 'filter_1',
|
||||
filterState: { value: undefined },
|
||||
ownState: {},
|
||||
},
|
||||
};
|
||||
|
||||
test('should parse simple Rison filters', () => {
|
||||
const risonString = '(country:USA,year:2024)';
|
||||
const result = parseRisonFilters(risonString);
|
||||
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result[0]).toEqual({
|
||||
subject: 'country',
|
||||
operator: '==',
|
||||
comparator: 'USA',
|
||||
});
|
||||
expect(result[1]).toEqual({
|
||||
subject: 'year',
|
||||
operator: '==',
|
||||
comparator: 2024,
|
||||
});
|
||||
});
|
||||
|
||||
test('should parse IN operator with array syntax', () => {
|
||||
const result = parseRisonFilters('(country:!(USA,Canada))');
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0]).toEqual({
|
||||
subject: 'country',
|
||||
operator: 'IN',
|
||||
comparator: ['USA', 'Canada'],
|
||||
});
|
||||
});
|
||||
|
||||
test('should parse BETWEEN operator', () => {
|
||||
const result = parseRisonFilters('(msrp:(between:!(35,200)))');
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0]).toEqual({
|
||||
subject: 'msrp',
|
||||
operator: 'BETWEEN',
|
||||
comparator: [35, 200],
|
||||
});
|
||||
});
|
||||
|
||||
test('should parse NOT operator', () => {
|
||||
const result = parseRisonFilters('(NOT:(country:USA))');
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].operator).toBe('!=');
|
||||
expect(result[0].comparator).toBe('USA');
|
||||
});
|
||||
|
||||
test('should parse comparison operators', () => {
|
||||
expect(parseRisonFilters('(sales:(gt:100000))')[0].operator).toBe('>');
|
||||
expect(parseRisonFilters('(age:(gte:18))')[0].operator).toBe('>=');
|
||||
expect(parseRisonFilters('(temp:(lt:32))')[0].operator).toBe('<');
|
||||
expect(parseRisonFilters('(price:(lte:1000))')[0].operator).toBe('<=');
|
||||
});
|
||||
|
||||
test('should return empty array for invalid Rison', () => {
|
||||
expect(parseRisonFilters('invalid rison')).toEqual([]);
|
||||
expect(parseRisonFilters('(unclosed')).toEqual([]);
|
||||
});
|
||||
|
||||
test('should match Rison filter to native filter by column name', () => {
|
||||
const risonFilters: RisonFilter[] = [
|
||||
{ subject: 'country', operator: '==', comparator: 'USA' },
|
||||
];
|
||||
|
||||
const result = injectRisonFiltersIntelligently(
|
||||
risonFilters,
|
||||
mockNativeFilters,
|
||||
mockDataMask,
|
||||
);
|
||||
|
||||
expect(result.updatedDataMask.filter_1.filterState?.value).toEqual(['USA']);
|
||||
expect(result.unmatchedFilters).toHaveLength(0);
|
||||
});
|
||||
|
||||
test('should match column names with spaces (case-insensitive)', () => {
|
||||
const risonFilters: RisonFilter[] = [
|
||||
{ subject: 'Country Code', operator: '==', comparator: 'USA' },
|
||||
];
|
||||
|
||||
const result = injectRisonFiltersIntelligently(
|
||||
risonFilters,
|
||||
mockNativeFilters,
|
||||
mockDataMask,
|
||||
);
|
||||
|
||||
expect(result.updatedDataMask.filter_3.filterState?.value).toEqual(['USA']);
|
||||
expect(result.unmatchedFilters).toHaveLength(0);
|
||||
});
|
||||
|
||||
test('should match column names case-insensitively', () => {
|
||||
const risonFilters: RisonFilter[] = [
|
||||
{ subject: 'country code', operator: '==', comparator: 'USA' },
|
||||
];
|
||||
|
||||
const result = injectRisonFiltersIntelligently(
|
||||
risonFilters,
|
||||
mockNativeFilters,
|
||||
mockDataMask,
|
||||
);
|
||||
|
||||
expect(result.updatedDataMask.filter_3.filterState?.value).toEqual(['USA']);
|
||||
expect(result.unmatchedFilters).toHaveLength(0);
|
||||
});
|
||||
|
||||
test('should handle unmatched filters with fallback', () => {
|
||||
const risonFilters: RisonFilter[] = [
|
||||
{ subject: 'region', operator: '==', comparator: 'North America' },
|
||||
];
|
||||
|
||||
const result = injectRisonFiltersIntelligently(
|
||||
risonFilters,
|
||||
mockNativeFilters,
|
||||
mockDataMask,
|
||||
);
|
||||
|
||||
expect(result.unmatchedFilters).toHaveLength(1);
|
||||
expect(result.unmatchedFilters[0].subject).toBe('region');
|
||||
});
|
||||
|
||||
test('should convert values correctly for different filter types', () => {
|
||||
const risonFilters: RisonFilter[] = [
|
||||
{ subject: 'country', operator: '==', comparator: 'USA' },
|
||||
{ subject: 'year', operator: 'BETWEEN', comparator: [2020, 2024] },
|
||||
];
|
||||
|
||||
const result = injectRisonFiltersIntelligently(
|
||||
risonFilters,
|
||||
mockNativeFilters,
|
||||
mockDataMask,
|
||||
);
|
||||
|
||||
// Select filter should be array
|
||||
expect(result.updatedDataMask.filter_1.filterState?.value).toEqual(['USA']);
|
||||
|
||||
// Range filter should be min/max object
|
||||
expect(result.updatedDataMask.filter_2.filterState?.value).toEqual({
|
||||
min: 2020,
|
||||
max: 2024,
|
||||
});
|
||||
|
||||
expect(result.unmatchedFilters).toHaveLength(0);
|
||||
});
|
||||
|
||||
test('should set extraFormData for auto-application on select filters', () => {
|
||||
const risonFilters: RisonFilter[] = [
|
||||
{ subject: 'country', operator: '==', comparator: 'USA' },
|
||||
];
|
||||
|
||||
const result = injectRisonFiltersIntelligently(
|
||||
risonFilters,
|
||||
mockNativeFilters,
|
||||
mockDataMask,
|
||||
);
|
||||
|
||||
expect(result.updatedDataMask.filter_1.extraFormData).toEqual({
|
||||
filters: [{ col: 'country', op: 'IN', val: ['USA'] }],
|
||||
});
|
||||
});
|
||||
|
||||
test('should set extraFormData for auto-application on IN filters', () => {
|
||||
const risonFilters: RisonFilter[] = [
|
||||
{ subject: 'country', operator: 'IN', comparator: ['USA', 'Canada'] },
|
||||
];
|
||||
|
||||
const result = injectRisonFiltersIntelligently(
|
||||
risonFilters,
|
||||
mockNativeFilters,
|
||||
mockDataMask,
|
||||
);
|
||||
|
||||
expect(result.updatedDataMask.filter_1.filterState?.value).toEqual([
|
||||
'USA',
|
||||
'Canada',
|
||||
]);
|
||||
expect(result.updatedDataMask.filter_1.extraFormData).toEqual({
|
||||
filters: [{ col: 'country', op: 'IN', val: ['USA', 'Canada'] }],
|
||||
});
|
||||
});
|
||||
|
||||
test('should set extraFormData for auto-application on BETWEEN filters', () => {
|
||||
const risonFilters: RisonFilter[] = [
|
||||
{ subject: 'year', operator: 'BETWEEN', comparator: [2020, 2024] },
|
||||
];
|
||||
|
||||
const result = injectRisonFiltersIntelligently(
|
||||
risonFilters,
|
||||
mockNativeFilters,
|
||||
mockDataMask,
|
||||
);
|
||||
|
||||
expect(result.updatedDataMask.filter_2.filterState?.value).toEqual({
|
||||
min: 2020,
|
||||
max: 2024,
|
||||
});
|
||||
expect(result.updatedDataMask.filter_2.extraFormData).toEqual({
|
||||
filters: [
|
||||
{ col: 'year', op: '>=', val: 2020 },
|
||||
{ col: 'year', op: '<=', val: 2024 },
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
test('should handle mixed matched and unmatched filters', () => {
|
||||
const risonFilters: RisonFilter[] = [
|
||||
{ subject: 'country', operator: '==', comparator: 'USA' },
|
||||
{ subject: 'category', operator: '==', comparator: 'Sales' },
|
||||
];
|
||||
|
||||
const result = injectRisonFiltersIntelligently(
|
||||
risonFilters,
|
||||
mockNativeFilters,
|
||||
mockDataMask,
|
||||
);
|
||||
|
||||
expect(result.updatedDataMask.filter_1.filterState?.value).toEqual(['USA']);
|
||||
expect(result.unmatchedFilters).toHaveLength(1);
|
||||
expect(result.unmatchedFilters[0].subject).toBe('category');
|
||||
});
|
||||
|
||||
test('should convert filters to adhoc format', () => {
|
||||
const risonFilters: RisonFilter[] = [
|
||||
{ subject: 'country', operator: '==', comparator: 'USA' },
|
||||
];
|
||||
|
||||
const adhocFilters = risonToAdhocFilters(risonFilters);
|
||||
|
||||
expect(adhocFilters).toHaveLength(1);
|
||||
expect(adhocFilters[0]).toMatchObject({
|
||||
expressionType: 'SIMPLE',
|
||||
clause: 'WHERE',
|
||||
subject: 'country',
|
||||
operator: '==',
|
||||
comparator: 'USA',
|
||||
});
|
||||
});
|
||||
|
||||
test('should convert filters to Rison string', () => {
|
||||
const filters: RisonFilter[] = [
|
||||
{ subject: 'country', operator: '==', comparator: 'USA' },
|
||||
];
|
||||
|
||||
const result = risonFiltersToString(filters);
|
||||
expect(result).toBe('(country:USA)');
|
||||
});
|
||||
|
||||
test('should convert IN filters to Rison string', () => {
|
||||
const filters: RisonFilter[] = [
|
||||
{ subject: 'country', operator: 'IN', comparator: ['USA', 'Canada'] },
|
||||
];
|
||||
|
||||
const result = risonFiltersToString(filters);
|
||||
expect(result).toBe('(country:!(USA,Canada))');
|
||||
});
|
||||
|
||||
test('should return empty string for empty filters', () => {
|
||||
expect(risonFiltersToString([])).toBe('');
|
||||
});
|
||||
490
superset-frontend/src/dashboard/util/risonFilters.ts
Normal file
490
superset-frontend/src/dashboard/util/risonFilters.ts
Normal file
@@ -0,0 +1,490 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
import {
|
||||
QueryObjectFilterClause,
|
||||
PartialFilters,
|
||||
DataMaskStateWithId,
|
||||
} from '@superset-ui/core';
|
||||
import rison from 'rison';
|
||||
|
||||
export interface RisonFilter {
|
||||
subject: string;
|
||||
operator: string;
|
||||
comparator: string | number | boolean | (string | number)[];
|
||||
}
|
||||
|
||||
export interface IntelligentRisonInjectionResult {
|
||||
updatedDataMask: DataMaskStateWithId;
|
||||
unmatchedFilters: RisonFilter[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse Rison filter syntax from URL parameter.
|
||||
* Supports formats like: (country:USA,year:2024)
|
||||
*/
|
||||
export function parseRisonFilters(risonString: string): RisonFilter[] {
|
||||
try {
|
||||
const parsed = rison.decode(risonString);
|
||||
const filters: RisonFilter[] = [];
|
||||
|
||||
if (!parsed || typeof parsed !== 'object') {
|
||||
return filters;
|
||||
}
|
||||
|
||||
const parsedObj = parsed as Record<string, unknown>;
|
||||
|
||||
// Handle OR operator: OR:!(condition1,condition2)
|
||||
if (parsedObj.OR && Array.isArray(parsedObj.OR)) {
|
||||
(parsedObj.OR as Record<string, unknown>[]).forEach(condition => {
|
||||
if (typeof condition === 'object') {
|
||||
Object.entries(condition).forEach(([key, value]) => {
|
||||
filters.push(parseFilterCondition(key, value));
|
||||
});
|
||||
}
|
||||
});
|
||||
return filters;
|
||||
}
|
||||
|
||||
// Handle NOT operator: NOT:(condition)
|
||||
if (parsedObj.NOT && typeof parsedObj.NOT === 'object') {
|
||||
Object.entries(parsedObj.NOT as Record<string, unknown>).forEach(
|
||||
([key, value]) => {
|
||||
const filter = parseFilterCondition(key, value);
|
||||
if (filter.operator === '==') {
|
||||
filter.operator = '!=';
|
||||
} else if (filter.operator === 'IN') {
|
||||
filter.operator = 'NOT IN';
|
||||
}
|
||||
filters.push(filter);
|
||||
},
|
||||
);
|
||||
return filters;
|
||||
}
|
||||
|
||||
// Handle regular filters
|
||||
Object.entries(parsedObj).forEach(([key, value]) => {
|
||||
if (key !== 'OR' && key !== 'NOT') {
|
||||
filters.push(parseFilterCondition(key, value));
|
||||
}
|
||||
});
|
||||
|
||||
return filters;
|
||||
} catch (error) {
|
||||
console.warn('Failed to parse Rison filters:', error);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse individual filter condition
|
||||
*/
|
||||
function parseFilterCondition(key: string, value: unknown): RisonFilter {
|
||||
// Handle comparison operators: (gt:100), (between:!(1,10))
|
||||
if (typeof value === 'object' && value !== null && !Array.isArray(value)) {
|
||||
const [operator, operatorValue] = Object.entries(
|
||||
value as Record<string, unknown>,
|
||||
)[0];
|
||||
|
||||
switch (operator) {
|
||||
case 'gt':
|
||||
return {
|
||||
subject: key,
|
||||
operator: '>',
|
||||
comparator: operatorValue as string | number,
|
||||
};
|
||||
case 'gte':
|
||||
return {
|
||||
subject: key,
|
||||
operator: '>=',
|
||||
comparator: operatorValue as string | number,
|
||||
};
|
||||
case 'lt':
|
||||
return {
|
||||
subject: key,
|
||||
operator: '<',
|
||||
comparator: operatorValue as string | number,
|
||||
};
|
||||
case 'lte':
|
||||
return {
|
||||
subject: key,
|
||||
operator: '<=',
|
||||
comparator: operatorValue as string | number,
|
||||
};
|
||||
case 'between':
|
||||
return {
|
||||
subject: key,
|
||||
operator: 'BETWEEN',
|
||||
comparator: operatorValue as (string | number)[],
|
||||
};
|
||||
case 'like':
|
||||
return {
|
||||
subject: key,
|
||||
operator: 'LIKE',
|
||||
comparator: operatorValue as string,
|
||||
};
|
||||
default:
|
||||
return {
|
||||
subject: key,
|
||||
operator: '==',
|
||||
comparator: value as string | number,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Handle IN operator: !(value1,value2)
|
||||
if (Array.isArray(value)) {
|
||||
return {
|
||||
subject: key,
|
||||
operator: 'IN',
|
||||
comparator: value as (string | number)[],
|
||||
};
|
||||
}
|
||||
|
||||
// Handle simple equality
|
||||
return {
|
||||
subject: key,
|
||||
operator: '==',
|
||||
comparator: value as string | number | boolean,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert Rison filters to Superset adhoc filter format
|
||||
*/
|
||||
export function risonToAdhocFilters(
|
||||
risonFilters: RisonFilter[],
|
||||
): QueryObjectFilterClause[] {
|
||||
return risonFilters.map(
|
||||
filter =>
|
||||
({
|
||||
expressionType: 'SIMPLE' as const,
|
||||
clause: 'WHERE' as const,
|
||||
subject: filter.subject,
|
||||
operator: filter.operator,
|
||||
comparator: filter.comparator,
|
||||
}) as unknown as QueryObjectFilterClause,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Prettify Rison filter URL by replacing encoded characters.
|
||||
* Uses browser history API to update URL without page reload.
|
||||
*/
|
||||
export function prettifyRisonFilterUrl(): void {
|
||||
try {
|
||||
const currentUrl = window.location.href;
|
||||
|
||||
if (!currentUrl.includes('&f=') && !currentUrl.includes('?f=')) {
|
||||
return;
|
||||
}
|
||||
|
||||
const urlMatch = currentUrl.match(/([?&])f=([^&]*)/);
|
||||
if (!urlMatch) {
|
||||
return;
|
||||
}
|
||||
|
||||
const separator = urlMatch[1];
|
||||
let risonValue = urlMatch[2];
|
||||
|
||||
if (!risonValue.includes('%') && !risonValue.includes('+')) {
|
||||
return;
|
||||
}
|
||||
|
||||
let previousValue = '';
|
||||
let decodeAttempts = 0;
|
||||
while (risonValue !== previousValue && decodeAttempts < 5) {
|
||||
previousValue = risonValue;
|
||||
try {
|
||||
if (risonValue.includes('%')) {
|
||||
risonValue = decodeURIComponent(risonValue);
|
||||
}
|
||||
} catch {
|
||||
break;
|
||||
}
|
||||
decodeAttempts += 1;
|
||||
}
|
||||
|
||||
risonValue = risonValue.replace(/\+/g, ' ');
|
||||
|
||||
const matchIndex = urlMatch.index ?? 0;
|
||||
const beforeRison = currentUrl.substring(0, matchIndex);
|
||||
const afterRison = currentUrl.substring(matchIndex + urlMatch[0].length);
|
||||
const prettifiedUrl = `${beforeRison}${separator}f=${risonValue}${afterRison}`;
|
||||
|
||||
if (prettifiedUrl !== currentUrl) {
|
||||
window.history.replaceState(window.history.state, '', prettifiedUrl);
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('Failed to prettify Rison URL:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get Rison filter parameter from URL
|
||||
*/
|
||||
export function getRisonFilterParam(): string | null {
|
||||
const params = new URLSearchParams(window.location.search);
|
||||
return params.get('f');
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert an array of RisonFilter back to Rison string format
|
||||
*/
|
||||
export function risonFiltersToString(filters: RisonFilter[]): string {
|
||||
if (filters.length === 0) {
|
||||
return '';
|
||||
}
|
||||
|
||||
const risonObject: Record<
|
||||
string,
|
||||
string | number | boolean | (string | number)[] | Record<string, unknown>
|
||||
> = {};
|
||||
|
||||
filters.forEach(filter => {
|
||||
if (filter.operator === 'IN' && Array.isArray(filter.comparator)) {
|
||||
risonObject[filter.subject] = filter.comparator;
|
||||
} else if (filter.operator === '==') {
|
||||
risonObject[filter.subject] = filter.comparator;
|
||||
} else {
|
||||
const operatorMap: Record<string, string> = {
|
||||
'>': 'gt',
|
||||
'>=': 'gte',
|
||||
'<': 'lt',
|
||||
'<=': 'lte',
|
||||
BETWEEN: 'between',
|
||||
LIKE: 'like',
|
||||
};
|
||||
|
||||
const risonOp = operatorMap[filter.operator] || filter.operator;
|
||||
risonObject[filter.subject] = { [risonOp]: filter.comparator };
|
||||
}
|
||||
});
|
||||
|
||||
try {
|
||||
return rison.encode(risonObject);
|
||||
} catch (error) {
|
||||
console.warn('Failed to encode Rison filters:', error);
|
||||
return '';
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update the URL to remove successfully matched filters, keeping only unmatched ones
|
||||
*/
|
||||
export function updateUrlWithUnmatchedFilters(
|
||||
unmatchedFilters: RisonFilter[],
|
||||
): void {
|
||||
try {
|
||||
const currentUrl = new URL(window.location.href);
|
||||
|
||||
if (unmatchedFilters.length === 0) {
|
||||
currentUrl.searchParams.delete('f');
|
||||
} else {
|
||||
const newRisonString = risonFiltersToString(unmatchedFilters);
|
||||
if (newRisonString) {
|
||||
currentUrl.searchParams.set('f', newRisonString);
|
||||
} else {
|
||||
currentUrl.searchParams.delete('f');
|
||||
}
|
||||
}
|
||||
|
||||
window.history.replaceState(
|
||||
window.history.state,
|
||||
'',
|
||||
currentUrl.toString(),
|
||||
);
|
||||
} catch (error) {
|
||||
console.warn('Failed to update URL with unmatched filters:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Find a native filter that matches a Rison filter by column name.
|
||||
* Uses case-insensitive, trimmed comparison to handle column names with spaces.
|
||||
*/
|
||||
function findMatchingNativeFilter(
|
||||
risonFilter: RisonFilter,
|
||||
nativeFilters: PartialFilters,
|
||||
): string | null {
|
||||
const normalizedSubject = risonFilter.subject.trim().toLowerCase();
|
||||
|
||||
for (const [filterId, nativeFilter] of Object.entries(nativeFilters)) {
|
||||
if (!nativeFilter?.targets) continue;
|
||||
|
||||
const hasMatchingTarget = nativeFilter.targets.some(target => {
|
||||
if (typeof target === 'object' && target && 'column' in target) {
|
||||
return (
|
||||
target.column?.name?.trim().toLowerCase() === normalizedSubject
|
||||
);
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
if (hasMatchingTarget) {
|
||||
return filterId;
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build extraFormData filters for a given rison filter and column name
|
||||
*/
|
||||
function buildExtraFormDataFilters(
|
||||
risonFilter: RisonFilter,
|
||||
columnName: string,
|
||||
): { col: string; op: string; val: unknown }[] {
|
||||
const { operator, comparator } = risonFilter;
|
||||
|
||||
if (operator === 'IN' || (operator === '==' && Array.isArray(comparator))) {
|
||||
return [
|
||||
{
|
||||
col: columnName,
|
||||
op: 'IN',
|
||||
val: Array.isArray(comparator) ? comparator : [comparator],
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
if (operator === '==' && !Array.isArray(comparator)) {
|
||||
return [{ col: columnName, op: 'IN', val: [comparator] }];
|
||||
}
|
||||
|
||||
if (
|
||||
operator === 'BETWEEN' &&
|
||||
Array.isArray(comparator) &&
|
||||
comparator.length === 2
|
||||
) {
|
||||
return [
|
||||
{ col: columnName, op: '>=', val: comparator[0] },
|
||||
{ col: columnName, op: '<=', val: comparator[1] },
|
||||
];
|
||||
}
|
||||
|
||||
return [{ col: columnName, op: operator, val: comparator }];
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a Rison filter value to the format expected by a native filter.
|
||||
* Also returns extraFormData for auto-application.
|
||||
*/
|
||||
function convertRisonToNativeValue(
|
||||
risonFilter: RisonFilter,
|
||||
nativeFilter: { filterType?: string },
|
||||
): unknown {
|
||||
const { comparator, operator } = risonFilter;
|
||||
const filterType = nativeFilter?.filterType;
|
||||
|
||||
switch (filterType) {
|
||||
case 'filter_select':
|
||||
if (operator === 'IN' || Array.isArray(comparator)) {
|
||||
return Array.isArray(comparator) ? comparator : [comparator];
|
||||
}
|
||||
return [comparator];
|
||||
|
||||
case 'filter_range':
|
||||
if (
|
||||
operator === 'BETWEEN' &&
|
||||
Array.isArray(comparator) &&
|
||||
comparator.length === 2
|
||||
) {
|
||||
return { min: comparator[0], max: comparator[1] };
|
||||
}
|
||||
return comparator;
|
||||
|
||||
case 'filter_time_range':
|
||||
case 'filter_timecolumn':
|
||||
return comparator;
|
||||
|
||||
default:
|
||||
return Array.isArray(comparator) ? comparator : [comparator];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a complete DataMask entry for a rison filter matched to a native filter.
|
||||
* Sets both filterState.value AND extraFormData so the filter auto-applies.
|
||||
*/
|
||||
function buildDataMaskForFilter(
|
||||
risonFilter: RisonFilter,
|
||||
nativeFilter: { id: string; filterType?: string; targets?: { column?: { name?: string } }[] },
|
||||
columnName: string,
|
||||
) {
|
||||
const convertedValue = convertRisonToNativeValue(risonFilter, nativeFilter);
|
||||
|
||||
return {
|
||||
id: nativeFilter.id,
|
||||
filterState: {
|
||||
value: convertedValue,
|
||||
},
|
||||
extraFormData: {
|
||||
filters: buildExtraFormDataFilters(risonFilter, columnName),
|
||||
},
|
||||
ownState: {},
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Intelligently inject Rison filters into native filters where possible,
|
||||
* falling back to brute-force injection for unmatched filters
|
||||
*/
|
||||
export function injectRisonFiltersIntelligently(
|
||||
risonFilters: RisonFilter[],
|
||||
nativeFilters: PartialFilters,
|
||||
currentDataMask: DataMaskStateWithId,
|
||||
): IntelligentRisonInjectionResult {
|
||||
const updatedDataMask = { ...currentDataMask };
|
||||
const unmatchedFilters: RisonFilter[] = [];
|
||||
|
||||
risonFilters.forEach(risonFilter => {
|
||||
const matchingFilterId = findMatchingNativeFilter(
|
||||
risonFilter,
|
||||
nativeFilters,
|
||||
);
|
||||
|
||||
if (matchingFilterId) {
|
||||
const matchedFilter = nativeFilters[matchingFilterId];
|
||||
if (matchedFilter) {
|
||||
const columnName =
|
||||
matchedFilter.targets?.[0]?.column?.name ?? risonFilter.subject;
|
||||
|
||||
const dataMaskEntry = buildDataMaskForFilter(
|
||||
risonFilter,
|
||||
matchedFilter as { id: string; filterType?: string; targets?: { column?: { name?: string } }[] },
|
||||
columnName,
|
||||
);
|
||||
|
||||
updatedDataMask[matchedFilter.id] = {
|
||||
...updatedDataMask[matchedFilter.id],
|
||||
...dataMaskEntry,
|
||||
};
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
unmatchedFilters.push(risonFilter);
|
||||
});
|
||||
|
||||
return {
|
||||
updatedDataMask,
|
||||
unmatchedFilters,
|
||||
};
|
||||
}
|
||||
@@ -590,7 +590,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
# Driver-specific params to be included in the `get_oauth2_token` request body
|
||||
oauth2_additional_token_request_params: dict[str, Any] = {}
|
||||
# Driver-specific exception that should be mapped to OAuth2RedirectError
|
||||
oauth2_exception = OAuth2RedirectError
|
||||
oauth2_exception: type[Exception] | tuple[type[Exception], ...] = (
|
||||
OAuth2RedirectError
|
||||
)
|
||||
|
||||
# Does the query id related to the connection?
|
||||
# The default value is True, which means that the query id is determined when
|
||||
|
||||
@@ -31,6 +31,7 @@ from marshmallow import fields, Schema
|
||||
from marshmallow.exceptions import ValidationError
|
||||
from requests import Session
|
||||
from shillelagh.adapters.api.gsheets.lib import SCOPES
|
||||
from shillelagh.exceptions import UnauthenticatedError
|
||||
from sqlalchemy.engine import create_engine
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import URL
|
||||
@@ -40,7 +41,7 @@ from superset.databases.schemas import encrypted_field_properties, EncryptedStri
|
||||
from superset.db_engine_specs.base import DatabaseCategory
|
||||
from superset.db_engine_specs.shillelagh import ShillelaghEngineSpec
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.exceptions import OAuth2TokenRefreshError, SupersetException
|
||||
from superset.utils import json
|
||||
from superset.utils.oauth2 import get_oauth2_access_token
|
||||
|
||||
@@ -151,6 +152,7 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
|
||||
"https://accounts.google.com/o/oauth2/v2/auth"
|
||||
)
|
||||
oauth2_token_request_uri = "https://oauth2.googleapis.com/token" # noqa: S105
|
||||
oauth2_exception = (UnauthenticatedError, OAuth2TokenRefreshError)
|
||||
|
||||
@classmethod
|
||||
def get_oauth2_authorization_uri(
|
||||
|
||||
@@ -62,6 +62,7 @@ Dataset Management:
|
||||
- list_datasets: List datasets with advanced filters (1-based pagination)
|
||||
- get_dataset_info: Get detailed dataset information by ID (includes columns/metrics)
|
||||
- create_virtual_dataset: Save a SQL query as a virtual dataset for charting
|
||||
- query_dataset: Query a dataset using its semantic layer (saved metrics, dimensions, filters) without needing a saved chart
|
||||
|
||||
Chart Management:
|
||||
- list_charts: List charts with advanced filters (1-based pagination)
|
||||
@@ -164,6 +165,17 @@ Use created_by_me for authorship, owned_by_me for edit ownership, or both
|
||||
together for the union. All flags can be combined with 'filters' but not
|
||||
with 'search'.
|
||||
|
||||
To query a dataset's semantic layer (metrics, dimensions):
|
||||
1. list_datasets(request={{}}) -> find a dataset
|
||||
2. get_dataset_info(request={{"identifier": <id>}}) -> examine columns AND metrics
|
||||
3. query_dataset(request={{
|
||||
"dataset_id": <id>,
|
||||
"metrics": ["count", "avg_revenue"],
|
||||
"columns": ["category"],
|
||||
"time_range": "Last 7 days",
|
||||
"row_limit": 100
|
||||
}}) -> returns tabular data using saved metrics and dimensions
|
||||
|
||||
To explore data with SQL:
|
||||
1. list_datasets(request={{}}) -> find a dataset and note its database_id
|
||||
2. execute_sql(request={{"database_id": <id>, "sql": "SELECT ..."}})
|
||||
@@ -520,6 +532,7 @@ from superset.mcp_service.dataset.tool import ( # noqa: F401, E402
|
||||
create_virtual_dataset,
|
||||
get_dataset_info,
|
||||
list_datasets,
|
||||
query_dataset,
|
||||
)
|
||||
from superset.mcp_service.explore.tool import ( # noqa: F401, E402
|
||||
generate_explore_link,
|
||||
|
||||
@@ -70,6 +70,8 @@ SORTABLE_CHART_COLUMNS = [
|
||||
"created_on",
|
||||
]
|
||||
|
||||
_DEFAULT_LIST_CHARTS_REQUEST = ListChartsRequest()
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["core"],
|
||||
@@ -81,7 +83,8 @@ SORTABLE_CHART_COLUMNS = [
|
||||
),
|
||||
)
|
||||
async def list_charts(
|
||||
request: ListChartsRequest, ctx: Context
|
||||
request: ListChartsRequest | None = None,
|
||||
ctx: Context = None,
|
||||
) -> ChartList | ChartError:
|
||||
"""List charts with filtering and search.
|
||||
|
||||
@@ -91,6 +94,7 @@ async def list_charts(
|
||||
Sortable columns for order_column: id, slice_name, viz_type, description,
|
||||
changed_on, created_on
|
||||
"""
|
||||
request = request or _DEFAULT_LIST_CHARTS_REQUEST.model_copy(deep=True)
|
||||
await ctx.info(
|
||||
"Listing charts: page=%s, page_size=%s, search=%s"
|
||||
% (
|
||||
|
||||
@@ -65,6 +65,8 @@ SORTABLE_DASHBOARD_COLUMNS = [
|
||||
"created_on",
|
||||
]
|
||||
|
||||
_DEFAULT_LIST_DASHBOARDS_REQUEST = ListDashboardsRequest()
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["core"],
|
||||
@@ -76,7 +78,8 @@ SORTABLE_DASHBOARD_COLUMNS = [
|
||||
),
|
||||
)
|
||||
async def list_dashboards(
|
||||
request: ListDashboardsRequest, ctx: Context
|
||||
request: ListDashboardsRequest | None = None,
|
||||
ctx: Context = None,
|
||||
) -> DashboardList:
|
||||
"""List dashboards with filtering and search. Returns dashboard metadata
|
||||
including title, slug, URL, and last modified time. Use select_columns to
|
||||
@@ -85,6 +88,7 @@ async def list_dashboards(
|
||||
Sortable columns for order_column: id, dashboard_title, slug, published,
|
||||
changed_on, created_on
|
||||
"""
|
||||
request = request or _DEFAULT_LIST_DASHBOARDS_REQUEST.model_copy(deep=True)
|
||||
await ctx.info(
|
||||
"Listing dashboards: page=%s, page_size=%s, search=%s"
|
||||
% (
|
||||
|
||||
@@ -36,10 +36,13 @@ from pydantic import (
|
||||
)
|
||||
|
||||
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
|
||||
from superset.mcp_service.chart.schemas import DataColumn, PerformanceMetadata
|
||||
from superset.mcp_service.common.cache_schemas import (
|
||||
CacheStatus,
|
||||
CreatedByMeMixin,
|
||||
MetadataCacheControl,
|
||||
OwnedByMeMixin,
|
||||
QueryCacheControl,
|
||||
)
|
||||
from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
|
||||
from superset.mcp_service.privacy import filter_user_directory_fields
|
||||
@@ -393,6 +396,146 @@ class CreateVirtualDatasetResponse(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
VALID_FILTER_OPS = Literal[
|
||||
"==",
|
||||
"!=",
|
||||
">",
|
||||
"<",
|
||||
">=",
|
||||
"<=",
|
||||
"LIKE",
|
||||
"NOT LIKE",
|
||||
"ILIKE",
|
||||
"NOT ILIKE",
|
||||
"IN",
|
||||
"NOT IN",
|
||||
"IS NULL",
|
||||
"IS NOT NULL",
|
||||
"IS TRUE",
|
||||
"IS FALSE",
|
||||
"TEMPORAL_RANGE",
|
||||
]
|
||||
|
||||
|
||||
class QueryDatasetFilter(BaseModel):
|
||||
"""A single filter condition for dataset queries."""
|
||||
|
||||
col: str = Field(..., description="Column name to filter on")
|
||||
op: VALID_FILTER_OPS = Field(
|
||||
...,
|
||||
description=(
|
||||
'Filter operator. Use "==" for equals, "!=" for not equals, '
|
||||
'"IN" / "NOT IN" for membership, "IS NULL" / "IS NOT NULL", '
|
||||
'"LIKE" for pattern matching, "TEMPORAL_RANGE" for time filters.'
|
||||
),
|
||||
)
|
||||
val: Any = Field(
|
||||
default=None,
|
||||
description="Filter value (omit for IS NULL/IS NOT NULL)",
|
||||
)
|
||||
|
||||
|
||||
class QueryDatasetRequest(QueryCacheControl):
|
||||
"""Request schema for query_dataset tool."""
|
||||
|
||||
dataset_id: int | str = Field(
|
||||
...,
|
||||
description="Dataset identifier — numeric ID or UUID string.",
|
||||
)
|
||||
metrics: List[str] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"Saved metric names to compute (e.g. ['count', 'avg_revenue']). "
|
||||
"Use get_dataset_info to discover available metrics."
|
||||
),
|
||||
)
|
||||
columns: List[str] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"Column/dimension names for GROUP BY or SELECT "
|
||||
"(e.g. ['category', 'region']). "
|
||||
"Use get_dataset_info to discover available columns."
|
||||
),
|
||||
)
|
||||
filters: List[QueryDatasetFilter] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
'Filter conditions (e.g. [{"col": "status", "op": "==", "val": "active"}]).'
|
||||
),
|
||||
)
|
||||
time_range: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Time range filter (e.g. 'Last 7 days', 'Last month', "
|
||||
"'2024-01-01 : 2024-12-31'). Requires a temporal column "
|
||||
"on the dataset."
|
||||
),
|
||||
)
|
||||
time_column: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Temporal column to apply time_range to. "
|
||||
"Defaults to the dataset's main datetime column."
|
||||
),
|
||||
)
|
||||
order_by: List[str] | None = Field(
|
||||
default=None,
|
||||
description="Column or metric names to sort results by.",
|
||||
)
|
||||
order_desc: bool = Field(
|
||||
default=True,
|
||||
description="Sort descending (True) or ascending (False).",
|
||||
)
|
||||
row_limit: int = Field(
|
||||
default=1000,
|
||||
ge=1,
|
||||
le=50000,
|
||||
description="Maximum number of rows to return (default 1000, max 50000).",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_metrics_or_columns(self) -> "QueryDatasetRequest":
|
||||
"""At least one of metrics or columns must be provided."""
|
||||
if not self.metrics and not self.columns:
|
||||
raise ValueError(
|
||||
"At least one of 'metrics' or 'columns' must be provided. "
|
||||
"Use get_dataset_info to discover available metrics and columns."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class QueryDatasetResponse(BaseModel):
|
||||
"""Response schema for query_dataset tool."""
|
||||
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
dataset_id: int = Field(..., description="Dataset ID")
|
||||
dataset_name: str = Field(..., description="Dataset name")
|
||||
columns: List[DataColumn] = Field(
|
||||
default_factory=list, description="Column metadata for returned data"
|
||||
)
|
||||
data: List[Dict[str, Any]] = Field(
|
||||
default_factory=list, description="Query result rows"
|
||||
)
|
||||
row_count: int = Field(0, description="Number of rows returned")
|
||||
total_rows: int | None = Field(
|
||||
None, description="Total row count from the query engine"
|
||||
)
|
||||
summary: str = Field("", description="Human-readable summary of the results")
|
||||
performance: PerformanceMetadata | None = Field(
|
||||
None, description="Query performance metadata"
|
||||
)
|
||||
cache_status: CacheStatus | None = Field(
|
||||
None, description="Cache hit/miss information"
|
||||
)
|
||||
applied_filters: List[QueryDatasetFilter] = Field(
|
||||
default_factory=list, description="Filters that were applied to the query"
|
||||
)
|
||||
warnings: List[str] = Field(
|
||||
default_factory=list, description="Any warnings encountered during execution"
|
||||
)
|
||||
|
||||
|
||||
def _parse_json_field(obj: Any, field_name: str) -> Dict[str, Any] | None:
|
||||
"""Parse a field that may be stored as a JSON string into a dict."""
|
||||
value = getattr(obj, field_name, None)
|
||||
|
||||
@@ -18,9 +18,11 @@
|
||||
from .create_virtual_dataset import create_virtual_dataset
|
||||
from .get_dataset_info import get_dataset_info
|
||||
from .list_datasets import list_datasets
|
||||
from .query_dataset import query_dataset
|
||||
|
||||
__all__ = [
|
||||
"create_virtual_dataset",
|
||||
"list_datasets",
|
||||
"get_dataset_info",
|
||||
"query_dataset",
|
||||
]
|
||||
|
||||
489
superset/mcp_service/dataset/tool/query_dataset.py
Normal file
489
superset/mcp_service/dataset/tool/query_dataset.py
Normal file
@@ -0,0 +1,489 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
MCP tool: query_dataset
|
||||
|
||||
Query a dataset using its semantic layer (saved metrics, calculated columns,
|
||||
dimensions) without requiring a saved chart.
|
||||
"""
|
||||
|
||||
import difflib
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from fastmcp import Context
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import joinedload, subqueryload
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.commands.exceptions import CommandException
|
||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.chart.schemas import DataColumn, PerformanceMetadata
|
||||
from superset.mcp_service.dataset.schemas import (
|
||||
DatasetError,
|
||||
QueryDatasetFilter,
|
||||
QueryDatasetRequest,
|
||||
QueryDatasetResponse,
|
||||
)
|
||||
from superset.mcp_service.privacy import (
|
||||
DATA_MODEL_METADATA_ERROR_TYPE,
|
||||
requires_data_model_metadata_access,
|
||||
user_can_view_data_model_metadata,
|
||||
)
|
||||
from superset.mcp_service.utils import _is_uuid
|
||||
from superset.mcp_service.utils.cache_utils import get_cache_status_from_result
|
||||
from superset.mcp_service.utils.oauth2_utils import build_oauth2_redirect_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_dataset(identifier: int | str, eager_options: list[Any]) -> Any | None:
|
||||
"""Resolve a dataset by int ID or UUID string.
|
||||
|
||||
Replicates the identifier resolution logic from ModelGetInfoCore._find_object().
|
||||
"""
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
|
||||
opts = eager_options or None
|
||||
|
||||
if isinstance(identifier, int):
|
||||
return DatasetDAO.find_by_id(identifier, query_options=opts)
|
||||
|
||||
# Try parsing as int
|
||||
try:
|
||||
id_val = int(identifier)
|
||||
return DatasetDAO.find_by_id(id_val, query_options=opts)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Try UUID
|
||||
if _is_uuid(str(identifier)):
|
||||
return DatasetDAO.find_by_id(identifier, id_column="uuid", query_options=opts)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _validate_names(
|
||||
requested: list[str],
|
||||
valid: set[str],
|
||||
kind: str,
|
||||
) -> list[str]:
|
||||
"""Return list of error messages for names not found in *valid*.
|
||||
|
||||
Includes close-match suggestions when available.
|
||||
"""
|
||||
errors: list[str] = []
|
||||
for name in requested:
|
||||
if name not in valid:
|
||||
suggestions = difflib.get_close_matches(name, valid, n=3, cutoff=0.6)
|
||||
msg = f"Unknown {kind}: '{name}'"
|
||||
if suggestions:
|
||||
msg += f". Did you mean: {', '.join(suggestions)}?"
|
||||
errors.append(msg)
|
||||
return errors
|
||||
|
||||
|
||||
@requires_data_model_metadata_access
|
||||
@tool(
|
||||
tags=["data"],
|
||||
class_permission_name="Dataset",
|
||||
annotations=ToolAnnotations(
|
||||
title="Query dataset",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def query_dataset( # noqa: C901
|
||||
request: QueryDatasetRequest, ctx: Context
|
||||
) -> QueryDatasetResponse | DatasetError:
|
||||
"""Query a dataset using its semantic layer (saved metrics, dimensions, filters).
|
||||
|
||||
Returns tabular data without requiring a saved chart. Use this when you want
|
||||
to compute saved metrics, group by dimensions, or apply filters directly
|
||||
against a dataset's curated semantic layer.
|
||||
|
||||
Workflow:
|
||||
1. list_datasets -> find a dataset
|
||||
2. get_dataset_info -> discover available columns and metrics
|
||||
3. query_dataset -> query using metric names and column names
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"dataset_id": 123,
|
||||
"metrics": ["count", "avg_revenue"],
|
||||
"columns": ["product_category"],
|
||||
"time_range": "Last 7 days",
|
||||
"row_limit": 100
|
||||
}
|
||||
```
|
||||
"""
|
||||
await ctx.info(
|
||||
"Starting dataset query: dataset_id=%s, metrics=%s, columns=%s, "
|
||||
"row_limit=%s"
|
||||
% (
|
||||
request.dataset_id,
|
||||
request.metrics,
|
||||
request.columns,
|
||||
request.row_limit,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
from superset.commands.chart.data.get_data_command import ChartDataCommand
|
||||
from superset.common.query_context_factory import QueryContextFactory
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 1: Check data-model metadata access BEFORE the dataset lookup.
|
||||
# Doing this first prevents leaking dataset existence — restricted
|
||||
# users always receive DataModelMetadataRestricted, never NotFound.
|
||||
# The decorator hides this tool from search; this check enforces
|
||||
# direct calls that bypass tool discovery.
|
||||
# ------------------------------------------------------------------
|
||||
if not user_can_view_data_model_metadata():
|
||||
await ctx.warning("Dataset metadata access blocked by privacy controls")
|
||||
return DatasetError.create(
|
||||
error=(
|
||||
"You don't have permission to access dataset details for your role."
|
||||
),
|
||||
error_type=DATA_MODEL_METADATA_ERROR_TYPE,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 2: Resolve dataset
|
||||
# ------------------------------------------------------------------
|
||||
await ctx.report_progress(1, 5, "Looking up dataset")
|
||||
eager_options = [
|
||||
subqueryload(SqlaTable.columns),
|
||||
subqueryload(SqlaTable.metrics),
|
||||
joinedload(SqlaTable.database),
|
||||
]
|
||||
|
||||
with event_logger.log_context(action="mcp.query_dataset.lookup"):
|
||||
dataset = _resolve_dataset(request.dataset_id, eager_options)
|
||||
|
||||
if dataset is None:
|
||||
await ctx.error("Dataset not found: identifier=%s" % (request.dataset_id,))
|
||||
return DatasetError.create(
|
||||
error=f"No dataset found with identifier: {request.dataset_id}",
|
||||
error_type="NotFound",
|
||||
)
|
||||
|
||||
dataset_name = getattr(dataset, "table_name", None) or f"Dataset {dataset.id}"
|
||||
await ctx.info(
|
||||
"Dataset found: id=%s, name=%s, columns=%s, metrics=%s"
|
||||
% (
|
||||
dataset.id,
|
||||
dataset_name,
|
||||
len(dataset.columns),
|
||||
len(dataset.metrics),
|
||||
)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 2: Validate requested columns and metrics
|
||||
# ------------------------------------------------------------------
|
||||
await ctx.report_progress(2, 5, "Validating columns and metrics")
|
||||
valid_columns = {c.column_name for c in dataset.columns}
|
||||
valid_metrics = {m.metric_name for m in dataset.metrics}
|
||||
|
||||
validation_errors: list[str] = []
|
||||
validation_errors.extend(
|
||||
_validate_names(request.columns, valid_columns, "column")
|
||||
)
|
||||
validation_errors.extend(
|
||||
_validate_names(request.metrics, valid_metrics, "metric")
|
||||
)
|
||||
# Validate filter column names against dataset columns
|
||||
filter_cols = [f.col for f in request.filters]
|
||||
validation_errors.extend(
|
||||
_validate_names(filter_cols, valid_columns, "filter column")
|
||||
)
|
||||
# Validate order_by names against columns + metrics
|
||||
if request.order_by:
|
||||
valid_orderby = valid_columns | valid_metrics
|
||||
validation_errors.extend(
|
||||
_validate_names(request.order_by, valid_orderby, "order_by")
|
||||
)
|
||||
|
||||
if validation_errors:
|
||||
error_msg = "; ".join(validation_errors)
|
||||
await ctx.error("Validation failed: %s" % (error_msg,))
|
||||
return DatasetError.create(
|
||||
error=error_msg,
|
||||
error_type="ValidationError",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 3: Build filters and time range
|
||||
# ------------------------------------------------------------------
|
||||
warnings: list[str] = []
|
||||
query_filters: list[dict[str, Any]] = [
|
||||
{"col": f.col, "op": f.op, "val": f.val} for f in request.filters
|
||||
]
|
||||
# Track all applied filters (including synthesized ones) for the response.
|
||||
effective_filters: list[QueryDatasetFilter] = list(request.filters)
|
||||
granularity: str | None = None
|
||||
|
||||
if request.time_range:
|
||||
temporal_col = request.time_column or getattr(
|
||||
dataset, "main_dttm_col", None
|
||||
)
|
||||
if not temporal_col:
|
||||
await ctx.error("time_range provided but no temporal column available")
|
||||
return DatasetError.create(
|
||||
error=(
|
||||
"time_range was provided but no temporal column is available. "
|
||||
"Either set time_column explicitly or ensure the dataset has "
|
||||
"a main datetime column configured."
|
||||
),
|
||||
error_type="ValidationError",
|
||||
)
|
||||
# Validate that the temporal column actually exists on the dataset
|
||||
if temporal_col not in valid_columns:
|
||||
await ctx.error("time_column '%s' not found on dataset" % temporal_col)
|
||||
return DatasetError.create(
|
||||
error=(
|
||||
f"time_column '{temporal_col}' does not exist on this dataset."
|
||||
),
|
||||
error_type="ValidationError",
|
||||
)
|
||||
# Warn if the chosen temporal column isn't marked as datetime
|
||||
dttm_cols = {c.column_name for c in dataset.columns if c.is_dttm}
|
||||
if temporal_col not in dttm_cols:
|
||||
warnings.append(
|
||||
f"Column '{temporal_col}' is not marked as a datetime "
|
||||
f"column on this dataset. Time filtering may not work "
|
||||
f"as expected."
|
||||
)
|
||||
|
||||
query_filters.append(
|
||||
{
|
||||
"col": temporal_col,
|
||||
"op": "TEMPORAL_RANGE",
|
||||
"val": request.time_range,
|
||||
}
|
||||
)
|
||||
effective_filters.append(
|
||||
QueryDatasetFilter(
|
||||
col=temporal_col,
|
||||
op="TEMPORAL_RANGE",
|
||||
val=request.time_range,
|
||||
)
|
||||
)
|
||||
granularity = temporal_col
|
||||
await ctx.debug(
|
||||
"Time filter: column=%s, range=%s" % (temporal_col, request.time_range)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 4: Build query dict
|
||||
# ------------------------------------------------------------------
|
||||
await ctx.report_progress(3, 5, "Building query")
|
||||
query_dict: dict[str, Any] = {
|
||||
"filters": query_filters,
|
||||
"columns": request.columns,
|
||||
"metrics": request.metrics,
|
||||
"row_limit": request.row_limit,
|
||||
"order_desc": request.order_desc,
|
||||
}
|
||||
if granularity:
|
||||
query_dict["granularity"] = granularity
|
||||
if request.order_by:
|
||||
# OrderBy = tuple[Metric | Column, bool] where bool is ascending
|
||||
query_dict["orderby"] = [
|
||||
(col, not request.order_desc) for col in request.order_by
|
||||
]
|
||||
|
||||
await ctx.debug("Query dict keys: %s" % (sorted(query_dict.keys()),))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 5: Create QueryContext and execute
|
||||
# ------------------------------------------------------------------
|
||||
await ctx.report_progress(4, 5, "Executing query")
|
||||
start_time = time.time()
|
||||
|
||||
with event_logger.log_context(action="mcp.query_dataset.execute"):
|
||||
factory = QueryContextFactory()
|
||||
# datasource_type is "table" because this tool queries SqlaTable
|
||||
# datasets (Superset's built-in semantic layer). External semantic
|
||||
# layers (dbt, Snowflake Cortex, etc.) use "semantic_view" and have
|
||||
# a different query path — see SemanticView + mapper.py.
|
||||
query_context = factory.create(
|
||||
datasource={"id": dataset.id, "type": "table"},
|
||||
queries=[query_dict],
|
||||
form_data={},
|
||||
force=not request.use_cache or request.force_refresh,
|
||||
custom_cache_timeout=request.cache_timeout,
|
||||
)
|
||||
|
||||
command = ChartDataCommand(query_context)
|
||||
command.validate()
|
||||
result = command.run()
|
||||
|
||||
query_duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
if not result or "queries" not in result or len(result["queries"]) == 0:
|
||||
await ctx.warning("Query returned no results for dataset %s" % dataset.id)
|
||||
return DatasetError.create(
|
||||
error="Query returned no results.",
|
||||
error_type="EmptyQuery",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 6: Format response
|
||||
# ------------------------------------------------------------------
|
||||
await ctx.report_progress(5, 5, "Formatting results")
|
||||
query_result = result["queries"][0]
|
||||
data = query_result.get("data", [])
|
||||
raw_columns = query_result.get("colnames", [])
|
||||
|
||||
if not data:
|
||||
return QueryDatasetResponse(
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset_name,
|
||||
columns=[],
|
||||
data=[],
|
||||
row_count=0,
|
||||
total_rows=0,
|
||||
summary=f"Query on '{dataset_name}' returned no data.",
|
||||
performance=PerformanceMetadata(
|
||||
query_duration_ms=query_duration_ms,
|
||||
cache_status="no_data",
|
||||
),
|
||||
cache_status=get_cache_status_from_result(
|
||||
query_result, force_refresh=request.force_refresh
|
||||
),
|
||||
applied_filters=effective_filters,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
# Build column metadata in a single pass per column.
|
||||
# Cap stats computation at STATS_SAMPLE rows to avoid O(rows*cols)
|
||||
# overhead on large result sets (row_limit allows up to 50k).
|
||||
stats_sample_size = 5000
|
||||
stats_rows = data[:stats_sample_size]
|
||||
|
||||
columns_meta: list[DataColumn] = []
|
||||
for col_name in raw_columns:
|
||||
sample_values = [
|
||||
row.get(col_name) for row in data[:3] if row.get(col_name) is not None
|
||||
]
|
||||
data_type = "string"
|
||||
if sample_values:
|
||||
if all(isinstance(v, bool) for v in sample_values):
|
||||
data_type = "boolean"
|
||||
elif all(isinstance(v, (int, float)) for v in sample_values):
|
||||
data_type = "numeric"
|
||||
|
||||
# Compute null_count and unique non-null values in one pass
|
||||
null_count = 0
|
||||
unique_vals: set[str] = set()
|
||||
for row in stats_rows:
|
||||
val = row.get(col_name)
|
||||
if val is None:
|
||||
null_count += 1
|
||||
else:
|
||||
unique_vals.add(str(val))
|
||||
|
||||
columns_meta.append(
|
||||
DataColumn(
|
||||
name=col_name,
|
||||
display_name=col_name.replace("_", " ").title(),
|
||||
data_type=data_type,
|
||||
sample_values=sample_values[:3],
|
||||
null_count=null_count,
|
||||
unique_count=len(unique_vals),
|
||||
)
|
||||
)
|
||||
|
||||
cache_status = get_cache_status_from_result(
|
||||
query_result, force_refresh=request.force_refresh
|
||||
)
|
||||
|
||||
cache_label = "cached" if cache_status and cache_status.cache_hit else "fresh"
|
||||
summary = (
|
||||
f"Dataset '{dataset_name}': {len(data)} rows, "
|
||||
f"{len(raw_columns)} columns ({cache_label})."
|
||||
)
|
||||
|
||||
await ctx.info(
|
||||
"Query complete: rows=%s, columns=%s, duration=%sms"
|
||||
% (len(data), len(raw_columns), query_duration_ms)
|
||||
)
|
||||
|
||||
return QueryDatasetResponse(
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset_name,
|
||||
columns=columns_meta,
|
||||
data=data,
|
||||
row_count=len(data),
|
||||
total_rows=query_result.get("rowcount"),
|
||||
summary=summary,
|
||||
performance=PerformanceMetadata(
|
||||
query_duration_ms=query_duration_ms,
|
||||
cache_status=cache_label,
|
||||
),
|
||||
cache_status=cache_status,
|
||||
applied_filters=effective_filters,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
except OAuth2RedirectError as exc:
|
||||
redirect_msg = build_oauth2_redirect_message(exc)
|
||||
await ctx.error("OAuth2 redirect required: %s" % (redirect_msg,))
|
||||
return DatasetError.create(
|
||||
error=redirect_msg,
|
||||
error_type="OAuth2Redirect",
|
||||
)
|
||||
|
||||
except OAuth2Error as exc:
|
||||
await ctx.error("OAuth2 error: %s" % (str(exc),))
|
||||
return DatasetError.create(
|
||||
error=f"OAuth2 authentication error: {exc}",
|
||||
error_type="OAuth2Error",
|
||||
)
|
||||
|
||||
except (CommandException, SupersetException) as exc:
|
||||
await ctx.error("Query failed: %s" % (str(exc),))
|
||||
return DatasetError.create(
|
||||
error=f"Query execution failed: {exc}",
|
||||
error_type="QueryError",
|
||||
)
|
||||
|
||||
except SQLAlchemyError as exc:
|
||||
await ctx.error("Database error: %s" % (str(exc),))
|
||||
return DatasetError.create(
|
||||
error=f"Database error: {exc}",
|
||||
error_type="DatabaseError",
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"Unexpected error while querying dataset: %s: %s",
|
||||
type(exc).__name__,
|
||||
str(exc),
|
||||
)
|
||||
await ctx.error("Unexpected error: %s: %s" % (type(exc).__name__, str(exc)))
|
||||
return DatasetError.create(
|
||||
error="An unexpected error occurred while querying the dataset.",
|
||||
error_type="UnexpectedError",
|
||||
)
|
||||
226
superset/utils/rison_filters.py
Normal file
226
superset/utils/rison_filters.py
Normal file
@@ -0,0 +1,226 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Parser for Rison URL filters that converts simplified filter syntax
|
||||
to Superset's adhoc_filters format.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import prison
|
||||
from flask import request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RisonFilterParser:
|
||||
"""
|
||||
Parse Rison filter syntax from URL parameter 'f' and convert to adhoc_filters.
|
||||
|
||||
Supports:
|
||||
- Simple equality: f=(country:USA)
|
||||
- Lists (IN): f=(country:!(USA,Canada))
|
||||
- NOT operator: f=(NOT:(country:USA))
|
||||
- OR operator: f=(OR:!(condition1,condition2))
|
||||
- Comparison operators: f=(sales:(gt:100000))
|
||||
- BETWEEN: f=(date:(between:!(2024-01-01,2024-12-31)))
|
||||
- LIKE: f=(name:(like:'%smith%'))
|
||||
"""
|
||||
|
||||
OPERATORS: dict[str, str] = {
|
||||
"gt": ">",
|
||||
"gte": ">=",
|
||||
"lt": "<",
|
||||
"lte": "<=",
|
||||
"between": "BETWEEN",
|
||||
"like": "LIKE",
|
||||
"ilike": "ILIKE",
|
||||
"ne": "!=",
|
||||
"eq": "==",
|
||||
}
|
||||
|
||||
def parse(self, filter_string: Optional[str] = None) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Parse Rison filter string and convert to adhoc_filters format.
|
||||
|
||||
Args:
|
||||
filter_string: Rison-encoded filter string, or None to get from request
|
||||
|
||||
Returns:
|
||||
List of adhoc_filter dictionaries
|
||||
"""
|
||||
if filter_string is None:
|
||||
filter_string = request.args.get("f")
|
||||
|
||||
if not filter_string:
|
||||
return []
|
||||
|
||||
try:
|
||||
filters_obj = prison.loads(filter_string)
|
||||
return self._convert_to_adhoc_filters(filters_obj)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to parse Rison filters: %s", filter_string, exc_info=True
|
||||
)
|
||||
return []
|
||||
|
||||
def _convert_to_adhoc_filters(
|
||||
self, filters_obj: Union[dict[str, Any], list[Any], Any]
|
||||
) -> list[dict[str, Any]]:
|
||||
if not isinstance(filters_obj, dict):
|
||||
return []
|
||||
|
||||
adhoc_filters: list[dict[str, Any]] = []
|
||||
|
||||
for key, value in filters_obj.items():
|
||||
if key == "OR":
|
||||
adhoc_filters.extend(self._handle_or_operator(value))
|
||||
elif key == "NOT":
|
||||
adhoc_filters.extend(self._handle_not_operator(value))
|
||||
else:
|
||||
filter_dict = self._create_filter(key, value)
|
||||
if filter_dict:
|
||||
adhoc_filters.append(filter_dict)
|
||||
|
||||
return adhoc_filters
|
||||
|
||||
def _create_filter(
|
||||
self, column: str, value: Any, negate: bool = False
|
||||
) -> Optional[dict[str, Any]]:
|
||||
filter_dict: dict[str, Any] = {
|
||||
"expressionType": "SIMPLE",
|
||||
"clause": "WHERE",
|
||||
"subject": column,
|
||||
}
|
||||
|
||||
if isinstance(value, list):
|
||||
filter_dict["operator"] = "NOT IN" if negate else "IN"
|
||||
filter_dict["comparator"] = value
|
||||
elif isinstance(value, dict):
|
||||
operator_info = self._parse_operator_dict(value)
|
||||
if operator_info:
|
||||
operator, comparator = operator_info
|
||||
if negate and operator == "==":
|
||||
operator = "!="
|
||||
elif negate and operator == "IN":
|
||||
operator = "NOT IN"
|
||||
filter_dict["operator"] = operator
|
||||
filter_dict["comparator"] = comparator
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
filter_dict["operator"] = "!=" if negate else "=="
|
||||
filter_dict["comparator"] = value
|
||||
|
||||
return filter_dict
|
||||
|
||||
def _parse_operator_dict(
|
||||
self, op_dict: dict[str, Any]
|
||||
) -> Optional[tuple[str, Any]]:
|
||||
if not op_dict:
|
||||
return None
|
||||
|
||||
for op_key, op_value in op_dict.items():
|
||||
if op_key in self.OPERATORS:
|
||||
operator = self.OPERATORS[op_key]
|
||||
if (
|
||||
operator == "BETWEEN"
|
||||
and isinstance(op_value, list)
|
||||
and len(op_value) == 2
|
||||
):
|
||||
return operator, op_value
|
||||
return operator, op_value
|
||||
if op_key == "in":
|
||||
return "IN", op_value if isinstance(op_value, list) else [op_value]
|
||||
if op_key == "nin":
|
||||
return "NOT IN", op_value if isinstance(op_value, list) else [op_value]
|
||||
|
||||
return None
|
||||
|
||||
def _handle_or_operator(self, or_value: Any) -> list[dict[str, Any]]:
|
||||
if not isinstance(or_value, list):
|
||||
return []
|
||||
|
||||
sql_parts: list[str] = []
|
||||
|
||||
for item in or_value:
|
||||
if isinstance(item, dict):
|
||||
for col, val in item.items():
|
||||
if col not in ("OR", "NOT"):
|
||||
sql_part = self._build_sql_condition(col, val)
|
||||
if sql_part:
|
||||
sql_parts.append(sql_part)
|
||||
|
||||
if sql_parts:
|
||||
return [
|
||||
{
|
||||
"expressionType": "SQL",
|
||||
"clause": "WHERE",
|
||||
"sqlExpression": f"({' OR '.join(sql_parts)})",
|
||||
}
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
def _build_sql_condition(self, column: str, value: Any) -> Optional[str]:
|
||||
if isinstance(value, list):
|
||||
values_str = ", ".join(
|
||||
[f"'{v}'" if isinstance(v, str) else str(v) for v in value]
|
||||
)
|
||||
return f"{column} IN ({values_str})"
|
||||
|
||||
if isinstance(value, dict):
|
||||
operator_info = self._parse_operator_dict(value)
|
||||
if operator_info:
|
||||
op, comp = operator_info
|
||||
if op == "BETWEEN" and isinstance(comp, list):
|
||||
return f"{column} BETWEEN '{comp[0]}' AND '{comp[1]}'"
|
||||
if op == "LIKE":
|
||||
return f"{column} LIKE '{comp}'"
|
||||
comp_str = f"'{comp}'" if isinstance(comp, str) else str(comp)
|
||||
return f"{column} {op} {comp_str}"
|
||||
|
||||
val_str = f"'{value}'" if isinstance(value, str) else str(value)
|
||||
return f"{column} = {val_str}"
|
||||
|
||||
def _handle_not_operator(self, not_value: Any) -> list[dict[str, Any]]:
|
||||
if isinstance(not_value, dict):
|
||||
filters: list[dict[str, Any]] = []
|
||||
for col, val in not_value.items():
|
||||
if col not in ("OR", "NOT"):
|
||||
filter_dict = self._create_filter(col, val, negate=True)
|
||||
if filter_dict:
|
||||
filters.append(filter_dict)
|
||||
return filters
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def merge_rison_filters(form_data: dict[str, Any]) -> None:
|
||||
"""
|
||||
Merge Rison filters from 'f' parameter into form_data.
|
||||
Modifies form_data in place.
|
||||
"""
|
||||
parser = RisonFilterParser()
|
||||
|
||||
if rison_filters := parser.parse():
|
||||
existing_filters = form_data.get("adhoc_filters", [])
|
||||
form_data["adhoc_filters"] = existing_filters + rison_filters
|
||||
logger.info("Added %d filters from Rison parameter", len(rison_filters))
|
||||
@@ -78,15 +78,6 @@ FEATURE_FLAGS = {
|
||||
|
||||
WEBDRIVER_BASEURL = "http://0.0.0.0:8081/"
|
||||
|
||||
# Enable CORS for embedded dashboard E2E tests (test app on port 9000)
|
||||
ENABLE_CORS = True
|
||||
CORS_OPTIONS: dict = {
|
||||
"origins": [
|
||||
"http://localhost:9000",
|
||||
],
|
||||
"supports_credentials": True,
|
||||
}
|
||||
|
||||
|
||||
def GET_FEATURE_FLAGS_FUNC(ff): # noqa: N802
|
||||
ff_copy = copy(ff)
|
||||
@@ -95,7 +86,6 @@ def GET_FEATURE_FLAGS_FUNC(ff): # noqa: N802
|
||||
|
||||
|
||||
TESTING = True
|
||||
TALISMAN_ENABLED = False
|
||||
WTF_CSRF_ENABLED = False
|
||||
|
||||
FAB_ROLES = {"TestRole": [["Security", "menu_access"], ["List Users", "menu_access"]]}
|
||||
|
||||
@@ -24,6 +24,7 @@ import pandas as pd
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from requests.exceptions import HTTPError
|
||||
from shillelagh.exceptions import UnauthenticatedError
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
@@ -789,6 +790,36 @@ def test_needs_oauth2_with_other_error(mocker: MockerFixture) -> None:
|
||||
assert GSheetsEngineSpec.needs_oauth2(ex) is False
|
||||
|
||||
|
||||
def test_needs_oauth2_with_shillelagh_unauthenticated_error(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that needs_oauth2 returns True when UnauthenticatedError is raised.
|
||||
"""
|
||||
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
|
||||
|
||||
g = mocker.patch("superset.db_engine_specs.gsheets.g")
|
||||
g.user = mocker.MagicMock()
|
||||
|
||||
ex = UnauthenticatedError("Token has been revoked")
|
||||
assert GSheetsEngineSpec.needs_oauth2(ex) is True
|
||||
|
||||
|
||||
def test_needs_oauth2_with_unrelated_exception_type(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that an unrelated exception type (with no matching message) returns
|
||||
False.
|
||||
"""
|
||||
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
|
||||
|
||||
g = mocker.patch("superset.db_engine_specs.gsheets.g")
|
||||
g.user = mocker.MagicMock()
|
||||
|
||||
assert GSheetsEngineSpec.needs_oauth2(ValueError("unrelated")) is False
|
||||
|
||||
|
||||
def test_get_oauth2_fresh_token_success(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config: OAuth2ClientConfig,
|
||||
|
||||
@@ -320,66 +320,13 @@ class TestChartDataModelMetadataPrivacy:
|
||||
assert data["error_type"] == DATA_MODEL_METADATA_ERROR_TYPE
|
||||
|
||||
|
||||
class TestListChartsCreatedByMe:
|
||||
"""Tests for the created_by_me flag on ListChartsRequest."""
|
||||
|
||||
def test_created_by_me_default_is_false(self):
|
||||
request = ListChartsRequest()
|
||||
assert request.created_by_me is False
|
||||
|
||||
def test_created_by_me_true_accepted(self):
|
||||
request = ListChartsRequest(created_by_me=True)
|
||||
assert request.created_by_me is True
|
||||
|
||||
def test_created_by_me_combined_with_filters(self):
|
||||
request = ListChartsRequest(
|
||||
created_by_me=True,
|
||||
filters=[ChartFilter(col="slice_name", opr="sw", value="My")],
|
||||
)
|
||||
assert request.created_by_me is True
|
||||
assert len(request.filters) == 1
|
||||
|
||||
def test_created_by_me_with_search_raises(self):
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError, match="created_by_me"):
|
||||
ListChartsRequest(created_by_me=True, search="My charts")
|
||||
|
||||
def test_chart_filter_rejects_created_by_fk(self):
|
||||
"""created_by_fk is not a public filter column; use created_by_me instead."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
ChartFilter(col="created_by_fk", opr="eq", value=1)
|
||||
|
||||
|
||||
class TestListChartsOwnedByMe:
|
||||
"""Tests for the owned_by_me flag on ListChartsRequest."""
|
||||
|
||||
def test_owned_by_me_default_is_false(self):
|
||||
request = ListChartsRequest()
|
||||
assert request.owned_by_me is False
|
||||
|
||||
def test_owned_by_me_true_accepted(self):
|
||||
request = ListChartsRequest(owned_by_me=True)
|
||||
assert request.owned_by_me is True
|
||||
|
||||
def test_owned_by_me_combined_with_filters(self):
|
||||
request = ListChartsRequest(
|
||||
owned_by_me=True,
|
||||
filters=[ChartFilter(col="slice_name", opr="sw", value="My")],
|
||||
)
|
||||
assert request.owned_by_me is True
|
||||
assert len(request.filters) == 1
|
||||
|
||||
def test_owned_by_me_with_search_raises(self):
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError, match="owned_by_me"):
|
||||
ListChartsRequest(owned_by_me=True, search="My charts")
|
||||
|
||||
def test_owned_by_me_and_created_by_me_allowed(self):
|
||||
"""Both flags together are valid (OR logic — creator or owner)."""
|
||||
request = ListChartsRequest(owned_by_me=True, created_by_me=True)
|
||||
assert request.owned_by_me is True
|
||||
assert request.created_by_me is True
|
||||
@patch("superset.daos.chart.ChartDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_charts_no_arguments(mock_list, mcp_server):
|
||||
"""Regression test: list_charts must accept zero arguments without raising
|
||||
pydantic_core.ValidationError: Missing required argument: request."""
|
||||
mock_list.return_value = ([], 0)
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("list_charts", {})
|
||||
data = json.loads(result.content[0].text)
|
||||
assert "charts" in data
|
||||
|
||||
@@ -30,7 +30,6 @@ from flask import g
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.mcp_service.dashboard.schemas import (
|
||||
DashboardFilter,
|
||||
ListDashboardsRequest,
|
||||
)
|
||||
from superset.mcp_service.dashboard.tool.get_dashboard_info import (
|
||||
@@ -1355,66 +1354,13 @@ class TestDashboardSortableColumns:
|
||||
assert col in list_dashboards.__doc__
|
||||
|
||||
|
||||
class TestListDashboardsCreatedByMe:
|
||||
"""Tests for the created_by_me flag on ListDashboardsRequest."""
|
||||
|
||||
def test_created_by_me_default_is_false(self):
|
||||
request = ListDashboardsRequest()
|
||||
assert request.created_by_me is False
|
||||
|
||||
def test_created_by_me_true_accepted(self):
|
||||
request = ListDashboardsRequest(created_by_me=True)
|
||||
assert request.created_by_me is True
|
||||
|
||||
def test_created_by_me_combined_with_filters(self):
|
||||
request = ListDashboardsRequest(
|
||||
created_by_me=True,
|
||||
filters=[DashboardFilter(col="published", opr="eq", value=True)],
|
||||
)
|
||||
assert request.created_by_me is True
|
||||
assert len(request.filters) == 1
|
||||
|
||||
def test_created_by_me_with_search_raises(self):
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError, match="created_by_me"):
|
||||
ListDashboardsRequest(created_by_me=True, search="My dashboards")
|
||||
|
||||
def test_dashboard_filter_rejects_created_by_fk(self):
|
||||
"""created_by_fk is not a public filter column; use created_by_me instead."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
DashboardFilter(col="created_by_fk", opr="eq", value=1)
|
||||
|
||||
|
||||
class TestListDashboardsOwnedByMe:
|
||||
"""Tests for the owned_by_me flag on ListDashboardsRequest."""
|
||||
|
||||
def test_owned_by_me_default_is_false(self):
|
||||
request = ListDashboardsRequest()
|
||||
assert request.owned_by_me is False
|
||||
|
||||
def test_owned_by_me_true_accepted(self):
|
||||
request = ListDashboardsRequest(owned_by_me=True)
|
||||
assert request.owned_by_me is True
|
||||
|
||||
def test_owned_by_me_combined_with_filters(self):
|
||||
request = ListDashboardsRequest(
|
||||
owned_by_me=True,
|
||||
filters=[DashboardFilter(col="published", opr="eq", value=True)],
|
||||
)
|
||||
assert request.owned_by_me is True
|
||||
assert len(request.filters) == 1
|
||||
|
||||
def test_owned_by_me_with_search_raises(self):
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError, match="owned_by_me"):
|
||||
ListDashboardsRequest(owned_by_me=True, search="My dashboards")
|
||||
|
||||
def test_owned_by_me_and_created_by_me_allowed(self):
|
||||
"""Both flags together are valid (OR logic — creator or owner)."""
|
||||
request = ListDashboardsRequest(owned_by_me=True, created_by_me=True)
|
||||
assert request.owned_by_me is True
|
||||
assert request.created_by_me is True
|
||||
@patch("superset.daos.dashboard.DashboardDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_dashboards_no_arguments(mock_list, mcp_server):
|
||||
"""Regression test: list_dashboards must accept zero arguments without raising
|
||||
pydantic_core.ValidationError: Missing required argument: request."""
|
||||
mock_list.return_value = ([], 0)
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("list_dashboards", {})
|
||||
data = json.loads(result.content[0].text)
|
||||
assert "dashboards" in data
|
||||
|
||||
831
tests/unit_tests/mcp_service/dataset/tool/test_query_dataset.py
Normal file
831
tests/unit_tests/mcp_service/dataset/tool/test_query_dataset.py
Normal file
@@ -0,0 +1,831 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Tests for the query_dataset MCP tool."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client, FastMCP
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.utils import json
|
||||
|
||||
query_dataset_module = importlib.import_module(
|
||||
"superset.mcp_service.dataset.tool.query_dataset"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server() -> FastMCP:
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth() -> Generator[MagicMock, None, None]:
|
||||
"""Mock authentication and metadata access for all tests."""
|
||||
with (
|
||||
patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user,
|
||||
patch.object(
|
||||
query_dataset_module,
|
||||
"user_can_view_data_model_metadata",
|
||||
return_value=True,
|
||||
),
|
||||
):
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
def _make_column(name: str, is_dttm: bool = False) -> MagicMock:
|
||||
"""Build a mock SqlaTable column with the given name and datetime flag."""
|
||||
col = MagicMock()
|
||||
col.column_name = name
|
||||
col.is_dttm = is_dttm
|
||||
col.verbose_name = None
|
||||
col.type = "VARCHAR"
|
||||
col.groupby = True
|
||||
col.filterable = True
|
||||
col.description = None
|
||||
return col
|
||||
|
||||
|
||||
def _make_metric(name: str, expression: str = "COUNT(*)") -> MagicMock:
|
||||
"""Build a mock SqlMetric with the given name and SQL expression."""
|
||||
metric = MagicMock()
|
||||
metric.metric_name = name
|
||||
metric.verbose_name = None
|
||||
metric.expression = expression
|
||||
metric.description = None
|
||||
metric.d3format = None
|
||||
return metric
|
||||
|
||||
|
||||
def _make_dataset(
|
||||
dataset_id: int = 1,
|
||||
table_name: str = "orders",
|
||||
columns: list[Any] | None = None,
|
||||
metrics: list[Any] | None = None,
|
||||
main_dttm_col: str | None = None,
|
||||
) -> MagicMock:
|
||||
"""Build a mock SqlaTable dataset with default columns and metrics."""
|
||||
ds = MagicMock()
|
||||
ds.id = dataset_id
|
||||
ds.table_name = table_name
|
||||
ds.uuid = f"test-uuid-{dataset_id}"
|
||||
ds.main_dttm_col = main_dttm_col
|
||||
ds.database = MagicMock()
|
||||
ds.database.database_name = "examples"
|
||||
ds.columns = columns or [
|
||||
_make_column("category"),
|
||||
_make_column("region"),
|
||||
_make_column("order_date", is_dttm=True),
|
||||
]
|
||||
ds.metrics = metrics or [
|
||||
_make_metric("count", "COUNT(*)"),
|
||||
_make_metric("total_revenue", "SUM(revenue)"),
|
||||
]
|
||||
return ds
|
||||
|
||||
|
||||
def _mock_command_result(
|
||||
data: list[dict[str, Any]] | None = None,
|
||||
colnames: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build the result dict that ChartDataCommand.run() returns."""
|
||||
data = data or [
|
||||
{"category": "Electronics", "count": 42},
|
||||
{"category": "Clothing", "count": 17},
|
||||
]
|
||||
colnames = colnames or ["category", "count"]
|
||||
return {
|
||||
"queries": [
|
||||
{
|
||||
"data": data,
|
||||
"colnames": colnames,
|
||||
"rowcount": len(data),
|
||||
"cache_key": "abc123",
|
||||
"is_cached": False,
|
||||
"cached_dttm": None,
|
||||
"cache_timeout": 300,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_success(mcp_server: FastMCP) -> None:
|
||||
"""Happy path: metrics + columns returns data."""
|
||||
dataset = _make_dataset()
|
||||
result_data = _mock_command_result()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
|
||||
return_value=result_data,
|
||||
),
|
||||
patch(
|
||||
"superset.common.query_context_factory.QueryContextFactory.create",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": ["count"],
|
||||
"columns": ["category"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["dataset_id"] == 1
|
||||
assert data["dataset_name"] == "orders"
|
||||
assert data["row_count"] == 2
|
||||
assert len(data["data"]) == 2
|
||||
assert data["data"][0]["category"] == "Electronics"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_not_found(mcp_server: FastMCP) -> None:
|
||||
"""Dataset ID that doesn't exist returns error."""
|
||||
with patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=None,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 999,
|
||||
"metrics": ["count"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "NotFound"
|
||||
assert "999" in data["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_invalid_metric(mcp_server: FastMCP) -> None:
|
||||
"""Unknown metric name returns validation error with suggestions."""
|
||||
dataset = _make_dataset()
|
||||
|
||||
with patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": ["countt"], # typo
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "ValidationError"
|
||||
assert "countt" in data["error"]
|
||||
# Should suggest "count" as a close match
|
||||
assert "count" in data["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_invalid_column(mcp_server: FastMCP) -> None:
|
||||
"""Unknown column name returns validation error."""
|
||||
dataset = _make_dataset()
|
||||
|
||||
with patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"columns": ["nonexistent_col"],
|
||||
"metrics": ["count"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "ValidationError"
|
||||
assert "nonexistent_col" in data["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_no_metrics_no_columns(mcp_server: FastMCP) -> None:
|
||||
"""Providing neither metrics nor columns raises validation error."""
|
||||
from fastmcp.exceptions import ToolError
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
with pytest.raises(ToolError, match="metrics.*columns"):
|
||||
await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": [],
|
||||
"columns": [],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_with_time_range(mcp_server: FastMCP) -> None:
|
||||
"""time_range is converted to TEMPORAL_RANGE filter + granularity."""
|
||||
dataset = _make_dataset(main_dttm_col="order_date")
|
||||
result_data = _mock_command_result()
|
||||
captured_queries: list[dict[str, Any]] = []
|
||||
|
||||
def capture_create(**kwargs):
|
||||
captured_queries.extend(kwargs.get("queries", []))
|
||||
return MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
|
||||
return_value=result_data,
|
||||
),
|
||||
patch(
|
||||
"superset.common.query_context_factory.QueryContextFactory.create",
|
||||
side_effect=capture_create,
|
||||
),
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": ["count"],
|
||||
"time_range": "Last 7 days",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert len(captured_queries) == 1
|
||||
query_dict = captured_queries[0]
|
||||
# Should have TEMPORAL_RANGE filter
|
||||
temporal_filters = [f for f in query_dict["filters"] if f["op"] == "TEMPORAL_RANGE"]
|
||||
assert len(temporal_filters) == 1
|
||||
assert temporal_filters[0]["col"] == "order_date"
|
||||
assert temporal_filters[0]["val"] == "Last 7 days"
|
||||
# Should set granularity
|
||||
assert query_dict["granularity"] == "order_date"
|
||||
# applied_filters in response must include the synthesized TEMPORAL_RANGE filter
|
||||
data = json.loads(result.content[0].text)
|
||||
resp_filters = data["applied_filters"]
|
||||
temporal_resp = [f for f in resp_filters if f["op"] == "TEMPORAL_RANGE"]
|
||||
assert len(temporal_resp) == 1
|
||||
assert temporal_resp[0]["col"] == "order_date"
|
||||
assert temporal_resp[0]["val"] == "Last 7 days"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_time_range_no_temporal_column(mcp_server: FastMCP) -> None:
|
||||
"""time_range without a temporal column returns error."""
|
||||
dataset = _make_dataset(main_dttm_col=None)
|
||||
|
||||
with patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": ["count"],
|
||||
"time_range": "Last 7 days",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "ValidationError"
|
||||
assert "temporal column" in data["error"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_with_filters(mcp_server: FastMCP) -> None:
|
||||
"""User-provided filters are passed through to the query."""
|
||||
dataset = _make_dataset()
|
||||
result_data = _mock_command_result()
|
||||
captured_queries: list[dict[str, Any]] = []
|
||||
|
||||
def capture_create(**kwargs):
|
||||
captured_queries.extend(kwargs.get("queries", []))
|
||||
return MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
|
||||
return_value=result_data,
|
||||
),
|
||||
patch(
|
||||
"superset.common.query_context_factory.QueryContextFactory.create",
|
||||
side_effect=capture_create,
|
||||
),
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": ["count"],
|
||||
"filters": [
|
||||
{"col": "category", "op": "==", "val": "Electronics"}
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert len(captured_queries) == 1
|
||||
filters = captured_queries[0]["filters"]
|
||||
assert len(filters) == 1
|
||||
assert filters[0]["col"] == "category"
|
||||
assert filters[0]["op"] == "=="
|
||||
assert filters[0]["val"] == "Electronics"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_empty_results(mcp_server: FastMCP) -> None:
|
||||
"""Query that returns no data gives a response with row_count=0."""
|
||||
dataset = _make_dataset()
|
||||
empty_result = {
|
||||
"queries": [
|
||||
{
|
||||
"data": [],
|
||||
"colnames": [],
|
||||
"rowcount": 0,
|
||||
"is_cached": False,
|
||||
"cached_dttm": None,
|
||||
"cache_timeout": 300,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
|
||||
return_value=empty_result,
|
||||
),
|
||||
patch(
|
||||
"superset.common.query_context_factory.QueryContextFactory.create",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": ["count"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["row_count"] == 0
|
||||
assert data["data"] == []
|
||||
assert "no data" in data["summary"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_by_uuid(mcp_server: FastMCP) -> None:
|
||||
"""UUID-based lookup works."""
|
||||
dataset = _make_dataset()
|
||||
result_data = _mock_command_result()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
) as mock_resolve,
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
|
||||
return_value=result_data,
|
||||
),
|
||||
patch(
|
||||
"superset.common.query_context_factory.QueryContextFactory.create",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": "a1b2c3d4-5678-90ab-cdef-1234567890ab",
|
||||
"metrics": ["count"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the resolve function was called with the UUID
|
||||
mock_resolve.assert_called_once()
|
||||
call_args = mock_resolve.call_args
|
||||
assert call_args[0][0] == "a1b2c3d4-5678-90ab-cdef-1234567890ab"
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["dataset_id"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_permission_denied(mcp_server: FastMCP) -> None:
|
||||
"""Permission denied from ChartDataCommand.validate() returns error."""
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
|
||||
dataset = _make_dataset()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch(
|
||||
"superset.common.query_context_factory.QueryContextFactory.create",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
|
||||
side_effect=SupersetSecurityException(
|
||||
SupersetError(
|
||||
message="Access denied",
|
||||
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
|
||||
level=ErrorLevel.WARNING,
|
||||
)
|
||||
),
|
||||
),
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": ["count"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "QueryError"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_order_by_valid(mcp_server: FastMCP) -> None:
|
||||
"""order_by with valid column/metric names passes through."""
|
||||
dataset = _make_dataset()
|
||||
result_data = _mock_command_result()
|
||||
captured_queries: list[dict[str, Any]] = []
|
||||
|
||||
def capture_create(**kwargs):
|
||||
captured_queries.extend(kwargs.get("queries", []))
|
||||
return MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
|
||||
return_value=result_data,
|
||||
),
|
||||
patch(
|
||||
"superset.common.query_context_factory.QueryContextFactory.create",
|
||||
side_effect=capture_create,
|
||||
),
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": ["count"],
|
||||
"columns": ["category"],
|
||||
"order_by": ["count"],
|
||||
"order_desc": True,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert len(captured_queries) == 1
|
||||
orderby = captured_queries[0].get("orderby", [])
|
||||
assert len(orderby) == 1
|
||||
assert orderby[0][0] == "count"
|
||||
# order_desc=True -> ascending=False
|
||||
assert orderby[0][1] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_order_by_invalid(mcp_server: FastMCP) -> None:
|
||||
"""order_by with an unknown name returns validation error."""
|
||||
dataset = _make_dataset()
|
||||
|
||||
with patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": ["count"],
|
||||
"order_by": ["nonexistent"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "ValidationError"
|
||||
assert "nonexistent" in data["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_time_column_override(mcp_server: FastMCP) -> None:
|
||||
"""Explicit time_column overrides dataset main_dttm_col."""
|
||||
dataset = _make_dataset(main_dttm_col="order_date")
|
||||
result_data = _mock_command_result()
|
||||
captured_queries: list[dict[str, Any]] = []
|
||||
|
||||
def capture_create(**kwargs):
|
||||
captured_queries.extend(kwargs.get("queries", []))
|
||||
return MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
|
||||
return_value=result_data,
|
||||
),
|
||||
patch(
|
||||
"superset.common.query_context_factory.QueryContextFactory.create",
|
||||
side_effect=capture_create,
|
||||
),
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": ["count"],
|
||||
"time_range": "Last 30 days",
|
||||
"time_column": "order_date",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert len(captured_queries) == 1
|
||||
query_dict = captured_queries[0]
|
||||
assert query_dict["granularity"] == "order_date"
|
||||
temporal_filters = [f for f in query_dict["filters"] if f["op"] == "TEMPORAL_RANGE"]
|
||||
assert temporal_filters[0]["col"] == "order_date"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_non_dttm_time_column_warns(mcp_server: FastMCP) -> None:
|
||||
"""Using a non-datetime column for time_range produces a warning."""
|
||||
dataset = _make_dataset(main_dttm_col=None)
|
||||
result_data = _mock_command_result()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
|
||||
return_value=result_data,
|
||||
),
|
||||
patch(
|
||||
"superset.common.query_context_factory.QueryContextFactory.create",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": ["count"],
|
||||
"time_range": "Last 7 days",
|
||||
"time_column": "category",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert len(data["warnings"]) > 0
|
||||
assert "not marked as a datetime" in data["warnings"][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_invalid_filter_column(mcp_server: FastMCP) -> None:
|
||||
"""Filter on a column that doesn't exist returns validation error."""
|
||||
dataset = _make_dataset()
|
||||
|
||||
with patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": ["count"],
|
||||
"filters": [
|
||||
{
|
||||
"col": "nonexistent",
|
||||
"op": "==",
|
||||
"val": "test",
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "ValidationError"
|
||||
assert "nonexistent" in data["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_metadata_access_denied_no_suggestions(
|
||||
mcp_server: FastMCP,
|
||||
) -> None:
|
||||
"""Users without data-model metadata access cannot probe column/metric names.
|
||||
|
||||
The privacy gate must fire before the validation step that returns close-match
|
||||
suggestions, so restricted users cannot enumerate schema details via typos.
|
||||
"""
|
||||
dataset = _make_dataset()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
query_dataset_module,
|
||||
"user_can_view_data_model_metadata",
|
||||
return_value=False,
|
||||
),
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
# Typo that would normally trigger close-match suggestions
|
||||
"metrics": ["countt"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
# Must be denied before returning any schema suggestions
|
||||
assert data["error_type"] == "DataModelMetadataRestricted"
|
||||
# Must NOT contain column/metric name suggestions
|
||||
assert "countt" not in data.get("error", "")
|
||||
assert "count" not in data.get("error", "")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_metadata_access_denied_nonexistent_dataset(
|
||||
mcp_server: FastMCP,
|
||||
) -> None:
|
||||
"""Metadata-restricted users must not be able to probe dataset existence.
|
||||
|
||||
The privacy gate fires before the DAO lookup, so a restricted caller
|
||||
always receives DataModelMetadataRestricted — never NotFound — regardless
|
||||
of whether the requested dataset ID exists.
|
||||
"""
|
||||
with patch.object(
|
||||
query_dataset_module,
|
||||
"user_can_view_data_model_metadata",
|
||||
return_value=False,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
# Use a dataset_id that does not exist
|
||||
"dataset_id": 999999,
|
||||
"metrics": ["count"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
# Must receive restricted error, not a NotFound that leaks existence
|
||||
assert data["error_type"] == "DataModelMetadataRestricted"
|
||||
assert data["error_type"] != "NotFound"
|
||||
133
tests/unit_tests/utils/test_rison_filters.py
Normal file
133
tests/unit_tests/utils/test_rison_filters.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Unit tests for Rison filter parser."""
|
||||
|
||||
|
||||
from superset.utils.rison_filters import RisonFilterParser
|
||||
|
||||
|
||||
def test_simple_equality():
|
||||
parser = RisonFilterParser()
|
||||
result = parser.parse("(country:USA)")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["expressionType"] == "SIMPLE"
|
||||
assert result[0]["clause"] == "WHERE"
|
||||
assert result[0]["subject"] == "country"
|
||||
assert result[0]["operator"] == "=="
|
||||
assert result[0]["comparator"] == "USA"
|
||||
|
||||
|
||||
def test_multiple_filters_and():
|
||||
parser = RisonFilterParser()
|
||||
result = parser.parse("(country:USA,year:2024)")
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["subject"] == "country"
|
||||
assert result[0]["comparator"] == "USA"
|
||||
assert result[1]["subject"] == "year"
|
||||
assert result[1]["comparator"] == 2024
|
||||
|
||||
|
||||
def test_list_in_operator():
|
||||
parser = RisonFilterParser()
|
||||
result = parser.parse("(country:!(USA,Canada))")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["subject"] == "country"
|
||||
assert result[0]["operator"] == "IN"
|
||||
assert result[0]["comparator"] == ["USA", "Canada"]
|
||||
|
||||
|
||||
def test_not_operator():
|
||||
parser = RisonFilterParser()
|
||||
result = parser.parse("(NOT:(country:USA))")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["subject"] == "country"
|
||||
assert result[0]["operator"] == "!="
|
||||
assert result[0]["comparator"] == "USA"
|
||||
|
||||
|
||||
def test_not_in_operator():
|
||||
parser = RisonFilterParser()
|
||||
result = parser.parse("(NOT:(country:!(USA,Canada)))")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["subject"] == "country"
|
||||
assert result[0]["operator"] == "NOT IN"
|
||||
assert result[0]["comparator"] == ["USA", "Canada"]
|
||||
|
||||
|
||||
def test_or_operator():
|
||||
parser = RisonFilterParser()
|
||||
result = parser.parse("(OR:!((status:active),(priority:high)))")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["expressionType"] == "SQL"
|
||||
assert result[0]["clause"] == "WHERE"
|
||||
assert "status = 'active' OR priority = 'high'" in result[0]["sqlExpression"]
|
||||
|
||||
|
||||
def test_comparison_operators():
|
||||
parser = RisonFilterParser()
|
||||
|
||||
result = parser.parse("(sales:(gt:100000))")
|
||||
assert result[0]["operator"] == ">"
|
||||
assert result[0]["comparator"] == 100000
|
||||
|
||||
result = parser.parse("(age:(gte:18))")
|
||||
assert result[0]["operator"] == ">="
|
||||
assert result[0]["comparator"] == 18
|
||||
|
||||
result = parser.parse("(temp:(lt:32))")
|
||||
assert result[0]["operator"] == "<"
|
||||
assert result[0]["comparator"] == 32
|
||||
|
||||
result = parser.parse("(price:(lte:1000))")
|
||||
assert result[0]["operator"] == "<="
|
||||
assert result[0]["comparator"] == 1000
|
||||
|
||||
|
||||
def test_between_operator():
|
||||
parser = RisonFilterParser()
|
||||
result = parser.parse("(date:(between:!('2024-01-01','2024-12-31')))")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["operator"] == "BETWEEN"
|
||||
assert result[0]["comparator"] == ["2024-01-01", "2024-12-31"]
|
||||
|
||||
|
||||
def test_like_operator():
|
||||
parser = RisonFilterParser()
|
||||
result = parser.parse("(name:(like:'%smith%'))")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["operator"] == "LIKE"
|
||||
assert result[0]["comparator"] == "%smith%"
|
||||
|
||||
|
||||
def test_empty_filter():
|
||||
parser = RisonFilterParser()
|
||||
assert parser.parse("") == []
|
||||
assert parser.parse("()") == []
|
||||
|
||||
|
||||
def test_invalid_rison():
|
||||
parser = RisonFilterParser()
|
||||
assert parser.parse("invalid rison") == []
|
||||
assert parser.parse("(unclosed") == []
|
||||
Reference in New Issue
Block a user