mirror of
https://github.com/apache/superset.git
synced 2026-06-26 09:59:21 +00:00
Compare commits
8 Commits
adopt/data
...
move-datab
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e49fd50489 | ||
|
|
f012128700 | ||
|
|
d8bcc66472 | ||
|
|
4b9b8187b3 | ||
|
|
83f7dc9d5b | ||
|
|
baca76ebe0 | ||
|
|
9a11c15a33 | ||
|
|
a90c8e0347 |
2
.github/CODEOWNERS
vendored
2
.github/CODEOWNERS
vendored
@@ -38,7 +38,7 @@
|
||||
|
||||
# Notify translation maintainers of changes to translations
|
||||
|
||||
/superset/translations/ @sfirke @rusackas
|
||||
/superset/translations/ @sfirke @rusackas @villebro @sadpandajoe @hainenber
|
||||
|
||||
# Notify PMC members of changes to extension-related files
|
||||
|
||||
|
||||
@@ -32,6 +32,10 @@ The `rls` rules passed to `POST /api/v1/security/guest_token/` are now validated
|
||||
|
||||
When the MCP service has JWT auth enabled (`MCP_AUTH_ENABLED = True`), an audience must be configured via `MCP_JWT_AUDIENCE` so issued tokens are bound to this service. The service now fails to start with a clear configuration error when the audience is unset, instead of starting with audience validation skipped. Deployments that enable MCP JWT auth must set `MCP_JWT_AUDIENCE` to the audience value their identity provider issues for the MCP service. API-key-only MCP deployments (JWT auth disabled) are unaffected.
|
||||
|
||||
### Swagger UI is opt-in (off by default)
|
||||
|
||||
`FAB_API_SWAGGER_UI` now defaults to `False` and is driven by the `SUPERSET_ENABLE_SWAGGER_UI` environment variable. The interactive Swagger UI / OpenAPI documentation endpoints (e.g. `/swagger/v1`) are therefore no longer exposed by default. To enable them, set `SUPERSET_ENABLE_SWAGGER_UI=true` (the bundled Docker development environment sets this) or override `FAB_API_SWAGGER_UI = True` in `superset_config.py`.
|
||||
|
||||
### Pivot table First/Last aggregations follow data order
|
||||
|
||||
The pivot table chart's `First` and `Last` aggregations now return the first and last value in data (query result) order, instead of effectively returning the minimum and maximum. Existing pivot tables that use these aggregations for totals/subtotals may show different values after upgrading. For deterministic results, ensure the underlying query has a stable sort order.
|
||||
|
||||
@@ -70,6 +70,8 @@ SUPERSET_LOG_LEVEL=info
|
||||
|
||||
SUPERSET_APP_ROOT="/"
|
||||
SUPERSET_ENV=development
|
||||
# Swagger UI is opt-in (off by default); enable it for local development.
|
||||
SUPERSET_ENABLE_SWAGGER_UI=true
|
||||
SUPERSET_LOAD_EXAMPLES=yes
|
||||
CYPRESS_CONFIG=false
|
||||
SUPERSET_PORT=8088
|
||||
|
||||
@@ -34,15 +34,14 @@ Frontend contribution types allow extensions to extend Superset's user interface
|
||||
|
||||
Extensions can add new views or panels to the host application, such as custom SQL Lab panels, dashboards, or other UI components. Contribution areas are uniquely identified (e.g., `sqllab.panels` for SQL Lab panels), enabling seamless integration into specific parts of the application.
|
||||
|
||||
```tsx
|
||||
import React from 'react';
|
||||
```typescript
|
||||
import { views } from '@apache-superset/core';
|
||||
import MyPanel from './MyPanel';
|
||||
|
||||
views.registerView(
|
||||
{ id: 'my-extension.main', name: 'My Panel Name' },
|
||||
'sqllab.panels',
|
||||
() => <MyPanel />,
|
||||
MyPanel,
|
||||
);
|
||||
```
|
||||
|
||||
@@ -112,6 +111,24 @@ editors.registerEditor(
|
||||
|
||||
See [Editors Extension Point](./extension-points/editors.md) for implementation details.
|
||||
|
||||
### Chat
|
||||
|
||||
Extensions can add a chat interface to Superset by registering a trigger component and a panel component. The host owns the layout, open/close state, and display mode — the extension only provides the UI. The panel can be displayed as a floating overlay or docked as a resizable sidebar beside the page content, and the user's preference is persisted across reloads.
|
||||
|
||||
```tsx
|
||||
import { chat } from '@apache-superset/core';
|
||||
import ChatTrigger from './ChatTrigger';
|
||||
import ChatPanel from './ChatPanel';
|
||||
|
||||
chat.registerChat(
|
||||
{ id: 'my-org.my-chat', name: 'My Chat' },
|
||||
ChatTrigger,
|
||||
ChatPanel,
|
||||
);
|
||||
```
|
||||
|
||||
See [Chat](./extension-points/chat.md) for implementation details.
|
||||
|
||||
## Backend
|
||||
|
||||
Backend contribution types allow extensions to extend Superset's server-side capabilities. Backend contributions are registered at startup via classes and functions imported from the auto-discovered `entrypoint.py` file.
|
||||
|
||||
141
docs/developer_docs/extensions/extension-points/chat.md
Normal file
141
docs/developer_docs/extensions/extension-points/chat.md
Normal file
@@ -0,0 +1,141 @@
|
||||
---
|
||||
title: Chat
|
||||
sidebar_position: 3
|
||||
---
|
||||
|
||||
<!--
|
||||
Licensed to the Apache Software Foundation (ASF) under one
|
||||
or more contributor license agreements. See the NOTICE file
|
||||
distributed with this work for additional information
|
||||
regarding copyright ownership. The ASF licenses this file
|
||||
to you under the Apache License, Version 2.0 (the
|
||||
"License"); you may not use this file except in compliance
|
||||
with the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing,
|
||||
software distributed under the License is distributed on an
|
||||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations
|
||||
under the License.
|
||||
-->
|
||||
|
||||
# Chat Contributions
|
||||
|
||||
Extensions can add a chat interface to Superset by registering a trigger and a panel. The host owns the layout, open/close state, and display mode — the extension only needs to provide the UI components.
|
||||
|
||||
## Overview
|
||||
|
||||
A chat registration consists of two React components:
|
||||
|
||||
| Component | Role |
|
||||
|-----------|------|
|
||||
| **Trigger** | Always-visible entry point (e.g., a floating button). Rendered in the bottom-right corner in floating mode, or as a fixed overlay in panel mode. |
|
||||
| **Panel** | The chat UI itself (message list, input, etc.). Mounted by the host in the active display mode. |
|
||||
|
||||
## Display Modes
|
||||
|
||||
The host supports two display modes, switchable by the user or the extension at runtime:
|
||||
|
||||
| Mode | Behavior |
|
||||
|------|----------|
|
||||
| `floating` | Panel floats above page content, anchored to the bottom-right corner. |
|
||||
| `panel` | Panel is docked to the right side of the application as a resizable sidebar, sitting beside the page content. |
|
||||
|
||||
The user's last selected mode and open/closed state are persisted across page reloads.
|
||||
|
||||
## Registering a Chat
|
||||
|
||||
Call `chat.registerChat` from your extension's entry point with a descriptor, a trigger factory, and a panel factory:
|
||||
|
||||
```tsx
|
||||
import { chat } from '@apache-superset/core';
|
||||
import ChatTrigger from './ChatTrigger';
|
||||
import ChatPanel from './ChatPanel';
|
||||
|
||||
chat.registerChat(
|
||||
{ id: 'my-org.my-chat', name: 'My Chat' },
|
||||
ChatTrigger,
|
||||
ChatPanel,
|
||||
);
|
||||
```
|
||||
|
||||
Only one chat registration is active at a time. If a second extension calls `registerChat`, it replaces the first and a warning is logged.
|
||||
|
||||
## Opening and Closing the Chat
|
||||
|
||||
The trigger component is responsible for toggling the panel. Use `chat.isOpen()`, `chat.open()`, and `chat.close()` to control visibility:
|
||||
|
||||
```tsx
|
||||
import { chat } from '@apache-superset/core';
|
||||
|
||||
export default function ChatTrigger() {
|
||||
return (
|
||||
<button onClick={() => (chat.isOpen() ? chat.close() : chat.open())}>
|
||||
💬
|
||||
</button>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
You can also subscribe to open/close events from any component:
|
||||
|
||||
```tsx
|
||||
useEffect(() => {
|
||||
const { dispose } = chat.onDidOpen(() => console.log('chat opened'));
|
||||
return dispose;
|
||||
}, []);
|
||||
```
|
||||
|
||||
## Changing the Display Mode
|
||||
|
||||
Call `chat.setDisplayMode` to switch between `'floating'` and `'panel'` modes. In your panel component, subscribe to `onDidChangeDisplayMode` to react to changes (including those triggered by the user):
|
||||
|
||||
```tsx
|
||||
import { useState, useEffect } from 'react';
|
||||
import { chat } from '@apache-superset/core';
|
||||
|
||||
export default function ChatPanel() {
|
||||
const [mode, setMode] = useState(chat.getDisplayMode());
|
||||
|
||||
useEffect(() => {
|
||||
const { dispose } = chat.onDidChangeDisplayMode(m => setMode(m));
|
||||
return dispose;
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div style={{ height: mode === 'panel' ? '100%' : '80vh' }}>
|
||||
<button onClick={() => chat.setDisplayMode(mode === 'panel' ? 'floating' : 'panel')}>
|
||||
{mode === 'panel' ? 'Float' : 'Dock'}
|
||||
</button>
|
||||
{/* message list and input */}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
## Chat API Reference
|
||||
|
||||
All methods are available on the `chat` namespace from `@apache-superset/core`:
|
||||
|
||||
| Method / Event | Description |
|
||||
|----------------|-------------|
|
||||
| `registerChat(descriptor, trigger, panel)` | Register a chat extension. Returns a `Disposable` to unregister. |
|
||||
| `open()` | Open the chat panel. No-op if already open or no registration. |
|
||||
| `close()` | Close the chat panel. |
|
||||
| `isOpen()` | Returns `true` if the panel is currently open. |
|
||||
| `getDisplayMode()` | Returns the current display mode (`'floating'` or `'panel'`). |
|
||||
| `setDisplayMode(mode)` | Switch between `'floating'` and `'panel'` mode. |
|
||||
| `onDidOpen(listener)` | Subscribe to panel open events. Returns a `Disposable`. |
|
||||
| `onDidClose(listener)` | Subscribe to panel close events. Returns a `Disposable`. |
|
||||
| `onDidChangeDisplayMode(listener)` | Subscribe to display mode changes. Returns a `Disposable`. |
|
||||
| `onDidRegisterChat(listener)` | Subscribe to registration events. |
|
||||
| `onDidUnregisterChat(listener)` | Subscribe to unregistration events. |
|
||||
| `onDidResizePanel(listener)` | Subscribe to panel resize events (panel mode only). Not all hosts provide a resizer — do not rely on this firing. Returns a `Disposable`. |
|
||||
|
||||
## Next Steps
|
||||
|
||||
- **[Contribution Types](../contribution-types.md)** — Explore other contribution types
|
||||
- **[Development](../development.md)** — Set up your development environment
|
||||
@@ -47,6 +47,8 @@ module.exports = {
|
||||
collapsed: true,
|
||||
items: [
|
||||
'extensions/extension-points/sqllab',
|
||||
'extensions/extension-points/editors',
|
||||
'extensions/extension-points/chat',
|
||||
],
|
||||
},
|
||||
'extensions/development',
|
||||
|
||||
@@ -519,104 +519,6 @@ For a connection to a SQL endpoint you need to use the HTTP path from the endpoi
|
||||
{"connect_args": {"http_path": "/sql/1.0/endpoints/****", "driver_path": "/path/to/odbc/driver"}}
|
||||
```
|
||||
|
||||
##### OAuth2 Authentication
|
||||
|
||||
Superset supports OAuth2 authentication for Databricks, allowing users to authenticate with their personal Databricks accounts instead of using shared access tokens. This provides better security and audit capabilities.
|
||||
|
||||
###### Prerequisites
|
||||
|
||||
1. Create an OAuth2 application in your Databricks account:
|
||||
- Go to your Databricks account console
|
||||
- Navigate to **Settings** → **Developer** → **OAuth apps**
|
||||
- Create a new OAuth app with the redirect URI: `http://your-superset-host:port/api/v1/database/oauth2/`
|
||||
|
||||
2. Configure OAuth2 in your `superset_config.py`:
|
||||
|
||||
```python
|
||||
from datetime import timedelta
|
||||
|
||||
# OAuth2 configuration for Databricks
|
||||
# OAuth2 endpoints are automatically detected based on your Databricks cloud provider
|
||||
DATABASE_OAUTH2_CLIENTS = {
|
||||
"Databricks (legacy)": {
|
||||
"id": "your-databricks-client-id",
|
||||
"secret": "your-databricks-client-secret",
|
||||
"scope": "sql",
|
||||
# The authorization endpoint is auto-detected from the hostname; the
|
||||
# token endpoint must be set explicitly (no DB context at exchange):
|
||||
# AWS: "authorization_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/{account_id}/v1/authorize",
|
||||
# Azure: "authorization_request_uri": "https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/authorize",
|
||||
# GCP: "authorization_request_uri": "https://accounts.gcp.databricks.com/oidc/accounts/{account_id}/v1/authorize",
|
||||
# "token_request_uri": "https://<provider-token-endpoint>",
|
||||
},
|
||||
"Databricks": {
|
||||
"id": "your-databricks-client-id",
|
||||
"secret": "your-databricks-client-secret",
|
||||
"scope": "sql",
|
||||
# Authorization endpoint auto-detected from hostname; set
|
||||
# "token_request_uri" explicitly for the token exchange.
|
||||
},
|
||||
}
|
||||
|
||||
# OAuth2 redirect URI (adjust hostname/port for your setup)
|
||||
DATABASE_OAUTH2_REDIRECT_URI = "http://your-superset-host:port/api/v1/database/oauth2/"
|
||||
|
||||
# Optional: OAuth2 timeout
|
||||
DATABASE_OAUTH2_TIMEOUT = timedelta(seconds=30)
|
||||
```
|
||||
|
||||
Replace the following placeholders:
|
||||
- `your-databricks-client-id`: Your Databricks OAuth2 application client ID
|
||||
- `your-databricks-client-secret`: Your Databricks OAuth2 application client secret
|
||||
- `your-superset-host:port`: Your Superset instance hostname and port
|
||||
|
||||
**Multi-Cloud Provider Support**
|
||||
|
||||
Superset automatically detects your Databricks cloud provider and uses the appropriate OAuth2 endpoints:
|
||||
|
||||
- **AWS**: Detected from hostnames containing `cloud.databricks.com`
|
||||
- **Azure**: Detected from hostnames containing `azure` or `azuredatabricks`
|
||||
- **GCP**: Detected from hostnames containing `gcp` or `googleusercontent`
|
||||
|
||||
You can also explicitly specify the cloud provider, along with the account
|
||||
identifier used to build the OAuth2 endpoints, in your database configuration
|
||||
under **Advanced** → **Other** → **ENGINE PARAMETERS**:
|
||||
|
||||
```json
|
||||
{
|
||||
"cloud_provider": "azure",
|
||||
"tenant_id": "your-azure-tenant-id"
|
||||
}
|
||||
```
|
||||
|
||||
For AWS and GCP, supply `account_id` instead:
|
||||
|
||||
```json
|
||||
{
|
||||
"cloud_provider": "aws",
|
||||
"account_id": "your-databricks-account-id"
|
||||
}
|
||||
```
|
||||
|
||||
Valid cloud provider values are: `aws`, `azure`, `gcp`. The **authorization**
|
||||
endpoint is auto-detected: Superset substitutes this identifier into the
|
||||
provider's authorization template. The **token** endpoint is not auto-resolved
|
||||
(token exchange has no database context to detect the provider), so for the
|
||||
auto-detected flow you must still supply a fully-resolved `token_request_uri`
|
||||
in `DATABASE_OAUTH2_CLIENTS`. If you supply fully-resolved
|
||||
`authorization_request_uri` and `token_request_uri` values, those take
|
||||
precedence and no `account_id`/`tenant_id` is required.
|
||||
|
||||
###### Usage
|
||||
|
||||
Once configured, users can:
|
||||
|
||||
1. Connect to Databricks databases normally using access tokens
|
||||
2. When querying data, Superset will automatically redirect users to authenticate with Databricks if needed
|
||||
3. User-specific OAuth2 tokens will be used for database connections, providing better security and audit trails
|
||||
|
||||
This feature works with both "Databricks (legacy)" and "Databricks" engine types and automatically supports all major cloud providers (AWS, Azure, GCP).
|
||||
|
||||
#### Denodo
|
||||
|
||||
The recommended connector library for Denodo is
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import { DATABASE_LIST } from 'cypress/utils/urls';
|
||||
|
||||
function closeModal() {
|
||||
cy.get('body').then($body => {
|
||||
if ($body.find('[data-test="database-modal"]').length) {
|
||||
cy.get('[aria-label="Close"]').eq(1).click();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
describe('Add database', () => {
|
||||
before(() => {
|
||||
cy.visit(DATABASE_LIST);
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
cy.intercept('POST', '**/api/v1/database/validate_parameters/**').as(
|
||||
'validateParams',
|
||||
);
|
||||
cy.intercept('POST', '**/api/v1/database/').as('createDb');
|
||||
|
||||
closeModal();
|
||||
cy.getBySel('btn-create-database').click();
|
||||
});
|
||||
|
||||
it('should open dynamic form', () => {
|
||||
cy.get('.preferred > :nth-child(1)').click();
|
||||
|
||||
cy.get('input[name="host"]').should('have.value', '');
|
||||
cy.get('input[name="port"]').should('have.value', '');
|
||||
cy.get('input[name="database"]').should('have.value', '');
|
||||
cy.get('input[name="username"]').should('have.value', '');
|
||||
cy.get('input[name="password"]').should('have.value', '');
|
||||
cy.get('input[name="database_name"]').should('have.value', '');
|
||||
});
|
||||
|
||||
it('should open sqlalchemy form', () => {
|
||||
cy.get('.preferred > :nth-child(1)').click();
|
||||
cy.getBySel('sqla-connect-btn').click();
|
||||
|
||||
cy.getBySel('database-name-input').should('be.visible');
|
||||
cy.getBySel('sqlalchemy-uri-input').should('be.visible');
|
||||
});
|
||||
|
||||
it('show error alerts on dynamic form for bad host', () => {
|
||||
cy.get('.preferred > :nth-child(1)').click();
|
||||
|
||||
cy.get('input[name="host"]').type('badhost', { force: true });
|
||||
cy.get('input[name="port"]').type('5432', { force: true });
|
||||
cy.get('input[name="username"]').type('testusername', { force: true });
|
||||
cy.get('input[name="database"]').type('testdb', { force: true });
|
||||
cy.get('input[name="password"]').type('testpass', { force: true });
|
||||
|
||||
cy.get('body').click(0, 0);
|
||||
|
||||
cy.wait('@validateParams', { timeout: 30000 });
|
||||
|
||||
cy.getBySel('btn-submit-connection').should('not.be.disabled');
|
||||
cy.getBySel('btn-submit-connection').click({ force: true });
|
||||
|
||||
cy.wait('@validateParams', { timeout: 30000 }).then(() => {
|
||||
cy.wait('@createDb', { timeout: 60000 }).then(() => {
|
||||
cy.contains(
|
||||
'.ant-form-item-explain-error',
|
||||
"The hostname provided can't be resolved",
|
||||
).should('exist');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it('show error alerts on dynamic form for bad port', () => {
|
||||
cy.get('.preferred > :nth-child(1)').click();
|
||||
|
||||
cy.get('input[name="host"]').type('localhost', { force: true });
|
||||
cy.get('body').click(0, 0);
|
||||
cy.wait('@validateParams', { timeout: 30000 });
|
||||
|
||||
cy.get('input[name="port"]').type('5430', { force: true });
|
||||
cy.get('input[name="database"]').type('testdb', { force: true });
|
||||
cy.get('input[name="username"]').type('testusername', { force: true });
|
||||
|
||||
cy.wait('@validateParams', { timeout: 30000 });
|
||||
|
||||
cy.get('input[name="password"]').type('testpass', { force: true });
|
||||
cy.wait('@validateParams');
|
||||
|
||||
cy.getBySel('btn-submit-connection').should('not.be.disabled');
|
||||
cy.getBySel('btn-submit-connection').click({ force: true });
|
||||
cy.wait('@validateParams', { timeout: 30000 }).then(() => {
|
||||
cy.get('body').click(0, 0);
|
||||
cy.getBySel('btn-submit-connection').click({ force: true });
|
||||
cy.wait('@createDb', { timeout: 60000 }).then(() => {
|
||||
cy.contains(
|
||||
'.ant-form-item-explain-error',
|
||||
'The port is closed',
|
||||
).should('exist');
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
30
superset-frontend/package-lock.json
generated
30
superset-frontend/package-lock.json
generated
@@ -26533,6 +26533,21 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/jsdom/node_modules/@noble/hashes": {
|
||||
"version": "2.2.0",
|
||||
"resolved": "https://registry.npmjs.org/@noble/hashes/-/hashes-2.2.0.tgz",
|
||||
"integrity": "sha512-IYqDGiTXab6FniAgnSdZwgWbomxpy9FtYvLKs7wCUs2a8RkITG+DFGO1DM9cr+E3/RgADRpFjrKVaJ1z6sjtEg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">= 20.19.0"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://paulmillr.com/funding/"
|
||||
}
|
||||
},
|
||||
"node_modules/jsdom/node_modules/css-tree": {
|
||||
"version": "3.2.1",
|
||||
"resolved": "https://registry.npmjs.org/css-tree/-/css-tree-3.2.1.tgz",
|
||||
@@ -43570,6 +43585,21 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/whatwg-url/node_modules/@noble/hashes": {
|
||||
"version": "2.2.0",
|
||||
"resolved": "https://registry.npmjs.org/@noble/hashes/-/hashes-2.2.0.tgz",
|
||||
"integrity": "sha512-IYqDGiTXab6FniAgnSdZwgWbomxpy9FtYvLKs7wCUs2a8RkITG+DFGO1DM9cr+E3/RgADRpFjrKVaJ1z6sjtEg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">= 20.19.0"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://paulmillr.com/funding/"
|
||||
}
|
||||
},
|
||||
"node_modules/whatwg-url/node_modules/webidl-conversions": {
|
||||
"version": "8.0.1",
|
||||
"resolved": "https://registry.npmjs.org/webidl-conversions/-/webidl-conversions-8.0.1.tgz",
|
||||
|
||||
@@ -18,6 +18,14 @@
|
||||
"types": "./lib/authentication/index.d.ts",
|
||||
"default": "./lib/authentication/index.js"
|
||||
},
|
||||
"./chat": {
|
||||
"types": "./lib/chat/index.d.ts",
|
||||
"default": "./lib/chat/index.js"
|
||||
},
|
||||
"./navigation": {
|
||||
"types": "./lib/navigation/index.d.ts",
|
||||
"default": "./lib/navigation/index.js"
|
||||
},
|
||||
"./commands": {
|
||||
"types": "./lib/commands/index.d.ts",
|
||||
"default": "./lib/commands/index.js"
|
||||
|
||||
156
superset-frontend/packages/superset-core/src/chat/index.ts
Normal file
156
superset-frontend/packages/superset-core/src/chat/index.ts
Normal file
@@ -0,0 +1,156 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @fileoverview Chat contribution API for Superset extensions.
|
||||
*
|
||||
* Chat is a dedicated contribution type: an extension registers
|
||||
* a chat via {@link registerChat} and the host owns where and how it is
|
||||
* mounted. The host applies singleton resolution — multiple chat extensions
|
||||
* may register, but exactly one is active at a time.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* import { chat } from '@apache-superset/core';
|
||||
*
|
||||
* chat.registerChat(
|
||||
* { id: 'acme.chat', name: 'Acme Chat' },
|
||||
* AcmeTrigger,
|
||||
* AcmePanel,
|
||||
* );
|
||||
* ```
|
||||
*/
|
||||
|
||||
import { ComponentType } from 'react';
|
||||
import type { Disposable, Event } from '../common';
|
||||
|
||||
export interface Chat {
|
||||
/** The unique identifier for the chat. */
|
||||
id: string;
|
||||
/** The display name of the chat. */
|
||||
name: string;
|
||||
/** Optional description of the chat. */
|
||||
description?: string;
|
||||
}
|
||||
|
||||
export type DisplayMode = 'floating' | 'panel';
|
||||
|
||||
/**
|
||||
* Registers a chat provider. Only one chat is active at a time; the most
|
||||
* recently registered chat wins. Disposing the returned Disposable unregisters
|
||||
* the chat.
|
||||
*
|
||||
* @param chat The chat descriptor (id, name).
|
||||
* @param trigger The trigger component — the collapsed bubble entry point.
|
||||
* Owns dynamic state such as unread counts.
|
||||
* @param panel The panel component, rendered in either display mode. In
|
||||
* 'floating' mode it appears as an overlay; in 'panel' mode it is docked
|
||||
* alongside the main content.
|
||||
* @returns A Disposable that unregisters the chat when disposed.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* chat.registerChat(
|
||||
* { id: 'acme.chat', name: 'Acme Chat' },
|
||||
* AcmeTrigger,
|
||||
* AcmePanel,
|
||||
* );
|
||||
* ```
|
||||
*/
|
||||
export declare function registerChat(
|
||||
chat: Chat,
|
||||
trigger: ComponentType,
|
||||
panel: ComponentType,
|
||||
): Disposable;
|
||||
|
||||
/**
|
||||
* Returns the active chat descriptor, or undefined if none is registered.
|
||||
*/
|
||||
export declare function getChat(): Chat | undefined;
|
||||
|
||||
/**
|
||||
* Event fired when a chat is registered.
|
||||
*/
|
||||
export declare const onDidRegisterChat: Event<Chat>;
|
||||
|
||||
/**
|
||||
* Event fired when a chat is unregistered.
|
||||
*/
|
||||
export declare const onDidUnregisterChat: Event<Chat>;
|
||||
|
||||
/**
|
||||
* Opens the active chat's panel.
|
||||
*
|
||||
* Acts on whichever chat is active, regardless of which extension calls it.
|
||||
* No-op when no chat is registered or the panel is already open.
|
||||
*/
|
||||
export declare function open(): void;
|
||||
|
||||
/**
|
||||
* Closes the active chat's panel.
|
||||
*
|
||||
* Acts on whichever chat is active, regardless of which extension calls it.
|
||||
* No-op when the panel is not open.
|
||||
*/
|
||||
export declare function close(): void;
|
||||
|
||||
/**
|
||||
* Returns whether the active chat's panel is currently open.
|
||||
*/
|
||||
export declare function isOpen(): boolean;
|
||||
|
||||
/**
|
||||
* Event fired when the chat panel opens. Also fired by the host's own
|
||||
* controls, not only by an extension's open() call.
|
||||
*/
|
||||
export declare const onDidOpen: Event<void>;
|
||||
|
||||
/**
|
||||
* Event fired when the chat panel closes, whether triggered by an extension
|
||||
* or by the host.
|
||||
*/
|
||||
export declare const onDidClose: Event<void>;
|
||||
|
||||
/**
|
||||
* Returns the current display mode.
|
||||
*/
|
||||
export declare function getDisplayMode(): DisplayMode;
|
||||
|
||||
/**
|
||||
* Sets the display mode. The mode is host-global and applies to whichever
|
||||
* chat is active. Use {@link onDidChangeDisplayMode} to observe all changes,
|
||||
* including those triggered by the host.
|
||||
*/
|
||||
export declare function setDisplayMode(displayMode: DisplayMode): void;
|
||||
|
||||
/**
|
||||
* Event fired when the display mode changes, whether triggered by an
|
||||
* extension via setDisplayMode() or by host-provided controls.
|
||||
*/
|
||||
export declare const onDidChangeDisplayMode: Event<DisplayMode>;
|
||||
|
||||
/**
|
||||
* Event fired when the panel is resized in panel mode. Not all hosts provide
|
||||
* a resizer — do not rely on this event firing.
|
||||
*/
|
||||
export declare const onDidResizePanel: Event<{ width: number }>;
|
||||
|
||||
// TODO: client actions API — tool availability functions will be added here
|
||||
// once the client_actions SIP is finalized. The chat namespace is the
|
||||
// intended integration point between the two SIPs.
|
||||
@@ -223,8 +223,6 @@ export interface Extension {
|
||||
dependencies: string[];
|
||||
/** Human-readable description of the extension */
|
||||
description: string;
|
||||
/** List of other extensions that this extension depends on */
|
||||
extensionDependencies: string[];
|
||||
/** Unique identifier for the extension */
|
||||
id: string;
|
||||
/** Human-readable name of the extension */
|
||||
|
||||
@@ -23,9 +23,10 @@
|
||||
* This module defines the aggregate interfaces used by the extension.json
|
||||
* manifest and the `superset-extensions` build command. Individual metadata
|
||||
* types are defined in their respective namespace modules (commands, views,
|
||||
* menus, editors) and re-exported here for the manifest schema.
|
||||
* menus, editors, chat) and re-exported here for the manifest schema.
|
||||
*/
|
||||
|
||||
import { Chat } from '../chat';
|
||||
import { Command } from '../commands';
|
||||
import { View } from '../views';
|
||||
import { Menu } from '../menus';
|
||||
@@ -71,7 +72,8 @@ export interface MenuContributions {
|
||||
}
|
||||
|
||||
/**
|
||||
* Aggregates all contributions (commands, menus, views, and editors) provided by an extension or module.
|
||||
* Aggregates all contributions (commands, menus, views, editors, and chat)
|
||||
* provided by an extension or module.
|
||||
*/
|
||||
export interface Contributions {
|
||||
/** List of commands. */
|
||||
@@ -82,4 +84,10 @@ export interface Contributions {
|
||||
views: ViewContributions;
|
||||
/** List of editors. */
|
||||
editors?: Editor[];
|
||||
/**
|
||||
* The chat contributed by the extension — at most one per extension, since
|
||||
* the host applies singleton resolution and renders exactly one active
|
||||
* chat at a time.
|
||||
*/
|
||||
chat?: Chat;
|
||||
}
|
||||
|
||||
@@ -18,10 +18,12 @@
|
||||
*/
|
||||
export * as common from './common';
|
||||
export * as authentication from './authentication';
|
||||
export * as chat from './chat';
|
||||
export * as commands from './commands';
|
||||
export * as editors from './editors';
|
||||
export * as extensions from './extensions';
|
||||
export * as menus from './menus';
|
||||
export * as navigation from './navigation';
|
||||
export * as sqlLab from './sqlLab';
|
||||
export * as views from './views';
|
||||
export * as contributions from './contributions';
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @fileoverview Navigation namespace for Superset extensions.
|
||||
*
|
||||
* Exposes the current application surface so extensions can react to route
|
||||
* changes without polling. Entity-level context (chart, dashboard, dataset)
|
||||
* is intentionally not included here — surface-specific namespaces that
|
||||
* resolve entity payloads are introduced in later phases.
|
||||
*/
|
||||
|
||||
import { Event } from '../common';
|
||||
|
||||
/**
|
||||
* The set of top-level application surfaces.
|
||||
*
|
||||
* `'explore'`, `'dashboard'` and `'dataset'` are the single-entity
|
||||
* editing/viewing surfaces. `'chart_list'`, `'dashboard_list'` and
|
||||
* `'dataset_list'` are the browse/list surfaces, distinct from those because no
|
||||
* single entity is active. `'sqllab'` is the SQL editor where
|
||||
* `sqlLab.getCurrentTab()` resolves; `'query_history'` and `'saved_queries'`
|
||||
* are the related SQL Lab browse pages, which are not the editor. `'home'` is
|
||||
* the welcome surface and the fallback for any route not explicitly enumerated.
|
||||
*/
|
||||
export type Page =
|
||||
| 'dashboard'
|
||||
| 'dashboard_list'
|
||||
| 'explore'
|
||||
| 'chart_list'
|
||||
| 'sqllab'
|
||||
| 'query_history'
|
||||
| 'saved_queries'
|
||||
| 'dataset'
|
||||
| 'dataset_list'
|
||||
| 'home';
|
||||
|
||||
/**
|
||||
* Returns the current page surface.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const page = navigation.getPage();
|
||||
* if (page === 'dashboard') {
|
||||
* // react to being on a dashboard surface
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
export declare function getPage(): Page;
|
||||
|
||||
/**
|
||||
* Event fired whenever the user navigates to a different surface.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const sub = navigation.onDidChangePage(page => {
|
||||
* if (page === 'dashboard') {
|
||||
* // react to navigating onto a dashboard surface
|
||||
* }
|
||||
* });
|
||||
* // later:
|
||||
* sub.dispose();
|
||||
* ```
|
||||
*/
|
||||
export declare const onDidChangePage: Event<Page>;
|
||||
@@ -30,12 +30,12 @@
|
||||
*
|
||||
* views.registerView(
|
||||
* { id: 'my_ext.result_stats', name: 'Result Stats', location: 'sqllab.panels' },
|
||||
* () => <ResultStatsPanel />,
|
||||
* ResultStatsPanel,
|
||||
* );
|
||||
* ```
|
||||
*/
|
||||
|
||||
import { ReactElement } from 'react';
|
||||
import { ComponentType } from 'react';
|
||||
import { Disposable, Event } from '../common';
|
||||
|
||||
/**
|
||||
@@ -58,7 +58,7 @@ export interface View {
|
||||
*
|
||||
* @param view The view descriptor (id and name).
|
||||
* @param location The location where this view should appear (e.g. "sqllab.panels").
|
||||
* @param provider A function that returns the React element to render.
|
||||
* @param component The React component to render at that location.
|
||||
* @returns A Disposable that unregisters the view when disposed.
|
||||
*
|
||||
* @example
|
||||
@@ -66,14 +66,14 @@ export interface View {
|
||||
* views.registerView(
|
||||
* { id: 'my_ext.result_stats', name: 'Result Stats' },
|
||||
* 'sqllab.panels',
|
||||
* () => <ResultStatsPanel />,
|
||||
* ResultStatsPanel,
|
||||
* );
|
||||
* ```
|
||||
*/
|
||||
export declare function registerView(
|
||||
view: View,
|
||||
location: string,
|
||||
provider: () => ReactElement,
|
||||
component: ComponentType,
|
||||
): Disposable;
|
||||
|
||||
/**
|
||||
|
||||
@@ -132,6 +132,26 @@ export const advancedAnalyticsControls: ControlPanelSectionConfig = {
|
||||
},
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
name: 'time_compare_full_range',
|
||||
config: {
|
||||
type: 'CheckboxControl',
|
||||
label: t('Show full range for time shift'),
|
||||
default: false,
|
||||
description: t(
|
||||
'Plot each time-shifted series across its full time range instead ' +
|
||||
'of truncating it to the main series. Useful for comparing a ' +
|
||||
'partial current period (e.g. today so far) against complete ' +
|
||||
'prior periods (e.g. all of yesterday).',
|
||||
),
|
||||
visibility: ({ controls }) =>
|
||||
Boolean(controls?.time_compare?.value) &&
|
||||
(!Array.isArray(controls?.time_compare?.value) ||
|
||||
controls.time_compare.value.length > 0),
|
||||
},
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
name: 'comparison_type',
|
||||
|
||||
@@ -318,14 +318,25 @@ function createAdvancedAnalyticsSection(
|
||||
): ControlPanelSectionConfig {
|
||||
const aaWithSuffix = cloneDeep(sections.advancedAnalyticsControls);
|
||||
aaWithSuffix.label = label;
|
||||
// `time_compare_full_range` is only wired into the regular timeseries query
|
||||
// builder, not the mixed-timeseries one, so drop it here to avoid showing a
|
||||
// control that has no effect.
|
||||
aaWithSuffix.controlSetRows = aaWithSuffix.controlSetRows
|
||||
.map(row =>
|
||||
row.filter(
|
||||
control =>
|
||||
(control as CustomControlItem)?.name !== 'time_compare_full_range',
|
||||
),
|
||||
)
|
||||
.filter(row => row.length > 0);
|
||||
if (!controlSuffix) {
|
||||
return aaWithSuffix;
|
||||
}
|
||||
aaWithSuffix.controlSetRows.forEach(row =>
|
||||
row.forEach((control: CustomControlItem) => {
|
||||
if (control?.name) {
|
||||
// eslint-disable-next-line no-param-reassign
|
||||
control.name = `${control.name}${controlSuffix}`;
|
||||
row.forEach(control => {
|
||||
const item = control as CustomControlItem;
|
||||
if (item?.name) {
|
||||
item.name = `${item.name}${controlSuffix}`;
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
@@ -82,6 +82,11 @@ export default function buildQuery(formData: QueryFormData) {
|
||||
? formData.time_compare
|
||||
: [];
|
||||
|
||||
// When comparing against prior periods, optionally keep each shifted series at
|
||||
// its full time range instead of truncating it to the main series' range.
|
||||
const time_compare_full_range =
|
||||
time_offsets.length > 0 && Boolean(formData.time_compare_full_range);
|
||||
|
||||
return [
|
||||
{
|
||||
...baseQueryObject,
|
||||
@@ -92,6 +97,7 @@ export default function buildQuery(formData: QueryFormData) {
|
||||
// todo: move `normalizeOrderBy to extractQueryFields`
|
||||
orderby: normalizeOrderBy(baseQueryObject).orderby,
|
||||
time_offsets,
|
||||
time_compare_full_range,
|
||||
/* Note that:
|
||||
1. The resample, rolling, cum, timeCompare operators should be after pivot.
|
||||
2. Resample must come before rolling so that imputed values are
|
||||
|
||||
@@ -381,6 +381,15 @@ export default function transformProps(
|
||||
const array = ensureIsArray(chartProps.rawFormData?.time_compare);
|
||||
const inverted = invert(verboseMap);
|
||||
|
||||
// With the "full range" time-shift option, offset series are outer-joined onto
|
||||
// the main series, which inserts null rows into the main series wherever the
|
||||
// comparison period has data the current period lacks. Connect nulls so the
|
||||
// main line stays continuous (matching the default left-join appearance) rather
|
||||
// than fragmenting at every inserted gap.
|
||||
const timeCompareFullRange = Boolean(
|
||||
chartProps.rawFormData?.time_compare_full_range,
|
||||
);
|
||||
|
||||
const offsetLineWidths: { [key: string]: number } = {};
|
||||
|
||||
// For horizontal bar charts, calculate min/max from data to avoid cutting off labels
|
||||
@@ -478,7 +487,7 @@ export default function transformProps(
|
||||
colorScaleKey,
|
||||
{
|
||||
area,
|
||||
connectNulls: derivedSeries,
|
||||
connectNulls: derivedSeries || timeCompareFullRange,
|
||||
filterState,
|
||||
seriesContexts,
|
||||
markerEnabled,
|
||||
|
||||
277
superset-frontend/src/core/chat/ChatHost.test.tsx
Normal file
277
superset-frontend/src/core/chat/ChatHost.test.tsx
Normal file
@@ -0,0 +1,277 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import { act, render, screen } from 'spec/helpers/testing-library';
|
||||
import { chat } from 'src/core/chat';
|
||||
import ChatProvider from './ChatProvider';
|
||||
import { ChatFloatingHost as ChatHost, ChatPanelHost } from './ChatHost';
|
||||
|
||||
beforeEach(() => {
|
||||
ChatProvider.getInstance().reset();
|
||||
});
|
||||
|
||||
test('renders nothing when no chat extension is registered', () => {
|
||||
render(<ChatHost />);
|
||||
|
||||
expect(screen.queryByTestId('chat-mount')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('renders the trigger bubble of the registered chat', () => {
|
||||
chat.registerChat(
|
||||
{ id: 'acme.chat', name: 'Acme Chat' },
|
||||
() => <button type="button">Acme Bubble</button>,
|
||||
() => <div>Acme Panel</div>,
|
||||
);
|
||||
|
||||
render(<ChatHost />);
|
||||
|
||||
expect(screen.getByTestId('chat-mount')).toBeInTheDocument();
|
||||
expect(screen.getByText('Acme Bubble')).toBeInTheDocument();
|
||||
// The panel stays unmounted until the chat is opened.
|
||||
expect(screen.queryByText('Acme Panel')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('mounts the panel when the chat opens and unmounts it on close', () => {
|
||||
chat.registerChat(
|
||||
{ id: 'acme.chat', name: 'Acme Chat' },
|
||||
() => <button type="button">Acme Bubble</button>,
|
||||
() => <div>Acme Panel</div>,
|
||||
);
|
||||
|
||||
render(<ChatHost />);
|
||||
|
||||
act(() => chat.open());
|
||||
|
||||
expect(screen.getByText('Acme Panel')).toBeInTheDocument();
|
||||
// In floating mode the trigger stays mounted alongside the open panel.
|
||||
expect(screen.getByText('Acme Bubble')).toBeInTheDocument();
|
||||
|
||||
act(() => chat.close());
|
||||
|
||||
expect(screen.queryByText('Acme Panel')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('renders the last-registered chat when several are installed', () => {
|
||||
jest.spyOn(console, 'warn').mockImplementation(() => {});
|
||||
|
||||
chat.registerChat(
|
||||
{ id: 'first.chat', name: 'First Chat' },
|
||||
() => <div>First Bubble</div>,
|
||||
() => <div>First Panel</div>,
|
||||
);
|
||||
chat.registerChat(
|
||||
{ id: 'second.chat', name: 'Second Chat' },
|
||||
() => <div>Second Bubble</div>,
|
||||
() => <div>Second Panel</div>,
|
||||
);
|
||||
|
||||
jest.restoreAllMocks();
|
||||
render(<ChatHost />);
|
||||
|
||||
// Last-loaded wins: the second registration takes over the singleton slot.
|
||||
expect(screen.getByText('Second Bubble')).toBeInTheDocument();
|
||||
expect(screen.queryByText('First Bubble')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('reacts to a chat registering after the initial render', () => {
|
||||
render(<ChatHost />);
|
||||
|
||||
expect(screen.queryByTestId('chat-mount')).not.toBeInTheDocument();
|
||||
|
||||
act(() => {
|
||||
chat.registerChat(
|
||||
{ id: 'acme.chat', name: 'Acme Chat' },
|
||||
() => <button type="button">Acme Bubble</button>,
|
||||
() => <div>Acme Panel</div>,
|
||||
);
|
||||
});
|
||||
|
||||
expect(screen.getByText('Acme Bubble')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('a takeover mounts the incoming chat closed', () => {
|
||||
chat.registerChat(
|
||||
{ id: 'first.chat', name: 'First Chat' },
|
||||
() => <div>First Bubble</div>,
|
||||
() => <div>First Panel</div>,
|
||||
);
|
||||
|
||||
render(<ChatHost />);
|
||||
act(() => chat.open());
|
||||
expect(screen.getByText('First Panel')).toBeInTheDocument();
|
||||
|
||||
act(() => {
|
||||
jest.spyOn(console, 'warn').mockImplementation(() => {});
|
||||
chat.registerChat(
|
||||
{ id: 'second.chat', name: 'Second Chat' },
|
||||
() => <div>Second Bubble</div>,
|
||||
() => <div>Second Panel</div>,
|
||||
);
|
||||
jest.restoreAllMocks();
|
||||
});
|
||||
|
||||
// The displaced chat's open state must not leak into the winner.
|
||||
expect(screen.getByText('Second Bubble')).toBeInTheDocument();
|
||||
expect(screen.queryByText('Second Panel')).not.toBeInTheDocument();
|
||||
expect(screen.queryByText('First Panel')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('ChatPanelHost renders the panel when open in panel mode', () => {
|
||||
chat.registerChat(
|
||||
{ id: 'acme.chat', name: 'Acme Chat' },
|
||||
() => <button type="button">Acme Bubble</button>,
|
||||
() => <div>Acme Panel</div>,
|
||||
);
|
||||
|
||||
render(<ChatPanelHost />);
|
||||
|
||||
act(() => {
|
||||
chat.setDisplayMode('panel');
|
||||
chat.open();
|
||||
});
|
||||
|
||||
expect(screen.getByText('Acme Panel')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('ChatFloatingHost suppresses the floating panel in panel mode but keeps the trigger', () => {
|
||||
chat.registerChat(
|
||||
{ id: 'acme.chat', name: 'Acme Chat' },
|
||||
() => <button type="button">Acme Bubble</button>,
|
||||
() => <div>Acme Panel</div>,
|
||||
);
|
||||
|
||||
render(<ChatHost />);
|
||||
|
||||
act(() => {
|
||||
chat.setDisplayMode('panel');
|
||||
chat.open();
|
||||
});
|
||||
|
||||
// In panel mode the floating panel is suppressed (ChatPanelHost owns that slot).
|
||||
expect(screen.queryByText('Acme Panel')).not.toBeInTheDocument();
|
||||
// The trigger stays rendered so the user can reopen after collapsing.
|
||||
expect(screen.getByText('Acme Bubble')).toBeInTheDocument();
|
||||
|
||||
act(() => chat.close());
|
||||
|
||||
// Trigger remains visible even when closed — it's the user's only way back.
|
||||
expect(screen.getByText('Acme Bubble')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('a crashing panel does not take the trigger down with it', () => {
|
||||
const FailingPanel = () => {
|
||||
throw new Error('panel blew up');
|
||||
};
|
||||
chat.registerChat(
|
||||
{ id: 'acme.chat', name: 'Acme Chat' },
|
||||
() => <button type="button">Acme Bubble</button>,
|
||||
() => <FailingPanel />,
|
||||
);
|
||||
|
||||
render(<ChatHost />);
|
||||
act(() => chat.open());
|
||||
|
||||
// The panel's boundary contains the crash; the trigger keeps rendering so
|
||||
// the user is not stranded without a way back.
|
||||
expect(screen.queryByText('panel blew up')).not.toBeInTheDocument();
|
||||
expect(screen.getByText('Acme Bubble')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('isolates a failing trigger so it does not crash the host', () => {
|
||||
const FailingTrigger = () => {
|
||||
throw new Error('chat blew up');
|
||||
};
|
||||
chat.registerChat(
|
||||
{ id: 'acme.chat', name: 'Acme Chat' },
|
||||
() => <FailingTrigger />,
|
||||
() => <div>Acme Panel</div>,
|
||||
);
|
||||
|
||||
// The host-owned error boundary catches the failure; render does not throw.
|
||||
expect(() => render(<ChatHost />)).not.toThrow();
|
||||
// The mount slot still renders (the boundary lives inside it), confirming
|
||||
// the provider was actually exercised and contained.
|
||||
expect(screen.getByTestId('chat-mount')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('isolates a component that throws during render', () => {
|
||||
chat.registerChat(
|
||||
{ id: 'acme.chat', name: 'Acme Chat' },
|
||||
() => {
|
||||
throw new Error('provider blew up');
|
||||
},
|
||||
() => <div>Acme Panel</div>,
|
||||
);
|
||||
|
||||
expect(() => render(<ChatHost />)).not.toThrow();
|
||||
expect(screen.getByTestId('chat-mount')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('recovers from a crashed chat when a different chat takes over', () => {
|
||||
const FailingTrigger = () => {
|
||||
throw new Error('first chat blew up');
|
||||
};
|
||||
chat.registerChat(
|
||||
{ id: 'first.chat', name: 'First Chat' },
|
||||
() => <FailingTrigger />,
|
||||
() => <div>First Panel</div>,
|
||||
);
|
||||
|
||||
render(<ChatHost />);
|
||||
expect(screen.queryByText('Second Bubble')).not.toBeInTheDocument();
|
||||
|
||||
act(() => {
|
||||
jest.spyOn(console, 'warn').mockImplementation(() => {});
|
||||
chat.registerChat(
|
||||
{ id: 'second.chat', name: 'Second Chat' },
|
||||
() => <div>Second Bubble</div>,
|
||||
() => <div>Second Panel</div>,
|
||||
);
|
||||
jest.restoreAllMocks();
|
||||
});
|
||||
|
||||
// The boundary is keyed per registration, so the latched crash from the
|
||||
// first chat does not blank the second one.
|
||||
expect(screen.getByText('Second Bubble')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('recovers from a crashed chat when a different id takes over', () => {
|
||||
const FailingTrigger = () => {
|
||||
throw new Error('broken release');
|
||||
};
|
||||
chat.registerChat(
|
||||
{ id: 'acme.chat', name: 'Acme Chat' },
|
||||
() => <FailingTrigger />,
|
||||
() => <div>Acme Panel</div>,
|
||||
);
|
||||
|
||||
render(<ChatHost />);
|
||||
|
||||
act(() => {
|
||||
jest.spyOn(console, 'warn').mockImplementation(() => {});
|
||||
chat.registerChat(
|
||||
{ id: 'fixed.chat', name: 'Fixed Chat' },
|
||||
() => <div>Fixed Bubble</div>,
|
||||
() => <div>Fixed Panel</div>,
|
||||
);
|
||||
jest.restoreAllMocks();
|
||||
});
|
||||
|
||||
// Different id: boundary key changes, latch resets, fix renders.
|
||||
expect(screen.getByText('Fixed Bubble')).toBeInTheDocument();
|
||||
});
|
||||
133
superset-frontend/src/core/chat/ChatHost.tsx
Normal file
133
superset-frontend/src/core/chat/ChatHost.tsx
Normal file
@@ -0,0 +1,133 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import { type ComponentType, useRef } from 'react';
|
||||
import { t } from '@apache-superset/core/translation';
|
||||
import { logging } from '@apache-superset/core/utils';
|
||||
import { css, useTheme } from '@apache-superset/core/theme';
|
||||
import { ErrorBoundary } from 'src/components/ErrorBoundary';
|
||||
import { addDangerToast } from 'src/components/MessageToasts/actions';
|
||||
import { store } from 'src/views/store';
|
||||
import { useChat } from '.';
|
||||
|
||||
const CHAT_EDGE_MARGIN = 24;
|
||||
|
||||
/**
|
||||
* Returns an onError handler that shows a toast on crash, once per chat id.
|
||||
*/
|
||||
function useCrashNotifier(chatId: string | undefined) {
|
||||
const notifiedFor = useRef<string | undefined>(undefined);
|
||||
return (error: Error) => {
|
||||
if (!chatId) return;
|
||||
logging.error('[chat] provider crashed', error);
|
||||
if (notifiedFor.current !== chatId) {
|
||||
notifiedFor.current = chatId;
|
||||
store.dispatch(addDangerToast(t('The chat failed to load.')));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Wraps a component in an ErrorBoundary, keyed by chat id so the boundary
|
||||
* resets when a different chat takes over.
|
||||
*/
|
||||
const ChatBoundary = ({
|
||||
component: Component,
|
||||
onError,
|
||||
}: {
|
||||
component: ComponentType;
|
||||
onError: (error: Error) => void;
|
||||
}) => (
|
||||
<ErrorBoundary showMessage={false} onError={onError}>
|
||||
<Component />
|
||||
</ErrorBoundary>
|
||||
);
|
||||
|
||||
/**
|
||||
* Renders the chat panel content in panel mode. Fills its container height.
|
||||
*/
|
||||
export const ChatPanelHost = () => {
|
||||
const { chat, panel } = useChat();
|
||||
const onError = useCrashNotifier(chat?.id);
|
||||
|
||||
if (!chat || !panel) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
data-test="chat-mount"
|
||||
css={css`
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
height: 100%;
|
||||
`}
|
||||
>
|
||||
<ChatBoundary key={chat.id} component={panel} onError={onError} />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Renders the chat trigger and, when the panel is open in floating mode, the
|
||||
* floating panel overlay. The trigger is always visible when a chat is
|
||||
* registered; the panel overlay is suppressed in panel mode.
|
||||
*/
|
||||
export const ChatFloatingHost = () => {
|
||||
const theme = useTheme();
|
||||
const { open: panelOpen, mode, chat, trigger, panel } = useChat();
|
||||
const onError = useCrashNotifier(chat?.id);
|
||||
|
||||
if (!chat || !trigger || !panel) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
data-test="chat-mount"
|
||||
css={css`
|
||||
position: fixed;
|
||||
right: ${CHAT_EDGE_MARGIN}px;
|
||||
bottom: ${CHAT_EDGE_MARGIN}px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: flex-end;
|
||||
gap: ${theme.sizeUnit * 2}px;
|
||||
/* Above dashboard content and the toast layer, below modal dialogs. */
|
||||
z-index: ${theme.zIndexPopupBase + 2};
|
||||
`}
|
||||
>
|
||||
{/*
|
||||
Separate boundaries so a crashing panel cannot take the trigger down
|
||||
with it — the trigger is the user's only way back.
|
||||
*/}
|
||||
{panelOpen && mode !== 'panel' && (
|
||||
<ChatBoundary
|
||||
key={`panel-${chat.id}`}
|
||||
component={panel}
|
||||
onError={onError}
|
||||
/>
|
||||
)}
|
||||
<ChatBoundary
|
||||
key={`trigger-${chat.id}`}
|
||||
component={trigger}
|
||||
onError={onError}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
257
superset-frontend/src/core/chat/ChatProvider.test.ts
Normal file
257
superset-frontend/src/core/chat/ChatProvider.test.ts
Normal file
@@ -0,0 +1,257 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import { createElement } from 'react';
|
||||
import ChatProvider from './ChatProvider';
|
||||
|
||||
const trigger = () => createElement('button', null, 'Bubble');
|
||||
const panel = () => createElement('div', null, 'Panel');
|
||||
|
||||
beforeEach(() => {
|
||||
ChatProvider.getInstance().reset();
|
||||
});
|
||||
|
||||
test('returns the singleton instance', () => {
|
||||
expect(ChatProvider.getInstance()).toBe(ChatProvider.getInstance());
|
||||
});
|
||||
|
||||
test('getChat returns undefined when no chat is registered', () => {
|
||||
expect(ChatProvider.getInstance().getChat()).toBeUndefined();
|
||||
});
|
||||
|
||||
test('registerChat sets the registration and returns the descriptor copy', () => {
|
||||
const provider = ChatProvider.getInstance();
|
||||
const descriptor = { id: 'acme.chat', name: 'Acme Chat' };
|
||||
const disposable = provider.registerChat(descriptor, trigger, panel);
|
||||
|
||||
expect(provider.getChat()).toEqual(descriptor);
|
||||
disposable.dispose();
|
||||
});
|
||||
|
||||
test('the last-registered chat wins and logs a warning', () => {
|
||||
const provider = ChatProvider.getInstance();
|
||||
const warn = jest.spyOn(console, 'warn').mockImplementation(() => {});
|
||||
|
||||
provider.registerChat({ id: 'first.chat', name: 'First' }, trigger, panel);
|
||||
provider.registerChat({ id: 'second.chat', name: 'Second' }, trigger, panel);
|
||||
|
||||
expect(provider.getChat()?.id).toBe('second.chat');
|
||||
expect(warn).toHaveBeenCalledTimes(1);
|
||||
expect(warn.mock.calls[0][0]).toContain('second.chat');
|
||||
expect(warn.mock.calls[0][0]).toContain('first.chat');
|
||||
warn.mockRestore();
|
||||
});
|
||||
|
||||
test('re-registering with a different id replaces the active chat', () => {
|
||||
const provider = ChatProvider.getInstance();
|
||||
jest.spyOn(console, 'warn').mockImplementation(() => {});
|
||||
|
||||
provider.registerChat({ id: 'first.chat', name: 'First' }, trigger, panel);
|
||||
expect(provider.getChat()?.id).toBe('first.chat');
|
||||
|
||||
provider.registerChat({ id: 'second.chat', name: 'Second' }, trigger, panel);
|
||||
expect(provider.getChat()?.id).toBe('second.chat');
|
||||
|
||||
jest.restoreAllMocks();
|
||||
});
|
||||
|
||||
test('disposing the registration clears it', () => {
|
||||
const provider = ChatProvider.getInstance();
|
||||
const disposable = provider.registerChat(
|
||||
{ id: 'acme.chat', name: 'Acme' },
|
||||
trigger,
|
||||
panel,
|
||||
);
|
||||
|
||||
disposable.dispose();
|
||||
|
||||
expect(provider.getChat()).toBeUndefined();
|
||||
});
|
||||
|
||||
test('disposing twice fires unregister only once', () => {
|
||||
const provider = ChatProvider.getInstance();
|
||||
const unregistered = jest.fn();
|
||||
provider.onDidUnregisterChat(unregistered);
|
||||
|
||||
const disposable = provider.registerChat(
|
||||
{ id: 'acme.chat', name: 'Acme' },
|
||||
trigger,
|
||||
panel,
|
||||
);
|
||||
disposable.dispose();
|
||||
disposable.dispose();
|
||||
|
||||
expect(unregistered).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
test('onDidRegisterChat and onDidUnregisterChat fire with the descriptor', () => {
|
||||
const provider = ChatProvider.getInstance();
|
||||
const registered = jest.fn();
|
||||
const unregistered = jest.fn();
|
||||
provider.onDidRegisterChat(registered);
|
||||
provider.onDidUnregisterChat(unregistered);
|
||||
|
||||
const descriptor = { id: 'acme.chat', name: 'Acme' };
|
||||
const disposable = provider.registerChat(descriptor, trigger, panel);
|
||||
|
||||
expect(registered).toHaveBeenCalledWith(descriptor);
|
||||
expect(unregistered).not.toHaveBeenCalled();
|
||||
|
||||
disposable.dispose();
|
||||
|
||||
expect(unregistered).toHaveBeenCalledWith(descriptor);
|
||||
});
|
||||
|
||||
test('open and close toggle the panel state', () => {
|
||||
const provider = ChatProvider.getInstance();
|
||||
provider.registerChat({ id: 'acme.chat', name: 'Acme' }, trigger, panel);
|
||||
|
||||
expect(provider.isOpen()).toBe(false);
|
||||
|
||||
provider.open();
|
||||
expect(provider.isOpen()).toBe(true);
|
||||
|
||||
provider.close();
|
||||
expect(provider.isOpen()).toBe(false);
|
||||
});
|
||||
|
||||
test('open fires once; duplicate open is a no-op', () => {
|
||||
const provider = ChatProvider.getInstance();
|
||||
const opened = jest.fn();
|
||||
provider.onDidOpen(opened);
|
||||
provider.registerChat({ id: 'acme.chat', name: 'Acme' }, trigger, panel);
|
||||
|
||||
provider.open();
|
||||
provider.open();
|
||||
|
||||
expect(opened).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
test('close fires once; duplicate close is a no-op', () => {
|
||||
const provider = ChatProvider.getInstance();
|
||||
const closed = jest.fn();
|
||||
provider.onDidClose(closed);
|
||||
provider.registerChat({ id: 'acme.chat', name: 'Acme' }, trigger, panel);
|
||||
|
||||
provider.open();
|
||||
provider.close();
|
||||
provider.close();
|
||||
|
||||
expect(closed).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
test('open is a no-op when no chat is registered', () => {
|
||||
const provider = ChatProvider.getInstance();
|
||||
const opened = jest.fn();
|
||||
provider.onDidOpen(opened);
|
||||
|
||||
provider.open();
|
||||
|
||||
expect(provider.isOpen()).toBe(false);
|
||||
expect(opened).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('registering a second chat while open closes the panel', () => {
|
||||
const provider = ChatProvider.getInstance();
|
||||
const closed = jest.fn();
|
||||
provider.onDidClose(closed);
|
||||
jest.spyOn(console, 'warn').mockImplementation(() => {});
|
||||
|
||||
provider.registerChat({ id: 'first.chat', name: 'First' }, trigger, panel);
|
||||
provider.open();
|
||||
provider.registerChat({ id: 'second.chat', name: 'Second' }, trigger, panel);
|
||||
|
||||
expect(provider.isOpen()).toBe(false);
|
||||
expect(closed).toHaveBeenCalledTimes(1);
|
||||
jest.restoreAllMocks();
|
||||
});
|
||||
|
||||
test('disposing the active chat while open closes the panel', () => {
|
||||
const provider = ChatProvider.getInstance();
|
||||
const closed = jest.fn();
|
||||
provider.onDidClose(closed);
|
||||
|
||||
const disposable = provider.registerChat(
|
||||
{ id: 'acme.chat', name: 'Acme' },
|
||||
trigger,
|
||||
panel,
|
||||
);
|
||||
provider.open();
|
||||
disposable.dispose();
|
||||
|
||||
expect(provider.isOpen()).toBe(false);
|
||||
expect(closed).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
test('a late registration does not inherit a stale open state', () => {
|
||||
const provider = ChatProvider.getInstance();
|
||||
const disposable = provider.registerChat(
|
||||
{ id: 'acme.chat', name: 'Acme' },
|
||||
trigger,
|
||||
panel,
|
||||
);
|
||||
provider.open();
|
||||
disposable.dispose();
|
||||
|
||||
provider.registerChat({ id: 'late.chat', name: 'Late' }, trigger, panel);
|
||||
|
||||
expect(provider.isOpen()).toBe(false);
|
||||
});
|
||||
|
||||
test('getDisplayMode defaults to floating', () => {
|
||||
expect(ChatProvider.getInstance().getDisplayMode()).toBe('floating');
|
||||
});
|
||||
|
||||
test('setDisplayMode updates mode and fires event only on change', () => {
|
||||
const provider = ChatProvider.getInstance();
|
||||
const modeChanged = jest.fn();
|
||||
provider.onDidChangeDisplayMode(modeChanged);
|
||||
|
||||
provider.setDisplayMode('floating');
|
||||
expect(modeChanged).not.toHaveBeenCalled();
|
||||
|
||||
provider.setDisplayMode('panel');
|
||||
expect(provider.getDisplayMode()).toBe('panel');
|
||||
expect(modeChanged).toHaveBeenCalledWith('panel');
|
||||
});
|
||||
|
||||
test('state reflects changes after registration and open', () => {
|
||||
const provider = ChatProvider.getInstance();
|
||||
|
||||
expect(provider.getChat()).toBeUndefined();
|
||||
expect(provider.isOpen()).toBe(false);
|
||||
|
||||
provider.registerChat({ id: 'acme.chat', name: 'Acme' }, trigger, panel);
|
||||
provider.open();
|
||||
|
||||
expect(provider.isOpen()).toBe(true);
|
||||
expect(provider.getChat()?.id).toBe('acme.chat');
|
||||
});
|
||||
|
||||
test('reset clears all state', () => {
|
||||
const provider = ChatProvider.getInstance();
|
||||
provider.registerChat({ id: 'acme.chat', name: 'Acme' }, trigger, panel);
|
||||
provider.open();
|
||||
provider.setDisplayMode('panel');
|
||||
|
||||
provider.reset();
|
||||
|
||||
expect(provider.getChat()).toBeUndefined();
|
||||
expect(provider.isOpen()).toBe(false);
|
||||
expect(provider.getDisplayMode()).toBe('floating');
|
||||
});
|
||||
209
superset-frontend/src/core/chat/ChatProvider.ts
Normal file
209
superset-frontend/src/core/chat/ChatProvider.ts
Normal file
@@ -0,0 +1,209 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
import { ComponentType } from 'react';
|
||||
import type { chat as chatApi } from '@apache-superset/core';
|
||||
import {
|
||||
LocalStorageKeys,
|
||||
getItem,
|
||||
setItem,
|
||||
} from 'src/utils/localStorageHelpers';
|
||||
import { Disposable } from '../models';
|
||||
import { createValueEventEmitter, createEventEmitter } from '../utils';
|
||||
|
||||
type Chat = chatApi.Chat;
|
||||
type DisplayMode = chatApi.DisplayMode;
|
||||
|
||||
/**
|
||||
* Singleton manager for the chat provider.
|
||||
* Handles registration, open/close state, and display mode.
|
||||
*/
|
||||
class ChatProvider {
|
||||
private static instance: ChatProvider;
|
||||
|
||||
private chat: Chat | undefined;
|
||||
|
||||
private trigger: ComponentType | undefined;
|
||||
|
||||
private panel: ComponentType | undefined;
|
||||
|
||||
private opened: boolean;
|
||||
|
||||
private stateSubscribers = new Set<() => void>();
|
||||
|
||||
private registerEmitter = createEventEmitter<Chat>();
|
||||
|
||||
private unregisterEmitter = createEventEmitter<Chat>();
|
||||
|
||||
private openEmitter = createEventEmitter<void>();
|
||||
|
||||
private closeEmitter = createEventEmitter<void>();
|
||||
|
||||
private resizePanelEmitter = createEventEmitter<{ width: number }>();
|
||||
|
||||
private modeEmitter: ReturnType<typeof createValueEventEmitter<DisplayMode>>;
|
||||
|
||||
private constructor() {
|
||||
const persisted = getItem(LocalStorageKeys.ChatState, {
|
||||
open: false,
|
||||
mode: 'floating',
|
||||
});
|
||||
const mode = (
|
||||
persisted.mode === 'panel' ? 'panel' : 'floating'
|
||||
) as DisplayMode;
|
||||
this.opened = persisted.open === true;
|
||||
this.modeEmitter = createValueEventEmitter<DisplayMode>(mode);
|
||||
}
|
||||
|
||||
public static getInstance(): ChatProvider {
|
||||
if (!ChatProvider.instance) {
|
||||
ChatProvider.instance = new ChatProvider();
|
||||
}
|
||||
return ChatProvider.instance;
|
||||
}
|
||||
|
||||
public subscribe = (listener: () => void): (() => void) => {
|
||||
this.stateSubscribers.add(listener);
|
||||
return () => this.stateSubscribers.delete(listener);
|
||||
};
|
||||
|
||||
private notifyState(): void {
|
||||
setItem(LocalStorageKeys.ChatState, {
|
||||
open: this.opened,
|
||||
mode: this.modeEmitter.getCurrent(),
|
||||
});
|
||||
this.stateSubscribers.forEach(fn => fn());
|
||||
}
|
||||
|
||||
private closePanel(): void {
|
||||
this.opened = false;
|
||||
this.closeEmitter.fire();
|
||||
}
|
||||
|
||||
public registerChat(
|
||||
chat: Chat,
|
||||
trigger: ComponentType,
|
||||
panel: ComponentType,
|
||||
): Disposable {
|
||||
if (this.chat) {
|
||||
// eslint-disable-next-line no-console
|
||||
console.warn(
|
||||
`[Superset] Multiple chat extensions registered. Using "${chat.id}"; discarding "${this.chat.id}".`,
|
||||
);
|
||||
this.unregisterEmitter.fire(this.chat);
|
||||
if (this.opened) this.closePanel();
|
||||
}
|
||||
|
||||
this.chat = chat;
|
||||
this.trigger = trigger;
|
||||
this.panel = panel;
|
||||
this.registerEmitter.fire(chat);
|
||||
this.notifyState();
|
||||
|
||||
return new Disposable(() => {
|
||||
if (this.chat !== chat) return;
|
||||
this.chat = undefined;
|
||||
this.trigger = undefined;
|
||||
this.panel = undefined;
|
||||
this.unregisterEmitter.fire(chat);
|
||||
if (this.opened) this.closePanel();
|
||||
this.notifyState();
|
||||
});
|
||||
}
|
||||
|
||||
public getChat(): Chat | undefined {
|
||||
return this.chat;
|
||||
}
|
||||
|
||||
public getTrigger(): ComponentType | undefined {
|
||||
return this.trigger;
|
||||
}
|
||||
|
||||
public getPanel(): ComponentType | undefined {
|
||||
return this.panel;
|
||||
}
|
||||
|
||||
public open(): void {
|
||||
if (this.opened || !this.chat) return;
|
||||
this.opened = true;
|
||||
this.openEmitter.fire();
|
||||
this.notifyState();
|
||||
}
|
||||
|
||||
public close(): void {
|
||||
if (!this.opened || !this.chat) return;
|
||||
this.closePanel();
|
||||
this.notifyState();
|
||||
}
|
||||
|
||||
public isOpen(): boolean {
|
||||
return this.opened;
|
||||
}
|
||||
|
||||
public getDisplayMode(): DisplayMode {
|
||||
return this.modeEmitter.getCurrent();
|
||||
}
|
||||
|
||||
public setDisplayMode(displayMode: DisplayMode): void {
|
||||
if (displayMode === this.modeEmitter.getCurrent()) return;
|
||||
this.modeEmitter.fire(displayMode);
|
||||
this.notifyState();
|
||||
}
|
||||
|
||||
public get onDidRegisterChat() {
|
||||
return this.registerEmitter.subscribe;
|
||||
}
|
||||
|
||||
public get onDidUnregisterChat() {
|
||||
return this.unregisterEmitter.subscribe;
|
||||
}
|
||||
|
||||
public get onDidOpen() {
|
||||
return this.openEmitter.subscribe;
|
||||
}
|
||||
|
||||
public get onDidClose() {
|
||||
return this.closeEmitter.subscribe;
|
||||
}
|
||||
|
||||
public get onDidChangeDisplayMode() {
|
||||
return this.modeEmitter.subscribe;
|
||||
}
|
||||
|
||||
public get onDidResizePanel() {
|
||||
return this.resizePanelEmitter.subscribe;
|
||||
}
|
||||
|
||||
public reset(): void {
|
||||
this.chat = undefined;
|
||||
this.trigger = undefined;
|
||||
this.panel = undefined;
|
||||
this.opened = false;
|
||||
this.registerEmitter = createEventEmitter<Chat>();
|
||||
this.unregisterEmitter = createEventEmitter<Chat>();
|
||||
this.openEmitter = createEventEmitter<void>();
|
||||
this.closeEmitter = createEventEmitter<void>();
|
||||
this.resizePanelEmitter = createEventEmitter<{ width: number }>();
|
||||
this.modeEmitter = createValueEventEmitter<DisplayMode>('floating');
|
||||
this.stateSubscribers.clear();
|
||||
setItem(LocalStorageKeys.ChatState, { open: false, mode: 'floating' });
|
||||
}
|
||||
}
|
||||
|
||||
export default ChatProvider;
|
||||
68
superset-frontend/src/core/chat/index.test.ts
Normal file
68
superset-frontend/src/core/chat/index.test.ts
Normal file
@@ -0,0 +1,68 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import { createElement } from 'react';
|
||||
import { chat } from './index';
|
||||
import ChatProvider from './ChatProvider';
|
||||
|
||||
const trigger = () => createElement('button', null, 'Bubble');
|
||||
const panel = () => createElement('div', null, 'Panel');
|
||||
|
||||
beforeEach(() => {
|
||||
ChatProvider.getInstance().reset();
|
||||
});
|
||||
|
||||
test('getChat returns undefined when no chat is registered', () => {
|
||||
expect(chat.getChat()).toBeUndefined();
|
||||
});
|
||||
|
||||
test('registerChat makes the chat retrievable via getChat', () => {
|
||||
const descriptor = { id: 'acme.chat', name: 'Acme Chat' };
|
||||
chat.registerChat(descriptor, trigger, panel);
|
||||
|
||||
expect(chat.getChat()).toEqual(descriptor);
|
||||
});
|
||||
|
||||
test('the last-registered chat wins when multiple are registered', () => {
|
||||
jest.spyOn(console, 'warn').mockImplementation(() => {});
|
||||
|
||||
chat.registerChat({ id: 'first.chat', name: 'First' }, trigger, panel);
|
||||
chat.registerChat({ id: 'second.chat', name: 'Second' }, trigger, panel);
|
||||
|
||||
expect(chat.getChat()?.id).toBe('second.chat');
|
||||
jest.restoreAllMocks();
|
||||
});
|
||||
|
||||
test('open and close toggle isOpen', () => {
|
||||
chat.registerChat({ id: 'acme.chat', name: 'Acme' }, trigger, panel);
|
||||
|
||||
expect(chat.isOpen()).toBe(false);
|
||||
chat.open();
|
||||
expect(chat.isOpen()).toBe(true);
|
||||
chat.close();
|
||||
expect(chat.isOpen()).toBe(false);
|
||||
});
|
||||
|
||||
test('getDisplayMode defaults to floating', () => {
|
||||
expect(chat.getDisplayMode()).toBe('floating');
|
||||
});
|
||||
|
||||
test('setDisplayMode updates the display mode', () => {
|
||||
chat.setDisplayMode('panel');
|
||||
expect(chat.getDisplayMode()).toBe('panel');
|
||||
});
|
||||
82
superset-frontend/src/core/chat/index.ts
Normal file
82
superset-frontend/src/core/chat/index.ts
Normal file
@@ -0,0 +1,82 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @fileoverview Host implementation of the `chat` contribution type.
|
||||
*
|
||||
* Extensions register via the public `chat.registerChat()` and the host owns
|
||||
* mounting, open/close state, and the display mode. Only the last-registered
|
||||
* chat is active at a time.
|
||||
*
|
||||
* The public namespace (`chat`) is exposed to extensions on `window.superset`.
|
||||
* `useChat` is host-internal and NOT part of the public `@apache-superset/core` API.
|
||||
*/
|
||||
|
||||
import { useSyncExternalStore } from 'react';
|
||||
import memoizeOne from 'memoize-one';
|
||||
import type { chat as chatApi } from '@apache-superset/core';
|
||||
import ChatProvider from './ChatProvider';
|
||||
|
||||
export { ChatFloatingHost, ChatPanelHost } from './ChatHost';
|
||||
|
||||
const provider = ChatProvider.getInstance();
|
||||
|
||||
const buildSnapshot = memoizeOne(
|
||||
(
|
||||
open: boolean,
|
||||
mode: chatApi.DisplayMode,
|
||||
chat: chatApi.Chat | undefined,
|
||||
trigger: ReturnType<typeof provider.getTrigger>,
|
||||
panel: ReturnType<typeof provider.getPanel>,
|
||||
) => ({ open, mode, chat, trigger, panel }),
|
||||
);
|
||||
|
||||
const getSnapshot = () =>
|
||||
buildSnapshot(
|
||||
provider.isOpen(),
|
||||
provider.getDisplayMode(),
|
||||
provider.getChat(),
|
||||
provider.getTrigger(),
|
||||
provider.getPanel(),
|
||||
);
|
||||
|
||||
/**
|
||||
* Host-internal hook. Returns the current open/mode state and the active chat
|
||||
* (trigger, panel, descriptor).
|
||||
*/
|
||||
export const useChat = () =>
|
||||
useSyncExternalStore(provider.subscribe, getSnapshot);
|
||||
|
||||
export const chat: typeof chatApi = {
|
||||
registerChat: provider.registerChat.bind(provider),
|
||||
getChat: provider.getChat.bind(provider),
|
||||
onDidRegisterChat: provider.onDidRegisterChat,
|
||||
onDidUnregisterChat: provider.onDidUnregisterChat,
|
||||
open: provider.open.bind(provider),
|
||||
close: provider.close.bind(provider),
|
||||
isOpen: provider.isOpen.bind(provider),
|
||||
onDidOpen: provider.onDidOpen,
|
||||
onDidClose: provider.onDidClose,
|
||||
getDisplayMode: provider.getDisplayMode.bind(provider),
|
||||
setDisplayMode: provider.setDisplayMode.bind(provider),
|
||||
onDidChangeDisplayMode: provider.onDidChangeDisplayMode,
|
||||
// The host fires this from its panel resizer; until that chrome exists the
|
||||
// event is exposed but never fires.
|
||||
onDidResizePanel: provider.onDidResizePanel,
|
||||
};
|
||||
@@ -254,33 +254,6 @@ test('event listeners can be disposed', () => {
|
||||
expect(listener).toHaveBeenCalledTimes(1); // Still only 1 call
|
||||
});
|
||||
|
||||
test('handles errors in event listeners gracefully', () => {
|
||||
const manager = EditorProviders.getInstance();
|
||||
const consoleErrorSpy = jest.spyOn(console, 'error').mockImplementation();
|
||||
|
||||
const errorListener = jest.fn(() => {
|
||||
throw new Error('Listener error');
|
||||
});
|
||||
const successListener = jest.fn();
|
||||
|
||||
manager.onDidRegister(errorListener);
|
||||
manager.onDidRegister(successListener);
|
||||
|
||||
manager.registerProvider(createMockEditor(), createMockEditorComponent());
|
||||
|
||||
// Both listeners should have been called
|
||||
expect(errorListener).toHaveBeenCalledTimes(1);
|
||||
expect(successListener).toHaveBeenCalledTimes(1);
|
||||
|
||||
// Error should have been logged
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
'Error in event listener:',
|
||||
expect.any(Error),
|
||||
);
|
||||
|
||||
consoleErrorSpy.mockRestore();
|
||||
});
|
||||
|
||||
test('reset clears all providers and language mappings', () => {
|
||||
const manager = EditorProviders.getInstance();
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
|
||||
import type { editors } from '@apache-superset/core';
|
||||
import { Disposable } from '../models';
|
||||
import { createEventEmitter } from '../utils';
|
||||
|
||||
type EditorLanguage = editors.EditorLanguage;
|
||||
type EditorProvider = editors.EditorProvider;
|
||||
@@ -27,45 +28,8 @@ type EditorComponent = editors.EditorComponent;
|
||||
type EditorRegisteredEvent = editors.EditorRegisteredEvent;
|
||||
type EditorUnregisteredEvent = editors.EditorUnregisteredEvent;
|
||||
|
||||
/**
|
||||
* Listener function type for events.
|
||||
*/
|
||||
type Listener<T> = (e: T) => void;
|
||||
|
||||
/**
|
||||
* Simple event emitter for editor provider lifecycle events.
|
||||
*/
|
||||
class EventEmitter<T> {
|
||||
private listeners: Set<Listener<T>> = new Set();
|
||||
|
||||
/**
|
||||
* Subscribe to this event.
|
||||
* @param listener The listener function to call when the event is fired.
|
||||
* @returns A Disposable to unsubscribe from the event.
|
||||
*/
|
||||
subscribe(listener: Listener<T>): Disposable {
|
||||
this.listeners.add(listener);
|
||||
return new Disposable(() => {
|
||||
this.listeners.delete(listener);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Fire the event with the given data.
|
||||
* @param data The event data to pass to listeners.
|
||||
*/
|
||||
fire(data: T): void {
|
||||
this.listeners.forEach(listener => {
|
||||
try {
|
||||
listener(data);
|
||||
} catch (error) {
|
||||
// eslint-disable-next-line no-console
|
||||
console.error('Error in event listener:', error);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Singleton manager for editor providers.
|
||||
* Handles registration, resolution, and lifecycle of custom editor implementations.
|
||||
@@ -83,15 +47,9 @@ class EditorProviders {
|
||||
*/
|
||||
private languageToProvider: Map<EditorLanguage, string> = new Map();
|
||||
|
||||
/**
|
||||
* Event emitter for provider registration events.
|
||||
*/
|
||||
private registerEmitter = new EventEmitter<EditorRegisteredEvent>();
|
||||
private registerEmitter = createEventEmitter<EditorRegisteredEvent>();
|
||||
|
||||
/**
|
||||
* Event emitter for provider unregistration events.
|
||||
*/
|
||||
private unregisterEmitter = new EventEmitter<EditorUnregisteredEvent>();
|
||||
private unregisterEmitter = createEventEmitter<EditorUnregisteredEvent>();
|
||||
|
||||
private syncListeners: Set<() => void> = new Set();
|
||||
|
||||
@@ -226,8 +184,11 @@ class EditorProviders {
|
||||
* @param listener The listener function.
|
||||
* @returns A Disposable to unsubscribe.
|
||||
*/
|
||||
public onDidRegister(listener: Listener<EditorRegisteredEvent>): Disposable {
|
||||
return this.registerEmitter.subscribe(listener);
|
||||
public onDidRegister(
|
||||
listener: Listener<EditorRegisteredEvent>,
|
||||
thisArgs?: unknown,
|
||||
): Disposable {
|
||||
return this.registerEmitter.subscribe(listener, thisArgs);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -237,8 +198,9 @@ class EditorProviders {
|
||||
*/
|
||||
public onDidUnregister(
|
||||
listener: Listener<EditorUnregisteredEvent>,
|
||||
thisArgs?: unknown,
|
||||
): Disposable {
|
||||
return this.unregisterEmitter.subscribe(listener);
|
||||
return this.unregisterEmitter.subscribe(listener, thisArgs);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -248,6 +210,8 @@ class EditorProviders {
|
||||
this.providers.clear();
|
||||
this.languageToProvider.clear();
|
||||
this.syncListeners.clear();
|
||||
this.registerEmitter = createEventEmitter<EditorRegisteredEvent>();
|
||||
this.unregisterEmitter = createEventEmitter<EditorUnregisteredEvent>();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -18,130 +18,39 @@
|
||||
*/
|
||||
|
||||
/**
|
||||
* @fileoverview Implementation of the editors API for Superset.
|
||||
* @fileoverview Host implementation of the `editors` contribution type.
|
||||
*
|
||||
* This module provides the runtime implementation of the editor registration
|
||||
* and resolution functions declared in the API types.
|
||||
* Extensions register via the public `editors.registerEditor()` and the host
|
||||
* resolves the appropriate provider per language, falling back to the built-in
|
||||
* AceEditorProvider when no extension is registered.
|
||||
*
|
||||
* The public namespace (`editors`) is exposed to extensions on `window.superset`.
|
||||
* `EditorHost` is the host-internal component for rendering editors and is NOT
|
||||
* part of the public `@apache-superset/core` API.
|
||||
*/
|
||||
|
||||
import { useSyncExternalStore } from 'react';
|
||||
import { editors as editorsApi } from '@apache-superset/core';
|
||||
import { Disposable } from '../models';
|
||||
import EditorProviders from './EditorProviders';
|
||||
|
||||
type EditorLanguage = editorsApi.EditorLanguage;
|
||||
type Editor = editorsApi.Editor;
|
||||
type EditorProvider = editorsApi.EditorProvider;
|
||||
type EditorComponent = editorsApi.EditorComponent;
|
||||
type EditorRegisteredEvent = editorsApi.EditorRegisteredEvent;
|
||||
type EditorUnregisteredEvent = editorsApi.EditorUnregisteredEvent;
|
||||
export type { EditorHostProps } from './EditorHost';
|
||||
export { default as EditorHost } from './EditorHost';
|
||||
export { default as AceEditorProvider } from './AceEditorProvider';
|
||||
|
||||
/**
|
||||
* Register an editor provider as a module-level side effect.
|
||||
* Takes the editor descriptor directly rather than looking it up
|
||||
* from a manifest by ID.
|
||||
*
|
||||
* @param editor The editor descriptor.
|
||||
* @param component The React component implementing the editor.
|
||||
* @returns A Disposable to unregister the provider.
|
||||
*/
|
||||
export const registerEditor = (
|
||||
editor: Editor,
|
||||
component: EditorComponent,
|
||||
): Disposable => {
|
||||
const providers = EditorProviders.getInstance();
|
||||
return providers.registerProvider(editor, component);
|
||||
};
|
||||
const provider = EditorProviders.getInstance();
|
||||
|
||||
/**
|
||||
* Get the editor provider for a specific language.
|
||||
* Returns the extension's editor if registered, otherwise undefined.
|
||||
*
|
||||
* @param language The language to get an editor for
|
||||
* @returns The editor provider or undefined if no extension provides one
|
||||
*/
|
||||
export const getEditor = (
|
||||
language: EditorLanguage,
|
||||
): EditorProvider | undefined => {
|
||||
const manager = EditorProviders.getInstance();
|
||||
return manager.getProvider(language);
|
||||
};
|
||||
|
||||
/**
|
||||
* Check if an extension has registered an editor for a language.
|
||||
*
|
||||
* @param language The language to check
|
||||
* @returns True if an extension provides an editor for this language
|
||||
*/
|
||||
export const hasEditor = (language: EditorLanguage): boolean => {
|
||||
const manager = EditorProviders.getInstance();
|
||||
return manager.hasProvider(language);
|
||||
};
|
||||
|
||||
/**
|
||||
* Get all registered editor providers.
|
||||
*
|
||||
* @returns Array of all registered editor providers
|
||||
*/
|
||||
export const getAllEditors = (): EditorProvider[] => {
|
||||
const manager = EditorProviders.getInstance();
|
||||
return manager.getAllProviders();
|
||||
};
|
||||
|
||||
/**
|
||||
* Event fired when an editor is registered.
|
||||
* Subscribe to this event to react when extensions register new editors.
|
||||
*/
|
||||
export const onDidRegisterEditor = (
|
||||
listener: (e: EditorRegisteredEvent) => void,
|
||||
): Disposable => {
|
||||
const manager = EditorProviders.getInstance();
|
||||
return manager.onDidRegister(listener);
|
||||
};
|
||||
|
||||
/**
|
||||
* Event fired when an editor is unregistered.
|
||||
* Subscribe to this event to react when extensions unregister editors.
|
||||
*/
|
||||
export const onDidUnregisterEditor = (
|
||||
listener: (e: EditorUnregisteredEvent) => void,
|
||||
): Disposable => {
|
||||
const manager = EditorProviders.getInstance();
|
||||
return manager.onDidUnregister(listener);
|
||||
};
|
||||
|
||||
/**
|
||||
* Hook that returns the editor provider for a specific language and re-renders when it changes.
|
||||
*
|
||||
* @param language The language to get an editor for
|
||||
* @returns The editor provider or undefined if no extension provides one
|
||||
*/
|
||||
export const useEditor = (
|
||||
language: EditorLanguage,
|
||||
): EditorProvider | undefined => {
|
||||
const manager = EditorProviders.getInstance();
|
||||
return useSyncExternalStore(
|
||||
manager.subscribe,
|
||||
() => manager.getProvider(language),
|
||||
export const useEditor = (language: editorsApi.EditorLanguage) =>
|
||||
useSyncExternalStore(
|
||||
provider.subscribe,
|
||||
() => provider.getProvider(language),
|
||||
() => undefined,
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Editors API object for use in the extension system.
|
||||
*/
|
||||
export const editors: typeof editorsApi = {
|
||||
registerEditor,
|
||||
getEditor,
|
||||
hasEditor,
|
||||
getAllEditors,
|
||||
onDidRegisterEditor,
|
||||
onDidUnregisterEditor,
|
||||
registerEditor: provider.registerProvider.bind(provider),
|
||||
getEditor: provider.getProvider.bind(provider),
|
||||
hasEditor: provider.hasProvider.bind(provider),
|
||||
getAllEditors: provider.getAllProviders.bind(provider),
|
||||
onDidRegisterEditor: provider.onDidRegister.bind(provider),
|
||||
onDidUnregisterEditor: provider.onDidUnregister.bind(provider),
|
||||
};
|
||||
|
||||
export { EditorProviders };
|
||||
|
||||
// Component exports
|
||||
export { default as EditorHost } from './EditorHost';
|
||||
export type { EditorHostProps } from './EditorHost';
|
||||
export { default as AceEditorProvider } from './AceEditorProvider';
|
||||
|
||||
@@ -27,11 +27,13 @@ export const core: typeof coreType = {
|
||||
};
|
||||
|
||||
export * from './authentication';
|
||||
export * from './chat';
|
||||
export * from './commands';
|
||||
export * from './editors';
|
||||
export * from './extensions';
|
||||
export * from './menus';
|
||||
export * from './models';
|
||||
export * from './navigation';
|
||||
export * from './sqlLab';
|
||||
export * from './utils';
|
||||
export * from './views';
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
import { useSyncExternalStore } from 'react';
|
||||
import type { menus as menusApi } from '@apache-superset/core';
|
||||
import { Disposable } from '../models';
|
||||
import { createEventEmitter } from '../utils';
|
||||
|
||||
type MenuItem = menusApi.MenuItem;
|
||||
type Menu = menusApi.Menu;
|
||||
@@ -47,19 +48,19 @@ const subscribe = (listener: () => void) => {
|
||||
return () => syncListeners.delete(listener);
|
||||
};
|
||||
|
||||
const registerListeners = new Set<(e: MenuItemRegisteredEvent) => void>();
|
||||
const unregisterListeners = new Set<(e: MenuItemUnregisteredEvent) => void>();
|
||||
const registerEmitter = createEventEmitter<MenuItemRegisteredEvent>();
|
||||
const unregisterEmitter = createEventEmitter<MenuItemUnregisteredEvent>();
|
||||
|
||||
const menuCache = new Map<string, Menu | undefined>();
|
||||
const notifyRegister = (event: MenuItemRegisteredEvent) => {
|
||||
menuCache.clear();
|
||||
syncListeners.forEach(l => l());
|
||||
registerListeners.forEach(l => l(event));
|
||||
registerEmitter.fire(event);
|
||||
};
|
||||
const notifyUnregister = (event: MenuItemUnregisteredEvent) => {
|
||||
menuCache.clear();
|
||||
syncListeners.forEach(l => l());
|
||||
unregisterListeners.forEach(l => l(event));
|
||||
unregisterEmitter.fire(event);
|
||||
};
|
||||
|
||||
const registerMenuItem: typeof menusApi.registerMenuItem = (
|
||||
@@ -117,16 +118,14 @@ export const useMenu = (location: string): Menu | undefined =>
|
||||
|
||||
export const onDidRegisterMenuItem: typeof menusApi.onDidRegisterMenuItem = (
|
||||
listener: (e: MenuItemRegisteredEvent) => void,
|
||||
): Disposable => {
|
||||
registerListeners.add(listener);
|
||||
return new Disposable(() => registerListeners.delete(listener));
|
||||
};
|
||||
thisArgs?: unknown,
|
||||
): Disposable => registerEmitter.subscribe(listener, thisArgs);
|
||||
|
||||
export const onDidUnregisterMenuItem: typeof menusApi.onDidUnregisterMenuItem =
|
||||
(listener: (e: MenuItemUnregisteredEvent) => void): Disposable => {
|
||||
unregisterListeners.add(listener);
|
||||
return new Disposable(() => unregisterListeners.delete(listener));
|
||||
};
|
||||
(
|
||||
listener: (e: MenuItemUnregisteredEvent) => void,
|
||||
thisArgs?: unknown,
|
||||
): Disposable => unregisterEmitter.subscribe(listener, thisArgs);
|
||||
|
||||
export const menus: typeof menusApi = {
|
||||
registerMenuItem,
|
||||
|
||||
124
superset-frontend/src/core/navigation/index.test.ts
Normal file
124
superset-frontend/src/core/navigation/index.test.ts
Normal file
@@ -0,0 +1,124 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
// Reset module state between tests so currentPage is re-initialized.
|
||||
beforeEach(() => {
|
||||
jest.resetModules();
|
||||
Object.defineProperty(window, 'location', {
|
||||
writable: true,
|
||||
value: { pathname: '/' },
|
||||
});
|
||||
});
|
||||
|
||||
async function importNavigation() {
|
||||
const mod = await import('./index');
|
||||
return mod;
|
||||
}
|
||||
|
||||
test('getPage falls back to "home" for the welcome page and unknown pathnames', async () => {
|
||||
const { navigation, notifyLocationChanged } = await importNavigation();
|
||||
// The default pathname ('/') is not enumerated and falls back to home.
|
||||
expect(navigation.getPage()).toBe('home');
|
||||
notifyLocationChanged('/superset/welcome/');
|
||||
expect(navigation.getPage()).toBe('home');
|
||||
});
|
||||
|
||||
test('getPage derives the page from window.location.pathname', async () => {
|
||||
window.location.pathname = '/superset/dashboard/42/';
|
||||
const { navigation } = await importNavigation();
|
||||
expect(navigation.getPage()).toBe('dashboard');
|
||||
});
|
||||
|
||||
test('notifyLocationChanged updates the current page type', async () => {
|
||||
const { navigation, notifyLocationChanged } = await importNavigation();
|
||||
notifyLocationChanged('/explore/?form_data={}');
|
||||
expect(navigation.getPage()).toBe('explore');
|
||||
});
|
||||
|
||||
test('notifyLocationChanged fires listeners on page type change', async () => {
|
||||
const { navigation, notifyLocationChanged } = await importNavigation();
|
||||
const listener = jest.fn();
|
||||
const disposable = navigation.onDidChangePage(listener);
|
||||
|
||||
notifyLocationChanged('/superset/dashboard/1/');
|
||||
expect(listener).toHaveBeenCalledWith('dashboard');
|
||||
|
||||
disposable.dispose();
|
||||
});
|
||||
|
||||
test('notifyLocationChanged does not fire listeners when page type is unchanged', async () => {
|
||||
window.location.pathname = '/superset/dashboard/1/';
|
||||
const { navigation, notifyLocationChanged } = await importNavigation();
|
||||
const listener = jest.fn();
|
||||
navigation.onDidChangePage(listener);
|
||||
|
||||
notifyLocationChanged('/superset/dashboard/2/');
|
||||
expect(listener).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('onDidChangePage listener is removed after dispose', async () => {
|
||||
const { navigation, notifyLocationChanged } = await importNavigation();
|
||||
const listener = jest.fn();
|
||||
const disposable = navigation.onDidChangePage(listener);
|
||||
|
||||
disposable.dispose();
|
||||
notifyLocationChanged('/superset/dashboard/1/');
|
||||
expect(listener).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('sqllab path is matched with and without trailing slash', async () => {
|
||||
const { notifyLocationChanged, navigation } = await importNavigation();
|
||||
notifyLocationChanged('/sqllab');
|
||||
expect(navigation.getPage()).toBe('sqllab');
|
||||
notifyLocationChanged('/explore/');
|
||||
notifyLocationChanged('/sqllab/');
|
||||
expect(navigation.getPage()).toBe('sqllab');
|
||||
});
|
||||
|
||||
test('chart and dashboard list pages get their own page types', async () => {
|
||||
const { notifyLocationChanged, navigation } = await importNavigation();
|
||||
notifyLocationChanged('/chart/list/');
|
||||
expect(navigation.getPage()).toBe('chart_list');
|
||||
notifyLocationChanged('/dashboard/list/');
|
||||
expect(navigation.getPage()).toBe('dashboard_list');
|
||||
});
|
||||
|
||||
test('dataset list and single-dataset pages get distinct page types', async () => {
|
||||
const { notifyLocationChanged, navigation } = await importNavigation();
|
||||
notifyLocationChanged('/tablemodelview/list/');
|
||||
expect(navigation.getPage()).toBe('dataset_list');
|
||||
notifyLocationChanged('/dataset/42');
|
||||
expect(navigation.getPage()).toBe('dataset');
|
||||
});
|
||||
|
||||
test('sqllab editor, query history, and saved queries get distinct page types', async () => {
|
||||
const { notifyLocationChanged, navigation } = await importNavigation();
|
||||
notifyLocationChanged('/sqllab/');
|
||||
expect(navigation.getPage()).toBe('sqllab');
|
||||
notifyLocationChanged('/sqllab/history/');
|
||||
expect(navigation.getPage()).toBe('query_history');
|
||||
notifyLocationChanged('/savedqueryview/list/');
|
||||
expect(navigation.getPage()).toBe('saved_queries');
|
||||
});
|
||||
|
||||
test('chart/add resolves to explore, not chart_list', async () => {
|
||||
const { notifyLocationChanged, navigation } = await importNavigation();
|
||||
notifyLocationChanged('/chart/add');
|
||||
expect(navigation.getPage()).toBe('explore');
|
||||
});
|
||||
94
superset-frontend/src/core/navigation/index.ts
Normal file
94
superset-frontend/src/core/navigation/index.ts
Normal file
@@ -0,0 +1,94 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @fileoverview Host-internal implementation of the `navigation` namespace.
|
||||
*
|
||||
* Derives the current {@link Page} from the browser location by matching
|
||||
* against {@link RoutePaths}. Call {@link useNavigationTracker} once in the
|
||||
* app shell to keep the page in sync with React Router.
|
||||
*/
|
||||
|
||||
import { useEffect, useRef } from 'react';
|
||||
import { useLocation, matchPath } from 'react-router-dom';
|
||||
import type { navigation as navigationApi } from '@apache-superset/core';
|
||||
import { RoutePaths } from '../../views/routePaths';
|
||||
import { Disposable } from '../models';
|
||||
import { createValueEventEmitter } from '../utils';
|
||||
|
||||
type Page = navigationApi.Page;
|
||||
|
||||
/** Maps route path patterns to their corresponding Page type. */
|
||||
const PAGE_ROUTES: { path: string; page: Page }[] = [
|
||||
{ path: RoutePaths.DASHBOARD, page: 'dashboard' },
|
||||
{ path: RoutePaths.DASHBOARD_LIST, page: 'dashboard_list' },
|
||||
{ path: RoutePaths.QUERY_HISTORY, page: 'query_history' },
|
||||
{ path: RoutePaths.SAVED_QUERIES, page: 'saved_queries' },
|
||||
{ path: RoutePaths.SQLLAB, page: 'sqllab' },
|
||||
{ path: RoutePaths.CHART_ADD, page: 'explore' },
|
||||
{ path: RoutePaths.CHART_LIST, page: 'chart_list' },
|
||||
{ path: RoutePaths.EXPLORE, page: 'explore' },
|
||||
{ path: RoutePaths.EXPLORE_PERMALINK, page: 'explore' },
|
||||
{ path: RoutePaths.DATASET_LIST, page: 'dataset_list' },
|
||||
{ path: RoutePaths.DATASET_ADD, page: 'dataset' },
|
||||
{ path: RoutePaths.DATASET, page: 'dataset' },
|
||||
];
|
||||
|
||||
function derivePage(pathname: string): Page {
|
||||
for (const { path, page } of PAGE_ROUTES) {
|
||||
if (matchPath(pathname, { path, exact: false })) return page;
|
||||
}
|
||||
return 'home';
|
||||
}
|
||||
|
||||
const pageEmitter = createValueEventEmitter<Page>(
|
||||
derivePage(window.location.pathname),
|
||||
);
|
||||
|
||||
/** Updates the current page from a pathname. No-op when the page is unchanged. */
|
||||
export const notifyLocationChanged = (pathname: string): void => {
|
||||
const next = derivePage(pathname);
|
||||
if (next === pageEmitter.getCurrent()) return;
|
||||
pageEmitter.fire(next);
|
||||
};
|
||||
|
||||
const getPage: typeof navigationApi.getPage = () => pageEmitter.getCurrent();
|
||||
|
||||
const onDidChangePage: typeof navigationApi.onDidChangePage = (
|
||||
listener: (page: Page) => void,
|
||||
thisArgs?: unknown,
|
||||
): Disposable => pageEmitter.subscribe(listener, thisArgs);
|
||||
|
||||
/** Synchronizes the navigation module with React Router. Call once in the app shell. */
|
||||
export const useNavigationTracker = () => {
|
||||
const location = useLocation();
|
||||
const prevPathname = useRef<string | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (prevPathname.current !== location.pathname) {
|
||||
prevPathname.current = location.pathname;
|
||||
notifyLocationChanged(location.pathname);
|
||||
}
|
||||
}, [location.pathname]);
|
||||
};
|
||||
|
||||
export const navigation: typeof navigationApi = {
|
||||
getPage,
|
||||
onDidChangePage,
|
||||
};
|
||||
@@ -48,7 +48,7 @@ import { AnyListenerPredicate } from '@reduxjs/toolkit';
|
||||
import type { QueryEditor, SqlLabRootState } from 'src/SqlLab/types';
|
||||
import { newQueryTabName } from 'src/SqlLab/utils/newQueryTabName';
|
||||
import { Database, Disposable } from '../models';
|
||||
import { createActionListener } from '../utils';
|
||||
import { createActionListener } from '../storeUtils';
|
||||
import {
|
||||
Panel,
|
||||
Tab,
|
||||
|
||||
48
superset-frontend/src/core/storeUtils.ts
Normal file
48
superset-frontend/src/core/storeUtils.ts
Normal file
@@ -0,0 +1,48 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import type { common as core } from '@apache-superset/core';
|
||||
import { listenerMiddleware, RootState, store } from 'src/views/store';
|
||||
import { AnyListenerPredicate } from '@reduxjs/toolkit';
|
||||
|
||||
export function createActionListener<V, A = unknown>(
|
||||
predicate: AnyListenerPredicate<RootState>,
|
||||
listener: (v: V) => void,
|
||||
valueParser: (action: A, state: RootState) => V | null | undefined,
|
||||
thisArgs?: unknown,
|
||||
): core.Disposable {
|
||||
const boundListener = thisArgs ? listener.bind(thisArgs as object) : listener;
|
||||
|
||||
const unsubscribe = listenerMiddleware.startListening({
|
||||
predicate,
|
||||
effect: action => {
|
||||
const state = store.getState();
|
||||
// The predicate already ensures the action matches type A at runtime.
|
||||
const value = valueParser(action as unknown as A, state);
|
||||
if (value != null) {
|
||||
boundListener(value);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
return {
|
||||
dispose: () => {
|
||||
unsubscribe();
|
||||
},
|
||||
};
|
||||
}
|
||||
@@ -17,33 +17,54 @@
|
||||
* under the License.
|
||||
*/
|
||||
import type { common as core } from '@apache-superset/core';
|
||||
import { AnyAction } from 'redux';
|
||||
import { listenerMiddleware, RootState, store } from 'src/views/store';
|
||||
import { AnyListenerPredicate } from '@reduxjs/toolkit';
|
||||
|
||||
export function createActionListener<V>(
|
||||
predicate: AnyListenerPredicate<RootState>,
|
||||
listener: (v: V) => void,
|
||||
valueParser: (action: AnyAction, state: RootState) => V | null | undefined,
|
||||
thisArgs?: any,
|
||||
): core.Disposable {
|
||||
const boundListener = thisArgs ? listener.bind(thisArgs) : listener;
|
||||
type Listener<T> = (e: T) => unknown;
|
||||
|
||||
const unsubscribe = listenerMiddleware.startListening({
|
||||
predicate,
|
||||
effect: (action: AnyAction) => {
|
||||
const state = store.getState();
|
||||
const value = valueParser(action, state);
|
||||
// Skip calling listener if valueParser returns null/undefined
|
||||
if (value != null) {
|
||||
boundListener(value);
|
||||
}
|
||||
},
|
||||
});
|
||||
/** A stateless event emitter exposing a VS Code-style `event` subscriber. */
|
||||
export interface EventEmitter<T> {
|
||||
/** Notifies every current subscriber with `value`. */
|
||||
fire(value: T): void;
|
||||
/** Registers a listener; returns a Disposable that removes it. */
|
||||
subscribe: core.Event<T>;
|
||||
}
|
||||
|
||||
/** An event emitter that also retains the last fired value. */
|
||||
export interface ValueEventEmitter<T> extends EventEmitter<T> {
|
||||
/** Returns the value last passed to {@link fire} (or the initial value). */
|
||||
getCurrent(): T;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a stateless event emitter. Listeners registered via `event` receive
|
||||
* every subsequent `fire`; a returned Disposable removes the listener.
|
||||
*/
|
||||
export function createEventEmitter<T>(): EventEmitter<T> {
|
||||
const listeners = new Set<Listener<T>>();
|
||||
const subscribe: core.Event<T> = (listener, thisArgs) => {
|
||||
const bound = thisArgs ? listener.bind(thisArgs) : listener;
|
||||
listeners.add(bound);
|
||||
return { dispose: () => listeners.delete(bound) };
|
||||
};
|
||||
return {
|
||||
dispose: () => {
|
||||
unsubscribe();
|
||||
},
|
||||
fire: value => listeners.forEach(fn => fn(value)),
|
||||
subscribe,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a value event emitter seeded with `initial`. Behaves like
|
||||
* {@link createEventEmitter} but also tracks the last fired value, readable
|
||||
* via `getCurrent` — useful for state that is both observed and queried.
|
||||
*/
|
||||
export function createValueEventEmitter<T>(initial: T): ValueEventEmitter<T> {
|
||||
const { fire, subscribe } = createEventEmitter<T>();
|
||||
let current = initial;
|
||||
return {
|
||||
fire: value => {
|
||||
current = value;
|
||||
fire(value);
|
||||
},
|
||||
subscribe,
|
||||
getCurrent: () => current,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -24,11 +24,12 @@
|
||||
* Extensions register views as side effects at import time.
|
||||
*/
|
||||
|
||||
import React, { ReactElement, useSyncExternalStore } from 'react';
|
||||
import React, { ComponentType, useSyncExternalStore } from 'react';
|
||||
import type { views as viewsApi } from '@apache-superset/core';
|
||||
import { ErrorBoundary } from 'src/components/ErrorBoundary';
|
||||
import ExtensionPlaceholder from 'src/extensions/ExtensionPlaceholder';
|
||||
import { Disposable } from '../models';
|
||||
import { createEventEmitter } from '../utils';
|
||||
|
||||
type View = viewsApi.View;
|
||||
type ViewRegisteredEvent = viewsApi.ViewRegisteredEvent;
|
||||
@@ -36,7 +37,7 @@ type ViewUnregisteredEvent = viewsApi.ViewUnregisteredEvent;
|
||||
|
||||
const viewRegistry: Map<
|
||||
string,
|
||||
{ view: View; location: string; provider: () => ReactElement }
|
||||
{ view: View; location: string; component: ComponentType }
|
||||
> = new Map();
|
||||
|
||||
const locationIndex: Map<string, Set<string>> = new Map();
|
||||
@@ -47,29 +48,29 @@ const subscribe = (listener: () => void) => {
|
||||
return () => syncListeners.delete(listener);
|
||||
};
|
||||
|
||||
const registerListeners = new Set<(e: ViewRegisteredEvent) => void>();
|
||||
const unregisterListeners = new Set<(e: ViewUnregisteredEvent) => void>();
|
||||
const registerEmitter = createEventEmitter<ViewRegisteredEvent>();
|
||||
const unregisterEmitter = createEventEmitter<ViewUnregisteredEvent>();
|
||||
|
||||
const viewsCache = new Map<string, View[] | undefined>();
|
||||
const notifyRegister = (event: ViewRegisteredEvent) => {
|
||||
viewsCache.clear();
|
||||
syncListeners.forEach(l => l());
|
||||
registerListeners.forEach(l => l(event));
|
||||
registerEmitter.fire(event);
|
||||
};
|
||||
const notifyUnregister = (event: ViewUnregisteredEvent) => {
|
||||
viewsCache.clear();
|
||||
syncListeners.forEach(l => l());
|
||||
unregisterListeners.forEach(l => l(event));
|
||||
unregisterEmitter.fire(event);
|
||||
};
|
||||
|
||||
const registerView: typeof viewsApi.registerView = (
|
||||
view: View,
|
||||
location: string,
|
||||
provider: () => ReactElement,
|
||||
component: ComponentType,
|
||||
): Disposable => {
|
||||
const { id } = view;
|
||||
|
||||
viewRegistry.set(id, { view, location, provider });
|
||||
viewRegistry.set(id, { view, location, component });
|
||||
|
||||
const ids = locationIndex.get(location) ?? new Set();
|
||||
ids.add(id);
|
||||
@@ -83,12 +84,16 @@ const registerView: typeof viewsApi.registerView = (
|
||||
});
|
||||
};
|
||||
|
||||
export const resolveView = (id: string): ReactElement => {
|
||||
const provider = viewRegistry.get(id)?.provider;
|
||||
if (!provider) {
|
||||
export const resolveView = (id: string): React.ReactElement => {
|
||||
const entry = viewRegistry.get(id);
|
||||
if (!entry) {
|
||||
return React.createElement(ExtensionPlaceholder, { id });
|
||||
}
|
||||
return React.createElement(ErrorBoundary, null, provider());
|
||||
return React.createElement(
|
||||
ErrorBoundary,
|
||||
null,
|
||||
React.createElement(entry.component),
|
||||
);
|
||||
};
|
||||
|
||||
const getViews: typeof viewsApi.getViews = (
|
||||
@@ -116,17 +121,11 @@ export const useViews = (location: string): View[] | undefined =>
|
||||
|
||||
export const onDidRegisterView: typeof viewsApi.onDidRegisterView = (
|
||||
listener: (e: ViewRegisteredEvent) => void,
|
||||
): Disposable => {
|
||||
registerListeners.add(listener);
|
||||
return new Disposable(() => registerListeners.delete(listener));
|
||||
};
|
||||
): Disposable => registerEmitter.subscribe(listener);
|
||||
|
||||
export const onDidUnregisterView: typeof viewsApi.onDidUnregisterView = (
|
||||
listener: (e: ViewUnregisteredEvent) => void,
|
||||
): Disposable => {
|
||||
unregisterListeners.add(listener);
|
||||
return new Disposable(() => unregisterListeners.delete(listener));
|
||||
};
|
||||
): Disposable => unregisterEmitter.subscribe(listener);
|
||||
|
||||
export const views: typeof viewsApi = {
|
||||
registerView,
|
||||
|
||||
@@ -31,7 +31,6 @@ function createMockExtension(overrides: Partial<Extension> = {}): Extension {
|
||||
version: '1.0.0',
|
||||
dependencies: [],
|
||||
remoteEntry: '',
|
||||
extensionDependencies: [],
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -72,6 +72,7 @@ afterEach(() => {
|
||||
test('renders without crashing', () => {
|
||||
render(<ExtensionsStartup />, {
|
||||
useRedux: true,
|
||||
useRouter: true,
|
||||
initialState: mockInitialState,
|
||||
});
|
||||
|
||||
@@ -88,6 +89,7 @@ test('sets up global superset object when user is logged in', async () => {
|
||||
|
||||
render(<ExtensionsStartup />, {
|
||||
useRedux: true,
|
||||
useRouter: true,
|
||||
initialState: mockInitialState,
|
||||
});
|
||||
|
||||
@@ -95,6 +97,7 @@ test('sets up global superset object when user is logged in', async () => {
|
||||
// Verify the global superset object is set up
|
||||
expect((window as any).superset).toBeDefined();
|
||||
expect((window as any).superset.authentication).toBeDefined();
|
||||
expect((window as any).superset.chat).toBeDefined();
|
||||
expect((window as any).superset.core).toBeDefined();
|
||||
expect((window as any).superset.commands).toBeDefined();
|
||||
expect((window as any).superset.extensions).toBeDefined();
|
||||
@@ -109,6 +112,7 @@ test('sets up global superset object when user is logged in', async () => {
|
||||
test('does not set up global superset object when user is not logged in', async () => {
|
||||
render(<ExtensionsStartup />, {
|
||||
useRedux: true,
|
||||
useRouter: true,
|
||||
initialState: mockInitialStateNoUser,
|
||||
});
|
||||
|
||||
@@ -127,6 +131,7 @@ test('initializes ExtensionsLoader when user is logged in', async () => {
|
||||
|
||||
render(<ExtensionsStartup />, {
|
||||
useRedux: true,
|
||||
useRouter: true,
|
||||
initialState: mockInitialState,
|
||||
});
|
||||
|
||||
@@ -144,6 +149,7 @@ test('initializes ExtensionsLoader when user is logged in', async () => {
|
||||
test('does not initialize ExtensionsLoader when user is not logged in', async () => {
|
||||
render(<ExtensionsStartup />, {
|
||||
useRedux: true,
|
||||
useRouter: true,
|
||||
initialState: mockInitialStateNoUser,
|
||||
});
|
||||
|
||||
@@ -169,6 +175,7 @@ test('only initializes once even with multiple renders', async () => {
|
||||
|
||||
const { rerender } = render(<ExtensionsStartup />, {
|
||||
useRedux: true,
|
||||
useRouter: true,
|
||||
initialState: mockInitialState,
|
||||
});
|
||||
|
||||
@@ -205,6 +212,7 @@ test('initializes ExtensionsLoader when EnableExtensions feature flag is enabled
|
||||
|
||||
render(<ExtensionsStartup />, {
|
||||
useRedux: true,
|
||||
useRouter: true,
|
||||
initialState: mockInitialState,
|
||||
});
|
||||
|
||||
@@ -234,6 +242,7 @@ test('does not initialize ExtensionsLoader when EnableExtensions feature flag is
|
||||
|
||||
render(<ExtensionsStartup />, {
|
||||
useRedux: true,
|
||||
useRouter: true,
|
||||
initialState: mockInitialState,
|
||||
});
|
||||
|
||||
@@ -268,6 +277,7 @@ test('continues rendering children even when ExtensionsLoader initialization fai
|
||||
</ExtensionsStartup>,
|
||||
{
|
||||
useRedux: true,
|
||||
useRouter: true,
|
||||
initialState: mockInitialState,
|
||||
},
|
||||
);
|
||||
|
||||
@@ -17,41 +17,32 @@
|
||||
* under the License.
|
||||
*/
|
||||
import { useEffect } from 'react';
|
||||
import { FeatureFlag, isFeatureEnabled } from '@superset-ui/core';
|
||||
// eslint-disable-next-line no-restricted-syntax
|
||||
import * as supersetCore from '@apache-superset/core';
|
||||
import { FeatureFlag, isFeatureEnabled } from '@superset-ui/core';
|
||||
import {
|
||||
authentication,
|
||||
chat,
|
||||
core,
|
||||
commands,
|
||||
editors,
|
||||
extensions,
|
||||
menus,
|
||||
navigation,
|
||||
useNavigationTracker,
|
||||
sqlLab,
|
||||
views,
|
||||
} from 'src/core';
|
||||
import { useSelector } from 'react-redux';
|
||||
import { RootState } from 'src/views/store';
|
||||
import ExtensionsLoader from './ExtensionsLoader';
|
||||
|
||||
declare global {
|
||||
interface Window {
|
||||
superset: {
|
||||
authentication: typeof authentication;
|
||||
core: typeof core;
|
||||
commands: typeof commands;
|
||||
editors: typeof editors;
|
||||
extensions: typeof extensions;
|
||||
menus: typeof menus;
|
||||
sqlLab: typeof sqlLab;
|
||||
views: typeof views;
|
||||
};
|
||||
}
|
||||
}
|
||||
import 'src/extensions/Namespaces';
|
||||
|
||||
const ExtensionsStartup: React.FC<{ children?: React.ReactNode }> = ({
|
||||
children,
|
||||
}) => {
|
||||
useNavigationTracker();
|
||||
|
||||
const userId = useSelector<RootState, number | undefined>(
|
||||
({ user }) => user.userId,
|
||||
);
|
||||
@@ -59,15 +50,19 @@ const ExtensionsStartup: React.FC<{ children?: React.ReactNode }> = ({
|
||||
useEffect(() => {
|
||||
if (userId == null) return;
|
||||
|
||||
// Provide the implementations for @apache-superset/core
|
||||
// Provide the implementations for @apache-superset/core.
|
||||
// Namespaces are listed explicitly — do not spread the core package here,
|
||||
// as that would leak un-contracted symbols onto window.superset.
|
||||
window.superset = {
|
||||
...supersetCore,
|
||||
authentication,
|
||||
chat,
|
||||
core,
|
||||
commands,
|
||||
editors,
|
||||
extensions,
|
||||
menus,
|
||||
navigation,
|
||||
sqlLab,
|
||||
views,
|
||||
};
|
||||
|
||||
60
superset-frontend/src/extensions/Namespaces.ts
Normal file
60
superset-frontend/src/extensions/Namespaces.ts
Normal file
@@ -0,0 +1,60 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Global `window.superset` type augmentation.
|
||||
*
|
||||
* Lives in its own module (rather than inline in ExtensionsStartup) so every
|
||||
* file that reads or writes `window.superset` — notably ExtensionsLoader —
|
||||
* sees the type regardless of how files are batched during compilation. Both
|
||||
* the startup component and the loader import this module for its side effect.
|
||||
*/
|
||||
|
||||
import type {
|
||||
authentication,
|
||||
chat,
|
||||
commands,
|
||||
core,
|
||||
editors,
|
||||
extensions,
|
||||
menus,
|
||||
navigation,
|
||||
sqlLab,
|
||||
views,
|
||||
} from 'src/core';
|
||||
|
||||
/** The host namespaces exposed to extensions on `window.superset`. */
|
||||
export interface Namespaces {
|
||||
authentication: typeof authentication;
|
||||
core: typeof core;
|
||||
chat: typeof chat;
|
||||
commands: typeof commands;
|
||||
editors: typeof editors;
|
||||
extensions: typeof extensions;
|
||||
menus: typeof menus;
|
||||
navigation: typeof navigation;
|
||||
sqlLab: typeof sqlLab;
|
||||
views: typeof views;
|
||||
}
|
||||
|
||||
declare global {
|
||||
interface Window {
|
||||
superset: Namespaces;
|
||||
}
|
||||
}
|
||||
@@ -77,18 +77,22 @@ const databaseFixture: DatabaseObject = {
|
||||
// eslint-disable-next-line no-restricted-globals -- TODO: Migrate from describe blocks
|
||||
describe('DatabaseModal', () => {
|
||||
beforeEach(() => {
|
||||
fetchMock.post(DATABASE_CONNECT_ENDPOINT, {
|
||||
id: 10,
|
||||
result: {
|
||||
configuration_method: 'sqlalchemy_form',
|
||||
database_name: 'Other2',
|
||||
driver: 'apsw',
|
||||
expose_in_sqllab: true,
|
||||
extra: '{"allows_virtual_table_explore":true}',
|
||||
sqlalchemy_uri: 'gsheets://',
|
||||
fetchMock.post(
|
||||
DATABASE_CONNECT_ENDPOINT,
|
||||
{
|
||||
id: 10,
|
||||
result: {
|
||||
configuration_method: 'sqlalchemy_form',
|
||||
database_name: 'Other2',
|
||||
driver: 'apsw',
|
||||
expose_in_sqllab: true,
|
||||
extra: '{"allows_virtual_table_explore":true}',
|
||||
sqlalchemy_uri: 'gsheets://',
|
||||
},
|
||||
json: 'foo',
|
||||
},
|
||||
json: 'foo',
|
||||
});
|
||||
{ name: 'database-connect' },
|
||||
);
|
||||
|
||||
fetchMock.get(DATABASE_FETCH_ENDPOINT, {
|
||||
result: {
|
||||
@@ -311,9 +315,13 @@ describe('DatabaseModal', () => {
|
||||
},
|
||||
],
|
||||
});
|
||||
fetchMock.post(VALIDATE_PARAMS_ENDPOINT, {
|
||||
message: 'OK',
|
||||
});
|
||||
fetchMock.post(
|
||||
VALIDATE_PARAMS_ENDPOINT,
|
||||
{
|
||||
message: 'OK',
|
||||
},
|
||||
{ name: 'validate-params' },
|
||||
);
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
@@ -1367,14 +1375,16 @@ describe('DatabaseModal', () => {
|
||||
const textboxes = screen.getAllByRole('textbox');
|
||||
const hostField = textboxes[0];
|
||||
const portField = screen.getByRole('spinbutton');
|
||||
const databaseNameField = textboxes[1];
|
||||
// textboxes[1] is the connection `database` field; the engine display
|
||||
// name (`database_name`) auto-fills and is asserted separately.
|
||||
const databaseField = textboxes[1];
|
||||
const usernameField = textboxes[2];
|
||||
const passwordField = textboxes[3];
|
||||
const connectButton = screen.getByRole('button', { name: 'Connect' });
|
||||
|
||||
expect(hostField).toHaveValue('');
|
||||
expect(portField).toHaveValue(null);
|
||||
expect(databaseNameField).toHaveValue('');
|
||||
expect(databaseField).toHaveValue('');
|
||||
expect(usernameField).toHaveValue('');
|
||||
expect(passwordField).toHaveValue('');
|
||||
|
||||
@@ -1382,7 +1392,7 @@ describe('DatabaseModal', () => {
|
||||
|
||||
userEvent.type(hostField, 'localhost');
|
||||
userEvent.type(portField, '5432');
|
||||
userEvent.type(databaseNameField, 'postgres');
|
||||
userEvent.type(databaseField, 'postgres');
|
||||
userEvent.type(usernameField, 'testdb');
|
||||
userEvent.type(passwordField, 'demoPassword');
|
||||
|
||||
@@ -1391,7 +1401,7 @@ describe('DatabaseModal', () => {
|
||||
expect(await screen.findByDisplayValue(/5432/i)).toBeInTheDocument();
|
||||
expect(hostField).toHaveValue('localhost');
|
||||
expect(portField).toHaveValue(5432);
|
||||
expect(databaseNameField).toHaveValue('postgres');
|
||||
expect(databaseField).toHaveValue('postgres');
|
||||
expect(usernameField).toHaveValue('testdb');
|
||||
expect(passwordField).toHaveValue('demoPassword');
|
||||
|
||||
@@ -1581,6 +1591,135 @@ describe('DatabaseModal', () => {
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
// Tests migrated from the deprecated Cypress suite
|
||||
// (cypress-base/cypress/e2e/database/modal.test.ts). The two "error alert"
|
||||
// cases originally relied on a real backend connection attempt (real DNS /
|
||||
// socket behaviour), which made them flaky. They are reproduced here by
|
||||
// mocking the validate_parameters response: the frontend's responsibility is
|
||||
// to map an `extra.invalid` field error onto the matching form field, which
|
||||
// is exactly what these assertions exercise. Whether a bad host/port really
|
||||
// fails to connect is a backend concern, covered by backend tests.
|
||||
const selectPostgres = async () => {
|
||||
userEvent.click(await screen.findByRole('button', { name: /postgresql/i }));
|
||||
// Dynamic form (step 2 of 3) is now visible
|
||||
expect(await screen.findByText(/step 2 of 3/i)).toBeInTheDocument();
|
||||
};
|
||||
|
||||
// The modal renders into a portal on document.body, so fields are queried
|
||||
// from the document rather than the render container.
|
||||
const fieldByName = (name: string) =>
|
||||
document.querySelector(`input[name="${name}"]`) as HTMLInputElement;
|
||||
|
||||
const fillDynamicForm = () => {
|
||||
const values: Record<string, string> = {
|
||||
host: 'badhost',
|
||||
port: '5432',
|
||||
database: 'testdb',
|
||||
username: 'testusername',
|
||||
password: 'testpass',
|
||||
};
|
||||
Object.entries(values).forEach(([name, value]) =>
|
||||
userEvent.type(fieldByName(name), value),
|
||||
);
|
||||
};
|
||||
|
||||
test('defaults the display name to the selected engine', async () => {
|
||||
setup();
|
||||
await selectPostgres();
|
||||
|
||||
// The display name field auto-fills with the selected engine's name. The
|
||||
// empty initial state of the connection fields (host/port/database/…) is
|
||||
// already covered by the "enters form credentials" test above.
|
||||
expect(fieldByName('database_name')).toHaveValue('PostgreSQL');
|
||||
});
|
||||
|
||||
test('switches to the SQLAlchemy URI form via the connect link', async () => {
|
||||
setup();
|
||||
await selectPostgres();
|
||||
|
||||
userEvent.click(screen.getByTestId('sqla-connect-btn'));
|
||||
|
||||
expect(await screen.findByTestId('database-name-input')).toBeVisible();
|
||||
expect(screen.getByTestId('sqlalchemy-uri-input')).toBeVisible();
|
||||
});
|
||||
|
||||
test.each([
|
||||
{
|
||||
field: 'host',
|
||||
errorType: 'CONNECTION_INVALID_HOSTNAME_ERROR',
|
||||
message: "The hostname provided can't be resolved.",
|
||||
},
|
||||
{
|
||||
field: 'port',
|
||||
errorType: 'CONNECTION_PORT_CLOSED_ERROR',
|
||||
message: 'The port is closed.',
|
||||
},
|
||||
])(
|
||||
'surfaces a $field validation error returned by validate_parameters',
|
||||
async ({ field, errorType, message }) => {
|
||||
const createResource = jest.fn();
|
||||
const useSingleViewResourceMock = jest
|
||||
.spyOn(hooks, 'useSingleViewResource')
|
||||
.mockReturnValue({
|
||||
state: {
|
||||
loading: false,
|
||||
resource: null,
|
||||
error: null,
|
||||
},
|
||||
fetchResource: jest.fn(),
|
||||
createResource,
|
||||
updateResource: jest.fn(),
|
||||
clearError: jest.fn(),
|
||||
setResource: jest.fn(),
|
||||
});
|
||||
|
||||
setup();
|
||||
await selectPostgres();
|
||||
fillDynamicForm();
|
||||
|
||||
const submitButton = screen.getByTestId('btn-submit-connection');
|
||||
// Blur the last field and let the (default, passing) validation settle
|
||||
// so its async onBlur result can't race with — and overwrite — the
|
||||
// submit-time validation below. Mirrors the Cypress `body.click(0, 0)`.
|
||||
userEvent.click(document.body);
|
||||
await waitFor(() => expect(submitButton).toBeEnabled());
|
||||
|
||||
// Make validation fail the way a real backend would for an unreachable
|
||||
// host / closed port. The frontend's job is to map `extra.invalid:
|
||||
// [field]` onto the matching form field. The 422 short-circuits the
|
||||
// submit before the create (database-connect) request fires, so only
|
||||
// this validate_parameters mock drives the rendered error.
|
||||
fetchMock.modifyRoute('validate-params', {
|
||||
response: {
|
||||
status: 422,
|
||||
body: {
|
||||
errors: [
|
||||
{
|
||||
message,
|
||||
error_type: errorType,
|
||||
level: 'error',
|
||||
extra: { invalid: [field] },
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
userEvent.click(submitButton);
|
||||
|
||||
// Wait for the async error to render, then confirm it surfaced as an
|
||||
// antd inline field error (not a general alert)...
|
||||
const errorText = await screen.findByText(message);
|
||||
expect(errorText.closest('.ant-form-item-explain-error')).not.toBeNull();
|
||||
// ...and that it belongs to the form item for the field named in
|
||||
// `extra.invalid` — i.e. that input lives in the same `.ant-form-item`.
|
||||
const formItem = errorText.closest('.ant-form-item') as HTMLElement;
|
||||
expect(formItem.querySelector(`input[name="${field}"]`)).not.toBeNull();
|
||||
expect(createResource).not.toHaveBeenCalled();
|
||||
useSingleViewResourceMock.mockRestore();
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
test('handleChangeWithValidation function clears validation errors when called', () => {
|
||||
|
||||
@@ -940,7 +940,7 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
||||
}
|
||||
|
||||
const errors = await getValidation(dbToUpdate, true);
|
||||
if (!isEmpty(validationErrors) || errors?.length) {
|
||||
if (!isEmpty(validationErrors) || !isEmpty(errors)) {
|
||||
addDangerToast(
|
||||
t('Connection failed, please check your connection settings.'),
|
||||
);
|
||||
|
||||
@@ -57,6 +57,7 @@ export enum LocalStorageKeys {
|
||||
DashboardExploreContext = 'dashboard__explore_context',
|
||||
DashboardEditorShowOnlyMyCharts = 'dashboard__editor_show_only_my_charts',
|
||||
CommonResizableSidebarWidths = 'common__resizable_sidebar_widths',
|
||||
ChatState = 'chat__state',
|
||||
}
|
||||
|
||||
export type LocalStorageValues = {
|
||||
@@ -78,6 +79,7 @@ export type LocalStorageValues = {
|
||||
dashboard__explore_context: Record<string, DashboardContextForExplore>;
|
||||
dashboard__editor_show_only_my_charts: boolean;
|
||||
common__resizable_sidebar_widths: Record<string, number>;
|
||||
chat__state: { open: boolean; mode: string };
|
||||
};
|
||||
|
||||
/*
|
||||
|
||||
@@ -25,8 +25,8 @@ import {
|
||||
useLocation,
|
||||
} from 'react-router-dom';
|
||||
import { bindActionCreators } from 'redux';
|
||||
import { css } from '@apache-superset/core/theme';
|
||||
import { Layout, Loading } from '@superset-ui/core/components';
|
||||
import { css, useTheme } from '@apache-superset/core/theme';
|
||||
import { Flex, Layout, Loading } from '@superset-ui/core/components';
|
||||
import { setupAGGridModules } from '@superset-ui/core/components/ThemedAgGridReact';
|
||||
import { ErrorBoundary } from 'src/components';
|
||||
import Menu from 'src/features/home/Menu';
|
||||
@@ -39,7 +39,12 @@ import { Logger, LOG_ACTIONS_SPA_NAVIGATION } from 'src/logger/LogUtils';
|
||||
import setupCodeOverrides from 'src/setup/setupCodeOverrides';
|
||||
import { logEvent } from 'src/logger/actions';
|
||||
import { store } from 'src/views/store';
|
||||
import { FeatureFlag, isFeatureEnabled } from '@superset-ui/core';
|
||||
import { isUser } from 'src/types/bootstrapTypes';
|
||||
import ExtensionsStartup from 'src/extensions/ExtensionsStartup';
|
||||
import { Splitter } from 'src/components/Splitter';
|
||||
import { ChatFloatingHost, ChatPanelHost, useChat } from 'src/core/chat';
|
||||
import useStoredSidebarWidth from 'src/components/ResizableSidebar/useStoredSidebarWidth';
|
||||
import { RootContextProviders } from './RootContextProviders';
|
||||
import { ScrollToTop } from './ScrollToTop';
|
||||
|
||||
@@ -79,42 +84,139 @@ const LocationPathnameLogger = () => {
|
||||
return <></>;
|
||||
};
|
||||
|
||||
const CHAT_PANEL_DEFAULT_WIDTH = 400;
|
||||
const CHAT_PANEL_MIN_WIDTH = 280;
|
||||
|
||||
const RouteSwitch = () => {
|
||||
const theme = useTheme();
|
||||
return (
|
||||
<Switch>
|
||||
{routes.map(({ path, Component, props = {}, Fallback = Loading }) => (
|
||||
<Route path={path} key={path}>
|
||||
<Suspense fallback={<Fallback />}>
|
||||
<ErrorBoundary
|
||||
css={css`
|
||||
margin: ${theme.sizeUnit * 4}px;
|
||||
`}
|
||||
>
|
||||
<Component user={bootstrapData.user} {...props} />
|
||||
</ErrorBoundary>
|
||||
</Suspense>
|
||||
</Route>
|
||||
))}
|
||||
<Redirect from="/" to="/superset/welcome/" exact />
|
||||
</Switch>
|
||||
);
|
||||
};
|
||||
|
||||
const layoutCss = css`
|
||||
flex: 1;
|
||||
min-height: 0;
|
||||
overflow: hidden;
|
||||
`;
|
||||
|
||||
const contentCss = css`
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
min-height: 0;
|
||||
overflow-y: auto;
|
||||
position: relative;
|
||||
`;
|
||||
|
||||
/**
|
||||
* Renders the main content area. When the chat panel is open in panel mode,
|
||||
* wraps <Layout> and <ChatPanelContent> in a Splitter so they sit side-by-side
|
||||
* with a lazy drag bar (blue preview line, resize committed on mouseup).
|
||||
* The full <Layout> tree lives inside the first panel so its internal flex
|
||||
* context is preserved — SQL Lab, Explore, and other pages are unaffected.
|
||||
*/
|
||||
const AppContent = () => {
|
||||
const isAuthenticated =
|
||||
isUser(bootstrapData.user) && !bootstrapData.user.isAnonymous;
|
||||
const chatExtensionsEnabled =
|
||||
isFeatureEnabled(FeatureFlag.EnableExtensions) && isAuthenticated;
|
||||
const { open: panelOpen, mode, chat } = useChat();
|
||||
const hasChatExtension = chatExtensionsEnabled && !!chat;
|
||||
const isPanelOpen = hasChatExtension && mode === 'panel' && panelOpen;
|
||||
|
||||
const [storedWidth, setStoredWidth] = useStoredSidebarWidth(
|
||||
'chat:panel',
|
||||
CHAT_PANEL_DEFAULT_WIDTH,
|
||||
);
|
||||
|
||||
const layoutContent = (
|
||||
<Layout css={layoutCss}>
|
||||
<Layout.Content css={contentCss}>
|
||||
<RouteSwitch />
|
||||
</Layout.Content>
|
||||
</Layout>
|
||||
);
|
||||
|
||||
if (!isPanelOpen) {
|
||||
return (
|
||||
<>
|
||||
{layoutContent}
|
||||
{hasChatExtension && <ChatFloatingHost />}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Splitter
|
||||
lazy
|
||||
onResizeEnd={sizes => {
|
||||
const chatWidth = sizes[sizes.length - 1];
|
||||
if (
|
||||
typeof chatWidth === 'number' &&
|
||||
chatWidth >= CHAT_PANEL_MIN_WIDTH
|
||||
) {
|
||||
setStoredWidth(chatWidth);
|
||||
}
|
||||
}}
|
||||
css={css`
|
||||
flex: 1;
|
||||
min-height: 0;
|
||||
overflow: hidden;
|
||||
|
||||
/*
|
||||
* Splitter.Panel is not a flex container by default, so flex:1 on
|
||||
* children (Layout, ChatPanelHost) has no height effect and
|
||||
* panels collapse. Making them flex columns lets children fill them.
|
||||
*/
|
||||
& > .ant-splitter-panel {
|
||||
display: flex !important;
|
||||
flex-direction: column;
|
||||
}
|
||||
`}
|
||||
>
|
||||
<Splitter.Panel>{layoutContent}</Splitter.Panel>
|
||||
<Splitter.Panel size={storedWidth} min={CHAT_PANEL_MIN_WIDTH}>
|
||||
<ChatPanelHost />
|
||||
</Splitter.Panel>
|
||||
</Splitter>
|
||||
);
|
||||
};
|
||||
|
||||
const App = () => (
|
||||
<Router basename={applicationRoot()}>
|
||||
<ScrollToTop />
|
||||
<LocationPathnameLogger />
|
||||
<RootContextProviders>
|
||||
<Menu
|
||||
data={bootstrapData.common.menu_data}
|
||||
isFrontendRoute={isFrontendRoute}
|
||||
/>
|
||||
<ExtensionsStartup>
|
||||
<Switch>
|
||||
{routes.map(({ path, Component, props = {}, Fallback = Loading }) => (
|
||||
<Route path={path} key={path}>
|
||||
<Suspense fallback={<Fallback />}>
|
||||
<Layout>
|
||||
<Layout.Content
|
||||
css={css`
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
`}
|
||||
>
|
||||
<ErrorBoundary
|
||||
css={css`
|
||||
margin: 16px;
|
||||
`}
|
||||
>
|
||||
<Component user={bootstrapData.user} {...props} />
|
||||
</ErrorBoundary>
|
||||
</Layout.Content>
|
||||
</Layout>
|
||||
</Suspense>
|
||||
</Route>
|
||||
))}
|
||||
<Redirect from="/" to="/superset/welcome/" exact />
|
||||
</Switch>
|
||||
</ExtensionsStartup>
|
||||
<Flex
|
||||
vertical
|
||||
css={css`
|
||||
height: 100vh;
|
||||
overflow: hidden;
|
||||
`}
|
||||
>
|
||||
<Menu
|
||||
data={bootstrapData.common.menu_data}
|
||||
isFrontendRoute={isFrontendRoute}
|
||||
/>
|
||||
<ExtensionsStartup>
|
||||
<AppContent />
|
||||
</ExtensionsStartup>
|
||||
</Flex>
|
||||
<ToastContainer />
|
||||
</RootContextProviders>
|
||||
</Router>
|
||||
|
||||
60
superset-frontend/src/views/routePaths.ts
Normal file
60
superset-frontend/src/views/routePaths.ts
Normal file
@@ -0,0 +1,60 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
export const RoutePaths = {
|
||||
REDIRECT: '/redirect/',
|
||||
LOGIN: '/login/',
|
||||
REGISTER_ACTIVATION: '/register/activation/:activationHash',
|
||||
REGISTER: '/register/',
|
||||
LOGOUT: '/logout/',
|
||||
HOME: '/superset/welcome/',
|
||||
FILE_HANDLER: '/superset/file-handler',
|
||||
DASHBOARD: '/superset/dashboard/:idOrSlug/',
|
||||
DASHBOARD_LIST: '/dashboard/list/',
|
||||
CHART_ADD: '/chart/add',
|
||||
CHART_LIST: '/chart/list/',
|
||||
DATASET_LIST: '/tablemodelview/list/',
|
||||
DATABASE_LIST: '/databaseview/list/',
|
||||
SAVED_QUERIES: '/savedqueryview/list/',
|
||||
CSS_TEMPLATES: '/csstemplatemodelview/list/',
|
||||
THEMES: '/theme/list/',
|
||||
ANNOTATION_LAYERS: '/annotationlayer/list/',
|
||||
ANNOTATION_LIST: '/annotationlayer/:annotationLayerId/annotation/',
|
||||
QUERY_HISTORY: '/sqllab/history/',
|
||||
ALERTS: '/alert/list/',
|
||||
REPORTS: '/report/list/',
|
||||
ALERT_LOG: '/alert/:alertId/log/',
|
||||
REPORT_LOG: '/report/:alertId/log/',
|
||||
EXPLORE: '/explore/',
|
||||
EXPLORE_PERMALINK: '/superset/explore/p',
|
||||
DATASET_ADD: '/dataset/add/',
|
||||
DATASET: '/dataset/:datasetId',
|
||||
ROW_LEVEL_SECURITY: '/rowlevelsecurity/list',
|
||||
TASKS: '/tasks/list/',
|
||||
SQLLAB: '/sqllab/',
|
||||
USER_INFO: '/user_info/',
|
||||
ACTION_LOG: '/actionlog/list',
|
||||
REGISTRATIONS: '/registrations/',
|
||||
ALL_ENTITIES: '/superset/all_entities/',
|
||||
TAGS: '/superset/tags/',
|
||||
ROLES: '/roles/',
|
||||
USERS: '/users/',
|
||||
GROUPS: '/list_groups/',
|
||||
EXTENSIONS: '/extensions/list/',
|
||||
} as const;
|
||||
@@ -26,6 +26,7 @@ import {
|
||||
} from 'react';
|
||||
import { isUserAdmin } from 'src/dashboard/util/permissionUtils';
|
||||
import getBootstrapData from 'src/utils/getBootstrapData';
|
||||
import { RoutePaths } from './routePaths';
|
||||
|
||||
// not lazy loaded since this is the home page.
|
||||
import Home from 'src/pages/Home';
|
||||
@@ -189,158 +190,58 @@ const RedirectWarning = lazy(
|
||||
|
||||
type Routes = {
|
||||
path: string;
|
||||
Component: ComponentType;
|
||||
Fallback?: ComponentType;
|
||||
Component: ComponentType<any>;
|
||||
Fallback?: ComponentType<any>;
|
||||
props?: ComponentProps<any>;
|
||||
}[];
|
||||
|
||||
export const routes: Routes = [
|
||||
{ path: RoutePaths.REDIRECT, Component: RedirectWarning },
|
||||
{ path: RoutePaths.LOGIN, Component: Login },
|
||||
{ path: RoutePaths.REGISTER_ACTIVATION, Component: Register },
|
||||
{ path: RoutePaths.REGISTER, Component: Register },
|
||||
{ path: RoutePaths.LOGOUT, Component: Login },
|
||||
{ path: RoutePaths.HOME, Component: Home },
|
||||
{ path: RoutePaths.FILE_HANDLER, Component: FileHandler },
|
||||
{ path: RoutePaths.DASHBOARD_LIST, Component: DashboardList },
|
||||
{ path: RoutePaths.DASHBOARD, Component: Dashboard },
|
||||
{ path: RoutePaths.CHART_ADD, Component: ChartCreation },
|
||||
{ path: RoutePaths.CHART_LIST, Component: ChartList },
|
||||
{ path: RoutePaths.DATASET_LIST, Component: DatasetList },
|
||||
{ path: RoutePaths.DATABASE_LIST, Component: DatabaseList },
|
||||
{ path: RoutePaths.SAVED_QUERIES, Component: SavedQueryList },
|
||||
{ path: RoutePaths.CSS_TEMPLATES, Component: CssTemplateList },
|
||||
{ path: RoutePaths.THEMES, Component: ThemeList },
|
||||
{ path: RoutePaths.ANNOTATION_LAYERS, Component: AnnotationLayerList },
|
||||
{ path: RoutePaths.ANNOTATION_LIST, Component: AnnotationList },
|
||||
{ path: RoutePaths.QUERY_HISTORY, Component: QueryHistoryList },
|
||||
{ path: RoutePaths.ALERTS, Component: AlertReportList },
|
||||
{
|
||||
path: '/redirect/',
|
||||
Component: RedirectWarning,
|
||||
},
|
||||
{
|
||||
path: '/login/',
|
||||
Component: Login,
|
||||
},
|
||||
{
|
||||
path: '/register/activation/:activationHash',
|
||||
Component: Register,
|
||||
},
|
||||
{
|
||||
path: '/register/',
|
||||
Component: Register,
|
||||
},
|
||||
{
|
||||
path: '/logout/',
|
||||
Component: Login,
|
||||
},
|
||||
{
|
||||
path: '/superset/welcome/',
|
||||
Component: Home,
|
||||
},
|
||||
{
|
||||
path: '/superset/file-handler',
|
||||
Component: FileHandler,
|
||||
},
|
||||
{
|
||||
path: '/dashboard/list/',
|
||||
Component: DashboardList,
|
||||
},
|
||||
{
|
||||
path: '/superset/dashboard/:idOrSlug/',
|
||||
Component: Dashboard,
|
||||
},
|
||||
{
|
||||
path: '/chart/add',
|
||||
Component: ChartCreation,
|
||||
},
|
||||
{
|
||||
path: '/chart/list/',
|
||||
Component: ChartList,
|
||||
},
|
||||
{
|
||||
path: '/tablemodelview/list/',
|
||||
Component: DatasetList,
|
||||
},
|
||||
{
|
||||
path: '/databaseview/list/',
|
||||
Component: DatabaseList,
|
||||
},
|
||||
{
|
||||
path: '/savedqueryview/list/',
|
||||
Component: SavedQueryList,
|
||||
},
|
||||
{
|
||||
path: '/csstemplatemodelview/list/',
|
||||
Component: CssTemplateList,
|
||||
},
|
||||
{
|
||||
path: '/theme/list/',
|
||||
Component: ThemeList,
|
||||
},
|
||||
{
|
||||
path: '/annotationlayer/list/',
|
||||
Component: AnnotationLayerList,
|
||||
},
|
||||
{
|
||||
path: '/annotationlayer/:annotationLayerId/annotation/',
|
||||
Component: AnnotationList,
|
||||
},
|
||||
{
|
||||
path: '/sqllab/history/',
|
||||
Component: QueryHistoryList,
|
||||
},
|
||||
{
|
||||
path: '/alert/list/',
|
||||
path: RoutePaths.REPORTS,
|
||||
Component: AlertReportList,
|
||||
props: { isReportEnabled: true },
|
||||
},
|
||||
{ path: RoutePaths.ALERT_LOG, Component: ExecutionLogList },
|
||||
{
|
||||
path: '/report/list/',
|
||||
Component: AlertReportList,
|
||||
props: {
|
||||
isReportEnabled: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
path: '/alert/:alertId/log/',
|
||||
path: RoutePaths.REPORT_LOG,
|
||||
Component: ExecutionLogList,
|
||||
props: { isReportEnabled: true },
|
||||
},
|
||||
{
|
||||
path: '/report/:alertId/log/',
|
||||
Component: ExecutionLogList,
|
||||
props: {
|
||||
isReportEnabled: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
path: '/explore/',
|
||||
Component: Chart,
|
||||
},
|
||||
{
|
||||
path: '/superset/explore/p',
|
||||
Component: Chart,
|
||||
},
|
||||
{
|
||||
path: '/dataset/add/',
|
||||
Component: DatasetCreation,
|
||||
},
|
||||
{
|
||||
path: '/dataset/:datasetId',
|
||||
Component: DatasetCreation,
|
||||
},
|
||||
{
|
||||
path: '/rowlevelsecurity/list',
|
||||
Component: RowLevelSecurityList,
|
||||
},
|
||||
{
|
||||
path: '/tasks/list/',
|
||||
Component: TaskList,
|
||||
},
|
||||
{
|
||||
path: '/sqllab/',
|
||||
Component: SqlLab,
|
||||
},
|
||||
{ path: '/user_info/', Component: UserInfo },
|
||||
{
|
||||
path: '/actionlog/list',
|
||||
Component: ActionLogList,
|
||||
},
|
||||
{
|
||||
path: '/registrations/',
|
||||
Component: UserRegistrations,
|
||||
},
|
||||
{ path: RoutePaths.EXPLORE, Component: Chart },
|
||||
{ path: RoutePaths.EXPLORE_PERMALINK, Component: Chart },
|
||||
{ path: RoutePaths.DATASET_ADD, Component: DatasetCreation },
|
||||
{ path: RoutePaths.DATASET, Component: DatasetCreation },
|
||||
{ path: RoutePaths.ROW_LEVEL_SECURITY, Component: RowLevelSecurityList },
|
||||
{ path: RoutePaths.TASKS, Component: TaskList },
|
||||
{ path: RoutePaths.SQLLAB, Component: SqlLab },
|
||||
{ path: RoutePaths.USER_INFO, Component: UserInfo },
|
||||
{ path: RoutePaths.ACTION_LOG, Component: ActionLogList },
|
||||
{ path: RoutePaths.REGISTRATIONS, Component: UserRegistrations },
|
||||
];
|
||||
|
||||
if (isFeatureEnabled(FeatureFlag.TaggingSystem)) {
|
||||
routes.push({
|
||||
path: '/superset/all_entities/',
|
||||
Component: AllEntities,
|
||||
});
|
||||
routes.push({
|
||||
path: '/superset/tags/',
|
||||
Component: Tags,
|
||||
});
|
||||
routes.push({ path: RoutePaths.ALL_ENTITIES, Component: AllEntities });
|
||||
routes.push({ path: RoutePaths.TAGS, Component: Tags });
|
||||
}
|
||||
|
||||
const user = getBootstrapData()?.user;
|
||||
@@ -350,33 +251,18 @@ const isAdmin = isUserAdmin(user);
|
||||
|
||||
if (isAdmin) {
|
||||
routes.push(
|
||||
{
|
||||
path: '/roles/',
|
||||
Component: RolesList,
|
||||
},
|
||||
{
|
||||
path: '/users/',
|
||||
Component: UsersList,
|
||||
},
|
||||
{
|
||||
path: '/list_groups/',
|
||||
Component: GroupsList,
|
||||
},
|
||||
{ path: RoutePaths.ROLES, Component: RolesList },
|
||||
{ path: RoutePaths.USERS, Component: UsersList },
|
||||
{ path: RoutePaths.GROUPS, Component: GroupsList },
|
||||
);
|
||||
|
||||
if (isFeatureEnabled(FeatureFlag.EnableExtensions)) {
|
||||
routes.push({
|
||||
path: '/extensions/list/',
|
||||
Component: Extensions,
|
||||
});
|
||||
routes.push({ path: RoutePaths.EXTENSIONS, Component: Extensions });
|
||||
}
|
||||
}
|
||||
|
||||
if (authRegistrationEnabled) {
|
||||
routes.push({
|
||||
path: '/registrations/',
|
||||
Component: UserRegistrations,
|
||||
});
|
||||
routes.push({ path: RoutePaths.REGISTRATIONS, Component: UserRegistrations });
|
||||
}
|
||||
|
||||
const frontEndRoutes: Record<string, boolean> = routes
|
||||
|
||||
@@ -1455,6 +1455,18 @@ class ChartDataQueryObjectSchema(Schema):
|
||||
fields.String(),
|
||||
allow_none=True,
|
||||
)
|
||||
time_compare_full_range = fields.Boolean(
|
||||
required=False,
|
||||
allow_none=True,
|
||||
metadata={
|
||||
"description": (
|
||||
"When using a time comparison (time_offsets), plot each shifted "
|
||||
"series across its full time range instead of truncating it to the "
|
||||
"main series' range. Useful for comparing a partial current period "
|
||||
"against complete prior periods."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@post_load
|
||||
def rename_deprecated_fields(
|
||||
|
||||
@@ -105,6 +105,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
||||
series_limit: int
|
||||
series_limit_metric: Metric | None
|
||||
time_offsets: list[str]
|
||||
time_compare_full_range: bool
|
||||
time_shift: str | None
|
||||
time_range: str | None
|
||||
to_dttm: datetime | None
|
||||
@@ -162,6 +163,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
||||
self.to_dttm = kwargs.get("to_dttm")
|
||||
self.result_type = kwargs.get("result_type")
|
||||
self.time_offsets = kwargs.get("time_offsets", [])
|
||||
self.time_compare_full_range = kwargs.get("time_compare_full_range", False)
|
||||
self.inner_from_dttm = kwargs.get("inner_from_dttm")
|
||||
self.inner_to_dttm = kwargs.get("inner_to_dttm")
|
||||
self._rename_deprecated_fields(kwargs)
|
||||
@@ -410,6 +412,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
||||
"group_others_when_limit_reached": self.group_others_when_limit_reached,
|
||||
"to_dttm": self.to_dttm,
|
||||
"time_shift": self.time_shift,
|
||||
"time_compare_full_range": self.time_compare_full_range,
|
||||
}
|
||||
return query_object_dict
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing import Any, Literal, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -32,9 +32,14 @@ def left_join_df(
|
||||
join_keys: list[str],
|
||||
lsuffix: str = "",
|
||||
rsuffix: str = "",
|
||||
how: Literal["left", "right", "inner", "outer", "cross"] = "left",
|
||||
) -> pd.DataFrame:
|
||||
# `how` defaults to "left" so callers that only want the left frame's rows are
|
||||
# unaffected. Passing how="outer" keeps right-only rows, which is used by the
|
||||
# time-comparison "full range" option so historical series are not truncated to
|
||||
# the main series' time range.
|
||||
df = left_df.set_index(join_keys).join(
|
||||
right_df.set_index(join_keys), lsuffix=lsuffix, rsuffix=rsuffix
|
||||
right_df.set_index(join_keys), how=how, lsuffix=lsuffix, rsuffix=rsuffix
|
||||
)
|
||||
df.reset_index(inplace=True)
|
||||
return df
|
||||
|
||||
@@ -434,7 +434,13 @@ LOGO_RIGHT_TEXT: Callable[[], str] | str = ""
|
||||
|
||||
# Enables SWAGGER UI for superset openapi spec
|
||||
# ex: http://localhost:8080/swagger/v1
|
||||
FAB_API_SWAGGER_UI = True
|
||||
# Disabled by default so the interactive API documentation surface is opt-in.
|
||||
# Enable it by setting the SUPERSET_ENABLE_SWAGGER_UI environment variable
|
||||
# (e.g. for local development) or by overriding FAB_API_SWAGGER_UI in
|
||||
# superset_config.py.
|
||||
FAB_API_SWAGGER_UI = utils.cast_to_boolean(
|
||||
os.environ.get("SUPERSET_ENABLE_SWAGGER_UI", False)
|
||||
)
|
||||
|
||||
# ----------------------------------------------------
|
||||
# AUTHENTICATION CONFIG
|
||||
|
||||
@@ -38,18 +38,12 @@ from superset.db_engine_specs.base import (
|
||||
)
|
||||
from superset.db_engine_specs.hive import HiveEngineSpec
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError
|
||||
from superset.utils import json
|
||||
from superset.utils.core import get_user_agent, QuerySource
|
||||
from superset.utils.network import is_hostname_valid, is_port_open
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.core import Database
|
||||
from superset.superset_typing import (
|
||||
OAuth2ClientConfig,
|
||||
OAuth2State,
|
||||
OAuth2TokenResponse,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
@@ -283,105 +277,6 @@ class DatabricksDynamicBaseEngineSpec(BasicParametersMixin, DatabricksBaseEngine
|
||||
"port": "port",
|
||||
}
|
||||
|
||||
# OAuth2 endpoints for different cloud providers
|
||||
_oauth2_endpoints = {
|
||||
"aws": {
|
||||
"authorization_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/{}/v1/authorize",
|
||||
"token_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/{}/v1/token",
|
||||
},
|
||||
"azure": {
|
||||
"authorization_request_uri": "https://login.microsoftonline.com/{}/oauth2/v2.0/authorize",
|
||||
"token_request_uri": "https://login.microsoftonline.com/{}/oauth2/v2.0/token",
|
||||
},
|
||||
"gcp": {
|
||||
"authorization_request_uri": "https://accounts.gcp.databricks.com/oidc/accounts/{}/v1/authorize",
|
||||
"token_request_uri": "https://accounts.gcp.databricks.com/oidc/accounts/{}/v1/token",
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _detect_cloud_provider(cls, database: Database) -> str:
|
||||
"""
|
||||
Detect the cloud provider based on the database configuration.
|
||||
|
||||
Returns:
|
||||
str: The cloud provider ('aws', 'azure', or 'gcp')
|
||||
"""
|
||||
# Check if cloud provider is explicitly configured in extra
|
||||
if isinstance(
|
||||
(cloud_provider := cls.get_extra_params(database).get("cloud_provider")),
|
||||
str,
|
||||
):
|
||||
provider = cloud_provider.lower()
|
||||
if provider in cls._oauth2_endpoints:
|
||||
return provider
|
||||
|
||||
# Try to detect from hostname
|
||||
hostname = database.url_object.host or ""
|
||||
hostname = hostname.lower()
|
||||
|
||||
if "azure" in hostname or "azuredatabricks" in hostname:
|
||||
return "azure"
|
||||
elif "gcp" in hostname or "googleusercontent" in hostname:
|
||||
return "gcp"
|
||||
else:
|
||||
# Default to AWS for compatibility
|
||||
return "aws"
|
||||
|
||||
@classmethod
|
||||
def _resolve_oauth2_endpoint(
|
||||
cls,
|
||||
database: Database,
|
||||
provider: str,
|
||||
endpoint_key: str,
|
||||
) -> str:
|
||||
"""
|
||||
Build a fully-resolved OAuth2 endpoint for the detected cloud provider.
|
||||
|
||||
The per-provider templates carry a single ``{}`` placeholder for the
|
||||
Databricks account id (or Azure tenant id), read from the database's
|
||||
``extra`` (``account_id``, or ``tenant_id`` for Azure). Raising when it
|
||||
is absent keeps the flow from issuing a request to an unresolved
|
||||
``.../{}/...`` endpoint.
|
||||
"""
|
||||
template = cls._oauth2_endpoints[provider][endpoint_key]
|
||||
if "{}" not in template:
|
||||
return template
|
||||
|
||||
extra = cls.get_extra_params(database)
|
||||
account_id = extra.get("account_id") or extra.get("tenant_id")
|
||||
if not account_id:
|
||||
raise OAuth2Error(
|
||||
"Databricks OAuth2 endpoints could not be resolved: set "
|
||||
"`account_id` (or `tenant_id` for Azure) in the database's "
|
||||
"engine parameters, or provide a fully-resolved "
|
||||
f"`{endpoint_key}` in DATABASE_OAUTH2_CLIENTS."
|
||||
)
|
||||
return template.format(account_id)
|
||||
|
||||
@classmethod
|
||||
def impersonate_user(
|
||||
cls,
|
||||
database: Database,
|
||||
username: str | None,
|
||||
user_token: str | None,
|
||||
url: URL,
|
||||
engine_kwargs: dict[str, Any],
|
||||
) -> tuple[URL, dict[str, Any]]:
|
||||
"""
|
||||
Update connection with OAuth2 access token for user impersonation.
|
||||
"""
|
||||
if user_token:
|
||||
# Replace the access token in the URL with the user's OAuth2 token
|
||||
url = url.set(password=user_token)
|
||||
|
||||
# Also update connect_args if they contain access token
|
||||
connect_args = engine_kwargs.setdefault("connect_args", {})
|
||||
if "access_token" in connect_args:
|
||||
connect_args["access_token"] = user_token
|
||||
|
||||
return url, engine_kwargs
|
||||
|
||||
@staticmethod
|
||||
def get_extra_params(
|
||||
database: Database, source: QuerySource | None = None
|
||||
@@ -579,74 +474,6 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
|
||||
supports_dynamic_catalog = True
|
||||
supports_cross_catalog_queries = True
|
||||
|
||||
# OAuth 2.0 support
|
||||
supports_oauth2 = True
|
||||
oauth2_exception = OAuth2RedirectError
|
||||
oauth2_scope = "sql"
|
||||
|
||||
# OAuth2 endpoints are determined dynamically based on cloud provider
|
||||
oauth2_authorization_request_uri = "" # Set dynamically
|
||||
oauth2_token_request_uri = "" # Set dynamically
|
||||
|
||||
@classmethod
|
||||
def get_oauth2_authorization_uri(
|
||||
cls,
|
||||
config: "OAuth2ClientConfig",
|
||||
state: "OAuth2State",
|
||||
code_verifier: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Return URI for initial OAuth2 request with dynamic endpoint detection.
|
||||
|
||||
A fully-resolved `authorization_request_uri` from `DATABASE_OAUTH2_CLIENTS`
|
||||
is preserved; only fall back to the auto-detected, account-resolved
|
||||
endpoint when none is configured.
|
||||
"""
|
||||
if not config.get("authorization_request_uri"):
|
||||
from superset import db
|
||||
from superset.models.core import Database
|
||||
|
||||
# Get the database to detect cloud provider
|
||||
database_id = state["database_id"]
|
||||
if database := db.session.get(Database, database_id):
|
||||
provider = cls._detect_cloud_provider(database)
|
||||
from typing import cast
|
||||
|
||||
config = cast(
|
||||
"OAuth2ClientConfig",
|
||||
dict(config)
|
||||
| {
|
||||
"authorization_request_uri": cls._resolve_oauth2_endpoint(
|
||||
database, provider, "authorization_request_uri"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
return super().get_oauth2_authorization_uri(config, state, code_verifier)
|
||||
|
||||
@classmethod
|
||||
def get_oauth2_token(
|
||||
cls,
|
||||
config: "OAuth2ClientConfig",
|
||||
code: str,
|
||||
code_verifier: str | None = None,
|
||||
) -> "OAuth2TokenResponse":
|
||||
"""
|
||||
Exchange authorization code for refresh/access tokens.
|
||||
|
||||
The token request URI is resolved when the OAuth2 config is built (see
|
||||
`get_oauth2_config`) and already targets the correct cloud provider and
|
||||
account. There is no database context here to auto-detect it, so fail
|
||||
fast rather than POST to an unresolved endpoint when it is missing.
|
||||
"""
|
||||
if not config.get("token_request_uri"):
|
||||
raise OAuth2Error(
|
||||
"Databricks OAuth2 token endpoint is not configured: provide a "
|
||||
"fully-resolved `token_request_uri` in DATABASE_OAUTH2_CLIENTS."
|
||||
)
|
||||
|
||||
return super().get_oauth2_token(config, code, code_verifier)
|
||||
|
||||
@classmethod
|
||||
def build_sqlalchemy_uri( # type: ignore
|
||||
cls, parameters: DatabricksNativeParametersType, *_
|
||||
@@ -858,74 +685,6 @@ class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
|
||||
|
||||
supports_dynamic_schema = supports_catalog = supports_dynamic_catalog = True
|
||||
|
||||
# OAuth 2.0 support
|
||||
supports_oauth2 = True
|
||||
oauth2_exception = OAuth2RedirectError
|
||||
oauth2_scope = "sql"
|
||||
|
||||
# OAuth2 endpoints are determined dynamically based on cloud provider
|
||||
oauth2_authorization_request_uri = "" # Set dynamically
|
||||
oauth2_token_request_uri = "" # Set dynamically
|
||||
|
||||
@classmethod
|
||||
def get_oauth2_authorization_uri(
|
||||
cls,
|
||||
config: "OAuth2ClientConfig",
|
||||
state: "OAuth2State",
|
||||
code_verifier: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Return URI for initial OAuth2 request with dynamic endpoint detection.
|
||||
|
||||
A fully-resolved `authorization_request_uri` from `DATABASE_OAUTH2_CLIENTS`
|
||||
is preserved; only fall back to the auto-detected, account-resolved
|
||||
endpoint when none is configured.
|
||||
"""
|
||||
if not config.get("authorization_request_uri"):
|
||||
from superset import db
|
||||
from superset.models.core import Database
|
||||
|
||||
# Get the database to detect cloud provider
|
||||
database_id = state["database_id"]
|
||||
if database := db.session.get(Database, database_id):
|
||||
provider = cls._detect_cloud_provider(database)
|
||||
from typing import cast
|
||||
|
||||
config = cast(
|
||||
"OAuth2ClientConfig",
|
||||
dict(config)
|
||||
| {
|
||||
"authorization_request_uri": cls._resolve_oauth2_endpoint(
|
||||
database, provider, "authorization_request_uri"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
return super().get_oauth2_authorization_uri(config, state, code_verifier)
|
||||
|
||||
@classmethod
|
||||
def get_oauth2_token(
|
||||
cls,
|
||||
config: "OAuth2ClientConfig",
|
||||
code: str,
|
||||
code_verifier: str | None = None,
|
||||
) -> "OAuth2TokenResponse":
|
||||
"""
|
||||
Exchange authorization code for refresh/access tokens.
|
||||
|
||||
The token request URI is resolved when the OAuth2 config is built (see
|
||||
`get_oauth2_config`) and already targets the correct cloud provider and
|
||||
account. There is no database context here to auto-detect it, so fail
|
||||
fast rather than POST to an unresolved endpoint when it is missing.
|
||||
"""
|
||||
if not config.get("token_request_uri"):
|
||||
raise OAuth2Error(
|
||||
"Databricks OAuth2 token endpoint is not configured: provide a "
|
||||
"fully-resolved `token_request_uri` in DATABASE_OAUTH2_CLIENTS."
|
||||
)
|
||||
|
||||
return super().get_oauth2_token(config, code, code_verifier)
|
||||
|
||||
@classmethod
|
||||
def build_sqlalchemy_uri( # type: ignore
|
||||
cls, parameters: DatabricksPythonConnectorParametersType, *_
|
||||
|
||||
@@ -238,6 +238,7 @@ def build_extension_data(extension: LoadedExtension) -> dict[str, Any]:
|
||||
manifest = extension.manifest
|
||||
extension_data: dict[str, Any] = {
|
||||
"id": manifest.id,
|
||||
"publisher": manifest.publisher,
|
||||
"name": extension.name,
|
||||
"version": extension.version,
|
||||
"description": manifest.description or "",
|
||||
|
||||
@@ -129,6 +129,7 @@ Dashboard Management:
|
||||
- get_dashboard_info: Get detailed dashboard information by ID
|
||||
- get_dashboard_layout: Get parsed tabs and chart positions for a dashboard (companion to get_dashboard_info when its omitted_fields hint flags position_json)
|
||||
- generate_dashboard: Create a dashboard from chart IDs (requires write access)
|
||||
- update_dashboard: Update an existing dashboard's title/description/slug/published/layout/theme/CSS (requires write access; ownership-checked per-instance)
|
||||
- add_chart_to_existing_dashboard: Add a chart to an existing dashboard (requires write access)
|
||||
|
||||
Annotation Layers:
|
||||
@@ -694,6 +695,7 @@ from superset.mcp_service.dashboard.tool import ( # noqa: F401, E402
|
||||
get_dashboard_info,
|
||||
get_dashboard_layout,
|
||||
list_dashboards,
|
||||
update_dashboard,
|
||||
)
|
||||
from superset.mcp_service.database.tool import ( # noqa: F401, E402
|
||||
get_database_info,
|
||||
|
||||
@@ -103,11 +103,15 @@ from superset.mcp_service.utils import (
|
||||
escape_llm_context_delimiters,
|
||||
sanitize_for_llm_context,
|
||||
)
|
||||
from superset.mcp_service.utils.response_utils import humanize_timestamp
|
||||
from superset.mcp_service.utils.response_utils import (
|
||||
humanize_timestamp,
|
||||
OmittedFieldsBuilder,
|
||||
)
|
||||
from superset.mcp_service.utils.sanitization import (
|
||||
sanitize_user_input,
|
||||
sanitize_user_input_with_changes,
|
||||
)
|
||||
from superset.mcp_service.utils.url_utils import get_superset_base_url
|
||||
from superset.utils.json import loads as json_loads
|
||||
|
||||
|
||||
@@ -129,8 +133,6 @@ class DashboardError(BaseModel):
|
||||
@classmethod
|
||||
def create(cls, error: str, error_type: str) -> "DashboardError":
|
||||
"""Create a standardized DashboardError with timestamp."""
|
||||
from datetime import datetime
|
||||
|
||||
return cls(error=error, error_type=error_type, timestamp=datetime.now())
|
||||
|
||||
|
||||
@@ -618,6 +620,48 @@ class GenerateDashboardRequest(BaseModel):
|
||||
published: bool = Field(
|
||||
default=False, description="Whether to publish the dashboard"
|
||||
)
|
||||
slug: str | None = Field(
|
||||
None,
|
||||
max_length=255,
|
||||
description=(
|
||||
"Optional URL slug for the dashboard. When set, the dashboard "
|
||||
"is reachable at /superset/dashboard/<slug>/ instead of "
|
||||
"/superset/dashboard/<id>/. Must be unique across the instance."
|
||||
),
|
||||
)
|
||||
position_json: Dict[str, Any] | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Optional explicit dashboard layout (Superset's position_json "
|
||||
"dict). When set, replaces the auto-generated layout entirely. "
|
||||
"Pass this when you need custom row composition, MARKDOWN "
|
||||
"blocks, HEADER components, or specific chart widths/heights. "
|
||||
"Omit to let the tool auto-generate a packed grid from chart_ids."
|
||||
),
|
||||
)
|
||||
json_metadata_overrides: Dict[str, Any] | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Optional overrides applied on top of the default "
|
||||
"json_metadata. Common fields: label_colors (per-series brand "
|
||||
"palette, e.g. {'Electronics': '#4C78A8'}), color_scheme "
|
||||
"(named Superset palette), cross_filters_enabled (bool, "
|
||||
"default False — set True for interactive dashboards), "
|
||||
"shared_label_colors (list of label names for cross-chart "
|
||||
"color consistency). Merged shallowly into the defaults; pass "
|
||||
"only the keys you want to override."
|
||||
),
|
||||
)
|
||||
css: str | None = Field(
|
||||
None,
|
||||
max_length=50000,
|
||||
description=(
|
||||
"Optional dashboard-level CSS. Useful for hiding chart chrome "
|
||||
"(kebab menus, cross-filter chips) on print-ready dashboards, "
|
||||
"or tweaking padding/typography. Applied as-is to the "
|
||||
"dashboard's css field."
|
||||
),
|
||||
)
|
||||
sanitization_warnings: List[str] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
@@ -690,6 +734,168 @@ class GenerateDashboardRequest(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class UpdateDashboardRequest(BaseModel):
|
||||
"""Request schema for updating an existing dashboard's layout/theme/style.
|
||||
|
||||
All fields are optional; only the fields explicitly passed are applied.
|
||||
Use to retroactively set a custom layout, brand palette, or CSS on a
|
||||
dashboard that was created via ``generate_dashboard`` (or earlier via
|
||||
the REST API) without a full re-create.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
identifier: int | str = Field(
|
||||
...,
|
||||
description=(
|
||||
"Dashboard ID (integer), UUID, or slug. Same identifier shape "
|
||||
"accepted by ``get_dashboard_info``."
|
||||
),
|
||||
)
|
||||
dashboard_title: str | None = Field(
|
||||
None,
|
||||
max_length=500,
|
||||
description="Optional new dashboard title.",
|
||||
validation_alias=AliasChoices("dashboard_title", "title", "name"),
|
||||
)
|
||||
description: str | None = Field(
|
||||
None,
|
||||
description="Optional new dashboard description.",
|
||||
)
|
||||
slug: str | None = Field(
|
||||
None,
|
||||
max_length=255,
|
||||
description=("Optional new URL slug. Pass empty string to clear a slug."),
|
||||
)
|
||||
published: bool | None = Field(
|
||||
None,
|
||||
description="Optional published flag.",
|
||||
)
|
||||
position_json: Dict[str, Any] | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Optional replacement layout (Superset's position_json dict). "
|
||||
"When set, fully replaces the existing layout. Get the current "
|
||||
"layout via ``get_dashboard_info`` first if you want to make "
|
||||
"incremental changes."
|
||||
),
|
||||
)
|
||||
json_metadata_overrides: Dict[str, Any] | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Optional overrides applied on top of the existing "
|
||||
"json_metadata. Merged shallowly — pass only the keys you "
|
||||
"want to change (e.g. label_colors, color_scheme, "
|
||||
"cross_filters_enabled)."
|
||||
),
|
||||
)
|
||||
css: str | None = Field(
|
||||
None,
|
||||
max_length=50000,
|
||||
description=(
|
||||
"Optional new dashboard CSS. Pass empty string to clear existing CSS."
|
||||
),
|
||||
)
|
||||
sanitization_warnings: List[str] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"Internal: warnings emitted when user input was altered by "
|
||||
"sanitization. Populated by the ``mode='before'`` validator "
|
||||
"before dashboard_title is rewritten."
|
||||
),
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _detect_dashboard_title_sanitization(cls, data: Any) -> Any:
|
||||
"""Reject empty-after-sanitization titles and warn on partial strip.
|
||||
|
||||
Mirrors the same guard ``GenerateDashboardRequest`` applies so a
|
||||
prompt-injected LLM cannot push XSS payloads through the update
|
||||
path that the create path already rejects. Server-only
|
||||
``sanitization_warnings`` is reset here so a caller cannot inject
|
||||
warning text.
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
data["sanitization_warnings"] = []
|
||||
# Must match every AliasChoice on ``dashboard_title`` — otherwise
|
||||
# an XSS payload supplied via a different key (e.g. ``name``)
|
||||
# would bypass this ``mode='before'`` guard and slip through to
|
||||
# Pydantic's alias resolution unsanitized.
|
||||
for key in ("dashboard_title", "title", "name"):
|
||||
if key in data:
|
||||
raw = data[key]
|
||||
break
|
||||
else:
|
||||
raw = None
|
||||
if not isinstance(raw, str) or not raw.strip():
|
||||
return data
|
||||
sanitized, was_modified = sanitize_user_input_with_changes(
|
||||
raw, "Dashboard title", max_length=500, allow_empty=True
|
||||
)
|
||||
if was_modified and not sanitized:
|
||||
raise ValueError(
|
||||
"dashboard_title contained only disallowed content "
|
||||
"(HTML/script/URL schemes) and was removed entirely by "
|
||||
"sanitization. Provide a dashboard_title with plain text."
|
||||
)
|
||||
if was_modified:
|
||||
data["sanitization_warnings"].append(
|
||||
"dashboard_title was modified during sanitization to "
|
||||
"remove potentially unsafe content; the stored title "
|
||||
"differs from the input."
|
||||
)
|
||||
return data
|
||||
|
||||
@field_validator("dashboard_title")
|
||||
@classmethod
|
||||
def sanitize_dashboard_title(cls, v: str | None) -> str | None:
|
||||
"""Sanitize dashboard title to prevent XSS."""
|
||||
if v is None or v == "":
|
||||
return v
|
||||
return sanitize_user_input(
|
||||
v, "Dashboard title", max_length=500, allow_empty=True
|
||||
)
|
||||
|
||||
|
||||
class UpdateDashboardResponse(BaseModel):
|
||||
"""Response schema for ``update_dashboard``.
|
||||
|
||||
Distinct from ``GenerateDashboardResponse`` because the semantics
|
||||
differ: this response reports which fields actually changed on an
|
||||
existing dashboard, rather than describing a newly created one.
|
||||
"""
|
||||
|
||||
dashboard: DashboardInfo | None = Field(
|
||||
None, description="The updated dashboard info, if successful"
|
||||
)
|
||||
dashboard_url: str | None = Field(None, description="URL to view the dashboard")
|
||||
error: str | None = Field(None, description="Error message, if update failed")
|
||||
permission_denied: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"True when the user lacks edit rights on the target "
|
||||
"dashboard. When True, ``error`` carries the human-readable "
|
||||
"explanation and the response is otherwise empty."
|
||||
),
|
||||
)
|
||||
changed_fields: List[str] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"Names of fields that were actually applied. Empty when the "
|
||||
"request was a no-op or failed before any field was applied."
|
||||
),
|
||||
)
|
||||
warnings: List[str] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"Non-fatal advisory messages — for example, that the supplied "
|
||||
"title was altered by sanitization."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class GenerateDashboardResponse(BaseModel):
|
||||
"""Response schema for dashboard generation."""
|
||||
|
||||
@@ -972,7 +1178,6 @@ def _build_omitted_fields(
|
||||
Uses the shared OmittedFieldsBuilder utility so the pattern is consistent
|
||||
across all MCP tool serializers.
|
||||
"""
|
||||
from superset.mcp_service.utils.response_utils import OmittedFieldsBuilder
|
||||
|
||||
return (
|
||||
OmittedFieldsBuilder()
|
||||
@@ -1005,7 +1210,6 @@ def serialize_chart_summary(
|
||||
"""Serialize a chart to a lightweight summary for dashboard context."""
|
||||
if not chart:
|
||||
return None
|
||||
from superset.mcp_service.utils.url_utils import get_superset_base_url
|
||||
|
||||
chart_id = getattr(chart, "id", None)
|
||||
chart_url = None
|
||||
@@ -1112,9 +1316,20 @@ def _sanitize_dashboard_info_for_llm_context(
|
||||
return DashboardInfo.model_validate(payload)
|
||||
|
||||
|
||||
def dashboard_serializer(dashboard: "Dashboard") -> DashboardInfo:
|
||||
from superset.mcp_service.utils.url_utils import get_superset_base_url
|
||||
def _safe_user_label(value: Any) -> str | None:
|
||||
"""Coerce a `*_by_name` model attribute to a display string or None.
|
||||
|
||||
The Dashboard model exposes ``created_by_name`` / ``changed_by_name``
|
||||
as plain strings, but some serializer call sites pass through
|
||||
objects (User instances, Mocks in tests) — defensive coercion keeps
|
||||
the response a valid string and avoids leaking ``repr(user)``.
|
||||
"""
|
||||
if isinstance(value, str) and value:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def dashboard_serializer(dashboard: "Dashboard") -> DashboardInfo:
|
||||
include_data_model_metadata = user_can_view_data_model_metadata()
|
||||
base_url = get_superset_base_url()
|
||||
relative_url = dashboard.url # e.g. "/superset/dashboard/{slug_or_id}/"
|
||||
@@ -1175,7 +1390,6 @@ def dashboard_serializer(dashboard: "Dashboard") -> DashboardInfo:
|
||||
|
||||
def serialize_dashboard_object(dashboard: Any) -> DashboardInfo:
|
||||
"""Simple dashboard serializer that safely handles object attributes."""
|
||||
from superset.mcp_service.utils.url_utils import get_superset_base_url
|
||||
|
||||
# Construct URL from id/slug (the model's @property isn't available on
|
||||
# column-only query tuples returned by DAO.list with select_columns)
|
||||
|
||||
@@ -20,6 +20,7 @@ from .generate_dashboard import generate_dashboard
|
||||
from .get_dashboard_info import get_dashboard_info
|
||||
from .get_dashboard_layout import get_dashboard_layout
|
||||
from .list_dashboards import list_dashboards
|
||||
from .update_dashboard import update_dashboard
|
||||
|
||||
__all__ = [
|
||||
"list_dashboards",
|
||||
@@ -27,4 +28,5 @@ __all__ = [
|
||||
"get_dashboard_layout",
|
||||
"generate_dashboard",
|
||||
"add_chart_to_existing_dashboard",
|
||||
"update_dashboard",
|
||||
]
|
||||
|
||||
@@ -26,9 +26,11 @@ from typing import Any, Dict, List
|
||||
|
||||
from fastmcp import Context
|
||||
from flask import g
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.extensions import db, event_logger
|
||||
from superset.mcp_service.dashboard.constants import (
|
||||
generate_id,
|
||||
GRID_COLUMN_COUNT,
|
||||
@@ -38,8 +40,11 @@ from superset.mcp_service.dashboard.schemas import (
|
||||
DashboardInfo,
|
||||
GenerateDashboardRequest,
|
||||
GenerateDashboardResponse,
|
||||
serialize_chart_summary,
|
||||
serialize_tag_object,
|
||||
)
|
||||
from superset.mcp_service.privacy import user_can_view_data_model_metadata
|
||||
from superset.mcp_service.utils.response_utils import humanize_timestamp
|
||||
from superset.mcp_service.utils.url_utils import get_superset_base_url
|
||||
from superset.utils import json
|
||||
|
||||
@@ -197,22 +202,24 @@ def generate_dashboard( # noqa: C901
|
||||
- To add a chart to an EXISTING dashboard, use add_chart_to_existing_dashboard.
|
||||
Never use this tool as a fallback when add_chart_to_existing_dashboard fails.
|
||||
- All charts must exist and be accessible to current user
|
||||
- Charts arranged automatically in 2-column grid layout
|
||||
- Layout: by default, charts are arranged in an auto-generated 2-column
|
||||
grid. When ``position_json`` is supplied, that explicit layout is
|
||||
written verbatim and the auto-generated grid is skipped — use this to
|
||||
compose custom rows, header bands, or MARKDOWN/HEADER components.
|
||||
|
||||
Returns:
|
||||
- Dashboard ID and URL
|
||||
"""
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
# Advisory messages (e.g. title sanitization) surfaced to the caller
|
||||
# alongside the created dashboard so they can tell when their input
|
||||
# was altered.
|
||||
sanitization_warnings = list(getattr(request, "sanitization_warnings", []) or [])
|
||||
|
||||
try:
|
||||
# Get chart objects from IDs (required for SQLAlchemy relationships)
|
||||
from superset import db
|
||||
# avoids ImportError before Flask app initialisation:
|
||||
# `Exception: App not initialized yet. Please call init_app first`
|
||||
# raised from superset.utils.encrypt when Slice's encrypted Column
|
||||
# types are instantiated at model-class definition time.
|
||||
from superset.models.slice import Slice
|
||||
|
||||
with event_logger.log_context(action="mcp.generate_dashboard.chart_validation"):
|
||||
@@ -253,9 +260,14 @@ def generate_dashboard( # noqa: C901
|
||||
),
|
||||
)
|
||||
|
||||
# Create dashboard layout with chart objects
|
||||
# Create dashboard layout with chart objects.
|
||||
# If the caller provided an explicit position_json, use it verbatim;
|
||||
# otherwise auto-generate a packed-grid layout from the chart ids.
|
||||
with event_logger.log_context(action="mcp.generate_dashboard.layout"):
|
||||
layout = _create_dashboard_layout(chart_objects)
|
||||
if request.position_json:
|
||||
layout = request.position_json
|
||||
else:
|
||||
layout = _create_dashboard_layout(chart_objects)
|
||||
|
||||
# Resolve dashboard title: use provided title or derive from chart names
|
||||
dashboard_title = (
|
||||
@@ -276,27 +288,32 @@ def generate_dashboard( # noqa: C901
|
||||
from superset.models.dashboard import Dashboard
|
||||
|
||||
with event_logger.log_context(action="mcp.generate_dashboard.db_write"):
|
||||
json_metadata = json.dumps(
|
||||
{
|
||||
"filter_scopes": {},
|
||||
"expanded_slices": {},
|
||||
"refresh_frequency": 0,
|
||||
"timed_refresh_immune_slices": [],
|
||||
"color_scheme": None,
|
||||
"label_colors": {},
|
||||
"shared_label_colors": {},
|
||||
"color_scheme_domain": [],
|
||||
"cross_filters_enabled": False,
|
||||
"native_filter_configuration": [],
|
||||
"global_chart_configuration": {
|
||||
"scope": {
|
||||
"rootPath": ["ROOT_ID"],
|
||||
"excluded": [],
|
||||
}
|
||||
},
|
||||
"chart_configuration": {},
|
||||
}
|
||||
)
|
||||
# Build the default json_metadata that every new dashboard gets.
|
||||
# When the caller supplied json_metadata_overrides, merge those
|
||||
# in shallowly so the LLM can override label_colors / color_scheme
|
||||
# / cross_filters_enabled without re-specifying the whole shape.
|
||||
json_metadata_dict: Dict[str, Any] = {
|
||||
"filter_scopes": {},
|
||||
"expanded_slices": {},
|
||||
"refresh_frequency": 0,
|
||||
"timed_refresh_immune_slices": [],
|
||||
"color_scheme": None,
|
||||
"label_colors": {},
|
||||
"shared_label_colors": {},
|
||||
"color_scheme_domain": [],
|
||||
"cross_filters_enabled": False,
|
||||
"native_filter_configuration": [],
|
||||
"global_chart_configuration": {
|
||||
"scope": {
|
||||
"rootPath": ["ROOT_ID"],
|
||||
"excluded": [],
|
||||
}
|
||||
},
|
||||
"chart_configuration": {},
|
||||
}
|
||||
if request.json_metadata_overrides:
|
||||
json_metadata_dict.update(request.json_metadata_overrides)
|
||||
json_metadata = json.dumps(json_metadata_dict)
|
||||
|
||||
try:
|
||||
dashboard = Dashboard()
|
||||
@@ -308,6 +325,12 @@ def generate_dashboard( # noqa: C901
|
||||
if request.description:
|
||||
dashboard.description = request.description
|
||||
|
||||
if request.slug:
|
||||
dashboard.slug = request.slug
|
||||
|
||||
if request.css:
|
||||
dashboard.css = request.css
|
||||
|
||||
# Re-query the current user and charts directly in the
|
||||
# current db.session. g.user was loaded in a Flask
|
||||
# app_context that has since been torn down (the
|
||||
@@ -344,6 +367,38 @@ def generate_dashboard( # noqa: C901
|
||||
dashboard.id,
|
||||
exc_info=True,
|
||||
)
|
||||
except IntegrityError as db_err:
|
||||
try:
|
||||
db.session.rollback() # pylint: disable=consider-using-transaction
|
||||
except SQLAlchemyError:
|
||||
logger.warning(
|
||||
"Database rollback failed during error handling",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.error("Dashboard creation failed: %s", db_err, exc_info=True)
|
||||
# Slug uniqueness is the only IntegrityError a caller
|
||||
# can fix on retry; surface a clear, structured message
|
||||
# so the LLM can propose a different slug. Detection
|
||||
# scans both the Postgres constraint name
|
||||
# (``dashboards_slug_key``) and the SQLite phrasing
|
||||
# (``UNIQUE constraint failed: dashboards.slug``).
|
||||
err_text = str(db_err).lower()
|
||||
if request.slug and "slug" in err_text:
|
||||
return GenerateDashboardResponse(
|
||||
dashboard=None,
|
||||
dashboard_url=None,
|
||||
error=(
|
||||
f"Slug {request.slug!r} is already in use by "
|
||||
"another dashboard. Choose a different slug "
|
||||
"and retry, or omit the slug to get a "
|
||||
"generated URL."
|
||||
),
|
||||
)
|
||||
return GenerateDashboardResponse(
|
||||
dashboard=None,
|
||||
dashboard_url=None,
|
||||
error=("Failed to create dashboard due to a database constraint."),
|
||||
)
|
||||
except SQLAlchemyError as db_err:
|
||||
try:
|
||||
db.session.rollback() # pylint: disable=consider-using-transaction
|
||||
@@ -363,6 +418,17 @@ def generate_dashboard( # noqa: C901
|
||||
error="Failed to create dashboard due to a database error.",
|
||||
)
|
||||
|
||||
# ``dashboard.id`` is fixed at create-commit time; the post-commit
|
||||
# re-fetch below either returns the same row or fails — in either
|
||||
# outcome the URL doesn't change. Bind it once so the three
|
||||
# downstream consumers (partial response, DashboardInfo.url,
|
||||
# response.dashboard_url) share a single source. Prefer the slug
|
||||
# over the id to match ``update_dashboard``'s canonical URL shape.
|
||||
dashboard_url = (
|
||||
f"{get_superset_base_url()}/superset/dashboard/"
|
||||
f"{dashboard.slug or dashboard.id}/"
|
||||
)
|
||||
|
||||
# Re-fetch with eager-loaded relationships for serialization.
|
||||
# The preceding commit may invalidate the session in multi-tenant
|
||||
# environments, causing "Can't reconnect until invalid transaction
|
||||
@@ -396,9 +462,6 @@ def generate_dashboard( # noqa: C901
|
||||
"Database rollback failed during dashboard re-fetch error handling",
|
||||
exc_info=True,
|
||||
)
|
||||
dashboard_url = (
|
||||
f"{get_superset_base_url()}/superset/dashboard/{dashboard.id}/"
|
||||
)
|
||||
return GenerateDashboardResponse(
|
||||
dashboard=DashboardInfo(
|
||||
id=dashboard.id,
|
||||
@@ -409,16 +472,15 @@ def generate_dashboard( # noqa: C901
|
||||
),
|
||||
dashboard_url=dashboard_url,
|
||||
error=None,
|
||||
warnings=sanitization_warnings,
|
||||
warnings=sanitization_warnings
|
||||
+ [
|
||||
"Dashboard created but response metadata is partial "
|
||||
"(post-create refresh failed); some fields are omitted. "
|
||||
"Call get_dashboard_info to retrieve the full record."
|
||||
],
|
||||
)
|
||||
|
||||
# Convert to our response format
|
||||
from superset.mcp_service.dashboard.schemas import (
|
||||
serialize_chart_summary,
|
||||
serialize_tag_object,
|
||||
)
|
||||
from superset.mcp_service.utils.response_utils import humanize_timestamp
|
||||
|
||||
include_data_model_metadata = user_can_view_data_model_metadata()
|
||||
dashboard_info = DashboardInfo(
|
||||
id=dashboard.id,
|
||||
@@ -433,7 +495,7 @@ def generate_dashboard( # noqa: C901
|
||||
created_by=dashboard.created_by_name or None,
|
||||
changed_by=dashboard.changed_by_name or None,
|
||||
uuid=str(dashboard.uuid) if dashboard.uuid else None,
|
||||
url=f"{get_superset_base_url()}/superset/dashboard/{dashboard.id}/",
|
||||
url=dashboard_url,
|
||||
chart_count=len(request.chart_ids),
|
||||
tags=[
|
||||
serialize_tag_object(tag)
|
||||
@@ -453,8 +515,6 @@ def generate_dashboard( # noqa: C901
|
||||
],
|
||||
)
|
||||
|
||||
dashboard_url = f"{get_superset_base_url()}/superset/dashboard/{dashboard.id}/"
|
||||
|
||||
logger.info(
|
||||
"Created dashboard %s with %s charts", dashboard.id, len(request.chart_ids)
|
||||
)
|
||||
@@ -467,17 +527,19 @@ def generate_dashboard( # noqa: C901
|
||||
)
|
||||
|
||||
except (SQLAlchemyError, ValueError, AttributeError, ValidationError) as e:
|
||||
from superset import db
|
||||
|
||||
try:
|
||||
db.session.rollback() # pylint: disable=consider-using-transaction
|
||||
except SQLAlchemyError:
|
||||
logger.warning(
|
||||
"Database rollback failed during error handling", exc_info=True
|
||||
)
|
||||
# ``str(e)`` on SQLAlchemyError frequently contains table/column/
|
||||
# constraint names that should not leak to the MCP response.
|
||||
# The raw exception is captured above via ``logger.error`` with
|
||||
# ``exc_info=True``; the response surfaces a generic message.
|
||||
logger.error("Error creating dashboard: %s", e, exc_info=True)
|
||||
return GenerateDashboardResponse(
|
||||
dashboard=None,
|
||||
dashboard_url=None,
|
||||
error=f"Failed to create dashboard: {str(e)}",
|
||||
error="Failed to create dashboard due to an internal error.",
|
||||
)
|
||||
|
||||
269
superset/mcp_service/dashboard/tool/update_dashboard.py
Normal file
269
superset/mcp_service/dashboard/tool/update_dashboard.py
Normal file
@@ -0,0 +1,269 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Update dashboard FastMCP tool
|
||||
|
||||
This module contains the FastMCP tool for updating an existing dashboard's
|
||||
layout, theme, and styling. Companion to ``generate_dashboard`` for
|
||||
incremental edits without re-creating the dashboard.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastmcp import Context
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.commands.dashboard.exceptions import DashboardNotFoundError
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
from superset.extensions import db, event_logger
|
||||
from superset.mcp_service.dashboard.schemas import (
|
||||
dashboard_serializer,
|
||||
DashboardError,
|
||||
UpdateDashboardRequest,
|
||||
UpdateDashboardResponse,
|
||||
)
|
||||
from superset.mcp_service.utils.url_utils import get_superset_base_url
|
||||
from superset.utils import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _build_dashboard_url(dashboard: Any) -> str:
|
||||
"""Build the user-facing dashboard URL, preferring slug over id."""
|
||||
return (
|
||||
f"{get_superset_base_url()}/superset/dashboard/"
|
||||
f"{dashboard.slug or dashboard.id}/"
|
||||
)
|
||||
|
||||
|
||||
def _find_and_authorize_dashboard(
|
||||
identifier: int | str,
|
||||
) -> tuple[Any, UpdateDashboardResponse | DashboardError | None]:
|
||||
"""Return (dashboard, None) on success or (None, error_response) on failure.
|
||||
|
||||
Mirrors the helper in ``add_chart_to_existing_dashboard``: combines
|
||||
the not-found and forbidden cases so the main tool body has a single
|
||||
pre-condition branch. Returns ``DashboardError`` on not-found and
|
||||
``UpdateDashboardResponse`` (with ``permission_denied=True``) on
|
||||
ownership failure — the two shapes carry different information for
|
||||
the caller.
|
||||
"""
|
||||
# avoids ImportError before Flask app initialisation:
|
||||
# `Exception: App not initialized yet. Please call init_app first`
|
||||
# raised from superset.utils.encrypt when DashboardDAO is imported
|
||||
# (via Slice's encrypted Column types). `security_manager` is a
|
||||
# LocalProxy that needs the same app context to resolve at call
|
||||
# time, so it is co-located with the DAO it accompanies.
|
||||
from superset import security_manager
|
||||
from superset.daos.dashboard import DashboardDAO
|
||||
|
||||
try:
|
||||
dashboard = DashboardDAO.get_by_id_or_slug(identifier)
|
||||
except (DashboardNotFoundError, SQLAlchemyError):
|
||||
return None, DashboardError(
|
||||
error=f"Dashboard not found: {identifier!r}",
|
||||
error_type="DashboardNotFound",
|
||||
)
|
||||
|
||||
if dashboard is None:
|
||||
return None, DashboardError(
|
||||
error=f"Dashboard not found: {identifier!r}",
|
||||
error_type="DashboardNotFound",
|
||||
)
|
||||
|
||||
try:
|
||||
security_manager.raise_for_ownership(dashboard)
|
||||
except SupersetSecurityException:
|
||||
return None, UpdateDashboardResponse(
|
||||
permission_denied=True,
|
||||
error=(
|
||||
f"You don't have permission to edit dashboard "
|
||||
f"'{dashboard.dashboard_title}' (ID: {dashboard.id})."
|
||||
),
|
||||
)
|
||||
|
||||
return dashboard, None
|
||||
|
||||
|
||||
def _merge_json_metadata(dashboard: Any, overrides: dict[str, Any]) -> str:
|
||||
"""Shallow-merge ``overrides`` onto the dashboard's existing metadata.
|
||||
|
||||
Parses defensively: a row may carry malformed JSON or a non-object
|
||||
payload (e.g. ``"[]"``) from an older migration or manual edit. Either
|
||||
would raise out of the caller's ``SQLAlchemyError`` handler, so fall
|
||||
back to an empty object and overlay the overrides on top.
|
||||
"""
|
||||
existing: dict[str, Any] = {}
|
||||
if dashboard.json_metadata:
|
||||
try:
|
||||
parsed = json.loads(dashboard.json_metadata)
|
||||
if isinstance(parsed, dict):
|
||||
existing = parsed
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
existing.update(overrides)
|
||||
return json.dumps(existing)
|
||||
|
||||
|
||||
def _apply_field_updates(dashboard: Any, request: UpdateDashboardRequest) -> list[str]:
|
||||
"""Apply each explicitly-passed field to the dashboard.
|
||||
|
||||
Returns the names of fields actually changed. Mutates ``dashboard``
|
||||
in place. ``json_metadata_overrides`` is merged shallowly with the
|
||||
existing ``json_metadata``; an empty string in ``slug`` or ``css``
|
||||
clears the underlying value.
|
||||
"""
|
||||
changed: list[str] = []
|
||||
|
||||
if request.dashboard_title is not None:
|
||||
dashboard.dashboard_title = request.dashboard_title
|
||||
changed.append("dashboard_title")
|
||||
|
||||
if request.description is not None:
|
||||
dashboard.description = request.description
|
||||
changed.append("description")
|
||||
|
||||
if request.slug is not None:
|
||||
dashboard.slug = request.slug or None
|
||||
changed.append("slug")
|
||||
|
||||
if request.published is not None:
|
||||
dashboard.published = request.published
|
||||
changed.append("published")
|
||||
|
||||
if request.position_json is not None:
|
||||
dashboard.position_json = json.dumps(request.position_json)
|
||||
changed.append("position_json")
|
||||
|
||||
if request.json_metadata_overrides is not None:
|
||||
dashboard.json_metadata = _merge_json_metadata(
|
||||
dashboard, request.json_metadata_overrides
|
||||
)
|
||||
changed.append("json_metadata")
|
||||
|
||||
if request.css is not None:
|
||||
dashboard.css = request.css or None
|
||||
changed.append("css")
|
||||
|
||||
return changed
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["mutate"],
|
||||
class_permission_name="Dashboard",
|
||||
method_permission_name="write",
|
||||
annotations=ToolAnnotations(
|
||||
title="Update dashboard layout/theme/CSS",
|
||||
readOnlyHint=False,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
def update_dashboard(
|
||||
request: UpdateDashboardRequest, ctx: Context
|
||||
) -> UpdateDashboardResponse | DashboardError:
|
||||
"""Patch an existing dashboard's layout, theme, or styling.
|
||||
|
||||
Companion to ``generate_dashboard`` for incremental edits. Accepts
|
||||
the same layout/theme/CSS fields that ``generate_dashboard`` does, so
|
||||
an LLM can:
|
||||
|
||||
- Set or replace ``position_json`` after auto-generation
|
||||
- Apply brand ``label_colors`` and ``color_scheme`` via
|
||||
``json_metadata_overrides``
|
||||
- Toggle ``cross_filters_enabled`` via ``json_metadata_overrides``
|
||||
- Inject ``css`` to hide chrome on print-ready dashboards
|
||||
- Update ``dashboard_title``, ``description``, ``slug``, ``published``
|
||||
|
||||
Only the fields explicitly passed are applied; other fields are left
|
||||
unchanged. ``json_metadata_overrides`` is merged shallowly with the
|
||||
existing json_metadata — pass only the keys you want to change.
|
||||
|
||||
Example::
|
||||
|
||||
update_dashboard(request={
|
||||
"identifier": 42,
|
||||
"json_metadata_overrides": {
|
||||
"label_colors": {"Electronics": "#4C78A8"},
|
||||
"cross_filters_enabled": False,
|
||||
},
|
||||
"css": ".header-controls {display: none;}",
|
||||
})
|
||||
"""
|
||||
ctx.info(f"Updating dashboard: identifier={request.identifier}")
|
||||
|
||||
dashboard, auth_error = _find_and_authorize_dashboard(request.identifier)
|
||||
if auth_error is not None:
|
||||
return auth_error
|
||||
|
||||
changed_fields: list[str] = []
|
||||
warnings: list[str] = list(request.sanitization_warnings)
|
||||
|
||||
try:
|
||||
with event_logger.log_context(action="mcp.update_dashboard.apply"):
|
||||
changed_fields = _apply_field_updates(dashboard, request)
|
||||
|
||||
if not changed_fields:
|
||||
warnings.append("No fields provided; dashboard unchanged.")
|
||||
return UpdateDashboardResponse(
|
||||
dashboard=dashboard_serializer(dashboard),
|
||||
dashboard_url=_build_dashboard_url(dashboard),
|
||||
error=None,
|
||||
changed_fields=[],
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
db.session.commit() # pylint: disable=consider-using-transaction
|
||||
try:
|
||||
db.session.refresh(dashboard)
|
||||
except SQLAlchemyError:
|
||||
logger.warning(
|
||||
"Dashboard %s updated but refresh failed; "
|
||||
"continuing with current values",
|
||||
dashboard.id,
|
||||
exc_info=True,
|
||||
)
|
||||
warnings.append(
|
||||
"Dashboard updated but post-update refresh failed; "
|
||||
"returned values may not reflect database state."
|
||||
)
|
||||
|
||||
except SQLAlchemyError as db_err:
|
||||
try:
|
||||
db.session.rollback() # pylint: disable=consider-using-transaction
|
||||
except SQLAlchemyError:
|
||||
logger.warning(
|
||||
"Database rollback failed during error handling",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.error("Dashboard update failed: %s", db_err, exc_info=True)
|
||||
return DashboardError(
|
||||
error="Failed to update dashboard due to a database error.",
|
||||
error_type="DatabaseError",
|
||||
)
|
||||
|
||||
ctx.info(f"Dashboard {dashboard.id} updated: changed={changed_fields}")
|
||||
|
||||
return UpdateDashboardResponse(
|
||||
dashboard=dashboard_serializer(dashboard),
|
||||
dashboard_url=_build_dashboard_url(dashboard),
|
||||
error=None,
|
||||
changed_fields=changed_fields,
|
||||
warnings=warnings,
|
||||
)
|
||||
@@ -33,6 +33,7 @@ from typing import (
|
||||
Callable,
|
||||
cast,
|
||||
ClassVar,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
@@ -130,7 +131,11 @@ from superset.utils.core import (
|
||||
SqlExpressionType,
|
||||
TIME_COMPARISON,
|
||||
)
|
||||
from superset.utils.date_parser import get_past_or_future, normalize_time_delta
|
||||
from superset.utils.date_parser import (
|
||||
get_past_or_future,
|
||||
normalize_time_delta,
|
||||
TimeDeltaAmbiguousError,
|
||||
)
|
||||
from superset.utils.dates import datetime_to_epoch
|
||||
from superset.utils.rls import apply_rls
|
||||
|
||||
@@ -2013,6 +2018,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||
offset_dfs,
|
||||
time_grain,
|
||||
join_keys,
|
||||
full_range=getattr(query_object, "time_compare_full_range", False),
|
||||
)
|
||||
|
||||
return CachedTimeOffset(df=df, queries=queries, cache_keys=cache_keys)
|
||||
@@ -2210,7 +2216,11 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||
return offset_df, join_keys
|
||||
|
||||
def _perform_join(
|
||||
self, df: pd.DataFrame, offset_df: pd.DataFrame, actual_join_keys: list[str]
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
offset_df: pd.DataFrame,
|
||||
actual_join_keys: list[str],
|
||||
how: Literal["left", "right", "inner", "outer", "cross"] = "left",
|
||||
) -> pd.DataFrame:
|
||||
"""Perform the appropriate join operation."""
|
||||
if actual_join_keys:
|
||||
@@ -2219,6 +2229,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||
right_df=offset_df,
|
||||
join_keys=actual_join_keys,
|
||||
rsuffix=R_SUFFIX,
|
||||
how=how,
|
||||
)
|
||||
else:
|
||||
temp_key = "__temp_join_key__"
|
||||
@@ -2230,6 +2241,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||
right_df=offset_df,
|
||||
join_keys=[temp_key],
|
||||
rsuffix=R_SUFFIX,
|
||||
how=how,
|
||||
)
|
||||
|
||||
# Remove temporary join keys
|
||||
@@ -2245,6 +2257,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||
offset_dfs: dict[str, pd.DataFrame],
|
||||
time_grain: str | None,
|
||||
join_keys: list[str],
|
||||
full_range: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Join offset DataFrames with the main DataFrame.
|
||||
@@ -2253,6 +2266,10 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||
:param offset_dfs: A list of offset DataFrames.
|
||||
:param time_grain: The time grain used to calculate the temporal join key.
|
||||
:param join_keys: The keys to join on.
|
||||
:param full_range: When True, time-shifted (offset) series keep their full
|
||||
time range instead of being truncated to the main series' range. This
|
||||
uses an outer join so offset-only rows (e.g. the rest of a prior day when
|
||||
the current day is still in progress) are preserved.
|
||||
"""
|
||||
join_column_producer = app.config["TIME_GRAIN_JOIN_COLUMN_PRODUCERS"].get(
|
||||
time_grain
|
||||
@@ -2280,13 +2297,60 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||
join_column_producer,
|
||||
)
|
||||
|
||||
df = self._perform_join(df, offset_df, actual_join_keys)
|
||||
# The full-range option is only meaningful for relative offsets aligned
|
||||
# on a temporal join column (time_grain set). Date-range offsets and the
|
||||
# grain-less path keep the existing left-join behavior.
|
||||
use_outer_join = (
|
||||
full_range
|
||||
and bool(time_grain)
|
||||
and not is_date_range_offset
|
||||
and bool(join_keys)
|
||||
)
|
||||
how: Literal["left", "outer"] = "outer" if use_outer_join else "left"
|
||||
|
||||
df = self._perform_join(df, offset_df, actual_join_keys, how=how)
|
||||
|
||||
if use_outer_join:
|
||||
df = self._coalesce_offset_index(df, offset, join_keys)
|
||||
|
||||
df = self._apply_cleanup_logic(
|
||||
df, offset, time_grain, join_keys, is_date_range_offset
|
||||
)
|
||||
|
||||
return df
|
||||
|
||||
def _coalesce_offset_index(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
offset: str,
|
||||
join_keys: list[str],
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Rebuild the temporal x-axis after an outer join with an offset DataFrame.
|
||||
|
||||
Offset-only rows (those with no matching row in the main series) have a null
|
||||
x-axis value because the join happens on the normalized offset join column,
|
||||
not the raw temporal column. Their real timestamp lives in the suffixed
|
||||
right-hand column, expressed in the offset's own time range (e.g. "yesterday
|
||||
15:00"). Shifting it forward by the offset places it on the main series'
|
||||
axis (e.g. "today 15:00") so the comparison line spans the full period.
|
||||
"""
|
||||
x_axis = join_keys[0]
|
||||
offset_x_axis = f"{x_axis}{R_SUFFIX}"
|
||||
if x_axis not in df.columns or offset_x_axis not in df.columns:
|
||||
return df
|
||||
|
||||
# normalize_time_delta returns a negative delta for "... ago" offsets, so
|
||||
# subtracting it shifts the historical timestamp forward onto the main axis.
|
||||
try:
|
||||
forward_shift = DateOffset(**normalize_time_delta(offset))
|
||||
except (ValueError, TimeDeltaAmbiguousError):
|
||||
return df
|
||||
|
||||
shifted = df[offset_x_axis] - forward_shift
|
||||
df[x_axis] = df[x_axis].fillna(shifted)
|
||||
return df
|
||||
|
||||
def add_offset_join_column(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
|
||||
@@ -221,6 +221,7 @@ class QueryObjectDict(TypedDict, total=False):
|
||||
group_others_when_limit_reached: bool
|
||||
to_dttm: datetime | None
|
||||
time_shift: str | None
|
||||
time_compare_full_range: bool
|
||||
post_processing: list[dict[str, Any]]
|
||||
|
||||
# Additional fields used throughout the codebase
|
||||
|
||||
@@ -194,9 +194,9 @@ def should_use_v2_api() -> bool:
|
||||
except SlackApiError:
|
||||
# use the v1 api but warn with a deprecation message
|
||||
logger.warning(
|
||||
"""Your current Slack scopes are missing `channels:read`. Please add
|
||||
this to your Slack app in order to continue using the v1 API. Support
|
||||
for the old Slack API will be removed in Superset version 6.0.0."""
|
||||
"Your current Slack scopes are missing `channels:read`. Please add "
|
||||
"this to your Slack app in order to continue using the v1 API. Support "
|
||||
"for the old Slack API will be removed in Superset version 6.0.0."
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
@@ -135,6 +135,10 @@ ALERT_REPORTS_QUERY_EXECUTION_MAX_TRIES = 3
|
||||
|
||||
FAB_ADD_SECURITY_API = True
|
||||
|
||||
# Swagger UI / OpenAPI spec is opt-in in the base config; enable it for tests
|
||||
# that exercise the /api/v1/_openapi spec endpoint.
|
||||
FAB_API_SWAGGER_UI = True
|
||||
|
||||
|
||||
class CeleryConfig:
|
||||
broker_url = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}"
|
||||
|
||||
@@ -53,6 +53,9 @@ _datasource._perform_join = ExploreMixin._perform_join.__get__(_datasource)
|
||||
_datasource._apply_cleanup_logic = ExploreMixin._apply_cleanup_logic.__get__(
|
||||
_datasource
|
||||
)
|
||||
_datasource._coalesce_offset_index = ExploreMixin._coalesce_offset_index.__get__(
|
||||
_datasource
|
||||
)
|
||||
# Static methods don't need binding - assign directly
|
||||
_datasource.generate_join_column = ExploreMixin.generate_join_column
|
||||
_datasource.is_valid_date_range_static = ExploreMixin.is_valid_date_range_static
|
||||
@@ -211,6 +214,91 @@ def test_join_offset_dfs_with_month_granularity():
|
||||
assert_frame_equal(expected, result)
|
||||
|
||||
|
||||
def test_join_offset_dfs_full_range_keeps_historical_tail():
|
||||
"""
|
||||
With full_range=True the offset (historical) series keeps its full time range
|
||||
even when the main series ends earlier.
|
||||
|
||||
Simulates "today so far" (main, ends at 01:00) compared against "1 day ago"
|
||||
(a complete prior day, runs to 02:00). The 02:00 historical point must survive
|
||||
and be aligned onto today's axis, with the main metric left null there.
|
||||
"""
|
||||
# Main series: today, only two hours of data so far.
|
||||
df = DataFrame(
|
||||
{
|
||||
"A": [Timestamp("2021-01-02 00:00"), Timestamp("2021-01-02 01:00")],
|
||||
"V": [1.0, 2.0],
|
||||
}
|
||||
)
|
||||
# Offset series: the full prior day (already renamed metric column "B").
|
||||
offset_df = DataFrame(
|
||||
{
|
||||
"A": [
|
||||
Timestamp("2021-01-01 00:00"),
|
||||
Timestamp("2021-01-01 01:00"),
|
||||
Timestamp("2021-01-01 02:00"),
|
||||
],
|
||||
"B": [10.0, 20.0, 30.0],
|
||||
}
|
||||
)
|
||||
offset_dfs = {"1 day ago": offset_df}
|
||||
time_grain = TimeGrain.HOUR
|
||||
join_keys = ["A"]
|
||||
|
||||
expected = DataFrame(
|
||||
{
|
||||
"A": [
|
||||
Timestamp("2021-01-02 00:00"),
|
||||
Timestamp("2021-01-02 01:00"),
|
||||
Timestamp("2021-01-02 02:00"),
|
||||
],
|
||||
"V": [1.0, 2.0, None],
|
||||
"B": [10.0, 20.0, 30.0],
|
||||
}
|
||||
)
|
||||
|
||||
result = query_context_processor.join_offset_dfs(
|
||||
df, offset_dfs, time_grain, join_keys, full_range=True
|
||||
)
|
||||
|
||||
assert_frame_equal(expected, result)
|
||||
|
||||
|
||||
def test_join_offset_dfs_full_range_disabled_truncates_historical():
|
||||
"""The default (full_range=False) left join drops the historical 02:00 point."""
|
||||
df = DataFrame(
|
||||
{
|
||||
"A": [Timestamp("2021-01-02 00:00"), Timestamp("2021-01-02 01:00")],
|
||||
"V": [1.0, 2.0],
|
||||
}
|
||||
)
|
||||
offset_df = DataFrame(
|
||||
{
|
||||
"A": [
|
||||
Timestamp("2021-01-01 00:00"),
|
||||
Timestamp("2021-01-01 01:00"),
|
||||
Timestamp("2021-01-01 02:00"),
|
||||
],
|
||||
"B": [10.0, 20.0, 30.0],
|
||||
}
|
||||
)
|
||||
offset_dfs = {"1 day ago": offset_df}
|
||||
|
||||
expected = DataFrame(
|
||||
{
|
||||
"A": [Timestamp("2021-01-02 00:00"), Timestamp("2021-01-02 01:00")],
|
||||
"V": [1.0, 2.0],
|
||||
"B": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
|
||||
result = query_context_processor.join_offset_dfs(
|
||||
df, offset_dfs, TimeGrain.HOUR, ["A"], full_range=False
|
||||
)
|
||||
|
||||
assert_frame_equal(expected, result)
|
||||
|
||||
|
||||
def test_join_offset_dfs_totals_query_no_dimensions():
|
||||
"""
|
||||
Test time offset join for totals query with no dimension columns.
|
||||
|
||||
77
tests/unit_tests/config_swagger_ui_test.py
Normal file
77
tests/unit_tests/config_swagger_ui_test.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Tests for environment-driven Swagger UI config defaults.
|
||||
|
||||
``superset.config`` is imported in a fresh subprocess for each case so the
|
||||
module is evaluated under a controlled environment without reloading (and
|
||||
mutating) the config module shared by the rest of the test session.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _resolve_swagger_default(env_value: str | None) -> str:
|
||||
"""Resolve ``FAB_API_SWAGGER_UI`` for a given ``SUPERSET_ENABLE_SWAGGER_UI``.
|
||||
|
||||
Evaluates ``superset.config`` in a fresh subprocess under a controlled
|
||||
environment so the result reflects only the supplied env var. Config-path
|
||||
overrides are stripped so a local ``superset_config`` cannot taint the
|
||||
default, and only the final stdout line is read in case config loading
|
||||
emits banner output.
|
||||
"""
|
||||
env = dict(os.environ)
|
||||
for var in (
|
||||
"SUPERSET_ENABLE_SWAGGER_UI",
|
||||
"SUPERSET_CONFIG_PATH",
|
||||
"SUPERSET_CONFIG",
|
||||
):
|
||||
env.pop(var, None)
|
||||
if env_value is not None:
|
||||
env["SUPERSET_ENABLE_SWAGGER_UI"] = env_value
|
||||
result = subprocess.run( # noqa: S603
|
||||
[
|
||||
sys.executable,
|
||||
"-c",
|
||||
"import superset.config as c; print(c.FAB_API_SWAGGER_UI)",
|
||||
],
|
||||
env=env,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
return result.stdout.strip().splitlines()[-1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_value, expected",
|
||||
[
|
||||
(None, "False"), # unset -> off by default
|
||||
("true", "True"),
|
||||
("True", "True"),
|
||||
("false", "False"),
|
||||
("", "False"),
|
||||
],
|
||||
)
|
||||
def test_fab_api_swagger_ui_is_env_driven_and_off_by_default(
|
||||
env_value: str | None, expected: str
|
||||
) -> None:
|
||||
"""Swagger UI defaults to off and follows ``SUPERSET_ENABLE_SWAGGER_UI``."""
|
||||
assert _resolve_swagger_default(env_value) == expected
|
||||
@@ -17,23 +17,14 @@
|
||||
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
from superset.db_engine_specs.base import OAuth2State
|
||||
from superset.db_engine_specs.databricks import (
|
||||
DatabricksNativeEngineSpec,
|
||||
DatabricksPythonConnectorEngineSpec,
|
||||
)
|
||||
from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError
|
||||
from superset.superset_typing import OAuth2ClientConfig
|
||||
from superset.utils import json
|
||||
from superset.utils.oauth2 import decode_oauth2_state
|
||||
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
|
||||
from tests.unit_tests.fixtures.common import dttm # noqa: F401
|
||||
|
||||
@@ -300,541 +291,3 @@ def test_get_prequeries(mocker: MockerFixture) -> None:
|
||||
"USE CATALOG `evil`` USE CATALOG bad`",
|
||||
"USE SCHEMA `evil`` USE SCHEMA bad`",
|
||||
]
|
||||
|
||||
|
||||
# OAuth2 Tests
|
||||
|
||||
|
||||
def test_oauth2_attributes() -> None:
|
||||
"""
|
||||
Test that OAuth2 attributes are properly set for both engine specs.
|
||||
"""
|
||||
# Test DatabricksNativeEngineSpec
|
||||
assert DatabricksNativeEngineSpec.supports_oauth2 is True
|
||||
assert DatabricksNativeEngineSpec.oauth2_exception is OAuth2RedirectError
|
||||
assert DatabricksNativeEngineSpec.oauth2_scope == "sql"
|
||||
# OAuth2 endpoints are now dynamic and set at runtime
|
||||
assert DatabricksNativeEngineSpec.oauth2_authorization_request_uri == ""
|
||||
assert DatabricksNativeEngineSpec.oauth2_token_request_uri == ""
|
||||
|
||||
# Test DatabricksPythonConnectorEngineSpec
|
||||
assert DatabricksPythonConnectorEngineSpec.supports_oauth2 is True
|
||||
assert DatabricksPythonConnectorEngineSpec.oauth2_exception is OAuth2RedirectError
|
||||
assert DatabricksPythonConnectorEngineSpec.oauth2_scope == "sql"
|
||||
# OAuth2 endpoints are now dynamic and set at runtime
|
||||
assert DatabricksPythonConnectorEngineSpec.oauth2_authorization_request_uri == ""
|
||||
assert DatabricksPythonConnectorEngineSpec.oauth2_token_request_uri == ""
|
||||
|
||||
|
||||
def test_impersonate_user_with_token(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test impersonate_user method with OAuth2 token for DatabricksNativeEngineSpec.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
original_url = make_url(
|
||||
"databricks+connector://token:original-token@host:443/database"
|
||||
)
|
||||
engine_kwargs = {"connect_args": {"access_token": "original-token"}}
|
||||
|
||||
# Test with user token
|
||||
url, kwargs = DatabricksNativeEngineSpec.impersonate_user(
|
||||
database=database,
|
||||
username="user1",
|
||||
user_token="user-oauth-token", # noqa: S106
|
||||
url=original_url,
|
||||
engine_kwargs=engine_kwargs,
|
||||
)
|
||||
|
||||
# Check that the password (token) was updated in the URL
|
||||
assert url.password == "user-oauth-token" # noqa: S105
|
||||
# Check that access_token was updated in connect_args
|
||||
assert kwargs["connect_args"]["access_token"] == "user-oauth-token" # noqa: S105
|
||||
|
||||
|
||||
def test_impersonate_user_without_token(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test impersonate_user method without OAuth2 token.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
original_url = make_url(
|
||||
"databricks+connector://token:original-token@host:443/database"
|
||||
)
|
||||
engine_kwargs = {"connect_args": {"access_token": "original-token"}}
|
||||
|
||||
# Test without user token
|
||||
url, kwargs = DatabricksNativeEngineSpec.impersonate_user(
|
||||
database=database,
|
||||
username="user1",
|
||||
user_token=None,
|
||||
url=original_url,
|
||||
engine_kwargs=engine_kwargs,
|
||||
)
|
||||
|
||||
# Check that nothing was changed
|
||||
assert url.password == "original-token" # noqa: S105
|
||||
assert kwargs["connect_args"]["access_token"] == "original-token" # noqa: S105
|
||||
|
||||
|
||||
def test_impersonate_user_python_connector(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test impersonate_user method for DatabricksPythonConnectorEngineSpec.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
original_url = make_url(
|
||||
"databricks://token:original-token@host:443?http_path=path&catalog=main&schema=default"
|
||||
)
|
||||
engine_kwargs = {"connect_args": {"access_token": "original-token"}}
|
||||
|
||||
# Test with user token
|
||||
url, kwargs = DatabricksPythonConnectorEngineSpec.impersonate_user(
|
||||
database=database,
|
||||
username="user1",
|
||||
user_token="user-oauth-token", # noqa: S106
|
||||
url=original_url,
|
||||
engine_kwargs=engine_kwargs,
|
||||
)
|
||||
|
||||
# Check that the password (token) was updated in the URL
|
||||
assert url.password == "user-oauth-token" # noqa: S105
|
||||
# Check that access_token was updated in connect_args
|
||||
assert kwargs["connect_args"]["access_token"] == "user-oauth-token" # noqa: S105
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_config_native() -> OAuth2ClientConfig:
|
||||
"""
|
||||
Config for Databricks Native OAuth2.
|
||||
"""
|
||||
return {
|
||||
"id": "databricks-client-id",
|
||||
"secret": "databricks-client-secret",
|
||||
"scope": "sql",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"authorization_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/authorize",
|
||||
"token_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
|
||||
"request_content_type": "json",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_config_python() -> OAuth2ClientConfig:
|
||||
"""
|
||||
Config for Databricks Python Connector OAuth2.
|
||||
"""
|
||||
return {
|
||||
"id": "databricks-client-id",
|
||||
"secret": "databricks-client-secret",
|
||||
"scope": "sql",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"authorization_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/authorize",
|
||||
"token_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
|
||||
"request_content_type": "json",
|
||||
}
|
||||
|
||||
|
||||
def test_is_oauth2_enabled_no_config_native(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test `is_oauth2_enabled` when OAuth2 is not configured for Native engine.
|
||||
"""
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
new={"DATABASE_OAUTH2_CLIENTS": {}},
|
||||
)
|
||||
|
||||
assert DatabricksNativeEngineSpec.is_oauth2_enabled() is False
|
||||
|
||||
|
||||
def test_is_oauth2_enabled_config_native(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test `is_oauth2_enabled` when OAuth2 is configured for Native engine.
|
||||
"""
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
new={
|
||||
"DATABASE_OAUTH2_CLIENTS": {
|
||||
"Databricks (legacy)": {
|
||||
"id": "client-id",
|
||||
"secret": "client-secret",
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert DatabricksNativeEngineSpec.is_oauth2_enabled() is True
|
||||
|
||||
|
||||
def test_is_oauth2_enabled_no_config_python(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test `is_oauth2_enabled` when OAuth2 is not configured for Python Connector engine.
|
||||
"""
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
new={"DATABASE_OAUTH2_CLIENTS": {}},
|
||||
)
|
||||
|
||||
assert DatabricksPythonConnectorEngineSpec.is_oauth2_enabled() is False
|
||||
|
||||
|
||||
def test_is_oauth2_enabled_config_python(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test `is_oauth2_enabled` when OAuth2 is configured for Python Connector engine.
|
||||
"""
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
new={
|
||||
"DATABASE_OAUTH2_CLIENTS": {
|
||||
"Databricks": {
|
||||
"id": "client-id",
|
||||
"secret": "client-secret",
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert DatabricksPythonConnectorEngineSpec.is_oauth2_enabled() is True
|
||||
|
||||
|
||||
def test_get_oauth2_authorization_uri_native(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config_native: OAuth2ClientConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Test `get_oauth2_authorization_uri` for Native engine.
|
||||
"""
|
||||
from superset.db_engine_specs.base import OAuth2State
|
||||
|
||||
state: OAuth2State = {
|
||||
"database_id": 1,
|
||||
"user_id": 1,
|
||||
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"tab_id": "1234",
|
||||
}
|
||||
|
||||
url = DatabricksNativeEngineSpec.get_oauth2_authorization_uri(
|
||||
oauth2_config_native, state
|
||||
)
|
||||
parsed = urlparse(url)
|
||||
assert parsed.netloc == "accounts.cloud.databricks.com"
|
||||
assert parsed.path == "/oidc/accounts/12345/v1/authorize"
|
||||
|
||||
query = parse_qs(parsed.query)
|
||||
assert query["scope"][0] == "sql"
|
||||
encoded_state = query["state"][0].replace("%2E", ".")
|
||||
assert decode_oauth2_state(encoded_state) == state
|
||||
|
||||
|
||||
def test_get_oauth2_authorization_uri_python(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config_python: OAuth2ClientConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Test `get_oauth2_authorization_uri` for Python Connector engine.
|
||||
"""
|
||||
from superset.db_engine_specs.base import OAuth2State
|
||||
|
||||
state: OAuth2State = {
|
||||
"database_id": 1,
|
||||
"user_id": 1,
|
||||
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"tab_id": "1234",
|
||||
}
|
||||
|
||||
url = DatabricksPythonConnectorEngineSpec.get_oauth2_authorization_uri(
|
||||
oauth2_config_python, state
|
||||
)
|
||||
parsed = urlparse(url)
|
||||
assert parsed.netloc == "accounts.cloud.databricks.com"
|
||||
assert parsed.path == "/oidc/accounts/12345/v1/authorize"
|
||||
|
||||
query = parse_qs(parsed.query)
|
||||
assert query["scope"][0] == "sql"
|
||||
encoded_state = query["state"][0].replace("%2E", ".")
|
||||
assert decode_oauth2_state(encoded_state) == state
|
||||
|
||||
|
||||
def test_get_oauth2_token_native(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config_native: OAuth2ClientConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Test `get_oauth2_token` for Native engine.
|
||||
"""
|
||||
requests = mocker.patch("superset.db_engine_specs.base.requests")
|
||||
requests.post().json.return_value = {
|
||||
"access_token": "access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "sql",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "refresh-token",
|
||||
}
|
||||
|
||||
assert DatabricksNativeEngineSpec.get_oauth2_token(
|
||||
oauth2_config_native, "authorization-code"
|
||||
) == {
|
||||
"access_token": "access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "sql",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "refresh-token",
|
||||
}
|
||||
requests.post.assert_called_with(
|
||||
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
|
||||
json={
|
||||
"code": "authorization-code",
|
||||
"client_id": "databricks-client-id",
|
||||
"client_secret": "databricks-client-secret",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
|
||||
def test_get_oauth2_token_python(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config_python: OAuth2ClientConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Test `get_oauth2_token` for Python Connector engine.
|
||||
"""
|
||||
requests = mocker.patch("superset.db_engine_specs.base.requests")
|
||||
requests.post().json.return_value = {
|
||||
"access_token": "access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "sql",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "refresh-token",
|
||||
}
|
||||
|
||||
assert DatabricksPythonConnectorEngineSpec.get_oauth2_token(
|
||||
oauth2_config_python, "authorization-code"
|
||||
) == {
|
||||
"access_token": "access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "sql",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "refresh-token",
|
||||
}
|
||||
requests.post.assert_called_with(
|
||||
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
|
||||
json={
|
||||
"code": "authorization-code",
|
||||
"client_id": "databricks-client-id",
|
||||
"client_secret": "databricks-client-secret",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
|
||||
def test_get_oauth2_fresh_token_native(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config_native: OAuth2ClientConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Test `get_oauth2_fresh_token` for Native engine.
|
||||
"""
|
||||
requests = mocker.patch("superset.db_engine_specs.base.requests")
|
||||
requests.post().json.return_value = {
|
||||
"access_token": "new-access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "sql",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "new-refresh-token",
|
||||
}
|
||||
|
||||
assert DatabricksNativeEngineSpec.get_oauth2_fresh_token(
|
||||
oauth2_config_native, "old-refresh-token"
|
||||
) == {
|
||||
"access_token": "new-access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "sql",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "new-refresh-token",
|
||||
}
|
||||
requests.post.assert_called_with(
|
||||
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
|
||||
json={
|
||||
"client_id": "databricks-client-id",
|
||||
"client_secret": "databricks-client-secret",
|
||||
"refresh_token": "old-refresh-token",
|
||||
"grant_type": "refresh_token",
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
|
||||
def _oauth2_state() -> OAuth2State:
|
||||
"""
|
||||
Build the default OAuth2 state shared by the OAuth2 tests.
|
||||
"""
|
||||
state: OAuth2State = {
|
||||
"database_id": 1,
|
||||
"user_id": 1,
|
||||
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"tab_id": "1234",
|
||||
}
|
||||
return state
|
||||
|
||||
|
||||
def _unresolved_oauth2_config() -> OAuth2ClientConfig:
|
||||
"""
|
||||
Config as built by `get_oauth2_config` when no endpoints are overridden:
|
||||
the URIs default to the spec's empty class attributes.
|
||||
"""
|
||||
return {
|
||||
"id": "databricks-client-id",
|
||||
"secret": "databricks-client-secret",
|
||||
"scope": "sql",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"authorization_request_uri": "",
|
||||
"token_request_uri": "",
|
||||
"request_content_type": "json",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec",
|
||||
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"extra, netloc, path",
|
||||
[
|
||||
(
|
||||
{"cloud_provider": "aws", "account_id": "acct-999"},
|
||||
"accounts.cloud.databricks.com",
|
||||
"/oidc/accounts/acct-999/v1/authorize",
|
||||
),
|
||||
(
|
||||
{"cloud_provider": "azure", "tenant_id": "tenant-abc"},
|
||||
"login.microsoftonline.com",
|
||||
"/tenant-abc/oauth2/v2.0/authorize",
|
||||
),
|
||||
(
|
||||
{"cloud_provider": "gcp", "account_id": "acct-gcp"},
|
||||
"accounts.gcp.databricks.com",
|
||||
"/oidc/accounts/acct-gcp/v1/authorize",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_oauth2_authorization_uri_autodetects_and_resolves(
|
||||
mocker: MockerFixture,
|
||||
spec: Any,
|
||||
extra: dict[str, Any],
|
||||
netloc: str,
|
||||
path: str,
|
||||
) -> None:
|
||||
"""
|
||||
With no configured `authorization_request_uri`, the endpoint is auto-detected
|
||||
per cloud provider and the `account_id`/`tenant_id` placeholder is resolved
|
||||
from the database `extra`.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
database.extra = json.dumps(extra)
|
||||
database.url_object.host = "dbc-abc.cloud.databricks.com"
|
||||
mocker.patch("superset.db.session.get", return_value=database)
|
||||
|
||||
url = spec.get_oauth2_authorization_uri(
|
||||
_unresolved_oauth2_config(), _oauth2_state()
|
||||
)
|
||||
parsed = urlparse(url)
|
||||
assert parsed.netloc == netloc
|
||||
assert parsed.path == path
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec",
|
||||
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
|
||||
)
|
||||
def test_get_oauth2_authorization_uri_preserves_configured(
|
||||
mocker: MockerFixture,
|
||||
spec: Any,
|
||||
) -> None:
|
||||
"""
|
||||
A fully-resolved `authorization_request_uri` is never overwritten by the
|
||||
auto-detected template, and no database lookup is needed.
|
||||
"""
|
||||
session_get = mocker.patch("superset.db.session.get")
|
||||
config = _unresolved_oauth2_config()
|
||||
config["authorization_request_uri"] = (
|
||||
"https://accounts.cloud.databricks.com/oidc/accounts/override/v1/authorize"
|
||||
)
|
||||
|
||||
url = spec.get_oauth2_authorization_uri(config, _oauth2_state())
|
||||
assert urlparse(url).path == "/oidc/accounts/override/v1/authorize"
|
||||
session_get.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec",
|
||||
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
|
||||
)
|
||||
def test_get_oauth2_authorization_uri_fails_without_account_id(
|
||||
mocker: MockerFixture,
|
||||
spec: Any,
|
||||
) -> None:
|
||||
"""
|
||||
When the endpoint must be auto-detected but no `account_id`/`tenant_id` is
|
||||
available, fail fast instead of emitting an unresolved `.../{}/...` URL.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
database.extra = json.dumps({"cloud_provider": "aws"})
|
||||
database.url_object.host = "dbc-abc.cloud.databricks.com"
|
||||
mocker.patch("superset.db.session.get", return_value=database)
|
||||
|
||||
with pytest.raises(OAuth2Error):
|
||||
spec.get_oauth2_authorization_uri(_unresolved_oauth2_config(), _oauth2_state())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec",
|
||||
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
|
||||
)
|
||||
def test_get_oauth2_token_fails_without_uri(
|
||||
mocker: MockerFixture,
|
||||
spec: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Token exchange has no database context to auto-detect the endpoint, so a
|
||||
missing `token_request_uri` fails fast rather than POSTing to `.../{}/...`.
|
||||
"""
|
||||
with pytest.raises(OAuth2Error):
|
||||
spec.get_oauth2_token(_unresolved_oauth2_config(), "authorization-code")
|
||||
|
||||
|
||||
def test_get_oauth2_fresh_token_python(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config_python: OAuth2ClientConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Test `get_oauth2_fresh_token` for Python Connector engine.
|
||||
"""
|
||||
requests = mocker.patch("superset.db_engine_specs.base.requests")
|
||||
requests.post().json.return_value = {
|
||||
"access_token": "new-access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "sql",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "new-refresh-token",
|
||||
}
|
||||
|
||||
assert DatabricksPythonConnectorEngineSpec.get_oauth2_fresh_token(
|
||||
oauth2_config_python, "old-refresh-token"
|
||||
) == {
|
||||
"access_token": "new-access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "sql",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "new-refresh-token",
|
||||
}
|
||||
requests.post.assert_called_with(
|
||||
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
|
||||
json={
|
||||
"client_id": "databricks-client-id",
|
||||
"client_secret": "databricks-client-secret",
|
||||
"refresh_token": "old-refresh-token",
|
||||
"grant_type": "refresh_token",
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
@@ -1,327 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from superset.db_engine_specs.databricks import (
|
||||
DatabricksNativeEngineSpec,
|
||||
DatabricksPythonConnectorEngineSpec,
|
||||
)
|
||||
from superset.superset_typing import OAuth2ClientConfig
|
||||
from superset.utils.oauth2 import decode_oauth2_state
|
||||
|
||||
# Multi-Cloud Provider Tests
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_database_aws(mocker: MockerFixture) -> MagicMock:
|
||||
"""
|
||||
Mock database with AWS hostname.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
database.url_object.host = "my-cluster.cloud.databricks.com"
|
||||
database.extra = "{}"
|
||||
database.id = 1
|
||||
return database
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_database_azure(mocker: MockerFixture) -> MagicMock:
|
||||
"""
|
||||
Mock database with Azure hostname.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
database.url_object.host = "adb-123456789.12.azuredatabricks.net"
|
||||
database.extra = '{"tenant_id": "azure-tenant-id"}'
|
||||
database.id = 2
|
||||
return database
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_database_gcp(mocker: MockerFixture) -> MagicMock:
|
||||
"""
|
||||
Mock database with GCP hostname.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
database.url_object.host = "123456789.gcp.databricks.com"
|
||||
database.extra = '{"account_id": "12345"}'
|
||||
database.id = 3
|
||||
return database
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_config() -> OAuth2ClientConfig:
|
||||
"""
|
||||
Config for Databricks OAuth2 with a fully-resolved endpoint configured.
|
||||
"""
|
||||
return {
|
||||
"id": "databricks-client-id",
|
||||
"secret": "databricks-client-secret",
|
||||
"scope": "sql",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"authorization_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/authorize",
|
||||
"token_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
|
||||
"request_content_type": "json",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_config_no_uri() -> OAuth2ClientConfig:
|
||||
"""
|
||||
Config for Databricks OAuth2 without a pre-configured endpoint, so the
|
||||
per-provider endpoint is auto-detected and account-resolved.
|
||||
"""
|
||||
return {
|
||||
"id": "databricks-client-id",
|
||||
"secret": "databricks-client-secret",
|
||||
"scope": "sql",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"authorization_request_uri": "",
|
||||
"token_request_uri": "",
|
||||
"request_content_type": "json",
|
||||
}
|
||||
|
||||
|
||||
def test_cloud_provider_detection_aws(mock_database_aws: MagicMock) -> None:
|
||||
"""
|
||||
Test cloud provider detection for AWS.
|
||||
"""
|
||||
provider = DatabricksNativeEngineSpec._detect_cloud_provider(mock_database_aws)
|
||||
assert provider == "aws"
|
||||
|
||||
|
||||
def test_cloud_provider_detection_azure(mock_database_azure: MagicMock) -> None:
|
||||
"""
|
||||
Test cloud provider detection for Azure.
|
||||
"""
|
||||
provider = DatabricksNativeEngineSpec._detect_cloud_provider(mock_database_azure)
|
||||
assert provider == "azure"
|
||||
|
||||
|
||||
def test_cloud_provider_detection_gcp(mock_database_gcp: MagicMock) -> None:
|
||||
"""
|
||||
Test cloud provider detection for GCP.
|
||||
"""
|
||||
provider = DatabricksNativeEngineSpec._detect_cloud_provider(mock_database_gcp)
|
||||
assert provider == "gcp"
|
||||
|
||||
|
||||
def test_cloud_provider_detection_explicit_config(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test cloud provider detection with explicit configuration.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
database.url_object.host = "generic-host.com"
|
||||
|
||||
# Mock get_extra_params to return explicit cloud provider
|
||||
mocker.patch.object(
|
||||
DatabricksNativeEngineSpec,
|
||||
"get_extra_params",
|
||||
return_value={"cloud_provider": "azure"},
|
||||
)
|
||||
|
||||
provider = DatabricksNativeEngineSpec._detect_cloud_provider(database)
|
||||
assert provider == "azure"
|
||||
|
||||
|
||||
def test_cloud_provider_detection_invalid_config_falls_back_to_hostname(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
An unrecognized explicit `cloud_provider` is ignored and detection falls
|
||||
back to the hostname rather than raising or returning the bad value.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
database.url_object.host = "adb-123456789.12.azuredatabricks.net"
|
||||
|
||||
mocker.patch.object(
|
||||
DatabricksNativeEngineSpec,
|
||||
"get_extra_params",
|
||||
return_value={"cloud_provider": "oracle"},
|
||||
)
|
||||
|
||||
provider = DatabricksNativeEngineSpec._detect_cloud_provider(database)
|
||||
assert provider == "azure"
|
||||
|
||||
|
||||
def test_cloud_provider_detection_non_string_falls_back_to_hostname(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
A non-string `cloud_provider` (e.g. a boolean from malformed JSON) is
|
||||
ignored without raising and detection falls back to the hostname.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
database.url_object.host = "adb-123456789.12.azuredatabricks.net"
|
||||
|
||||
mocker.patch.object(
|
||||
DatabricksNativeEngineSpec,
|
||||
"get_extra_params",
|
||||
return_value={"cloud_provider": True},
|
||||
)
|
||||
|
||||
provider = DatabricksNativeEngineSpec._detect_cloud_provider(database)
|
||||
assert provider == "azure"
|
||||
|
||||
|
||||
def test_get_oauth2_authorization_uri_aws(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config: OAuth2ClientConfig,
|
||||
mock_database_aws: MagicMock,
|
||||
) -> None:
|
||||
"""
|
||||
Test OAuth2 authorization URI generation for AWS provider.
|
||||
"""
|
||||
from superset.db_engine_specs.base import OAuth2State
|
||||
|
||||
# Mock the database query
|
||||
mocker.patch("superset.extensions.db.session.get", return_value=mock_database_aws)
|
||||
|
||||
state: OAuth2State = {
|
||||
"database_id": 1,
|
||||
"user_id": 1,
|
||||
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"tab_id": "1234",
|
||||
}
|
||||
|
||||
url = DatabricksNativeEngineSpec.get_oauth2_authorization_uri(oauth2_config, state)
|
||||
parsed = urlparse(url)
|
||||
assert parsed.netloc == "accounts.cloud.databricks.com"
|
||||
assert "/oidc/accounts/" in parsed.path
|
||||
assert "/v1/authorize" in parsed.path
|
||||
|
||||
query = parse_qs(parsed.query)
|
||||
assert query["scope"][0] == "sql"
|
||||
encoded_state = query["state"][0].replace("%2E", ".")
|
||||
assert decode_oauth2_state(encoded_state) == state
|
||||
|
||||
|
||||
def test_get_oauth2_authorization_uri_azure(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config_no_uri: OAuth2ClientConfig,
|
||||
mock_database_azure: MagicMock,
|
||||
) -> None:
|
||||
"""
|
||||
Test OAuth2 authorization URI generation for Azure provider.
|
||||
"""
|
||||
from superset.db_engine_specs.base import OAuth2State
|
||||
|
||||
# Mock the database query
|
||||
mocker.patch("superset.extensions.db.session.get", return_value=mock_database_azure)
|
||||
|
||||
state: OAuth2State = {
|
||||
"database_id": 2,
|
||||
"user_id": 1,
|
||||
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"tab_id": "1234",
|
||||
}
|
||||
|
||||
url = DatabricksNativeEngineSpec.get_oauth2_authorization_uri(
|
||||
oauth2_config_no_uri, state
|
||||
)
|
||||
parsed = urlparse(url)
|
||||
assert parsed.netloc == "login.microsoftonline.com"
|
||||
assert "/oauth2/v2.0/authorize" in parsed.path
|
||||
|
||||
query = parse_qs(parsed.query)
|
||||
assert query["scope"][0] == "sql"
|
||||
encoded_state = query["state"][0].replace("%2E", ".")
|
||||
assert decode_oauth2_state(encoded_state) == state
|
||||
|
||||
|
||||
def test_get_oauth2_authorization_uri_gcp(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config_no_uri: OAuth2ClientConfig,
|
||||
mock_database_gcp: MagicMock,
|
||||
) -> None:
|
||||
"""
|
||||
Test OAuth2 authorization URI generation for GCP provider.
|
||||
"""
|
||||
from superset.db_engine_specs.base import OAuth2State
|
||||
|
||||
# Mock the database query
|
||||
mocker.patch("superset.extensions.db.session.get", return_value=mock_database_gcp)
|
||||
|
||||
state: OAuth2State = {
|
||||
"database_id": 3,
|
||||
"user_id": 1,
|
||||
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"tab_id": "1234",
|
||||
}
|
||||
|
||||
url = DatabricksNativeEngineSpec.get_oauth2_authorization_uri(
|
||||
oauth2_config_no_uri, state
|
||||
)
|
||||
parsed = urlparse(url)
|
||||
assert parsed.netloc == "accounts.gcp.databricks.com"
|
||||
assert "/oidc/accounts/" in parsed.path
|
||||
assert "/v1/authorize" in parsed.path
|
||||
|
||||
query = parse_qs(parsed.query)
|
||||
assert query["scope"][0] == "sql"
|
||||
encoded_state = query["state"][0].replace("%2E", ".")
|
||||
assert decode_oauth2_state(encoded_state) == state
|
||||
|
||||
|
||||
def test_python_connector_cloud_provider_detection_azure(
|
||||
mock_database_azure: MagicMock,
|
||||
) -> None:
|
||||
"""
|
||||
Test cloud provider detection for Python Connector with Azure.
|
||||
"""
|
||||
provider = DatabricksPythonConnectorEngineSpec._detect_cloud_provider(
|
||||
mock_database_azure
|
||||
)
|
||||
assert provider == "azure"
|
||||
|
||||
|
||||
def test_python_connector_oauth2_authorization_uri_azure(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config_no_uri: OAuth2ClientConfig,
|
||||
mock_database_azure: MagicMock,
|
||||
) -> None:
|
||||
"""
|
||||
Test OAuth2 authorization URI generation for Python Connector with Azure provider.
|
||||
"""
|
||||
from superset.db_engine_specs.base import OAuth2State
|
||||
|
||||
# Mock the database query
|
||||
mocker.patch("superset.extensions.db.session.get", return_value=mock_database_azure)
|
||||
|
||||
state: OAuth2State = {
|
||||
"database_id": 2,
|
||||
"user_id": 1,
|
||||
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"tab_id": "1234",
|
||||
}
|
||||
|
||||
url = DatabricksPythonConnectorEngineSpec.get_oauth2_authorization_uri(
|
||||
oauth2_config_no_uri, state
|
||||
)
|
||||
parsed = urlparse(url)
|
||||
assert parsed.netloc == "login.microsoftonline.com"
|
||||
assert "/oauth2/v2.0/authorize" in parsed.path
|
||||
|
||||
query = parse_qs(parsed.query)
|
||||
assert query["scope"][0] == "sql"
|
||||
encoded_state = query["state"][0].replace("%2E", ".")
|
||||
assert decode_oauth2_state(encoded_state) == state
|
||||
@@ -30,10 +30,12 @@ from pydantic import ValidationError
|
||||
from superset.mcp_service.dashboard.schemas import (
|
||||
_extract_cross_filters_enabled,
|
||||
_extract_native_filters,
|
||||
_safe_user_label,
|
||||
dashboard_serializer,
|
||||
GenerateDashboardRequest,
|
||||
serialize_chart_summary,
|
||||
serialize_dashboard_object,
|
||||
UpdateDashboardRequest,
|
||||
)
|
||||
from superset.mcp_service.utils.sanitization import (
|
||||
LLM_CONTEXT_CLOSE_DELIMITER,
|
||||
@@ -87,8 +89,8 @@ def _mock_dashboard(
|
||||
class TestSerializeDashboardObject:
|
||||
"""Tests for serialize_dashboard_object slug handling."""
|
||||
|
||||
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
|
||||
def test_slug_none_returns_empty_string(self, mock_base_url):
|
||||
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
|
||||
def test_slug_none_returns_empty_string(self, mock_base_url) -> None:
|
||||
"""Dashboards with slug=None should return slug="" for consistency
|
||||
with dashboard_serializer."""
|
||||
mock_base_url.return_value = "http://localhost:8088"
|
||||
@@ -98,8 +100,8 @@ class TestSerializeDashboardObject:
|
||||
|
||||
assert result.slug == ""
|
||||
|
||||
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
|
||||
def test_slug_empty_string_returns_empty_string(self, mock_base_url):
|
||||
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
|
||||
def test_slug_empty_string_returns_empty_string(self, mock_base_url) -> None:
|
||||
"""Dashboards with slug="" should return slug=""."""
|
||||
mock_base_url.return_value = "http://localhost:8088"
|
||||
|
||||
@@ -108,8 +110,8 @@ class TestSerializeDashboardObject:
|
||||
|
||||
assert result.slug == ""
|
||||
|
||||
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
|
||||
def test_slug_with_value_preserved(self, mock_base_url):
|
||||
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
|
||||
def test_slug_with_value_preserved(self, mock_base_url) -> None:
|
||||
"""Dashboards with a real slug should preserve it."""
|
||||
mock_base_url.return_value = "http://localhost:8088"
|
||||
|
||||
@@ -118,8 +120,8 @@ class TestSerializeDashboardObject:
|
||||
|
||||
assert result.slug == "my-dashboard"
|
||||
|
||||
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
|
||||
def test_url_uses_id_when_no_slug(self, mock_base_url):
|
||||
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
|
||||
def test_url_uses_id_when_no_slug(self, mock_base_url) -> None:
|
||||
"""URL should use dashboard id when slug is None."""
|
||||
mock_base_url.return_value = "http://localhost:8088"
|
||||
|
||||
@@ -128,8 +130,8 @@ class TestSerializeDashboardObject:
|
||||
|
||||
assert result.url == "http://localhost:8088/superset/dashboard/42/"
|
||||
|
||||
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
|
||||
def test_url_uses_slug_when_available(self, mock_base_url):
|
||||
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
|
||||
def test_url_uses_slug_when_available(self, mock_base_url) -> None:
|
||||
"""URL should use slug when available."""
|
||||
mock_base_url.return_value = "http://localhost:8088"
|
||||
|
||||
@@ -138,8 +140,8 @@ class TestSerializeDashboardObject:
|
||||
|
||||
assert result.url == "http://localhost:8088/superset/dashboard/my-dashboard/"
|
||||
|
||||
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
|
||||
def test_no_json_metadata_or_position_json_in_response(self, mock_base_url):
|
||||
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
|
||||
def test_no_json_metadata_or_position_json_in_response(self, mock_base_url) -> None:
|
||||
"""DashboardInfo should not contain json_metadata or position_json."""
|
||||
mock_base_url.return_value = "http://localhost:8088"
|
||||
|
||||
@@ -150,7 +152,7 @@ class TestSerializeDashboardObject:
|
||||
assert not hasattr(result, "position_json")
|
||||
|
||||
@patch("superset.mcp_service.dashboard.schemas.user_can_view_data_model_metadata")
|
||||
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
|
||||
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
|
||||
def test_native_filters_extracted_from_json_metadata(
|
||||
self,
|
||||
mock_base_url,
|
||||
@@ -196,7 +198,7 @@ class TestSerializeDashboardObject:
|
||||
assert result.cross_filters_enabled is True
|
||||
|
||||
@patch("superset.mcp_service.dashboard.schemas.user_can_view_data_model_metadata")
|
||||
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
|
||||
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
|
||||
def test_restricted_user_redacts_native_filter_targets(
|
||||
self,
|
||||
mock_base_url,
|
||||
@@ -230,7 +232,7 @@ class TestSerializeDashboardObject:
|
||||
assert result.cross_filters_enabled is True
|
||||
|
||||
@patch("superset.mcp_service.dashboard.schemas.user_can_view_data_model_metadata")
|
||||
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
|
||||
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
|
||||
def test_chart_summaries_are_lightweight(
|
||||
self,
|
||||
mock_base_url,
|
||||
@@ -262,7 +264,7 @@ class TestSerializeDashboardObject:
|
||||
assert not hasattr(result.charts[0], "owners")
|
||||
|
||||
@patch("superset.mcp_service.dashboard.schemas.user_can_view_data_model_metadata")
|
||||
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
|
||||
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
|
||||
def test_restricted_user_redacts_chart_datasource_name(
|
||||
self,
|
||||
mock_base_url,
|
||||
@@ -288,7 +290,7 @@ class TestSerializeDashboardObject:
|
||||
assert result.charts[0].url == "http://localhost:8088/explore/?slice_id=5"
|
||||
|
||||
@patch("superset.mcp_service.dashboard.schemas.user_can_view_data_model_metadata")
|
||||
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
|
||||
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
|
||||
def test_dashboard_serializer_restricted_user_redacts_data_model_metadata(
|
||||
self,
|
||||
mock_base_url,
|
||||
@@ -327,7 +329,7 @@ class TestSerializeDashboardObject:
|
||||
assert result.native_filters[0].targets == []
|
||||
|
||||
@patch("superset.mcp_service.dashboard.schemas.user_can_view_data_model_metadata")
|
||||
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
|
||||
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
|
||||
def test_descriptive_fields_are_sanitized(
|
||||
self,
|
||||
mock_base_url: MagicMock,
|
||||
@@ -393,22 +395,22 @@ class TestSerializeDashboardObject:
|
||||
class TestExtractNativeFilters:
|
||||
"""Tests for _extract_native_filters helper."""
|
||||
|
||||
def test_none_input(self):
|
||||
def test_none_input(self) -> None:
|
||||
assert _extract_native_filters(None) == []
|
||||
|
||||
def test_empty_string(self):
|
||||
def test_empty_string(self) -> None:
|
||||
assert _extract_native_filters("") == []
|
||||
|
||||
def test_invalid_json(self):
|
||||
def test_invalid_json(self) -> None:
|
||||
assert _extract_native_filters("not json") == []
|
||||
|
||||
def test_no_filter_config(self):
|
||||
def test_no_filter_config(self) -> None:
|
||||
assert _extract_native_filters("{}") == []
|
||||
|
||||
def test_non_list_filter_config(self):
|
||||
def test_non_list_filter_config(self) -> None:
|
||||
assert _extract_native_filters('{"native_filter_configuration": "bad"}') == []
|
||||
|
||||
def test_valid_filters(self):
|
||||
def test_valid_filters(self) -> None:
|
||||
metadata = json_dumps(
|
||||
{
|
||||
"native_filter_configuration": [
|
||||
@@ -428,7 +430,7 @@ class TestExtractNativeFilters:
|
||||
assert result[0].filter_type == "filter_select"
|
||||
assert result[0].targets == []
|
||||
|
||||
def test_valid_filters_include_targets_when_metadata_allowed(self):
|
||||
def test_valid_filters_include_targets_when_metadata_allowed(self) -> None:
|
||||
metadata = json_dumps(
|
||||
{
|
||||
"native_filter_configuration": [
|
||||
@@ -447,14 +449,14 @@ class TestExtractNativeFilters:
|
||||
)
|
||||
assert result[0].targets == [{"column": {"name": "col1"}}]
|
||||
|
||||
def test_skips_non_dict_entries(self):
|
||||
def test_skips_non_dict_entries(self) -> None:
|
||||
metadata = json_dumps(
|
||||
{"native_filter_configuration": [{"id": "f1", "name": "ok"}, "bad", 123]}
|
||||
)
|
||||
result = _extract_native_filters(metadata)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_non_dict_top_level_json(self):
|
||||
def test_non_dict_top_level_json(self) -> None:
|
||||
"""json_metadata that parses to a list/number should return empty."""
|
||||
assert _extract_native_filters("[]") == []
|
||||
assert _extract_native_filters("123") == []
|
||||
@@ -464,26 +466,26 @@ class TestExtractNativeFilters:
|
||||
class TestExtractCrossFiltersEnabled:
|
||||
"""Tests for _extract_cross_filters_enabled helper."""
|
||||
|
||||
def test_none_input(self):
|
||||
def test_none_input(self) -> None:
|
||||
assert _extract_cross_filters_enabled(None) is None
|
||||
|
||||
def test_empty_json(self):
|
||||
def test_empty_json(self) -> None:
|
||||
assert _extract_cross_filters_enabled("{}") is None
|
||||
|
||||
def test_true(self):
|
||||
def test_true(self) -> None:
|
||||
assert _extract_cross_filters_enabled('{"cross_filters_enabled": true}') is True
|
||||
|
||||
def test_false(self):
|
||||
def test_false(self) -> None:
|
||||
assert (
|
||||
_extract_cross_filters_enabled('{"cross_filters_enabled": false}') is False
|
||||
)
|
||||
|
||||
def test_non_bool_value(self):
|
||||
def test_non_bool_value(self) -> None:
|
||||
assert (
|
||||
_extract_cross_filters_enabled('{"cross_filters_enabled": "yes"}') is None
|
||||
)
|
||||
|
||||
def test_non_dict_top_level_json(self):
|
||||
def test_non_dict_top_level_json(self) -> None:
|
||||
"""json_metadata that parses to a list/number should return None."""
|
||||
assert _extract_cross_filters_enabled("[]") is None
|
||||
assert _extract_cross_filters_enabled("123") is None
|
||||
@@ -493,7 +495,7 @@ class TestExtractCrossFiltersEnabled:
|
||||
class TestSerializeChartSummary:
|
||||
"""Tests for serialize_chart_summary helper."""
|
||||
|
||||
def test_datasource_name_redacted_by_default(self):
|
||||
def test_datasource_name_redacted_by_default(self) -> None:
|
||||
chart = MagicMock()
|
||||
chart.id = 5
|
||||
chart.slice_name = "Revenue Chart"
|
||||
@@ -510,7 +512,7 @@ class TestSerializeChartSummary:
|
||||
class TestOmittedFieldsBuilder:
|
||||
"""Tests for the shared OmittedFieldsBuilder utility."""
|
||||
|
||||
def test_builder_basic(self):
|
||||
def test_builder_basic(self) -> None:
|
||||
from superset.mcp_service.utils.response_utils import OmittedFieldsBuilder
|
||||
|
||||
result = (
|
||||
@@ -525,7 +527,7 @@ class TestOmittedFieldsBuilder:
|
||||
assert "meta_field" in result
|
||||
assert "extracted" in result["meta_field"]
|
||||
|
||||
def test_builder_none_values(self):
|
||||
def test_builder_none_values(self) -> None:
|
||||
from superset.mcp_service.utils.response_utils import OmittedFieldsBuilder
|
||||
|
||||
result = (
|
||||
@@ -537,8 +539,8 @@ class TestOmittedFieldsBuilder:
|
||||
assert "empty" in result["empty_field"]
|
||||
assert "empty" in result["also_empty"]
|
||||
|
||||
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
|
||||
def test_omitted_fields_in_serialized_dashboard(self, mock_base_url):
|
||||
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
|
||||
def test_omitted_fields_in_serialized_dashboard(self, mock_base_url) -> None:
|
||||
"""omitted_fields should describe what was stripped and include sizes."""
|
||||
mock_base_url.return_value = "http://localhost:8088"
|
||||
|
||||
@@ -555,8 +557,8 @@ class TestOmittedFieldsBuilder:
|
||||
assert "extracted" in result.omitted_fields["json_metadata"]
|
||||
assert "layout tree" in result.omitted_fields["position_json"].lower()
|
||||
|
||||
@patch("superset.mcp_service.utils.url_utils.get_superset_base_url")
|
||||
def test_omitted_fields_with_none_values(self, mock_base_url):
|
||||
@patch("superset.mcp_service.dashboard.schemas.get_superset_base_url")
|
||||
def test_omitted_fields_with_none_values(self, mock_base_url) -> None:
|
||||
"""omitted_fields should still be present when raw fields are None."""
|
||||
mock_base_url.return_value = "http://localhost:8088"
|
||||
|
||||
@@ -625,3 +627,212 @@ class TestGenerateDashboardRequestTitleSanitization:
|
||||
assert len(req.sanitization_warnings) == 1
|
||||
assert "dashboard_title" in req.sanitization_warnings[0]
|
||||
assert "injected" not in req.sanitization_warnings[0]
|
||||
|
||||
|
||||
class TestGenerateDashboardRequestLayoutTheme:
|
||||
"""generate_dashboard accepts optional position_json, theme overrides, CSS."""
|
||||
|
||||
def test_layout_theme_css_fields_default_to_none(self) -> None:
|
||||
req = GenerateDashboardRequest(chart_ids=[1])
|
||||
assert req.position_json is None
|
||||
assert req.json_metadata_overrides is None
|
||||
assert req.css is None
|
||||
assert req.slug is None
|
||||
|
||||
def test_position_json_accepted(self) -> None:
|
||||
position = {
|
||||
"ROOT_ID": {"children": ["GRID_ID"], "type": "ROOT"},
|
||||
"GRID_ID": {"children": ["ROW-1"], "type": "GRID"},
|
||||
}
|
||||
req = GenerateDashboardRequest(chart_ids=[1], position_json=position)
|
||||
assert req.position_json == position
|
||||
|
||||
def test_json_metadata_overrides_accepted(self) -> None:
|
||||
overrides = {
|
||||
"label_colors": {"Electronics": "#4C78A8"},
|
||||
"cross_filters_enabled": False,
|
||||
}
|
||||
req = GenerateDashboardRequest(chart_ids=[1], json_metadata_overrides=overrides)
|
||||
assert req.json_metadata_overrides == overrides
|
||||
|
||||
def test_css_accepted(self) -> None:
|
||||
req = GenerateDashboardRequest(
|
||||
chart_ids=[1], css=".header-controls{display:none}"
|
||||
)
|
||||
assert req.css == ".header-controls{display:none}"
|
||||
|
||||
def test_slug_accepted(self) -> None:
|
||||
req = GenerateDashboardRequest(chart_ids=[1], slug="my-dashboard")
|
||||
assert req.slug == "my-dashboard"
|
||||
|
||||
def test_css_max_length_enforced(self) -> None:
|
||||
with pytest.raises(ValidationError, match="at most 50000"):
|
||||
GenerateDashboardRequest(chart_ids=[1], css="x" * 50001)
|
||||
|
||||
def test_title_alias_accepted(self) -> None:
|
||||
"""``title`` is one of the AliasChoices for ``dashboard_title``
|
||||
— JSON callers using either name resolve to the same field."""
|
||||
req = GenerateDashboardRequest(chart_ids=[1], title="Q4 Review")
|
||||
assert req.dashboard_title == "Q4 Review"
|
||||
|
||||
def test_name_alias_accepted(self) -> None:
|
||||
"""``name`` is the third AliasChoice for ``dashboard_title``
|
||||
and must resolve identically to ``title`` and ``dashboard_title``."""
|
||||
req = GenerateDashboardRequest(chart_ids=[1], name="Q4 Review")
|
||||
assert req.dashboard_title == "Q4 Review"
|
||||
|
||||
def test_published_defaults_to_false(self) -> None:
|
||||
"""``published`` defaults to False — newly generated dashboards
|
||||
are drafts by default to avoid accidentally publishing partial
|
||||
work-in-progress to all users."""
|
||||
req = GenerateDashboardRequest(chart_ids=[1])
|
||||
assert req.published is False
|
||||
|
||||
def test_published_true_accepted(self) -> None:
|
||||
"""An explicit ``published=True`` is preserved."""
|
||||
req = GenerateDashboardRequest(chart_ids=[1], published=True)
|
||||
assert req.published is True
|
||||
|
||||
|
||||
class TestUpdateDashboardRequest:
|
||||
"""Schema validation for update_dashboard's request."""
|
||||
|
||||
def test_identifier_required(self) -> None:
|
||||
with pytest.raises(ValidationError, match="Field required"):
|
||||
UpdateDashboardRequest()
|
||||
|
||||
def test_int_identifier_accepted(self) -> None:
|
||||
req = UpdateDashboardRequest(identifier=42)
|
||||
assert req.identifier == 42
|
||||
|
||||
def test_string_identifier_accepted(self) -> None:
|
||||
req = UpdateDashboardRequest(identifier="my-slug")
|
||||
assert req.identifier == "my-slug"
|
||||
|
||||
def test_all_optional_fields_default_to_none(self) -> None:
|
||||
req = UpdateDashboardRequest(identifier=1)
|
||||
assert req.dashboard_title is None
|
||||
assert req.description is None
|
||||
assert req.slug is None
|
||||
assert req.published is None
|
||||
assert req.position_json is None
|
||||
assert req.json_metadata_overrides is None
|
||||
assert req.css is None
|
||||
|
||||
def test_position_json_and_overrides_and_css(self) -> None:
|
||||
req = UpdateDashboardRequest(
|
||||
identifier=42,
|
||||
position_json={"ROOT_ID": {"type": "ROOT"}},
|
||||
json_metadata_overrides={"cross_filters_enabled": True},
|
||||
css=".x{}",
|
||||
)
|
||||
assert req.position_json == {"ROOT_ID": {"type": "ROOT"}}
|
||||
assert req.json_metadata_overrides == {"cross_filters_enabled": True}
|
||||
assert req.css == ".x{}"
|
||||
|
||||
def test_title_alias_accepted(self) -> None:
|
||||
"""`title` is accepted as an alias for `dashboard_title`."""
|
||||
req = UpdateDashboardRequest(identifier=1, title="New Title")
|
||||
assert req.dashboard_title == "New Title"
|
||||
|
||||
def test_name_alias_accepted(self) -> None:
|
||||
"""`name` is accepted as an alias for `dashboard_title` — mirrors
|
||||
``GenerateDashboardRequest``'s third AliasChoice so callers can
|
||||
use the same key name on both create and update paths."""
|
||||
req = UpdateDashboardRequest(identifier=1, name="New Title")
|
||||
assert req.dashboard_title == "New Title"
|
||||
|
||||
def test_css_max_length_enforced(self) -> None:
|
||||
with pytest.raises(ValidationError, match="at most 50000"):
|
||||
UpdateDashboardRequest(identifier=1, css="x" * 50001)
|
||||
|
||||
def test_title_partial_strip_emits_warning(self) -> None:
|
||||
"""Mirror of ``test_title_partial_strip_emits_warning`` on the
|
||||
create path — sanitization removes the HTML, the title survives,
|
||||
and a warning records that the input was altered."""
|
||||
req = UpdateDashboardRequest(identifier=1, dashboard_title="Q1 <b>Review</b>")
|
||||
assert req.dashboard_title == "Q1 Review"
|
||||
assert len(req.sanitization_warnings) == 1
|
||||
assert "dashboard_title" in req.sanitization_warnings[0]
|
||||
|
||||
def test_client_supplied_warnings_are_discarded(self) -> None:
|
||||
"""``sanitization_warnings`` is server-only on the update path
|
||||
too — caller-supplied entries are dropped so an attacker cannot
|
||||
smuggle warning text through the response."""
|
||||
req = UpdateDashboardRequest(
|
||||
identifier=1,
|
||||
dashboard_title="Clean Title",
|
||||
sanitization_warnings=["<script>injected</script>"],
|
||||
)
|
||||
assert req.sanitization_warnings == []
|
||||
|
||||
def test_title_xss_only_rejected_at_schema_level(self) -> None:
|
||||
"""An XSS-only title is rejected by the Pydantic validator
|
||||
before the tool ever runs — matches the create path's guard."""
|
||||
with pytest.raises(ValidationError, match="removed entirely"):
|
||||
UpdateDashboardRequest(
|
||||
identifier=1,
|
||||
dashboard_title="<script>alert(1)</script>",
|
||||
)
|
||||
|
||||
def test_title_alias_xss_rejected(self) -> None:
|
||||
"""The ``title`` alias resolves to the same sanitized field, so
|
||||
XSS-only input supplied via the alias must be rejected with the
|
||||
same guard. Otherwise an attacker could bypass sanitization just
|
||||
by choosing a different request key."""
|
||||
with pytest.raises(ValidationError, match="removed entirely"):
|
||||
UpdateDashboardRequest(
|
||||
identifier=1,
|
||||
title="<script>alert(1)</script>",
|
||||
)
|
||||
|
||||
def test_name_alias_xss_rejected(self) -> None:
|
||||
"""Same as ``test_title_alias_xss_rejected`` for the ``name``
|
||||
AliasChoice — every alias funnels through the same validator,
|
||||
not just the canonical field name."""
|
||||
with pytest.raises(ValidationError, match="removed entirely"):
|
||||
UpdateDashboardRequest(
|
||||
identifier=1,
|
||||
name="<script>alert(1)</script>",
|
||||
)
|
||||
|
||||
|
||||
class TestSafeUserLabel:
|
||||
"""``_safe_user_label`` defensively coerces ``*_by_name`` attributes
|
||||
so dashboard serialization never leaks a ``repr(user)`` or trips
|
||||
Pydantic with a non-string value."""
|
||||
|
||||
def test_plain_string_passes_through(self) -> None:
|
||||
assert _safe_user_label("alice") == "alice"
|
||||
|
||||
def test_empty_string_returns_none(self) -> None:
|
||||
"""Empty string is collapsed to None so the response carries
|
||||
an explicit "no author" signal rather than a misleading ""."""
|
||||
assert _safe_user_label("") is None
|
||||
|
||||
def test_none_returns_none(self) -> None:
|
||||
assert _safe_user_label(None) is None
|
||||
|
||||
def test_mock_object_returns_none(self) -> None:
|
||||
"""Mocks (and anything else non-string) become None — this is
|
||||
the case the helper was specifically introduced to handle."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
assert _safe_user_label(MagicMock()) is None
|
||||
|
||||
def test_user_object_returns_none(self) -> None:
|
||||
"""A User instance also coerces to None rather than leaking
|
||||
``repr(user)`` (which can contain memory addresses, hashes, or
|
||||
internal id fields). Callers that want a user display name
|
||||
should resolve it explicitly via ``created_by_name`` upstream."""
|
||||
|
||||
class _User:
|
||||
def __repr__(self) -> str:
|
||||
return "<User id=42 username='alice'>"
|
||||
|
||||
assert _safe_user_label(_User()) is None
|
||||
|
||||
def test_integer_returns_none(self) -> None:
|
||||
"""Numbers and other non-string scalars also coerce to None
|
||||
rather than being str-cast and silently leaking the value."""
|
||||
assert _safe_user_label(42) is None
|
||||
|
||||
@@ -206,8 +206,10 @@ class TestGenerateDashboard:
|
||||
== "Analytics Dashboard"
|
||||
)
|
||||
assert result.structured_content["dashboard"]["chart_count"] == 2
|
||||
# URL prefers the slug over the id, matching update_dashboard.
|
||||
assert (
|
||||
"/superset/dashboard/10/" in result.structured_content["dashboard_url"]
|
||||
"/superset/dashboard/test-dashboard-10/"
|
||||
in result.structured_content["dashboard_url"]
|
||||
)
|
||||
|
||||
@patch("superset.models.dashboard.Dashboard")
|
||||
@@ -461,6 +463,117 @@ class TestGenerateDashboard:
|
||||
# rollback called by tool + event_logger error handling
|
||||
assert mock_db_session.rollback.call_count >= 1
|
||||
|
||||
@patch("superset.models.dashboard.Dashboard")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_duplicate_slug_returns_actionable_error(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_dashboard_cls,
|
||||
mcp_server,
|
||||
) -> None:
|
||||
"""When the supplied slug collides with an existing dashboard,
|
||||
``commit()`` raises ``IntegrityError``. The tool catches it,
|
||||
recognises the slug-uniqueness violation, and returns a
|
||||
structured error naming the offending slug so the LLM can
|
||||
propose a different one — instead of the generic
|
||||
"internal error" message used for other DB failures."""
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
mock_query = Mock()
|
||||
mock_filter = Mock()
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_query.filter_by.return_value = mock_filter
|
||||
mock_filter.order_by.return_value = mock_filter
|
||||
mock_filter.all.return_value = [_mock_chart(id=1)]
|
||||
mock_filter.first.return_value = Mock(
|
||||
id=1,
|
||||
username="admin",
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
email="admin@example.com",
|
||||
active=True,
|
||||
)
|
||||
mock_db_session.query.return_value = mock_query
|
||||
# Postgres-shaped IntegrityError naming the slug constraint.
|
||||
mock_db_session.commit.side_effect = IntegrityError(
|
||||
statement="INSERT INTO dashboards ...",
|
||||
params={},
|
||||
orig=Exception(
|
||||
"duplicate key value violates unique constraint "
|
||||
'"dashboards_slug_key"\n'
|
||||
"DETAIL: Key (slug)=(my-slug) already exists."
|
||||
),
|
||||
)
|
||||
mock_dashboard_cls.return_value = _mock_dashboard(id=100)
|
||||
|
||||
request = {
|
||||
"chart_ids": [1],
|
||||
"dashboard_title": "Q4 Review",
|
||||
"slug": "my-slug",
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("generate_dashboard", {"request": request})
|
||||
|
||||
err = result.structured_content["error"]
|
||||
assert err is not None
|
||||
assert "my-slug" in err
|
||||
assert "already in use" in err
|
||||
assert "different slug" in err or "Choose a different" in err
|
||||
assert result.structured_content["dashboard"] is None
|
||||
assert mock_db_session.rollback.call_count >= 1
|
||||
|
||||
@patch("superset.models.dashboard.Dashboard")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_integrity_error_unrelated_to_slug(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_dashboard_cls,
|
||||
mcp_server,
|
||||
) -> None:
|
||||
"""An IntegrityError that is NOT about the slug (e.g. an FK
|
||||
violation on chart_id) gets the generic constraint message
|
||||
rather than the slug-specific one — slug detection must not
|
||||
match every IntegrityError indiscriminately."""
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
mock_query = Mock()
|
||||
mock_filter = Mock()
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_query.filter_by.return_value = mock_filter
|
||||
mock_filter.order_by.return_value = mock_filter
|
||||
mock_filter.all.return_value = [_mock_chart(id=1)]
|
||||
mock_filter.first.return_value = Mock(
|
||||
id=1,
|
||||
username="admin",
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
email="admin@example.com",
|
||||
active=True,
|
||||
)
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_db_session.commit.side_effect = IntegrityError(
|
||||
statement="INSERT INTO dashboard_slices ...",
|
||||
params={},
|
||||
orig=Exception(
|
||||
'violates foreign key constraint "dashboard_slices_slice_id_fkey"'
|
||||
),
|
||||
)
|
||||
mock_dashboard_cls.return_value = _mock_dashboard(id=101)
|
||||
|
||||
# No slug → the slug-detection branch must not match.
|
||||
request = {"chart_ids": [1], "dashboard_title": "No Slug Dashboard"}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("generate_dashboard", {"request": request})
|
||||
|
||||
err = result.structured_content["error"]
|
||||
assert err is not None
|
||||
assert "constraint" in err
|
||||
assert "slug" not in err.lower()
|
||||
|
||||
@patch("superset.models.dashboard.Dashboard")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.db.session")
|
||||
@@ -575,6 +688,160 @@ class TestGenerateDashboard:
|
||||
created = mock_dashboard_cls.return_value
|
||||
assert created.dashboard_title == ""
|
||||
|
||||
@patch("superset.models.dashboard.Dashboard")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_position_json_override(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_find_by_id,
|
||||
mock_dashboard_cls,
|
||||
mcp_server,
|
||||
) -> None:
|
||||
"""An explicit ``position_json`` replaces the auto-generated layout
|
||||
in full — the tool serializes the caller's dict verbatim into the
|
||||
dashboard's ``position_json`` column rather than calling the
|
||||
layout-builder helper."""
|
||||
from superset.utils import json
|
||||
|
||||
charts = [_mock_chart(id=1, slice_name="Sales")]
|
||||
mock_dashboard = _mock_dashboard(id=70, title="Custom Layout Dashboard")
|
||||
_setup_generate_dashboard_mocks(
|
||||
mock_db_session,
|
||||
mock_find_by_id,
|
||||
mock_dashboard_cls,
|
||||
charts,
|
||||
mock_dashboard,
|
||||
)
|
||||
|
||||
custom_layout = {
|
||||
"ROOT_ID": {"type": "ROOT", "children": ["GRID_ID"]},
|
||||
"GRID_ID": {"type": "GRID", "children": ["ROW-custom"]},
|
||||
"ROW-custom": {
|
||||
"type": "ROW",
|
||||
"children": ["CHART-1"],
|
||||
"meta": {"background": "BACKGROUND_TRANSPARENT"},
|
||||
},
|
||||
"CHART-1": {
|
||||
"type": "CHART",
|
||||
"children": [],
|
||||
"meta": {"chartId": 1, "width": 12, "height": 100},
|
||||
},
|
||||
}
|
||||
request = {
|
||||
"chart_ids": [1],
|
||||
"dashboard_title": "Custom Layout Dashboard",
|
||||
"position_json": custom_layout,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("generate_dashboard", {"request": request})
|
||||
|
||||
assert result.structured_content["error"] is None
|
||||
created = mock_dashboard_cls.return_value
|
||||
# position_json on the model is a JSON string; round-trip it
|
||||
# and verify the caller's layout — not the auto-generated 2-col
|
||||
# grid — was written.
|
||||
stored = json.loads(created.position_json)
|
||||
assert stored == custom_layout
|
||||
# The auto-generated layout's HEADER/ROW ids wouldn't match
|
||||
# `ROW-custom`; this sanity-check guards against regressions
|
||||
# where the override silently merges with the default.
|
||||
assert "ROW-custom" in stored
|
||||
|
||||
@patch("superset.models.dashboard.Dashboard")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_json_metadata_overrides_shallow_merged(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_find_by_id,
|
||||
mock_dashboard_cls,
|
||||
mcp_server,
|
||||
) -> None:
|
||||
"""``json_metadata_overrides`` is shallow-merged on top of the
|
||||
defaults: caller-supplied keys win, defaults for keys the caller
|
||||
didn't touch survive."""
|
||||
from superset.utils import json
|
||||
|
||||
charts = [_mock_chart(id=1, slice_name="Sales")]
|
||||
mock_dashboard = _mock_dashboard(id=71, title="Themed Dashboard")
|
||||
_setup_generate_dashboard_mocks(
|
||||
mock_db_session,
|
||||
mock_find_by_id,
|
||||
mock_dashboard_cls,
|
||||
charts,
|
||||
mock_dashboard,
|
||||
)
|
||||
|
||||
overrides = {
|
||||
"label_colors": {"Electronics": "#4C78A8"},
|
||||
"cross_filters_enabled": True,
|
||||
}
|
||||
request = {
|
||||
"chart_ids": [1],
|
||||
"dashboard_title": "Themed Dashboard",
|
||||
"json_metadata_overrides": overrides,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("generate_dashboard", {"request": request})
|
||||
|
||||
assert result.structured_content["error"] is None
|
||||
created = mock_dashboard_cls.return_value
|
||||
merged = json.loads(created.json_metadata)
|
||||
# Caller overrides land verbatim
|
||||
assert merged["label_colors"] == {"Electronics": "#4C78A8"}
|
||||
assert merged["cross_filters_enabled"] is True
|
||||
# Defaults for keys the caller did NOT supply must still be
|
||||
# present — this is what makes it a *shallow merge* rather
|
||||
# than a full replace.
|
||||
assert "timed_refresh_immune_slices" in merged
|
||||
assert "refresh_frequency" in merged
|
||||
|
||||
@patch("superset.models.dashboard.Dashboard")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_slug_and_css_applied(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_find_by_id,
|
||||
mock_dashboard_cls,
|
||||
mcp_server,
|
||||
) -> None:
|
||||
"""Both ``slug`` and ``css`` land on the created dashboard model
|
||||
when supplied. Verifying via the model directly — not just the
|
||||
response — confirms the tool wrote the fields, not merely echoed
|
||||
them back."""
|
||||
charts = [_mock_chart(id=1, slice_name="Sales")]
|
||||
mock_dashboard = _mock_dashboard(id=72, title="Branded Dashboard")
|
||||
_setup_generate_dashboard_mocks(
|
||||
mock_db_session,
|
||||
mock_find_by_id,
|
||||
mock_dashboard_cls,
|
||||
charts,
|
||||
mock_dashboard,
|
||||
)
|
||||
|
||||
css_value = ".header-controls { display: none; }"
|
||||
request = {
|
||||
"chart_ids": [1],
|
||||
"dashboard_title": "Branded Dashboard",
|
||||
"slug": "branded-q4",
|
||||
"css": css_value,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("generate_dashboard", {"request": request})
|
||||
|
||||
assert result.structured_content["error"] is None
|
||||
created = mock_dashboard_cls.return_value
|
||||
assert created.slug == "branded-q4"
|
||||
assert created.css == css_value
|
||||
|
||||
|
||||
class TestAddChartToExistingDashboard:
|
||||
"""Tests for add_chart_to_existing_dashboard MCP tool."""
|
||||
|
||||
@@ -0,0 +1,311 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Tests for the update_dashboard MCP tool."""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.utils import json
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth():
|
||||
"""Mock authentication for all tests in this module."""
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
|
||||
with patch("superset.security_manager.raise_for_ownership"):
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
def _mock_dashboard(
|
||||
id: int = 42,
|
||||
title: str = "Test Dashboard",
|
||||
slug: str | None = "test-slug",
|
||||
published: bool = True,
|
||||
css: str | None = None,
|
||||
json_metadata: str | None = None,
|
||||
position_json: str | None = None,
|
||||
):
|
||||
"""Build a Mock with EVERY field the DashboardInfo serializer touches
|
||||
explicitly set. Without this, Mock returns auto-Mock objects for
|
||||
unset attributes, which Pydantic rejects as wrong-type."""
|
||||
dashboard = Mock()
|
||||
dashboard.id = id
|
||||
dashboard.dashboard_title = title
|
||||
dashboard.slug = slug
|
||||
dashboard.description = "desc"
|
||||
dashboard.published = published
|
||||
dashboard.css = css
|
||||
dashboard.json_metadata = json_metadata or json.dumps({"label_colors": {}})
|
||||
dashboard.position_json = position_json
|
||||
dashboard.certified_by = None
|
||||
dashboard.certification_details = None
|
||||
dashboard.is_managed_externally = False
|
||||
dashboard.external_url = None
|
||||
dashboard.created_on = datetime(2024, 1, 1)
|
||||
dashboard.changed_on = datetime(2024, 1, 1)
|
||||
dashboard.created_by = Mock(username="admin")
|
||||
dashboard.changed_by = Mock(username="admin")
|
||||
dashboard.created_by_name = "admin"
|
||||
dashboard.changed_by_name = "admin"
|
||||
dashboard.created_on_humanized = "a day ago"
|
||||
dashboard.changed_on_humanized = "a day ago"
|
||||
dashboard.uuid = f"dashboard-uuid-{id}"
|
||||
dashboard.slices = []
|
||||
dashboard.owners = []
|
||||
dashboard.tags = []
|
||||
return dashboard
|
||||
|
||||
|
||||
class TestUpdateDashboard:
|
||||
"""update_dashboard patches existing dashboard layout/theme/CSS."""
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.get_by_id_or_slug")
|
||||
@patch("superset.extensions.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_layout_theme_and_css(
|
||||
self, mock_session, mock_get, mcp_server
|
||||
) -> None:
|
||||
dash = _mock_dashboard(
|
||||
id=42,
|
||||
json_metadata=json.dumps(
|
||||
{"label_colors": {"Old": "#000"}, "cross_filters_enabled": True}
|
||||
),
|
||||
)
|
||||
mock_get.return_value = dash
|
||||
|
||||
position = {"ROOT_ID": {"type": "ROOT", "children": ["GRID_ID"]}}
|
||||
overrides = {
|
||||
"label_colors": {"Electronics": "#4C78A8"},
|
||||
"cross_filters_enabled": False,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"update_dashboard",
|
||||
{
|
||||
"request": {
|
||||
"identifier": 42,
|
||||
"position_json": position,
|
||||
"json_metadata_overrides": overrides,
|
||||
"css": ".x{color:red}",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# All three top-level writes applied to the model
|
||||
assert dash.position_json == json.dumps(position)
|
||||
assert dash.css == ".x{color:red}"
|
||||
# json_metadata is shallow-merged: label_colors REPLACED (top-level
|
||||
# key), but other keys not in overrides preserved
|
||||
merged = json.loads(dash.json_metadata)
|
||||
assert merged["label_colors"] == {"Electronics": "#4C78A8"}
|
||||
assert merged["cross_filters_enabled"] is False
|
||||
assert mock_session.commit.call_count >= 1
|
||||
# changed_fields enumerates what actually changed.
|
||||
# StructuredContentStripperMiddleware strips structured_content;
|
||||
# the JSON-encoded response lives in content[0].text.
|
||||
payload = json.loads(result.content[0].text)
|
||||
changed = set(payload.get("changed_fields") or [])
|
||||
assert {"position_json", "json_metadata", "css"} <= changed
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.get_by_id_or_slug")
|
||||
@patch("superset.extensions.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_with_no_fields_is_noop(
|
||||
self, mock_session, mock_get, mcp_server
|
||||
) -> None:
|
||||
dash = _mock_dashboard(id=42)
|
||||
original_css = dash.css
|
||||
original_title = dash.dashboard_title
|
||||
mock_get.return_value = dash
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"update_dashboard", {"request": {"identifier": 42}}
|
||||
)
|
||||
|
||||
# Nothing modified — dashboard fields unchanged, warning emitted
|
||||
assert dash.css == original_css
|
||||
assert dash.dashboard_title == original_title
|
||||
# StructuredContentStripperMiddleware strips structured_content;
|
||||
# the JSON-encoded response lives in content[0].text.
|
||||
payload = json.loads(result.content[0].text)
|
||||
warnings = payload.get("warnings") or []
|
||||
assert any("No fields provided" in w for w in warnings)
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.get_by_id_or_slug")
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_missing_dashboard_returns_error(
|
||||
self, mock_get, mcp_server
|
||||
) -> None:
|
||||
# get_by_id_or_slug raises when not found; the tool catches and
|
||||
# returns a structured DashboardError
|
||||
from superset.commands.dashboard.exceptions import (
|
||||
DashboardNotFoundError,
|
||||
)
|
||||
|
||||
mock_get.side_effect = DashboardNotFoundError()
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"update_dashboard", {"request": {"identifier": 999999}}
|
||||
)
|
||||
|
||||
payload = json.loads(result.content[0].text)
|
||||
assert "not found" in (payload.get("error") or "").lower()
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.get_by_id_or_slug")
|
||||
@patch("superset.extensions.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_title_and_slug_and_published(
|
||||
self, mock_session, mock_get, mcp_server
|
||||
) -> None:
|
||||
dash = _mock_dashboard(id=42, published=False)
|
||||
mock_get.return_value = dash
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
await client.call_tool(
|
||||
"update_dashboard",
|
||||
{
|
||||
"request": {
|
||||
"identifier": 42,
|
||||
"dashboard_title": "Renamed",
|
||||
"slug": "renamed-slug",
|
||||
"published": True,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert dash.dashboard_title == "Renamed"
|
||||
assert dash.slug == "renamed-slug"
|
||||
assert dash.published is True
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.get_by_id_or_slug")
|
||||
@patch("superset.extensions.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_description(self, mock_session, mock_get, mcp_server) -> None:
|
||||
"""A description-only update writes ``description`` and reports
|
||||
it in ``changed_fields`` without touching other fields."""
|
||||
dash = _mock_dashboard(id=42)
|
||||
original_title = dash.dashboard_title
|
||||
original_slug = dash.slug
|
||||
mock_get.return_value = dash
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"update_dashboard",
|
||||
{
|
||||
"request": {
|
||||
"identifier": 42,
|
||||
"description": "Q4 executive review — refreshed weekly.",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert dash.description == "Q4 executive review — refreshed weekly."
|
||||
# Other fields untouched
|
||||
assert dash.dashboard_title == original_title
|
||||
assert dash.slug == original_slug
|
||||
payload = json.loads(result.content[0].text)
|
||||
assert "description" in payload.get("changed_fields", [])
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.get_by_id_or_slug")
|
||||
@patch("superset.extensions.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_slug_clears_slug(
|
||||
self, mock_session, mock_get, mcp_server
|
||||
) -> None:
|
||||
"""An explicit empty string clears the slug."""
|
||||
dash = _mock_dashboard(id=42, slug="had-a-slug")
|
||||
mock_get.return_value = dash
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
await client.call_tool(
|
||||
"update_dashboard",
|
||||
{"request": {"identifier": 42, "slug": ""}},
|
||||
)
|
||||
|
||||
assert dash.slug is None
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.get_by_id_or_slug")
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_owner_gets_permission_denied(self, mock_get, mcp_server) -> None:
|
||||
"""A user without ownership on the dashboard receives a
|
||||
permission_denied response — the class-level Dashboard.write
|
||||
permission is not enough on its own.
|
||||
"""
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
|
||||
dash = _mock_dashboard(id=42)
|
||||
mock_get.return_value = dash
|
||||
|
||||
# mock_auth fixture patches raise_for_ownership to a no-op for the
|
||||
# whole module; override here so this one test sees the real
|
||||
# ownership rejection path.
|
||||
with patch(
|
||||
"superset.security_manager.raise_for_ownership",
|
||||
side_effect=SupersetSecurityException(Mock(message="forbidden")),
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"update_dashboard",
|
||||
{
|
||||
"request": {
|
||||
"identifier": 42,
|
||||
"dashboard_title": "Hostile rename",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
payload = json.loads(result.content[0].text)
|
||||
assert payload.get("permission_denied") is True
|
||||
assert "permission" in (payload.get("error") or "").lower()
|
||||
# Title was NOT applied
|
||||
assert dash.dashboard_title == "Test Dashboard"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_xss_only_title_is_rejected(self, mcp_server) -> None:
|
||||
"""A dashboard_title that sanitizes to an empty string raises at
|
||||
the Pydantic layer — same guard as ``generate_dashboard``. The
|
||||
update path must not be a backdoor for XSS payloads."""
|
||||
from fastmcp.exceptions import ToolError
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
with pytest.raises(ToolError, match="dashboard_title"):
|
||||
await client.call_tool(
|
||||
"update_dashboard",
|
||||
{
|
||||
"request": {
|
||||
"identifier": 42,
|
||||
"dashboard_title": "<script>alert(1)</script>",
|
||||
}
|
||||
},
|
||||
)
|
||||
@@ -59,6 +59,7 @@ def test_default_query_object_to_dict():
|
||||
"series_limit": 0,
|
||||
"series_limit_metric": None,
|
||||
"time_shift": None,
|
||||
"time_compare_full_range": False,
|
||||
"to_dttm": None,
|
||||
}
|
||||
|
||||
|
||||
@@ -25,7 +25,9 @@ from superset.security.api import RlsRuleSchema
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"app",
|
||||
[{"WTF_CSRF_ENABLED": True}],
|
||||
# Enable the Swagger UI / OpenAPI spec (opt-in, off by default) so the
|
||||
# OpenApi blueprint is registered and included in the exempt set below.
|
||||
[{"WTF_CSRF_ENABLED": True, "FAB_API_SWAGGER_UI": True}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_csrf_exempt_blueprints(app_context: None) -> None:
|
||||
|
||||
Reference in New Issue
Block a user