Compare commits

..

8 Commits

Author SHA1 Message Date
Joe Li
e49fd50489 fix(database): stop saving after validation errors 2026-06-25 15:59:15 -07:00
Joe Li
f012128700 test(databases): migrate database modal Cypress tests to RTL
Port the deprecated Cypress database-modal E2E suite to Jest + React
Testing Library, continuing the testing-strategy modernization (prefer
unit/integration over E2E; Cypress is being removed).

The two original 'error alert' cases relied on a real backend connection
attempt (real DNS / socket behaviour), which is what made them flaky.
They are reproduced as RTL tests that mock the validate_parameters
response, exercising the frontend's actual responsibility: mapping an
extra.invalid field error onto the matching form field. Whether a bad
host/port truly fails to connect is a backend concern, covered by
backend tests -- noted here so a reviewer can veto the trade-off.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-25 14:41:49 -07:00
Greg Neighbors
d8bcc66472 feat(mcp): dashboard layout, theme, and CSS control + update_dashboard tool (#40399)
Co-authored-by: gkneighb <26003+gkneighb@users.noreply.github.com>
Co-authored-by: Greg Neighbors <gregneighbors@Gregs-MacBook-Air-2.local>
Co-authored-by: Greg Neighbors <gregneighbors@Gregs-Air-2.lan>
Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-authored-by: Evan Rusackas <evan@rusackas.com>
2026-06-25 10:41:07 -07:00
Evan Rusackas
4b9b8187b3 fix(config): make Swagger UI opt-in (off by default) (#41300)
Co-authored-by: Amin Ghadersohi <amin.ghadersohi@gmail.com>
Co-authored-by: Claude Fable 5 <noreply@anthropic.com>
2026-06-25 10:34:28 -07:00
Evan Rusackas
83f7dc9d5b chore(codeowners): add translation maintainers (#41429)
Co-authored-by: Amin Ghadersohi <amin.ghadersohi@gmail.com>
Co-authored-by: Claude Fable 5 <noreply@anthropic.com>
2026-06-25 10:09:16 -07:00
Elizabeth Thompson
baca76ebe0 fix(slack): fix indented triple-quoted string in v1 API deprecation warning (#41393) 2026-06-25 09:54:33 -07:00
Mehmet Salih Yavuz
9a11c15a33 feat(explore): add full-range option for time-shift comparison (#41334) 2026-06-25 18:30:33 +03:00
Michael S. Molina
a90c8e0347 feat(extensions): add Chat contribution type (SIP-214) (#41205)
Co-authored-by: Enzo Martellucci <52219496+EnxDev@users.noreply.github.com>
Co-authored-by: Enzo Martellucci <enzomartellucci@gmail.com>
Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-25 08:57:30 -03:00
71 changed files with 4108 additions and 1930 deletions

2
.github/CODEOWNERS vendored
View File

@@ -38,7 +38,7 @@
# Notify translation maintainers of changes to translations
/superset/translations/ @sfirke @rusackas
/superset/translations/ @sfirke @rusackas @villebro @sadpandajoe @hainenber
# Notify PMC members of changes to extension-related files

View File

@@ -32,6 +32,10 @@ The `rls` rules passed to `POST /api/v1/security/guest_token/` are now validated
When the MCP service has JWT auth enabled (`MCP_AUTH_ENABLED = True`), an audience must be configured via `MCP_JWT_AUDIENCE` so issued tokens are bound to this service. The service now fails to start with a clear configuration error when the audience is unset, instead of starting with audience validation skipped. Deployments that enable MCP JWT auth must set `MCP_JWT_AUDIENCE` to the audience value their identity provider issues for the MCP service. API-key-only MCP deployments (JWT auth disabled) are unaffected.
### Swagger UI is opt-in (off by default)
`FAB_API_SWAGGER_UI` now defaults to `False` and is driven by the `SUPERSET_ENABLE_SWAGGER_UI` environment variable. The interactive Swagger UI / OpenAPI documentation endpoints (e.g. `/swagger/v1`) are therefore no longer exposed by default. To enable them, set `SUPERSET_ENABLE_SWAGGER_UI=true` (the bundled Docker development environment sets this) or override `FAB_API_SWAGGER_UI = True` in `superset_config.py`.
### Pivot table First/Last aggregations follow data order
The pivot table chart's `First` and `Last` aggregations now return the first and last value in data (query result) order, instead of effectively returning the minimum and maximum. Existing pivot tables that use these aggregations for totals/subtotals may show different values after upgrading. For deterministic results, ensure the underlying query has a stable sort order.

View File

@@ -70,6 +70,8 @@ SUPERSET_LOG_LEVEL=info
SUPERSET_APP_ROOT="/"
SUPERSET_ENV=development
# Swagger UI is opt-in (off by default); enable it for local development.
SUPERSET_ENABLE_SWAGGER_UI=true
SUPERSET_LOAD_EXAMPLES=yes
CYPRESS_CONFIG=false
SUPERSET_PORT=8088

View File

@@ -34,15 +34,14 @@ Frontend contribution types allow extensions to extend Superset's user interface
Extensions can add new views or panels to the host application, such as custom SQL Lab panels, dashboards, or other UI components. Contribution areas are uniquely identified (e.g., `sqllab.panels` for SQL Lab panels), enabling seamless integration into specific parts of the application.
```tsx
import React from 'react';
```typescript
import { views } from '@apache-superset/core';
import MyPanel from './MyPanel';
views.registerView(
{ id: 'my-extension.main', name: 'My Panel Name' },
'sqllab.panels',
() => <MyPanel />,
MyPanel,
);
```
@@ -112,6 +111,24 @@ editors.registerEditor(
See [Editors Extension Point](./extension-points/editors.md) for implementation details.
### Chat
Extensions can add a chat interface to Superset by registering a trigger component and a panel component. The host owns the layout, open/close state, and display mode — the extension only provides the UI. The panel can be displayed as a floating overlay or docked as a resizable sidebar beside the page content, and the user's preference is persisted across reloads.
```tsx
import { chat } from '@apache-superset/core';
import ChatTrigger from './ChatTrigger';
import ChatPanel from './ChatPanel';
chat.registerChat(
{ id: 'my-org.my-chat', name: 'My Chat' },
ChatTrigger,
ChatPanel,
);
```
See [Chat](./extension-points/chat.md) for implementation details.
## Backend
Backend contribution types allow extensions to extend Superset's server-side capabilities. Backend contributions are registered at startup via classes and functions imported from the auto-discovered `entrypoint.py` file.

View File

@@ -0,0 +1,141 @@
---
title: Chat
sidebar_position: 3
---
<!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
# Chat Contributions
Extensions can add a chat interface to Superset by registering a trigger and a panel. The host owns the layout, open/close state, and display mode — the extension only needs to provide the UI components.
## Overview
A chat registration consists of two React components:
| Component | Role |
|-----------|------|
| **Trigger** | Always-visible entry point (e.g., a floating button). Rendered in the bottom-right corner in floating mode, or as a fixed overlay in panel mode. |
| **Panel** | The chat UI itself (message list, input, etc.). Mounted by the host in the active display mode. |
## Display Modes
The host supports two display modes, switchable by the user or the extension at runtime:
| Mode | Behavior |
|------|----------|
| `floating` | Panel floats above page content, anchored to the bottom-right corner. |
| `panel` | Panel is docked to the right side of the application as a resizable sidebar, sitting beside the page content. |
The user's last selected mode and open/closed state are persisted across page reloads.
## Registering a Chat
Call `chat.registerChat` from your extension's entry point with a descriptor, a trigger factory, and a panel factory:
```tsx
import { chat } from '@apache-superset/core';
import ChatTrigger from './ChatTrigger';
import ChatPanel from './ChatPanel';
chat.registerChat(
{ id: 'my-org.my-chat', name: 'My Chat' },
ChatTrigger,
ChatPanel,
);
```
Only one chat registration is active at a time. If a second extension calls `registerChat`, it replaces the first and a warning is logged.
## Opening and Closing the Chat
The trigger component is responsible for toggling the panel. Use `chat.isOpen()`, `chat.open()`, and `chat.close()` to control visibility:
```tsx
import { chat } from '@apache-superset/core';
export default function ChatTrigger() {
return (
<button onClick={() => (chat.isOpen() ? chat.close() : chat.open())}>
💬
</button>
);
}
```
You can also subscribe to open/close events from any component:
```tsx
useEffect(() => {
const { dispose } = chat.onDidOpen(() => console.log('chat opened'));
return dispose;
}, []);
```
## Changing the Display Mode
Call `chat.setDisplayMode` to switch between `'floating'` and `'panel'` modes. In your panel component, subscribe to `onDidChangeDisplayMode` to react to changes (including those triggered by the user):
```tsx
import { useState, useEffect } from 'react';
import { chat } from '@apache-superset/core';
export default function ChatPanel() {
const [mode, setMode] = useState(chat.getDisplayMode());
useEffect(() => {
const { dispose } = chat.onDidChangeDisplayMode(m => setMode(m));
return dispose;
}, []);
return (
<div style={{ height: mode === 'panel' ? '100%' : '80vh' }}>
<button onClick={() => chat.setDisplayMode(mode === 'panel' ? 'floating' : 'panel')}>
{mode === 'panel' ? 'Float' : 'Dock'}
</button>
{/* message list and input */}
</div>
);
}
```
## Chat API Reference
All methods are available on the `chat` namespace from `@apache-superset/core`:
| Method / Event | Description |
|----------------|-------------|
| `registerChat(descriptor, trigger, panel)` | Register a chat extension. Returns a `Disposable` to unregister. |
| `open()` | Open the chat panel. No-op if already open or no registration. |
| `close()` | Close the chat panel. |
| `isOpen()` | Returns `true` if the panel is currently open. |
| `getDisplayMode()` | Returns the current display mode (`'floating'` or `'panel'`). |
| `setDisplayMode(mode)` | Switch between `'floating'` and `'panel'` mode. |
| `onDidOpen(listener)` | Subscribe to panel open events. Returns a `Disposable`. |
| `onDidClose(listener)` | Subscribe to panel close events. Returns a `Disposable`. |
| `onDidChangeDisplayMode(listener)` | Subscribe to display mode changes. Returns a `Disposable`. |
| `onDidRegisterChat(listener)` | Subscribe to registration events. |
| `onDidUnregisterChat(listener)` | Subscribe to unregistration events. |
| `onDidResizePanel(listener)` | Subscribe to panel resize events (panel mode only). Not all hosts provide a resizer — do not rely on this firing. Returns a `Disposable`. |
## Next Steps
- **[Contribution Types](../contribution-types.md)** — Explore other contribution types
- **[Development](../development.md)** — Set up your development environment

View File

@@ -47,6 +47,8 @@ module.exports = {
collapsed: true,
items: [
'extensions/extension-points/sqllab',
'extensions/extension-points/editors',
'extensions/extension-points/chat',
],
},
'extensions/development',

View File

@@ -519,104 +519,6 @@ For a connection to a SQL endpoint you need to use the HTTP path from the endpoi
{"connect_args": {"http_path": "/sql/1.0/endpoints/****", "driver_path": "/path/to/odbc/driver"}}
```
##### OAuth2 Authentication
Superset supports OAuth2 authentication for Databricks, allowing users to authenticate with their personal Databricks accounts instead of using shared access tokens. This provides better security and audit capabilities.
###### Prerequisites
1. Create an OAuth2 application in your Databricks account:
- Go to your Databricks account console
- Navigate to **Settings** → **Developer** → **OAuth apps**
- Create a new OAuth app with the redirect URI: `http://your-superset-host:port/api/v1/database/oauth2/`
2. Configure OAuth2 in your `superset_config.py`:
```python
from datetime import timedelta
# OAuth2 configuration for Databricks
# OAuth2 endpoints are automatically detected based on your Databricks cloud provider
DATABASE_OAUTH2_CLIENTS = {
"Databricks (legacy)": {
"id": "your-databricks-client-id",
"secret": "your-databricks-client-secret",
"scope": "sql",
# The authorization endpoint is auto-detected from the hostname; the
# token endpoint must be set explicitly (no DB context at exchange):
# AWS: "authorization_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/{account_id}/v1/authorize",
# Azure: "authorization_request_uri": "https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/authorize",
# GCP: "authorization_request_uri": "https://accounts.gcp.databricks.com/oidc/accounts/{account_id}/v1/authorize",
# "token_request_uri": "https://<provider-token-endpoint>",
},
"Databricks": {
"id": "your-databricks-client-id",
"secret": "your-databricks-client-secret",
"scope": "sql",
# Authorization endpoint auto-detected from hostname; set
# "token_request_uri" explicitly for the token exchange.
},
}
# OAuth2 redirect URI (adjust hostname/port for your setup)
DATABASE_OAUTH2_REDIRECT_URI = "http://your-superset-host:port/api/v1/database/oauth2/"
# Optional: OAuth2 timeout
DATABASE_OAUTH2_TIMEOUT = timedelta(seconds=30)
```
Replace the following placeholders:
- `your-databricks-client-id`: Your Databricks OAuth2 application client ID
- `your-databricks-client-secret`: Your Databricks OAuth2 application client secret
- `your-superset-host:port`: Your Superset instance hostname and port
**Multi-Cloud Provider Support**
Superset automatically detects your Databricks cloud provider and uses the appropriate OAuth2 endpoints:
- **AWS**: Detected from hostnames containing `cloud.databricks.com`
- **Azure**: Detected from hostnames containing `azure` or `azuredatabricks`
- **GCP**: Detected from hostnames containing `gcp` or `googleusercontent`
You can also explicitly specify the cloud provider, along with the account
identifier used to build the OAuth2 endpoints, in your database configuration
under **Advanced** → **Other** → **ENGINE PARAMETERS**:
```json
{
"cloud_provider": "azure",
"tenant_id": "your-azure-tenant-id"
}
```
For AWS and GCP, supply `account_id` instead:
```json
{
"cloud_provider": "aws",
"account_id": "your-databricks-account-id"
}
```
Valid cloud provider values are: `aws`, `azure`, `gcp`. The **authorization**
endpoint is auto-detected: Superset substitutes this identifier into the
provider's authorization template. The **token** endpoint is not auto-resolved
(token exchange has no database context to detect the provider), so for the
auto-detected flow you must still supply a fully-resolved `token_request_uri`
in `DATABASE_OAUTH2_CLIENTS`. If you supply fully-resolved
`authorization_request_uri` and `token_request_uri` values, those take
precedence and no `account_id`/`tenant_id` is required.
###### Usage
Once configured, users can:
1. Connect to Databricks databases normally using access tokens
2. When querying data, Superset will automatically redirect users to authenticate with Databricks if needed
3. User-specific OAuth2 tokens will be used for database connections, providing better security and audit trails
This feature works with both "Databricks (legacy)" and "Databricks" engine types and automatically supports all major cloud providers (AWS, Azure, GCP).
#### Denodo
The recommended connector library for Denodo is

View File

@@ -1,118 +0,0 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { DATABASE_LIST } from 'cypress/utils/urls';
function closeModal() {
cy.get('body').then($body => {
if ($body.find('[data-test="database-modal"]').length) {
cy.get('[aria-label="Close"]').eq(1).click();
}
});
}
describe('Add database', () => {
before(() => {
cy.visit(DATABASE_LIST);
});
beforeEach(() => {
cy.intercept('POST', '**/api/v1/database/validate_parameters/**').as(
'validateParams',
);
cy.intercept('POST', '**/api/v1/database/').as('createDb');
closeModal();
cy.getBySel('btn-create-database').click();
});
it('should open dynamic form', () => {
cy.get('.preferred > :nth-child(1)').click();
cy.get('input[name="host"]').should('have.value', '');
cy.get('input[name="port"]').should('have.value', '');
cy.get('input[name="database"]').should('have.value', '');
cy.get('input[name="username"]').should('have.value', '');
cy.get('input[name="password"]').should('have.value', '');
cy.get('input[name="database_name"]').should('have.value', '');
});
it('should open sqlalchemy form', () => {
cy.get('.preferred > :nth-child(1)').click();
cy.getBySel('sqla-connect-btn').click();
cy.getBySel('database-name-input').should('be.visible');
cy.getBySel('sqlalchemy-uri-input').should('be.visible');
});
it('show error alerts on dynamic form for bad host', () => {
cy.get('.preferred > :nth-child(1)').click();
cy.get('input[name="host"]').type('badhost', { force: true });
cy.get('input[name="port"]').type('5432', { force: true });
cy.get('input[name="username"]').type('testusername', { force: true });
cy.get('input[name="database"]').type('testdb', { force: true });
cy.get('input[name="password"]').type('testpass', { force: true });
cy.get('body').click(0, 0);
cy.wait('@validateParams', { timeout: 30000 });
cy.getBySel('btn-submit-connection').should('not.be.disabled');
cy.getBySel('btn-submit-connection').click({ force: true });
cy.wait('@validateParams', { timeout: 30000 }).then(() => {
cy.wait('@createDb', { timeout: 60000 }).then(() => {
cy.contains(
'.ant-form-item-explain-error',
"The hostname provided can't be resolved",
).should('exist');
});
});
});
it('show error alerts on dynamic form for bad port', () => {
cy.get('.preferred > :nth-child(1)').click();
cy.get('input[name="host"]').type('localhost', { force: true });
cy.get('body').click(0, 0);
cy.wait('@validateParams', { timeout: 30000 });
cy.get('input[name="port"]').type('5430', { force: true });
cy.get('input[name="database"]').type('testdb', { force: true });
cy.get('input[name="username"]').type('testusername', { force: true });
cy.wait('@validateParams', { timeout: 30000 });
cy.get('input[name="password"]').type('testpass', { force: true });
cy.wait('@validateParams');
cy.getBySel('btn-submit-connection').should('not.be.disabled');
cy.getBySel('btn-submit-connection').click({ force: true });
cy.wait('@validateParams', { timeout: 30000 }).then(() => {
cy.get('body').click(0, 0);
cy.getBySel('btn-submit-connection').click({ force: true });
cy.wait('@createDb', { timeout: 60000 }).then(() => {
cy.contains(
'.ant-form-item-explain-error',
'The port is closed',
).should('exist');
});
});
});
});

View File

@@ -26533,6 +26533,21 @@
}
}
},
"node_modules/jsdom/node_modules/@noble/hashes": {
"version": "2.2.0",
"resolved": "https://registry.npmjs.org/@noble/hashes/-/hashes-2.2.0.tgz",
"integrity": "sha512-IYqDGiTXab6FniAgnSdZwgWbomxpy9FtYvLKs7wCUs2a8RkITG+DFGO1DM9cr+E3/RgADRpFjrKVaJ1z6sjtEg==",
"dev": true,
"license": "MIT",
"optional": true,
"peer": true,
"engines": {
"node": ">= 20.19.0"
},
"funding": {
"url": "https://paulmillr.com/funding/"
}
},
"node_modules/jsdom/node_modules/css-tree": {
"version": "3.2.1",
"resolved": "https://registry.npmjs.org/css-tree/-/css-tree-3.2.1.tgz",
@@ -43570,6 +43585,21 @@
}
}
},
"node_modules/whatwg-url/node_modules/@noble/hashes": {
"version": "2.2.0",
"resolved": "https://registry.npmjs.org/@noble/hashes/-/hashes-2.2.0.tgz",
"integrity": "sha512-IYqDGiTXab6FniAgnSdZwgWbomxpy9FtYvLKs7wCUs2a8RkITG+DFGO1DM9cr+E3/RgADRpFjrKVaJ1z6sjtEg==",
"dev": true,
"license": "MIT",
"optional": true,
"peer": true,
"engines": {
"node": ">= 20.19.0"
},
"funding": {
"url": "https://paulmillr.com/funding/"
}
},
"node_modules/whatwg-url/node_modules/webidl-conversions": {
"version": "8.0.1",
"resolved": "https://registry.npmjs.org/webidl-conversions/-/webidl-conversions-8.0.1.tgz",

View File

@@ -18,6 +18,14 @@
"types": "./lib/authentication/index.d.ts",
"default": "./lib/authentication/index.js"
},
"./chat": {
"types": "./lib/chat/index.d.ts",
"default": "./lib/chat/index.js"
},
"./navigation": {
"types": "./lib/navigation/index.d.ts",
"default": "./lib/navigation/index.js"
},
"./commands": {
"types": "./lib/commands/index.d.ts",
"default": "./lib/commands/index.js"

View File

@@ -0,0 +1,156 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* @fileoverview Chat contribution API for Superset extensions.
*
* Chat is a dedicated contribution type: an extension registers
* a chat via {@link registerChat} and the host owns where and how it is
* mounted. The host applies singleton resolution — multiple chat extensions
* may register, but exactly one is active at a time.
*
* @example
* ```typescript
* import { chat } from '@apache-superset/core';
*
* chat.registerChat(
* { id: 'acme.chat', name: 'Acme Chat' },
* AcmeTrigger,
* AcmePanel,
* );
* ```
*/
import { ComponentType } from 'react';
import type { Disposable, Event } from '../common';
export interface Chat {
/** The unique identifier for the chat. */
id: string;
/** The display name of the chat. */
name: string;
/** Optional description of the chat. */
description?: string;
}
export type DisplayMode = 'floating' | 'panel';
/**
* Registers a chat provider. Only one chat is active at a time; the most
* recently registered chat wins. Disposing the returned Disposable unregisters
* the chat.
*
* @param chat The chat descriptor (id, name).
* @param trigger The trigger component — the collapsed bubble entry point.
* Owns dynamic state such as unread counts.
* @param panel The panel component, rendered in either display mode. In
* 'floating' mode it appears as an overlay; in 'panel' mode it is docked
* alongside the main content.
* @returns A Disposable that unregisters the chat when disposed.
*
* @example
* ```typescript
* chat.registerChat(
* { id: 'acme.chat', name: 'Acme Chat' },
* AcmeTrigger,
* AcmePanel,
* );
* ```
*/
export declare function registerChat(
chat: Chat,
trigger: ComponentType,
panel: ComponentType,
): Disposable;
/**
* Returns the active chat descriptor, or undefined if none is registered.
*/
export declare function getChat(): Chat | undefined;
/**
* Event fired when a chat is registered.
*/
export declare const onDidRegisterChat: Event<Chat>;
/**
* Event fired when a chat is unregistered.
*/
export declare const onDidUnregisterChat: Event<Chat>;
/**
* Opens the active chat's panel.
*
* Acts on whichever chat is active, regardless of which extension calls it.
* No-op when no chat is registered or the panel is already open.
*/
export declare function open(): void;
/**
* Closes the active chat's panel.
*
* Acts on whichever chat is active, regardless of which extension calls it.
* No-op when the panel is not open.
*/
export declare function close(): void;
/**
* Returns whether the active chat's panel is currently open.
*/
export declare function isOpen(): boolean;
/**
* Event fired when the chat panel opens. Also fired by the host's own
* controls, not only by an extension's open() call.
*/
export declare const onDidOpen: Event<void>;
/**
* Event fired when the chat panel closes, whether triggered by an extension
* or by the host.
*/
export declare const onDidClose: Event<void>;
/**
* Returns the current display mode.
*/
export declare function getDisplayMode(): DisplayMode;
/**
* Sets the display mode. The mode is host-global and applies to whichever
* chat is active. Use {@link onDidChangeDisplayMode} to observe all changes,
* including those triggered by the host.
*/
export declare function setDisplayMode(displayMode: DisplayMode): void;
/**
* Event fired when the display mode changes, whether triggered by an
* extension via setDisplayMode() or by host-provided controls.
*/
export declare const onDidChangeDisplayMode: Event<DisplayMode>;
/**
* Event fired when the panel is resized in panel mode. Not all hosts provide
* a resizer — do not rely on this event firing.
*/
export declare const onDidResizePanel: Event<{ width: number }>;
// TODO: client actions API — tool availability functions will be added here
// once the client_actions SIP is finalized. The chat namespace is the
// intended integration point between the two SIPs.

View File

@@ -223,8 +223,6 @@ export interface Extension {
dependencies: string[];
/** Human-readable description of the extension */
description: string;
/** List of other extensions that this extension depends on */
extensionDependencies: string[];
/** Unique identifier for the extension */
id: string;
/** Human-readable name of the extension */

View File

@@ -23,9 +23,10 @@
* This module defines the aggregate interfaces used by the extension.json
* manifest and the `superset-extensions` build command. Individual metadata
* types are defined in their respective namespace modules (commands, views,
* menus, editors) and re-exported here for the manifest schema.
* menus, editors, chat) and re-exported here for the manifest schema.
*/
import { Chat } from '../chat';
import { Command } from '../commands';
import { View } from '../views';
import { Menu } from '../menus';
@@ -71,7 +72,8 @@ export interface MenuContributions {
}
/**
* Aggregates all contributions (commands, menus, views, and editors) provided by an extension or module.
* Aggregates all contributions (commands, menus, views, editors, and chat)
* provided by an extension or module.
*/
export interface Contributions {
/** List of commands. */
@@ -82,4 +84,10 @@ export interface Contributions {
views: ViewContributions;
/** List of editors. */
editors?: Editor[];
/**
* The chat contributed by the extension — at most one per extension, since
* the host applies singleton resolution and renders exactly one active
* chat at a time.
*/
chat?: Chat;
}

View File

@@ -18,10 +18,12 @@
*/
export * as common from './common';
export * as authentication from './authentication';
export * as chat from './chat';
export * as commands from './commands';
export * as editors from './editors';
export * as extensions from './extensions';
export * as menus from './menus';
export * as navigation from './navigation';
export * as sqlLab from './sqlLab';
export * as views from './views';
export * as contributions from './contributions';

View File

@@ -0,0 +1,81 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* @fileoverview Navigation namespace for Superset extensions.
*
* Exposes the current application surface so extensions can react to route
* changes without polling. Entity-level context (chart, dashboard, dataset)
* is intentionally not included here — surface-specific namespaces that
* resolve entity payloads are introduced in later phases.
*/
import { Event } from '../common';
/**
* The set of top-level application surfaces.
*
* `'explore'`, `'dashboard'` and `'dataset'` are the single-entity
* editing/viewing surfaces. `'chart_list'`, `'dashboard_list'` and
* `'dataset_list'` are the browse/list surfaces, distinct from those because no
* single entity is active. `'sqllab'` is the SQL editor where
* `sqlLab.getCurrentTab()` resolves; `'query_history'` and `'saved_queries'`
* are the related SQL Lab browse pages, which are not the editor. `'home'` is
* the welcome surface and the fallback for any route not explicitly enumerated.
*/
export type Page =
| 'dashboard'
| 'dashboard_list'
| 'explore'
| 'chart_list'
| 'sqllab'
| 'query_history'
| 'saved_queries'
| 'dataset'
| 'dataset_list'
| 'home';
/**
* Returns the current page surface.
*
* @example
* ```typescript
* const page = navigation.getPage();
* if (page === 'dashboard') {
* // react to being on a dashboard surface
* }
* ```
*/
export declare function getPage(): Page;
/**
* Event fired whenever the user navigates to a different surface.
*
* @example
* ```typescript
* const sub = navigation.onDidChangePage(page => {
* if (page === 'dashboard') {
* // react to navigating onto a dashboard surface
* }
* });
* // later:
* sub.dispose();
* ```
*/
export declare const onDidChangePage: Event<Page>;

View File

@@ -30,12 +30,12 @@
*
* views.registerView(
* { id: 'my_ext.result_stats', name: 'Result Stats', location: 'sqllab.panels' },
* () => <ResultStatsPanel />,
* ResultStatsPanel,
* );
* ```
*/
import { ReactElement } from 'react';
import { ComponentType } from 'react';
import { Disposable, Event } from '../common';
/**
@@ -58,7 +58,7 @@ export interface View {
*
* @param view The view descriptor (id and name).
* @param location The location where this view should appear (e.g. "sqllab.panels").
* @param provider A function that returns the React element to render.
* @param component The React component to render at that location.
* @returns A Disposable that unregisters the view when disposed.
*
* @example
@@ -66,14 +66,14 @@ export interface View {
* views.registerView(
* { id: 'my_ext.result_stats', name: 'Result Stats' },
* 'sqllab.panels',
* () => <ResultStatsPanel />,
* ResultStatsPanel,
* );
* ```
*/
export declare function registerView(
view: View,
location: string,
provider: () => ReactElement,
component: ComponentType,
): Disposable;
/**

View File

@@ -132,6 +132,26 @@ export const advancedAnalyticsControls: ControlPanelSectionConfig = {
},
},
],
[
{
name: 'time_compare_full_range',
config: {
type: 'CheckboxControl',
label: t('Show full range for time shift'),
default: false,
description: t(
'Plot each time-shifted series across its full time range instead ' +
'of truncating it to the main series. Useful for comparing a ' +
'partial current period (e.g. today so far) against complete ' +
'prior periods (e.g. all of yesterday).',
),
visibility: ({ controls }) =>
Boolean(controls?.time_compare?.value) &&
(!Array.isArray(controls?.time_compare?.value) ||
controls.time_compare.value.length > 0),
},
},
],
[
{
name: 'comparison_type',

View File

@@ -318,14 +318,25 @@ function createAdvancedAnalyticsSection(
): ControlPanelSectionConfig {
const aaWithSuffix = cloneDeep(sections.advancedAnalyticsControls);
aaWithSuffix.label = label;
// `time_compare_full_range` is only wired into the regular timeseries query
// builder, not the mixed-timeseries one, so drop it here to avoid showing a
// control that has no effect.
aaWithSuffix.controlSetRows = aaWithSuffix.controlSetRows
.map(row =>
row.filter(
control =>
(control as CustomControlItem)?.name !== 'time_compare_full_range',
),
)
.filter(row => row.length > 0);
if (!controlSuffix) {
return aaWithSuffix;
}
aaWithSuffix.controlSetRows.forEach(row =>
row.forEach((control: CustomControlItem) => {
if (control?.name) {
// eslint-disable-next-line no-param-reassign
control.name = `${control.name}${controlSuffix}`;
row.forEach(control => {
const item = control as CustomControlItem;
if (item?.name) {
item.name = `${item.name}${controlSuffix}`;
}
}),
);

View File

@@ -82,6 +82,11 @@ export default function buildQuery(formData: QueryFormData) {
? formData.time_compare
: [];
// When comparing against prior periods, optionally keep each shifted series at
// its full time range instead of truncating it to the main series' range.
const time_compare_full_range =
time_offsets.length > 0 && Boolean(formData.time_compare_full_range);
return [
{
...baseQueryObject,
@@ -92,6 +97,7 @@ export default function buildQuery(formData: QueryFormData) {
// todo: move `normalizeOrderBy to extractQueryFields`
orderby: normalizeOrderBy(baseQueryObject).orderby,
time_offsets,
time_compare_full_range,
/* Note that:
1. The resample, rolling, cum, timeCompare operators should be after pivot.
2. Resample must come before rolling so that imputed values are

View File

@@ -381,6 +381,15 @@ export default function transformProps(
const array = ensureIsArray(chartProps.rawFormData?.time_compare);
const inverted = invert(verboseMap);
// With the "full range" time-shift option, offset series are outer-joined onto
// the main series, which inserts null rows into the main series wherever the
// comparison period has data the current period lacks. Connect nulls so the
// main line stays continuous (matching the default left-join appearance) rather
// than fragmenting at every inserted gap.
const timeCompareFullRange = Boolean(
chartProps.rawFormData?.time_compare_full_range,
);
const offsetLineWidths: { [key: string]: number } = {};
// For horizontal bar charts, calculate min/max from data to avoid cutting off labels
@@ -478,7 +487,7 @@ export default function transformProps(
colorScaleKey,
{
area,
connectNulls: derivedSeries,
connectNulls: derivedSeries || timeCompareFullRange,
filterState,
seriesContexts,
markerEnabled,

View File

@@ -0,0 +1,277 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { act, render, screen } from 'spec/helpers/testing-library';
import { chat } from 'src/core/chat';
import ChatProvider from './ChatProvider';
import { ChatFloatingHost as ChatHost, ChatPanelHost } from './ChatHost';
beforeEach(() => {
ChatProvider.getInstance().reset();
});
test('renders nothing when no chat extension is registered', () => {
render(<ChatHost />);
expect(screen.queryByTestId('chat-mount')).not.toBeInTheDocument();
});
test('renders the trigger bubble of the registered chat', () => {
chat.registerChat(
{ id: 'acme.chat', name: 'Acme Chat' },
() => <button type="button">Acme Bubble</button>,
() => <div>Acme Panel</div>,
);
render(<ChatHost />);
expect(screen.getByTestId('chat-mount')).toBeInTheDocument();
expect(screen.getByText('Acme Bubble')).toBeInTheDocument();
// The panel stays unmounted until the chat is opened.
expect(screen.queryByText('Acme Panel')).not.toBeInTheDocument();
});
test('mounts the panel when the chat opens and unmounts it on close', () => {
chat.registerChat(
{ id: 'acme.chat', name: 'Acme Chat' },
() => <button type="button">Acme Bubble</button>,
() => <div>Acme Panel</div>,
);
render(<ChatHost />);
act(() => chat.open());
expect(screen.getByText('Acme Panel')).toBeInTheDocument();
// In floating mode the trigger stays mounted alongside the open panel.
expect(screen.getByText('Acme Bubble')).toBeInTheDocument();
act(() => chat.close());
expect(screen.queryByText('Acme Panel')).not.toBeInTheDocument();
});
test('renders the last-registered chat when several are installed', () => {
jest.spyOn(console, 'warn').mockImplementation(() => {});
chat.registerChat(
{ id: 'first.chat', name: 'First Chat' },
() => <div>First Bubble</div>,
() => <div>First Panel</div>,
);
chat.registerChat(
{ id: 'second.chat', name: 'Second Chat' },
() => <div>Second Bubble</div>,
() => <div>Second Panel</div>,
);
jest.restoreAllMocks();
render(<ChatHost />);
// Last-loaded wins: the second registration takes over the singleton slot.
expect(screen.getByText('Second Bubble')).toBeInTheDocument();
expect(screen.queryByText('First Bubble')).not.toBeInTheDocument();
});
test('reacts to a chat registering after the initial render', () => {
render(<ChatHost />);
expect(screen.queryByTestId('chat-mount')).not.toBeInTheDocument();
act(() => {
chat.registerChat(
{ id: 'acme.chat', name: 'Acme Chat' },
() => <button type="button">Acme Bubble</button>,
() => <div>Acme Panel</div>,
);
});
expect(screen.getByText('Acme Bubble')).toBeInTheDocument();
});
test('a takeover mounts the incoming chat closed', () => {
chat.registerChat(
{ id: 'first.chat', name: 'First Chat' },
() => <div>First Bubble</div>,
() => <div>First Panel</div>,
);
render(<ChatHost />);
act(() => chat.open());
expect(screen.getByText('First Panel')).toBeInTheDocument();
act(() => {
jest.spyOn(console, 'warn').mockImplementation(() => {});
chat.registerChat(
{ id: 'second.chat', name: 'Second Chat' },
() => <div>Second Bubble</div>,
() => <div>Second Panel</div>,
);
jest.restoreAllMocks();
});
// The displaced chat's open state must not leak into the winner.
expect(screen.getByText('Second Bubble')).toBeInTheDocument();
expect(screen.queryByText('Second Panel')).not.toBeInTheDocument();
expect(screen.queryByText('First Panel')).not.toBeInTheDocument();
});
test('ChatPanelHost renders the panel when open in panel mode', () => {
chat.registerChat(
{ id: 'acme.chat', name: 'Acme Chat' },
() => <button type="button">Acme Bubble</button>,
() => <div>Acme Panel</div>,
);
render(<ChatPanelHost />);
act(() => {
chat.setDisplayMode('panel');
chat.open();
});
expect(screen.getByText('Acme Panel')).toBeInTheDocument();
});
test('ChatFloatingHost suppresses the floating panel in panel mode but keeps the trigger', () => {
chat.registerChat(
{ id: 'acme.chat', name: 'Acme Chat' },
() => <button type="button">Acme Bubble</button>,
() => <div>Acme Panel</div>,
);
render(<ChatHost />);
act(() => {
chat.setDisplayMode('panel');
chat.open();
});
// In panel mode the floating panel is suppressed (ChatPanelHost owns that slot).
expect(screen.queryByText('Acme Panel')).not.toBeInTheDocument();
// The trigger stays rendered so the user can reopen after collapsing.
expect(screen.getByText('Acme Bubble')).toBeInTheDocument();
act(() => chat.close());
// Trigger remains visible even when closed — it's the user's only way back.
expect(screen.getByText('Acme Bubble')).toBeInTheDocument();
});
test('a crashing panel does not take the trigger down with it', () => {
const FailingPanel = () => {
throw new Error('panel blew up');
};
chat.registerChat(
{ id: 'acme.chat', name: 'Acme Chat' },
() => <button type="button">Acme Bubble</button>,
() => <FailingPanel />,
);
render(<ChatHost />);
act(() => chat.open());
// The panel's boundary contains the crash; the trigger keeps rendering so
// the user is not stranded without a way back.
expect(screen.queryByText('panel blew up')).not.toBeInTheDocument();
expect(screen.getByText('Acme Bubble')).toBeInTheDocument();
});
test('isolates a failing trigger so it does not crash the host', () => {
const FailingTrigger = () => {
throw new Error('chat blew up');
};
chat.registerChat(
{ id: 'acme.chat', name: 'Acme Chat' },
() => <FailingTrigger />,
() => <div>Acme Panel</div>,
);
// The host-owned error boundary catches the failure; render does not throw.
expect(() => render(<ChatHost />)).not.toThrow();
// The mount slot still renders (the boundary lives inside it), confirming
// the provider was actually exercised and contained.
expect(screen.getByTestId('chat-mount')).toBeInTheDocument();
});
test('isolates a component that throws during render', () => {
chat.registerChat(
{ id: 'acme.chat', name: 'Acme Chat' },
() => {
throw new Error('provider blew up');
},
() => <div>Acme Panel</div>,
);
expect(() => render(<ChatHost />)).not.toThrow();
expect(screen.getByTestId('chat-mount')).toBeInTheDocument();
});
test('recovers from a crashed chat when a different chat takes over', () => {
const FailingTrigger = () => {
throw new Error('first chat blew up');
};
chat.registerChat(
{ id: 'first.chat', name: 'First Chat' },
() => <FailingTrigger />,
() => <div>First Panel</div>,
);
render(<ChatHost />);
expect(screen.queryByText('Second Bubble')).not.toBeInTheDocument();
act(() => {
jest.spyOn(console, 'warn').mockImplementation(() => {});
chat.registerChat(
{ id: 'second.chat', name: 'Second Chat' },
() => <div>Second Bubble</div>,
() => <div>Second Panel</div>,
);
jest.restoreAllMocks();
});
// The boundary is keyed per registration, so the latched crash from the
// first chat does not blank the second one.
expect(screen.getByText('Second Bubble')).toBeInTheDocument();
});
test('recovers from a crashed chat when a different id takes over', () => {
const FailingTrigger = () => {
throw new Error('broken release');
};
chat.registerChat(
{ id: 'acme.chat', name: 'Acme Chat' },
() => <FailingTrigger />,
() => <div>Acme Panel</div>,
);
render(<ChatHost />);
act(() => {
jest.spyOn(console, 'warn').mockImplementation(() => {});
chat.registerChat(
{ id: 'fixed.chat', name: 'Fixed Chat' },
() => <div>Fixed Bubble</div>,
() => <div>Fixed Panel</div>,
);
jest.restoreAllMocks();
});
// Different id: boundary key changes, latch resets, fix renders.
expect(screen.getByText('Fixed Bubble')).toBeInTheDocument();
});

View File

@@ -0,0 +1,133 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { type ComponentType, useRef } from 'react';
import { t } from '@apache-superset/core/translation';
import { logging } from '@apache-superset/core/utils';
import { css, useTheme } from '@apache-superset/core/theme';
import { ErrorBoundary } from 'src/components/ErrorBoundary';
import { addDangerToast } from 'src/components/MessageToasts/actions';
import { store } from 'src/views/store';
import { useChat } from '.';
const CHAT_EDGE_MARGIN = 24;
/**
* Returns an onError handler that shows a toast on crash, once per chat id.
*/
function useCrashNotifier(chatId: string | undefined) {
const notifiedFor = useRef<string | undefined>(undefined);
return (error: Error) => {
if (!chatId) return;
logging.error('[chat] provider crashed', error);
if (notifiedFor.current !== chatId) {
notifiedFor.current = chatId;
store.dispatch(addDangerToast(t('The chat failed to load.')));
}
};
}
/**
* Wraps a component in an ErrorBoundary, keyed by chat id so the boundary
* resets when a different chat takes over.
*/
const ChatBoundary = ({
component: Component,
onError,
}: {
component: ComponentType;
onError: (error: Error) => void;
}) => (
<ErrorBoundary showMessage={false} onError={onError}>
<Component />
</ErrorBoundary>
);
/**
* Renders the chat panel content in panel mode. Fills its container height.
*/
export const ChatPanelHost = () => {
const { chat, panel } = useChat();
const onError = useCrashNotifier(chat?.id);
if (!chat || !panel) {
return null;
}
return (
<div
data-test="chat-mount"
css={css`
display: flex;
flex-direction: column;
height: 100%;
`}
>
<ChatBoundary key={chat.id} component={panel} onError={onError} />
</div>
);
};
/**
* Renders the chat trigger and, when the panel is open in floating mode, the
* floating panel overlay. The trigger is always visible when a chat is
* registered; the panel overlay is suppressed in panel mode.
*/
export const ChatFloatingHost = () => {
const theme = useTheme();
const { open: panelOpen, mode, chat, trigger, panel } = useChat();
const onError = useCrashNotifier(chat?.id);
if (!chat || !trigger || !panel) {
return null;
}
return (
<div
data-test="chat-mount"
css={css`
position: fixed;
right: ${CHAT_EDGE_MARGIN}px;
bottom: ${CHAT_EDGE_MARGIN}px;
display: flex;
flex-direction: column;
align-items: flex-end;
gap: ${theme.sizeUnit * 2}px;
/* Above dashboard content and the toast layer, below modal dialogs. */
z-index: ${theme.zIndexPopupBase + 2};
`}
>
{/*
Separate boundaries so a crashing panel cannot take the trigger down
with it — the trigger is the user's only way back.
*/}
{panelOpen && mode !== 'panel' && (
<ChatBoundary
key={`panel-${chat.id}`}
component={panel}
onError={onError}
/>
)}
<ChatBoundary
key={`trigger-${chat.id}`}
component={trigger}
onError={onError}
/>
</div>
);
};

View File

@@ -0,0 +1,257 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { createElement } from 'react';
import ChatProvider from './ChatProvider';
const trigger = () => createElement('button', null, 'Bubble');
const panel = () => createElement('div', null, 'Panel');
beforeEach(() => {
ChatProvider.getInstance().reset();
});
test('returns the singleton instance', () => {
expect(ChatProvider.getInstance()).toBe(ChatProvider.getInstance());
});
test('getChat returns undefined when no chat is registered', () => {
expect(ChatProvider.getInstance().getChat()).toBeUndefined();
});
test('registerChat sets the registration and returns the descriptor copy', () => {
const provider = ChatProvider.getInstance();
const descriptor = { id: 'acme.chat', name: 'Acme Chat' };
const disposable = provider.registerChat(descriptor, trigger, panel);
expect(provider.getChat()).toEqual(descriptor);
disposable.dispose();
});
test('the last-registered chat wins and logs a warning', () => {
const provider = ChatProvider.getInstance();
const warn = jest.spyOn(console, 'warn').mockImplementation(() => {});
provider.registerChat({ id: 'first.chat', name: 'First' }, trigger, panel);
provider.registerChat({ id: 'second.chat', name: 'Second' }, trigger, panel);
expect(provider.getChat()?.id).toBe('second.chat');
expect(warn).toHaveBeenCalledTimes(1);
expect(warn.mock.calls[0][0]).toContain('second.chat');
expect(warn.mock.calls[0][0]).toContain('first.chat');
warn.mockRestore();
});
test('re-registering with a different id replaces the active chat', () => {
const provider = ChatProvider.getInstance();
jest.spyOn(console, 'warn').mockImplementation(() => {});
provider.registerChat({ id: 'first.chat', name: 'First' }, trigger, panel);
expect(provider.getChat()?.id).toBe('first.chat');
provider.registerChat({ id: 'second.chat', name: 'Second' }, trigger, panel);
expect(provider.getChat()?.id).toBe('second.chat');
jest.restoreAllMocks();
});
test('disposing the registration clears it', () => {
const provider = ChatProvider.getInstance();
const disposable = provider.registerChat(
{ id: 'acme.chat', name: 'Acme' },
trigger,
panel,
);
disposable.dispose();
expect(provider.getChat()).toBeUndefined();
});
test('disposing twice fires unregister only once', () => {
const provider = ChatProvider.getInstance();
const unregistered = jest.fn();
provider.onDidUnregisterChat(unregistered);
const disposable = provider.registerChat(
{ id: 'acme.chat', name: 'Acme' },
trigger,
panel,
);
disposable.dispose();
disposable.dispose();
expect(unregistered).toHaveBeenCalledTimes(1);
});
test('onDidRegisterChat and onDidUnregisterChat fire with the descriptor', () => {
const provider = ChatProvider.getInstance();
const registered = jest.fn();
const unregistered = jest.fn();
provider.onDidRegisterChat(registered);
provider.onDidUnregisterChat(unregistered);
const descriptor = { id: 'acme.chat', name: 'Acme' };
const disposable = provider.registerChat(descriptor, trigger, panel);
expect(registered).toHaveBeenCalledWith(descriptor);
expect(unregistered).not.toHaveBeenCalled();
disposable.dispose();
expect(unregistered).toHaveBeenCalledWith(descriptor);
});
test('open and close toggle the panel state', () => {
const provider = ChatProvider.getInstance();
provider.registerChat({ id: 'acme.chat', name: 'Acme' }, trigger, panel);
expect(provider.isOpen()).toBe(false);
provider.open();
expect(provider.isOpen()).toBe(true);
provider.close();
expect(provider.isOpen()).toBe(false);
});
test('open fires once; duplicate open is a no-op', () => {
const provider = ChatProvider.getInstance();
const opened = jest.fn();
provider.onDidOpen(opened);
provider.registerChat({ id: 'acme.chat', name: 'Acme' }, trigger, panel);
provider.open();
provider.open();
expect(opened).toHaveBeenCalledTimes(1);
});
test('close fires once; duplicate close is a no-op', () => {
const provider = ChatProvider.getInstance();
const closed = jest.fn();
provider.onDidClose(closed);
provider.registerChat({ id: 'acme.chat', name: 'Acme' }, trigger, panel);
provider.open();
provider.close();
provider.close();
expect(closed).toHaveBeenCalledTimes(1);
});
test('open is a no-op when no chat is registered', () => {
const provider = ChatProvider.getInstance();
const opened = jest.fn();
provider.onDidOpen(opened);
provider.open();
expect(provider.isOpen()).toBe(false);
expect(opened).not.toHaveBeenCalled();
});
test('registering a second chat while open closes the panel', () => {
const provider = ChatProvider.getInstance();
const closed = jest.fn();
provider.onDidClose(closed);
jest.spyOn(console, 'warn').mockImplementation(() => {});
provider.registerChat({ id: 'first.chat', name: 'First' }, trigger, panel);
provider.open();
provider.registerChat({ id: 'second.chat', name: 'Second' }, trigger, panel);
expect(provider.isOpen()).toBe(false);
expect(closed).toHaveBeenCalledTimes(1);
jest.restoreAllMocks();
});
test('disposing the active chat while open closes the panel', () => {
const provider = ChatProvider.getInstance();
const closed = jest.fn();
provider.onDidClose(closed);
const disposable = provider.registerChat(
{ id: 'acme.chat', name: 'Acme' },
trigger,
panel,
);
provider.open();
disposable.dispose();
expect(provider.isOpen()).toBe(false);
expect(closed).toHaveBeenCalledTimes(1);
});
test('a late registration does not inherit a stale open state', () => {
const provider = ChatProvider.getInstance();
const disposable = provider.registerChat(
{ id: 'acme.chat', name: 'Acme' },
trigger,
panel,
);
provider.open();
disposable.dispose();
provider.registerChat({ id: 'late.chat', name: 'Late' }, trigger, panel);
expect(provider.isOpen()).toBe(false);
});
test('getDisplayMode defaults to floating', () => {
expect(ChatProvider.getInstance().getDisplayMode()).toBe('floating');
});
test('setDisplayMode updates mode and fires event only on change', () => {
const provider = ChatProvider.getInstance();
const modeChanged = jest.fn();
provider.onDidChangeDisplayMode(modeChanged);
provider.setDisplayMode('floating');
expect(modeChanged).not.toHaveBeenCalled();
provider.setDisplayMode('panel');
expect(provider.getDisplayMode()).toBe('panel');
expect(modeChanged).toHaveBeenCalledWith('panel');
});
test('state reflects changes after registration and open', () => {
const provider = ChatProvider.getInstance();
expect(provider.getChat()).toBeUndefined();
expect(provider.isOpen()).toBe(false);
provider.registerChat({ id: 'acme.chat', name: 'Acme' }, trigger, panel);
provider.open();
expect(provider.isOpen()).toBe(true);
expect(provider.getChat()?.id).toBe('acme.chat');
});
test('reset clears all state', () => {
const provider = ChatProvider.getInstance();
provider.registerChat({ id: 'acme.chat', name: 'Acme' }, trigger, panel);
provider.open();
provider.setDisplayMode('panel');
provider.reset();
expect(provider.getChat()).toBeUndefined();
expect(provider.isOpen()).toBe(false);
expect(provider.getDisplayMode()).toBe('floating');
});

View File

@@ -0,0 +1,209 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { ComponentType } from 'react';
import type { chat as chatApi } from '@apache-superset/core';
import {
LocalStorageKeys,
getItem,
setItem,
} from 'src/utils/localStorageHelpers';
import { Disposable } from '../models';
import { createValueEventEmitter, createEventEmitter } from '../utils';
type Chat = chatApi.Chat;
type DisplayMode = chatApi.DisplayMode;
/**
* Singleton manager for the chat provider.
* Handles registration, open/close state, and display mode.
*/
class ChatProvider {
private static instance: ChatProvider;
private chat: Chat | undefined;
private trigger: ComponentType | undefined;
private panel: ComponentType | undefined;
private opened: boolean;
private stateSubscribers = new Set<() => void>();
private registerEmitter = createEventEmitter<Chat>();
private unregisterEmitter = createEventEmitter<Chat>();
private openEmitter = createEventEmitter<void>();
private closeEmitter = createEventEmitter<void>();
private resizePanelEmitter = createEventEmitter<{ width: number }>();
private modeEmitter: ReturnType<typeof createValueEventEmitter<DisplayMode>>;
private constructor() {
const persisted = getItem(LocalStorageKeys.ChatState, {
open: false,
mode: 'floating',
});
const mode = (
persisted.mode === 'panel' ? 'panel' : 'floating'
) as DisplayMode;
this.opened = persisted.open === true;
this.modeEmitter = createValueEventEmitter<DisplayMode>(mode);
}
public static getInstance(): ChatProvider {
if (!ChatProvider.instance) {
ChatProvider.instance = new ChatProvider();
}
return ChatProvider.instance;
}
public subscribe = (listener: () => void): (() => void) => {
this.stateSubscribers.add(listener);
return () => this.stateSubscribers.delete(listener);
};
private notifyState(): void {
setItem(LocalStorageKeys.ChatState, {
open: this.opened,
mode: this.modeEmitter.getCurrent(),
});
this.stateSubscribers.forEach(fn => fn());
}
private closePanel(): void {
this.opened = false;
this.closeEmitter.fire();
}
public registerChat(
chat: Chat,
trigger: ComponentType,
panel: ComponentType,
): Disposable {
if (this.chat) {
// eslint-disable-next-line no-console
console.warn(
`[Superset] Multiple chat extensions registered. Using "${chat.id}"; discarding "${this.chat.id}".`,
);
this.unregisterEmitter.fire(this.chat);
if (this.opened) this.closePanel();
}
this.chat = chat;
this.trigger = trigger;
this.panel = panel;
this.registerEmitter.fire(chat);
this.notifyState();
return new Disposable(() => {
if (this.chat !== chat) return;
this.chat = undefined;
this.trigger = undefined;
this.panel = undefined;
this.unregisterEmitter.fire(chat);
if (this.opened) this.closePanel();
this.notifyState();
});
}
public getChat(): Chat | undefined {
return this.chat;
}
public getTrigger(): ComponentType | undefined {
return this.trigger;
}
public getPanel(): ComponentType | undefined {
return this.panel;
}
public open(): void {
if (this.opened || !this.chat) return;
this.opened = true;
this.openEmitter.fire();
this.notifyState();
}
public close(): void {
if (!this.opened || !this.chat) return;
this.closePanel();
this.notifyState();
}
public isOpen(): boolean {
return this.opened;
}
public getDisplayMode(): DisplayMode {
return this.modeEmitter.getCurrent();
}
public setDisplayMode(displayMode: DisplayMode): void {
if (displayMode === this.modeEmitter.getCurrent()) return;
this.modeEmitter.fire(displayMode);
this.notifyState();
}
public get onDidRegisterChat() {
return this.registerEmitter.subscribe;
}
public get onDidUnregisterChat() {
return this.unregisterEmitter.subscribe;
}
public get onDidOpen() {
return this.openEmitter.subscribe;
}
public get onDidClose() {
return this.closeEmitter.subscribe;
}
public get onDidChangeDisplayMode() {
return this.modeEmitter.subscribe;
}
public get onDidResizePanel() {
return this.resizePanelEmitter.subscribe;
}
public reset(): void {
this.chat = undefined;
this.trigger = undefined;
this.panel = undefined;
this.opened = false;
this.registerEmitter = createEventEmitter<Chat>();
this.unregisterEmitter = createEventEmitter<Chat>();
this.openEmitter = createEventEmitter<void>();
this.closeEmitter = createEventEmitter<void>();
this.resizePanelEmitter = createEventEmitter<{ width: number }>();
this.modeEmitter = createValueEventEmitter<DisplayMode>('floating');
this.stateSubscribers.clear();
setItem(LocalStorageKeys.ChatState, { open: false, mode: 'floating' });
}
}
export default ChatProvider;

View File

@@ -0,0 +1,68 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { createElement } from 'react';
import { chat } from './index';
import ChatProvider from './ChatProvider';
const trigger = () => createElement('button', null, 'Bubble');
const panel = () => createElement('div', null, 'Panel');
beforeEach(() => {
ChatProvider.getInstance().reset();
});
test('getChat returns undefined when no chat is registered', () => {
expect(chat.getChat()).toBeUndefined();
});
test('registerChat makes the chat retrievable via getChat', () => {
const descriptor = { id: 'acme.chat', name: 'Acme Chat' };
chat.registerChat(descriptor, trigger, panel);
expect(chat.getChat()).toEqual(descriptor);
});
test('the last-registered chat wins when multiple are registered', () => {
jest.spyOn(console, 'warn').mockImplementation(() => {});
chat.registerChat({ id: 'first.chat', name: 'First' }, trigger, panel);
chat.registerChat({ id: 'second.chat', name: 'Second' }, trigger, panel);
expect(chat.getChat()?.id).toBe('second.chat');
jest.restoreAllMocks();
});
test('open and close toggle isOpen', () => {
chat.registerChat({ id: 'acme.chat', name: 'Acme' }, trigger, panel);
expect(chat.isOpen()).toBe(false);
chat.open();
expect(chat.isOpen()).toBe(true);
chat.close();
expect(chat.isOpen()).toBe(false);
});
test('getDisplayMode defaults to floating', () => {
expect(chat.getDisplayMode()).toBe('floating');
});
test('setDisplayMode updates the display mode', () => {
chat.setDisplayMode('panel');
expect(chat.getDisplayMode()).toBe('panel');
});

View File

@@ -0,0 +1,82 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* @fileoverview Host implementation of the `chat` contribution type.
*
* Extensions register via the public `chat.registerChat()` and the host owns
* mounting, open/close state, and the display mode. Only the last-registered
* chat is active at a time.
*
* The public namespace (`chat`) is exposed to extensions on `window.superset`.
* `useChat` is host-internal and NOT part of the public `@apache-superset/core` API.
*/
import { useSyncExternalStore } from 'react';
import memoizeOne from 'memoize-one';
import type { chat as chatApi } from '@apache-superset/core';
import ChatProvider from './ChatProvider';
export { ChatFloatingHost, ChatPanelHost } from './ChatHost';
const provider = ChatProvider.getInstance();
const buildSnapshot = memoizeOne(
(
open: boolean,
mode: chatApi.DisplayMode,
chat: chatApi.Chat | undefined,
trigger: ReturnType<typeof provider.getTrigger>,
panel: ReturnType<typeof provider.getPanel>,
) => ({ open, mode, chat, trigger, panel }),
);
const getSnapshot = () =>
buildSnapshot(
provider.isOpen(),
provider.getDisplayMode(),
provider.getChat(),
provider.getTrigger(),
provider.getPanel(),
);
/**
* Host-internal hook. Returns the current open/mode state and the active chat
* (trigger, panel, descriptor).
*/
export const useChat = () =>
useSyncExternalStore(provider.subscribe, getSnapshot);
export const chat: typeof chatApi = {
registerChat: provider.registerChat.bind(provider),
getChat: provider.getChat.bind(provider),
onDidRegisterChat: provider.onDidRegisterChat,
onDidUnregisterChat: provider.onDidUnregisterChat,
open: provider.open.bind(provider),
close: provider.close.bind(provider),
isOpen: provider.isOpen.bind(provider),
onDidOpen: provider.onDidOpen,
onDidClose: provider.onDidClose,
getDisplayMode: provider.getDisplayMode.bind(provider),
setDisplayMode: provider.setDisplayMode.bind(provider),
onDidChangeDisplayMode: provider.onDidChangeDisplayMode,
// The host fires this from its panel resizer; until that chrome exists the
// event is exposed but never fires.
onDidResizePanel: provider.onDidResizePanel,
};

View File

@@ -254,33 +254,6 @@ test('event listeners can be disposed', () => {
expect(listener).toHaveBeenCalledTimes(1); // Still only 1 call
});
test('handles errors in event listeners gracefully', () => {
const manager = EditorProviders.getInstance();
const consoleErrorSpy = jest.spyOn(console, 'error').mockImplementation();
const errorListener = jest.fn(() => {
throw new Error('Listener error');
});
const successListener = jest.fn();
manager.onDidRegister(errorListener);
manager.onDidRegister(successListener);
manager.registerProvider(createMockEditor(), createMockEditorComponent());
// Both listeners should have been called
expect(errorListener).toHaveBeenCalledTimes(1);
expect(successListener).toHaveBeenCalledTimes(1);
// Error should have been logged
expect(consoleErrorSpy).toHaveBeenCalledWith(
'Error in event listener:',
expect.any(Error),
);
consoleErrorSpy.mockRestore();
});
test('reset clears all providers and language mappings', () => {
const manager = EditorProviders.getInstance();

View File

@@ -19,6 +19,7 @@
import type { editors } from '@apache-superset/core';
import { Disposable } from '../models';
import { createEventEmitter } from '../utils';
type EditorLanguage = editors.EditorLanguage;
type EditorProvider = editors.EditorProvider;
@@ -27,45 +28,8 @@ type EditorComponent = editors.EditorComponent;
type EditorRegisteredEvent = editors.EditorRegisteredEvent;
type EditorUnregisteredEvent = editors.EditorUnregisteredEvent;
/**
* Listener function type for events.
*/
type Listener<T> = (e: T) => void;
/**
* Simple event emitter for editor provider lifecycle events.
*/
class EventEmitter<T> {
private listeners: Set<Listener<T>> = new Set();
/**
* Subscribe to this event.
* @param listener The listener function to call when the event is fired.
* @returns A Disposable to unsubscribe from the event.
*/
subscribe(listener: Listener<T>): Disposable {
this.listeners.add(listener);
return new Disposable(() => {
this.listeners.delete(listener);
});
}
/**
* Fire the event with the given data.
* @param data The event data to pass to listeners.
*/
fire(data: T): void {
this.listeners.forEach(listener => {
try {
listener(data);
} catch (error) {
// eslint-disable-next-line no-console
console.error('Error in event listener:', error);
}
});
}
}
/**
* Singleton manager for editor providers.
* Handles registration, resolution, and lifecycle of custom editor implementations.
@@ -83,15 +47,9 @@ class EditorProviders {
*/
private languageToProvider: Map<EditorLanguage, string> = new Map();
/**
* Event emitter for provider registration events.
*/
private registerEmitter = new EventEmitter<EditorRegisteredEvent>();
private registerEmitter = createEventEmitter<EditorRegisteredEvent>();
/**
* Event emitter for provider unregistration events.
*/
private unregisterEmitter = new EventEmitter<EditorUnregisteredEvent>();
private unregisterEmitter = createEventEmitter<EditorUnregisteredEvent>();
private syncListeners: Set<() => void> = new Set();
@@ -226,8 +184,11 @@ class EditorProviders {
* @param listener The listener function.
* @returns A Disposable to unsubscribe.
*/
public onDidRegister(listener: Listener<EditorRegisteredEvent>): Disposable {
return this.registerEmitter.subscribe(listener);
public onDidRegister(
listener: Listener<EditorRegisteredEvent>,
thisArgs?: unknown,
): Disposable {
return this.registerEmitter.subscribe(listener, thisArgs);
}
/**
@@ -237,8 +198,9 @@ class EditorProviders {
*/
public onDidUnregister(
listener: Listener<EditorUnregisteredEvent>,
thisArgs?: unknown,
): Disposable {
return this.unregisterEmitter.subscribe(listener);
return this.unregisterEmitter.subscribe(listener, thisArgs);
}
/**
@@ -248,6 +210,8 @@ class EditorProviders {
this.providers.clear();
this.languageToProvider.clear();
this.syncListeners.clear();
this.registerEmitter = createEventEmitter<EditorRegisteredEvent>();
this.unregisterEmitter = createEventEmitter<EditorUnregisteredEvent>();
}
}

View File

@@ -18,130 +18,39 @@
*/
/**
* @fileoverview Implementation of the editors API for Superset.
* @fileoverview Host implementation of the `editors` contribution type.
*
* This module provides the runtime implementation of the editor registration
* and resolution functions declared in the API types.
* Extensions register via the public `editors.registerEditor()` and the host
* resolves the appropriate provider per language, falling back to the built-in
* AceEditorProvider when no extension is registered.
*
* The public namespace (`editors`) is exposed to extensions on `window.superset`.
* `EditorHost` is the host-internal component for rendering editors and is NOT
* part of the public `@apache-superset/core` API.
*/
import { useSyncExternalStore } from 'react';
import { editors as editorsApi } from '@apache-superset/core';
import { Disposable } from '../models';
import EditorProviders from './EditorProviders';
type EditorLanguage = editorsApi.EditorLanguage;
type Editor = editorsApi.Editor;
type EditorProvider = editorsApi.EditorProvider;
type EditorComponent = editorsApi.EditorComponent;
type EditorRegisteredEvent = editorsApi.EditorRegisteredEvent;
type EditorUnregisteredEvent = editorsApi.EditorUnregisteredEvent;
export type { EditorHostProps } from './EditorHost';
export { default as EditorHost } from './EditorHost';
export { default as AceEditorProvider } from './AceEditorProvider';
/**
* Register an editor provider as a module-level side effect.
* Takes the editor descriptor directly rather than looking it up
* from a manifest by ID.
*
* @param editor The editor descriptor.
* @param component The React component implementing the editor.
* @returns A Disposable to unregister the provider.
*/
export const registerEditor = (
editor: Editor,
component: EditorComponent,
): Disposable => {
const providers = EditorProviders.getInstance();
return providers.registerProvider(editor, component);
};
const provider = EditorProviders.getInstance();
/**
* Get the editor provider for a specific language.
* Returns the extension's editor if registered, otherwise undefined.
*
* @param language The language to get an editor for
* @returns The editor provider or undefined if no extension provides one
*/
export const getEditor = (
language: EditorLanguage,
): EditorProvider | undefined => {
const manager = EditorProviders.getInstance();
return manager.getProvider(language);
};
/**
* Check if an extension has registered an editor for a language.
*
* @param language The language to check
* @returns True if an extension provides an editor for this language
*/
export const hasEditor = (language: EditorLanguage): boolean => {
const manager = EditorProviders.getInstance();
return manager.hasProvider(language);
};
/**
* Get all registered editor providers.
*
* @returns Array of all registered editor providers
*/
export const getAllEditors = (): EditorProvider[] => {
const manager = EditorProviders.getInstance();
return manager.getAllProviders();
};
/**
* Event fired when an editor is registered.
* Subscribe to this event to react when extensions register new editors.
*/
export const onDidRegisterEditor = (
listener: (e: EditorRegisteredEvent) => void,
): Disposable => {
const manager = EditorProviders.getInstance();
return manager.onDidRegister(listener);
};
/**
* Event fired when an editor is unregistered.
* Subscribe to this event to react when extensions unregister editors.
*/
export const onDidUnregisterEditor = (
listener: (e: EditorUnregisteredEvent) => void,
): Disposable => {
const manager = EditorProviders.getInstance();
return manager.onDidUnregister(listener);
};
/**
* Hook that returns the editor provider for a specific language and re-renders when it changes.
*
* @param language The language to get an editor for
* @returns The editor provider or undefined if no extension provides one
*/
export const useEditor = (
language: EditorLanguage,
): EditorProvider | undefined => {
const manager = EditorProviders.getInstance();
return useSyncExternalStore(
manager.subscribe,
() => manager.getProvider(language),
export const useEditor = (language: editorsApi.EditorLanguage) =>
useSyncExternalStore(
provider.subscribe,
() => provider.getProvider(language),
() => undefined,
);
};
/**
* Editors API object for use in the extension system.
*/
export const editors: typeof editorsApi = {
registerEditor,
getEditor,
hasEditor,
getAllEditors,
onDidRegisterEditor,
onDidUnregisterEditor,
registerEditor: provider.registerProvider.bind(provider),
getEditor: provider.getProvider.bind(provider),
hasEditor: provider.hasProvider.bind(provider),
getAllEditors: provider.getAllProviders.bind(provider),
onDidRegisterEditor: provider.onDidRegister.bind(provider),
onDidUnregisterEditor: provider.onDidUnregister.bind(provider),
};
export { EditorProviders };
// Component exports
export { default as EditorHost } from './EditorHost';
export type { EditorHostProps } from './EditorHost';
export { default as AceEditorProvider } from './AceEditorProvider';

View File

@@ -27,11 +27,13 @@ export const core: typeof coreType = {
};
export * from './authentication';
export * from './chat';
export * from './commands';
export * from './editors';
export * from './extensions';
export * from './menus';
export * from './models';
export * from './navigation';
export * from './sqlLab';
export * from './utils';
export * from './views';

View File

@@ -27,6 +27,7 @@
import { useSyncExternalStore } from 'react';
import type { menus as menusApi } from '@apache-superset/core';
import { Disposable } from '../models';
import { createEventEmitter } from '../utils';
type MenuItem = menusApi.MenuItem;
type Menu = menusApi.Menu;
@@ -47,19 +48,19 @@ const subscribe = (listener: () => void) => {
return () => syncListeners.delete(listener);
};
const registerListeners = new Set<(e: MenuItemRegisteredEvent) => void>();
const unregisterListeners = new Set<(e: MenuItemUnregisteredEvent) => void>();
const registerEmitter = createEventEmitter<MenuItemRegisteredEvent>();
const unregisterEmitter = createEventEmitter<MenuItemUnregisteredEvent>();
const menuCache = new Map<string, Menu | undefined>();
const notifyRegister = (event: MenuItemRegisteredEvent) => {
menuCache.clear();
syncListeners.forEach(l => l());
registerListeners.forEach(l => l(event));
registerEmitter.fire(event);
};
const notifyUnregister = (event: MenuItemUnregisteredEvent) => {
menuCache.clear();
syncListeners.forEach(l => l());
unregisterListeners.forEach(l => l(event));
unregisterEmitter.fire(event);
};
const registerMenuItem: typeof menusApi.registerMenuItem = (
@@ -117,16 +118,14 @@ export const useMenu = (location: string): Menu | undefined =>
export const onDidRegisterMenuItem: typeof menusApi.onDidRegisterMenuItem = (
listener: (e: MenuItemRegisteredEvent) => void,
): Disposable => {
registerListeners.add(listener);
return new Disposable(() => registerListeners.delete(listener));
};
thisArgs?: unknown,
): Disposable => registerEmitter.subscribe(listener, thisArgs);
export const onDidUnregisterMenuItem: typeof menusApi.onDidUnregisterMenuItem =
(listener: (e: MenuItemUnregisteredEvent) => void): Disposable => {
unregisterListeners.add(listener);
return new Disposable(() => unregisterListeners.delete(listener));
};
(
listener: (e: MenuItemUnregisteredEvent) => void,
thisArgs?: unknown,
): Disposable => unregisterEmitter.subscribe(listener, thisArgs);
export const menus: typeof menusApi = {
registerMenuItem,

View File

@@ -0,0 +1,124 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
// Reset module state between tests so currentPage is re-initialized.
beforeEach(() => {
jest.resetModules();
Object.defineProperty(window, 'location', {
writable: true,
value: { pathname: '/' },
});
});
async function importNavigation() {
const mod = await import('./index');
return mod;
}
test('getPage falls back to "home" for the welcome page and unknown pathnames', async () => {
const { navigation, notifyLocationChanged } = await importNavigation();
// The default pathname ('/') is not enumerated and falls back to home.
expect(navigation.getPage()).toBe('home');
notifyLocationChanged('/superset/welcome/');
expect(navigation.getPage()).toBe('home');
});
test('getPage derives the page from window.location.pathname', async () => {
window.location.pathname = '/superset/dashboard/42/';
const { navigation } = await importNavigation();
expect(navigation.getPage()).toBe('dashboard');
});
test('notifyLocationChanged updates the current page type', async () => {
const { navigation, notifyLocationChanged } = await importNavigation();
notifyLocationChanged('/explore/?form_data={}');
expect(navigation.getPage()).toBe('explore');
});
test('notifyLocationChanged fires listeners on page type change', async () => {
const { navigation, notifyLocationChanged } = await importNavigation();
const listener = jest.fn();
const disposable = navigation.onDidChangePage(listener);
notifyLocationChanged('/superset/dashboard/1/');
expect(listener).toHaveBeenCalledWith('dashboard');
disposable.dispose();
});
test('notifyLocationChanged does not fire listeners when page type is unchanged', async () => {
window.location.pathname = '/superset/dashboard/1/';
const { navigation, notifyLocationChanged } = await importNavigation();
const listener = jest.fn();
navigation.onDidChangePage(listener);
notifyLocationChanged('/superset/dashboard/2/');
expect(listener).not.toHaveBeenCalled();
});
test('onDidChangePage listener is removed after dispose', async () => {
const { navigation, notifyLocationChanged } = await importNavigation();
const listener = jest.fn();
const disposable = navigation.onDidChangePage(listener);
disposable.dispose();
notifyLocationChanged('/superset/dashboard/1/');
expect(listener).not.toHaveBeenCalled();
});
test('sqllab path is matched with and without trailing slash', async () => {
const { notifyLocationChanged, navigation } = await importNavigation();
notifyLocationChanged('/sqllab');
expect(navigation.getPage()).toBe('sqllab');
notifyLocationChanged('/explore/');
notifyLocationChanged('/sqllab/');
expect(navigation.getPage()).toBe('sqllab');
});
test('chart and dashboard list pages get their own page types', async () => {
const { notifyLocationChanged, navigation } = await importNavigation();
notifyLocationChanged('/chart/list/');
expect(navigation.getPage()).toBe('chart_list');
notifyLocationChanged('/dashboard/list/');
expect(navigation.getPage()).toBe('dashboard_list');
});
test('dataset list and single-dataset pages get distinct page types', async () => {
const { notifyLocationChanged, navigation } = await importNavigation();
notifyLocationChanged('/tablemodelview/list/');
expect(navigation.getPage()).toBe('dataset_list');
notifyLocationChanged('/dataset/42');
expect(navigation.getPage()).toBe('dataset');
});
test('sqllab editor, query history, and saved queries get distinct page types', async () => {
const { notifyLocationChanged, navigation } = await importNavigation();
notifyLocationChanged('/sqllab/');
expect(navigation.getPage()).toBe('sqllab');
notifyLocationChanged('/sqllab/history/');
expect(navigation.getPage()).toBe('query_history');
notifyLocationChanged('/savedqueryview/list/');
expect(navigation.getPage()).toBe('saved_queries');
});
test('chart/add resolves to explore, not chart_list', async () => {
const { notifyLocationChanged, navigation } = await importNavigation();
notifyLocationChanged('/chart/add');
expect(navigation.getPage()).toBe('explore');
});

View File

@@ -0,0 +1,94 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* @fileoverview Host-internal implementation of the `navigation` namespace.
*
* Derives the current {@link Page} from the browser location by matching
* against {@link RoutePaths}. Call {@link useNavigationTracker} once in the
* app shell to keep the page in sync with React Router.
*/
import { useEffect, useRef } from 'react';
import { useLocation, matchPath } from 'react-router-dom';
import type { navigation as navigationApi } from '@apache-superset/core';
import { RoutePaths } from '../../views/routePaths';
import { Disposable } from '../models';
import { createValueEventEmitter } from '../utils';
type Page = navigationApi.Page;
/** Maps route path patterns to their corresponding Page type. */
const PAGE_ROUTES: { path: string; page: Page }[] = [
{ path: RoutePaths.DASHBOARD, page: 'dashboard' },
{ path: RoutePaths.DASHBOARD_LIST, page: 'dashboard_list' },
{ path: RoutePaths.QUERY_HISTORY, page: 'query_history' },
{ path: RoutePaths.SAVED_QUERIES, page: 'saved_queries' },
{ path: RoutePaths.SQLLAB, page: 'sqllab' },
{ path: RoutePaths.CHART_ADD, page: 'explore' },
{ path: RoutePaths.CHART_LIST, page: 'chart_list' },
{ path: RoutePaths.EXPLORE, page: 'explore' },
{ path: RoutePaths.EXPLORE_PERMALINK, page: 'explore' },
{ path: RoutePaths.DATASET_LIST, page: 'dataset_list' },
{ path: RoutePaths.DATASET_ADD, page: 'dataset' },
{ path: RoutePaths.DATASET, page: 'dataset' },
];
function derivePage(pathname: string): Page {
for (const { path, page } of PAGE_ROUTES) {
if (matchPath(pathname, { path, exact: false })) return page;
}
return 'home';
}
const pageEmitter = createValueEventEmitter<Page>(
derivePage(window.location.pathname),
);
/** Updates the current page from a pathname. No-op when the page is unchanged. */
export const notifyLocationChanged = (pathname: string): void => {
const next = derivePage(pathname);
if (next === pageEmitter.getCurrent()) return;
pageEmitter.fire(next);
};
const getPage: typeof navigationApi.getPage = () => pageEmitter.getCurrent();
const onDidChangePage: typeof navigationApi.onDidChangePage = (
listener: (page: Page) => void,
thisArgs?: unknown,
): Disposable => pageEmitter.subscribe(listener, thisArgs);
/** Synchronizes the navigation module with React Router. Call once in the app shell. */
export const useNavigationTracker = () => {
const location = useLocation();
const prevPathname = useRef<string | null>(null);
useEffect(() => {
if (prevPathname.current !== location.pathname) {
prevPathname.current = location.pathname;
notifyLocationChanged(location.pathname);
}
}, [location.pathname]);
};
export const navigation: typeof navigationApi = {
getPage,
onDidChangePage,
};

View File

@@ -48,7 +48,7 @@ import { AnyListenerPredicate } from '@reduxjs/toolkit';
import type { QueryEditor, SqlLabRootState } from 'src/SqlLab/types';
import { newQueryTabName } from 'src/SqlLab/utils/newQueryTabName';
import { Database, Disposable } from '../models';
import { createActionListener } from '../utils';
import { createActionListener } from '../storeUtils';
import {
Panel,
Tab,

View File

@@ -0,0 +1,48 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import type { common as core } from '@apache-superset/core';
import { listenerMiddleware, RootState, store } from 'src/views/store';
import { AnyListenerPredicate } from '@reduxjs/toolkit';
export function createActionListener<V, A = unknown>(
predicate: AnyListenerPredicate<RootState>,
listener: (v: V) => void,
valueParser: (action: A, state: RootState) => V | null | undefined,
thisArgs?: unknown,
): core.Disposable {
const boundListener = thisArgs ? listener.bind(thisArgs as object) : listener;
const unsubscribe = listenerMiddleware.startListening({
predicate,
effect: action => {
const state = store.getState();
// The predicate already ensures the action matches type A at runtime.
const value = valueParser(action as unknown as A, state);
if (value != null) {
boundListener(value);
}
},
});
return {
dispose: () => {
unsubscribe();
},
};
}

View File

@@ -17,33 +17,54 @@
* under the License.
*/
import type { common as core } from '@apache-superset/core';
import { AnyAction } from 'redux';
import { listenerMiddleware, RootState, store } from 'src/views/store';
import { AnyListenerPredicate } from '@reduxjs/toolkit';
export function createActionListener<V>(
predicate: AnyListenerPredicate<RootState>,
listener: (v: V) => void,
valueParser: (action: AnyAction, state: RootState) => V | null | undefined,
thisArgs?: any,
): core.Disposable {
const boundListener = thisArgs ? listener.bind(thisArgs) : listener;
type Listener<T> = (e: T) => unknown;
const unsubscribe = listenerMiddleware.startListening({
predicate,
effect: (action: AnyAction) => {
const state = store.getState();
const value = valueParser(action, state);
// Skip calling listener if valueParser returns null/undefined
if (value != null) {
boundListener(value);
}
},
});
/** A stateless event emitter exposing a VS Code-style `event` subscriber. */
export interface EventEmitter<T> {
/** Notifies every current subscriber with `value`. */
fire(value: T): void;
/** Registers a listener; returns a Disposable that removes it. */
subscribe: core.Event<T>;
}
/** An event emitter that also retains the last fired value. */
export interface ValueEventEmitter<T> extends EventEmitter<T> {
/** Returns the value last passed to {@link fire} (or the initial value). */
getCurrent(): T;
}
/**
* Creates a stateless event emitter. Listeners registered via `event` receive
* every subsequent `fire`; a returned Disposable removes the listener.
*/
export function createEventEmitter<T>(): EventEmitter<T> {
const listeners = new Set<Listener<T>>();
const subscribe: core.Event<T> = (listener, thisArgs) => {
const bound = thisArgs ? listener.bind(thisArgs) : listener;
listeners.add(bound);
return { dispose: () => listeners.delete(bound) };
};
return {
dispose: () => {
unsubscribe();
},
fire: value => listeners.forEach(fn => fn(value)),
subscribe,
};
}
/**
* Creates a value event emitter seeded with `initial`. Behaves like
* {@link createEventEmitter} but also tracks the last fired value, readable
* via `getCurrent` — useful for state that is both observed and queried.
*/
export function createValueEventEmitter<T>(initial: T): ValueEventEmitter<T> {
const { fire, subscribe } = createEventEmitter<T>();
let current = initial;
return {
fire: value => {
current = value;
fire(value);
},
subscribe,
getCurrent: () => current,
};
}

View File

@@ -24,11 +24,12 @@
* Extensions register views as side effects at import time.
*/
import React, { ReactElement, useSyncExternalStore } from 'react';
import React, { ComponentType, useSyncExternalStore } from 'react';
import type { views as viewsApi } from '@apache-superset/core';
import { ErrorBoundary } from 'src/components/ErrorBoundary';
import ExtensionPlaceholder from 'src/extensions/ExtensionPlaceholder';
import { Disposable } from '../models';
import { createEventEmitter } from '../utils';
type View = viewsApi.View;
type ViewRegisteredEvent = viewsApi.ViewRegisteredEvent;
@@ -36,7 +37,7 @@ type ViewUnregisteredEvent = viewsApi.ViewUnregisteredEvent;
const viewRegistry: Map<
string,
{ view: View; location: string; provider: () => ReactElement }
{ view: View; location: string; component: ComponentType }
> = new Map();
const locationIndex: Map<string, Set<string>> = new Map();
@@ -47,29 +48,29 @@ const subscribe = (listener: () => void) => {
return () => syncListeners.delete(listener);
};
const registerListeners = new Set<(e: ViewRegisteredEvent) => void>();
const unregisterListeners = new Set<(e: ViewUnregisteredEvent) => void>();
const registerEmitter = createEventEmitter<ViewRegisteredEvent>();
const unregisterEmitter = createEventEmitter<ViewUnregisteredEvent>();
const viewsCache = new Map<string, View[] | undefined>();
const notifyRegister = (event: ViewRegisteredEvent) => {
viewsCache.clear();
syncListeners.forEach(l => l());
registerListeners.forEach(l => l(event));
registerEmitter.fire(event);
};
const notifyUnregister = (event: ViewUnregisteredEvent) => {
viewsCache.clear();
syncListeners.forEach(l => l());
unregisterListeners.forEach(l => l(event));
unregisterEmitter.fire(event);
};
const registerView: typeof viewsApi.registerView = (
view: View,
location: string,
provider: () => ReactElement,
component: ComponentType,
): Disposable => {
const { id } = view;
viewRegistry.set(id, { view, location, provider });
viewRegistry.set(id, { view, location, component });
const ids = locationIndex.get(location) ?? new Set();
ids.add(id);
@@ -83,12 +84,16 @@ const registerView: typeof viewsApi.registerView = (
});
};
export const resolveView = (id: string): ReactElement => {
const provider = viewRegistry.get(id)?.provider;
if (!provider) {
export const resolveView = (id: string): React.ReactElement => {
const entry = viewRegistry.get(id);
if (!entry) {
return React.createElement(ExtensionPlaceholder, { id });
}
return React.createElement(ErrorBoundary, null, provider());
return React.createElement(
ErrorBoundary,
null,
React.createElement(entry.component),
);
};
const getViews: typeof viewsApi.getViews = (
@@ -116,17 +121,11 @@ export const useViews = (location: string): View[] | undefined =>
export const onDidRegisterView: typeof viewsApi.onDidRegisterView = (
listener: (e: ViewRegisteredEvent) => void,
): Disposable => {
registerListeners.add(listener);
return new Disposable(() => registerListeners.delete(listener));
};
): Disposable => registerEmitter.subscribe(listener);
export const onDidUnregisterView: typeof viewsApi.onDidUnregisterView = (
listener: (e: ViewUnregisteredEvent) => void,
): Disposable => {
unregisterListeners.add(listener);
return new Disposable(() => unregisterListeners.delete(listener));
};
): Disposable => unregisterEmitter.subscribe(listener);
export const views: typeof viewsApi = {
registerView,

View File

@@ -31,7 +31,6 @@ function createMockExtension(overrides: Partial<Extension> = {}): Extension {
version: '1.0.0',
dependencies: [],
remoteEntry: '',
extensionDependencies: [],
...overrides,
};
}

View File

@@ -72,6 +72,7 @@ afterEach(() => {
test('renders without crashing', () => {
render(<ExtensionsStartup />, {
useRedux: true,
useRouter: true,
initialState: mockInitialState,
});
@@ -88,6 +89,7 @@ test('sets up global superset object when user is logged in', async () => {
render(<ExtensionsStartup />, {
useRedux: true,
useRouter: true,
initialState: mockInitialState,
});
@@ -95,6 +97,7 @@ test('sets up global superset object when user is logged in', async () => {
// Verify the global superset object is set up
expect((window as any).superset).toBeDefined();
expect((window as any).superset.authentication).toBeDefined();
expect((window as any).superset.chat).toBeDefined();
expect((window as any).superset.core).toBeDefined();
expect((window as any).superset.commands).toBeDefined();
expect((window as any).superset.extensions).toBeDefined();
@@ -109,6 +112,7 @@ test('sets up global superset object when user is logged in', async () => {
test('does not set up global superset object when user is not logged in', async () => {
render(<ExtensionsStartup />, {
useRedux: true,
useRouter: true,
initialState: mockInitialStateNoUser,
});
@@ -127,6 +131,7 @@ test('initializes ExtensionsLoader when user is logged in', async () => {
render(<ExtensionsStartup />, {
useRedux: true,
useRouter: true,
initialState: mockInitialState,
});
@@ -144,6 +149,7 @@ test('initializes ExtensionsLoader when user is logged in', async () => {
test('does not initialize ExtensionsLoader when user is not logged in', async () => {
render(<ExtensionsStartup />, {
useRedux: true,
useRouter: true,
initialState: mockInitialStateNoUser,
});
@@ -169,6 +175,7 @@ test('only initializes once even with multiple renders', async () => {
const { rerender } = render(<ExtensionsStartup />, {
useRedux: true,
useRouter: true,
initialState: mockInitialState,
});
@@ -205,6 +212,7 @@ test('initializes ExtensionsLoader when EnableExtensions feature flag is enabled
render(<ExtensionsStartup />, {
useRedux: true,
useRouter: true,
initialState: mockInitialState,
});
@@ -234,6 +242,7 @@ test('does not initialize ExtensionsLoader when EnableExtensions feature flag is
render(<ExtensionsStartup />, {
useRedux: true,
useRouter: true,
initialState: mockInitialState,
});
@@ -268,6 +277,7 @@ test('continues rendering children even when ExtensionsLoader initialization fai
</ExtensionsStartup>,
{
useRedux: true,
useRouter: true,
initialState: mockInitialState,
},
);

View File

@@ -17,41 +17,32 @@
* under the License.
*/
import { useEffect } from 'react';
import { FeatureFlag, isFeatureEnabled } from '@superset-ui/core';
// eslint-disable-next-line no-restricted-syntax
import * as supersetCore from '@apache-superset/core';
import { FeatureFlag, isFeatureEnabled } from '@superset-ui/core';
import {
authentication,
chat,
core,
commands,
editors,
extensions,
menus,
navigation,
useNavigationTracker,
sqlLab,
views,
} from 'src/core';
import { useSelector } from 'react-redux';
import { RootState } from 'src/views/store';
import ExtensionsLoader from './ExtensionsLoader';
declare global {
interface Window {
superset: {
authentication: typeof authentication;
core: typeof core;
commands: typeof commands;
editors: typeof editors;
extensions: typeof extensions;
menus: typeof menus;
sqlLab: typeof sqlLab;
views: typeof views;
};
}
}
import 'src/extensions/Namespaces';
const ExtensionsStartup: React.FC<{ children?: React.ReactNode }> = ({
children,
}) => {
useNavigationTracker();
const userId = useSelector<RootState, number | undefined>(
({ user }) => user.userId,
);
@@ -59,15 +50,19 @@ const ExtensionsStartup: React.FC<{ children?: React.ReactNode }> = ({
useEffect(() => {
if (userId == null) return;
// Provide the implementations for @apache-superset/core
// Provide the implementations for @apache-superset/core.
// Namespaces are listed explicitly — do not spread the core package here,
// as that would leak un-contracted symbols onto window.superset.
window.superset = {
...supersetCore,
authentication,
chat,
core,
commands,
editors,
extensions,
menus,
navigation,
sqlLab,
views,
};

View File

@@ -0,0 +1,60 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* Global `window.superset` type augmentation.
*
* Lives in its own module (rather than inline in ExtensionsStartup) so every
* file that reads or writes `window.superset` — notably ExtensionsLoader —
* sees the type regardless of how files are batched during compilation. Both
* the startup component and the loader import this module for its side effect.
*/
import type {
authentication,
chat,
commands,
core,
editors,
extensions,
menus,
navigation,
sqlLab,
views,
} from 'src/core';
/** The host namespaces exposed to extensions on `window.superset`. */
export interface Namespaces {
authentication: typeof authentication;
core: typeof core;
chat: typeof chat;
commands: typeof commands;
editors: typeof editors;
extensions: typeof extensions;
menus: typeof menus;
navigation: typeof navigation;
sqlLab: typeof sqlLab;
views: typeof views;
}
declare global {
interface Window {
superset: Namespaces;
}
}

View File

@@ -77,18 +77,22 @@ const databaseFixture: DatabaseObject = {
// eslint-disable-next-line no-restricted-globals -- TODO: Migrate from describe blocks
describe('DatabaseModal', () => {
beforeEach(() => {
fetchMock.post(DATABASE_CONNECT_ENDPOINT, {
id: 10,
result: {
configuration_method: 'sqlalchemy_form',
database_name: 'Other2',
driver: 'apsw',
expose_in_sqllab: true,
extra: '{"allows_virtual_table_explore":true}',
sqlalchemy_uri: 'gsheets://',
fetchMock.post(
DATABASE_CONNECT_ENDPOINT,
{
id: 10,
result: {
configuration_method: 'sqlalchemy_form',
database_name: 'Other2',
driver: 'apsw',
expose_in_sqllab: true,
extra: '{"allows_virtual_table_explore":true}',
sqlalchemy_uri: 'gsheets://',
},
json: 'foo',
},
json: 'foo',
});
{ name: 'database-connect' },
);
fetchMock.get(DATABASE_FETCH_ENDPOINT, {
result: {
@@ -311,9 +315,13 @@ describe('DatabaseModal', () => {
},
],
});
fetchMock.post(VALIDATE_PARAMS_ENDPOINT, {
message: 'OK',
});
fetchMock.post(
VALIDATE_PARAMS_ENDPOINT,
{
message: 'OK',
},
{ name: 'validate-params' },
);
});
beforeEach(() => {
@@ -1367,14 +1375,16 @@ describe('DatabaseModal', () => {
const textboxes = screen.getAllByRole('textbox');
const hostField = textboxes[0];
const portField = screen.getByRole('spinbutton');
const databaseNameField = textboxes[1];
// textboxes[1] is the connection `database` field; the engine display
// name (`database_name`) auto-fills and is asserted separately.
const databaseField = textboxes[1];
const usernameField = textboxes[2];
const passwordField = textboxes[3];
const connectButton = screen.getByRole('button', { name: 'Connect' });
expect(hostField).toHaveValue('');
expect(portField).toHaveValue(null);
expect(databaseNameField).toHaveValue('');
expect(databaseField).toHaveValue('');
expect(usernameField).toHaveValue('');
expect(passwordField).toHaveValue('');
@@ -1382,7 +1392,7 @@ describe('DatabaseModal', () => {
userEvent.type(hostField, 'localhost');
userEvent.type(portField, '5432');
userEvent.type(databaseNameField, 'postgres');
userEvent.type(databaseField, 'postgres');
userEvent.type(usernameField, 'testdb');
userEvent.type(passwordField, 'demoPassword');
@@ -1391,7 +1401,7 @@ describe('DatabaseModal', () => {
expect(await screen.findByDisplayValue(/5432/i)).toBeInTheDocument();
expect(hostField).toHaveValue('localhost');
expect(portField).toHaveValue(5432);
expect(databaseNameField).toHaveValue('postgres');
expect(databaseField).toHaveValue('postgres');
expect(usernameField).toHaveValue('testdb');
expect(passwordField).toHaveValue('demoPassword');
@@ -1581,6 +1591,135 @@ describe('DatabaseModal', () => {
).toBeInTheDocument();
});
});
// Tests migrated from the deprecated Cypress suite
// (cypress-base/cypress/e2e/database/modal.test.ts). The two "error alert"
// cases originally relied on a real backend connection attempt (real DNS /
// socket behaviour), which made them flaky. They are reproduced here by
// mocking the validate_parameters response: the frontend's responsibility is
// to map an `extra.invalid` field error onto the matching form field, which
// is exactly what these assertions exercise. Whether a bad host/port really
// fails to connect is a backend concern, covered by backend tests.
const selectPostgres = async () => {
userEvent.click(await screen.findByRole('button', { name: /postgresql/i }));
// Dynamic form (step 2 of 3) is now visible
expect(await screen.findByText(/step 2 of 3/i)).toBeInTheDocument();
};
// The modal renders into a portal on document.body, so fields are queried
// from the document rather than the render container.
const fieldByName = (name: string) =>
document.querySelector(`input[name="${name}"]`) as HTMLInputElement;
const fillDynamicForm = () => {
const values: Record<string, string> = {
host: 'badhost',
port: '5432',
database: 'testdb',
username: 'testusername',
password: 'testpass',
};
Object.entries(values).forEach(([name, value]) =>
userEvent.type(fieldByName(name), value),
);
};
test('defaults the display name to the selected engine', async () => {
setup();
await selectPostgres();
// The display name field auto-fills with the selected engine's name. The
// empty initial state of the connection fields (host/port/database/…) is
// already covered by the "enters form credentials" test above.
expect(fieldByName('database_name')).toHaveValue('PostgreSQL');
});
test('switches to the SQLAlchemy URI form via the connect link', async () => {
setup();
await selectPostgres();
userEvent.click(screen.getByTestId('sqla-connect-btn'));
expect(await screen.findByTestId('database-name-input')).toBeVisible();
expect(screen.getByTestId('sqlalchemy-uri-input')).toBeVisible();
});
test.each([
{
field: 'host',
errorType: 'CONNECTION_INVALID_HOSTNAME_ERROR',
message: "The hostname provided can't be resolved.",
},
{
field: 'port',
errorType: 'CONNECTION_PORT_CLOSED_ERROR',
message: 'The port is closed.',
},
])(
'surfaces a $field validation error returned by validate_parameters',
async ({ field, errorType, message }) => {
const createResource = jest.fn();
const useSingleViewResourceMock = jest
.spyOn(hooks, 'useSingleViewResource')
.mockReturnValue({
state: {
loading: false,
resource: null,
error: null,
},
fetchResource: jest.fn(),
createResource,
updateResource: jest.fn(),
clearError: jest.fn(),
setResource: jest.fn(),
});
setup();
await selectPostgres();
fillDynamicForm();
const submitButton = screen.getByTestId('btn-submit-connection');
// Blur the last field and let the (default, passing) validation settle
// so its async onBlur result can't race with — and overwrite — the
// submit-time validation below. Mirrors the Cypress `body.click(0, 0)`.
userEvent.click(document.body);
await waitFor(() => expect(submitButton).toBeEnabled());
// Make validation fail the way a real backend would for an unreachable
// host / closed port. The frontend's job is to map `extra.invalid:
// [field]` onto the matching form field. The 422 short-circuits the
// submit before the create (database-connect) request fires, so only
// this validate_parameters mock drives the rendered error.
fetchMock.modifyRoute('validate-params', {
response: {
status: 422,
body: {
errors: [
{
message,
error_type: errorType,
level: 'error',
extra: { invalid: [field] },
},
],
},
},
});
userEvent.click(submitButton);
// Wait for the async error to render, then confirm it surfaced as an
// antd inline field error (not a general alert)...
const errorText = await screen.findByText(message);
expect(errorText.closest('.ant-form-item-explain-error')).not.toBeNull();
// ...and that it belongs to the form item for the field named in
// `extra.invalid` — i.e. that input lives in the same `.ant-form-item`.
const formItem = errorText.closest('.ant-form-item') as HTMLElement;
expect(formItem.querySelector(`input[name="${field}"]`)).not.toBeNull();
expect(createResource).not.toHaveBeenCalled();
useSingleViewResourceMock.mockRestore();
},
);
});
test('handleChangeWithValidation function clears validation errors when called', () => {

View File

@@ -940,7 +940,7 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
}
const errors = await getValidation(dbToUpdate, true);
if (!isEmpty(validationErrors) || errors?.length) {
if (!isEmpty(validationErrors) || !isEmpty(errors)) {
addDangerToast(
t('Connection failed, please check your connection settings.'),
);

View File

@@ -57,6 +57,7 @@ export enum LocalStorageKeys {
DashboardExploreContext = 'dashboard__explore_context',
DashboardEditorShowOnlyMyCharts = 'dashboard__editor_show_only_my_charts',
CommonResizableSidebarWidths = 'common__resizable_sidebar_widths',
ChatState = 'chat__state',
}
export type LocalStorageValues = {
@@ -78,6 +79,7 @@ export type LocalStorageValues = {
dashboard__explore_context: Record<string, DashboardContextForExplore>;
dashboard__editor_show_only_my_charts: boolean;
common__resizable_sidebar_widths: Record<string, number>;
chat__state: { open: boolean; mode: string };
};
/*

View File

@@ -25,8 +25,8 @@ import {
useLocation,
} from 'react-router-dom';
import { bindActionCreators } from 'redux';
import { css } from '@apache-superset/core/theme';
import { Layout, Loading } from '@superset-ui/core/components';
import { css, useTheme } from '@apache-superset/core/theme';
import { Flex, Layout, Loading } from '@superset-ui/core/components';
import { setupAGGridModules } from '@superset-ui/core/components/ThemedAgGridReact';
import { ErrorBoundary } from 'src/components';
import Menu from 'src/features/home/Menu';
@@ -39,7 +39,12 @@ import { Logger, LOG_ACTIONS_SPA_NAVIGATION } from 'src/logger/LogUtils';
import setupCodeOverrides from 'src/setup/setupCodeOverrides';
import { logEvent } from 'src/logger/actions';
import { store } from 'src/views/store';
import { FeatureFlag, isFeatureEnabled } from '@superset-ui/core';
import { isUser } from 'src/types/bootstrapTypes';
import ExtensionsStartup from 'src/extensions/ExtensionsStartup';
import { Splitter } from 'src/components/Splitter';
import { ChatFloatingHost, ChatPanelHost, useChat } from 'src/core/chat';
import useStoredSidebarWidth from 'src/components/ResizableSidebar/useStoredSidebarWidth';
import { RootContextProviders } from './RootContextProviders';
import { ScrollToTop } from './ScrollToTop';
@@ -79,42 +84,139 @@ const LocationPathnameLogger = () => {
return <></>;
};
const CHAT_PANEL_DEFAULT_WIDTH = 400;
const CHAT_PANEL_MIN_WIDTH = 280;
const RouteSwitch = () => {
const theme = useTheme();
return (
<Switch>
{routes.map(({ path, Component, props = {}, Fallback = Loading }) => (
<Route path={path} key={path}>
<Suspense fallback={<Fallback />}>
<ErrorBoundary
css={css`
margin: ${theme.sizeUnit * 4}px;
`}
>
<Component user={bootstrapData.user} {...props} />
</ErrorBoundary>
</Suspense>
</Route>
))}
<Redirect from="/" to="/superset/welcome/" exact />
</Switch>
);
};
const layoutCss = css`
flex: 1;
min-height: 0;
overflow: hidden;
`;
const contentCss = css`
display: flex;
flex-direction: column;
min-height: 0;
overflow-y: auto;
position: relative;
`;
/**
* Renders the main content area. When the chat panel is open in panel mode,
* wraps <Layout> and <ChatPanelContent> in a Splitter so they sit side-by-side
* with a lazy drag bar (blue preview line, resize committed on mouseup).
* The full <Layout> tree lives inside the first panel so its internal flex
* context is preserved — SQL Lab, Explore, and other pages are unaffected.
*/
const AppContent = () => {
const isAuthenticated =
isUser(bootstrapData.user) && !bootstrapData.user.isAnonymous;
const chatExtensionsEnabled =
isFeatureEnabled(FeatureFlag.EnableExtensions) && isAuthenticated;
const { open: panelOpen, mode, chat } = useChat();
const hasChatExtension = chatExtensionsEnabled && !!chat;
const isPanelOpen = hasChatExtension && mode === 'panel' && panelOpen;
const [storedWidth, setStoredWidth] = useStoredSidebarWidth(
'chat:panel',
CHAT_PANEL_DEFAULT_WIDTH,
);
const layoutContent = (
<Layout css={layoutCss}>
<Layout.Content css={contentCss}>
<RouteSwitch />
</Layout.Content>
</Layout>
);
if (!isPanelOpen) {
return (
<>
{layoutContent}
{hasChatExtension && <ChatFloatingHost />}
</>
);
}
return (
<Splitter
lazy
onResizeEnd={sizes => {
const chatWidth = sizes[sizes.length - 1];
if (
typeof chatWidth === 'number' &&
chatWidth >= CHAT_PANEL_MIN_WIDTH
) {
setStoredWidth(chatWidth);
}
}}
css={css`
flex: 1;
min-height: 0;
overflow: hidden;
/*
* Splitter.Panel is not a flex container by default, so flex:1 on
* children (Layout, ChatPanelHost) has no height effect and
* panels collapse. Making them flex columns lets children fill them.
*/
& > .ant-splitter-panel {
display: flex !important;
flex-direction: column;
}
`}
>
<Splitter.Panel>{layoutContent}</Splitter.Panel>
<Splitter.Panel size={storedWidth} min={CHAT_PANEL_MIN_WIDTH}>
<ChatPanelHost />
</Splitter.Panel>
</Splitter>
);
};
const App = () => (
<Router basename={applicationRoot()}>
<ScrollToTop />
<LocationPathnameLogger />
<RootContextProviders>
<Menu
data={bootstrapData.common.menu_data}
isFrontendRoute={isFrontendRoute}
/>
<ExtensionsStartup>
<Switch>
{routes.map(({ path, Component, props = {}, Fallback = Loading }) => (
<Route path={path} key={path}>
<Suspense fallback={<Fallback />}>
<Layout>
<Layout.Content
css={css`
display: flex;
flex-direction: column;
`}
>
<ErrorBoundary
css={css`
margin: 16px;
`}
>
<Component user={bootstrapData.user} {...props} />
</ErrorBoundary>
</Layout.Content>
</Layout>
</Suspense>
</Route>
))}
<Redirect from="/" to="/superset/welcome/" exact />
</Switch>
</ExtensionsStartup>
<Flex
vertical
css={css`
height: 100vh;
overflow: hidden;
`}
>
<Menu
data={bootstrapData.common.menu_data}
isFrontendRoute={isFrontendRoute}
/>
<ExtensionsStartup>
<AppContent />
</ExtensionsStartup>
</Flex>
<ToastContainer />
</RootContextProviders>
</Router>

View File

@@ -0,0 +1,60 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
export const RoutePaths = {
REDIRECT: '/redirect/',
LOGIN: '/login/',
REGISTER_ACTIVATION: '/register/activation/:activationHash',
REGISTER: '/register/',
LOGOUT: '/logout/',
HOME: '/superset/welcome/',
FILE_HANDLER: '/superset/file-handler',
DASHBOARD: '/superset/dashboard/:idOrSlug/',
DASHBOARD_LIST: '/dashboard/list/',
CHART_ADD: '/chart/add',
CHART_LIST: '/chart/list/',
DATASET_LIST: '/tablemodelview/list/',
DATABASE_LIST: '/databaseview/list/',
SAVED_QUERIES: '/savedqueryview/list/',
CSS_TEMPLATES: '/csstemplatemodelview/list/',
THEMES: '/theme/list/',
ANNOTATION_LAYERS: '/annotationlayer/list/',
ANNOTATION_LIST: '/annotationlayer/:annotationLayerId/annotation/',
QUERY_HISTORY: '/sqllab/history/',
ALERTS: '/alert/list/',
REPORTS: '/report/list/',
ALERT_LOG: '/alert/:alertId/log/',
REPORT_LOG: '/report/:alertId/log/',
EXPLORE: '/explore/',
EXPLORE_PERMALINK: '/superset/explore/p',
DATASET_ADD: '/dataset/add/',
DATASET: '/dataset/:datasetId',
ROW_LEVEL_SECURITY: '/rowlevelsecurity/list',
TASKS: '/tasks/list/',
SQLLAB: '/sqllab/',
USER_INFO: '/user_info/',
ACTION_LOG: '/actionlog/list',
REGISTRATIONS: '/registrations/',
ALL_ENTITIES: '/superset/all_entities/',
TAGS: '/superset/tags/',
ROLES: '/roles/',
USERS: '/users/',
GROUPS: '/list_groups/',
EXTENSIONS: '/extensions/list/',
} as const;

View File

@@ -26,6 +26,7 @@ import {
} from 'react';
import { isUserAdmin } from 'src/dashboard/util/permissionUtils';
import getBootstrapData from 'src/utils/getBootstrapData';
import { RoutePaths } from './routePaths';
// not lazy loaded since this is the home page.
import Home from 'src/pages/Home';
@@ -189,158 +190,58 @@ const RedirectWarning = lazy(
type Routes = {
path: string;
Component: ComponentType;
Fallback?: ComponentType;
Component: ComponentType<any>;
Fallback?: ComponentType<any>;
props?: ComponentProps<any>;
}[];
export const routes: Routes = [
{ path: RoutePaths.REDIRECT, Component: RedirectWarning },
{ path: RoutePaths.LOGIN, Component: Login },
{ path: RoutePaths.REGISTER_ACTIVATION, Component: Register },
{ path: RoutePaths.REGISTER, Component: Register },
{ path: RoutePaths.LOGOUT, Component: Login },
{ path: RoutePaths.HOME, Component: Home },
{ path: RoutePaths.FILE_HANDLER, Component: FileHandler },
{ path: RoutePaths.DASHBOARD_LIST, Component: DashboardList },
{ path: RoutePaths.DASHBOARD, Component: Dashboard },
{ path: RoutePaths.CHART_ADD, Component: ChartCreation },
{ path: RoutePaths.CHART_LIST, Component: ChartList },
{ path: RoutePaths.DATASET_LIST, Component: DatasetList },
{ path: RoutePaths.DATABASE_LIST, Component: DatabaseList },
{ path: RoutePaths.SAVED_QUERIES, Component: SavedQueryList },
{ path: RoutePaths.CSS_TEMPLATES, Component: CssTemplateList },
{ path: RoutePaths.THEMES, Component: ThemeList },
{ path: RoutePaths.ANNOTATION_LAYERS, Component: AnnotationLayerList },
{ path: RoutePaths.ANNOTATION_LIST, Component: AnnotationList },
{ path: RoutePaths.QUERY_HISTORY, Component: QueryHistoryList },
{ path: RoutePaths.ALERTS, Component: AlertReportList },
{
path: '/redirect/',
Component: RedirectWarning,
},
{
path: '/login/',
Component: Login,
},
{
path: '/register/activation/:activationHash',
Component: Register,
},
{
path: '/register/',
Component: Register,
},
{
path: '/logout/',
Component: Login,
},
{
path: '/superset/welcome/',
Component: Home,
},
{
path: '/superset/file-handler',
Component: FileHandler,
},
{
path: '/dashboard/list/',
Component: DashboardList,
},
{
path: '/superset/dashboard/:idOrSlug/',
Component: Dashboard,
},
{
path: '/chart/add',
Component: ChartCreation,
},
{
path: '/chart/list/',
Component: ChartList,
},
{
path: '/tablemodelview/list/',
Component: DatasetList,
},
{
path: '/databaseview/list/',
Component: DatabaseList,
},
{
path: '/savedqueryview/list/',
Component: SavedQueryList,
},
{
path: '/csstemplatemodelview/list/',
Component: CssTemplateList,
},
{
path: '/theme/list/',
Component: ThemeList,
},
{
path: '/annotationlayer/list/',
Component: AnnotationLayerList,
},
{
path: '/annotationlayer/:annotationLayerId/annotation/',
Component: AnnotationList,
},
{
path: '/sqllab/history/',
Component: QueryHistoryList,
},
{
path: '/alert/list/',
path: RoutePaths.REPORTS,
Component: AlertReportList,
props: { isReportEnabled: true },
},
{ path: RoutePaths.ALERT_LOG, Component: ExecutionLogList },
{
path: '/report/list/',
Component: AlertReportList,
props: {
isReportEnabled: true,
},
},
{
path: '/alert/:alertId/log/',
path: RoutePaths.REPORT_LOG,
Component: ExecutionLogList,
props: { isReportEnabled: true },
},
{
path: '/report/:alertId/log/',
Component: ExecutionLogList,
props: {
isReportEnabled: true,
},
},
{
path: '/explore/',
Component: Chart,
},
{
path: '/superset/explore/p',
Component: Chart,
},
{
path: '/dataset/add/',
Component: DatasetCreation,
},
{
path: '/dataset/:datasetId',
Component: DatasetCreation,
},
{
path: '/rowlevelsecurity/list',
Component: RowLevelSecurityList,
},
{
path: '/tasks/list/',
Component: TaskList,
},
{
path: '/sqllab/',
Component: SqlLab,
},
{ path: '/user_info/', Component: UserInfo },
{
path: '/actionlog/list',
Component: ActionLogList,
},
{
path: '/registrations/',
Component: UserRegistrations,
},
{ path: RoutePaths.EXPLORE, Component: Chart },
{ path: RoutePaths.EXPLORE_PERMALINK, Component: Chart },
{ path: RoutePaths.DATASET_ADD, Component: DatasetCreation },
{ path: RoutePaths.DATASET, Component: DatasetCreation },
{ path: RoutePaths.ROW_LEVEL_SECURITY, Component: RowLevelSecurityList },
{ path: RoutePaths.TASKS, Component: TaskList },
{ path: RoutePaths.SQLLAB, Component: SqlLab },
{ path: RoutePaths.USER_INFO, Component: UserInfo },
{ path: RoutePaths.ACTION_LOG, Component: ActionLogList },
{ path: RoutePaths.REGISTRATIONS, Component: UserRegistrations },
];
if (isFeatureEnabled(FeatureFlag.TaggingSystem)) {
routes.push({
path: '/superset/all_entities/',
Component: AllEntities,
});
routes.push({
path: '/superset/tags/',
Component: Tags,
});
routes.push({ path: RoutePaths.ALL_ENTITIES, Component: AllEntities });
routes.push({ path: RoutePaths.TAGS, Component: Tags });
}
const user = getBootstrapData()?.user;
@@ -350,33 +251,18 @@ const isAdmin = isUserAdmin(user);
if (isAdmin) {
routes.push(
{
path: '/roles/',
Component: RolesList,
},
{
path: '/users/',
Component: UsersList,
},
{
path: '/list_groups/',
Component: GroupsList,
},
{ path: RoutePaths.ROLES, Component: RolesList },
{ path: RoutePaths.USERS, Component: UsersList },
{ path: RoutePaths.GROUPS, Component: GroupsList },
);
if (isFeatureEnabled(FeatureFlag.EnableExtensions)) {
routes.push({
path: '/extensions/list/',
Component: Extensions,
});
routes.push({ path: RoutePaths.EXTENSIONS, Component: Extensions });
}
}
if (authRegistrationEnabled) {
routes.push({
path: '/registrations/',
Component: UserRegistrations,
});
routes.push({ path: RoutePaths.REGISTRATIONS, Component: UserRegistrations });
}
const frontEndRoutes: Record<string, boolean> = routes

View File

@@ -1455,6 +1455,18 @@ class ChartDataQueryObjectSchema(Schema):
fields.String(),
allow_none=True,
)
time_compare_full_range = fields.Boolean(
required=False,
allow_none=True,
metadata={
"description": (
"When using a time comparison (time_offsets), plot each shifted "
"series across its full time range instead of truncating it to the "
"main series' range. Useful for comparing a partial current period "
"against complete prior periods."
)
},
)
@post_load
def rename_deprecated_fields(

View File

@@ -105,6 +105,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
series_limit: int
series_limit_metric: Metric | None
time_offsets: list[str]
time_compare_full_range: bool
time_shift: str | None
time_range: str | None
to_dttm: datetime | None
@@ -162,6 +163,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
self.to_dttm = kwargs.get("to_dttm")
self.result_type = kwargs.get("result_type")
self.time_offsets = kwargs.get("time_offsets", [])
self.time_compare_full_range = kwargs.get("time_compare_full_range", False)
self.inner_from_dttm = kwargs.get("inner_from_dttm")
self.inner_to_dttm = kwargs.get("inner_to_dttm")
self._rename_deprecated_fields(kwargs)
@@ -410,6 +412,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
"group_others_when_limit_reached": self.group_others_when_limit_reached,
"to_dttm": self.to_dttm,
"time_shift": self.time_shift,
"time_compare_full_range": self.time_compare_full_range,
}
return query_object_dict

View File

@@ -17,7 +17,7 @@
from __future__ import annotations
import datetime
from typing import Any, TYPE_CHECKING
from typing import Any, Literal, TYPE_CHECKING
import numpy as np
import pandas as pd
@@ -32,9 +32,14 @@ def left_join_df(
join_keys: list[str],
lsuffix: str = "",
rsuffix: str = "",
how: Literal["left", "right", "inner", "outer", "cross"] = "left",
) -> pd.DataFrame:
# `how` defaults to "left" so callers that only want the left frame's rows are
# unaffected. Passing how="outer" keeps right-only rows, which is used by the
# time-comparison "full range" option so historical series are not truncated to
# the main series' time range.
df = left_df.set_index(join_keys).join(
right_df.set_index(join_keys), lsuffix=lsuffix, rsuffix=rsuffix
right_df.set_index(join_keys), how=how, lsuffix=lsuffix, rsuffix=rsuffix
)
df.reset_index(inplace=True)
return df

View File

@@ -434,7 +434,13 @@ LOGO_RIGHT_TEXT: Callable[[], str] | str = ""
# Enables SWAGGER UI for superset openapi spec
# ex: http://localhost:8080/swagger/v1
FAB_API_SWAGGER_UI = True
# Disabled by default so the interactive API documentation surface is opt-in.
# Enable it by setting the SUPERSET_ENABLE_SWAGGER_UI environment variable
# (e.g. for local development) or by overriding FAB_API_SWAGGER_UI in
# superset_config.py.
FAB_API_SWAGGER_UI = utils.cast_to_boolean(
os.environ.get("SUPERSET_ENABLE_SWAGGER_UI", False)
)
# ----------------------------------------------------
# AUTHENTICATION CONFIG

View File

@@ -38,18 +38,12 @@ from superset.db_engine_specs.base import (
)
from superset.db_engine_specs.hive import HiveEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.utils import json
from superset.utils.core import get_user_agent, QuerySource
from superset.utils.network import is_hostname_valid, is_port_open
if TYPE_CHECKING:
from superset.models.core import Database
from superset.superset_typing import (
OAuth2ClientConfig,
OAuth2State,
OAuth2TokenResponse,
)
try:
@@ -283,105 +277,6 @@ class DatabricksDynamicBaseEngineSpec(BasicParametersMixin, DatabricksBaseEngine
"port": "port",
}
# OAuth2 endpoints for different cloud providers
_oauth2_endpoints = {
"aws": {
"authorization_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/{}/v1/authorize",
"token_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/{}/v1/token",
},
"azure": {
"authorization_request_uri": "https://login.microsoftonline.com/{}/oauth2/v2.0/authorize",
"token_request_uri": "https://login.microsoftonline.com/{}/oauth2/v2.0/token",
},
"gcp": {
"authorization_request_uri": "https://accounts.gcp.databricks.com/oidc/accounts/{}/v1/authorize",
"token_request_uri": "https://accounts.gcp.databricks.com/oidc/accounts/{}/v1/token",
},
}
@classmethod
def _detect_cloud_provider(cls, database: Database) -> str:
"""
Detect the cloud provider based on the database configuration.
Returns:
str: The cloud provider ('aws', 'azure', or 'gcp')
"""
# Check if cloud provider is explicitly configured in extra
if isinstance(
(cloud_provider := cls.get_extra_params(database).get("cloud_provider")),
str,
):
provider = cloud_provider.lower()
if provider in cls._oauth2_endpoints:
return provider
# Try to detect from hostname
hostname = database.url_object.host or ""
hostname = hostname.lower()
if "azure" in hostname or "azuredatabricks" in hostname:
return "azure"
elif "gcp" in hostname or "googleusercontent" in hostname:
return "gcp"
else:
# Default to AWS for compatibility
return "aws"
@classmethod
def _resolve_oauth2_endpoint(
cls,
database: Database,
provider: str,
endpoint_key: str,
) -> str:
"""
Build a fully-resolved OAuth2 endpoint for the detected cloud provider.
The per-provider templates carry a single ``{}`` placeholder for the
Databricks account id (or Azure tenant id), read from the database's
``extra`` (``account_id``, or ``tenant_id`` for Azure). Raising when it
is absent keeps the flow from issuing a request to an unresolved
``.../{}/...`` endpoint.
"""
template = cls._oauth2_endpoints[provider][endpoint_key]
if "{}" not in template:
return template
extra = cls.get_extra_params(database)
account_id = extra.get("account_id") or extra.get("tenant_id")
if not account_id:
raise OAuth2Error(
"Databricks OAuth2 endpoints could not be resolved: set "
"`account_id` (or `tenant_id` for Azure) in the database's "
"engine parameters, or provide a fully-resolved "
f"`{endpoint_key}` in DATABASE_OAUTH2_CLIENTS."
)
return template.format(account_id)
@classmethod
def impersonate_user(
cls,
database: Database,
username: str | None,
user_token: str | None,
url: URL,
engine_kwargs: dict[str, Any],
) -> tuple[URL, dict[str, Any]]:
"""
Update connection with OAuth2 access token for user impersonation.
"""
if user_token:
# Replace the access token in the URL with the user's OAuth2 token
url = url.set(password=user_token)
# Also update connect_args if they contain access token
connect_args = engine_kwargs.setdefault("connect_args", {})
if "access_token" in connect_args:
connect_args["access_token"] = user_token
return url, engine_kwargs
@staticmethod
def get_extra_params(
database: Database, source: QuerySource | None = None
@@ -579,74 +474,6 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
supports_dynamic_catalog = True
supports_cross_catalog_queries = True
# OAuth 2.0 support
supports_oauth2 = True
oauth2_exception = OAuth2RedirectError
oauth2_scope = "sql"
# OAuth2 endpoints are determined dynamically based on cloud provider
oauth2_authorization_request_uri = "" # Set dynamically
oauth2_token_request_uri = "" # Set dynamically
@classmethod
def get_oauth2_authorization_uri(
cls,
config: "OAuth2ClientConfig",
state: "OAuth2State",
code_verifier: str | None = None,
) -> str:
"""
Return URI for initial OAuth2 request with dynamic endpoint detection.
A fully-resolved `authorization_request_uri` from `DATABASE_OAUTH2_CLIENTS`
is preserved; only fall back to the auto-detected, account-resolved
endpoint when none is configured.
"""
if not config.get("authorization_request_uri"):
from superset import db
from superset.models.core import Database
# Get the database to detect cloud provider
database_id = state["database_id"]
if database := db.session.get(Database, database_id):
provider = cls._detect_cloud_provider(database)
from typing import cast
config = cast(
"OAuth2ClientConfig",
dict(config)
| {
"authorization_request_uri": cls._resolve_oauth2_endpoint(
database, provider, "authorization_request_uri"
)
},
)
return super().get_oauth2_authorization_uri(config, state, code_verifier)
@classmethod
def get_oauth2_token(
cls,
config: "OAuth2ClientConfig",
code: str,
code_verifier: str | None = None,
) -> "OAuth2TokenResponse":
"""
Exchange authorization code for refresh/access tokens.
The token request URI is resolved when the OAuth2 config is built (see
`get_oauth2_config`) and already targets the correct cloud provider and
account. There is no database context here to auto-detect it, so fail
fast rather than POST to an unresolved endpoint when it is missing.
"""
if not config.get("token_request_uri"):
raise OAuth2Error(
"Databricks OAuth2 token endpoint is not configured: provide a "
"fully-resolved `token_request_uri` in DATABASE_OAUTH2_CLIENTS."
)
return super().get_oauth2_token(config, code, code_verifier)
@classmethod
def build_sqlalchemy_uri( # type: ignore
cls, parameters: DatabricksNativeParametersType, *_
@@ -858,74 +685,6 @@ class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
supports_dynamic_schema = supports_catalog = supports_dynamic_catalog = True
# OAuth 2.0 support
supports_oauth2 = True
oauth2_exception = OAuth2RedirectError
oauth2_scope = "sql"
# OAuth2 endpoints are determined dynamically based on cloud provider
oauth2_authorization_request_uri = "" # Set dynamically
oauth2_token_request_uri = "" # Set dynamically
@classmethod
def get_oauth2_authorization_uri(
cls,
config: "OAuth2ClientConfig",
state: "OAuth2State",
code_verifier: str | None = None,
) -> str:
"""
Return URI for initial OAuth2 request with dynamic endpoint detection.
A fully-resolved `authorization_request_uri` from `DATABASE_OAUTH2_CLIENTS`
is preserved; only fall back to the auto-detected, account-resolved
endpoint when none is configured.
"""
if not config.get("authorization_request_uri"):
from superset import db
from superset.models.core import Database
# Get the database to detect cloud provider
database_id = state["database_id"]
if database := db.session.get(Database, database_id):
provider = cls._detect_cloud_provider(database)
from typing import cast
config = cast(
"OAuth2ClientConfig",
dict(config)
| {
"authorization_request_uri": cls._resolve_oauth2_endpoint(
database, provider, "authorization_request_uri"
)
},
)
return super().get_oauth2_authorization_uri(config, state, code_verifier)
@classmethod
def get_oauth2_token(
cls,
config: "OAuth2ClientConfig",
code: str,
code_verifier: str | None = None,
) -> "OAuth2TokenResponse":
"""
Exchange authorization code for refresh/access tokens.
The token request URI is resolved when the OAuth2 config is built (see
`get_oauth2_config`) and already targets the correct cloud provider and
account. There is no database context here to auto-detect it, so fail
fast rather than POST to an unresolved endpoint when it is missing.
"""
if not config.get("token_request_uri"):
raise OAuth2Error(
"Databricks OAuth2 token endpoint is not configured: provide a "
"fully-resolved `token_request_uri` in DATABASE_OAUTH2_CLIENTS."
)
return super().get_oauth2_token(config, code, code_verifier)
@classmethod
def build_sqlalchemy_uri( # type: ignore
cls, parameters: DatabricksPythonConnectorParametersType, *_

View File

@@ -238,6 +238,7 @@ def build_extension_data(extension: LoadedExtension) -> dict[str, Any]:
manifest = extension.manifest
extension_data: dict[str, Any] = {
"id": manifest.id,
"publisher": manifest.publisher,
"name": extension.name,
"version": extension.version,
"description": manifest.description or "",

View File

@@ -129,6 +129,7 @@ Dashboard Management:
- get_dashboard_info: Get detailed dashboard information by ID
- get_dashboard_layout: Get parsed tabs and chart positions for a dashboard (companion to get_dashboard_info when its omitted_fields hint flags position_json)
- generate_dashboard: Create a dashboard from chart IDs (requires write access)
- update_dashboard: Update an existing dashboard's title/description/slug/published/layout/theme/CSS (requires write access; ownership-checked per-instance)
- add_chart_to_existing_dashboard: Add a chart to an existing dashboard (requires write access)
Annotation Layers:
@@ -694,6 +695,7 @@ from superset.mcp_service.dashboard.tool import ( # noqa: F401, E402
get_dashboard_info,
get_dashboard_layout,
list_dashboards,
update_dashboard,
)
from superset.mcp_service.database.tool import ( # noqa: F401, E402
get_database_info,

View File

@@ -103,11 +103,15 @@ from superset.mcp_service.utils import (
escape_llm_context_delimiters,
sanitize_for_llm_context,
)
from superset.mcp_service.utils.response_utils import humanize_timestamp
from superset.mcp_service.utils.response_utils import (
humanize_timestamp,
OmittedFieldsBuilder,
)
from superset.mcp_service.utils.sanitization import (
sanitize_user_input,
sanitize_user_input_with_changes,
)
from superset.mcp_service.utils.url_utils import get_superset_base_url
from superset.utils.json import loads as json_loads
@@ -129,8 +133,6 @@ class DashboardError(BaseModel):
@classmethod
def create(cls, error: str, error_type: str) -> "DashboardError":
"""Create a standardized DashboardError with timestamp."""
from datetime import datetime
return cls(error=error, error_type=error_type, timestamp=datetime.now())
@@ -618,6 +620,48 @@ class GenerateDashboardRequest(BaseModel):
published: bool = Field(
default=False, description="Whether to publish the dashboard"
)
slug: str | None = Field(
None,
max_length=255,
description=(
"Optional URL slug for the dashboard. When set, the dashboard "
"is reachable at /superset/dashboard/<slug>/ instead of "
"/superset/dashboard/<id>/. Must be unique across the instance."
),
)
position_json: Dict[str, Any] | None = Field(
None,
description=(
"Optional explicit dashboard layout (Superset's position_json "
"dict). When set, replaces the auto-generated layout entirely. "
"Pass this when you need custom row composition, MARKDOWN "
"blocks, HEADER components, or specific chart widths/heights. "
"Omit to let the tool auto-generate a packed grid from chart_ids."
),
)
json_metadata_overrides: Dict[str, Any] | None = Field(
None,
description=(
"Optional overrides applied on top of the default "
"json_metadata. Common fields: label_colors (per-series brand "
"palette, e.g. {'Electronics': '#4C78A8'}), color_scheme "
"(named Superset palette), cross_filters_enabled (bool, "
"default False — set True for interactive dashboards), "
"shared_label_colors (list of label names for cross-chart "
"color consistency). Merged shallowly into the defaults; pass "
"only the keys you want to override."
),
)
css: str | None = Field(
None,
max_length=50000,
description=(
"Optional dashboard-level CSS. Useful for hiding chart chrome "
"(kebab menus, cross-filter chips) on print-ready dashboards, "
"or tweaking padding/typography. Applied as-is to the "
"dashboard's css field."
),
)
sanitization_warnings: List[str] = Field(
default_factory=list,
description=(
@@ -690,6 +734,168 @@ class GenerateDashboardRequest(BaseModel):
)
class UpdateDashboardRequest(BaseModel):
"""Request schema for updating an existing dashboard's layout/theme/style.
All fields are optional; only the fields explicitly passed are applied.
Use to retroactively set a custom layout, brand palette, or CSS on a
dashboard that was created via ``generate_dashboard`` (or earlier via
the REST API) without a full re-create.
"""
model_config = ConfigDict(populate_by_name=True)
identifier: int | str = Field(
...,
description=(
"Dashboard ID (integer), UUID, or slug. Same identifier shape "
"accepted by ``get_dashboard_info``."
),
)
dashboard_title: str | None = Field(
None,
max_length=500,
description="Optional new dashboard title.",
validation_alias=AliasChoices("dashboard_title", "title", "name"),
)
description: str | None = Field(
None,
description="Optional new dashboard description.",
)
slug: str | None = Field(
None,
max_length=255,
description=("Optional new URL slug. Pass empty string to clear a slug."),
)
published: bool | None = Field(
None,
description="Optional published flag.",
)
position_json: Dict[str, Any] | None = Field(
None,
description=(
"Optional replacement layout (Superset's position_json dict). "
"When set, fully replaces the existing layout. Get the current "
"layout via ``get_dashboard_info`` first if you want to make "
"incremental changes."
),
)
json_metadata_overrides: Dict[str, Any] | None = Field(
None,
description=(
"Optional overrides applied on top of the existing "
"json_metadata. Merged shallowly — pass only the keys you "
"want to change (e.g. label_colors, color_scheme, "
"cross_filters_enabled)."
),
)
css: str | None = Field(
None,
max_length=50000,
description=(
"Optional new dashboard CSS. Pass empty string to clear existing CSS."
),
)
sanitization_warnings: List[str] = Field(
default_factory=list,
description=(
"Internal: warnings emitted when user input was altered by "
"sanitization. Populated by the ``mode='before'`` validator "
"before dashboard_title is rewritten."
),
)
@model_validator(mode="before")
@classmethod
def _detect_dashboard_title_sanitization(cls, data: Any) -> Any:
"""Reject empty-after-sanitization titles and warn on partial strip.
Mirrors the same guard ``GenerateDashboardRequest`` applies so a
prompt-injected LLM cannot push XSS payloads through the update
path that the create path already rejects. Server-only
``sanitization_warnings`` is reset here so a caller cannot inject
warning text.
"""
if not isinstance(data, dict):
return data
data["sanitization_warnings"] = []
# Must match every AliasChoice on ``dashboard_title`` — otherwise
# an XSS payload supplied via a different key (e.g. ``name``)
# would bypass this ``mode='before'`` guard and slip through to
# Pydantic's alias resolution unsanitized.
for key in ("dashboard_title", "title", "name"):
if key in data:
raw = data[key]
break
else:
raw = None
if not isinstance(raw, str) or not raw.strip():
return data
sanitized, was_modified = sanitize_user_input_with_changes(
raw, "Dashboard title", max_length=500, allow_empty=True
)
if was_modified and not sanitized:
raise ValueError(
"dashboard_title contained only disallowed content "
"(HTML/script/URL schemes) and was removed entirely by "
"sanitization. Provide a dashboard_title with plain text."
)
if was_modified:
data["sanitization_warnings"].append(
"dashboard_title was modified during sanitization to "
"remove potentially unsafe content; the stored title "
"differs from the input."
)
return data
@field_validator("dashboard_title")
@classmethod
def sanitize_dashboard_title(cls, v: str | None) -> str | None:
"""Sanitize dashboard title to prevent XSS."""
if v is None or v == "":
return v
return sanitize_user_input(
v, "Dashboard title", max_length=500, allow_empty=True
)
class UpdateDashboardResponse(BaseModel):
"""Response schema for ``update_dashboard``.
Distinct from ``GenerateDashboardResponse`` because the semantics
differ: this response reports which fields actually changed on an
existing dashboard, rather than describing a newly created one.
"""
dashboard: DashboardInfo | None = Field(
None, description="The updated dashboard info, if successful"
)
dashboard_url: str | None = Field(None, description="URL to view the dashboard")
error: str | None = Field(None, description="Error message, if update failed")
permission_denied: bool = Field(
default=False,
description=(
"True when the user lacks edit rights on the target "
"dashboard. When True, ``error`` carries the human-readable "
"explanation and the response is otherwise empty."
),
)
changed_fields: List[str] = Field(
default_factory=list,
description=(
"Names of fields that were actually applied. Empty when the "
"request was a no-op or failed before any field was applied."
),
)
warnings: List[str] = Field(
default_factory=list,
description=(
"Non-fatal advisory messages — for example, that the supplied "
"title was altered by sanitization."
),
)
class GenerateDashboardResponse(BaseModel):
"""Response schema for dashboard generation."""
@@ -972,7 +1178,6 @@ def _build_omitted_fields(
Uses the shared OmittedFieldsBuilder utility so the pattern is consistent
across all MCP tool serializers.
"""
from superset.mcp_service.utils.response_utils import OmittedFieldsBuilder
return (
OmittedFieldsBuilder()
@@ -1005,7 +1210,6 @@ def serialize_chart_summary(
"""Serialize a chart to a lightweight summary for dashboard context."""
if not chart:
return None
from superset.mcp_service.utils.url_utils import get_superset_base_url
chart_id = getattr(chart, "id", None)
chart_url = None
@@ -1112,9 +1316,20 @@ def _sanitize_dashboard_info_for_llm_context(
return DashboardInfo.model_validate(payload)
def dashboard_serializer(dashboard: "Dashboard") -> DashboardInfo:
from superset.mcp_service.utils.url_utils import get_superset_base_url
def _safe_user_label(value: Any) -> str | None:
"""Coerce a `*_by_name` model attribute to a display string or None.
The Dashboard model exposes ``created_by_name`` / ``changed_by_name``
as plain strings, but some serializer call sites pass through
objects (User instances, Mocks in tests) — defensive coercion keeps
the response a valid string and avoids leaking ``repr(user)``.
"""
if isinstance(value, str) and value:
return value
return None
def dashboard_serializer(dashboard: "Dashboard") -> DashboardInfo:
include_data_model_metadata = user_can_view_data_model_metadata()
base_url = get_superset_base_url()
relative_url = dashboard.url # e.g. "/superset/dashboard/{slug_or_id}/"
@@ -1175,7 +1390,6 @@ def dashboard_serializer(dashboard: "Dashboard") -> DashboardInfo:
def serialize_dashboard_object(dashboard: Any) -> DashboardInfo:
"""Simple dashboard serializer that safely handles object attributes."""
from superset.mcp_service.utils.url_utils import get_superset_base_url
# Construct URL from id/slug (the model's @property isn't available on
# column-only query tuples returned by DAO.list with select_columns)

View File

@@ -20,6 +20,7 @@ from .generate_dashboard import generate_dashboard
from .get_dashboard_info import get_dashboard_info
from .get_dashboard_layout import get_dashboard_layout
from .list_dashboards import list_dashboards
from .update_dashboard import update_dashboard
__all__ = [
"list_dashboards",
@@ -27,4 +28,5 @@ __all__ = [
"get_dashboard_layout",
"generate_dashboard",
"add_chart_to_existing_dashboard",
"update_dashboard",
]

View File

@@ -26,9 +26,11 @@ from typing import Any, Dict, List
from fastmcp import Context
from flask import g
from pydantic import ValidationError
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.extensions import event_logger
from superset.extensions import db, event_logger
from superset.mcp_service.dashboard.constants import (
generate_id,
GRID_COLUMN_COUNT,
@@ -38,8 +40,11 @@ from superset.mcp_service.dashboard.schemas import (
DashboardInfo,
GenerateDashboardRequest,
GenerateDashboardResponse,
serialize_chart_summary,
serialize_tag_object,
)
from superset.mcp_service.privacy import user_can_view_data_model_metadata
from superset.mcp_service.utils.response_utils import humanize_timestamp
from superset.mcp_service.utils.url_utils import get_superset_base_url
from superset.utils import json
@@ -197,22 +202,24 @@ def generate_dashboard( # noqa: C901
- To add a chart to an EXISTING dashboard, use add_chart_to_existing_dashboard.
Never use this tool as a fallback when add_chart_to_existing_dashboard fails.
- All charts must exist and be accessible to current user
- Charts arranged automatically in 2-column grid layout
- Layout: by default, charts are arranged in an auto-generated 2-column
grid. When ``position_json`` is supplied, that explicit layout is
written verbatim and the auto-generated grid is skipped — use this to
compose custom rows, header bands, or MARKDOWN/HEADER components.
Returns:
- Dashboard ID and URL
"""
from pydantic import ValidationError
from sqlalchemy.exc import SQLAlchemyError
# Advisory messages (e.g. title sanitization) surfaced to the caller
# alongside the created dashboard so they can tell when their input
# was altered.
sanitization_warnings = list(getattr(request, "sanitization_warnings", []) or [])
try:
# Get chart objects from IDs (required for SQLAlchemy relationships)
from superset import db
# avoids ImportError before Flask app initialisation:
# `Exception: App not initialized yet. Please call init_app first`
# raised from superset.utils.encrypt when Slice's encrypted Column
# types are instantiated at model-class definition time.
from superset.models.slice import Slice
with event_logger.log_context(action="mcp.generate_dashboard.chart_validation"):
@@ -253,9 +260,14 @@ def generate_dashboard( # noqa: C901
),
)
# Create dashboard layout with chart objects
# Create dashboard layout with chart objects.
# If the caller provided an explicit position_json, use it verbatim;
# otherwise auto-generate a packed-grid layout from the chart ids.
with event_logger.log_context(action="mcp.generate_dashboard.layout"):
layout = _create_dashboard_layout(chart_objects)
if request.position_json:
layout = request.position_json
else:
layout = _create_dashboard_layout(chart_objects)
# Resolve dashboard title: use provided title or derive from chart names
dashboard_title = (
@@ -276,27 +288,32 @@ def generate_dashboard( # noqa: C901
from superset.models.dashboard import Dashboard
with event_logger.log_context(action="mcp.generate_dashboard.db_write"):
json_metadata = json.dumps(
{
"filter_scopes": {},
"expanded_slices": {},
"refresh_frequency": 0,
"timed_refresh_immune_slices": [],
"color_scheme": None,
"label_colors": {},
"shared_label_colors": {},
"color_scheme_domain": [],
"cross_filters_enabled": False,
"native_filter_configuration": [],
"global_chart_configuration": {
"scope": {
"rootPath": ["ROOT_ID"],
"excluded": [],
}
},
"chart_configuration": {},
}
)
# Build the default json_metadata that every new dashboard gets.
# When the caller supplied json_metadata_overrides, merge those
# in shallowly so the LLM can override label_colors / color_scheme
# / cross_filters_enabled without re-specifying the whole shape.
json_metadata_dict: Dict[str, Any] = {
"filter_scopes": {},
"expanded_slices": {},
"refresh_frequency": 0,
"timed_refresh_immune_slices": [],
"color_scheme": None,
"label_colors": {},
"shared_label_colors": {},
"color_scheme_domain": [],
"cross_filters_enabled": False,
"native_filter_configuration": [],
"global_chart_configuration": {
"scope": {
"rootPath": ["ROOT_ID"],
"excluded": [],
}
},
"chart_configuration": {},
}
if request.json_metadata_overrides:
json_metadata_dict.update(request.json_metadata_overrides)
json_metadata = json.dumps(json_metadata_dict)
try:
dashboard = Dashboard()
@@ -308,6 +325,12 @@ def generate_dashboard( # noqa: C901
if request.description:
dashboard.description = request.description
if request.slug:
dashboard.slug = request.slug
if request.css:
dashboard.css = request.css
# Re-query the current user and charts directly in the
# current db.session. g.user was loaded in a Flask
# app_context that has since been torn down (the
@@ -344,6 +367,38 @@ def generate_dashboard( # noqa: C901
dashboard.id,
exc_info=True,
)
except IntegrityError as db_err:
try:
db.session.rollback() # pylint: disable=consider-using-transaction
except SQLAlchemyError:
logger.warning(
"Database rollback failed during error handling",
exc_info=True,
)
logger.error("Dashboard creation failed: %s", db_err, exc_info=True)
# Slug uniqueness is the only IntegrityError a caller
# can fix on retry; surface a clear, structured message
# so the LLM can propose a different slug. Detection
# scans both the Postgres constraint name
# (``dashboards_slug_key``) and the SQLite phrasing
# (``UNIQUE constraint failed: dashboards.slug``).
err_text = str(db_err).lower()
if request.slug and "slug" in err_text:
return GenerateDashboardResponse(
dashboard=None,
dashboard_url=None,
error=(
f"Slug {request.slug!r} is already in use by "
"another dashboard. Choose a different slug "
"and retry, or omit the slug to get a "
"generated URL."
),
)
return GenerateDashboardResponse(
dashboard=None,
dashboard_url=None,
error=("Failed to create dashboard due to a database constraint."),
)
except SQLAlchemyError as db_err:
try:
db.session.rollback() # pylint: disable=consider-using-transaction
@@ -363,6 +418,17 @@ def generate_dashboard( # noqa: C901
error="Failed to create dashboard due to a database error.",
)
# ``dashboard.id`` is fixed at create-commit time; the post-commit
# re-fetch below either returns the same row or fails — in either
# outcome the URL doesn't change. Bind it once so the three
# downstream consumers (partial response, DashboardInfo.url,
# response.dashboard_url) share a single source. Prefer the slug
# over the id to match ``update_dashboard``'s canonical URL shape.
dashboard_url = (
f"{get_superset_base_url()}/superset/dashboard/"
f"{dashboard.slug or dashboard.id}/"
)
# Re-fetch with eager-loaded relationships for serialization.
# The preceding commit may invalidate the session in multi-tenant
# environments, causing "Can't reconnect until invalid transaction
@@ -396,9 +462,6 @@ def generate_dashboard( # noqa: C901
"Database rollback failed during dashboard re-fetch error handling",
exc_info=True,
)
dashboard_url = (
f"{get_superset_base_url()}/superset/dashboard/{dashboard.id}/"
)
return GenerateDashboardResponse(
dashboard=DashboardInfo(
id=dashboard.id,
@@ -409,16 +472,15 @@ def generate_dashboard( # noqa: C901
),
dashboard_url=dashboard_url,
error=None,
warnings=sanitization_warnings,
warnings=sanitization_warnings
+ [
"Dashboard created but response metadata is partial "
"(post-create refresh failed); some fields are omitted. "
"Call get_dashboard_info to retrieve the full record."
],
)
# Convert to our response format
from superset.mcp_service.dashboard.schemas import (
serialize_chart_summary,
serialize_tag_object,
)
from superset.mcp_service.utils.response_utils import humanize_timestamp
include_data_model_metadata = user_can_view_data_model_metadata()
dashboard_info = DashboardInfo(
id=dashboard.id,
@@ -433,7 +495,7 @@ def generate_dashboard( # noqa: C901
created_by=dashboard.created_by_name or None,
changed_by=dashboard.changed_by_name or None,
uuid=str(dashboard.uuid) if dashboard.uuid else None,
url=f"{get_superset_base_url()}/superset/dashboard/{dashboard.id}/",
url=dashboard_url,
chart_count=len(request.chart_ids),
tags=[
serialize_tag_object(tag)
@@ -453,8 +515,6 @@ def generate_dashboard( # noqa: C901
],
)
dashboard_url = f"{get_superset_base_url()}/superset/dashboard/{dashboard.id}/"
logger.info(
"Created dashboard %s with %s charts", dashboard.id, len(request.chart_ids)
)
@@ -467,17 +527,19 @@ def generate_dashboard( # noqa: C901
)
except (SQLAlchemyError, ValueError, AttributeError, ValidationError) as e:
from superset import db
try:
db.session.rollback() # pylint: disable=consider-using-transaction
except SQLAlchemyError:
logger.warning(
"Database rollback failed during error handling", exc_info=True
)
# ``str(e)`` on SQLAlchemyError frequently contains table/column/
# constraint names that should not leak to the MCP response.
# The raw exception is captured above via ``logger.error`` with
# ``exc_info=True``; the response surfaces a generic message.
logger.error("Error creating dashboard: %s", e, exc_info=True)
return GenerateDashboardResponse(
dashboard=None,
dashboard_url=None,
error=f"Failed to create dashboard: {str(e)}",
error="Failed to create dashboard due to an internal error.",
)

View File

@@ -0,0 +1,269 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Update dashboard FastMCP tool
This module contains the FastMCP tool for updating an existing dashboard's
layout, theme, and styling. Companion to ``generate_dashboard`` for
incremental edits without re-creating the dashboard.
"""
import logging
from typing import Any
from fastmcp import Context
from sqlalchemy.exc import SQLAlchemyError
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.commands.dashboard.exceptions import DashboardNotFoundError
from superset.exceptions import SupersetSecurityException
from superset.extensions import db, event_logger
from superset.mcp_service.dashboard.schemas import (
dashboard_serializer,
DashboardError,
UpdateDashboardRequest,
UpdateDashboardResponse,
)
from superset.mcp_service.utils.url_utils import get_superset_base_url
from superset.utils import json
logger = logging.getLogger(__name__)
def _build_dashboard_url(dashboard: Any) -> str:
"""Build the user-facing dashboard URL, preferring slug over id."""
return (
f"{get_superset_base_url()}/superset/dashboard/"
f"{dashboard.slug or dashboard.id}/"
)
def _find_and_authorize_dashboard(
identifier: int | str,
) -> tuple[Any, UpdateDashboardResponse | DashboardError | None]:
"""Return (dashboard, None) on success or (None, error_response) on failure.
Mirrors the helper in ``add_chart_to_existing_dashboard``: combines
the not-found and forbidden cases so the main tool body has a single
pre-condition branch. Returns ``DashboardError`` on not-found and
``UpdateDashboardResponse`` (with ``permission_denied=True``) on
ownership failure — the two shapes carry different information for
the caller.
"""
# avoids ImportError before Flask app initialisation:
# `Exception: App not initialized yet. Please call init_app first`
# raised from superset.utils.encrypt when DashboardDAO is imported
# (via Slice's encrypted Column types). `security_manager` is a
# LocalProxy that needs the same app context to resolve at call
# time, so it is co-located with the DAO it accompanies.
from superset import security_manager
from superset.daos.dashboard import DashboardDAO
try:
dashboard = DashboardDAO.get_by_id_or_slug(identifier)
except (DashboardNotFoundError, SQLAlchemyError):
return None, DashboardError(
error=f"Dashboard not found: {identifier!r}",
error_type="DashboardNotFound",
)
if dashboard is None:
return None, DashboardError(
error=f"Dashboard not found: {identifier!r}",
error_type="DashboardNotFound",
)
try:
security_manager.raise_for_ownership(dashboard)
except SupersetSecurityException:
return None, UpdateDashboardResponse(
permission_denied=True,
error=(
f"You don't have permission to edit dashboard "
f"'{dashboard.dashboard_title}' (ID: {dashboard.id})."
),
)
return dashboard, None
def _merge_json_metadata(dashboard: Any, overrides: dict[str, Any]) -> str:
"""Shallow-merge ``overrides`` onto the dashboard's existing metadata.
Parses defensively: a row may carry malformed JSON or a non-object
payload (e.g. ``"[]"``) from an older migration or manual edit. Either
would raise out of the caller's ``SQLAlchemyError`` handler, so fall
back to an empty object and overlay the overrides on top.
"""
existing: dict[str, Any] = {}
if dashboard.json_metadata:
try:
parsed = json.loads(dashboard.json_metadata)
if isinstance(parsed, dict):
existing = parsed
except (ValueError, TypeError):
pass
existing.update(overrides)
return json.dumps(existing)
def _apply_field_updates(dashboard: Any, request: UpdateDashboardRequest) -> list[str]:
"""Apply each explicitly-passed field to the dashboard.
Returns the names of fields actually changed. Mutates ``dashboard``
in place. ``json_metadata_overrides`` is merged shallowly with the
existing ``json_metadata``; an empty string in ``slug`` or ``css``
clears the underlying value.
"""
changed: list[str] = []
if request.dashboard_title is not None:
dashboard.dashboard_title = request.dashboard_title
changed.append("dashboard_title")
if request.description is not None:
dashboard.description = request.description
changed.append("description")
if request.slug is not None:
dashboard.slug = request.slug or None
changed.append("slug")
if request.published is not None:
dashboard.published = request.published
changed.append("published")
if request.position_json is not None:
dashboard.position_json = json.dumps(request.position_json)
changed.append("position_json")
if request.json_metadata_overrides is not None:
dashboard.json_metadata = _merge_json_metadata(
dashboard, request.json_metadata_overrides
)
changed.append("json_metadata")
if request.css is not None:
dashboard.css = request.css or None
changed.append("css")
return changed
@tool(
tags=["mutate"],
class_permission_name="Dashboard",
method_permission_name="write",
annotations=ToolAnnotations(
title="Update dashboard layout/theme/CSS",
readOnlyHint=False,
destructiveHint=False,
),
)
def update_dashboard(
request: UpdateDashboardRequest, ctx: Context
) -> UpdateDashboardResponse | DashboardError:
"""Patch an existing dashboard's layout, theme, or styling.
Companion to ``generate_dashboard`` for incremental edits. Accepts
the same layout/theme/CSS fields that ``generate_dashboard`` does, so
an LLM can:
- Set or replace ``position_json`` after auto-generation
- Apply brand ``label_colors`` and ``color_scheme`` via
``json_metadata_overrides``
- Toggle ``cross_filters_enabled`` via ``json_metadata_overrides``
- Inject ``css`` to hide chrome on print-ready dashboards
- Update ``dashboard_title``, ``description``, ``slug``, ``published``
Only the fields explicitly passed are applied; other fields are left
unchanged. ``json_metadata_overrides`` is merged shallowly with the
existing json_metadata — pass only the keys you want to change.
Example::
update_dashboard(request={
"identifier": 42,
"json_metadata_overrides": {
"label_colors": {"Electronics": "#4C78A8"},
"cross_filters_enabled": False,
},
"css": ".header-controls {display: none;}",
})
"""
ctx.info(f"Updating dashboard: identifier={request.identifier}")
dashboard, auth_error = _find_and_authorize_dashboard(request.identifier)
if auth_error is not None:
return auth_error
changed_fields: list[str] = []
warnings: list[str] = list(request.sanitization_warnings)
try:
with event_logger.log_context(action="mcp.update_dashboard.apply"):
changed_fields = _apply_field_updates(dashboard, request)
if not changed_fields:
warnings.append("No fields provided; dashboard unchanged.")
return UpdateDashboardResponse(
dashboard=dashboard_serializer(dashboard),
dashboard_url=_build_dashboard_url(dashboard),
error=None,
changed_fields=[],
warnings=warnings,
)
db.session.commit() # pylint: disable=consider-using-transaction
try:
db.session.refresh(dashboard)
except SQLAlchemyError:
logger.warning(
"Dashboard %s updated but refresh failed; "
"continuing with current values",
dashboard.id,
exc_info=True,
)
warnings.append(
"Dashboard updated but post-update refresh failed; "
"returned values may not reflect database state."
)
except SQLAlchemyError as db_err:
try:
db.session.rollback() # pylint: disable=consider-using-transaction
except SQLAlchemyError:
logger.warning(
"Database rollback failed during error handling",
exc_info=True,
)
logger.error("Dashboard update failed: %s", db_err, exc_info=True)
return DashboardError(
error="Failed to update dashboard due to a database error.",
error_type="DatabaseError",
)
ctx.info(f"Dashboard {dashboard.id} updated: changed={changed_fields}")
return UpdateDashboardResponse(
dashboard=dashboard_serializer(dashboard),
dashboard_url=_build_dashboard_url(dashboard),
error=None,
changed_fields=changed_fields,
warnings=warnings,
)

View File

@@ -33,6 +33,7 @@ from typing import (
Callable,
cast,
ClassVar,
Literal,
NamedTuple,
Optional,
TYPE_CHECKING,
@@ -130,7 +131,11 @@ from superset.utils.core import (
SqlExpressionType,
TIME_COMPARISON,
)
from superset.utils.date_parser import get_past_or_future, normalize_time_delta
from superset.utils.date_parser import (
get_past_or_future,
normalize_time_delta,
TimeDeltaAmbiguousError,
)
from superset.utils.dates import datetime_to_epoch
from superset.utils.rls import apply_rls
@@ -2013,6 +2018,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
offset_dfs,
time_grain,
join_keys,
full_range=getattr(query_object, "time_compare_full_range", False),
)
return CachedTimeOffset(df=df, queries=queries, cache_keys=cache_keys)
@@ -2210,7 +2216,11 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
return offset_df, join_keys
def _perform_join(
self, df: pd.DataFrame, offset_df: pd.DataFrame, actual_join_keys: list[str]
self,
df: pd.DataFrame,
offset_df: pd.DataFrame,
actual_join_keys: list[str],
how: Literal["left", "right", "inner", "outer", "cross"] = "left",
) -> pd.DataFrame:
"""Perform the appropriate join operation."""
if actual_join_keys:
@@ -2219,6 +2229,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
right_df=offset_df,
join_keys=actual_join_keys,
rsuffix=R_SUFFIX,
how=how,
)
else:
temp_key = "__temp_join_key__"
@@ -2230,6 +2241,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
right_df=offset_df,
join_keys=[temp_key],
rsuffix=R_SUFFIX,
how=how,
)
# Remove temporary join keys
@@ -2245,6 +2257,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
offset_dfs: dict[str, pd.DataFrame],
time_grain: str | None,
join_keys: list[str],
full_range: bool = False,
) -> pd.DataFrame:
"""
Join offset DataFrames with the main DataFrame.
@@ -2253,6 +2266,10 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
:param offset_dfs: A list of offset DataFrames.
:param time_grain: The time grain used to calculate the temporal join key.
:param join_keys: The keys to join on.
:param full_range: When True, time-shifted (offset) series keep their full
time range instead of being truncated to the main series' range. This
uses an outer join so offset-only rows (e.g. the rest of a prior day when
the current day is still in progress) are preserved.
"""
join_column_producer = app.config["TIME_GRAIN_JOIN_COLUMN_PRODUCERS"].get(
time_grain
@@ -2280,13 +2297,60 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
join_column_producer,
)
df = self._perform_join(df, offset_df, actual_join_keys)
# The full-range option is only meaningful for relative offsets aligned
# on a temporal join column (time_grain set). Date-range offsets and the
# grain-less path keep the existing left-join behavior.
use_outer_join = (
full_range
and bool(time_grain)
and not is_date_range_offset
and bool(join_keys)
)
how: Literal["left", "outer"] = "outer" if use_outer_join else "left"
df = self._perform_join(df, offset_df, actual_join_keys, how=how)
if use_outer_join:
df = self._coalesce_offset_index(df, offset, join_keys)
df = self._apply_cleanup_logic(
df, offset, time_grain, join_keys, is_date_range_offset
)
return df
def _coalesce_offset_index(
self,
df: pd.DataFrame,
offset: str,
join_keys: list[str],
) -> pd.DataFrame:
"""
Rebuild the temporal x-axis after an outer join with an offset DataFrame.
Offset-only rows (those with no matching row in the main series) have a null
x-axis value because the join happens on the normalized offset join column,
not the raw temporal column. Their real timestamp lives in the suffixed
right-hand column, expressed in the offset's own time range (e.g. "yesterday
15:00"). Shifting it forward by the offset places it on the main series'
axis (e.g. "today 15:00") so the comparison line spans the full period.
"""
x_axis = join_keys[0]
offset_x_axis = f"{x_axis}{R_SUFFIX}"
if x_axis not in df.columns or offset_x_axis not in df.columns:
return df
# normalize_time_delta returns a negative delta for "... ago" offsets, so
# subtracting it shifts the historical timestamp forward onto the main axis.
try:
forward_shift = DateOffset(**normalize_time_delta(offset))
except (ValueError, TimeDeltaAmbiguousError):
return df
shifted = df[offset_x_axis] - forward_shift
df[x_axis] = df[x_axis].fillna(shifted)
return df
def add_offset_join_column(
self,
df: pd.DataFrame,

View File

@@ -221,6 +221,7 @@ class QueryObjectDict(TypedDict, total=False):
group_others_when_limit_reached: bool
to_dttm: datetime | None
time_shift: str | None
time_compare_full_range: bool
post_processing: list[dict[str, Any]]
# Additional fields used throughout the codebase

View File

@@ -194,9 +194,9 @@ def should_use_v2_api() -> bool:
except SlackApiError:
# use the v1 api but warn with a deprecation message
logger.warning(
"""Your current Slack scopes are missing `channels:read`. Please add
this to your Slack app in order to continue using the v1 API. Support
for the old Slack API will be removed in Superset version 6.0.0."""
"Your current Slack scopes are missing `channels:read`. Please add "
"this to your Slack app in order to continue using the v1 API. Support "
"for the old Slack API will be removed in Superset version 6.0.0."
)
return False

View File

@@ -135,6 +135,10 @@ ALERT_REPORTS_QUERY_EXECUTION_MAX_TRIES = 3
FAB_ADD_SECURITY_API = True
# Swagger UI / OpenAPI spec is opt-in in the base config; enable it for tests
# that exercise the /api/v1/_openapi spec endpoint.
FAB_API_SWAGGER_UI = True
class CeleryConfig:
broker_url = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}"

View File

@@ -53,6 +53,9 @@ _datasource._perform_join = ExploreMixin._perform_join.__get__(_datasource)
_datasource._apply_cleanup_logic = ExploreMixin._apply_cleanup_logic.__get__(
_datasource
)
_datasource._coalesce_offset_index = ExploreMixin._coalesce_offset_index.__get__(
_datasource
)
# Static methods don't need binding - assign directly
_datasource.generate_join_column = ExploreMixin.generate_join_column
_datasource.is_valid_date_range_static = ExploreMixin.is_valid_date_range_static
@@ -211,6 +214,91 @@ def test_join_offset_dfs_with_month_granularity():
assert_frame_equal(expected, result)
def test_join_offset_dfs_full_range_keeps_historical_tail():
"""
With full_range=True the offset (historical) series keeps its full time range
even when the main series ends earlier.
Simulates "today so far" (main, ends at 01:00) compared against "1 day ago"
(a complete prior day, runs to 02:00). The 02:00 historical point must survive
and be aligned onto today's axis, with the main metric left null there.
"""
# Main series: today, only two hours of data so far.
df = DataFrame(
{
"A": [Timestamp("2021-01-02 00:00"), Timestamp("2021-01-02 01:00")],
"V": [1.0, 2.0],
}
)
# Offset series: the full prior day (already renamed metric column "B").
offset_df = DataFrame(
{
"A": [
Timestamp("2021-01-01 00:00"),
Timestamp("2021-01-01 01:00"),
Timestamp("2021-01-01 02:00"),
],
"B": [10.0, 20.0, 30.0],
}
)
offset_dfs = {"1 day ago": offset_df}
time_grain = TimeGrain.HOUR
join_keys = ["A"]
expected = DataFrame(
{
"A": [
Timestamp("2021-01-02 00:00"),
Timestamp("2021-01-02 01:00"),
Timestamp("2021-01-02 02:00"),
],
"V": [1.0, 2.0, None],
"B": [10.0, 20.0, 30.0],
}
)
result = query_context_processor.join_offset_dfs(
df, offset_dfs, time_grain, join_keys, full_range=True
)
assert_frame_equal(expected, result)
def test_join_offset_dfs_full_range_disabled_truncates_historical():
"""The default (full_range=False) left join drops the historical 02:00 point."""
df = DataFrame(
{
"A": [Timestamp("2021-01-02 00:00"), Timestamp("2021-01-02 01:00")],
"V": [1.0, 2.0],
}
)
offset_df = DataFrame(
{
"A": [
Timestamp("2021-01-01 00:00"),
Timestamp("2021-01-01 01:00"),
Timestamp("2021-01-01 02:00"),
],
"B": [10.0, 20.0, 30.0],
}
)
offset_dfs = {"1 day ago": offset_df}
expected = DataFrame(
{
"A": [Timestamp("2021-01-02 00:00"), Timestamp("2021-01-02 01:00")],
"V": [1.0, 2.0],
"B": [10.0, 20.0],
}
)
result = query_context_processor.join_offset_dfs(
df, offset_dfs, TimeGrain.HOUR, ["A"], full_range=False
)
assert_frame_equal(expected, result)
def test_join_offset_dfs_totals_query_no_dimensions():
"""
Test time offset join for totals query with no dimension columns.

View File

@@ -0,0 +1,77 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for environment-driven Swagger UI config defaults.
``superset.config`` is imported in a fresh subprocess for each case so the
module is evaluated under a controlled environment without reloading (and
mutating) the config module shared by the rest of the test session.
"""
import os
import subprocess
import sys
import pytest
def _resolve_swagger_default(env_value: str | None) -> str:
"""Resolve ``FAB_API_SWAGGER_UI`` for a given ``SUPERSET_ENABLE_SWAGGER_UI``.
Evaluates ``superset.config`` in a fresh subprocess under a controlled
environment so the result reflects only the supplied env var. Config-path
overrides are stripped so a local ``superset_config`` cannot taint the
default, and only the final stdout line is read in case config loading
emits banner output.
"""
env = dict(os.environ)
for var in (
"SUPERSET_ENABLE_SWAGGER_UI",
"SUPERSET_CONFIG_PATH",
"SUPERSET_CONFIG",
):
env.pop(var, None)
if env_value is not None:
env["SUPERSET_ENABLE_SWAGGER_UI"] = env_value
result = subprocess.run( # noqa: S603
[
sys.executable,
"-c",
"import superset.config as c; print(c.FAB_API_SWAGGER_UI)",
],
env=env,
capture_output=True,
text=True,
check=True,
)
return result.stdout.strip().splitlines()[-1]
@pytest.mark.parametrize(
"env_value, expected",
[
(None, "False"), # unset -> off by default
("true", "True"),
("True", "True"),
("false", "False"),
("", "False"),
],
)
def test_fab_api_swagger_ui_is_env_driven_and_off_by_default(
env_value: str | None, expected: str
) -> None:
"""Swagger UI defaults to off and follows ``SUPERSET_ENABLE_SWAGGER_UI``."""
assert _resolve_swagger_default(env_value) == expected

View File

@@ -17,23 +17,14 @@
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
from datetime import datetime
from typing import Any, Optional
from urllib.parse import parse_qs, urlparse
from typing import Optional
import pytest
from pytest_mock import MockerFixture
from sqlalchemy.engine.url import make_url
from superset.db_engine_specs.base import OAuth2State
from superset.db_engine_specs.databricks import (
DatabricksNativeEngineSpec,
DatabricksPythonConnectorEngineSpec,
)
from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.superset_typing import OAuth2ClientConfig
from superset.utils import json
from superset.utils.oauth2 import decode_oauth2_state
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm # noqa: F401
@@ -300,541 +291,3 @@ def test_get_prequeries(mocker: MockerFixture) -> None:
"USE CATALOG `evil`` USE CATALOG bad`",
"USE SCHEMA `evil`` USE SCHEMA bad`",
]
# OAuth2 Tests
def test_oauth2_attributes() -> None:
"""
Test that OAuth2 attributes are properly set for both engine specs.
"""
# Test DatabricksNativeEngineSpec
assert DatabricksNativeEngineSpec.supports_oauth2 is True
assert DatabricksNativeEngineSpec.oauth2_exception is OAuth2RedirectError
assert DatabricksNativeEngineSpec.oauth2_scope == "sql"
# OAuth2 endpoints are now dynamic and set at runtime
assert DatabricksNativeEngineSpec.oauth2_authorization_request_uri == ""
assert DatabricksNativeEngineSpec.oauth2_token_request_uri == ""
# Test DatabricksPythonConnectorEngineSpec
assert DatabricksPythonConnectorEngineSpec.supports_oauth2 is True
assert DatabricksPythonConnectorEngineSpec.oauth2_exception is OAuth2RedirectError
assert DatabricksPythonConnectorEngineSpec.oauth2_scope == "sql"
# OAuth2 endpoints are now dynamic and set at runtime
assert DatabricksPythonConnectorEngineSpec.oauth2_authorization_request_uri == ""
assert DatabricksPythonConnectorEngineSpec.oauth2_token_request_uri == ""
def test_impersonate_user_with_token(mocker: MockerFixture) -> None:
"""
Test impersonate_user method with OAuth2 token for DatabricksNativeEngineSpec.
"""
database = mocker.MagicMock()
original_url = make_url(
"databricks+connector://token:original-token@host:443/database"
)
engine_kwargs = {"connect_args": {"access_token": "original-token"}}
# Test with user token
url, kwargs = DatabricksNativeEngineSpec.impersonate_user(
database=database,
username="user1",
user_token="user-oauth-token", # noqa: S106
url=original_url,
engine_kwargs=engine_kwargs,
)
# Check that the password (token) was updated in the URL
assert url.password == "user-oauth-token" # noqa: S105
# Check that access_token was updated in connect_args
assert kwargs["connect_args"]["access_token"] == "user-oauth-token" # noqa: S105
def test_impersonate_user_without_token(mocker: MockerFixture) -> None:
"""
Test impersonate_user method without OAuth2 token.
"""
database = mocker.MagicMock()
original_url = make_url(
"databricks+connector://token:original-token@host:443/database"
)
engine_kwargs = {"connect_args": {"access_token": "original-token"}}
# Test without user token
url, kwargs = DatabricksNativeEngineSpec.impersonate_user(
database=database,
username="user1",
user_token=None,
url=original_url,
engine_kwargs=engine_kwargs,
)
# Check that nothing was changed
assert url.password == "original-token" # noqa: S105
assert kwargs["connect_args"]["access_token"] == "original-token" # noqa: S105
def test_impersonate_user_python_connector(mocker: MockerFixture) -> None:
"""
Test impersonate_user method for DatabricksPythonConnectorEngineSpec.
"""
database = mocker.MagicMock()
original_url = make_url(
"databricks://token:original-token@host:443?http_path=path&catalog=main&schema=default"
)
engine_kwargs = {"connect_args": {"access_token": "original-token"}}
# Test with user token
url, kwargs = DatabricksPythonConnectorEngineSpec.impersonate_user(
database=database,
username="user1",
user_token="user-oauth-token", # noqa: S106
url=original_url,
engine_kwargs=engine_kwargs,
)
# Check that the password (token) was updated in the URL
assert url.password == "user-oauth-token" # noqa: S105
# Check that access_token was updated in connect_args
assert kwargs["connect_args"]["access_token"] == "user-oauth-token" # noqa: S105
@pytest.fixture
def oauth2_config_native() -> OAuth2ClientConfig:
"""
Config for Databricks Native OAuth2.
"""
return {
"id": "databricks-client-id",
"secret": "databricks-client-secret",
"scope": "sql",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/authorize",
"token_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
"request_content_type": "json",
}
@pytest.fixture
def oauth2_config_python() -> OAuth2ClientConfig:
"""
Config for Databricks Python Connector OAuth2.
"""
return {
"id": "databricks-client-id",
"secret": "databricks-client-secret",
"scope": "sql",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/authorize",
"token_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
"request_content_type": "json",
}
def test_is_oauth2_enabled_no_config_native(mocker: MockerFixture) -> None:
"""
Test `is_oauth2_enabled` when OAuth2 is not configured for Native engine.
"""
mocker.patch(
"flask.current_app.config",
new={"DATABASE_OAUTH2_CLIENTS": {}},
)
assert DatabricksNativeEngineSpec.is_oauth2_enabled() is False
def test_is_oauth2_enabled_config_native(mocker: MockerFixture) -> None:
"""
Test `is_oauth2_enabled` when OAuth2 is configured for Native engine.
"""
mocker.patch(
"flask.current_app.config",
new={
"DATABASE_OAUTH2_CLIENTS": {
"Databricks (legacy)": {
"id": "client-id",
"secret": "client-secret",
},
}
},
)
assert DatabricksNativeEngineSpec.is_oauth2_enabled() is True
def test_is_oauth2_enabled_no_config_python(mocker: MockerFixture) -> None:
"""
Test `is_oauth2_enabled` when OAuth2 is not configured for Python Connector engine.
"""
mocker.patch(
"flask.current_app.config",
new={"DATABASE_OAUTH2_CLIENTS": {}},
)
assert DatabricksPythonConnectorEngineSpec.is_oauth2_enabled() is False
def test_is_oauth2_enabled_config_python(mocker: MockerFixture) -> None:
"""
Test `is_oauth2_enabled` when OAuth2 is configured for Python Connector engine.
"""
mocker.patch(
"flask.current_app.config",
new={
"DATABASE_OAUTH2_CLIENTS": {
"Databricks": {
"id": "client-id",
"secret": "client-secret",
},
}
},
)
assert DatabricksPythonConnectorEngineSpec.is_oauth2_enabled() is True
def test_get_oauth2_authorization_uri_native(
mocker: MockerFixture,
oauth2_config_native: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_authorization_uri` for Native engine.
"""
from superset.db_engine_specs.base import OAuth2State
state: OAuth2State = {
"database_id": 1,
"user_id": 1,
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"tab_id": "1234",
}
url = DatabricksNativeEngineSpec.get_oauth2_authorization_uri(
oauth2_config_native, state
)
parsed = urlparse(url)
assert parsed.netloc == "accounts.cloud.databricks.com"
assert parsed.path == "/oidc/accounts/12345/v1/authorize"
query = parse_qs(parsed.query)
assert query["scope"][0] == "sql"
encoded_state = query["state"][0].replace("%2E", ".")
assert decode_oauth2_state(encoded_state) == state
def test_get_oauth2_authorization_uri_python(
mocker: MockerFixture,
oauth2_config_python: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_authorization_uri` for Python Connector engine.
"""
from superset.db_engine_specs.base import OAuth2State
state: OAuth2State = {
"database_id": 1,
"user_id": 1,
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"tab_id": "1234",
}
url = DatabricksPythonConnectorEngineSpec.get_oauth2_authorization_uri(
oauth2_config_python, state
)
parsed = urlparse(url)
assert parsed.netloc == "accounts.cloud.databricks.com"
assert parsed.path == "/oidc/accounts/12345/v1/authorize"
query = parse_qs(parsed.query)
assert query["scope"][0] == "sql"
encoded_state = query["state"][0].replace("%2E", ".")
assert decode_oauth2_state(encoded_state) == state
def test_get_oauth2_token_native(
mocker: MockerFixture,
oauth2_config_native: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_token` for Native engine.
"""
requests = mocker.patch("superset.db_engine_specs.base.requests")
requests.post().json.return_value = {
"access_token": "access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
assert DatabricksNativeEngineSpec.get_oauth2_token(
oauth2_config_native, "authorization-code"
) == {
"access_token": "access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
requests.post.assert_called_with(
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
json={
"code": "authorization-code",
"client_id": "databricks-client-id",
"client_secret": "databricks-client-secret",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"grant_type": "authorization_code",
},
timeout=30.0,
)
def test_get_oauth2_token_python(
mocker: MockerFixture,
oauth2_config_python: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_token` for Python Connector engine.
"""
requests = mocker.patch("superset.db_engine_specs.base.requests")
requests.post().json.return_value = {
"access_token": "access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
assert DatabricksPythonConnectorEngineSpec.get_oauth2_token(
oauth2_config_python, "authorization-code"
) == {
"access_token": "access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
requests.post.assert_called_with(
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
json={
"code": "authorization-code",
"client_id": "databricks-client-id",
"client_secret": "databricks-client-secret",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"grant_type": "authorization_code",
},
timeout=30.0,
)
def test_get_oauth2_fresh_token_native(
mocker: MockerFixture,
oauth2_config_native: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_fresh_token` for Native engine.
"""
requests = mocker.patch("superset.db_engine_specs.base.requests")
requests.post().json.return_value = {
"access_token": "new-access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "new-refresh-token",
}
assert DatabricksNativeEngineSpec.get_oauth2_fresh_token(
oauth2_config_native, "old-refresh-token"
) == {
"access_token": "new-access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "new-refresh-token",
}
requests.post.assert_called_with(
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
json={
"client_id": "databricks-client-id",
"client_secret": "databricks-client-secret",
"refresh_token": "old-refresh-token",
"grant_type": "refresh_token",
},
timeout=30.0,
)
def _oauth2_state() -> OAuth2State:
"""
Build the default OAuth2 state shared by the OAuth2 tests.
"""
state: OAuth2State = {
"database_id": 1,
"user_id": 1,
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"tab_id": "1234",
}
return state
def _unresolved_oauth2_config() -> OAuth2ClientConfig:
"""
Config as built by `get_oauth2_config` when no endpoints are overridden:
the URIs default to the spec's empty class attributes.
"""
return {
"id": "databricks-client-id",
"secret": "databricks-client-secret",
"scope": "sql",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "",
"token_request_uri": "",
"request_content_type": "json",
}
@pytest.mark.parametrize(
"spec",
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
)
@pytest.mark.parametrize(
"extra, netloc, path",
[
(
{"cloud_provider": "aws", "account_id": "acct-999"},
"accounts.cloud.databricks.com",
"/oidc/accounts/acct-999/v1/authorize",
),
(
{"cloud_provider": "azure", "tenant_id": "tenant-abc"},
"login.microsoftonline.com",
"/tenant-abc/oauth2/v2.0/authorize",
),
(
{"cloud_provider": "gcp", "account_id": "acct-gcp"},
"accounts.gcp.databricks.com",
"/oidc/accounts/acct-gcp/v1/authorize",
),
],
)
def test_get_oauth2_authorization_uri_autodetects_and_resolves(
mocker: MockerFixture,
spec: Any,
extra: dict[str, Any],
netloc: str,
path: str,
) -> None:
"""
With no configured `authorization_request_uri`, the endpoint is auto-detected
per cloud provider and the `account_id`/`tenant_id` placeholder is resolved
from the database `extra`.
"""
database = mocker.MagicMock()
database.extra = json.dumps(extra)
database.url_object.host = "dbc-abc.cloud.databricks.com"
mocker.patch("superset.db.session.get", return_value=database)
url = spec.get_oauth2_authorization_uri(
_unresolved_oauth2_config(), _oauth2_state()
)
parsed = urlparse(url)
assert parsed.netloc == netloc
assert parsed.path == path
@pytest.mark.parametrize(
"spec",
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
)
def test_get_oauth2_authorization_uri_preserves_configured(
mocker: MockerFixture,
spec: Any,
) -> None:
"""
A fully-resolved `authorization_request_uri` is never overwritten by the
auto-detected template, and no database lookup is needed.
"""
session_get = mocker.patch("superset.db.session.get")
config = _unresolved_oauth2_config()
config["authorization_request_uri"] = (
"https://accounts.cloud.databricks.com/oidc/accounts/override/v1/authorize"
)
url = spec.get_oauth2_authorization_uri(config, _oauth2_state())
assert urlparse(url).path == "/oidc/accounts/override/v1/authorize"
session_get.assert_not_called()
@pytest.mark.parametrize(
"spec",
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
)
def test_get_oauth2_authorization_uri_fails_without_account_id(
mocker: MockerFixture,
spec: Any,
) -> None:
"""
When the endpoint must be auto-detected but no `account_id`/`tenant_id` is
available, fail fast instead of emitting an unresolved `.../{}/...` URL.
"""
database = mocker.MagicMock()
database.extra = json.dumps({"cloud_provider": "aws"})
database.url_object.host = "dbc-abc.cloud.databricks.com"
mocker.patch("superset.db.session.get", return_value=database)
with pytest.raises(OAuth2Error):
spec.get_oauth2_authorization_uri(_unresolved_oauth2_config(), _oauth2_state())
@pytest.mark.parametrize(
"spec",
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
)
def test_get_oauth2_token_fails_without_uri(
mocker: MockerFixture,
spec: Any,
) -> None:
"""
Token exchange has no database context to auto-detect the endpoint, so a
missing `token_request_uri` fails fast rather than POSTing to `.../{}/...`.
"""
with pytest.raises(OAuth2Error):
spec.get_oauth2_token(_unresolved_oauth2_config(), "authorization-code")
def test_get_oauth2_fresh_token_python(
mocker: MockerFixture,
oauth2_config_python: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_fresh_token` for Python Connector engine.
"""
requests = mocker.patch("superset.db_engine_specs.base.requests")
requests.post().json.return_value = {
"access_token": "new-access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "new-refresh-token",
}
assert DatabricksPythonConnectorEngineSpec.get_oauth2_fresh_token(
oauth2_config_python, "old-refresh-token"
) == {
"access_token": "new-access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "new-refresh-token",
}
requests.post.assert_called_with(
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
json={
"client_id": "databricks-client-id",
"client_secret": "databricks-client-secret",
"refresh_token": "old-refresh-token",
"grant_type": "refresh_token",
},
timeout=30.0,
)

View File

@@ -1,327 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
from unittest.mock import MagicMock
from urllib.parse import parse_qs, urlparse
import pytest
from pytest_mock import MockerFixture
from superset.db_engine_specs.databricks import (
DatabricksNativeEngineSpec,
DatabricksPythonConnectorEngineSpec,
)
from superset.superset_typing import OAuth2ClientConfig
from superset.utils.oauth2 import decode_oauth2_state
# Multi-Cloud Provider Tests
@pytest.fixture
def mock_database_aws(mocker: MockerFixture) -> MagicMock:
"""
Mock database with AWS hostname.
"""
database = mocker.MagicMock()
database.url_object.host = "my-cluster.cloud.databricks.com"
database.extra = "{}"
database.id = 1
return database
@pytest.fixture
def mock_database_azure(mocker: MockerFixture) -> MagicMock:
"""
Mock database with Azure hostname.
"""
database = mocker.MagicMock()
database.url_object.host = "adb-123456789.12.azuredatabricks.net"
database.extra = '{"tenant_id": "azure-tenant-id"}'
database.id = 2
return database
@pytest.fixture
def mock_database_gcp(mocker: MockerFixture) -> MagicMock:
"""
Mock database with GCP hostname.
"""
database = mocker.MagicMock()
database.url_object.host = "123456789.gcp.databricks.com"
database.extra = '{"account_id": "12345"}'
database.id = 3
return database
@pytest.fixture
def oauth2_config() -> OAuth2ClientConfig:
"""
Config for Databricks OAuth2 with a fully-resolved endpoint configured.
"""
return {
"id": "databricks-client-id",
"secret": "databricks-client-secret",
"scope": "sql",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/authorize",
"token_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
"request_content_type": "json",
}
@pytest.fixture
def oauth2_config_no_uri() -> OAuth2ClientConfig:
"""
Config for Databricks OAuth2 without a pre-configured endpoint, so the
per-provider endpoint is auto-detected and account-resolved.
"""
return {
"id": "databricks-client-id",
"secret": "databricks-client-secret",
"scope": "sql",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "",
"token_request_uri": "",
"request_content_type": "json",
}
def test_cloud_provider_detection_aws(mock_database_aws: MagicMock) -> None:
"""
Test cloud provider detection for AWS.
"""
provider = DatabricksNativeEngineSpec._detect_cloud_provider(mock_database_aws)
assert provider == "aws"
def test_cloud_provider_detection_azure(mock_database_azure: MagicMock) -> None:
"""
Test cloud provider detection for Azure.
"""
provider = DatabricksNativeEngineSpec._detect_cloud_provider(mock_database_azure)
assert provider == "azure"
def test_cloud_provider_detection_gcp(mock_database_gcp: MagicMock) -> None:
"""
Test cloud provider detection for GCP.
"""
provider = DatabricksNativeEngineSpec._detect_cloud_provider(mock_database_gcp)
assert provider == "gcp"
def test_cloud_provider_detection_explicit_config(mocker: MockerFixture) -> None:
"""
Test cloud provider detection with explicit configuration.
"""
database = mocker.MagicMock()
database.url_object.host = "generic-host.com"
# Mock get_extra_params to return explicit cloud provider
mocker.patch.object(
DatabricksNativeEngineSpec,
"get_extra_params",
return_value={"cloud_provider": "azure"},
)
provider = DatabricksNativeEngineSpec._detect_cloud_provider(database)
assert provider == "azure"
def test_cloud_provider_detection_invalid_config_falls_back_to_hostname(
mocker: MockerFixture,
) -> None:
"""
An unrecognized explicit `cloud_provider` is ignored and detection falls
back to the hostname rather than raising or returning the bad value.
"""
database = mocker.MagicMock()
database.url_object.host = "adb-123456789.12.azuredatabricks.net"
mocker.patch.object(
DatabricksNativeEngineSpec,
"get_extra_params",
return_value={"cloud_provider": "oracle"},
)
provider = DatabricksNativeEngineSpec._detect_cloud_provider(database)
assert provider == "azure"
def test_cloud_provider_detection_non_string_falls_back_to_hostname(
mocker: MockerFixture,
) -> None:
"""
A non-string `cloud_provider` (e.g. a boolean from malformed JSON) is
ignored without raising and detection falls back to the hostname.
"""
database = mocker.MagicMock()
database.url_object.host = "adb-123456789.12.azuredatabricks.net"
mocker.patch.object(
DatabricksNativeEngineSpec,
"get_extra_params",
return_value={"cloud_provider": True},
)
provider = DatabricksNativeEngineSpec._detect_cloud_provider(database)
assert provider == "azure"
def test_get_oauth2_authorization_uri_aws(
mocker: MockerFixture,
oauth2_config: OAuth2ClientConfig,
mock_database_aws: MagicMock,
) -> None:
"""
Test OAuth2 authorization URI generation for AWS provider.
"""
from superset.db_engine_specs.base import OAuth2State
# Mock the database query
mocker.patch("superset.extensions.db.session.get", return_value=mock_database_aws)
state: OAuth2State = {
"database_id": 1,
"user_id": 1,
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"tab_id": "1234",
}
url = DatabricksNativeEngineSpec.get_oauth2_authorization_uri(oauth2_config, state)
parsed = urlparse(url)
assert parsed.netloc == "accounts.cloud.databricks.com"
assert "/oidc/accounts/" in parsed.path
assert "/v1/authorize" in parsed.path
query = parse_qs(parsed.query)
assert query["scope"][0] == "sql"
encoded_state = query["state"][0].replace("%2E", ".")
assert decode_oauth2_state(encoded_state) == state
def test_get_oauth2_authorization_uri_azure(
mocker: MockerFixture,
oauth2_config_no_uri: OAuth2ClientConfig,
mock_database_azure: MagicMock,
) -> None:
"""
Test OAuth2 authorization URI generation for Azure provider.
"""
from superset.db_engine_specs.base import OAuth2State
# Mock the database query
mocker.patch("superset.extensions.db.session.get", return_value=mock_database_azure)
state: OAuth2State = {
"database_id": 2,
"user_id": 1,
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"tab_id": "1234",
}
url = DatabricksNativeEngineSpec.get_oauth2_authorization_uri(
oauth2_config_no_uri, state
)
parsed = urlparse(url)
assert parsed.netloc == "login.microsoftonline.com"
assert "/oauth2/v2.0/authorize" in parsed.path
query = parse_qs(parsed.query)
assert query["scope"][0] == "sql"
encoded_state = query["state"][0].replace("%2E", ".")
assert decode_oauth2_state(encoded_state) == state
def test_get_oauth2_authorization_uri_gcp(
mocker: MockerFixture,
oauth2_config_no_uri: OAuth2ClientConfig,
mock_database_gcp: MagicMock,
) -> None:
"""
Test OAuth2 authorization URI generation for GCP provider.
"""
from superset.db_engine_specs.base import OAuth2State
# Mock the database query
mocker.patch("superset.extensions.db.session.get", return_value=mock_database_gcp)
state: OAuth2State = {
"database_id": 3,
"user_id": 1,
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"tab_id": "1234",
}
url = DatabricksNativeEngineSpec.get_oauth2_authorization_uri(
oauth2_config_no_uri, state
)
parsed = urlparse(url)
assert parsed.netloc == "accounts.gcp.databricks.com"
assert "/oidc/accounts/" in parsed.path
assert "/v1/authorize" in parsed.path
query = parse_qs(parsed.query)
assert query["scope"][0] == "sql"
encoded_state = query["state"][0].replace("%2E", ".")
assert decode_oauth2_state(encoded_state) == state
def test_python_connector_cloud_provider_detection_azure(
mock_database_azure: MagicMock,
) -> None:
"""
Test cloud provider detection for Python Connector with Azure.
"""
provider = DatabricksPythonConnectorEngineSpec._detect_cloud_provider(
mock_database_azure
)
assert provider == "azure"
def test_python_connector_oauth2_authorization_uri_azure(
mocker: MockerFixture,
oauth2_config_no_uri: OAuth2ClientConfig,
mock_database_azure: MagicMock,
) -> None:
"""
Test OAuth2 authorization URI generation for Python Connector with Azure provider.
"""
from superset.db_engine_specs.base import OAuth2State
# Mock the database query
mocker.patch("superset.extensions.db.session.get", return_value=mock_database_azure)
state: OAuth2State = {
"database_id": 2,
"user_id": 1,
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"tab_id": "1234",
}
url = DatabricksPythonConnectorEngineSpec.get_oauth2_authorization_uri(
oauth2_config_no_uri, state
)
parsed = urlparse(url)
assert parsed.netloc == "login.microsoftonline.com"
assert "/oauth2/v2.0/authorize" in parsed.path
query = parse_qs(parsed.query)
assert query["scope"][0] == "sql"
encoded_state = query["state"][0].replace("%2E", ".")
assert decode_oauth2_state(encoded_state) == state

View File

@@ -30,10 +30,12 @@ from pydantic import ValidationError
from superset.mcp_service.dashboard.schemas import (
_extract_cross_filters_enabled,
_extract_native_filters,
_safe_user_label,
dashboard_serializer,
GenerateDashboardRequest,
serialize_chart_summary,
serialize_dashboard_object,
UpdateDashboardRequest,
)
from superset.mcp_service.utils.sanitization import (
LLM_CONTEXT_CLOSE_DELIMITER,
@@ -87,8 +89,8 @@ def _mock_dashboard(
class TestSerializeDashboardObject:
"""Tests for serialize_dashboard_object slug handling."""
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
def test_slug_none_returns_empty_string(self, mock_base_url):
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
def test_slug_none_returns_empty_string(self, mock_base_url) -> None:
"""Dashboards with slug=None should return slug="" for consistency
with dashboard_serializer."""
mock_base_url.return_value = "http://localhost:8088"
@@ -98,8 +100,8 @@ class TestSerializeDashboardObject:
assert result.slug == ""
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
def test_slug_empty_string_returns_empty_string(self, mock_base_url):
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
def test_slug_empty_string_returns_empty_string(self, mock_base_url) -> None:
"""Dashboards with slug="" should return slug=""."""
mock_base_url.return_value = "http://localhost:8088"
@@ -108,8 +110,8 @@ class TestSerializeDashboardObject:
assert result.slug == ""
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
def test_slug_with_value_preserved(self, mock_base_url):
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
def test_slug_with_value_preserved(self, mock_base_url) -> None:
"""Dashboards with a real slug should preserve it."""
mock_base_url.return_value = "http://localhost:8088"
@@ -118,8 +120,8 @@ class TestSerializeDashboardObject:
assert result.slug == "my-dashboard"
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
def test_url_uses_id_when_no_slug(self, mock_base_url):
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
def test_url_uses_id_when_no_slug(self, mock_base_url) -> None:
"""URL should use dashboard id when slug is None."""
mock_base_url.return_value = "http://localhost:8088"
@@ -128,8 +130,8 @@ class TestSerializeDashboardObject:
assert result.url == "http://localhost:8088/superset/dashboard/42/"
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
def test_url_uses_slug_when_available(self, mock_base_url):
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
def test_url_uses_slug_when_available(self, mock_base_url) -> None:
"""URL should use slug when available."""
mock_base_url.return_value = "http://localhost:8088"
@@ -138,8 +140,8 @@ class TestSerializeDashboardObject:
assert result.url == "http://localhost:8088/superset/dashboard/my-dashboard/"
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
def test_no_json_metadata_or_position_json_in_response(self, mock_base_url):
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
def test_no_json_metadata_or_position_json_in_response(self, mock_base_url) -> None:
"""DashboardInfo should not contain json_metadata or position_json."""
mock_base_url.return_value = "http://localhost:8088"
@@ -150,7 +152,7 @@ class TestSerializeDashboardObject:
assert not hasattr(result, "position_json")
@patch("superset.mcp_service.dashboard.schemas.user_can_view_data_model_metadata")
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
def test_native_filters_extracted_from_json_metadata(
self,
mock_base_url,
@@ -196,7 +198,7 @@ class TestSerializeDashboardObject:
assert result.cross_filters_enabled is True
@patch("superset.mcp_service.dashboard.schemas.user_can_view_data_model_metadata")
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
def test_restricted_user_redacts_native_filter_targets(
self,
mock_base_url,
@@ -230,7 +232,7 @@ class TestSerializeDashboardObject:
assert result.cross_filters_enabled is True
@patch("superset.mcp_service.dashboard.schemas.user_can_view_data_model_metadata")
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
def test_chart_summaries_are_lightweight(
self,
mock_base_url,
@@ -262,7 +264,7 @@ class TestSerializeDashboardObject:
assert not hasattr(result.charts[0], "owners")
@patch("superset.mcp_service.dashboard.schemas.user_can_view_data_model_metadata")
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
def test_restricted_user_redacts_chart_datasource_name(
self,
mock_base_url,
@@ -288,7 +290,7 @@ class TestSerializeDashboardObject:
assert result.charts[0].url == "http://localhost:8088/explore/?slice_id=5"
@patch("superset.mcp_service.dashboard.schemas.user_can_view_data_model_metadata")
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
def test_dashboard_serializer_restricted_user_redacts_data_model_metadata(
self,
mock_base_url,
@@ -327,7 +329,7 @@ class TestSerializeDashboardObject:
assert result.native_filters[0].targets == []
@patch("superset.mcp_service.dashboard.schemas.user_can_view_data_model_metadata")
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
def test_descriptive_fields_are_sanitized(
self,
mock_base_url: MagicMock,
@@ -393,22 +395,22 @@ class TestSerializeDashboardObject:
class TestExtractNativeFilters:
"""Tests for _extract_native_filters helper."""
def test_none_input(self):
def test_none_input(self) -> None:
assert _extract_native_filters(None) == []
def test_empty_string(self):
def test_empty_string(self) -> None:
assert _extract_native_filters("") == []
def test_invalid_json(self):
def test_invalid_json(self) -> None:
assert _extract_native_filters("not json") == []
def test_no_filter_config(self):
def test_no_filter_config(self) -> None:
assert _extract_native_filters("{}") == []
def test_non_list_filter_config(self):
def test_non_list_filter_config(self) -> None:
assert _extract_native_filters('{"native_filter_configuration": "bad"}') == []
def test_valid_filters(self):
def test_valid_filters(self) -> None:
metadata = json_dumps(
{
"native_filter_configuration": [
@@ -428,7 +430,7 @@ class TestExtractNativeFilters:
assert result[0].filter_type == "filter_select"
assert result[0].targets == []
def test_valid_filters_include_targets_when_metadata_allowed(self):
def test_valid_filters_include_targets_when_metadata_allowed(self) -> None:
metadata = json_dumps(
{
"native_filter_configuration": [
@@ -447,14 +449,14 @@ class TestExtractNativeFilters:
)
assert result[0].targets == [{"column": {"name": "col1"}}]
def test_skips_non_dict_entries(self):
def test_skips_non_dict_entries(self) -> None:
metadata = json_dumps(
{"native_filter_configuration": [{"id": "f1", "name": "ok"}, "bad", 123]}
)
result = _extract_native_filters(metadata)
assert len(result) == 1
def test_non_dict_top_level_json(self):
def test_non_dict_top_level_json(self) -> None:
"""json_metadata that parses to a list/number should return empty."""
assert _extract_native_filters("[]") == []
assert _extract_native_filters("123") == []
@@ -464,26 +466,26 @@ class TestExtractNativeFilters:
class TestExtractCrossFiltersEnabled:
"""Tests for _extract_cross_filters_enabled helper."""
def test_none_input(self):
def test_none_input(self) -> None:
assert _extract_cross_filters_enabled(None) is None
def test_empty_json(self):
def test_empty_json(self) -> None:
assert _extract_cross_filters_enabled("{}") is None
def test_true(self):
def test_true(self) -> None:
assert _extract_cross_filters_enabled('{"cross_filters_enabled": true}') is True
def test_false(self):
def test_false(self) -> None:
assert (
_extract_cross_filters_enabled('{"cross_filters_enabled": false}') is False
)
def test_non_bool_value(self):
def test_non_bool_value(self) -> None:
assert (
_extract_cross_filters_enabled('{"cross_filters_enabled": "yes"}') is None
)
def test_non_dict_top_level_json(self):
def test_non_dict_top_level_json(self) -> None:
"""json_metadata that parses to a list/number should return None."""
assert _extract_cross_filters_enabled("[]") is None
assert _extract_cross_filters_enabled("123") is None
@@ -493,7 +495,7 @@ class TestExtractCrossFiltersEnabled:
class TestSerializeChartSummary:
"""Tests for serialize_chart_summary helper."""
def test_datasource_name_redacted_by_default(self):
def test_datasource_name_redacted_by_default(self) -> None:
chart = MagicMock()
chart.id = 5
chart.slice_name = "Revenue Chart"
@@ -510,7 +512,7 @@ class TestSerializeChartSummary:
class TestOmittedFieldsBuilder:
"""Tests for the shared OmittedFieldsBuilder utility."""
def test_builder_basic(self):
def test_builder_basic(self) -> None:
from superset.mcp_service.utils.response_utils import OmittedFieldsBuilder
result = (
@@ -525,7 +527,7 @@ class TestOmittedFieldsBuilder:
assert "meta_field" in result
assert "extracted" in result["meta_field"]
def test_builder_none_values(self):
def test_builder_none_values(self) -> None:
from superset.mcp_service.utils.response_utils import OmittedFieldsBuilder
result = (
@@ -537,8 +539,8 @@ class TestOmittedFieldsBuilder:
assert "empty" in result["empty_field"]
assert "empty" in result["also_empty"]
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
def test_omitted_fields_in_serialized_dashboard(self, mock_base_url):
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
def test_omitted_fields_in_serialized_dashboard(self, mock_base_url) -> None:
"""omitted_fields should describe what was stripped and include sizes."""
mock_base_url.return_value = "http://localhost:8088"
@@ -555,8 +557,8 @@ class TestOmittedFieldsBuilder:
assert "extracted" in result.omitted_fields["json_metadata"]
assert "layout tree" in result.omitted_fields["position_json"].lower()
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
def test_omitted_fields_with_none_values(self, mock_base_url):
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
def test_omitted_fields_with_none_values(self, mock_base_url) -> None:
"""omitted_fields should still be present when raw fields are None."""
mock_base_url.return_value = "http://localhost:8088"
@@ -625,3 +627,212 @@ class TestGenerateDashboardRequestTitleSanitization:
assert len(req.sanitization_warnings) == 1
assert "dashboard_title" in req.sanitization_warnings[0]
assert "injected" not in req.sanitization_warnings[0]
class TestGenerateDashboardRequestLayoutTheme:
"""generate_dashboard accepts optional position_json, theme overrides, CSS."""
def test_layout_theme_css_fields_default_to_none(self) -> None:
req = GenerateDashboardRequest(chart_ids=[1])
assert req.position_json is None
assert req.json_metadata_overrides is None
assert req.css is None
assert req.slug is None
def test_position_json_accepted(self) -> None:
position = {
"ROOT_ID": {"children": ["GRID_ID"], "type": "ROOT"},
"GRID_ID": {"children": ["ROW-1"], "type": "GRID"},
}
req = GenerateDashboardRequest(chart_ids=[1], position_json=position)
assert req.position_json == position
def test_json_metadata_overrides_accepted(self) -> None:
overrides = {
"label_colors": {"Electronics": "#4C78A8"},
"cross_filters_enabled": False,
}
req = GenerateDashboardRequest(chart_ids=[1], json_metadata_overrides=overrides)
assert req.json_metadata_overrides == overrides
def test_css_accepted(self) -> None:
req = GenerateDashboardRequest(
chart_ids=[1], css=".header-controls{display:none}"
)
assert req.css == ".header-controls{display:none}"
def test_slug_accepted(self) -> None:
req = GenerateDashboardRequest(chart_ids=[1], slug="my-dashboard")
assert req.slug == "my-dashboard"
def test_css_max_length_enforced(self) -> None:
with pytest.raises(ValidationError, match="at most 50000"):
GenerateDashboardRequest(chart_ids=[1], css="x" * 50001)
def test_title_alias_accepted(self) -> None:
"""``title`` is one of the AliasChoices for ``dashboard_title``
— JSON callers using either name resolve to the same field."""
req = GenerateDashboardRequest(chart_ids=[1], title="Q4 Review")
assert req.dashboard_title == "Q4 Review"
def test_name_alias_accepted(self) -> None:
"""``name`` is the third AliasChoice for ``dashboard_title``
and must resolve identically to ``title`` and ``dashboard_title``."""
req = GenerateDashboardRequest(chart_ids=[1], name="Q4 Review")
assert req.dashboard_title == "Q4 Review"
def test_published_defaults_to_false(self) -> None:
"""``published`` defaults to False — newly generated dashboards
are drafts by default to avoid accidentally publishing partial
work-in-progress to all users."""
req = GenerateDashboardRequest(chart_ids=[1])
assert req.published is False
def test_published_true_accepted(self) -> None:
"""An explicit ``published=True`` is preserved."""
req = GenerateDashboardRequest(chart_ids=[1], published=True)
assert req.published is True
class TestUpdateDashboardRequest:
"""Schema validation for update_dashboard's request."""
def test_identifier_required(self) -> None:
with pytest.raises(ValidationError, match="Field required"):
UpdateDashboardRequest()
def test_int_identifier_accepted(self) -> None:
req = UpdateDashboardRequest(identifier=42)
assert req.identifier == 42
def test_string_identifier_accepted(self) -> None:
req = UpdateDashboardRequest(identifier="my-slug")
assert req.identifier == "my-slug"
def test_all_optional_fields_default_to_none(self) -> None:
req = UpdateDashboardRequest(identifier=1)
assert req.dashboard_title is None
assert req.description is None
assert req.slug is None
assert req.published is None
assert req.position_json is None
assert req.json_metadata_overrides is None
assert req.css is None
def test_position_json_and_overrides_and_css(self) -> None:
req = UpdateDashboardRequest(
identifier=42,
position_json={"ROOT_ID": {"type": "ROOT"}},
json_metadata_overrides={"cross_filters_enabled": True},
css=".x{}",
)
assert req.position_json == {"ROOT_ID": {"type": "ROOT"}}
assert req.json_metadata_overrides == {"cross_filters_enabled": True}
assert req.css == ".x{}"
def test_title_alias_accepted(self) -> None:
"""`title` is accepted as an alias for `dashboard_title`."""
req = UpdateDashboardRequest(identifier=1, title="New Title")
assert req.dashboard_title == "New Title"
def test_name_alias_accepted(self) -> None:
"""`name` is accepted as an alias for `dashboard_title` — mirrors
``GenerateDashboardRequest``'s third AliasChoice so callers can
use the same key name on both create and update paths."""
req = UpdateDashboardRequest(identifier=1, name="New Title")
assert req.dashboard_title == "New Title"
def test_css_max_length_enforced(self) -> None:
with pytest.raises(ValidationError, match="at most 50000"):
UpdateDashboardRequest(identifier=1, css="x" * 50001)
def test_title_partial_strip_emits_warning(self) -> None:
"""Mirror of ``test_title_partial_strip_emits_warning`` on the
create path — sanitization removes the HTML, the title survives,
and a warning records that the input was altered."""
req = UpdateDashboardRequest(identifier=1, dashboard_title="Q1 <b>Review</b>")
assert req.dashboard_title == "Q1 Review"
assert len(req.sanitization_warnings) == 1
assert "dashboard_title" in req.sanitization_warnings[0]
def test_client_supplied_warnings_are_discarded(self) -> None:
"""``sanitization_warnings`` is server-only on the update path
too — caller-supplied entries are dropped so an attacker cannot
smuggle warning text through the response."""
req = UpdateDashboardRequest(
identifier=1,
dashboard_title="Clean Title",
sanitization_warnings=["<script>injected</script>"],
)
assert req.sanitization_warnings == []
def test_title_xss_only_rejected_at_schema_level(self) -> None:
"""An XSS-only title is rejected by the Pydantic validator
before the tool ever runs — matches the create path's guard."""
with pytest.raises(ValidationError, match="removed entirely"):
UpdateDashboardRequest(
identifier=1,
dashboard_title="<script>alert(1)</script>",
)
def test_title_alias_xss_rejected(self) -> None:
"""The ``title`` alias resolves to the same sanitized field, so
XSS-only input supplied via the alias must be rejected with the
same guard. Otherwise an attacker could bypass sanitization just
by choosing a different request key."""
with pytest.raises(ValidationError, match="removed entirely"):
UpdateDashboardRequest(
identifier=1,
title="<script>alert(1)</script>",
)
def test_name_alias_xss_rejected(self) -> None:
"""Same as ``test_title_alias_xss_rejected`` for the ``name``
AliasChoice — every alias funnels through the same validator,
not just the canonical field name."""
with pytest.raises(ValidationError, match="removed entirely"):
UpdateDashboardRequest(
identifier=1,
name="<script>alert(1)</script>",
)
class TestSafeUserLabel:
"""``_safe_user_label`` defensively coerces ``*_by_name`` attributes
so dashboard serialization never leaks a ``repr(user)`` or trips
Pydantic with a non-string value."""
def test_plain_string_passes_through(self) -> None:
assert _safe_user_label("alice") == "alice"
def test_empty_string_returns_none(self) -> None:
"""Empty string is collapsed to None so the response carries
an explicit "no author" signal rather than a misleading ""."""
assert _safe_user_label("") is None
def test_none_returns_none(self) -> None:
assert _safe_user_label(None) is None
def test_mock_object_returns_none(self) -> None:
"""Mocks (and anything else non-string) become None — this is
the case the helper was specifically introduced to handle."""
from unittest.mock import MagicMock
assert _safe_user_label(MagicMock()) is None
def test_user_object_returns_none(self) -> None:
"""A User instance also coerces to None rather than leaking
``repr(user)`` (which can contain memory addresses, hashes, or
internal id fields). Callers that want a user display name
should resolve it explicitly via ``created_by_name`` upstream."""
class _User:
def __repr__(self) -> str:
return "<User id=42 username='alice'>"
assert _safe_user_label(_User()) is None
def test_integer_returns_none(self) -> None:
"""Numbers and other non-string scalars also coerce to None
rather than being str-cast and silently leaking the value."""
assert _safe_user_label(42) is None

View File

@@ -206,8 +206,10 @@ class TestGenerateDashboard:
== "Analytics Dashboard"
)
assert result.structured_content["dashboard"]["chart_count"] == 2
# URL prefers the slug over the id, matching update_dashboard.
assert (
"/superset/dashboard/10/" in result.structured_content["dashboard_url"]
"/superset/dashboard/test-dashboard-10/"
in result.structured_content["dashboard_url"]
)
@patch("superset.models.dashboard.Dashboard")
@@ -461,6 +463,117 @@ class TestGenerateDashboard:
# rollback called by tool + event_logger error handling
assert mock_db_session.rollback.call_count >= 1
@patch("superset.models.dashboard.Dashboard")
@patch("superset.db.session")
@pytest.mark.asyncio
async def test_generate_dashboard_duplicate_slug_returns_actionable_error(
self,
mock_db_session,
mock_dashboard_cls,
mcp_server,
) -> None:
"""When the supplied slug collides with an existing dashboard,
``commit()`` raises ``IntegrityError``. The tool catches it,
recognises the slug-uniqueness violation, and returns a
structured error naming the offending slug so the LLM can
propose a different one — instead of the generic
"internal error" message used for other DB failures."""
from sqlalchemy.exc import IntegrityError
mock_query = Mock()
mock_filter = Mock()
mock_query.filter.return_value = mock_filter
mock_query.filter_by.return_value = mock_filter
mock_filter.order_by.return_value = mock_filter
mock_filter.all.return_value = [_mock_chart(id=1)]
mock_filter.first.return_value = Mock(
id=1,
username="admin",
first_name="Admin",
last_name="User",
email="admin@example.com",
active=True,
)
mock_db_session.query.return_value = mock_query
# Postgres-shaped IntegrityError naming the slug constraint.
mock_db_session.commit.side_effect = IntegrityError(
statement="INSERT INTO dashboards ...",
params={},
orig=Exception(
"duplicate key value violates unique constraint "
'"dashboards_slug_key"\n'
"DETAIL: Key (slug)=(my-slug) already exists."
),
)
mock_dashboard_cls.return_value = _mock_dashboard(id=100)
request = {
"chart_ids": [1],
"dashboard_title": "Q4 Review",
"slug": "my-slug",
}
async with Client(mcp_server) as client:
result = await client.call_tool("generate_dashboard", {"request": request})
err = result.structured_content["error"]
assert err is not None
assert "my-slug" in err
assert "already in use" in err
assert "different slug" in err or "Choose a different" in err
assert result.structured_content["dashboard"] is None
assert mock_db_session.rollback.call_count >= 1
@patch("superset.models.dashboard.Dashboard")
@patch("superset.db.session")
@pytest.mark.asyncio
async def test_generate_dashboard_integrity_error_unrelated_to_slug(
self,
mock_db_session,
mock_dashboard_cls,
mcp_server,
) -> None:
"""An IntegrityError that is NOT about the slug (e.g. an FK
violation on chart_id) gets the generic constraint message
rather than the slug-specific one — slug detection must not
match every IntegrityError indiscriminately."""
from sqlalchemy.exc import IntegrityError
mock_query = Mock()
mock_filter = Mock()
mock_query.filter.return_value = mock_filter
mock_query.filter_by.return_value = mock_filter
mock_filter.order_by.return_value = mock_filter
mock_filter.all.return_value = [_mock_chart(id=1)]
mock_filter.first.return_value = Mock(
id=1,
username="admin",
first_name="Admin",
last_name="User",
email="admin@example.com",
active=True,
)
mock_db_session.query.return_value = mock_query
mock_db_session.commit.side_effect = IntegrityError(
statement="INSERT INTO dashboard_slices ...",
params={},
orig=Exception(
'violates foreign key constraint "dashboard_slices_slice_id_fkey"'
),
)
mock_dashboard_cls.return_value = _mock_dashboard(id=101)
# No slug → the slug-detection branch must not match.
request = {"chart_ids": [1], "dashboard_title": "No Slug Dashboard"}
async with Client(mcp_server) as client:
result = await client.call_tool("generate_dashboard", {"request": request})
err = result.structured_content["error"]
assert err is not None
assert "constraint" in err
assert "slug" not in err.lower()
@patch("superset.models.dashboard.Dashboard")
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
@patch("superset.db.session")
@@ -575,6 +688,160 @@ class TestGenerateDashboard:
created = mock_dashboard_cls.return_value
assert created.dashboard_title == ""
@patch("superset.models.dashboard.Dashboard")
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
@patch("superset.db.session")
@pytest.mark.asyncio
async def test_generate_dashboard_position_json_override(
self,
mock_db_session,
mock_find_by_id,
mock_dashboard_cls,
mcp_server,
) -> None:
"""An explicit ``position_json`` replaces the auto-generated layout
in full — the tool serializes the caller's dict verbatim into the
dashboard's ``position_json`` column rather than calling the
layout-builder helper."""
from superset.utils import json
charts = [_mock_chart(id=1, slice_name="Sales")]
mock_dashboard = _mock_dashboard(id=70, title="Custom Layout Dashboard")
_setup_generate_dashboard_mocks(
mock_db_session,
mock_find_by_id,
mock_dashboard_cls,
charts,
mock_dashboard,
)
custom_layout = {
"ROOT_ID": {"type": "ROOT", "children": ["GRID_ID"]},
"GRID_ID": {"type": "GRID", "children": ["ROW-custom"]},
"ROW-custom": {
"type": "ROW",
"children": ["CHART-1"],
"meta": {"background": "BACKGROUND_TRANSPARENT"},
},
"CHART-1": {
"type": "CHART",
"children": [],
"meta": {"chartId": 1, "width": 12, "height": 100},
},
}
request = {
"chart_ids": [1],
"dashboard_title": "Custom Layout Dashboard",
"position_json": custom_layout,
}
async with Client(mcp_server) as client:
result = await client.call_tool("generate_dashboard", {"request": request})
assert result.structured_content["error"] is None
created = mock_dashboard_cls.return_value
# position_json on the model is a JSON string; round-trip it
# and verify the caller's layout — not the auto-generated 2-col
# grid — was written.
stored = json.loads(created.position_json)
assert stored == custom_layout
# The auto-generated layout's HEADER/ROW ids wouldn't match
# `ROW-custom`; this sanity-check guards against regressions
# where the override silently merges with the default.
assert "ROW-custom" in stored
@patch("superset.models.dashboard.Dashboard")
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
@patch("superset.db.session")
@pytest.mark.asyncio
async def test_generate_dashboard_json_metadata_overrides_shallow_merged(
self,
mock_db_session,
mock_find_by_id,
mock_dashboard_cls,
mcp_server,
) -> None:
"""``json_metadata_overrides`` is shallow-merged on top of the
defaults: caller-supplied keys win, defaults for keys the caller
didn't touch survive."""
from superset.utils import json
charts = [_mock_chart(id=1, slice_name="Sales")]
mock_dashboard = _mock_dashboard(id=71, title="Themed Dashboard")
_setup_generate_dashboard_mocks(
mock_db_session,
mock_find_by_id,
mock_dashboard_cls,
charts,
mock_dashboard,
)
overrides = {
"label_colors": {"Electronics": "#4C78A8"},
"cross_filters_enabled": True,
}
request = {
"chart_ids": [1],
"dashboard_title": "Themed Dashboard",
"json_metadata_overrides": overrides,
}
async with Client(mcp_server) as client:
result = await client.call_tool("generate_dashboard", {"request": request})
assert result.structured_content["error"] is None
created = mock_dashboard_cls.return_value
merged = json.loads(created.json_metadata)
# Caller overrides land verbatim
assert merged["label_colors"] == {"Electronics": "#4C78A8"}
assert merged["cross_filters_enabled"] is True
# Defaults for keys the caller did NOT supply must still be
# present — this is what makes it a *shallow merge* rather
# than a full replace.
assert "timed_refresh_immune_slices" in merged
assert "refresh_frequency" in merged
@patch("superset.models.dashboard.Dashboard")
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
@patch("superset.db.session")
@pytest.mark.asyncio
async def test_generate_dashboard_slug_and_css_applied(
self,
mock_db_session,
mock_find_by_id,
mock_dashboard_cls,
mcp_server,
) -> None:
"""Both ``slug`` and ``css`` land on the created dashboard model
when supplied. Verifying via the model directly — not just the
response — confirms the tool wrote the fields, not merely echoed
them back."""
charts = [_mock_chart(id=1, slice_name="Sales")]
mock_dashboard = _mock_dashboard(id=72, title="Branded Dashboard")
_setup_generate_dashboard_mocks(
mock_db_session,
mock_find_by_id,
mock_dashboard_cls,
charts,
mock_dashboard,
)
css_value = ".header-controls { display: none; }"
request = {
"chart_ids": [1],
"dashboard_title": "Branded Dashboard",
"slug": "branded-q4",
"css": css_value,
}
async with Client(mcp_server) as client:
result = await client.call_tool("generate_dashboard", {"request": request})
assert result.structured_content["error"] is None
created = mock_dashboard_cls.return_value
assert created.slug == "branded-q4"
assert created.css == css_value
class TestAddChartToExistingDashboard:
"""Tests for add_chart_to_existing_dashboard MCP tool."""

View File

@@ -0,0 +1,311 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for the update_dashboard MCP tool."""
from datetime import datetime
from unittest.mock import Mock, patch
import pytest
from fastmcp import Client
from superset.mcp_service.app import mcp
from superset.utils import json
@pytest.fixture
def mcp_server():
return mcp
@pytest.fixture(autouse=True)
def mock_auth():
"""Mock authentication for all tests in this module."""
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
with patch("superset.security_manager.raise_for_ownership"):
mock_user = Mock()
mock_user.id = 1
mock_user.username = "admin"
mock_get_user.return_value = mock_user
yield mock_get_user
def _mock_dashboard(
id: int = 42,
title: str = "Test Dashboard",
slug: str | None = "test-slug",
published: bool = True,
css: str | None = None,
json_metadata: str | None = None,
position_json: str | None = None,
):
"""Build a Mock with EVERY field the DashboardInfo serializer touches
explicitly set. Without this, Mock returns auto-Mock objects for
unset attributes, which Pydantic rejects as wrong-type."""
dashboard = Mock()
dashboard.id = id
dashboard.dashboard_title = title
dashboard.slug = slug
dashboard.description = "desc"
dashboard.published = published
dashboard.css = css
dashboard.json_metadata = json_metadata or json.dumps({"label_colors": {}})
dashboard.position_json = position_json
dashboard.certified_by = None
dashboard.certification_details = None
dashboard.is_managed_externally = False
dashboard.external_url = None
dashboard.created_on = datetime(2024, 1, 1)
dashboard.changed_on = datetime(2024, 1, 1)
dashboard.created_by = Mock(username="admin")
dashboard.changed_by = Mock(username="admin")
dashboard.created_by_name = "admin"
dashboard.changed_by_name = "admin"
dashboard.created_on_humanized = "a day ago"
dashboard.changed_on_humanized = "a day ago"
dashboard.uuid = f"dashboard-uuid-{id}"
dashboard.slices = []
dashboard.owners = []
dashboard.tags = []
return dashboard
class TestUpdateDashboard:
"""update_dashboard patches existing dashboard layout/theme/CSS."""
@patch("superset.daos.dashboard.DashboardDAO.get_by_id_or_slug")
@patch("superset.extensions.db.session")
@pytest.mark.asyncio
async def test_update_layout_theme_and_css(
self, mock_session, mock_get, mcp_server
) -> None:
dash = _mock_dashboard(
id=42,
json_metadata=json.dumps(
{"label_colors": {"Old": "#000"}, "cross_filters_enabled": True}
),
)
mock_get.return_value = dash
position = {"ROOT_ID": {"type": "ROOT", "children": ["GRID_ID"]}}
overrides = {
"label_colors": {"Electronics": "#4C78A8"},
"cross_filters_enabled": False,
}
async with Client(mcp_server) as client:
result = await client.call_tool(
"update_dashboard",
{
"request": {
"identifier": 42,
"position_json": position,
"json_metadata_overrides": overrides,
"css": ".x{color:red}",
}
},
)
# All three top-level writes applied to the model
assert dash.position_json == json.dumps(position)
assert dash.css == ".x{color:red}"
# json_metadata is shallow-merged: label_colors REPLACED (top-level
# key), but other keys not in overrides preserved
merged = json.loads(dash.json_metadata)
assert merged["label_colors"] == {"Electronics": "#4C78A8"}
assert merged["cross_filters_enabled"] is False
assert mock_session.commit.call_count >= 1
# changed_fields enumerates what actually changed.
# StructuredContentStripperMiddleware strips structured_content;
# the JSON-encoded response lives in content[0].text.
payload = json.loads(result.content[0].text)
changed = set(payload.get("changed_fields") or [])
assert {"position_json", "json_metadata", "css"} <= changed
@patch("superset.daos.dashboard.DashboardDAO.get_by_id_or_slug")
@patch("superset.extensions.db.session")
@pytest.mark.asyncio
async def test_update_with_no_fields_is_noop(
self, mock_session, mock_get, mcp_server
) -> None:
dash = _mock_dashboard(id=42)
original_css = dash.css
original_title = dash.dashboard_title
mock_get.return_value = dash
async with Client(mcp_server) as client:
result = await client.call_tool(
"update_dashboard", {"request": {"identifier": 42}}
)
# Nothing modified — dashboard fields unchanged, warning emitted
assert dash.css == original_css
assert dash.dashboard_title == original_title
# StructuredContentStripperMiddleware strips structured_content;
# the JSON-encoded response lives in content[0].text.
payload = json.loads(result.content[0].text)
warnings = payload.get("warnings") or []
assert any("No fields provided" in w for w in warnings)
@patch("superset.daos.dashboard.DashboardDAO.get_by_id_or_slug")
@pytest.mark.asyncio
async def test_update_missing_dashboard_returns_error(
self, mock_get, mcp_server
) -> None:
# get_by_id_or_slug raises when not found; the tool catches and
# returns a structured DashboardError
from superset.commands.dashboard.exceptions import (
DashboardNotFoundError,
)
mock_get.side_effect = DashboardNotFoundError()
async with Client(mcp_server) as client:
result = await client.call_tool(
"update_dashboard", {"request": {"identifier": 999999}}
)
payload = json.loads(result.content[0].text)
assert "not found" in (payload.get("error") or "").lower()
@patch("superset.daos.dashboard.DashboardDAO.get_by_id_or_slug")
@patch("superset.extensions.db.session")
@pytest.mark.asyncio
async def test_update_title_and_slug_and_published(
self, mock_session, mock_get, mcp_server
) -> None:
dash = _mock_dashboard(id=42, published=False)
mock_get.return_value = dash
async with Client(mcp_server) as client:
await client.call_tool(
"update_dashboard",
{
"request": {
"identifier": 42,
"dashboard_title": "Renamed",
"slug": "renamed-slug",
"published": True,
}
},
)
assert dash.dashboard_title == "Renamed"
assert dash.slug == "renamed-slug"
assert dash.published is True
@patch("superset.daos.dashboard.DashboardDAO.get_by_id_or_slug")
@patch("superset.extensions.db.session")
@pytest.mark.asyncio
async def test_update_description(self, mock_session, mock_get, mcp_server) -> None:
"""A description-only update writes ``description`` and reports
it in ``changed_fields`` without touching other fields."""
dash = _mock_dashboard(id=42)
original_title = dash.dashboard_title
original_slug = dash.slug
mock_get.return_value = dash
async with Client(mcp_server) as client:
result = await client.call_tool(
"update_dashboard",
{
"request": {
"identifier": 42,
"description": "Q4 executive review — refreshed weekly.",
}
},
)
assert dash.description == "Q4 executive review — refreshed weekly."
# Other fields untouched
assert dash.dashboard_title == original_title
assert dash.slug == original_slug
payload = json.loads(result.content[0].text)
assert "description" in payload.get("changed_fields", [])
@patch("superset.daos.dashboard.DashboardDAO.get_by_id_or_slug")
@patch("superset.extensions.db.session")
@pytest.mark.asyncio
async def test_empty_slug_clears_slug(
self, mock_session, mock_get, mcp_server
) -> None:
"""An explicit empty string clears the slug."""
dash = _mock_dashboard(id=42, slug="had-a-slug")
mock_get.return_value = dash
async with Client(mcp_server) as client:
await client.call_tool(
"update_dashboard",
{"request": {"identifier": 42, "slug": ""}},
)
assert dash.slug is None
@patch("superset.daos.dashboard.DashboardDAO.get_by_id_or_slug")
@pytest.mark.asyncio
async def test_non_owner_gets_permission_denied(self, mock_get, mcp_server) -> None:
"""A user without ownership on the dashboard receives a
permission_denied response — the class-level Dashboard.write
permission is not enough on its own.
"""
from superset.exceptions import SupersetSecurityException
dash = _mock_dashboard(id=42)
mock_get.return_value = dash
# mock_auth fixture patches raise_for_ownership to a no-op for the
# whole module; override here so this one test sees the real
# ownership rejection path.
with patch(
"superset.security_manager.raise_for_ownership",
side_effect=SupersetSecurityException(Mock(message="forbidden")),
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"update_dashboard",
{
"request": {
"identifier": 42,
"dashboard_title": "Hostile rename",
}
},
)
payload = json.loads(result.content[0].text)
assert payload.get("permission_denied") is True
assert "permission" in (payload.get("error") or "").lower()
# Title was NOT applied
assert dash.dashboard_title == "Test Dashboard"
@pytest.mark.asyncio
async def test_xss_only_title_is_rejected(self, mcp_server) -> None:
"""A dashboard_title that sanitizes to an empty string raises at
the Pydantic layer — same guard as ``generate_dashboard``. The
update path must not be a backdoor for XSS payloads."""
from fastmcp.exceptions import ToolError
async with Client(mcp_server) as client:
with pytest.raises(ToolError, match="dashboard_title"):
await client.call_tool(
"update_dashboard",
{
"request": {
"identifier": 42,
"dashboard_title": "<script>alert(1)</script>",
}
},
)

View File

@@ -59,6 +59,7 @@ def test_default_query_object_to_dict():
"series_limit": 0,
"series_limit_metric": None,
"time_shift": None,
"time_compare_full_range": False,
"to_dttm": None,
}

View File

@@ -25,7 +25,9 @@ from superset.security.api import RlsRuleSchema
@pytest.mark.parametrize(
"app",
[{"WTF_CSRF_ENABLED": True}],
# Enable the Swagger UI / OpenAPI spec (opt-in, off by default) so the
# OpenApi blueprint is registered and included in the exempt set below.
[{"WTF_CSRF_ENABLED": True, "FAB_API_SWAGGER_UI": True}],
indirect=True,
)
def test_csrf_exempt_blueprints(app_context: None) -> None: