Compare commits

..

5 Commits

Author SHA1 Message Date
Hugh A. Miles II
3cb8b17337 Merge branch 'master' into hughhhh/zagreb-embed-page 2026-04-30 18:35:24 -04:00
Hugh A Miles II
16e83455e9 feat(dashboard): add Rison-encoded URL filter support
Adds ?f=(...) URL parameter support for hydrating dashboard filters from
human-readable Rison syntax. Matches URL filters to native filters by column
name (case-insensitive, handles spaces) and auto-applies them by populating
both filterState.value and extraFormData. Unmatched filters render as
removable chips in a "URL Filters" section above cross-filters.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-04-30 18:31:05 -04:00
Amin Ghadersohi
957b298ae1 fix(mcp): add default request parameter to list_charts and list_dashboards (#39730) 2026-04-30 18:04:39 -04:00
Amin Ghadersohi
f29d82b3b1 feat(mcp): add query_dataset tool to query datasets using semantic layer (#39727) 2026-04-30 18:03:41 -04:00
Vitor Avila
3f550f166f fix(GSheets OAuth2): Re-add UnauthenticatedError (#39785) 2026-04-30 18:57:00 -03:00
35 changed files with 3257 additions and 873 deletions

View File

@@ -59,15 +59,6 @@ build-assets() {
say "::endgroup::"
}
build-embedded-sdk() {
cd "$GITHUB_WORKSPACE/superset-embedded-sdk"
say "::group::Build embedded SDK bundle for E2E tests"
npm ci
npm run build
say "::endgroup::"
}
build-instrumented-assets() {
cd "$GITHUB_WORKSPACE/superset-frontend"

View File

@@ -169,7 +169,6 @@ jobs:
PYTHONPATH: ${{ github.workspace }}
REDIS_PORT: 16379
GITHUB_TOKEN: ${{ github.token }}
SUPERSET_FEATURE_EMBEDDED_SUPERSET: "true"
services:
postgres:
image: postgres:17-alpine
@@ -240,11 +239,6 @@ jobs:
uses: ./.github/actions/cached-dependencies
with:
run: build-instrumented-assets
- name: Build embedded SDK
if: steps.check.outputs.python || steps.check.outputs.frontend
uses: ./.github/actions/cached-dependencies
with:
run: build-embedded-sdk
- name: Install Playwright
if: steps.check.outputs.python || steps.check.outputs.frontend
uses: ./.github/actions/cached-dependencies

View File

@@ -43,7 +43,6 @@ jobs:
PYTHONPATH: ${{ github.workspace }}
REDIS_PORT: 16379
GITHUB_TOKEN: ${{ github.token }}
SUPERSET_FEATURE_EMBEDDED_SUPERSET: "true"
services:
postgres:
image: postgres:17-alpine
@@ -114,11 +113,6 @@ jobs:
uses: ./.github/actions/cached-dependencies
with:
run: build-instrumented-assets
- name: Build embedded SDK
if: steps.check.outputs.python || steps.check.outputs.frontend
uses: ./.github/actions/cached-dependencies
with:
run: build-embedded-sdk
- name: Install Playwright
if: steps.check.outputs.python || steps.check.outputs.frontend
uses: ./.github/actions/cached-dependencies

View File

@@ -95,7 +95,6 @@ export default defineConfig({
testIgnore: [
'**/tests/auth/**/*.spec.ts',
'**/tests/sqllab/**/*.spec.ts',
'**/tests/embedded/**/*.spec.ts',
...(process.env.INCLUDE_EXPERIMENTAL ? [] : ['**/experimental/**']),
],
use: {
@@ -133,18 +132,6 @@ export default defineConfig({
// No storageState = clean browser with no cached cookies
},
},
{
// Embedded dashboard tests - validates the full embedding flow:
// external app -> SDK -> iframe -> guest token -> dashboard render
name: 'chromium-embedded',
testMatch: '**/tests/embedded/**/*.spec.ts',
use: {
browserName: 'chromium',
testIdAttribute: 'data-test',
// Uses admin auth for API calls to configure embedding and get guest tokens
storageState: 'playwright/.auth/user.json',
},
},
],
// Web server setup - disabled in CI (Flask started separately in workflow)

View File

@@ -1,96 +0,0 @@
<!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Embedded Dashboard Test App</title>
<style>
html, body { margin: 0; padding: 0; height: 100%; }
#superset-container { width: 100%; height: 100vh; }
#superset-container iframe { width: 100%; height: 100%; border: none; }
#error { color: red; padding: 20px; display: none; }
#status { padding: 10px; font-family: monospace; font-size: 12px; }
</style>
</head>
<body>
<div id="status">Initializing embedded dashboard...</div>
<div id="error"></div>
<div id="superset-container" data-test="embedded-container"></div>
<script src="/sdk/index.js"></script>
<script>
(async function () {
const params = new URLSearchParams(window.location.search);
const uuid = params.get('uuid');
const supersetDomain = params.get('supersetDomain');
if (!uuid || !supersetDomain) {
document.getElementById('error').style.display = 'block';
document.getElementById('error').textContent =
'Missing required query params: uuid, supersetDomain';
return;
}
const statusEl = document.getElementById('status');
// fetchGuestToken is injected by Playwright via page.exposeFunction()
// Falls back to window.__guestToken for simple/static token injection
async function fetchGuestToken() {
if (typeof window.__fetchGuestToken === 'function') {
statusEl.textContent = 'Fetching guest token...';
const token = await window.__fetchGuestToken();
statusEl.textContent = 'Guest token received, loading dashboard...';
return token;
}
if (window.__guestToken) {
return window.__guestToken;
}
throw new Error('No guest token source available');
}
try {
// Parse optional UI config from query params
const uiConfig = {};
if (params.get('hideTitle') === 'true') uiConfig.hideTitle = true;
if (params.get('hideTab') === 'true') uiConfig.hideTab = true;
if (params.get('hideChartControls') === 'true') uiConfig.hideChartControls = true;
const dashboard = await supersetEmbeddedSdk.embedDashboard({
id: uuid,
supersetDomain: supersetDomain,
mountPoint: document.getElementById('superset-container'),
fetchGuestToken: fetchGuestToken,
dashboardUiConfig: Object.keys(uiConfig).length > 0 ? uiConfig : undefined,
debug: params.get('debug') === 'true',
});
statusEl.textContent = 'Dashboard embedded successfully';
// Expose dashboard API on window for Playwright assertions
window.__embeddedDashboard = dashboard;
} catch (err) {
document.getElementById('error').style.display = 'block';
document.getElementById('error').textContent = 'Embed failed: ' + err.message;
statusEl.textContent = 'Error';
}
})();
</script>
</body>
</html>

View File

@@ -132,14 +132,26 @@ export interface DashboardResult {
published?: boolean;
}
async function getDashboardByFilter(
/**
* Get a dashboard by its title
* @param page - Playwright page instance (provides authentication context)
* @param title - The dashboard_title to search for
* @returns Dashboard object if found, null if not found
*/
export async function getDashboardByName(
page: Page,
col: 'dashboard_title' | 'slug',
value: string,
title: string,
): Promise<DashboardResult | null> {
const queryParam = rison.encode({
filters: [{ col, opr: 'eq', value }],
});
const filter = {
filters: [
{
col: 'dashboard_title',
opr: 'eq',
value: title,
},
],
};
const queryParam = rison.encode(filter);
const response = await apiGet(
page,
`${ENDPOINTS.DASHBOARD}?q=${queryParam}`,
@@ -157,29 +169,3 @@ async function getDashboardByFilter(
return null;
}
/**
* Get a dashboard by its title
* @param page - Playwright page instance (provides authentication context)
* @param title - The dashboard_title to search for
* @returns Dashboard object if found, null if not found
*/
export async function getDashboardByName(
page: Page,
title: string,
): Promise<DashboardResult | null> {
return getDashboardByFilter(page, 'dashboard_title', title);
}
/**
* Get a dashboard by its slug
* @param page - Playwright page instance (provides authentication context)
* @param slug - The slug to search for
* @returns Dashboard object if found, null if not found
*/
export async function getDashboardBySlug(
page: Page,
slug: string,
): Promise<DashboardResult | null> {
return getDashboardByFilter(page, 'slug', slug);
}

View File

@@ -1,113 +0,0 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { Page } from '@playwright/test';
import { apiPost, apiPut } from './requests';
import { ENDPOINTS as DASHBOARD_ENDPOINTS } from './dashboard';
export const ENDPOINTS = {
SECURITY_LOGIN: 'api/v1/security/login',
GUEST_TOKEN: 'api/v1/security/guest_token/',
} as const;
export interface EmbeddedConfig {
uuid: string;
allowed_domains: string[];
dashboard_id: number;
}
/**
* Enable embedding on a dashboard and return the embedded UUID.
* Uses PUT (upsert) to preserve UUID across repeated calls.
* @param page - Playwright page instance (provides authentication context)
* @param dashboardIdOrSlug - Numeric dashboard id or slug
* @param allowedDomains - Domains allowed to embed; empty array allows all
* @returns Embedded config with UUID, allowed_domains, and dashboard_id
*/
export async function apiEnableEmbedding(
page: Page,
dashboardIdOrSlug: number | string,
allowedDomains: string[] = [],
): Promise<EmbeddedConfig> {
const response = await apiPut(
page,
`${DASHBOARD_ENDPOINTS.DASHBOARD}${dashboardIdOrSlug}/embedded`,
{ allowed_domains: allowedDomains },
);
const body = await response.json();
return body.result as EmbeddedConfig;
}
/**
* Get a guest token for an embedded dashboard.
* Uses the admin login flow (login → access_token → guest_token).
* @param page - Playwright page instance (used for request context)
* @param dashboardId - Dashboard id to grant access to
* @param options - Optional login credentials and RLS rules
* @returns Signed guest token string
*/
export async function getGuestToken(
page: Page,
dashboardId: number | string,
options?: {
username?: string;
password?: string;
rls?: Array<{ dataset: number; clause: string }>;
},
): Promise<string> {
const username = options?.username ?? 'admin';
const password = options?.password ?? 'general';
const rls = options?.rls ?? [];
// Step 1: Login to get access token
const loginResponse = await apiPost(
page,
ENDPOINTS.SECURITY_LOGIN,
{
username,
password,
provider: 'db',
refresh: true,
},
{ allowMissingCsrf: true },
);
const loginBody = await loginResponse.json();
const accessToken = loginBody.access_token;
// Step 2: Fetch guest token using the access token.
// Uses raw page.request.post() (not apiPost) because the guest token endpoint
// requires a JWT Bearer token rather than session+CSRF auth.
const guestResponse = await page.request.post(ENDPOINTS.GUEST_TOKEN, {
data: {
user: {
username: 'embedded_test_user',
first_name: 'Embedded',
last_name: 'TestUser',
},
resources: [{ type: 'dashboard', id: String(dashboardId) }],
rls,
},
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${accessToken}`,
},
});
const guestBody = await guestResponse.json();
return guestBody.token;
}

View File

@@ -1,140 +0,0 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { Page, FrameLocator } from '@playwright/test';
import { EMBEDDED } from '../utils/constants';
/**
* Page object for the embedded dashboard test app.
*
* The test app runs on a separate origin (localhost:9000) and uses the
* @superset-ui/embedded-sdk to render a Superset dashboard in an iframe.
* Playwright's page.exposeFunction() bridges the guest token from Node.js
* into the browser page.
*/
export class EmbeddedPage {
private readonly page: Page;
private static readonly SELECTORS = {
CONTAINER: '[data-test="embedded-container"]',
IFRAME: 'iframe[title="Embedded Dashboard"]',
STATUS: '#status',
ERROR: '#error',
} as const;
constructor(page: Page) {
this.page = page;
}
/**
* Set up the guest token bridge before navigating.
* Must be called BEFORE goto() since embedDashboard() calls fetchGuestToken
* immediately on page load.
*/
async exposeTokenFetcher(tokenFn: () => Promise<string>): Promise<void> {
await this.page.exposeFunction('__fetchGuestToken', tokenFn);
}
/**
* Navigate to the embedded test app with the given parameters.
*/
async goto(params: {
uuid: string;
supersetDomain: string;
hideTitle?: boolean;
hideTab?: boolean;
hideChartControls?: boolean;
debug?: boolean;
}): Promise<void> {
const searchParams = new URLSearchParams({
uuid: params.uuid,
supersetDomain: params.supersetDomain,
});
if (params.hideTitle) searchParams.set('hideTitle', 'true');
if (params.hideTab) searchParams.set('hideTab', 'true');
if (params.hideChartControls) searchParams.set('hideChartControls', 'true');
if (params.debug) searchParams.set('debug', 'true');
await this.page.goto(`${EMBEDDED.APP_URL}/?${searchParams.toString()}`);
}
/**
* FrameLocator for the embedded dashboard iframe.
*/
get iframe(): FrameLocator {
return this.page.frameLocator(EmbeddedPage.SELECTORS.IFRAME);
}
/**
* Wait for the iframe to appear in the DOM.
*/
async waitForIframe(options?: { timeout?: number }): Promise<void> {
await this.page.locator(EmbeddedPage.SELECTORS.IFRAME).waitFor({
state: 'attached',
timeout: options?.timeout ?? EMBEDDED.IFRAME_LOAD,
});
}
/**
* Wait for dashboard content to render inside the iframe.
* Looks for the grid-container which indicates charts are loading/loaded.
*/
async waitForDashboardContent(options?: { timeout?: number }): Promise<void> {
const frame = this.iframe;
await frame
.locator('.grid-container, [data-test="grid-container"]')
.first()
.waitFor({
state: 'visible',
timeout: options?.timeout ?? EMBEDDED.DASHBOARD_RENDER,
});
}
/**
* Get the status text from the test app.
*/
async getStatus(): Promise<string> {
return (
(await this.page.locator(EmbeddedPage.SELECTORS.STATUS).textContent()) ??
''
);
}
/**
* Get the error text, if any.
*/
async getError(): Promise<string> {
const errorEl = this.page.locator(EmbeddedPage.SELECTORS.ERROR);
const display = await errorEl.evaluate(el => getComputedStyle(el).display);
if (display === 'none') return '';
return (await errorEl.textContent()) ?? '';
}
/**
* Check if the dashboard title is visible inside the iframe.
*/
async isTitleVisible(): Promise<boolean> {
const frame = this.iframe;
return frame
.locator(
'[data-test="dashboard-header-container"] [data-test="editable-title-input"]',
)
.isVisible();
}
}

View File

@@ -1,288 +0,0 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { test, expect, Browser, BrowserContext, Page } from '@playwright/test';
import { createServer, IncomingMessage, ServerResponse, Server } from 'http';
import { readFileSync, existsSync } from 'fs';
import { join } from 'path';
import { apiEnableEmbedding, getGuestToken } from '../../helpers/api/embedded';
import { getDashboardBySlug } from '../../helpers/api/dashboard';
import { EmbeddedPage } from '../../pages/EmbeddedPage';
import { EMBEDDED } from '../../utils/constants';
/**
* Superset domain (Flask server) — set by CI or defaults to local dev
*/
const SUPERSET_DOMAIN = (() => {
const url = process.env.PLAYWRIGHT_BASE_URL || 'http://localhost:8088';
return url.replace(/\/+$/, '');
})();
const SUPERSET_BASE_URL = SUPERSET_DOMAIN.endsWith('/')
? SUPERSET_DOMAIN
: `${SUPERSET_DOMAIN}/`;
/**
* Path to the SDK bundle built from superset-embedded-sdk/
*/
const SDK_BUNDLE_PATH = join(
__dirname,
'../../../../superset-embedded-sdk/bundle/index.js',
);
/**
* Path to the embedded test app static files
*/
const EMBED_APP_DIR = join(__dirname, '../../embedded-app');
/**
* Create a minimal static file server for the embedded test app.
* Serves only a fixed allowlist of routes — the test app references just
* its index.html and the SDK bundle, so anything else is 404.
*/
const INDEX_HTML_PATH = join(EMBED_APP_DIR, 'index.html');
function createEmbedAppServer(): Server {
return createServer((req: IncomingMessage, res: ServerResponse) => {
const urlPath = req.url?.split('?')[0] || '/';
if (urlPath === '/sdk/index.js') {
if (!existsSync(SDK_BUNDLE_PATH)) {
res.writeHead(404);
res.end(
'SDK bundle not found. Run: cd superset-embedded-sdk && npm ci && npm run build',
);
return;
}
res.writeHead(200, { 'Content-Type': 'text/javascript' });
res.end(readFileSync(SDK_BUNDLE_PATH));
return;
}
if (urlPath === '/' || urlPath === '/index.html') {
res.writeHead(200, { 'Content-Type': 'text/html' });
res.end(readFileSync(INDEX_HTML_PATH));
return;
}
res.writeHead(404);
res.end('Not found');
});
}
/**
* Create a browser context authenticated as admin for API-only work
* (enabling embedding, restoring config). Caller is responsible for closing.
*/
function createAdminContext(browser: Browser): Promise<BrowserContext> {
return browser.newContext({
storageState: 'playwright/.auth/user.json',
baseURL: SUPERSET_BASE_URL,
});
}
// ─── Test Suite ────────────────────────────────────────────────────────────
// Describe wrapper is needed for shared server state and serial execution:
// all tests share a static file server on a fixed port and must not run in parallel.
test.describe('Embedded Dashboard E2E', () => {
test.describe.configure({ mode: 'serial' });
let server: Server;
let embedUuid: string;
let dashboardId: number;
/**
* Set up a page to render the default embedded dashboard.
* Tests that need a different UUID or UI config should not use this helper.
*/
async function setupEmbeddedPage(page: Page): Promise<EmbeddedPage> {
const embeddedPage = new EmbeddedPage(page);
await embeddedPage.exposeTokenFetcher(async () =>
getGuestToken(page, dashboardId),
);
await embeddedPage.goto({
uuid: embedUuid,
supersetDomain: SUPERSET_DOMAIN,
});
await embeddedPage.waitForIframe();
await embeddedPage.waitForDashboardContent();
return embeddedPage;
}
test.beforeAll(async ({ browser }) => {
// Skip all tests if the SDK bundle hasn't been built
test.skip(
!existsSync(SDK_BUNDLE_PATH),
'Embedded SDK bundle not found. Build it with: cd superset-embedded-sdk && npm ci && npm run build',
);
// Start the embedded test app server
server = createEmbedAppServer();
await new Promise<void>((resolve, reject) => {
server.on('error', reject);
server.listen(EMBEDDED.APP_PORT, () => resolve());
});
// Use a fresh context with auth to set up test data via API
const context = await createAdminContext(browser);
const setupPage = await context.newPage();
try {
// Find a well-known example dashboard
const dashboard = await getDashboardBySlug(setupPage, 'world_health');
if (!dashboard) {
throw new Error(
'Dashboard "world_health" not found. Ensure load_examples ran in CI setup.',
);
}
dashboardId = dashboard.id;
// Enable embedding on the dashboard (empty allowed_domains = allow all)
const embedded = await apiEnableEmbedding(setupPage, dashboardId);
embedUuid = embedded.uuid;
} finally {
await context.close();
}
});
test.afterAll(async () => {
if (server) {
await new Promise<void>(resolve => server.close(() => resolve()));
}
});
test('dashboard renders in embedded iframe', async ({ page }) => {
const embeddedPage = await setupEmbeddedPage(page);
// Verify the iframe src points to Superset's /embedded/ endpoint
const iframeSrc = await page
.locator('iframe[title="Embedded Dashboard"]')
.getAttribute('src');
expect(iframeSrc).toContain(`/embedded/${embedUuid}`);
// Verify no errors in the test app
const error = await embeddedPage.getError();
expect(error).toBe('');
// Baseline: title should be visible when hideTitle is not set
const titleVisible = await embeddedPage.isTitleVisible();
expect(titleVisible).toBe(true);
});
test('UI config hideTitle hides dashboard title', async ({ page }) => {
const embeddedPage = new EmbeddedPage(page);
await embeddedPage.exposeTokenFetcher(async () =>
getGuestToken(page, dashboardId),
);
await embeddedPage.goto({
uuid: embedUuid,
supersetDomain: SUPERSET_DOMAIN,
hideTitle: true,
});
await embeddedPage.waitForIframe();
await embeddedPage.waitForDashboardContent();
// The iframe URL should include uiConfig parameter
const iframeSrc = await page
.locator('iframe[title="Embedded Dashboard"]')
.getAttribute('src');
expect(iframeSrc).toContain('uiConfig=');
// Verify the title is actually hidden inside the iframe
const titleVisible = await embeddedPage.isTitleVisible();
expect(titleVisible).toBe(false);
});
test('charts render inside embedded iframe', async ({ page }) => {
const embeddedPage = await setupEmbeddedPage(page);
// Verify chart containers are present and visible in the iframe
const charts = embeddedPage.iframe.locator(
'.chart-container, [data-test="chart-container"]',
);
await expect(charts.first()).toBeVisible({
timeout: EMBEDDED.DASHBOARD_RENDER,
});
});
test('allowed_domains blocks unauthorized referrer', async ({
page,
browser,
}) => {
const context = await createAdminContext(browser);
const setupPage = await context.newPage();
try {
// Restrict to a domain that is NOT localhost:9000
const restrictedEmbed = await apiEnableEmbedding(setupPage, dashboardId, [
'https://allowed.example.com',
]);
const embeddedPage = new EmbeddedPage(page);
await embeddedPage.exposeTokenFetcher(async () =>
getGuestToken(page, dashboardId),
);
await embeddedPage.goto({
uuid: restrictedEmbed.uuid,
supersetDomain: SUPERSET_DOMAIN,
});
// The iframe should load but get a 403 from Superset's referrer check
await embeddedPage.waitForIframe();
// The dashboard content should NOT render (403 blocks the embedded page)
const content = embeddedPage.iframe.locator(
'.grid-container, [data-test="grid-container"]',
);
await expect(content).not.toBeVisible({ timeout: 5000 });
} finally {
// Restore the open embedding config for other tests
await apiEnableEmbedding(setupPage, dashboardId, []);
await context.close();
}
});
test('guest token enables dashboard data access', async ({ page }) => {
const embeddedPage = new EmbeddedPage(page);
let tokenCallCount = 0;
await embeddedPage.exposeTokenFetcher(async () => {
tokenCallCount += 1;
return getGuestToken(page, dashboardId);
});
await embeddedPage.goto({
uuid: embedUuid,
supersetDomain: SUPERSET_DOMAIN,
});
await embeddedPage.waitForIframe();
await embeddedPage.waitForDashboardContent();
// The SDK should have called fetchGuestToken at least once
expect(tokenCallCount).toBeGreaterThanOrEqual(1);
// Verify charts are actually rendering data (not just loading spinners)
const charts = embeddedPage.iframe.locator(
'.chart-container, [data-test="chart-container"]',
);
const chartCount = await charts.count();
expect(chartCount).toBeGreaterThan(0);
});
});

View File

@@ -75,18 +75,3 @@ export const TIMEOUT = {
*/
SLOW_TEST: 60000, // 60s for tests that chain multiple slow operations
} as const;
/**
* Embedded dashboard test app configuration.
* The test app is served by a Node.js http server started in the test fixture.
*/
export const EMBEDDED = {
/** Port for the embedded test app static server */
APP_PORT: 9000,
/** Full URL for the embedded test app */
APP_URL: 'http://localhost:9000',
/** Timeout for iframe to appear in the DOM */
IFRAME_LOAD: 15000, // 15s
/** Timeout for dashboard content to render inside the iframe */
DASHBOARD_RENDER: 30000, // 30s
} as const;

View File

@@ -17,20 +17,31 @@
* under the License.
*/
import { FC, memo, useMemo } from 'react';
import { FC, memo, useCallback, useMemo } from 'react';
import { t } from '@apache-superset/core/translation';
import { DataMaskStateWithId } from '@superset-ui/core';
import { styled } from '@apache-superset/core/theme';
import { Loading } from '@superset-ui/core/components';
import { RootState } from 'src/dashboard/types';
import { Icons } from '@superset-ui/core/components/Icons';
import { FilterBarOrientation, RootState } from 'src/dashboard/types';
import { useChartLayoutItems } from 'src/dashboard/util/useChartLayoutItems';
import { useChartIds } from 'src/dashboard/util/charts/useChartIds';
import { useSelector } from 'react-redux';
import {
getRisonFilterParam,
parseRisonFilters,
updateUrlWithUnmatchedFilters,
} from 'src/dashboard/util/risonFilters';
import FilterControls from './FilterControls/FilterControls';
import { useChartsVerboseMaps, getFilterBarTestId } from './utils';
import { HorizontalBarProps } from './types';
import FilterBarSettings from './FilterBarSettings';
import crossFiltersSelector from './CrossFilters/selectors';
import {
getUrlFilterIndicators,
UrlFilterIndicator,
} from './UrlFilters/selectors';
import UrlFilterTag from './UrlFilters/UrlFilterTag';
const HorizontalBar = styled.div`
${({ theme }) => `
@@ -65,6 +76,28 @@ const FilterBarEmptyStateContainer = styled.div`
`}
`;
const UrlFiltersContainer = styled.div`
${({ theme }) => `
display: flex;
flex-direction: row;
align-items: center;
gap: ${theme.sizeUnit * 2}px;
padding: 0 ${theme.sizeUnit * 2}px;
margin-right: ${theme.sizeUnit * 2}px;
border-right: 1px solid ${theme.colorBorder};
`}
`;
const UrlFilterTitle = styled.div`
${({ theme }) => `
display: flex;
align-items: center;
gap: ${theme.sizeUnit}px;
font-weight: ${theme.fontWeightStrong};
font-size: ${theme.fontSizeSM}px;
`}
`;
const HorizontalFilterBar: FC<HorizontalBarProps> = ({
actions,
dataMaskSelected,
@@ -94,9 +127,47 @@ const HorizontalFilterBar: FC<HorizontalBarProps> = ({
[chartIds, chartLayoutItems, dataMask, verboseMaps],
);
const activeUrlFilters = useMemo(() => getUrlFilterIndicators(), []);
const handleRemoveUrlFilter = useCallback(
(filterToRemove: UrlFilterIndicator) => {
const risonParam = getRisonFilterParam();
if (!risonParam) return;
const currentFilters = parseRisonFilters(risonParam);
const remaining = currentFilters.filter(
f => f.subject !== filterToRemove.filter.subject,
);
updateUrlWithUnmatchedFilters(remaining);
},
[],
);
const urlFiltersComponent = useMemo(() => {
if (activeUrlFilters.length === 0) return null;
return (
<UrlFiltersContainer>
<UrlFilterTitle>
<Icons.LinkOutlined iconSize="s" />
{t('URL Filters')}
</UrlFilterTitle>
{activeUrlFilters.map(filter => (
<UrlFilterTag
key={filter.subject}
filter={filter}
orientation={FilterBarOrientation.Horizontal}
onRemove={handleRemoveUrlFilter}
/>
))}
</UrlFiltersContainer>
);
}, [activeUrlFilters, handleRemoveUrlFilter]);
const hasFilters =
filterValues.length > 0 ||
selectedCrossFilters.length > 0 ||
activeUrlFilters.length > 0 ||
chartCustomizationValues.length > 0;
return (
@@ -113,16 +184,19 @@ const HorizontalFilterBar: FC<HorizontalBarProps> = ({
</FilterBarEmptyStateContainer>
)}
{hasFilters && (
<FilterControls
dataMaskSelected={dataMaskSelected}
onFilterSelectionChange={onSelectionChange}
onPendingCustomizationDataMaskChange={
onPendingCustomizationDataMaskChange
}
chartCustomizationValues={chartCustomizationValues}
clearAllTriggers={clearAllTriggers}
onClearAllComplete={onClearAllComplete}
/>
<>
{urlFiltersComponent}
<FilterControls
dataMaskSelected={dataMaskSelected}
onFilterSelectionChange={onSelectionChange}
onPendingCustomizationDataMaskChange={
onPendingCustomizationDataMaskChange
}
chartCustomizationValues={chartCustomizationValues}
clearAllTriggers={clearAllTriggers}
onClearAllComplete={onClearAllComplete}
/>
</>
)}
{actions}
</>

View File

@@ -0,0 +1,85 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { useCSSTextTruncation } from '@superset-ui/core';
import { styled, css, useTheme } from '@apache-superset/core/theme';
import { Tag } from 'src/components/Tag';
import { Tooltip } from '@superset-ui/core/components';
import { FilterBarOrientation } from 'src/dashboard/types';
import { ellipsisCss } from '../CrossFilters/styles';
import { UrlFilterIndicator } from './selectors';
const StyledValue = styled.b`
${({ theme }) => `
max-width: ${theme.sizeUnit * 25}px;
`}
${ellipsisCss}
`;
const StyledColumn = styled('span')`
${({ theme }) => `
max-width: ${theme.sizeUnit * 25}px;
padding-right: ${theme.sizeUnit}px;
`}
${ellipsisCss}
`;
const StyledTag = styled(Tag)`
${({ theme }) => `
border: 1px solid ${theme.colorBorder};
border-radius: 2px;
.anticon-close {
vertical-align: middle;
}
`}
`;
const UrlFilterTag = (props: {
filter: UrlFilterIndicator;
orientation: FilterBarOrientation;
onRemove: (filter: UrlFilterIndicator) => void;
}) => {
const { filter, orientation, onRemove } = props;
const theme = useTheme();
const [columnRef, columnIsTruncated] =
useCSSTextTruncation<HTMLSpanElement>();
const [valueRef, valueIsTruncated] = useCSSTextTruncation<HTMLSpanElement>();
return (
<StyledTag
css={css`
${orientation === FilterBarOrientation.Vertical
? `margin-top: ${theme.sizeUnit * 2}px;`
: `margin-left: ${theme.sizeUnit * 2}px;`}
`}
closable
onClose={() => onRemove(filter)}
editable
>
<Tooltip title={columnIsTruncated ? filter.subject : null}>
<StyledColumn ref={columnRef}>{filter.subject}</StyledColumn>
</Tooltip>
<Tooltip title={valueIsTruncated ? filter.value : null}>
<StyledValue ref={valueRef}>{filter.value}</StyledValue>
</Tooltip>
</StyledTag>
);
};
export default UrlFilterTag;

View File

@@ -0,0 +1,34 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { useMemo } from 'react';
import { getUrlFilterIndicators } from './selectors';
import UrlFiltersVerticalCollapse from './VerticalCollapse';
const UrlFiltersVertical = () => {
const urlFilters = useMemo(() => getUrlFilterIndicators(), []);
if (!urlFilters.length) {
return null;
}
return <UrlFiltersVerticalCollapse urlFilters={urlFilters} />;
};
export default UrlFiltersVertical;

View File

@@ -0,0 +1,173 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { useMemo, useState, useCallback } from 'react';
import { t } from '@apache-superset/core/translation';
import { css, useTheme, SupersetTheme } from '@apache-superset/core/theme';
import { Icons } from '@superset-ui/core/components/Icons';
import { FilterBarOrientation } from 'src/dashboard/types';
import {
updateUrlWithUnmatchedFilters,
getRisonFilterParam,
parseRisonFilters,
} from 'src/dashboard/util/risonFilters';
import UrlFilterTag from './UrlFilterTag';
import { UrlFilterIndicator } from './selectors';
const UrlFiltersVerticalCollapse = (props: {
urlFilters: UrlFilterIndicator[];
}) => {
const { urlFilters: initialFilters } = props;
const theme = useTheme();
const [isOpen, setIsOpen] = useState(true);
const [urlFilters, setUrlFilters] =
useState<UrlFilterIndicator[]>(initialFilters);
const toggleSection = useCallback(() => {
setIsOpen(prev => !prev);
}, []);
const handleRemoveFilter = useCallback(
(filterToRemove: UrlFilterIndicator) => {
const risonParam = getRisonFilterParam();
if (!risonParam) return;
const currentFilters = parseRisonFilters(risonParam);
const remaining = currentFilters.filter(
f => f.subject !== filterToRemove.filter.subject,
);
updateUrlWithUnmatchedFilters(remaining);
setUrlFilters(prev =>
prev.filter(f => f.subject !== filterToRemove.subject),
);
},
[],
);
const sectionContainerStyle = useCallback(
(theme: SupersetTheme) => css`
margin-bottom: ${theme.sizeUnit * 3}px;
padding: 0 ${theme.sizeUnit * 4}px;
`,
[],
);
const sectionHeaderStyle = useCallback(
(theme: SupersetTheme) => css`
display: flex;
align-items: center;
justify-content: space-between;
padding: ${theme.sizeUnit * 2}px 0;
cursor: pointer;
user-select: none;
&:hover {
background: ${theme.colorBgTextHover};
margin: 0 -${theme.sizeUnit * 2}px;
padding: ${theme.sizeUnit * 2}px;
border-radius: ${theme.borderRadius}px;
}
`,
[],
);
const sectionTitleStyle = useCallback(
(theme: SupersetTheme) => css`
margin: 0;
font-size: ${theme.fontSize}px;
font-weight: ${theme.fontWeightStrong};
color: ${theme.colorText};
line-height: 1.3;
display: flex;
align-items: center;
gap: ${theme.sizeUnit}px;
`,
[],
);
const sectionContentStyle = useCallback(
(theme: SupersetTheme) => css`
padding: ${theme.sizeUnit * 2}px 0;
`,
[],
);
const dividerStyle = useCallback(
(theme: SupersetTheme) => css`
height: 1px;
background: ${theme.colorSplit};
margin: ${theme.sizeUnit * 2}px 0;
`,
[],
);
const iconStyle = useCallback(
(open: boolean, theme: SupersetTheme) => css`
transform: ${open ? 'rotate(0deg)' : 'rotate(180deg)'};
transition: transform 0.2s ease;
color: ${theme.colorTextSecondary};
`,
[],
);
const filterIndicators = useMemo(
() =>
urlFilters.map(filter => (
<UrlFilterTag
key={filter.subject}
filter={filter}
orientation={FilterBarOrientation.Vertical}
onRemove={handleRemoveFilter}
/>
)),
[urlFilters, handleRemoveFilter],
);
if (!urlFilters.length) {
return null;
}
return (
<div css={sectionContainerStyle}>
<div
css={sectionHeaderStyle}
onClick={toggleSection}
onKeyDown={e => {
if (e.key === 'Enter' || e.key === ' ') {
e.preventDefault();
toggleSection();
}
}}
role="button"
tabIndex={0}
>
<h4 css={sectionTitleStyle}>
<Icons.LinkOutlined iconSize="s" />
{t('URL Filters')}
</h4>
<Icons.UpOutlined iconSize="m" css={iconStyle(isOpen, theme)} />
</div>
{isOpen && <div css={sectionContentStyle}>{filterIndicators}</div>}
{isOpen && <div css={dividerStyle} data-test="url-filters-divider" />}
</div>
);
};
export default UrlFiltersVerticalCollapse;

View File

@@ -0,0 +1,60 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import {
getRisonFilterParam,
parseRisonFilters,
RisonFilter,
} from 'src/dashboard/util/risonFilters';
export interface UrlFilterIndicator {
subject: string;
operator: string;
value: string;
filter: RisonFilter;
}
function formatFilterValue(filter: RisonFilter): string {
const { comparator, operator } = filter;
if (operator === 'BETWEEN' && Array.isArray(comparator)) {
return `${comparator[0]} ${comparator[1]}`;
}
if (Array.isArray(comparator)) {
return comparator.join(', ');
}
return String(comparator);
}
export function getUrlFilterIndicators(): UrlFilterIndicator[] {
const risonParam = getRisonFilterParam();
if (!risonParam) {
return [];
}
const filters = parseRisonFilters(risonParam);
return filters.map(filter => ({
subject: filter.subject,
operator: filter.operator,
value: formatFilterValue(filter),
filter,
}));
}

View File

@@ -45,6 +45,7 @@ import Header from './Header';
import FilterControls from './FilterControls/FilterControls';
import CrossFiltersVertical from './CrossFilters/Vertical';
import crossFiltersSelector from './CrossFilters/selectors';
import UrlFiltersVertical from './UrlFilters/Vertical';
enum SectionType {
Filters = 'filters',
@@ -301,6 +302,7 @@ const VerticalFilterBar: FC<VerticalBarProps> = ({
) : (
<div css={tabPaneStyle} onScroll={onScroll}>
<>
<UrlFiltersVertical />
<CrossFiltersVertical hideHeader={hasOnlyOneSectionType} />
{filterControls}
</>

View File

@@ -107,9 +107,16 @@ const publishDataMask = debounce(
const previousParams = new URLSearchParams(search);
const newParams = new URLSearchParams();
let dataMaskKey: string | null;
let risonFilterValue: string | null = null;
previousParams.forEach((value, key) => {
if (!EXCLUDED_URL_PARAMS.includes(key)) {
newParams.append(key, value);
if (key === 'f') {
// Preserve the original Rison filter value to avoid encoding
risonFilterValue = value;
} else {
newParams.append(key, value);
}
}
});
@@ -148,9 +155,16 @@ const publishDataMask = debounce(
if (appRoot !== '/' && replacementPathname.startsWith(appRoot)) {
replacementPathname = replacementPathname.substring(appRoot.length);
}
// Manually reconstruct the search string to preserve Rison filter encoding
let searchString = newParams.toString();
if (risonFilterValue) {
const separator = searchString ? '&' : '';
searchString = `${searchString}${separator}f=${risonFilterValue}`;
}
history.replace({
pathname: replacementPathname,
search: newParams.toString(),
search: searchString,
});
}
},

View File

@@ -64,6 +64,15 @@ import SyncDashboardState, {
getDashboardContextLocalStorage,
} from '../components/SyncDashboardState';
import { AutoRefreshProvider } from '../contexts/AutoRefreshContext';
import { PartialFilters } from '@superset-ui/core';
import {
parseRisonFilters,
risonToAdhocFilters,
getRisonFilterParam,
prettifyRisonFilterUrl,
injectRisonFiltersIntelligently,
updateUrlWithUnmatchedFilters,
} from '../util/risonFilters';
export const DashboardPageIdContext = createContext('');
@@ -195,6 +204,61 @@ export const DashboardPage: FC<PageProps> = ({ idOrSlug }: PageProps) => {
dataMask = isOldRison;
}
// Parse Rison URL filters with intelligent native filter injection
const risonFilterParam = getRisonFilterParam();
if (risonFilterParam) {
const risonFilters = parseRisonFilters(risonFilterParam);
if (risonFilters.length > 0) {
// Convert native filter config array to keyed object for lookup
const filterConfigArray =
(dashboard?.metadata
?.native_filter_configuration as Array<Record<string, unknown> & { id: string }>) ||
[];
const nativeFilters: PartialFilters = {};
filterConfigArray.forEach(filter => {
nativeFilters[filter.id] = filter as PartialFilters[string];
});
const injectionResult = injectRisonFiltersIntelligently(
risonFilters,
nativeFilters,
dataMask,
);
dataMask = injectionResult.updatedDataMask;
// For unmatched filters, fall back to adhoc filter approach
if (injectionResult.unmatchedFilters.length > 0) {
const unmatchedAdhocFilters = risonToAdhocFilters(
injectionResult.unmatchedFilters,
);
const risonDataMask = {
__rison_filters__: {
filterState: { value: unmatchedAdhocFilters },
ownState: {},
},
};
dataMask = { ...dataMask, ...risonDataMask };
}
// Clean up URL: remove matched filters, keep only unmatched ones
const matchedCount =
risonFilters.length - injectionResult.unmatchedFilters.length;
if (matchedCount > 0) {
setTimeout(
() =>
updateUrlWithUnmatchedFilters(injectionResult.unmatchedFilters),
100,
);
}
if (injectionResult.unmatchedFilters.length > 0) {
setTimeout(() => prettifyRisonFilterUrl(), 150);
}
}
}
if (readyToRender) {
if (!isDashboardHydrated.current) {
isDashboardHydrated.current = true;

View File

@@ -0,0 +1,325 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { PartialFilters, DataMaskStateWithId } from '@superset-ui/core';
import {
injectRisonFiltersIntelligently,
RisonFilter,
parseRisonFilters,
risonFiltersToString,
risonToAdhocFilters,
} from './risonFilters';
const mockNativeFilters: PartialFilters = {
filter_1: {
id: 'filter_1',
targets: [
{
column: { name: 'country' },
datasetId: 1,
},
],
filterType: 'filter_select',
},
filter_2: {
id: 'filter_2',
targets: [
{
column: { name: 'year' },
datasetId: 1,
},
],
filterType: 'filter_range',
},
filter_3: {
id: 'filter_3',
targets: [
{
column: { name: 'Country Code' },
datasetId: 1,
},
],
filterType: 'filter_select',
},
};
const mockDataMask: DataMaskStateWithId = {
filter_1: {
id: 'filter_1',
filterState: { value: undefined },
ownState: {},
},
};
test('should parse simple Rison filters', () => {
const risonString = '(country:USA,year:2024)';
const result = parseRisonFilters(risonString);
expect(result).toHaveLength(2);
expect(result[0]).toEqual({
subject: 'country',
operator: '==',
comparator: 'USA',
});
expect(result[1]).toEqual({
subject: 'year',
operator: '==',
comparator: 2024,
});
});
test('should parse IN operator with array syntax', () => {
const result = parseRisonFilters('(country:!(USA,Canada))');
expect(result).toHaveLength(1);
expect(result[0]).toEqual({
subject: 'country',
operator: 'IN',
comparator: ['USA', 'Canada'],
});
});
test('should parse BETWEEN operator', () => {
const result = parseRisonFilters('(msrp:(between:!(35,200)))');
expect(result).toHaveLength(1);
expect(result[0]).toEqual({
subject: 'msrp',
operator: 'BETWEEN',
comparator: [35, 200],
});
});
test('should parse NOT operator', () => {
const result = parseRisonFilters('(NOT:(country:USA))');
expect(result).toHaveLength(1);
expect(result[0].operator).toBe('!=');
expect(result[0].comparator).toBe('USA');
});
test('should parse comparison operators', () => {
expect(parseRisonFilters('(sales:(gt:100000))')[0].operator).toBe('>');
expect(parseRisonFilters('(age:(gte:18))')[0].operator).toBe('>=');
expect(parseRisonFilters('(temp:(lt:32))')[0].operator).toBe('<');
expect(parseRisonFilters('(price:(lte:1000))')[0].operator).toBe('<=');
});
test('should return empty array for invalid Rison', () => {
expect(parseRisonFilters('invalid rison')).toEqual([]);
expect(parseRisonFilters('(unclosed')).toEqual([]);
});
test('should match Rison filter to native filter by column name', () => {
const risonFilters: RisonFilter[] = [
{ subject: 'country', operator: '==', comparator: 'USA' },
];
const result = injectRisonFiltersIntelligently(
risonFilters,
mockNativeFilters,
mockDataMask,
);
expect(result.updatedDataMask.filter_1.filterState?.value).toEqual(['USA']);
expect(result.unmatchedFilters).toHaveLength(0);
});
test('should match column names with spaces (case-insensitive)', () => {
const risonFilters: RisonFilter[] = [
{ subject: 'Country Code', operator: '==', comparator: 'USA' },
];
const result = injectRisonFiltersIntelligently(
risonFilters,
mockNativeFilters,
mockDataMask,
);
expect(result.updatedDataMask.filter_3.filterState?.value).toEqual(['USA']);
expect(result.unmatchedFilters).toHaveLength(0);
});
test('should match column names case-insensitively', () => {
const risonFilters: RisonFilter[] = [
{ subject: 'country code', operator: '==', comparator: 'USA' },
];
const result = injectRisonFiltersIntelligently(
risonFilters,
mockNativeFilters,
mockDataMask,
);
expect(result.updatedDataMask.filter_3.filterState?.value).toEqual(['USA']);
expect(result.unmatchedFilters).toHaveLength(0);
});
test('should handle unmatched filters with fallback', () => {
const risonFilters: RisonFilter[] = [
{ subject: 'region', operator: '==', comparator: 'North America' },
];
const result = injectRisonFiltersIntelligently(
risonFilters,
mockNativeFilters,
mockDataMask,
);
expect(result.unmatchedFilters).toHaveLength(1);
expect(result.unmatchedFilters[0].subject).toBe('region');
});
test('should convert values correctly for different filter types', () => {
const risonFilters: RisonFilter[] = [
{ subject: 'country', operator: '==', comparator: 'USA' },
{ subject: 'year', operator: 'BETWEEN', comparator: [2020, 2024] },
];
const result = injectRisonFiltersIntelligently(
risonFilters,
mockNativeFilters,
mockDataMask,
);
// Select filter should be array
expect(result.updatedDataMask.filter_1.filterState?.value).toEqual(['USA']);
// Range filter should be min/max object
expect(result.updatedDataMask.filter_2.filterState?.value).toEqual({
min: 2020,
max: 2024,
});
expect(result.unmatchedFilters).toHaveLength(0);
});
test('should set extraFormData for auto-application on select filters', () => {
const risonFilters: RisonFilter[] = [
{ subject: 'country', operator: '==', comparator: 'USA' },
];
const result = injectRisonFiltersIntelligently(
risonFilters,
mockNativeFilters,
mockDataMask,
);
expect(result.updatedDataMask.filter_1.extraFormData).toEqual({
filters: [{ col: 'country', op: 'IN', val: ['USA'] }],
});
});
test('should set extraFormData for auto-application on IN filters', () => {
const risonFilters: RisonFilter[] = [
{ subject: 'country', operator: 'IN', comparator: ['USA', 'Canada'] },
];
const result = injectRisonFiltersIntelligently(
risonFilters,
mockNativeFilters,
mockDataMask,
);
expect(result.updatedDataMask.filter_1.filterState?.value).toEqual([
'USA',
'Canada',
]);
expect(result.updatedDataMask.filter_1.extraFormData).toEqual({
filters: [{ col: 'country', op: 'IN', val: ['USA', 'Canada'] }],
});
});
test('should set extraFormData for auto-application on BETWEEN filters', () => {
const risonFilters: RisonFilter[] = [
{ subject: 'year', operator: 'BETWEEN', comparator: [2020, 2024] },
];
const result = injectRisonFiltersIntelligently(
risonFilters,
mockNativeFilters,
mockDataMask,
);
expect(result.updatedDataMask.filter_2.filterState?.value).toEqual({
min: 2020,
max: 2024,
});
expect(result.updatedDataMask.filter_2.extraFormData).toEqual({
filters: [
{ col: 'year', op: '>=', val: 2020 },
{ col: 'year', op: '<=', val: 2024 },
],
});
});
test('should handle mixed matched and unmatched filters', () => {
const risonFilters: RisonFilter[] = [
{ subject: 'country', operator: '==', comparator: 'USA' },
{ subject: 'category', operator: '==', comparator: 'Sales' },
];
const result = injectRisonFiltersIntelligently(
risonFilters,
mockNativeFilters,
mockDataMask,
);
expect(result.updatedDataMask.filter_1.filterState?.value).toEqual(['USA']);
expect(result.unmatchedFilters).toHaveLength(1);
expect(result.unmatchedFilters[0].subject).toBe('category');
});
test('should convert filters to adhoc format', () => {
const risonFilters: RisonFilter[] = [
{ subject: 'country', operator: '==', comparator: 'USA' },
];
const adhocFilters = risonToAdhocFilters(risonFilters);
expect(adhocFilters).toHaveLength(1);
expect(adhocFilters[0]).toMatchObject({
expressionType: 'SIMPLE',
clause: 'WHERE',
subject: 'country',
operator: '==',
comparator: 'USA',
});
});
test('should convert filters to Rison string', () => {
const filters: RisonFilter[] = [
{ subject: 'country', operator: '==', comparator: 'USA' },
];
const result = risonFiltersToString(filters);
expect(result).toBe('(country:USA)');
});
test('should convert IN filters to Rison string', () => {
const filters: RisonFilter[] = [
{ subject: 'country', operator: 'IN', comparator: ['USA', 'Canada'] },
];
const result = risonFiltersToString(filters);
expect(result).toBe('(country:!(USA,Canada))');
});
test('should return empty string for empty filters', () => {
expect(risonFiltersToString([])).toBe('');
});

View File

@@ -0,0 +1,490 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import {
QueryObjectFilterClause,
PartialFilters,
DataMaskStateWithId,
} from '@superset-ui/core';
import rison from 'rison';
export interface RisonFilter {
subject: string;
operator: string;
comparator: string | number | boolean | (string | number)[];
}
export interface IntelligentRisonInjectionResult {
updatedDataMask: DataMaskStateWithId;
unmatchedFilters: RisonFilter[];
}
/**
* Parse Rison filter syntax from URL parameter.
* Supports formats like: (country:USA,year:2024)
*/
export function parseRisonFilters(risonString: string): RisonFilter[] {
try {
const parsed = rison.decode(risonString);
const filters: RisonFilter[] = [];
if (!parsed || typeof parsed !== 'object') {
return filters;
}
const parsedObj = parsed as Record<string, unknown>;
// Handle OR operator: OR:!(condition1,condition2)
if (parsedObj.OR && Array.isArray(parsedObj.OR)) {
(parsedObj.OR as Record<string, unknown>[]).forEach(condition => {
if (typeof condition === 'object') {
Object.entries(condition).forEach(([key, value]) => {
filters.push(parseFilterCondition(key, value));
});
}
});
return filters;
}
// Handle NOT operator: NOT:(condition)
if (parsedObj.NOT && typeof parsedObj.NOT === 'object') {
Object.entries(parsedObj.NOT as Record<string, unknown>).forEach(
([key, value]) => {
const filter = parseFilterCondition(key, value);
if (filter.operator === '==') {
filter.operator = '!=';
} else if (filter.operator === 'IN') {
filter.operator = 'NOT IN';
}
filters.push(filter);
},
);
return filters;
}
// Handle regular filters
Object.entries(parsedObj).forEach(([key, value]) => {
if (key !== 'OR' && key !== 'NOT') {
filters.push(parseFilterCondition(key, value));
}
});
return filters;
} catch (error) {
console.warn('Failed to parse Rison filters:', error);
return [];
}
}
/**
* Parse individual filter condition
*/
function parseFilterCondition(key: string, value: unknown): RisonFilter {
// Handle comparison operators: (gt:100), (between:!(1,10))
if (typeof value === 'object' && value !== null && !Array.isArray(value)) {
const [operator, operatorValue] = Object.entries(
value as Record<string, unknown>,
)[0];
switch (operator) {
case 'gt':
return {
subject: key,
operator: '>',
comparator: operatorValue as string | number,
};
case 'gte':
return {
subject: key,
operator: '>=',
comparator: operatorValue as string | number,
};
case 'lt':
return {
subject: key,
operator: '<',
comparator: operatorValue as string | number,
};
case 'lte':
return {
subject: key,
operator: '<=',
comparator: operatorValue as string | number,
};
case 'between':
return {
subject: key,
operator: 'BETWEEN',
comparator: operatorValue as (string | number)[],
};
case 'like':
return {
subject: key,
operator: 'LIKE',
comparator: operatorValue as string,
};
default:
return {
subject: key,
operator: '==',
comparator: value as string | number,
};
}
}
// Handle IN operator: !(value1,value2)
if (Array.isArray(value)) {
return {
subject: key,
operator: 'IN',
comparator: value as (string | number)[],
};
}
// Handle simple equality
return {
subject: key,
operator: '==',
comparator: value as string | number | boolean,
};
}
/**
* Convert Rison filters to Superset adhoc filter format
*/
export function risonToAdhocFilters(
risonFilters: RisonFilter[],
): QueryObjectFilterClause[] {
return risonFilters.map(
filter =>
({
expressionType: 'SIMPLE' as const,
clause: 'WHERE' as const,
subject: filter.subject,
operator: filter.operator,
comparator: filter.comparator,
}) as unknown as QueryObjectFilterClause,
);
}
/**
* Prettify Rison filter URL by replacing encoded characters.
* Uses browser history API to update URL without page reload.
*/
export function prettifyRisonFilterUrl(): void {
try {
const currentUrl = window.location.href;
if (!currentUrl.includes('&f=') && !currentUrl.includes('?f=')) {
return;
}
const urlMatch = currentUrl.match(/([?&])f=([^&]*)/);
if (!urlMatch) {
return;
}
const separator = urlMatch[1];
let risonValue = urlMatch[2];
if (!risonValue.includes('%') && !risonValue.includes('+')) {
return;
}
let previousValue = '';
let decodeAttempts = 0;
while (risonValue !== previousValue && decodeAttempts < 5) {
previousValue = risonValue;
try {
if (risonValue.includes('%')) {
risonValue = decodeURIComponent(risonValue);
}
} catch {
break;
}
decodeAttempts += 1;
}
risonValue = risonValue.replace(/\+/g, ' ');
const matchIndex = urlMatch.index ?? 0;
const beforeRison = currentUrl.substring(0, matchIndex);
const afterRison = currentUrl.substring(matchIndex + urlMatch[0].length);
const prettifiedUrl = `${beforeRison}${separator}f=${risonValue}${afterRison}`;
if (prettifiedUrl !== currentUrl) {
window.history.replaceState(window.history.state, '', prettifiedUrl);
}
} catch (error) {
console.warn('Failed to prettify Rison URL:', error);
}
}
/**
* Get Rison filter parameter from URL
*/
export function getRisonFilterParam(): string | null {
const params = new URLSearchParams(window.location.search);
return params.get('f');
}
/**
* Convert an array of RisonFilter back to Rison string format
*/
export function risonFiltersToString(filters: RisonFilter[]): string {
if (filters.length === 0) {
return '';
}
const risonObject: Record<
string,
string | number | boolean | (string | number)[] | Record<string, unknown>
> = {};
filters.forEach(filter => {
if (filter.operator === 'IN' && Array.isArray(filter.comparator)) {
risonObject[filter.subject] = filter.comparator;
} else if (filter.operator === '==') {
risonObject[filter.subject] = filter.comparator;
} else {
const operatorMap: Record<string, string> = {
'>': 'gt',
'>=': 'gte',
'<': 'lt',
'<=': 'lte',
BETWEEN: 'between',
LIKE: 'like',
};
const risonOp = operatorMap[filter.operator] || filter.operator;
risonObject[filter.subject] = { [risonOp]: filter.comparator };
}
});
try {
return rison.encode(risonObject);
} catch (error) {
console.warn('Failed to encode Rison filters:', error);
return '';
}
}
/**
* Update the URL to remove successfully matched filters, keeping only unmatched ones
*/
export function updateUrlWithUnmatchedFilters(
unmatchedFilters: RisonFilter[],
): void {
try {
const currentUrl = new URL(window.location.href);
if (unmatchedFilters.length === 0) {
currentUrl.searchParams.delete('f');
} else {
const newRisonString = risonFiltersToString(unmatchedFilters);
if (newRisonString) {
currentUrl.searchParams.set('f', newRisonString);
} else {
currentUrl.searchParams.delete('f');
}
}
window.history.replaceState(
window.history.state,
'',
currentUrl.toString(),
);
} catch (error) {
console.warn('Failed to update URL with unmatched filters:', error);
}
}
/**
* Find a native filter that matches a Rison filter by column name.
* Uses case-insensitive, trimmed comparison to handle column names with spaces.
*/
function findMatchingNativeFilter(
risonFilter: RisonFilter,
nativeFilters: PartialFilters,
): string | null {
const normalizedSubject = risonFilter.subject.trim().toLowerCase();
for (const [filterId, nativeFilter] of Object.entries(nativeFilters)) {
if (!nativeFilter?.targets) continue;
const hasMatchingTarget = nativeFilter.targets.some(target => {
if (typeof target === 'object' && target && 'column' in target) {
return (
target.column?.name?.trim().toLowerCase() === normalizedSubject
);
}
return false;
});
if (hasMatchingTarget) {
return filterId;
}
}
return null;
}
/**
* Build extraFormData filters for a given rison filter and column name
*/
function buildExtraFormDataFilters(
risonFilter: RisonFilter,
columnName: string,
): { col: string; op: string; val: unknown }[] {
const { operator, comparator } = risonFilter;
if (operator === 'IN' || (operator === '==' && Array.isArray(comparator))) {
return [
{
col: columnName,
op: 'IN',
val: Array.isArray(comparator) ? comparator : [comparator],
},
];
}
if (operator === '==' && !Array.isArray(comparator)) {
return [{ col: columnName, op: 'IN', val: [comparator] }];
}
if (
operator === 'BETWEEN' &&
Array.isArray(comparator) &&
comparator.length === 2
) {
return [
{ col: columnName, op: '>=', val: comparator[0] },
{ col: columnName, op: '<=', val: comparator[1] },
];
}
return [{ col: columnName, op: operator, val: comparator }];
}
/**
* Convert a Rison filter value to the format expected by a native filter.
* Also returns extraFormData for auto-application.
*/
function convertRisonToNativeValue(
risonFilter: RisonFilter,
nativeFilter: { filterType?: string },
): unknown {
const { comparator, operator } = risonFilter;
const filterType = nativeFilter?.filterType;
switch (filterType) {
case 'filter_select':
if (operator === 'IN' || Array.isArray(comparator)) {
return Array.isArray(comparator) ? comparator : [comparator];
}
return [comparator];
case 'filter_range':
if (
operator === 'BETWEEN' &&
Array.isArray(comparator) &&
comparator.length === 2
) {
return { min: comparator[0], max: comparator[1] };
}
return comparator;
case 'filter_time_range':
case 'filter_timecolumn':
return comparator;
default:
return Array.isArray(comparator) ? comparator : [comparator];
}
}
/**
* Build a complete DataMask entry for a rison filter matched to a native filter.
* Sets both filterState.value AND extraFormData so the filter auto-applies.
*/
function buildDataMaskForFilter(
risonFilter: RisonFilter,
nativeFilter: { id: string; filterType?: string; targets?: { column?: { name?: string } }[] },
columnName: string,
) {
const convertedValue = convertRisonToNativeValue(risonFilter, nativeFilter);
return {
id: nativeFilter.id,
filterState: {
value: convertedValue,
},
extraFormData: {
filters: buildExtraFormDataFilters(risonFilter, columnName),
},
ownState: {},
};
}
/**
* Intelligently inject Rison filters into native filters where possible,
* falling back to brute-force injection for unmatched filters
*/
export function injectRisonFiltersIntelligently(
risonFilters: RisonFilter[],
nativeFilters: PartialFilters,
currentDataMask: DataMaskStateWithId,
): IntelligentRisonInjectionResult {
const updatedDataMask = { ...currentDataMask };
const unmatchedFilters: RisonFilter[] = [];
risonFilters.forEach(risonFilter => {
const matchingFilterId = findMatchingNativeFilter(
risonFilter,
nativeFilters,
);
if (matchingFilterId) {
const matchedFilter = nativeFilters[matchingFilterId];
if (matchedFilter) {
const columnName =
matchedFilter.targets?.[0]?.column?.name ?? risonFilter.subject;
const dataMaskEntry = buildDataMaskForFilter(
risonFilter,
matchedFilter as { id: string; filterType?: string; targets?: { column?: { name?: string } }[] },
columnName,
);
updatedDataMask[matchedFilter.id] = {
...updatedDataMask[matchedFilter.id],
...dataMaskEntry,
};
return;
}
}
unmatchedFilters.push(risonFilter);
});
return {
updatedDataMask,
unmatchedFilters,
};
}

View File

@@ -590,7 +590,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# Driver-specific params to be included in the `get_oauth2_token` request body
oauth2_additional_token_request_params: dict[str, Any] = {}
# Driver-specific exception that should be mapped to OAuth2RedirectError
oauth2_exception = OAuth2RedirectError
oauth2_exception: type[Exception] | tuple[type[Exception], ...] = (
OAuth2RedirectError
)
# Does the query id related to the connection?
# The default value is True, which means that the query id is determined when

View File

@@ -31,6 +31,7 @@ from marshmallow import fields, Schema
from marshmallow.exceptions import ValidationError
from requests import Session
from shillelagh.adapters.api.gsheets.lib import SCOPES
from shillelagh.exceptions import UnauthenticatedError
from sqlalchemy.engine import create_engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
@@ -40,7 +41,7 @@ from superset.databases.schemas import encrypted_field_properties, EncryptedStri
from superset.db_engine_specs.base import DatabaseCategory
from superset.db_engine_specs.shillelagh import ShillelaghEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetException
from superset.exceptions import OAuth2TokenRefreshError, SupersetException
from superset.utils import json
from superset.utils.oauth2 import get_oauth2_access_token
@@ -151,6 +152,7 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
"https://accounts.google.com/o/oauth2/v2/auth"
)
oauth2_token_request_uri = "https://oauth2.googleapis.com/token" # noqa: S105
oauth2_exception = (UnauthenticatedError, OAuth2TokenRefreshError)
@classmethod
def get_oauth2_authorization_uri(

View File

@@ -62,6 +62,7 @@ Dataset Management:
- list_datasets: List datasets with advanced filters (1-based pagination)
- get_dataset_info: Get detailed dataset information by ID (includes columns/metrics)
- create_virtual_dataset: Save a SQL query as a virtual dataset for charting
- query_dataset: Query a dataset using its semantic layer (saved metrics, dimensions, filters) without needing a saved chart
Chart Management:
- list_charts: List charts with advanced filters (1-based pagination)
@@ -164,6 +165,17 @@ Use created_by_me for authorship, owned_by_me for edit ownership, or both
together for the union. All flags can be combined with 'filters' but not
with 'search'.
To query a dataset's semantic layer (metrics, dimensions):
1. list_datasets(request={{}}) -> find a dataset
2. get_dataset_info(request={{"identifier": <id>}}) -> examine columns AND metrics
3. query_dataset(request={{
"dataset_id": <id>,
"metrics": ["count", "avg_revenue"],
"columns": ["category"],
"time_range": "Last 7 days",
"row_limit": 100
}}) -> returns tabular data using saved metrics and dimensions
To explore data with SQL:
1. list_datasets(request={{}}) -> find a dataset and note its database_id
2. execute_sql(request={{"database_id": <id>, "sql": "SELECT ..."}})
@@ -520,6 +532,7 @@ from superset.mcp_service.dataset.tool import ( # noqa: F401, E402
create_virtual_dataset,
get_dataset_info,
list_datasets,
query_dataset,
)
from superset.mcp_service.explore.tool import ( # noqa: F401, E402
generate_explore_link,

View File

@@ -70,6 +70,8 @@ SORTABLE_CHART_COLUMNS = [
"created_on",
]
_DEFAULT_LIST_CHARTS_REQUEST = ListChartsRequest()
@tool(
tags=["core"],
@@ -81,7 +83,8 @@ SORTABLE_CHART_COLUMNS = [
),
)
async def list_charts(
request: ListChartsRequest, ctx: Context
request: ListChartsRequest | None = None,
ctx: Context = None,
) -> ChartList | ChartError:
"""List charts with filtering and search.
@@ -91,6 +94,7 @@ async def list_charts(
Sortable columns for order_column: id, slice_name, viz_type, description,
changed_on, created_on
"""
request = request or _DEFAULT_LIST_CHARTS_REQUEST.model_copy(deep=True)
await ctx.info(
"Listing charts: page=%s, page_size=%s, search=%s"
% (

View File

@@ -65,6 +65,8 @@ SORTABLE_DASHBOARD_COLUMNS = [
"created_on",
]
_DEFAULT_LIST_DASHBOARDS_REQUEST = ListDashboardsRequest()
@tool(
tags=["core"],
@@ -76,7 +78,8 @@ SORTABLE_DASHBOARD_COLUMNS = [
),
)
async def list_dashboards(
request: ListDashboardsRequest, ctx: Context
request: ListDashboardsRequest | None = None,
ctx: Context = None,
) -> DashboardList:
"""List dashboards with filtering and search. Returns dashboard metadata
including title, slug, URL, and last modified time. Use select_columns to
@@ -85,6 +88,7 @@ async def list_dashboards(
Sortable columns for order_column: id, dashboard_title, slug, published,
changed_on, created_on
"""
request = request or _DEFAULT_LIST_DASHBOARDS_REQUEST.model_copy(deep=True)
await ctx.info(
"Listing dashboards: page=%s, page_size=%s, search=%s"
% (

View File

@@ -36,10 +36,13 @@ from pydantic import (
)
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
from superset.mcp_service.chart.schemas import DataColumn, PerformanceMetadata
from superset.mcp_service.common.cache_schemas import (
CacheStatus,
CreatedByMeMixin,
MetadataCacheControl,
OwnedByMeMixin,
QueryCacheControl,
)
from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
from superset.mcp_service.privacy import filter_user_directory_fields
@@ -393,6 +396,146 @@ class CreateVirtualDatasetResponse(BaseModel):
)
VALID_FILTER_OPS = Literal[
"==",
"!=",
">",
"<",
">=",
"<=",
"LIKE",
"NOT LIKE",
"ILIKE",
"NOT ILIKE",
"IN",
"NOT IN",
"IS NULL",
"IS NOT NULL",
"IS TRUE",
"IS FALSE",
"TEMPORAL_RANGE",
]
class QueryDatasetFilter(BaseModel):
"""A single filter condition for dataset queries."""
col: str = Field(..., description="Column name to filter on")
op: VALID_FILTER_OPS = Field(
...,
description=(
'Filter operator. Use "==" for equals, "!=" for not equals, '
'"IN" / "NOT IN" for membership, "IS NULL" / "IS NOT NULL", '
'"LIKE" for pattern matching, "TEMPORAL_RANGE" for time filters.'
),
)
val: Any = Field(
default=None,
description="Filter value (omit for IS NULL/IS NOT NULL)",
)
class QueryDatasetRequest(QueryCacheControl):
"""Request schema for query_dataset tool."""
dataset_id: int | str = Field(
...,
description="Dataset identifier — numeric ID or UUID string.",
)
metrics: List[str] = Field(
default_factory=list,
description=(
"Saved metric names to compute (e.g. ['count', 'avg_revenue']). "
"Use get_dataset_info to discover available metrics."
),
)
columns: List[str] = Field(
default_factory=list,
description=(
"Column/dimension names for GROUP BY or SELECT "
"(e.g. ['category', 'region']). "
"Use get_dataset_info to discover available columns."
),
)
filters: List[QueryDatasetFilter] = Field(
default_factory=list,
description=(
'Filter conditions (e.g. [{"col": "status", "op": "==", "val": "active"}]).'
),
)
time_range: str | None = Field(
default=None,
description=(
"Time range filter (e.g. 'Last 7 days', 'Last month', "
"'2024-01-01 : 2024-12-31'). Requires a temporal column "
"on the dataset."
),
)
time_column: str | None = Field(
default=None,
description=(
"Temporal column to apply time_range to. "
"Defaults to the dataset's main datetime column."
),
)
order_by: List[str] | None = Field(
default=None,
description="Column or metric names to sort results by.",
)
order_desc: bool = Field(
default=True,
description="Sort descending (True) or ascending (False).",
)
row_limit: int = Field(
default=1000,
ge=1,
le=50000,
description="Maximum number of rows to return (default 1000, max 50000).",
)
@model_validator(mode="after")
def validate_metrics_or_columns(self) -> "QueryDatasetRequest":
"""At least one of metrics or columns must be provided."""
if not self.metrics and not self.columns:
raise ValueError(
"At least one of 'metrics' or 'columns' must be provided. "
"Use get_dataset_info to discover available metrics and columns."
)
return self
class QueryDatasetResponse(BaseModel):
"""Response schema for query_dataset tool."""
model_config = ConfigDict(ser_json_timedelta="iso8601")
dataset_id: int = Field(..., description="Dataset ID")
dataset_name: str = Field(..., description="Dataset name")
columns: List[DataColumn] = Field(
default_factory=list, description="Column metadata for returned data"
)
data: List[Dict[str, Any]] = Field(
default_factory=list, description="Query result rows"
)
row_count: int = Field(0, description="Number of rows returned")
total_rows: int | None = Field(
None, description="Total row count from the query engine"
)
summary: str = Field("", description="Human-readable summary of the results")
performance: PerformanceMetadata | None = Field(
None, description="Query performance metadata"
)
cache_status: CacheStatus | None = Field(
None, description="Cache hit/miss information"
)
applied_filters: List[QueryDatasetFilter] = Field(
default_factory=list, description="Filters that were applied to the query"
)
warnings: List[str] = Field(
default_factory=list, description="Any warnings encountered during execution"
)
def _parse_json_field(obj: Any, field_name: str) -> Dict[str, Any] | None:
"""Parse a field that may be stored as a JSON string into a dict."""
value = getattr(obj, field_name, None)

View File

@@ -18,9 +18,11 @@
from .create_virtual_dataset import create_virtual_dataset
from .get_dataset_info import get_dataset_info
from .list_datasets import list_datasets
from .query_dataset import query_dataset
__all__ = [
"create_virtual_dataset",
"list_datasets",
"get_dataset_info",
"query_dataset",
]

View File

@@ -0,0 +1,489 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
MCP tool: query_dataset
Query a dataset using its semantic layer (saved metrics, calculated columns,
dimensions) without requiring a saved chart.
"""
import difflib
import logging
import time
from typing import Any
from fastmcp import Context
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import joinedload, subqueryload
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.commands.exceptions import CommandException
from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException
from superset.extensions import event_logger
from superset.mcp_service.chart.schemas import DataColumn, PerformanceMetadata
from superset.mcp_service.dataset.schemas import (
DatasetError,
QueryDatasetFilter,
QueryDatasetRequest,
QueryDatasetResponse,
)
from superset.mcp_service.privacy import (
DATA_MODEL_METADATA_ERROR_TYPE,
requires_data_model_metadata_access,
user_can_view_data_model_metadata,
)
from superset.mcp_service.utils import _is_uuid
from superset.mcp_service.utils.cache_utils import get_cache_status_from_result
from superset.mcp_service.utils.oauth2_utils import build_oauth2_redirect_message
logger = logging.getLogger(__name__)
def _resolve_dataset(identifier: int | str, eager_options: list[Any]) -> Any | None:
"""Resolve a dataset by int ID or UUID string.
Replicates the identifier resolution logic from ModelGetInfoCore._find_object().
"""
from superset.daos.dataset import DatasetDAO
opts = eager_options or None
if isinstance(identifier, int):
return DatasetDAO.find_by_id(identifier, query_options=opts)
# Try parsing as int
try:
id_val = int(identifier)
return DatasetDAO.find_by_id(id_val, query_options=opts)
except (ValueError, TypeError):
pass
# Try UUID
if _is_uuid(str(identifier)):
return DatasetDAO.find_by_id(identifier, id_column="uuid", query_options=opts)
return None
def _validate_names(
requested: list[str],
valid: set[str],
kind: str,
) -> list[str]:
"""Return list of error messages for names not found in *valid*.
Includes close-match suggestions when available.
"""
errors: list[str] = []
for name in requested:
if name not in valid:
suggestions = difflib.get_close_matches(name, valid, n=3, cutoff=0.6)
msg = f"Unknown {kind}: '{name}'"
if suggestions:
msg += f". Did you mean: {', '.join(suggestions)}?"
errors.append(msg)
return errors
@requires_data_model_metadata_access
@tool(
tags=["data"],
class_permission_name="Dataset",
annotations=ToolAnnotations(
title="Query dataset",
readOnlyHint=True,
destructiveHint=False,
),
)
async def query_dataset( # noqa: C901
request: QueryDatasetRequest, ctx: Context
) -> QueryDatasetResponse | DatasetError:
"""Query a dataset using its semantic layer (saved metrics, dimensions, filters).
Returns tabular data without requiring a saved chart. Use this when you want
to compute saved metrics, group by dimensions, or apply filters directly
against a dataset's curated semantic layer.
Workflow:
1. list_datasets -> find a dataset
2. get_dataset_info -> discover available columns and metrics
3. query_dataset -> query using metric names and column names
Example:
```json
{
"dataset_id": 123,
"metrics": ["count", "avg_revenue"],
"columns": ["product_category"],
"time_range": "Last 7 days",
"row_limit": 100
}
```
"""
await ctx.info(
"Starting dataset query: dataset_id=%s, metrics=%s, columns=%s, "
"row_limit=%s"
% (
request.dataset_id,
request.metrics,
request.columns,
request.row_limit,
)
)
try:
from superset.commands.chart.data.get_data_command import ChartDataCommand
from superset.common.query_context_factory import QueryContextFactory
from superset.connectors.sqla.models import SqlaTable
# ------------------------------------------------------------------
# Step 1: Check data-model metadata access BEFORE the dataset lookup.
# Doing this first prevents leaking dataset existence — restricted
# users always receive DataModelMetadataRestricted, never NotFound.
# The decorator hides this tool from search; this check enforces
# direct calls that bypass tool discovery.
# ------------------------------------------------------------------
if not user_can_view_data_model_metadata():
await ctx.warning("Dataset metadata access blocked by privacy controls")
return DatasetError.create(
error=(
"You don't have permission to access dataset details for your role."
),
error_type=DATA_MODEL_METADATA_ERROR_TYPE,
)
# ------------------------------------------------------------------
# Step 2: Resolve dataset
# ------------------------------------------------------------------
await ctx.report_progress(1, 5, "Looking up dataset")
eager_options = [
subqueryload(SqlaTable.columns),
subqueryload(SqlaTable.metrics),
joinedload(SqlaTable.database),
]
with event_logger.log_context(action="mcp.query_dataset.lookup"):
dataset = _resolve_dataset(request.dataset_id, eager_options)
if dataset is None:
await ctx.error("Dataset not found: identifier=%s" % (request.dataset_id,))
return DatasetError.create(
error=f"No dataset found with identifier: {request.dataset_id}",
error_type="NotFound",
)
dataset_name = getattr(dataset, "table_name", None) or f"Dataset {dataset.id}"
await ctx.info(
"Dataset found: id=%s, name=%s, columns=%s, metrics=%s"
% (
dataset.id,
dataset_name,
len(dataset.columns),
len(dataset.metrics),
)
)
# ------------------------------------------------------------------
# Step 2: Validate requested columns and metrics
# ------------------------------------------------------------------
await ctx.report_progress(2, 5, "Validating columns and metrics")
valid_columns = {c.column_name for c in dataset.columns}
valid_metrics = {m.metric_name for m in dataset.metrics}
validation_errors: list[str] = []
validation_errors.extend(
_validate_names(request.columns, valid_columns, "column")
)
validation_errors.extend(
_validate_names(request.metrics, valid_metrics, "metric")
)
# Validate filter column names against dataset columns
filter_cols = [f.col for f in request.filters]
validation_errors.extend(
_validate_names(filter_cols, valid_columns, "filter column")
)
# Validate order_by names against columns + metrics
if request.order_by:
valid_orderby = valid_columns | valid_metrics
validation_errors.extend(
_validate_names(request.order_by, valid_orderby, "order_by")
)
if validation_errors:
error_msg = "; ".join(validation_errors)
await ctx.error("Validation failed: %s" % (error_msg,))
return DatasetError.create(
error=error_msg,
error_type="ValidationError",
)
# ------------------------------------------------------------------
# Step 3: Build filters and time range
# ------------------------------------------------------------------
warnings: list[str] = []
query_filters: list[dict[str, Any]] = [
{"col": f.col, "op": f.op, "val": f.val} for f in request.filters
]
# Track all applied filters (including synthesized ones) for the response.
effective_filters: list[QueryDatasetFilter] = list(request.filters)
granularity: str | None = None
if request.time_range:
temporal_col = request.time_column or getattr(
dataset, "main_dttm_col", None
)
if not temporal_col:
await ctx.error("time_range provided but no temporal column available")
return DatasetError.create(
error=(
"time_range was provided but no temporal column is available. "
"Either set time_column explicitly or ensure the dataset has "
"a main datetime column configured."
),
error_type="ValidationError",
)
# Validate that the temporal column actually exists on the dataset
if temporal_col not in valid_columns:
await ctx.error("time_column '%s' not found on dataset" % temporal_col)
return DatasetError.create(
error=(
f"time_column '{temporal_col}' does not exist on this dataset."
),
error_type="ValidationError",
)
# Warn if the chosen temporal column isn't marked as datetime
dttm_cols = {c.column_name for c in dataset.columns if c.is_dttm}
if temporal_col not in dttm_cols:
warnings.append(
f"Column '{temporal_col}' is not marked as a datetime "
f"column on this dataset. Time filtering may not work "
f"as expected."
)
query_filters.append(
{
"col": temporal_col,
"op": "TEMPORAL_RANGE",
"val": request.time_range,
}
)
effective_filters.append(
QueryDatasetFilter(
col=temporal_col,
op="TEMPORAL_RANGE",
val=request.time_range,
)
)
granularity = temporal_col
await ctx.debug(
"Time filter: column=%s, range=%s" % (temporal_col, request.time_range)
)
# ------------------------------------------------------------------
# Step 4: Build query dict
# ------------------------------------------------------------------
await ctx.report_progress(3, 5, "Building query")
query_dict: dict[str, Any] = {
"filters": query_filters,
"columns": request.columns,
"metrics": request.metrics,
"row_limit": request.row_limit,
"order_desc": request.order_desc,
}
if granularity:
query_dict["granularity"] = granularity
if request.order_by:
# OrderBy = tuple[Metric | Column, bool] where bool is ascending
query_dict["orderby"] = [
(col, not request.order_desc) for col in request.order_by
]
await ctx.debug("Query dict keys: %s" % (sorted(query_dict.keys()),))
# ------------------------------------------------------------------
# Step 5: Create QueryContext and execute
# ------------------------------------------------------------------
await ctx.report_progress(4, 5, "Executing query")
start_time = time.time()
with event_logger.log_context(action="mcp.query_dataset.execute"):
factory = QueryContextFactory()
# datasource_type is "table" because this tool queries SqlaTable
# datasets (Superset's built-in semantic layer). External semantic
# layers (dbt, Snowflake Cortex, etc.) use "semantic_view" and have
# a different query path — see SemanticView + mapper.py.
query_context = factory.create(
datasource={"id": dataset.id, "type": "table"},
queries=[query_dict],
form_data={},
force=not request.use_cache or request.force_refresh,
custom_cache_timeout=request.cache_timeout,
)
command = ChartDataCommand(query_context)
command.validate()
result = command.run()
query_duration_ms = int((time.time() - start_time) * 1000)
if not result or "queries" not in result or len(result["queries"]) == 0:
await ctx.warning("Query returned no results for dataset %s" % dataset.id)
return DatasetError.create(
error="Query returned no results.",
error_type="EmptyQuery",
)
# ------------------------------------------------------------------
# Step 6: Format response
# ------------------------------------------------------------------
await ctx.report_progress(5, 5, "Formatting results")
query_result = result["queries"][0]
data = query_result.get("data", [])
raw_columns = query_result.get("colnames", [])
if not data:
return QueryDatasetResponse(
dataset_id=dataset.id,
dataset_name=dataset_name,
columns=[],
data=[],
row_count=0,
total_rows=0,
summary=f"Query on '{dataset_name}' returned no data.",
performance=PerformanceMetadata(
query_duration_ms=query_duration_ms,
cache_status="no_data",
),
cache_status=get_cache_status_from_result(
query_result, force_refresh=request.force_refresh
),
applied_filters=effective_filters,
warnings=warnings,
)
# Build column metadata in a single pass per column.
# Cap stats computation at STATS_SAMPLE rows to avoid O(rows*cols)
# overhead on large result sets (row_limit allows up to 50k).
stats_sample_size = 5000
stats_rows = data[:stats_sample_size]
columns_meta: list[DataColumn] = []
for col_name in raw_columns:
sample_values = [
row.get(col_name) for row in data[:3] if row.get(col_name) is not None
]
data_type = "string"
if sample_values:
if all(isinstance(v, bool) for v in sample_values):
data_type = "boolean"
elif all(isinstance(v, (int, float)) for v in sample_values):
data_type = "numeric"
# Compute null_count and unique non-null values in one pass
null_count = 0
unique_vals: set[str] = set()
for row in stats_rows:
val = row.get(col_name)
if val is None:
null_count += 1
else:
unique_vals.add(str(val))
columns_meta.append(
DataColumn(
name=col_name,
display_name=col_name.replace("_", " ").title(),
data_type=data_type,
sample_values=sample_values[:3],
null_count=null_count,
unique_count=len(unique_vals),
)
)
cache_status = get_cache_status_from_result(
query_result, force_refresh=request.force_refresh
)
cache_label = "cached" if cache_status and cache_status.cache_hit else "fresh"
summary = (
f"Dataset '{dataset_name}': {len(data)} rows, "
f"{len(raw_columns)} columns ({cache_label})."
)
await ctx.info(
"Query complete: rows=%s, columns=%s, duration=%sms"
% (len(data), len(raw_columns), query_duration_ms)
)
return QueryDatasetResponse(
dataset_id=dataset.id,
dataset_name=dataset_name,
columns=columns_meta,
data=data,
row_count=len(data),
total_rows=query_result.get("rowcount"),
summary=summary,
performance=PerformanceMetadata(
query_duration_ms=query_duration_ms,
cache_status=cache_label,
),
cache_status=cache_status,
applied_filters=effective_filters,
warnings=warnings,
)
except OAuth2RedirectError as exc:
redirect_msg = build_oauth2_redirect_message(exc)
await ctx.error("OAuth2 redirect required: %s" % (redirect_msg,))
return DatasetError.create(
error=redirect_msg,
error_type="OAuth2Redirect",
)
except OAuth2Error as exc:
await ctx.error("OAuth2 error: %s" % (str(exc),))
return DatasetError.create(
error=f"OAuth2 authentication error: {exc}",
error_type="OAuth2Error",
)
except (CommandException, SupersetException) as exc:
await ctx.error("Query failed: %s" % (str(exc),))
return DatasetError.create(
error=f"Query execution failed: {exc}",
error_type="QueryError",
)
except SQLAlchemyError as exc:
await ctx.error("Database error: %s" % (str(exc),))
return DatasetError.create(
error=f"Database error: {exc}",
error_type="DatabaseError",
)
except Exception as exc:
logger.exception(
"Unexpected error while querying dataset: %s: %s",
type(exc).__name__,
str(exc),
)
await ctx.error("Unexpected error: %s: %s" % (type(exc).__name__, str(exc)))
return DatasetError.create(
error="An unexpected error occurred while querying the dataset.",
error_type="UnexpectedError",
)

View File

@@ -0,0 +1,226 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Parser for Rison URL filters that converts simplified filter syntax
to Superset's adhoc_filters format.
"""
from __future__ import annotations
import logging
from typing import Any, Optional, Union
import prison
from flask import request
logger = logging.getLogger(__name__)
class RisonFilterParser:
"""
Parse Rison filter syntax from URL parameter 'f' and convert to adhoc_filters.
Supports:
- Simple equality: f=(country:USA)
- Lists (IN): f=(country:!(USA,Canada))
- NOT operator: f=(NOT:(country:USA))
- OR operator: f=(OR:!(condition1,condition2))
- Comparison operators: f=(sales:(gt:100000))
- BETWEEN: f=(date:(between:!(2024-01-01,2024-12-31)))
- LIKE: f=(name:(like:'%smith%'))
"""
OPERATORS: dict[str, str] = {
"gt": ">",
"gte": ">=",
"lt": "<",
"lte": "<=",
"between": "BETWEEN",
"like": "LIKE",
"ilike": "ILIKE",
"ne": "!=",
"eq": "==",
}
def parse(self, filter_string: Optional[str] = None) -> list[dict[str, Any]]:
"""
Parse Rison filter string and convert to adhoc_filters format.
Args:
filter_string: Rison-encoded filter string, or None to get from request
Returns:
List of adhoc_filter dictionaries
"""
if filter_string is None:
filter_string = request.args.get("f")
if not filter_string:
return []
try:
filters_obj = prison.loads(filter_string)
return self._convert_to_adhoc_filters(filters_obj)
except Exception:
logger.warning(
"Failed to parse Rison filters: %s", filter_string, exc_info=True
)
return []
def _convert_to_adhoc_filters(
self, filters_obj: Union[dict[str, Any], list[Any], Any]
) -> list[dict[str, Any]]:
if not isinstance(filters_obj, dict):
return []
adhoc_filters: list[dict[str, Any]] = []
for key, value in filters_obj.items():
if key == "OR":
adhoc_filters.extend(self._handle_or_operator(value))
elif key == "NOT":
adhoc_filters.extend(self._handle_not_operator(value))
else:
filter_dict = self._create_filter(key, value)
if filter_dict:
adhoc_filters.append(filter_dict)
return adhoc_filters
def _create_filter(
self, column: str, value: Any, negate: bool = False
) -> Optional[dict[str, Any]]:
filter_dict: dict[str, Any] = {
"expressionType": "SIMPLE",
"clause": "WHERE",
"subject": column,
}
if isinstance(value, list):
filter_dict["operator"] = "NOT IN" if negate else "IN"
filter_dict["comparator"] = value
elif isinstance(value, dict):
operator_info = self._parse_operator_dict(value)
if operator_info:
operator, comparator = operator_info
if negate and operator == "==":
operator = "!="
elif negate and operator == "IN":
operator = "NOT IN"
filter_dict["operator"] = operator
filter_dict["comparator"] = comparator
else:
return None
else:
filter_dict["operator"] = "!=" if negate else "=="
filter_dict["comparator"] = value
return filter_dict
def _parse_operator_dict(
self, op_dict: dict[str, Any]
) -> Optional[tuple[str, Any]]:
if not op_dict:
return None
for op_key, op_value in op_dict.items():
if op_key in self.OPERATORS:
operator = self.OPERATORS[op_key]
if (
operator == "BETWEEN"
and isinstance(op_value, list)
and len(op_value) == 2
):
return operator, op_value
return operator, op_value
if op_key == "in":
return "IN", op_value if isinstance(op_value, list) else [op_value]
if op_key == "nin":
return "NOT IN", op_value if isinstance(op_value, list) else [op_value]
return None
def _handle_or_operator(self, or_value: Any) -> list[dict[str, Any]]:
if not isinstance(or_value, list):
return []
sql_parts: list[str] = []
for item in or_value:
if isinstance(item, dict):
for col, val in item.items():
if col not in ("OR", "NOT"):
sql_part = self._build_sql_condition(col, val)
if sql_part:
sql_parts.append(sql_part)
if sql_parts:
return [
{
"expressionType": "SQL",
"clause": "WHERE",
"sqlExpression": f"({' OR '.join(sql_parts)})",
}
]
return []
def _build_sql_condition(self, column: str, value: Any) -> Optional[str]:
if isinstance(value, list):
values_str = ", ".join(
[f"'{v}'" if isinstance(v, str) else str(v) for v in value]
)
return f"{column} IN ({values_str})"
if isinstance(value, dict):
operator_info = self._parse_operator_dict(value)
if operator_info:
op, comp = operator_info
if op == "BETWEEN" and isinstance(comp, list):
return f"{column} BETWEEN '{comp[0]}' AND '{comp[1]}'"
if op == "LIKE":
return f"{column} LIKE '{comp}'"
comp_str = f"'{comp}'" if isinstance(comp, str) else str(comp)
return f"{column} {op} {comp_str}"
val_str = f"'{value}'" if isinstance(value, str) else str(value)
return f"{column} = {val_str}"
def _handle_not_operator(self, not_value: Any) -> list[dict[str, Any]]:
if isinstance(not_value, dict):
filters: list[dict[str, Any]] = []
for col, val in not_value.items():
if col not in ("OR", "NOT"):
filter_dict = self._create_filter(col, val, negate=True)
if filter_dict:
filters.append(filter_dict)
return filters
return []
def merge_rison_filters(form_data: dict[str, Any]) -> None:
"""
Merge Rison filters from 'f' parameter into form_data.
Modifies form_data in place.
"""
parser = RisonFilterParser()
if rison_filters := parser.parse():
existing_filters = form_data.get("adhoc_filters", [])
form_data["adhoc_filters"] = existing_filters + rison_filters
logger.info("Added %d filters from Rison parameter", len(rison_filters))

View File

@@ -78,15 +78,6 @@ FEATURE_FLAGS = {
WEBDRIVER_BASEURL = "http://0.0.0.0:8081/"
# Enable CORS for embedded dashboard E2E tests (test app on port 9000)
ENABLE_CORS = True
CORS_OPTIONS: dict = {
"origins": [
"http://localhost:9000",
],
"supports_credentials": True,
}
def GET_FEATURE_FLAGS_FUNC(ff): # noqa: N802
ff_copy = copy(ff)
@@ -95,7 +86,6 @@ def GET_FEATURE_FLAGS_FUNC(ff): # noqa: N802
TESTING = True
TALISMAN_ENABLED = False
WTF_CSRF_ENABLED = False
FAB_ROLES = {"TestRole": [["Security", "menu_access"], ["List Users", "menu_access"]]}

View File

@@ -24,6 +24,7 @@ import pandas as pd
import pytest
from pytest_mock import MockerFixture
from requests.exceptions import HTTPError
from shillelagh.exceptions import UnauthenticatedError
from sqlalchemy.engine.url import make_url
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
@@ -789,6 +790,36 @@ def test_needs_oauth2_with_other_error(mocker: MockerFixture) -> None:
assert GSheetsEngineSpec.needs_oauth2(ex) is False
def test_needs_oauth2_with_shillelagh_unauthenticated_error(
mocker: MockerFixture,
) -> None:
"""
Test that needs_oauth2 returns True when UnauthenticatedError is raised.
"""
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
g = mocker.patch("superset.db_engine_specs.gsheets.g")
g.user = mocker.MagicMock()
ex = UnauthenticatedError("Token has been revoked")
assert GSheetsEngineSpec.needs_oauth2(ex) is True
def test_needs_oauth2_with_unrelated_exception_type(
mocker: MockerFixture,
) -> None:
"""
Test that an unrelated exception type (with no matching message) returns
False.
"""
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
g = mocker.patch("superset.db_engine_specs.gsheets.g")
g.user = mocker.MagicMock()
assert GSheetsEngineSpec.needs_oauth2(ValueError("unrelated")) is False
def test_get_oauth2_fresh_token_success(
mocker: MockerFixture,
oauth2_config: OAuth2ClientConfig,

View File

@@ -320,66 +320,13 @@ class TestChartDataModelMetadataPrivacy:
assert data["error_type"] == DATA_MODEL_METADATA_ERROR_TYPE
class TestListChartsCreatedByMe:
"""Tests for the created_by_me flag on ListChartsRequest."""
def test_created_by_me_default_is_false(self):
request = ListChartsRequest()
assert request.created_by_me is False
def test_created_by_me_true_accepted(self):
request = ListChartsRequest(created_by_me=True)
assert request.created_by_me is True
def test_created_by_me_combined_with_filters(self):
request = ListChartsRequest(
created_by_me=True,
filters=[ChartFilter(col="slice_name", opr="sw", value="My")],
)
assert request.created_by_me is True
assert len(request.filters) == 1
def test_created_by_me_with_search_raises(self):
from pydantic import ValidationError
with pytest.raises(ValidationError, match="created_by_me"):
ListChartsRequest(created_by_me=True, search="My charts")
def test_chart_filter_rejects_created_by_fk(self):
"""created_by_fk is not a public filter column; use created_by_me instead."""
from pydantic import ValidationError
with pytest.raises(ValidationError):
ChartFilter(col="created_by_fk", opr="eq", value=1)
class TestListChartsOwnedByMe:
"""Tests for the owned_by_me flag on ListChartsRequest."""
def test_owned_by_me_default_is_false(self):
request = ListChartsRequest()
assert request.owned_by_me is False
def test_owned_by_me_true_accepted(self):
request = ListChartsRequest(owned_by_me=True)
assert request.owned_by_me is True
def test_owned_by_me_combined_with_filters(self):
request = ListChartsRequest(
owned_by_me=True,
filters=[ChartFilter(col="slice_name", opr="sw", value="My")],
)
assert request.owned_by_me is True
assert len(request.filters) == 1
def test_owned_by_me_with_search_raises(self):
from pydantic import ValidationError
with pytest.raises(ValidationError, match="owned_by_me"):
ListChartsRequest(owned_by_me=True, search="My charts")
def test_owned_by_me_and_created_by_me_allowed(self):
"""Both flags together are valid (OR logic — creator or owner)."""
request = ListChartsRequest(owned_by_me=True, created_by_me=True)
assert request.owned_by_me is True
assert request.created_by_me is True
@patch("superset.daos.chart.ChartDAO.list")
@pytest.mark.asyncio
async def test_list_charts_no_arguments(mock_list, mcp_server):
"""Regression test: list_charts must accept zero arguments without raising
pydantic_core.ValidationError: Missing required argument: request."""
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client:
result = await client.call_tool("list_charts", {})
data = json.loads(result.content[0].text)
assert "charts" in data

View File

@@ -30,7 +30,6 @@ from flask import g
from superset.mcp_service.app import mcp
from superset.mcp_service.dashboard.schemas import (
DashboardFilter,
ListDashboardsRequest,
)
from superset.mcp_service.dashboard.tool.get_dashboard_info import (
@@ -1355,66 +1354,13 @@ class TestDashboardSortableColumns:
assert col in list_dashboards.__doc__
class TestListDashboardsCreatedByMe:
"""Tests for the created_by_me flag on ListDashboardsRequest."""
def test_created_by_me_default_is_false(self):
request = ListDashboardsRequest()
assert request.created_by_me is False
def test_created_by_me_true_accepted(self):
request = ListDashboardsRequest(created_by_me=True)
assert request.created_by_me is True
def test_created_by_me_combined_with_filters(self):
request = ListDashboardsRequest(
created_by_me=True,
filters=[DashboardFilter(col="published", opr="eq", value=True)],
)
assert request.created_by_me is True
assert len(request.filters) == 1
def test_created_by_me_with_search_raises(self):
from pydantic import ValidationError
with pytest.raises(ValidationError, match="created_by_me"):
ListDashboardsRequest(created_by_me=True, search="My dashboards")
def test_dashboard_filter_rejects_created_by_fk(self):
"""created_by_fk is not a public filter column; use created_by_me instead."""
from pydantic import ValidationError
with pytest.raises(ValidationError):
DashboardFilter(col="created_by_fk", opr="eq", value=1)
class TestListDashboardsOwnedByMe:
"""Tests for the owned_by_me flag on ListDashboardsRequest."""
def test_owned_by_me_default_is_false(self):
request = ListDashboardsRequest()
assert request.owned_by_me is False
def test_owned_by_me_true_accepted(self):
request = ListDashboardsRequest(owned_by_me=True)
assert request.owned_by_me is True
def test_owned_by_me_combined_with_filters(self):
request = ListDashboardsRequest(
owned_by_me=True,
filters=[DashboardFilter(col="published", opr="eq", value=True)],
)
assert request.owned_by_me is True
assert len(request.filters) == 1
def test_owned_by_me_with_search_raises(self):
from pydantic import ValidationError
with pytest.raises(ValidationError, match="owned_by_me"):
ListDashboardsRequest(owned_by_me=True, search="My dashboards")
def test_owned_by_me_and_created_by_me_allowed(self):
"""Both flags together are valid (OR logic — creator or owner)."""
request = ListDashboardsRequest(owned_by_me=True, created_by_me=True)
assert request.owned_by_me is True
assert request.created_by_me is True
@patch("superset.daos.dashboard.DashboardDAO.list")
@pytest.mark.asyncio
async def test_list_dashboards_no_arguments(mock_list, mcp_server):
"""Regression test: list_dashboards must accept zero arguments without raising
pydantic_core.ValidationError: Missing required argument: request."""
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client:
result = await client.call_tool("list_dashboards", {})
data = json.loads(result.content[0].text)
assert "dashboards" in data

View File

@@ -0,0 +1,831 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for the query_dataset MCP tool."""
from __future__ import annotations
import importlib
from collections.abc import Generator
from typing import Any
from unittest.mock import MagicMock, Mock, patch
import pytest
from fastmcp import Client, FastMCP
from superset.mcp_service.app import mcp
from superset.utils import json
query_dataset_module = importlib.import_module(
"superset.mcp_service.dataset.tool.query_dataset"
)
@pytest.fixture
def mcp_server() -> FastMCP:
return mcp
@pytest.fixture(autouse=True)
def mock_auth() -> Generator[MagicMock, None, None]:
"""Mock authentication and metadata access for all tests."""
with (
patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user,
patch.object(
query_dataset_module,
"user_can_view_data_model_metadata",
return_value=True,
),
):
mock_user = Mock()
mock_user.id = 1
mock_user.username = "admin"
mock_get_user.return_value = mock_user
yield mock_get_user
def _make_column(name: str, is_dttm: bool = False) -> MagicMock:
"""Build a mock SqlaTable column with the given name and datetime flag."""
col = MagicMock()
col.column_name = name
col.is_dttm = is_dttm
col.verbose_name = None
col.type = "VARCHAR"
col.groupby = True
col.filterable = True
col.description = None
return col
def _make_metric(name: str, expression: str = "COUNT(*)") -> MagicMock:
"""Build a mock SqlMetric with the given name and SQL expression."""
metric = MagicMock()
metric.metric_name = name
metric.verbose_name = None
metric.expression = expression
metric.description = None
metric.d3format = None
return metric
def _make_dataset(
dataset_id: int = 1,
table_name: str = "orders",
columns: list[Any] | None = None,
metrics: list[Any] | None = None,
main_dttm_col: str | None = None,
) -> MagicMock:
"""Build a mock SqlaTable dataset with default columns and metrics."""
ds = MagicMock()
ds.id = dataset_id
ds.table_name = table_name
ds.uuid = f"test-uuid-{dataset_id}"
ds.main_dttm_col = main_dttm_col
ds.database = MagicMock()
ds.database.database_name = "examples"
ds.columns = columns or [
_make_column("category"),
_make_column("region"),
_make_column("order_date", is_dttm=True),
]
ds.metrics = metrics or [
_make_metric("count", "COUNT(*)"),
_make_metric("total_revenue", "SUM(revenue)"),
]
return ds
def _mock_command_result(
data: list[dict[str, Any]] | None = None,
colnames: list[str] | None = None,
) -> dict[str, Any]:
"""Build the result dict that ChartDataCommand.run() returns."""
data = data or [
{"category": "Electronics", "count": 42},
{"category": "Clothing", "count": 17},
]
colnames = colnames or ["category", "count"]
return {
"queries": [
{
"data": data,
"colnames": colnames,
"rowcount": len(data),
"cache_key": "abc123",
"is_cached": False,
"cached_dttm": None,
"cache_timeout": 300,
}
]
}
@pytest.mark.asyncio
async def test_query_dataset_success(mcp_server: FastMCP) -> None:
"""Happy path: metrics + columns returns data."""
dataset = _make_dataset()
result_data = _mock_command_result()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
return_value=result_data,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
return_value=MagicMock(),
),
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"columns": ["category"],
}
},
)
data = json.loads(result.content[0].text)
assert data["dataset_id"] == 1
assert data["dataset_name"] == "orders"
assert data["row_count"] == 2
assert len(data["data"]) == 2
assert data["data"][0]["category"] == "Electronics"
@pytest.mark.asyncio
async def test_query_dataset_not_found(mcp_server: FastMCP) -> None:
"""Dataset ID that doesn't exist returns error."""
with patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=None,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 999,
"metrics": ["count"],
}
},
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "NotFound"
assert "999" in data["error"]
@pytest.mark.asyncio
async def test_query_dataset_invalid_metric(mcp_server: FastMCP) -> None:
"""Unknown metric name returns validation error with suggestions."""
dataset = _make_dataset()
with patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["countt"], # typo
}
},
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "ValidationError"
assert "countt" in data["error"]
# Should suggest "count" as a close match
assert "count" in data["error"]
@pytest.mark.asyncio
async def test_query_dataset_invalid_column(mcp_server: FastMCP) -> None:
"""Unknown column name returns validation error."""
dataset = _make_dataset()
with patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"columns": ["nonexistent_col"],
"metrics": ["count"],
}
},
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "ValidationError"
assert "nonexistent_col" in data["error"]
@pytest.mark.asyncio
async def test_query_dataset_no_metrics_no_columns(mcp_server: FastMCP) -> None:
"""Providing neither metrics nor columns raises validation error."""
from fastmcp.exceptions import ToolError
async with Client(mcp_server) as client:
with pytest.raises(ToolError, match="metrics.*columns"):
await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": [],
"columns": [],
}
},
)
@pytest.mark.asyncio
async def test_query_dataset_with_time_range(mcp_server: FastMCP) -> None:
"""time_range is converted to TEMPORAL_RANGE filter + granularity."""
dataset = _make_dataset(main_dttm_col="order_date")
result_data = _mock_command_result()
captured_queries: list[dict[str, Any]] = []
def capture_create(**kwargs):
captured_queries.extend(kwargs.get("queries", []))
return MagicMock()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
return_value=result_data,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
side_effect=capture_create,
),
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"time_range": "Last 7 days",
}
},
)
assert len(captured_queries) == 1
query_dict = captured_queries[0]
# Should have TEMPORAL_RANGE filter
temporal_filters = [f for f in query_dict["filters"] if f["op"] == "TEMPORAL_RANGE"]
assert len(temporal_filters) == 1
assert temporal_filters[0]["col"] == "order_date"
assert temporal_filters[0]["val"] == "Last 7 days"
# Should set granularity
assert query_dict["granularity"] == "order_date"
# applied_filters in response must include the synthesized TEMPORAL_RANGE filter
data = json.loads(result.content[0].text)
resp_filters = data["applied_filters"]
temporal_resp = [f for f in resp_filters if f["op"] == "TEMPORAL_RANGE"]
assert len(temporal_resp) == 1
assert temporal_resp[0]["col"] == "order_date"
assert temporal_resp[0]["val"] == "Last 7 days"
@pytest.mark.asyncio
async def test_query_dataset_time_range_no_temporal_column(mcp_server: FastMCP) -> None:
"""time_range without a temporal column returns error."""
dataset = _make_dataset(main_dttm_col=None)
with patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"time_range": "Last 7 days",
}
},
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "ValidationError"
assert "temporal column" in data["error"].lower()
@pytest.mark.asyncio
async def test_query_dataset_with_filters(mcp_server: FastMCP) -> None:
"""User-provided filters are passed through to the query."""
dataset = _make_dataset()
result_data = _mock_command_result()
captured_queries: list[dict[str, Any]] = []
def capture_create(**kwargs):
captured_queries.extend(kwargs.get("queries", []))
return MagicMock()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
return_value=result_data,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
side_effect=capture_create,
),
):
async with Client(mcp_server) as client:
await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"filters": [
{"col": "category", "op": "==", "val": "Electronics"}
],
}
},
)
assert len(captured_queries) == 1
filters = captured_queries[0]["filters"]
assert len(filters) == 1
assert filters[0]["col"] == "category"
assert filters[0]["op"] == "=="
assert filters[0]["val"] == "Electronics"
@pytest.mark.asyncio
async def test_query_dataset_empty_results(mcp_server: FastMCP) -> None:
"""Query that returns no data gives a response with row_count=0."""
dataset = _make_dataset()
empty_result = {
"queries": [
{
"data": [],
"colnames": [],
"rowcount": 0,
"is_cached": False,
"cached_dttm": None,
"cache_timeout": 300,
}
]
}
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
return_value=empty_result,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
return_value=MagicMock(),
),
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
}
},
)
data = json.loads(result.content[0].text)
assert data["row_count"] == 0
assert data["data"] == []
assert "no data" in data["summary"].lower()
@pytest.mark.asyncio
async def test_query_dataset_by_uuid(mcp_server: FastMCP) -> None:
"""UUID-based lookup works."""
dataset = _make_dataset()
result_data = _mock_command_result()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
) as mock_resolve,
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
return_value=result_data,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
return_value=MagicMock(),
),
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": "a1b2c3d4-5678-90ab-cdef-1234567890ab",
"metrics": ["count"],
}
},
)
# Verify the resolve function was called with the UUID
mock_resolve.assert_called_once()
call_args = mock_resolve.call_args
assert call_args[0][0] == "a1b2c3d4-5678-90ab-cdef-1234567890ab"
data = json.loads(result.content[0].text)
assert data["dataset_id"] == 1
@pytest.mark.asyncio
async def test_query_dataset_permission_denied(mcp_server: FastMCP) -> None:
"""Permission denied from ChartDataCommand.validate() returns error."""
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetSecurityException
dataset = _make_dataset()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
return_value=MagicMock(),
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
side_effect=SupersetSecurityException(
SupersetError(
message="Access denied",
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
level=ErrorLevel.WARNING,
)
),
),
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
}
},
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "QueryError"
@pytest.mark.asyncio
async def test_query_dataset_order_by_valid(mcp_server: FastMCP) -> None:
"""order_by with valid column/metric names passes through."""
dataset = _make_dataset()
result_data = _mock_command_result()
captured_queries: list[dict[str, Any]] = []
def capture_create(**kwargs):
captured_queries.extend(kwargs.get("queries", []))
return MagicMock()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
return_value=result_data,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
side_effect=capture_create,
),
):
async with Client(mcp_server) as client:
await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"columns": ["category"],
"order_by": ["count"],
"order_desc": True,
}
},
)
assert len(captured_queries) == 1
orderby = captured_queries[0].get("orderby", [])
assert len(orderby) == 1
assert orderby[0][0] == "count"
# order_desc=True -> ascending=False
assert orderby[0][1] is False
@pytest.mark.asyncio
async def test_query_dataset_order_by_invalid(mcp_server: FastMCP) -> None:
"""order_by with an unknown name returns validation error."""
dataset = _make_dataset()
with patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"order_by": ["nonexistent"],
}
},
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "ValidationError"
assert "nonexistent" in data["error"]
@pytest.mark.asyncio
async def test_query_dataset_time_column_override(mcp_server: FastMCP) -> None:
"""Explicit time_column overrides dataset main_dttm_col."""
dataset = _make_dataset(main_dttm_col="order_date")
result_data = _mock_command_result()
captured_queries: list[dict[str, Any]] = []
def capture_create(**kwargs):
captured_queries.extend(kwargs.get("queries", []))
return MagicMock()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
return_value=result_data,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
side_effect=capture_create,
),
):
async with Client(mcp_server) as client:
await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"time_range": "Last 30 days",
"time_column": "order_date",
}
},
)
assert len(captured_queries) == 1
query_dict = captured_queries[0]
assert query_dict["granularity"] == "order_date"
temporal_filters = [f for f in query_dict["filters"] if f["op"] == "TEMPORAL_RANGE"]
assert temporal_filters[0]["col"] == "order_date"
@pytest.mark.asyncio
async def test_query_dataset_non_dttm_time_column_warns(mcp_server: FastMCP) -> None:
"""Using a non-datetime column for time_range produces a warning."""
dataset = _make_dataset(main_dttm_col=None)
result_data = _mock_command_result()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
return_value=result_data,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
return_value=MagicMock(),
),
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"time_range": "Last 7 days",
"time_column": "category",
}
},
)
data = json.loads(result.content[0].text)
assert len(data["warnings"]) > 0
assert "not marked as a datetime" in data["warnings"][0]
@pytest.mark.asyncio
async def test_query_dataset_invalid_filter_column(mcp_server: FastMCP) -> None:
"""Filter on a column that doesn't exist returns validation error."""
dataset = _make_dataset()
with patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"filters": [
{
"col": "nonexistent",
"op": "==",
"val": "test",
}
],
}
},
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "ValidationError"
assert "nonexistent" in data["error"]
@pytest.mark.asyncio
async def test_query_dataset_metadata_access_denied_no_suggestions(
mcp_server: FastMCP,
) -> None:
"""Users without data-model metadata access cannot probe column/metric names.
The privacy gate must fire before the validation step that returns close-match
suggestions, so restricted users cannot enumerate schema details via typos.
"""
dataset = _make_dataset()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch.object(
query_dataset_module,
"user_can_view_data_model_metadata",
return_value=False,
),
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
# Typo that would normally trigger close-match suggestions
"metrics": ["countt"],
}
},
)
data = json.loads(result.content[0].text)
# Must be denied before returning any schema suggestions
assert data["error_type"] == "DataModelMetadataRestricted"
# Must NOT contain column/metric name suggestions
assert "countt" not in data.get("error", "")
assert "count" not in data.get("error", "")
@pytest.mark.asyncio
async def test_query_dataset_metadata_access_denied_nonexistent_dataset(
mcp_server: FastMCP,
) -> None:
"""Metadata-restricted users must not be able to probe dataset existence.
The privacy gate fires before the DAO lookup, so a restricted caller
always receives DataModelMetadataRestricted — never NotFound — regardless
of whether the requested dataset ID exists.
"""
with patch.object(
query_dataset_module,
"user_can_view_data_model_metadata",
return_value=False,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
# Use a dataset_id that does not exist
"dataset_id": 999999,
"metrics": ["count"],
}
},
)
data = json.loads(result.content[0].text)
# Must receive restricted error, not a NotFound that leaks existence
assert data["error_type"] == "DataModelMetadataRestricted"
assert data["error_type"] != "NotFound"

View File

@@ -0,0 +1,133 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for Rison filter parser."""
from superset.utils.rison_filters import RisonFilterParser
def test_simple_equality():
parser = RisonFilterParser()
result = parser.parse("(country:USA)")
assert len(result) == 1
assert result[0]["expressionType"] == "SIMPLE"
assert result[0]["clause"] == "WHERE"
assert result[0]["subject"] == "country"
assert result[0]["operator"] == "=="
assert result[0]["comparator"] == "USA"
def test_multiple_filters_and():
parser = RisonFilterParser()
result = parser.parse("(country:USA,year:2024)")
assert len(result) == 2
assert result[0]["subject"] == "country"
assert result[0]["comparator"] == "USA"
assert result[1]["subject"] == "year"
assert result[1]["comparator"] == 2024
def test_list_in_operator():
parser = RisonFilterParser()
result = parser.parse("(country:!(USA,Canada))")
assert len(result) == 1
assert result[0]["subject"] == "country"
assert result[0]["operator"] == "IN"
assert result[0]["comparator"] == ["USA", "Canada"]
def test_not_operator():
parser = RisonFilterParser()
result = parser.parse("(NOT:(country:USA))")
assert len(result) == 1
assert result[0]["subject"] == "country"
assert result[0]["operator"] == "!="
assert result[0]["comparator"] == "USA"
def test_not_in_operator():
parser = RisonFilterParser()
result = parser.parse("(NOT:(country:!(USA,Canada)))")
assert len(result) == 1
assert result[0]["subject"] == "country"
assert result[0]["operator"] == "NOT IN"
assert result[0]["comparator"] == ["USA", "Canada"]
def test_or_operator():
parser = RisonFilterParser()
result = parser.parse("(OR:!((status:active),(priority:high)))")
assert len(result) == 1
assert result[0]["expressionType"] == "SQL"
assert result[0]["clause"] == "WHERE"
assert "status = 'active' OR priority = 'high'" in result[0]["sqlExpression"]
def test_comparison_operators():
parser = RisonFilterParser()
result = parser.parse("(sales:(gt:100000))")
assert result[0]["operator"] == ">"
assert result[0]["comparator"] == 100000
result = parser.parse("(age:(gte:18))")
assert result[0]["operator"] == ">="
assert result[0]["comparator"] == 18
result = parser.parse("(temp:(lt:32))")
assert result[0]["operator"] == "<"
assert result[0]["comparator"] == 32
result = parser.parse("(price:(lte:1000))")
assert result[0]["operator"] == "<="
assert result[0]["comparator"] == 1000
def test_between_operator():
parser = RisonFilterParser()
result = parser.parse("(date:(between:!('2024-01-01','2024-12-31')))")
assert len(result) == 1
assert result[0]["operator"] == "BETWEEN"
assert result[0]["comparator"] == ["2024-01-01", "2024-12-31"]
def test_like_operator():
parser = RisonFilterParser()
result = parser.parse("(name:(like:'%smith%'))")
assert len(result) == 1
assert result[0]["operator"] == "LIKE"
assert result[0]["comparator"] == "%smith%"
def test_empty_filter():
parser = RisonFilterParser()
assert parser.parse("") == []
assert parser.parse("()") == []
def test_invalid_rison():
parser = RisonFilterParser()
assert parser.parse("invalid rison") == []
assert parser.parse("(unclosed") == []