mirror of
https://github.com/apache/superset.git
synced 2026-05-16 05:15:16 +00:00
Compare commits
22 Commits
fix/mcp-ex
...
mcp-chart-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f5ba09f7af | ||
|
|
9d298f86f0 | ||
|
|
b34d346e0d | ||
|
|
18a9eff641 | ||
|
|
eface3bf54 | ||
|
|
b41a53bc8f | ||
|
|
f1b95d6ae3 | ||
|
|
1e2b541600 | ||
|
|
b16de3622f | ||
|
|
5e02d0ec65 | ||
|
|
139eea92f6 | ||
|
|
b09cbc80aa | ||
|
|
e7adf0c670 | ||
|
|
ad5e3170dd | ||
|
|
aa710672ed | ||
|
|
8c80caefa3 | ||
|
|
8088c5d1de | ||
|
|
9b520312a1 | ||
|
|
9ac4711ac8 | ||
|
|
7593d2a164 | ||
|
|
d3c44e311e | ||
|
|
b5186d1c65 |
2
.github/workflows/ephemeral-env-pr-close.yml
vendored
2
.github/workflows/ephemeral-env-pr-close.yml
vendored
@@ -58,7 +58,7 @@ jobs:
|
||||
- name: Login to Amazon ECR
|
||||
if: steps.describe-services.outputs.active == 'true'
|
||||
id: login-ecr
|
||||
uses: aws-actions/amazon-ecr-login@19d944daaa35f0fa1d3f7f8af1d3f2e5de25c5b7 # v2
|
||||
uses: aws-actions/amazon-ecr-login@fa648b43de3d4d023bcb3f89ed6940096949c419 # v2
|
||||
|
||||
- name: Delete ECR image tag
|
||||
if: steps.describe-services.outputs.active == 'true'
|
||||
|
||||
4
.github/workflows/ephemeral-env.yml
vendored
4
.github/workflows/ephemeral-env.yml
vendored
@@ -199,7 +199,7 @@ jobs:
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
uses: aws-actions/amazon-ecr-login@19d944daaa35f0fa1d3f7f8af1d3f2e5de25c5b7 # v2
|
||||
uses: aws-actions/amazon-ecr-login@fa648b43de3d4d023bcb3f89ed6940096949c419 # v2
|
||||
|
||||
- name: Load, tag and push image to ECR
|
||||
id: push-image
|
||||
@@ -235,7 +235,7 @@ jobs:
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
uses: aws-actions/amazon-ecr-login@19d944daaa35f0fa1d3f7f8af1d3f2e5de25c5b7 # v2
|
||||
uses: aws-actions/amazon-ecr-login@fa648b43de3d4d023bcb3f89ed6940096949c419 # v2
|
||||
|
||||
- name: Check target image exists in ECR
|
||||
id: check-image
|
||||
|
||||
@@ -70,7 +70,7 @@
|
||||
"@swc/core": "^1.15.33",
|
||||
"antd": "^6.3.7",
|
||||
"baseline-browser-mapping": "^2.10.27",
|
||||
"caniuse-lite": "^1.0.30001791",
|
||||
"caniuse-lite": "^1.0.30001792",
|
||||
"docusaurus-plugin-openapi-docs": "^5.0.2",
|
||||
"docusaurus-theme-openapi-docs": "^5.0.2",
|
||||
"js-yaml": "^4.1.1",
|
||||
|
||||
@@ -6035,10 +6035,10 @@ caniuse-api@^3.0.0:
|
||||
lodash.memoize "^4.1.2"
|
||||
lodash.uniq "^4.5.0"
|
||||
|
||||
caniuse-lite@^1.0.0, caniuse-lite@^1.0.30001702, caniuse-lite@^1.0.30001759, caniuse-lite@^1.0.30001791:
|
||||
version "1.0.30001791"
|
||||
resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001791.tgz#dfb93d85c40ad380c57123e72e10f3c575786b51"
|
||||
integrity sha512-yk0l/YSrOnFZk3UROpDLQD9+kC1l4meK/wed583AXrzoarMGJcbRi2Q4RaUYbKxYAsZ8sWmaSa/DsLmdBeI1vQ==
|
||||
caniuse-lite@^1.0.0, caniuse-lite@^1.0.30001702, caniuse-lite@^1.0.30001759, caniuse-lite@^1.0.30001792:
|
||||
version "1.0.30001792"
|
||||
resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001792.tgz#ca8bb9be244835a335e2018272ce7223691873c5"
|
||||
integrity sha512-hVLMUZFgR4JJ6ACt1uEESvQN1/dBVqPAKY0hgrV70eN3391K6juAfTjKZLKvOMsx8PxA7gsY1/tLMMTcfFLLpw==
|
||||
|
||||
ccount@^2.0.0:
|
||||
version "2.0.1"
|
||||
|
||||
@@ -142,10 +142,16 @@ druid = ["pydruid>=0.6.5,<0.7"]
|
||||
duckdb = ["duckdb>=1.4.2,<2", "duckdb-engine>=0.17.0"]
|
||||
dynamodb = ["pydynamodb>=0.4.2"]
|
||||
solr = ["sqlalchemy-solr >= 0.2.0"]
|
||||
elasticsearch = ["elasticsearch-dbapi>=0.2.12, <0.3.0"]
|
||||
elasticsearch = ["elasticsearch-dbapi>=0.2.13, <0.3.0"]
|
||||
exasol = ["sqlalchemy-exasol >= 2.4.0, <3.0"]
|
||||
excel = ["xlrd>=1.2.0, <1.3"]
|
||||
fastmcp = ["fastmcp>=3.2.4,<4.0"]
|
||||
fastmcp = [
|
||||
"fastmcp>=3.2.4,<4.0",
|
||||
# tiktoken backs the response-size-guard token estimator. Without
|
||||
# it, the middleware falls back to a coarser character-based
|
||||
# heuristic that under-counts JSON-heavy MCP responses.
|
||||
"tiktoken>=0.7.0,<1.0",
|
||||
]
|
||||
firebird = ["sqlalchemy-firebird>=0.7.0, <0.8"]
|
||||
firebolt = ["firebolt-sqlalchemy>=1.0.0, <2"]
|
||||
gevent = ["gevent>=23.9.1"]
|
||||
|
||||
@@ -183,7 +183,9 @@ idna==3.10
|
||||
# trio
|
||||
# url-normalize
|
||||
isodate==0.7.2
|
||||
# via apache-superset (pyproject.toml)
|
||||
# via
|
||||
# apache-superset (pyproject.toml)
|
||||
# apache-superset-core
|
||||
itsdangerous==2.2.0
|
||||
# via
|
||||
# flask
|
||||
@@ -296,6 +298,7 @@ pyarrow==20.0.0
|
||||
# via
|
||||
# -r requirements/base.in
|
||||
# apache-superset (pyproject.toml)
|
||||
# apache-superset-core
|
||||
pyasn1==0.6.3
|
||||
# via
|
||||
# pyasn1-modules
|
||||
|
||||
@@ -442,6 +442,7 @@ isodate==0.7.2
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
# apache-superset
|
||||
# apache-superset-core
|
||||
isort==6.0.1
|
||||
# via pylint
|
||||
itsdangerous==2.2.0
|
||||
@@ -715,6 +716,7 @@ pyarrow==20.0.0
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
# apache-superset
|
||||
# apache-superset-core
|
||||
# db-dtypes
|
||||
# pandas-gbq
|
||||
pyasn1==0.6.3
|
||||
@@ -866,6 +868,8 @@ referencing==0.36.2
|
||||
# jsonschema
|
||||
# jsonschema-path
|
||||
# jsonschema-specifications
|
||||
regex==2026.4.4
|
||||
# via tiktoken
|
||||
requests==2.33.0
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
@@ -878,6 +882,7 @@ requests==2.33.0
|
||||
# requests-cache
|
||||
# requests-oauthlib
|
||||
# shillelagh
|
||||
# tiktoken
|
||||
# trino
|
||||
requests-cache==1.2.1
|
||||
# via
|
||||
@@ -1003,6 +1008,8 @@ tabulate==0.9.0
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
# apache-superset
|
||||
tiktoken==0.12.0
|
||||
# via apache-superset
|
||||
tomli-w==1.2.0
|
||||
# via apache-superset-extensions-cli
|
||||
tomlkit==0.13.3
|
||||
|
||||
@@ -17,7 +17,8 @@
|
||||
* under the License.
|
||||
*/
|
||||
import { render, screen, act } from 'spec/helpers/testing-library';
|
||||
import { StatusIndicatorDot } from './StatusIndicatorDot';
|
||||
import { supersetTheme } from '@apache-superset/core/theme';
|
||||
import { getStatusConfig, StatusIndicatorDot } from './StatusIndicatorDot';
|
||||
import { AutoRefreshStatus } from '../../types/autoRefresh';
|
||||
|
||||
afterEach(() => {
|
||||
@@ -62,6 +63,15 @@ test('renders with paused status', () => {
|
||||
expect(dot).toHaveAttribute('data-status', AutoRefreshStatus.Paused);
|
||||
});
|
||||
|
||||
test('uses the icon color for the paused status outline', () => {
|
||||
expect(
|
||||
getStatusConfig(supersetTheme, AutoRefreshStatus.Paused),
|
||||
).toMatchObject({
|
||||
needsBorder: true,
|
||||
outlineColor: 'currentColor',
|
||||
});
|
||||
});
|
||||
|
||||
test('has correct accessibility attributes', () => {
|
||||
render(<StatusIndicatorDot status={AutoRefreshStatus.Success} />);
|
||||
const dot = screen.getByTestId('status-indicator-dot');
|
||||
|
||||
@@ -39,9 +39,10 @@ export interface StatusIndicatorDotProps {
|
||||
interface StatusConfig {
|
||||
color: string;
|
||||
needsBorder: boolean;
|
||||
outlineColor?: string;
|
||||
}
|
||||
|
||||
const getStatusConfig = (
|
||||
export const getStatusConfig = (
|
||||
theme: ReturnType<typeof useTheme>,
|
||||
status: AutoRefreshStatus,
|
||||
): StatusConfig => {
|
||||
@@ -75,6 +76,7 @@ const getStatusConfig = (
|
||||
return {
|
||||
color: theme.colorBgContainer,
|
||||
needsBorder: true,
|
||||
outlineColor: 'currentColor',
|
||||
};
|
||||
default:
|
||||
return {
|
||||
@@ -136,13 +138,15 @@ export const StatusIndicatorDot: FC<StatusIndicatorDotProps> = ({
|
||||
width: ${size}px;
|
||||
height: ${size}px;
|
||||
border-radius: 50%;
|
||||
color: ${theme.colorTextSecondary};
|
||||
background-color: ${statusConfig.color};
|
||||
transition:
|
||||
background-color ${theme.motionDurationMid} ease-in-out,
|
||||
border-color ${theme.motionDurationMid} ease-in-out;
|
||||
border: ${statusConfig.needsBorder
|
||||
? `1px solid ${theme.colorBorder}`
|
||||
: 'none'};
|
||||
border: ${statusConfig.needsBorder ? '1px solid' : 'none'};
|
||||
border-color: ${statusConfig.needsBorder
|
||||
? statusConfig.outlineColor
|
||||
: 'transparent'};
|
||||
box-shadow: ${statusConfig.needsBorder
|
||||
? 'none'
|
||||
: `0 0 0 2px ${theme.colorBgContainer}`};
|
||||
|
||||
@@ -21,6 +21,10 @@ import { VizType } from '@superset-ui/core';
|
||||
import { hydrateExplore, HYDRATE_EXPLORE } from './hydrateExplore';
|
||||
import { exploreInitialData } from '../fixtures';
|
||||
|
||||
afterEach(() => {
|
||||
window.history.pushState({}, '', '/');
|
||||
});
|
||||
|
||||
test('creates hydrate action from initial data', () => {
|
||||
const dispatch = jest.fn();
|
||||
const getState = jest.fn(() => ({
|
||||
@@ -168,6 +172,84 @@ test('creates hydrate action with existing state', () => {
|
||||
);
|
||||
});
|
||||
|
||||
test('hydrates sliceName from preview form data before saved slice name', () => {
|
||||
window.history.pushState({}, '', '/explore/?form_data_key=preview-key');
|
||||
|
||||
const dispatch = jest.fn();
|
||||
const getState = jest.fn(() => ({
|
||||
user: {},
|
||||
charts: {},
|
||||
datasources: {},
|
||||
common: {},
|
||||
explore: {},
|
||||
}));
|
||||
const previewSliceName = 'RENAMED - Bug Evidence';
|
||||
const savedSliceName = 'Most Populated Countries';
|
||||
const previewInitialData = {
|
||||
...exploreInitialData,
|
||||
form_data: {
|
||||
...exploreInitialData.form_data,
|
||||
slice_name: previewSliceName,
|
||||
},
|
||||
slice: {
|
||||
...exploreInitialData.slice!,
|
||||
slice_name: savedSliceName,
|
||||
},
|
||||
};
|
||||
|
||||
// @ts-expect-error we only need the fields consumed by hydrateExplore
|
||||
hydrateExplore(previewInitialData)(dispatch, getState);
|
||||
|
||||
expect(dispatch).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: HYDRATE_EXPLORE,
|
||||
data: expect.objectContaining({
|
||||
explore: expect.objectContaining({
|
||||
sliceName: previewSliceName,
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
test('hydrates sliceName from saved slice when regular form data has stale name', () => {
|
||||
const dispatch = jest.fn();
|
||||
const getState = jest.fn(() => ({
|
||||
user: {},
|
||||
charts: {},
|
||||
datasources: {},
|
||||
common: {},
|
||||
explore: {},
|
||||
}));
|
||||
const staleFormDataSliceName = 'Stale Params Name';
|
||||
const savedSliceName = 'Current Saved Name';
|
||||
const savedChartInitialData = {
|
||||
...exploreInitialData,
|
||||
form_data: {
|
||||
...exploreInitialData.form_data,
|
||||
slice_name: staleFormDataSliceName,
|
||||
},
|
||||
slice: {
|
||||
...exploreInitialData.slice!,
|
||||
slice_name: savedSliceName,
|
||||
},
|
||||
};
|
||||
|
||||
// @ts-expect-error we only need the fields consumed by hydrateExplore
|
||||
hydrateExplore(savedChartInitialData)(dispatch, getState);
|
||||
|
||||
expect(dispatch).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: HYDRATE_EXPLORE,
|
||||
data: expect.objectContaining({
|
||||
explore: expect.objectContaining({
|
||||
sliceName: savedSliceName,
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
test('uses configured default time range if not set', () => {
|
||||
const dispatch = jest.fn();
|
||||
const getState = jest.fn(() => ({
|
||||
|
||||
@@ -77,6 +77,12 @@ export const hydrateExplore =
|
||||
const fallbackSlice = sliceId ? sliceEntities?.slices?.[sliceId] : null;
|
||||
const initialSlice = slice ?? fallbackSlice;
|
||||
const initialFormData = form_data ?? initialSlice?.form_data;
|
||||
const isCachedFormData = getUrlParam(URL_PARAMS.formDataKey) !== null;
|
||||
const [primarySliceNameSource, fallbackSliceNameSource] = isCachedFormData
|
||||
? [initialFormData, initialSlice]
|
||||
: [initialSlice, initialFormData];
|
||||
const initialSliceName =
|
||||
primarySliceNameSource?.slice_name ?? fallbackSliceNameSource?.slice_name;
|
||||
if (!initialFormData.viz_type) {
|
||||
const defaultVizType = common?.conf.DEFAULT_VIZ_TYPE || VizType.Table;
|
||||
initialFormData.viz_type =
|
||||
@@ -183,6 +189,7 @@ export const hydrateExplore =
|
||||
// because `bootstrapData.controls` is undefined.
|
||||
controls: initialControls,
|
||||
form_data: initialFormData,
|
||||
sliceName: initialSliceName,
|
||||
slice: initialSlice,
|
||||
controlsTransferred: explore.controlsTransferred,
|
||||
standalone: getUrlParam(URL_PARAMS.standalone),
|
||||
|
||||
@@ -179,6 +179,33 @@ test('renders the right footer buttons', () => {
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('initializes chart name from current Explore slice name', () => {
|
||||
const previewSliceName = 'RENAMED - Bug Evidence';
|
||||
const savedSliceName = 'Most Populated Countries';
|
||||
const { getByTestId } = setup(
|
||||
{
|
||||
...defaultProps,
|
||||
form_data: {
|
||||
...defaultProps.form_data,
|
||||
slice_name: previewSliceName,
|
||||
},
|
||||
sliceName: previewSliceName,
|
||||
},
|
||||
mockStore({
|
||||
...initialState,
|
||||
explore: {
|
||||
...initialState.explore,
|
||||
slice: {
|
||||
...initialState.explore.slice,
|
||||
slice_name: savedSliceName,
|
||||
},
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
expect(getByTestId('new-chart-name')).toHaveValue(previewSliceName);
|
||||
});
|
||||
|
||||
test('does not render a message when overriding', () => {
|
||||
const { getByRole, queryByRole } = setup();
|
||||
|
||||
|
||||
@@ -35,7 +35,6 @@ import { CheckboxChangeEvent } from '@superset-ui/core/components/Checkbox/types
|
||||
|
||||
import { useHistory } from 'react-router-dom';
|
||||
import { setItem, LocalStorageKeys } from 'src/utils/localStorageHelpers';
|
||||
import { makeUrl } from 'src/utils/pathUtils';
|
||||
import Tabs from '@superset-ui/core/components/Tabs';
|
||||
import {
|
||||
Button,
|
||||
@@ -1824,7 +1823,9 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
||||
onClick={() => {
|
||||
setLoading(true);
|
||||
fetchAndSetDB();
|
||||
redirectURL(makeUrl(`/sqllab?db=true`));
|
||||
// redirectURL() delegates to history.push; React Router's basename
|
||||
// already prefixes the application root, so pass a relative path.
|
||||
redirectURL('/sqllab?db=true');
|
||||
}}
|
||||
>
|
||||
{t('Query data in SQL Lab')}
|
||||
|
||||
@@ -24,7 +24,6 @@ import { TableTab } from 'src/views/CRUD/types';
|
||||
import { t } from '@apache-superset/core/translation';
|
||||
import { styled } from '@apache-superset/core/theme';
|
||||
import { navigateTo } from 'src/utils/navigationUtils';
|
||||
import { makeUrl } from 'src/utils/pathUtils';
|
||||
import { WelcomeTable } from './types';
|
||||
|
||||
const EmptyContainer = styled.div`
|
||||
@@ -59,7 +58,9 @@ const REDIRECTS = {
|
||||
create: {
|
||||
[WelcomeTable.Charts]: '/chart/add',
|
||||
[WelcomeTable.Dashboards]: '/dashboard/new',
|
||||
[WelcomeTable.SavedQueries]: makeUrl('/sqllab?new=true'),
|
||||
// navigateTo() applies the application root internally; keep this
|
||||
// relative so the prefix isn't added twice.
|
||||
[WelcomeTable.SavedQueries]: '/sqllab?new=true',
|
||||
},
|
||||
viewAll: {
|
||||
[WelcomeTable.Charts]: '/chart/list',
|
||||
|
||||
@@ -44,7 +44,7 @@ import {
|
||||
TelemetryPixel,
|
||||
} from '@superset-ui/core/components';
|
||||
import type { ItemType, MenuItem } from '@superset-ui/core/components/Menu';
|
||||
import { ensureAppRoot, makeUrl } from 'src/utils/pathUtils';
|
||||
import { ensureAppRoot } from 'src/utils/pathUtils';
|
||||
import { isEmbedded } from 'src/dashboard/util/isEmbedded';
|
||||
import { findPermission } from 'src/utils/findPermission';
|
||||
import { isUserAdmin } from 'src/dashboard/util/permissionUtils';
|
||||
@@ -213,7 +213,10 @@ const RightMenu = ({
|
||||
},
|
||||
{
|
||||
label: t('SQL query'),
|
||||
url: makeUrl('/sqllab?new=true'),
|
||||
// Keep the URL relative so isFrontendRoute() matches and Link navigates
|
||||
// via React Router; the <Typography.Link> fallback applies ensureAppRoot
|
||||
// exactly once for non-frontend routes.
|
||||
url: '/sqllab?new=true',
|
||||
icon: <Icons.SearchOutlined data-test={`menu-item-${t('SQL query')}`} />,
|
||||
perm: 'can_sqllab',
|
||||
view: 'Superset',
|
||||
|
||||
@@ -25,11 +25,20 @@ import {
|
||||
fireEvent,
|
||||
waitFor,
|
||||
} from 'spec/helpers/testing-library';
|
||||
import { MemoryRouter } from 'react-router-dom';
|
||||
import { MemoryRouter, useLocation } from 'react-router-dom';
|
||||
import { QueryParamProvider } from 'use-query-params';
|
||||
import { ReactRouter5Adapter } from 'use-query-params/adapters/react-router-5';
|
||||
import * as getBootstrapData from 'src/utils/getBootstrapData';
|
||||
import SavedQueryList from '.';
|
||||
|
||||
// Renders the current router pathname+search so tests can assert navigation.
|
||||
function LocationDisplay() {
|
||||
const location = useLocation();
|
||||
return (
|
||||
<div data-test="location-display">{`${location.pathname}${location.search}`}</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Increase default timeout
|
||||
jest.setTimeout(30000);
|
||||
|
||||
@@ -88,6 +97,7 @@ const renderList = (props = {}, storeOverrides = {}) =>
|
||||
<MemoryRouter>
|
||||
<QueryParamProvider adapter={ReactRouter5Adapter}>
|
||||
<SavedQueryList user={mockUser} {...props} />
|
||||
<LocationDisplay />
|
||||
</QueryParamProvider>
|
||||
</MemoryRouter>,
|
||||
{
|
||||
@@ -242,4 +252,39 @@ describe('SavedQueryList', () => {
|
||||
// Verify delete buttons are not shown
|
||||
expect(screen.queryByTestId('delete-action')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('"+ Query" button pushes a router-relative path (subdirectory deployment)', async () => {
|
||||
// Simulate SUPERSET_APP_ROOT=/superset. ensureAppRoot/makeUrl read
|
||||
// applicationRoot() dynamically, so mocking it here makes the buggy code
|
||||
// path (makeUrl() around history.push) produce '/superset/sqllab?new=true'
|
||||
// instead of being a no-op. React Router's <Router basename> prefixes the
|
||||
// app root on its own, so history.push MUST receive a path without the
|
||||
// app-root prefix — otherwise navigation lands at /superset/superset/sqllab
|
||||
// and shows a blank page (sc-103661).
|
||||
const applicationRootSpy = jest
|
||||
.spyOn(getBootstrapData, 'applicationRoot')
|
||||
.mockReturnValue('/superset');
|
||||
|
||||
try {
|
||||
renderList();
|
||||
|
||||
await screen.findByTestId('saved_query-list-view');
|
||||
|
||||
const queryButton = await screen.findByRole('button', {
|
||||
name: /query/i,
|
||||
});
|
||||
fireEvent.click(queryButton);
|
||||
|
||||
await waitFor(() => {
|
||||
// The MemoryRouter in renderList uses the default ('/') basename, so
|
||||
// useLocation reflects exactly what history.push received. A correct
|
||||
// router-relative push produces '/sqllab?new=true'; a buggy push that
|
||||
// re-applied the app root would produce '/superset/sqllab?new=true'.
|
||||
const location = screen.getByTestId('location-display').textContent;
|
||||
expect(location).toBe('/sqllab?new=true');
|
||||
});
|
||||
} finally {
|
||||
applicationRootSpy.mockRestore();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -223,7 +223,9 @@ function SavedQueryList({
|
||||
name: t('Query'),
|
||||
buttonStyle: 'primary',
|
||||
onClick: () => {
|
||||
history.push(makeUrl('/sqllab?new=true'));
|
||||
// React Router's basename already includes the application root; passing
|
||||
// a relative path ensures correct navigation under subdirectory deployments.
|
||||
history.push('/sqllab?new=true');
|
||||
},
|
||||
});
|
||||
|
||||
@@ -245,7 +247,9 @@ function SavedQueryList({
|
||||
if (openInNewWindow) {
|
||||
window.open(makeUrl(`/sqllab?savedQueryId=${id}`));
|
||||
} else {
|
||||
history.push(makeUrl(`/sqllab?savedQueryId=${id}`));
|
||||
// React Router's basename already includes the application root; passing
|
||||
// a relative path ensures correct navigation under subdirectory deployments.
|
||||
history.push(`/sqllab?savedQueryId=${id}`);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -338,9 +342,7 @@ function SavedQueryList({
|
||||
row: {
|
||||
original: { id, label },
|
||||
},
|
||||
}: any) => (
|
||||
<Link to={makeUrl(`/sqllab?savedQueryId=${id}`)}>{label}</Link>
|
||||
),
|
||||
}: any) => <Link to={`/sqllab?savedQueryId=${id}`}>{label}</Link>,
|
||||
id: 'label',
|
||||
},
|
||||
{
|
||||
|
||||
17
superset/initialization/__init__.py
Normal file → Executable file
17
superset/initialization/__init__.py
Normal file → Executable file
@@ -747,6 +747,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
|
||||
# Configuration of feature_flags must be done first to allow init features
|
||||
# conditionally
|
||||
self.configure_feature_flags()
|
||||
self.configure_mcp_chart_registry()
|
||||
self.configure_db_encrypt()
|
||||
self.setup_db()
|
||||
|
||||
@@ -821,6 +822,22 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
|
||||
def configure_feature_flags(self) -> None:
|
||||
feature_flag_manager.init_app(self.superset_app)
|
||||
|
||||
def configure_mcp_chart_registry(self) -> None:
|
||||
from superset.mcp_service.chart import registry
|
||||
from superset.mcp_service.mcp_config import (
|
||||
MCP_CHART_PLUGIN_ENABLED_FUNC,
|
||||
MCP_DISABLED_CHART_PLUGINS,
|
||||
)
|
||||
|
||||
registry.configure(
|
||||
disabled=self.config.get(
|
||||
"MCP_DISABLED_CHART_PLUGINS", MCP_DISABLED_CHART_PLUGINS
|
||||
),
|
||||
enabled_func=self.config.get(
|
||||
"MCP_CHART_PLUGIN_ENABLED_FUNC", MCP_CHART_PLUGIN_ENABLED_FUNC
|
||||
),
|
||||
)
|
||||
|
||||
def configure_sqlglot_dialects(self) -> None:
|
||||
extensions = self.config["SQLGLOT_DIALECTS_EXTENSIONS"]
|
||||
|
||||
|
||||
@@ -222,10 +222,12 @@ Time grain for temporal x-axis (time_grain parameter):
|
||||
- PT1H (hourly), P1D (daily), P1W (weekly), P1M (monthly), P1Y (yearly)
|
||||
|
||||
Chart Types in Existing Charts (viewable via list_charts/get_chart_info):
|
||||
- pie, big_number, big_number_total, funnel, gauge_chart
|
||||
- echarts_timeseries_line, echarts_timeseries_bar, echarts_timeseries_area
|
||||
- pivot_table_v2, heatmap_v2, sankey_v2, sunburst_v2, treemap_v2
|
||||
- word_cloud, world_map, box_plot, bubble, mixed_timeseries
|
||||
Each chart returned by list_charts / get_chart_info includes a
|
||||
chart_type_display_name field with a human-readable name when available.
|
||||
This field is populated only for the 7 chart types supported by generate_chart
|
||||
(xy, pie, table, pivot_table, big_number, mixed_timeseries, handlebars).
|
||||
For all other viz_types (Funnel, Gauge, Heatmap, etc.) it will be null —
|
||||
use the raw viz_type field instead when referring to those chart types.
|
||||
|
||||
Query Examples:
|
||||
- List all tables:
|
||||
@@ -503,6 +505,7 @@ warnings.filterwarnings(
|
||||
# NOTE: Always add new prompt/resource imports here when creating new prompts/resources.
|
||||
# Prompts use @mcp.prompt decorators and resources use @mcp.resource decorators.
|
||||
# They register automatically on import, similar to tools.
|
||||
import superset.mcp_service.chart.plugins # noqa: F401, E402 — registers all chart type plugins
|
||||
from superset.mcp_service.chart import ( # noqa: F401, E402
|
||||
prompts as chart_prompts,
|
||||
resources as chart_resources,
|
||||
|
||||
@@ -318,29 +318,35 @@ def map_config_to_form_data(
|
||||
| BigNumberChartConfig,
|
||||
dataset_id: int | str | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Map chart config to Superset form_data."""
|
||||
if isinstance(config, TableChartConfig):
|
||||
return map_table_config(config)
|
||||
elif isinstance(config, XYChartConfig):
|
||||
return map_xy_config(config, dataset_id=dataset_id)
|
||||
elif isinstance(config, PieChartConfig):
|
||||
return map_pie_config(config)
|
||||
elif isinstance(config, PivotTableChartConfig):
|
||||
return map_pivot_table_config(config)
|
||||
elif isinstance(config, MixedTimeseriesChartConfig):
|
||||
return map_mixed_timeseries_config(config, dataset_id=dataset_id)
|
||||
elif isinstance(config, HandlebarsChartConfig):
|
||||
return map_handlebars_config(config)
|
||||
elif isinstance(config, BigNumberChartConfig):
|
||||
if config.show_trendline and config.temporal_column:
|
||||
if not is_column_truly_temporal(config.temporal_column, dataset_id):
|
||||
raise ValueError(
|
||||
f"Big Number trendline requires a temporal SQL column; "
|
||||
f"'{config.temporal_column}' is not temporal."
|
||||
)
|
||||
return map_big_number_config(config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported config type: {type(config)}")
|
||||
"""Map chart config to Superset form_data via the plugin registry.
|
||||
|
||||
The previous if/elif chain across all 7 chart types has been replaced by a
|
||||
single registry lookup. Cross-field constraints (e.g. BigNumber trendline
|
||||
temporal check) are now owned by each plugin's post_map_validate() method
|
||||
rather than being baked into this dispatcher.
|
||||
"""
|
||||
# Local import: plugins call map_*_config from their to_form_data() methods,
|
||||
# so chart_utils is loaded before plugins finish registering. A top-level
|
||||
# import of registry here would trigger plugin loading mid-import = cycle.
|
||||
from superset.mcp_service.chart.registry import get_registry
|
||||
|
||||
chart_type = getattr(config, "chart_type", None)
|
||||
plugin = get_registry().get(chart_type) if chart_type else None
|
||||
|
||||
if plugin is None:
|
||||
raise ValueError(
|
||||
f"Unsupported config type: {type(config)} (chart_type={chart_type!r})"
|
||||
)
|
||||
|
||||
form_data = plugin.to_form_data(config, dataset_id=dataset_id)
|
||||
|
||||
# Run post-map validation (e.g. BigNumber trendline temporal type check).
|
||||
# Raise ValueError to preserve backward-compatible error handling in callers.
|
||||
error = plugin.post_map_validate(config, form_data, dataset_id=dataset_id)
|
||||
if error is not None:
|
||||
raise ValueError(error.message)
|
||||
|
||||
return form_data
|
||||
|
||||
|
||||
def _add_adhoc_filters(
|
||||
@@ -1129,87 +1135,32 @@ def _big_number_chart_what(config: BigNumberChartConfig) -> str:
|
||||
|
||||
|
||||
def generate_chart_name(
|
||||
config: TableChartConfig
|
||||
| XYChartConfig
|
||||
| PieChartConfig
|
||||
| PivotTableChartConfig
|
||||
| MixedTimeseriesChartConfig
|
||||
| HandlebarsChartConfig
|
||||
| BigNumberChartConfig,
|
||||
config: Any,
|
||||
dataset_name: str | None = None,
|
||||
) -> str:
|
||||
"""Generate a descriptive chart name following a standard format.
|
||||
|
||||
Format conventions (by chart type):
|
||||
Aggregated (bar/scatter with group_by): [Metric] by [Dimension]
|
||||
Time-series (line/area, no group_by): [Metric] Over Time
|
||||
Table (no aggregates): [Dataset] Records
|
||||
Table (with aggregates): [Metric] Summary
|
||||
Pie: [Dimension] by [Metric]
|
||||
Pivot Table: Pivot Table – [Row1, Row2]
|
||||
Mixed Timeseries: [Primary] + [Secondary]
|
||||
An en-dash followed by context (filters / time grain) is appended
|
||||
Delegates to each plugin's ``generate_name()`` method.
|
||||
See each plugin's ``generate_name`` for chart-type-specific format conventions.
|
||||
An en-dash followed by context (filters / time grain) is appended by the plugin
|
||||
when such information is available.
|
||||
"""
|
||||
if isinstance(config, TableChartConfig):
|
||||
what = _table_chart_what(config, dataset_name)
|
||||
context = _summarize_filters(config.filters)
|
||||
elif isinstance(config, XYChartConfig):
|
||||
what = _xy_chart_what(config)
|
||||
context = _xy_chart_context(config)
|
||||
elif isinstance(config, PieChartConfig):
|
||||
what = _pie_chart_what(config)
|
||||
context = _summarize_filters(config.filters)
|
||||
elif isinstance(config, PivotTableChartConfig):
|
||||
what = _pivot_table_what(config)
|
||||
context = _summarize_filters(config.filters)
|
||||
elif isinstance(config, MixedTimeseriesChartConfig):
|
||||
what = _mixed_timeseries_what(config)
|
||||
context = _summarize_filters(config.filters)
|
||||
elif isinstance(config, HandlebarsChartConfig):
|
||||
what = _handlebars_chart_what(config)
|
||||
context = _summarize_filters(getattr(config, "filters", None))
|
||||
elif isinstance(config, BigNumberChartConfig):
|
||||
what = _big_number_chart_what(config)
|
||||
context = _summarize_filters(getattr(config, "filters", None))
|
||||
else:
|
||||
return "Chart"
|
||||
from superset.mcp_service.chart.registry import get_registry
|
||||
|
||||
name = what
|
||||
if context:
|
||||
name = f"{what} \u2013 {context}"
|
||||
return _truncate(name)
|
||||
plugin = get_registry().get(getattr(config, "chart_type", ""))
|
||||
if plugin is None:
|
||||
return "Chart"
|
||||
return _truncate(plugin.generate_name(config, dataset_name))
|
||||
|
||||
|
||||
def _resolve_viz_type(config: Any) -> str:
|
||||
"""Resolve the Superset viz_type from a chart config object."""
|
||||
chart_type = getattr(config, "chart_type", "unknown")
|
||||
if chart_type == "xy":
|
||||
kind = getattr(config, "kind", "line")
|
||||
viz_type_map = {
|
||||
"line": "echarts_timeseries_line",
|
||||
"bar": "echarts_timeseries_bar",
|
||||
"area": "echarts_area",
|
||||
"scatter": "echarts_timeseries_scatter",
|
||||
}
|
||||
return viz_type_map.get(kind, "echarts_timeseries_line")
|
||||
elif chart_type == "table":
|
||||
return getattr(config, "viz_type", "table")
|
||||
elif chart_type == "pie":
|
||||
return "pie"
|
||||
elif chart_type == "pivot_table":
|
||||
return "pivot_table_v2"
|
||||
elif chart_type == "mixed_timeseries":
|
||||
return "mixed_timeseries"
|
||||
elif chart_type == "handlebars":
|
||||
return "handlebars"
|
||||
elif chart_type == "big_number":
|
||||
show_trendline = getattr(config, "show_trendline", False)
|
||||
temporal_column = getattr(config, "temporal_column", None)
|
||||
return (
|
||||
"big_number" if show_trendline and temporal_column else "big_number_total"
|
||||
)
|
||||
return "unknown"
|
||||
from superset.mcp_service.chart.registry import get_registry
|
||||
|
||||
plugin = get_registry().get(getattr(config, "chart_type", ""))
|
||||
if plugin is None:
|
||||
return "unknown"
|
||||
return plugin.resolve_viz_type(config)
|
||||
|
||||
|
||||
def analyze_chart_capabilities(chart: Any | None, config: Any) -> ChartCapabilities:
|
||||
|
||||
255
superset/mcp_service/chart/plugin.py
Executable file
255
superset/mcp_service/chart/plugin.py
Executable file
@@ -0,0 +1,255 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
ChartTypePlugin protocol and BaseChartPlugin base class.
|
||||
|
||||
Each chart type owns its pre-validation, column extraction, form_data mapping,
|
||||
and post-map validation in a single plugin class. This eliminates the previous
|
||||
pattern of 4 separate dispatch points (schema_validator.py, dataset_validator.py,
|
||||
chart_utils.py, pipeline.py) that had to be updated in sync whenever a new chart
|
||||
type was added.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from superset.mcp_service.chart.schemas import ColumnRef
|
||||
from superset.mcp_service.common.error_schemas import ChartGenerationError
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ChartTypePlugin(Protocol):
|
||||
"""
|
||||
Protocol that every chart-type plugin must satisfy.
|
||||
|
||||
Implementing all eight methods in a single class guarantees that adding a
|
||||
new chart type requires only one new file — the plugin — rather than edits
|
||||
across multiple separate files.
|
||||
"""
|
||||
|
||||
#: Discriminator value matching ChartConfig's chart_type field.
|
||||
chart_type: str
|
||||
|
||||
#: Human-readable name shown to users (e.g. "Line / Bar / Area / Scatter").
|
||||
display_name: str
|
||||
|
||||
#: Maps every Superset-internal viz_type this plugin can produce to a
|
||||
#: user-facing display name, e.g. {"echarts_timeseries_line": "Line Chart"}.
|
||||
#: Used by the registry to resolve display names for existing charts without
|
||||
#: needing a separate JSON mapping file.
|
||||
native_viz_types: dict[str, str]
|
||||
|
||||
def pre_validate(
|
||||
self,
|
||||
config: dict[str, Any],
|
||||
) -> ChartGenerationError | None:
|
||||
"""
|
||||
Early validation of the raw config dict before Pydantic parsing.
|
||||
|
||||
Called by SchemaValidator before attempting to parse the request.
|
||||
Should check that required top-level keys are present and well-typed.
|
||||
|
||||
Returns None if valid, ChartGenerationError if invalid.
|
||||
"""
|
||||
...
|
||||
|
||||
def extract_column_refs(
|
||||
self,
|
||||
config: Any,
|
||||
) -> list[ColumnRef]:
|
||||
"""
|
||||
Extract all column references from a parsed chart config.
|
||||
|
||||
Called by DatasetValidator to validate that all referenced columns exist
|
||||
in the dataset. Must cover every field that holds a column name,
|
||||
including filters.
|
||||
|
||||
Returns a list of ColumnRef objects (may be empty).
|
||||
"""
|
||||
...
|
||||
|
||||
def to_form_data(
|
||||
self,
|
||||
config: Any,
|
||||
dataset_id: int | str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Map a parsed chart config to Superset's internal form_data dict.
|
||||
|
||||
Replaces the if/elif chain in chart_utils.map_config_to_form_data().
|
||||
|
||||
Returns a Superset form_data dict ready for caching and rendering.
|
||||
"""
|
||||
...
|
||||
|
||||
def post_map_validate(
|
||||
self,
|
||||
config: Any,
|
||||
form_data: dict[str, Any],
|
||||
dataset_id: int | str | None = None,
|
||||
) -> ChartGenerationError | None:
|
||||
"""
|
||||
Validate the mapped form_data after to_form_data() runs.
|
||||
|
||||
Use this for cross-field constraints that can only be checked once
|
||||
form_data is assembled (e.g. BigNumber trendline requires a temporal
|
||||
column whose type must be verified against the dataset).
|
||||
|
||||
Returns None if valid, ChartGenerationError if invalid.
|
||||
"""
|
||||
...
|
||||
|
||||
def normalize_column_refs(
|
||||
self,
|
||||
config: Any,
|
||||
dataset_context: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Return a new config with column names normalized to canonical dataset casing.
|
||||
|
||||
Called by DatasetValidator.normalize_column_names(). The default
|
||||
implementation (in BaseChartPlugin) returns the config unchanged; plugins
|
||||
with column fields override this to fix case sensitivity mismatches.
|
||||
|
||||
Returns a new config object (or the original if no normalization needed).
|
||||
"""
|
||||
...
|
||||
|
||||
def get_runtime_warnings(
|
||||
self,
|
||||
config: Any,
|
||||
dataset_id: int | str,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Return chart-type-specific runtime warnings (performance, compatibility).
|
||||
|
||||
Called by RuntimeValidator to collect per-type warnings. Warnings are
|
||||
informational only — they never block chart generation. The default
|
||||
implementation returns an empty list; plugins override this to emit
|
||||
chart-type-specific warnings (e.g. XY cardinality checks).
|
||||
|
||||
Returns a list of warning message strings (may be empty).
|
||||
"""
|
||||
...
|
||||
|
||||
def generate_name(
|
||||
self,
|
||||
config: Any,
|
||||
dataset_name: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Return a descriptive chart name for the given config.
|
||||
|
||||
Called by chart_utils.generate_chart_name(). The name should follow
|
||||
the standard format conventions documented in that function. Plugins
|
||||
that do not override this return the generic fallback "Chart".
|
||||
"""
|
||||
...
|
||||
|
||||
def resolve_viz_type(self, config: Any) -> str:
|
||||
"""
|
||||
Return the Superset-internal viz_type string for this config.
|
||||
|
||||
Called by chart_utils._resolve_viz_type(). The returned string must
|
||||
match a registered Superset viz plugin (e.g. "echarts_timeseries_line").
|
||||
Plugins that do not override this return "unknown".
|
||||
"""
|
||||
...
|
||||
|
||||
def schema_error_hint(self) -> "ChartGenerationError | None":
|
||||
"""
|
||||
Return a user-friendly error for Pydantic discriminated-union parse failures.
|
||||
|
||||
Called by SchemaValidator when Pydantic cannot parse the config union and
|
||||
the chart_type is known. Returning None falls back to the generic error.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class BaseChartPlugin:
|
||||
"""
|
||||
Base class providing sensible defaults for all ChartTypePlugin methods.
|
||||
|
||||
Concrete plugins extend this and override only what they need.
|
||||
"""
|
||||
|
||||
chart_type: str = ""
|
||||
display_name: str = ""
|
||||
native_viz_types: dict[str, str] = {}
|
||||
|
||||
def pre_validate(
|
||||
self,
|
||||
config: dict[str, Any],
|
||||
) -> ChartGenerationError | None:
|
||||
return None
|
||||
|
||||
def extract_column_refs(
|
||||
self,
|
||||
config: Any,
|
||||
) -> list[ColumnRef]:
|
||||
return []
|
||||
|
||||
def to_form_data(
|
||||
self,
|
||||
config: Any,
|
||||
dataset_id: int | str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__}.to_form_data() is not implemented"
|
||||
)
|
||||
|
||||
def post_map_validate(
|
||||
self,
|
||||
config: Any,
|
||||
form_data: dict[str, Any],
|
||||
dataset_id: int | str | None = None,
|
||||
) -> ChartGenerationError | None:
|
||||
return None
|
||||
|
||||
def normalize_column_refs(
|
||||
self,
|
||||
config: Any,
|
||||
dataset_context: Any,
|
||||
) -> Any:
|
||||
return config
|
||||
|
||||
def get_runtime_warnings(
|
||||
self,
|
||||
config: Any,
|
||||
dataset_id: int | str,
|
||||
) -> list[str]:
|
||||
return []
|
||||
|
||||
def generate_name(
|
||||
self,
|
||||
config: Any,
|
||||
dataset_name: str | None = None,
|
||||
) -> str:
|
||||
return "Chart"
|
||||
|
||||
def resolve_viz_type(self, config: Any) -> str:
|
||||
return "unknown"
|
||||
|
||||
def schema_error_hint(self) -> ChartGenerationError | None:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _with_context(what: str, context: str | None) -> str:
|
||||
"""Combine a 'what' label and optional context with an en-dash."""
|
||||
return f"{what} – {context}" if context else what
|
||||
58
superset/mcp_service/chart/plugins/__init__.py
Normal file
58
superset/mcp_service/chart/plugins/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Chart type plugins package.
|
||||
|
||||
Importing this module registers all built-in chart type plugins in the global
|
||||
registry. This module is imported by app.py at startup.
|
||||
|
||||
To add a new chart type:
|
||||
1. Create ``superset/mcp_service/chart/plugins/{chart_type}.py``
|
||||
2. Implement a class extending ``BaseChartPlugin``
|
||||
3. Import and register it here
|
||||
"""
|
||||
|
||||
from superset.mcp_service.chart.plugins.big_number import BigNumberChartPlugin
|
||||
from superset.mcp_service.chart.plugins.handlebars import HandlebarsChartPlugin
|
||||
from superset.mcp_service.chart.plugins.mixed_timeseries import (
|
||||
MixedTimeseriesChartPlugin,
|
||||
)
|
||||
from superset.mcp_service.chart.plugins.pie import PieChartPlugin
|
||||
from superset.mcp_service.chart.plugins.pivot_table import PivotTableChartPlugin
|
||||
from superset.mcp_service.chart.plugins.table import TableChartPlugin
|
||||
from superset.mcp_service.chart.plugins.xy import XYChartPlugin
|
||||
from superset.mcp_service.chart.registry import register
|
||||
|
||||
# Register all built-in chart type plugins
|
||||
register(XYChartPlugin())
|
||||
register(TableChartPlugin())
|
||||
register(PieChartPlugin())
|
||||
register(PivotTableChartPlugin())
|
||||
register(MixedTimeseriesChartPlugin())
|
||||
register(HandlebarsChartPlugin())
|
||||
register(BigNumberChartPlugin())
|
||||
|
||||
__all__ = [
|
||||
"BigNumberChartPlugin",
|
||||
"HandlebarsChartPlugin",
|
||||
"MixedTimeseriesChartPlugin",
|
||||
"PieChartPlugin",
|
||||
"PivotTableChartPlugin",
|
||||
"TableChartPlugin",
|
||||
"XYChartPlugin",
|
||||
]
|
||||
220
superset/mcp_service/chart/plugins/big_number.py
Executable file
220
superset/mcp_service/chart/plugins/big_number.py
Executable file
@@ -0,0 +1,220 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Big number chart type plugin."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from superset.mcp_service.chart.chart_utils import (
|
||||
_big_number_chart_what,
|
||||
_summarize_filters,
|
||||
is_column_truly_temporal,
|
||||
map_big_number_config,
|
||||
)
|
||||
from superset.mcp_service.chart.plugin import BaseChartPlugin
|
||||
from superset.mcp_service.chart.schemas import BigNumberChartConfig, ColumnRef
|
||||
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
|
||||
from superset.mcp_service.common.error_schemas import ChartGenerationError
|
||||
|
||||
|
||||
class BigNumberChartPlugin(BaseChartPlugin):
|
||||
"""Plugin for big_number chart type."""
|
||||
|
||||
chart_type = "big_number"
|
||||
display_name = "Big Number"
|
||||
native_viz_types = {
|
||||
"big_number": "Big Number with Trendline",
|
||||
"big_number_total": "Big Number",
|
||||
}
|
||||
|
||||
def pre_validate(
|
||||
self,
|
||||
config: dict[str, Any],
|
||||
) -> ChartGenerationError | None:
|
||||
if "metric" not in config:
|
||||
return ChartGenerationError(
|
||||
error_type="missing_metric",
|
||||
message="Big Number chart missing required field: metric",
|
||||
details=(
|
||||
"Big Number charts require a 'metric' field "
|
||||
"specifying the value to display"
|
||||
),
|
||||
suggestions=[
|
||||
"Add 'metric' with name and aggregate: "
|
||||
"{'name': 'revenue', 'aggregate': 'SUM'}",
|
||||
"The aggregate function is required (SUM, COUNT, AVG, MIN, MAX)",
|
||||
"Example: {'chart_type': 'big_number', "
|
||||
"'metric': {'name': 'sales', 'aggregate': 'SUM'}}",
|
||||
],
|
||||
error_code="MISSING_BIG_NUMBER_METRIC",
|
||||
)
|
||||
|
||||
metric = config.get("metric", {})
|
||||
if not isinstance(metric, dict):
|
||||
return ChartGenerationError(
|
||||
error_type="invalid_metric_type",
|
||||
message="Big Number metric must be a dict with 'name' and 'aggregate'",
|
||||
details=(
|
||||
f"The 'metric' field must be an object, got {type(metric).__name__}"
|
||||
),
|
||||
suggestions=[
|
||||
"Use a dict: {'name': 'col', 'aggregate': 'SUM'}",
|
||||
"Valid aggregates: SUM, COUNT, AVG, MIN, MAX",
|
||||
],
|
||||
error_code="INVALID_BIG_NUMBER_METRIC_TYPE",
|
||||
)
|
||||
if not metric.get("aggregate") and not metric.get("saved_metric"):
|
||||
return ChartGenerationError(
|
||||
error_type="missing_metric_aggregate",
|
||||
message=(
|
||||
"Big Number metric must include an aggregate function "
|
||||
"or reference a saved metric"
|
||||
),
|
||||
details=(
|
||||
"The metric must have an 'aggregate' field or 'saved_metric': true"
|
||||
),
|
||||
suggestions=[
|
||||
"Add 'aggregate': {'name': 'col', 'aggregate': 'SUM'}",
|
||||
"Or use a saved metric: {'name': 'metric', 'saved_metric': true}",
|
||||
"Valid aggregates: SUM, COUNT, AVG, MIN, MAX",
|
||||
],
|
||||
error_code="MISSING_BIG_NUMBER_AGGREGATE",
|
||||
)
|
||||
|
||||
show_trendline = config.get("show_trendline", False)
|
||||
temporal_column = config.get("temporal_column")
|
||||
if show_trendline and not temporal_column:
|
||||
return ChartGenerationError(
|
||||
error_type="missing_temporal_column",
|
||||
message="Trendline requires a temporal column",
|
||||
details=(
|
||||
"When 'show_trendline' is True, "
|
||||
"a 'temporal_column' must be specified"
|
||||
),
|
||||
suggestions=[
|
||||
"Add 'temporal_column': 'date_column_name'",
|
||||
"Or set 'show_trendline': false for number only",
|
||||
"Use get_dataset_info to find temporal columns",
|
||||
],
|
||||
error_code="MISSING_TEMPORAL_COLUMN",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
|
||||
if not isinstance(config, BigNumberChartConfig):
|
||||
return []
|
||||
refs: list[ColumnRef] = [config.metric]
|
||||
# temporal_column is a str field, not a ColumnRef — validate it exists
|
||||
if config.temporal_column:
|
||||
refs.append(ColumnRef(name=config.temporal_column))
|
||||
if config.filters:
|
||||
for f in config.filters:
|
||||
refs.append(ColumnRef(name=f.column))
|
||||
return refs
|
||||
|
||||
def to_form_data(
|
||||
self, config: Any, dataset_id: int | str | None = None
|
||||
) -> dict[str, Any]:
|
||||
return map_big_number_config(config)
|
||||
|
||||
def post_map_validate(
|
||||
self,
|
||||
config: Any,
|
||||
form_data: dict[str, Any],
|
||||
dataset_id: int | str | None = None,
|
||||
) -> ChartGenerationError | None:
|
||||
"""Verify the trendline temporal column is a real temporal SQL type.
|
||||
|
||||
This check was previously baked into map_config_to_form_data() in
|
||||
chart_utils.py as a special case. Moving it here keeps the dispatcher
|
||||
clean and makes the constraint explicit and discoverable.
|
||||
"""
|
||||
if not isinstance(config, BigNumberChartConfig):
|
||||
return None
|
||||
if not (config.show_trendline and config.temporal_column):
|
||||
return None
|
||||
|
||||
if not is_column_truly_temporal(config.temporal_column, dataset_id):
|
||||
return ChartGenerationError(
|
||||
error_type="non_temporal_trendline_column",
|
||||
message=(
|
||||
f"Big Number trendline requires a temporal SQL column; "
|
||||
f"'{config.temporal_column}' is not temporal."
|
||||
),
|
||||
details=(
|
||||
f"Column '{config.temporal_column}' does not have a temporal "
|
||||
f"SQL type (DATE, DATETIME, TIMESTAMP). The trendline requires "
|
||||
f"a true temporal column for DATE_TRUNC to work."
|
||||
),
|
||||
suggestions=[
|
||||
"Use get_dataset_info to find columns with temporal SQL types",
|
||||
"Set 'show_trendline': false to use any column as the metric",
|
||||
"If the column contains dates stored as integers, "
|
||||
"consider casting it in a virtual dataset",
|
||||
],
|
||||
error_code="NON_TEMPORAL_TRENDLINE_COLUMN",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
|
||||
what = _big_number_chart_what(config)
|
||||
context = _summarize_filters(getattr(config, "filters", None))
|
||||
return self._with_context(what, context)
|
||||
|
||||
def resolve_viz_type(self, config: Any) -> str:
|
||||
show_trendline = getattr(config, "show_trendline", False)
|
||||
temporal_column = getattr(config, "temporal_column", None)
|
||||
if show_trendline and temporal_column:
|
||||
return "big_number"
|
||||
return "big_number_total"
|
||||
|
||||
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
|
||||
config_dict = config.model_dump()
|
||||
|
||||
if config_dict.get("metric") and not config_dict["metric"].get("saved_metric"):
|
||||
config_dict["metric"]["name"] = DatasetValidator._get_canonical_column_name(
|
||||
config_dict["metric"]["name"], dataset_context
|
||||
)
|
||||
if config_dict.get("temporal_column"):
|
||||
config_dict["temporal_column"] = (
|
||||
DatasetValidator._get_canonical_column_name(
|
||||
config_dict["temporal_column"], dataset_context
|
||||
)
|
||||
)
|
||||
DatasetValidator._normalize_filters(config_dict, dataset_context)
|
||||
return BigNumberChartConfig.model_validate(config_dict)
|
||||
|
||||
def schema_error_hint(self) -> ChartGenerationError | None:
|
||||
return ChartGenerationError(
|
||||
error_type="big_number_validation_error",
|
||||
message="Big Number chart configuration validation failed",
|
||||
details=(
|
||||
"The Big Number chart configuration is missing required "
|
||||
"fields or has invalid structure"
|
||||
),
|
||||
suggestions=[
|
||||
"Ensure 'metric' field has 'name' and 'aggregate'",
|
||||
"Example: 'metric': {'name': 'revenue', 'aggregate': 'SUM'}",
|
||||
"For trendline: add show_trendline=true and temporal_column='col'",
|
||||
"Without trendline: just provide the metric",
|
||||
],
|
||||
error_code="BIG_NUMBER_VALIDATION_ERROR",
|
||||
)
|
||||
189
superset/mcp_service/chart/plugins/handlebars.py
Executable file
189
superset/mcp_service/chart/plugins/handlebars.py
Executable file
@@ -0,0 +1,189 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Handlebars chart type plugin."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from superset.mcp_service.chart.chart_utils import (
|
||||
_handlebars_chart_what,
|
||||
_summarize_filters,
|
||||
map_handlebars_config,
|
||||
)
|
||||
from superset.mcp_service.chart.plugin import BaseChartPlugin
|
||||
from superset.mcp_service.chart.schemas import ColumnRef, HandlebarsChartConfig
|
||||
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
|
||||
from superset.mcp_service.common.error_schemas import ChartGenerationError
|
||||
|
||||
|
||||
class HandlebarsChartPlugin(BaseChartPlugin):
|
||||
"""Plugin for handlebars chart type (custom HTML template charts)."""
|
||||
|
||||
chart_type = "handlebars"
|
||||
display_name = "Handlebars (Custom Template)"
|
||||
native_viz_types = {
|
||||
"handlebars": "Custom Template Chart",
|
||||
}
|
||||
|
||||
def pre_validate(
|
||||
self,
|
||||
config: dict[str, Any],
|
||||
) -> ChartGenerationError | None:
|
||||
if "handlebars_template" not in config:
|
||||
return ChartGenerationError(
|
||||
error_type="missing_handlebars_template",
|
||||
message="Handlebars chart missing required field: handlebars_template",
|
||||
details=(
|
||||
"Handlebars charts require a 'handlebars_template' string "
|
||||
"containing Handlebars HTML template markup"
|
||||
),
|
||||
suggestions=[
|
||||
"Add 'handlebars_template' with a Handlebars HTML template",
|
||||
"Data is available as {{data}} array in the template",
|
||||
"Example: '<ul>{{#each data}}<li>{{this.name}}: "
|
||||
"{{this.value}}</li>{{/each}}</ul>'",
|
||||
],
|
||||
error_code="MISSING_HANDLEBARS_TEMPLATE",
|
||||
)
|
||||
|
||||
template = config.get("handlebars_template")
|
||||
if not isinstance(template, str) or not template.strip():
|
||||
return ChartGenerationError(
|
||||
error_type="invalid_handlebars_template",
|
||||
message="Handlebars template must be a non-empty string",
|
||||
details=(
|
||||
"The 'handlebars_template' field must be a non-empty string "
|
||||
"containing valid Handlebars HTML template markup"
|
||||
),
|
||||
suggestions=[
|
||||
"Ensure handlebars_template is a non-empty string",
|
||||
"Example: '<ul>{{#each data}}<li>{{this.name}}</li>{{/each}}</ul>'",
|
||||
],
|
||||
error_code="INVALID_HANDLEBARS_TEMPLATE",
|
||||
)
|
||||
|
||||
query_mode = config.get("query_mode", "aggregate")
|
||||
if query_mode not in ("aggregate", "raw"):
|
||||
return ChartGenerationError(
|
||||
error_type="invalid_query_mode",
|
||||
message="Invalid query_mode for handlebars chart",
|
||||
details="query_mode must be either 'aggregate' or 'raw'",
|
||||
suggestions=[
|
||||
"Use 'aggregate' for aggregated data (default)",
|
||||
"Use 'raw' for individual rows",
|
||||
],
|
||||
error_code="INVALID_QUERY_MODE",
|
||||
)
|
||||
|
||||
if query_mode == "raw" and not config.get("columns"):
|
||||
return ChartGenerationError(
|
||||
error_type="missing_raw_columns",
|
||||
message="Handlebars chart in 'raw' mode requires 'columns'",
|
||||
details=(
|
||||
"When query_mode is 'raw', you must specify which columns "
|
||||
"to include in the query results"
|
||||
),
|
||||
suggestions=[
|
||||
"Add 'columns': [{'name': 'column_name'}] for raw mode",
|
||||
"Or use query_mode='aggregate' with 'metrics' and optional 'groupby'", # noqa: E501
|
||||
],
|
||||
error_code="MISSING_RAW_COLUMNS",
|
||||
)
|
||||
|
||||
if query_mode == "aggregate" and not config.get("metrics"):
|
||||
return ChartGenerationError(
|
||||
error_type="missing_aggregate_metrics",
|
||||
message="Handlebars chart in 'aggregate' mode requires 'metrics'",
|
||||
details=(
|
||||
"When query_mode is 'aggregate' (default), you must specify "
|
||||
"at least one metric with an aggregate function"
|
||||
),
|
||||
suggestions=[
|
||||
"Add 'metrics': [{'name': 'column', 'aggregate': 'SUM'}]",
|
||||
"Or use query_mode='raw' with 'columns' for individual rows",
|
||||
],
|
||||
error_code="MISSING_AGGREGATE_METRICS",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
|
||||
if not isinstance(config, HandlebarsChartConfig):
|
||||
return []
|
||||
refs: list[ColumnRef] = []
|
||||
if config.columns:
|
||||
refs.extend(config.columns)
|
||||
if config.metrics:
|
||||
refs.extend(config.metrics)
|
||||
if config.groupby:
|
||||
refs.extend(config.groupby)
|
||||
if config.filters:
|
||||
for f in config.filters:
|
||||
refs.append(ColumnRef(name=f.column))
|
||||
return refs
|
||||
|
||||
def to_form_data(
|
||||
self, config: Any, dataset_id: int | str | None = None
|
||||
) -> dict[str, Any]:
|
||||
return map_handlebars_config(config)
|
||||
|
||||
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
|
||||
what = _handlebars_chart_what(config)
|
||||
context = _summarize_filters(getattr(config, "filters", None))
|
||||
return self._with_context(what, context)
|
||||
|
||||
def resolve_viz_type(self, config: Any) -> str:
|
||||
return "handlebars"
|
||||
|
||||
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
|
||||
config_dict = config.model_dump()
|
||||
|
||||
def _norm_list(key: str) -> None:
|
||||
if config_dict.get(key):
|
||||
for col in config_dict[key]:
|
||||
if not col.get("saved_metric"):
|
||||
col["name"] = DatasetValidator._get_canonical_column_name(
|
||||
col["name"], dataset_context
|
||||
)
|
||||
|
||||
_norm_list("columns")
|
||||
_norm_list("metrics")
|
||||
_norm_list("groupby")
|
||||
DatasetValidator._normalize_filters(config_dict, dataset_context)
|
||||
return HandlebarsChartConfig.model_validate(config_dict)
|
||||
|
||||
def schema_error_hint(self) -> ChartGenerationError | None:
|
||||
return ChartGenerationError(
|
||||
error_type="handlebars_validation_error",
|
||||
message="Handlebars chart configuration validation failed",
|
||||
details=(
|
||||
"The handlebars chart configuration is missing "
|
||||
"required fields or has invalid structure"
|
||||
),
|
||||
suggestions=[
|
||||
"Ensure 'handlebars_template' is a non-empty string",
|
||||
"For aggregate mode: add 'metrics' with aggregate functions",
|
||||
"For raw mode: set 'query_mode': 'raw' and add 'columns'",
|
||||
"Example: {'chart_type': 'handlebars', "
|
||||
"'handlebars_template': "
|
||||
"'<ul>{{#each data}}<li>{{this.name}}</li>{{/each}}</ul>', "
|
||||
"'metrics': [{'name': 'sales', 'aggregate': 'SUM'}]}",
|
||||
],
|
||||
error_code="HANDLEBARS_VALIDATION_ERROR",
|
||||
)
|
||||
165
superset/mcp_service/chart/plugins/mixed_timeseries.py
Executable file
165
superset/mcp_service/chart/plugins/mixed_timeseries.py
Executable file
@@ -0,0 +1,165 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Mixed timeseries chart type plugin."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from superset.mcp_service.chart.chart_utils import (
|
||||
_mixed_timeseries_what,
|
||||
_summarize_filters,
|
||||
map_mixed_timeseries_config,
|
||||
)
|
||||
from superset.mcp_service.chart.plugin import BaseChartPlugin
|
||||
from superset.mcp_service.chart.schemas import ColumnRef, MixedTimeseriesChartConfig
|
||||
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
|
||||
from superset.mcp_service.common.error_schemas import ChartGenerationError
|
||||
|
||||
|
||||
class MixedTimeseriesChartPlugin(BaseChartPlugin):
|
||||
"""Plugin for mixed_timeseries chart type."""
|
||||
|
||||
chart_type = "mixed_timeseries"
|
||||
display_name = "Mixed Timeseries"
|
||||
native_viz_types = {
|
||||
"mixed_timeseries": "Mixed Timeseries Chart",
|
||||
}
|
||||
|
||||
def pre_validate(
|
||||
self,
|
||||
config: dict[str, Any],
|
||||
) -> ChartGenerationError | None:
|
||||
missing_fields = []
|
||||
|
||||
if "x" not in config:
|
||||
missing_fields.append("'x' (X-axis temporal column)")
|
||||
if "y" not in config:
|
||||
missing_fields.append("'y' (primary Y-axis metrics)")
|
||||
if "y_secondary" not in config:
|
||||
missing_fields.append("'y_secondary' (secondary Y-axis metrics)")
|
||||
|
||||
if missing_fields:
|
||||
return ChartGenerationError(
|
||||
error_type="missing_mixed_timeseries_fields",
|
||||
message=(
|
||||
f"Mixed timeseries chart missing required fields: "
|
||||
f"{', '.join(missing_fields)}"
|
||||
),
|
||||
details=(
|
||||
"Mixed timeseries charts require an x-axis, primary metrics, "
|
||||
"and secondary metrics"
|
||||
),
|
||||
suggestions=[
|
||||
"Add 'x' field: {'name': 'date_column'}",
|
||||
"Add 'y' field: [{'name': 'revenue', 'aggregate': 'SUM'}]",
|
||||
"Add 'y_secondary': [{'name': 'orders', 'aggregate': 'COUNT'}]",
|
||||
"Optional: 'primary_kind' and 'secondary_kind' for chart types",
|
||||
],
|
||||
error_code="MISSING_MIXED_TIMESERIES_FIELDS",
|
||||
)
|
||||
|
||||
for field_name in ["y", "y_secondary"]:
|
||||
if not isinstance(config.get(field_name, []), list):
|
||||
return ChartGenerationError(
|
||||
error_type=f"invalid_{field_name}_format",
|
||||
message=f"'{field_name}' must be a list of metrics",
|
||||
details=(
|
||||
f"The '{field_name}' field must be an array of metric "
|
||||
"specifications"
|
||||
),
|
||||
suggestions=[
|
||||
f"Wrap in array: '{field_name}': "
|
||||
"[{'name': 'col', 'aggregate': 'SUM'}]",
|
||||
],
|
||||
error_code=f"INVALID_{field_name.upper()}_FORMAT",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
|
||||
if not isinstance(config, MixedTimeseriesChartConfig):
|
||||
return []
|
||||
refs: list[ColumnRef] = [config.x]
|
||||
refs.extend(config.y)
|
||||
refs.extend(config.y_secondary)
|
||||
if config.group_by:
|
||||
refs.extend(config.group_by)
|
||||
if config.group_by_secondary:
|
||||
refs.extend(config.group_by_secondary)
|
||||
if config.filters:
|
||||
for f in config.filters:
|
||||
refs.append(ColumnRef(name=f.column))
|
||||
return refs
|
||||
|
||||
def to_form_data(
|
||||
self, config: Any, dataset_id: int | str | None = None
|
||||
) -> dict[str, Any]:
|
||||
return map_mixed_timeseries_config(config, dataset_id=dataset_id)
|
||||
|
||||
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
|
||||
what = _mixed_timeseries_what(config)
|
||||
context = _summarize_filters(config.filters)
|
||||
return self._with_context(what, context)
|
||||
|
||||
def resolve_viz_type(self, config: Any) -> str:
|
||||
return "mixed_timeseries"
|
||||
|
||||
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
|
||||
config_dict = config.model_dump()
|
||||
|
||||
def _norm_single(key: str) -> None:
|
||||
if config_dict.get(key):
|
||||
config_dict[key]["name"] = DatasetValidator._get_canonical_column_name(
|
||||
config_dict[key]["name"], dataset_context
|
||||
)
|
||||
|
||||
def _norm_list(key: str) -> None:
|
||||
if config_dict.get(key):
|
||||
for col in config_dict[key]:
|
||||
col["name"] = DatasetValidator._get_canonical_column_name(
|
||||
col["name"], dataset_context
|
||||
)
|
||||
|
||||
_norm_single("x")
|
||||
_norm_list("y")
|
||||
_norm_list("y_secondary")
|
||||
_norm_list("group_by")
|
||||
_norm_list("group_by_secondary")
|
||||
DatasetValidator._normalize_filters(config_dict, dataset_context)
|
||||
return MixedTimeseriesChartConfig.model_validate(config_dict)
|
||||
|
||||
def schema_error_hint(self) -> ChartGenerationError | None:
|
||||
return ChartGenerationError(
|
||||
error_type="mixed_timeseries_validation_error",
|
||||
message="Mixed timeseries chart configuration validation failed",
|
||||
details=(
|
||||
"The mixed timeseries configuration is missing "
|
||||
"required fields or has invalid structure"
|
||||
),
|
||||
suggestions=[
|
||||
"Ensure 'x' field has 'name' for the time axis column",
|
||||
"Ensure 'y' is an array of primary-axis metrics",
|
||||
"Ensure 'y_secondary' is an array of secondary-axis metrics",
|
||||
"Example: {'chart_type': 'mixed_timeseries', "
|
||||
"'x': {'name': 'order_date'}, "
|
||||
"'y': [{'name': 'revenue', 'aggregate': 'SUM'}], "
|
||||
"'y_secondary': [{'name': 'orders', 'aggregate': 'COUNT'}]}",
|
||||
],
|
||||
error_code="MIXED_TIMESERIES_VALIDATION_ERROR",
|
||||
)
|
||||
128
superset/mcp_service/chart/plugins/pie.py
Executable file
128
superset/mcp_service/chart/plugins/pie.py
Executable file
@@ -0,0 +1,128 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Pie chart type plugin."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from superset.mcp_service.chart.chart_utils import (
|
||||
_pie_chart_what,
|
||||
_summarize_filters,
|
||||
map_pie_config,
|
||||
)
|
||||
from superset.mcp_service.chart.plugin import BaseChartPlugin
|
||||
from superset.mcp_service.chart.schemas import ColumnRef, PieChartConfig
|
||||
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
|
||||
from superset.mcp_service.common.error_schemas import ChartGenerationError
|
||||
|
||||
|
||||
class PieChartPlugin(BaseChartPlugin):
|
||||
"""Plugin for pie chart type."""
|
||||
|
||||
chart_type = "pie"
|
||||
display_name = "Pie / Donut Chart"
|
||||
native_viz_types = {
|
||||
"pie": "Pie Chart",
|
||||
}
|
||||
|
||||
def pre_validate(
|
||||
self,
|
||||
config: dict[str, Any],
|
||||
) -> ChartGenerationError | None:
|
||||
missing_fields = []
|
||||
|
||||
if "dimension" not in config:
|
||||
missing_fields.append("'dimension' (category column for slices)")
|
||||
if "metric" not in config:
|
||||
missing_fields.append("'metric' (value metric for slice sizes)")
|
||||
|
||||
if missing_fields:
|
||||
return ChartGenerationError(
|
||||
error_type="missing_pie_fields",
|
||||
message=(
|
||||
f"Pie chart missing required fields: {', '.join(missing_fields)}"
|
||||
),
|
||||
details=(
|
||||
"Pie charts require a dimension (categories) and a metric (values)"
|
||||
),
|
||||
suggestions=[
|
||||
"Add 'dimension' field: {'name': 'category_column'}",
|
||||
"Add 'metric' field: {'name': 'value_column', 'aggregate': 'SUM'}",
|
||||
"Example: {'chart_type': 'pie', 'dimension': {'name': 'product'}, "
|
||||
"'metric': {'name': 'revenue', 'aggregate': 'SUM'}}",
|
||||
],
|
||||
error_code="MISSING_PIE_FIELDS",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
|
||||
if not isinstance(config, PieChartConfig):
|
||||
return []
|
||||
refs: list[ColumnRef] = [config.dimension, config.metric]
|
||||
if config.filters:
|
||||
for f in config.filters:
|
||||
refs.append(ColumnRef(name=f.column))
|
||||
return refs
|
||||
|
||||
def to_form_data(
|
||||
self, config: Any, dataset_id: int | str | None = None
|
||||
) -> dict[str, Any]:
|
||||
return map_pie_config(config)
|
||||
|
||||
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
|
||||
what = _pie_chart_what(config)
|
||||
context = _summarize_filters(config.filters)
|
||||
return self._with_context(what, context)
|
||||
|
||||
def resolve_viz_type(self, config: Any) -> str:
|
||||
return "pie"
|
||||
|
||||
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
|
||||
config_dict = config.model_dump()
|
||||
|
||||
if config_dict.get("dimension"):
|
||||
config_dict["dimension"]["name"] = (
|
||||
DatasetValidator._get_canonical_column_name(
|
||||
config_dict["dimension"]["name"], dataset_context
|
||||
)
|
||||
)
|
||||
if config_dict.get("metric") and not config_dict["metric"].get("saved_metric"):
|
||||
config_dict["metric"]["name"] = DatasetValidator._get_canonical_column_name(
|
||||
config_dict["metric"]["name"], dataset_context
|
||||
)
|
||||
DatasetValidator._normalize_filters(config_dict, dataset_context)
|
||||
return PieChartConfig.model_validate(config_dict)
|
||||
|
||||
def schema_error_hint(self) -> ChartGenerationError | None:
|
||||
return ChartGenerationError(
|
||||
error_type="pie_validation_error",
|
||||
message="Pie chart configuration validation failed",
|
||||
details=(
|
||||
"The pie chart configuration is missing required "
|
||||
"fields or has invalid structure"
|
||||
),
|
||||
suggestions=[
|
||||
"Ensure 'dimension' field has 'name' for the slice label",
|
||||
"Ensure 'metric' field has 'name' and 'aggregate'",
|
||||
"Example: {'chart_type': 'pie', 'dimension': {'name': 'category'}, "
|
||||
"'metric': {'name': 'revenue', 'aggregate': 'SUM'}}",
|
||||
],
|
||||
error_code="PIE_VALIDATION_ERROR",
|
||||
)
|
||||
153
superset/mcp_service/chart/plugins/pivot_table.py
Executable file
153
superset/mcp_service/chart/plugins/pivot_table.py
Executable file
@@ -0,0 +1,153 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Pivot table chart type plugin."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from superset.mcp_service.chart.chart_utils import (
|
||||
_pivot_table_what,
|
||||
_summarize_filters,
|
||||
map_pivot_table_config,
|
||||
)
|
||||
from superset.mcp_service.chart.plugin import BaseChartPlugin
|
||||
from superset.mcp_service.chart.schemas import ColumnRef, PivotTableChartConfig
|
||||
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
|
||||
from superset.mcp_service.common.error_schemas import ChartGenerationError
|
||||
|
||||
|
||||
class PivotTableChartPlugin(BaseChartPlugin):
|
||||
"""Plugin for pivot_table chart type."""
|
||||
|
||||
chart_type = "pivot_table"
|
||||
display_name = "Pivot Table"
|
||||
native_viz_types = {
|
||||
"pivot_table_v2": "Pivot Table",
|
||||
}
|
||||
|
||||
def pre_validate(
|
||||
self,
|
||||
config: dict[str, Any],
|
||||
) -> ChartGenerationError | None:
|
||||
missing_fields = []
|
||||
|
||||
if "rows" not in config:
|
||||
missing_fields.append("'rows' (row grouping columns)")
|
||||
if "metrics" not in config:
|
||||
missing_fields.append("'metrics' (aggregation metrics)")
|
||||
|
||||
if missing_fields:
|
||||
return ChartGenerationError(
|
||||
error_type="missing_pivot_fields",
|
||||
message=(
|
||||
f"Pivot table missing required fields: {', '.join(missing_fields)}"
|
||||
),
|
||||
details="Pivot tables require row groupings and metrics",
|
||||
suggestions=[
|
||||
"Add 'rows' field: [{'name': 'category'}]",
|
||||
"Add 'metrics' field: [{'name': 'sales', 'aggregate': 'SUM'}]",
|
||||
"Optional 'columns' for cross-tabulation: [{'name': 'region'}]",
|
||||
],
|
||||
error_code="MISSING_PIVOT_FIELDS",
|
||||
)
|
||||
|
||||
if not isinstance(config.get("rows", []), list):
|
||||
return ChartGenerationError(
|
||||
error_type="invalid_rows_format",
|
||||
message="Rows must be a list of columns",
|
||||
details="The 'rows' field must be an array of column specifications",
|
||||
suggestions=[
|
||||
"Wrap row columns in array: 'rows': [{'name': 'category'}]",
|
||||
],
|
||||
error_code="INVALID_ROWS_FORMAT",
|
||||
)
|
||||
|
||||
if not isinstance(config.get("metrics", []), list):
|
||||
return ChartGenerationError(
|
||||
error_type="invalid_metrics_format",
|
||||
message="Metrics must be a list",
|
||||
details="The 'metrics' field must be an array of metric specifications",
|
||||
suggestions=[
|
||||
"Wrap metrics in array: 'metrics': [{'name': 'sales', "
|
||||
"'aggregate': 'SUM'}]",
|
||||
],
|
||||
error_code="INVALID_METRICS_FORMAT",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
|
||||
if not isinstance(config, PivotTableChartConfig):
|
||||
return []
|
||||
refs: list[ColumnRef] = list(config.rows)
|
||||
refs.extend(config.metrics)
|
||||
if config.columns:
|
||||
refs.extend(config.columns)
|
||||
if config.filters:
|
||||
for f in config.filters:
|
||||
refs.append(ColumnRef(name=f.column))
|
||||
return refs
|
||||
|
||||
def to_form_data(
|
||||
self, config: Any, dataset_id: int | str | None = None
|
||||
) -> dict[str, Any]:
|
||||
return map_pivot_table_config(config)
|
||||
|
||||
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
|
||||
what = _pivot_table_what(config)
|
||||
context = _summarize_filters(config.filters)
|
||||
return self._with_context(what, context)
|
||||
|
||||
def resolve_viz_type(self, config: Any) -> str:
|
||||
return "pivot_table_v2"
|
||||
|
||||
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
|
||||
config_dict = config.model_dump()
|
||||
|
||||
def _norm_col_list(key: str) -> None:
|
||||
if config_dict.get(key):
|
||||
for col in config_dict[key]:
|
||||
col["name"] = DatasetValidator._get_canonical_column_name(
|
||||
col["name"], dataset_context
|
||||
)
|
||||
|
||||
_norm_col_list("rows")
|
||||
_norm_col_list("metrics")
|
||||
_norm_col_list("columns")
|
||||
DatasetValidator._normalize_filters(config_dict, dataset_context)
|
||||
return PivotTableChartConfig.model_validate(config_dict)
|
||||
|
||||
def schema_error_hint(self) -> ChartGenerationError | None:
|
||||
return ChartGenerationError(
|
||||
error_type="pivot_table_validation_error",
|
||||
message="Pivot table configuration validation failed",
|
||||
details=(
|
||||
"The pivot table configuration is missing required "
|
||||
"fields or has invalid structure"
|
||||
),
|
||||
suggestions=[
|
||||
"Ensure 'rows' field is an array of column specs",
|
||||
"Ensure 'metrics' field is an array with aggregate funcs",
|
||||
"Optional: add 'columns' for column grouping",
|
||||
"Example: {'chart_type': 'pivot_table', "
|
||||
"'rows': [{'name': 'region'}], "
|
||||
"'metrics': [{'name': 'revenue', 'aggregate': 'SUM'}]}",
|
||||
],
|
||||
error_code="PIVOT_TABLE_VALIDATION_ERROR",
|
||||
)
|
||||
128
superset/mcp_service/chart/plugins/table.py
Executable file
128
superset/mcp_service/chart/plugins/table.py
Executable file
@@ -0,0 +1,128 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Table chart type plugin."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from superset.mcp_service.chart.chart_utils import (
|
||||
_summarize_filters,
|
||||
_table_chart_what,
|
||||
map_table_config,
|
||||
)
|
||||
from superset.mcp_service.chart.plugin import BaseChartPlugin
|
||||
from superset.mcp_service.chart.schemas import ColumnRef, TableChartConfig
|
||||
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
|
||||
from superset.mcp_service.common.error_schemas import ChartGenerationError
|
||||
|
||||
|
||||
class TableChartPlugin(BaseChartPlugin):
|
||||
"""Plugin for table chart type."""
|
||||
|
||||
chart_type = "table"
|
||||
display_name = "Table"
|
||||
native_viz_types = {
|
||||
"table": "Table",
|
||||
"ag-grid-table": "Interactive Table",
|
||||
}
|
||||
|
||||
def pre_validate(
|
||||
self,
|
||||
config: dict[str, Any],
|
||||
) -> ChartGenerationError | None:
|
||||
if "columns" not in config:
|
||||
return ChartGenerationError(
|
||||
error_type="missing_columns",
|
||||
message="Table chart missing required field: columns",
|
||||
details=(
|
||||
"Table charts require a 'columns' array to specify which "
|
||||
"columns to display"
|
||||
),
|
||||
suggestions=[
|
||||
"Add 'columns' field with array of column specifications",
|
||||
"Example: 'columns': [{'name': 'product'}, {'name': 'sales', "
|
||||
"'aggregate': 'SUM'}]",
|
||||
"Each column can have optional 'aggregate' for metrics",
|
||||
],
|
||||
error_code="MISSING_COLUMNS",
|
||||
)
|
||||
|
||||
if not isinstance(config.get("columns", []), list):
|
||||
return ChartGenerationError(
|
||||
error_type="invalid_columns_format",
|
||||
message="Columns must be a list",
|
||||
details="The 'columns' field must be an array of column specifications",
|
||||
suggestions=[
|
||||
"Ensure columns is an array: 'columns': [...]",
|
||||
"Each column should be an object with 'name' field",
|
||||
],
|
||||
error_code="INVALID_COLUMNS_FORMAT",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
|
||||
if not isinstance(config, TableChartConfig):
|
||||
return []
|
||||
refs: list[ColumnRef] = list(config.columns)
|
||||
if config.filters:
|
||||
for f in config.filters:
|
||||
refs.append(ColumnRef(name=f.column))
|
||||
return refs
|
||||
|
||||
def to_form_data(
|
||||
self, config: Any, dataset_id: int | str | None = None
|
||||
) -> dict[str, Any]:
|
||||
return map_table_config(config)
|
||||
|
||||
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
|
||||
what = _table_chart_what(config, dataset_name)
|
||||
context = _summarize_filters(config.filters)
|
||||
return self._with_context(what, context)
|
||||
|
||||
def resolve_viz_type(self, config: Any) -> str:
|
||||
return getattr(config, "viz_type", "table")
|
||||
|
||||
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
|
||||
config_dict = config.model_dump()
|
||||
get_canonical = DatasetValidator._get_canonical_column_name
|
||||
|
||||
for col in config_dict.get("columns") or []:
|
||||
col["name"] = get_canonical(col["name"], dataset_context)
|
||||
|
||||
DatasetValidator._normalize_filters(config_dict, dataset_context)
|
||||
return TableChartConfig.model_validate(config_dict)
|
||||
|
||||
def schema_error_hint(self) -> ChartGenerationError | None:
|
||||
return ChartGenerationError(
|
||||
error_type="table_validation_error",
|
||||
message="Table chart configuration validation failed",
|
||||
details=(
|
||||
"The table chart configuration is missing required "
|
||||
"fields or has invalid structure"
|
||||
),
|
||||
suggestions=[
|
||||
"Ensure 'columns' field is an array of column specifications",
|
||||
"Each column needs {'name': 'column_name'}",
|
||||
"Optional: add 'aggregate' for metrics",
|
||||
"Example: 'columns': [{'name': 'product'}, "
|
||||
"{'name': 'sales', 'aggregate': 'SUM'}]",
|
||||
],
|
||||
error_code="TABLE_VALIDATION_ERROR",
|
||||
)
|
||||
192
superset/mcp_service/chart/plugins/xy.py
Executable file
192
superset/mcp_service/chart/plugins/xy.py
Executable file
@@ -0,0 +1,192 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""XY chart type plugin (line, bar, area, scatter)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from superset.mcp_service.chart.chart_utils import (
|
||||
_xy_chart_context,
|
||||
_xy_chart_what,
|
||||
map_xy_config,
|
||||
)
|
||||
from superset.mcp_service.chart.plugin import BaseChartPlugin
|
||||
from superset.mcp_service.chart.schemas import ColumnRef, XYChartConfig
|
||||
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
|
||||
from superset.mcp_service.chart.validation.runtime.cardinality_validator import (
|
||||
CardinalityValidator,
|
||||
)
|
||||
from superset.mcp_service.chart.validation.runtime.format_validator import (
|
||||
FormatTypeValidator,
|
||||
)
|
||||
from superset.mcp_service.common.error_schemas import ChartGenerationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class XYChartPlugin(BaseChartPlugin):
|
||||
"""Plugin for xy chart type (line, bar, area, scatter)."""
|
||||
|
||||
chart_type = "xy"
|
||||
display_name = "Line / Bar / Area / Scatter Chart"
|
||||
native_viz_types = {
|
||||
"echarts_timeseries_line": "Line Chart",
|
||||
"echarts_timeseries_bar": "Bar Chart",
|
||||
"echarts_area": "Area Chart",
|
||||
"echarts_timeseries_scatter": "Scatter Plot",
|
||||
}
|
||||
|
||||
def pre_validate(
|
||||
self,
|
||||
config: dict[str, Any],
|
||||
) -> ChartGenerationError | None:
|
||||
# x is optional — defaults to dataset's main_dttm_col in map_xy_config
|
||||
if "y" not in config:
|
||||
return ChartGenerationError(
|
||||
error_type="missing_xy_fields",
|
||||
message="XY chart missing required field: 'y' (Y-axis metrics)",
|
||||
details=(
|
||||
"XY charts require Y-axis (metrics) specifications. "
|
||||
"X-axis is optional and defaults to the dataset's primary "
|
||||
"datetime column when omitted."
|
||||
),
|
||||
suggestions=[
|
||||
"Add 'y' field: [{'name': 'metric_column', 'aggregate': 'SUM'}]",
|
||||
"Example: {'chart_type': 'xy', 'x': {'name': 'date'}, "
|
||||
"'y': [{'name': 'sales', 'aggregate': 'SUM'}]}",
|
||||
],
|
||||
error_code="MISSING_XY_FIELDS",
|
||||
)
|
||||
|
||||
if not isinstance(config.get("y", []), list):
|
||||
return ChartGenerationError(
|
||||
error_type="invalid_y_format",
|
||||
message="Y-axis must be a list of metrics",
|
||||
details="The 'y' field must be an array of metric specifications",
|
||||
suggestions=[
|
||||
"Wrap Y-axis metric in array: 'y': [{'name': 'column', "
|
||||
"'aggregate': 'SUM'}]",
|
||||
"Multiple metrics supported: 'y': [metric1, metric2, ...]",
|
||||
],
|
||||
error_code="INVALID_Y_FORMAT",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
|
||||
if not isinstance(config, XYChartConfig):
|
||||
return []
|
||||
refs: list[ColumnRef] = []
|
||||
if config.x is not None:
|
||||
refs.append(config.x)
|
||||
refs.extend(config.y)
|
||||
if config.group_by:
|
||||
refs.extend(config.group_by)
|
||||
if config.filters:
|
||||
for f in config.filters:
|
||||
refs.append(ColumnRef(name=f.column))
|
||||
return refs
|
||||
|
||||
def to_form_data(
|
||||
self, config: Any, dataset_id: int | str | None = None
|
||||
) -> dict[str, Any]:
|
||||
return map_xy_config(config, dataset_id=dataset_id)
|
||||
|
||||
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
|
||||
config_dict = config.model_dump()
|
||||
get_canonical = DatasetValidator._get_canonical_column_name
|
||||
|
||||
if config_dict.get("x"):
|
||||
config_dict["x"]["name"] = get_canonical(
|
||||
config_dict["x"]["name"], dataset_context
|
||||
)
|
||||
for y_col in config_dict.get("y") or []:
|
||||
y_col["name"] = get_canonical(y_col["name"], dataset_context)
|
||||
for gb_col in config_dict.get("group_by") or []:
|
||||
gb_col["name"] = get_canonical(gb_col["name"], dataset_context)
|
||||
|
||||
DatasetValidator._normalize_filters(config_dict, dataset_context)
|
||||
return XYChartConfig.model_validate(config_dict)
|
||||
|
||||
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
|
||||
what = _xy_chart_what(config)
|
||||
context = _xy_chart_context(config)
|
||||
return self._with_context(what, context)
|
||||
|
||||
def resolve_viz_type(self, config: Any) -> str:
|
||||
kind = getattr(config, "kind", "line")
|
||||
return {
|
||||
"line": "echarts_timeseries_line",
|
||||
"bar": "echarts_timeseries_bar",
|
||||
"area": "echarts_area",
|
||||
"scatter": "echarts_timeseries_scatter",
|
||||
}.get(kind, "echarts_timeseries_line")
|
||||
|
||||
def get_runtime_warnings(self, config: Any, dataset_id: int | str) -> list[str]:
|
||||
"""Return format-compatibility and cardinality warnings for XY charts."""
|
||||
if not isinstance(config, XYChartConfig):
|
||||
return []
|
||||
|
||||
warnings: list[str] = []
|
||||
|
||||
try:
|
||||
_valid, format_warnings = FormatTypeValidator.validate_format_compatibility(
|
||||
config
|
||||
)
|
||||
if format_warnings:
|
||||
warnings.extend(format_warnings)
|
||||
except Exception as exc:
|
||||
logger.warning("XY format validation failed: %s", exc)
|
||||
|
||||
try:
|
||||
chart_kind = config.kind
|
||||
group_by_col = config.group_by[0].name if config.group_by else None
|
||||
if config.x is not None:
|
||||
_ok, card_info = CardinalityValidator.check_cardinality(
|
||||
dataset_id=dataset_id,
|
||||
x_column=config.x.name,
|
||||
chart_type=chart_kind,
|
||||
group_by_column=group_by_col,
|
||||
)
|
||||
if not _ok and card_info:
|
||||
warnings.extend(card_info.get("warnings", []))
|
||||
warnings.extend(card_info.get("suggestions", []))
|
||||
except Exception as exc:
|
||||
logger.warning("XY cardinality validation failed: %s", exc)
|
||||
|
||||
return warnings
|
||||
|
||||
def schema_error_hint(self) -> ChartGenerationError | None:
|
||||
return ChartGenerationError(
|
||||
error_type="xy_validation_error",
|
||||
message="XY chart configuration validation failed",
|
||||
details=(
|
||||
"The XY chart configuration is missing required "
|
||||
"fields or has invalid structure"
|
||||
),
|
||||
suggestions=[
|
||||
"Note: 'x' is optional and defaults to the dataset's primary "
|
||||
"datetime column",
|
||||
"Ensure 'y' is an array: [{'name': 'metric', 'aggregate': 'SUM'}]",
|
||||
"Check that all column names are strings",
|
||||
"Verify aggregate functions are valid: SUM, COUNT, AVG, MIN, MAX",
|
||||
],
|
||||
error_code="XY_VALIDATION_ERROR",
|
||||
)
|
||||
228
superset/mcp_service/chart/registry.py
Executable file
228
superset/mcp_service/chart/registry.py
Executable file
@@ -0,0 +1,228 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
ChartTypeRegistry — central registry mapping chart_type strings to plugins.
|
||||
|
||||
Replaces the four previously-scattered dispatch locations:
|
||||
- schema_validator.py: chart_type_validators dict
|
||||
- dataset_validator.py: isinstance branches in _extract_column_references()
|
||||
- chart_utils.py: if/elif chain in map_config_to_form_data()
|
||||
- dataset_validator.py: isinstance branches in normalize_column_names()
|
||||
|
||||
Usage::
|
||||
|
||||
from superset.mcp_service.chart.registry import get_registry
|
||||
|
||||
plugin = get_registry().get("xy")
|
||||
if plugin is None:
|
||||
raise ValueError("Unknown chart type: xy")
|
||||
form_data = plugin.to_form_data(config, dataset_id)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.mcp_service.chart.plugin import ChartTypePlugin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_REGISTRY: dict[str, "ChartTypePlugin"] = {}
|
||||
_plugins_loaded = False
|
||||
_plugins_lock = threading.Lock()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Plugin filter — replaced atomically by configure() at app startup.
|
||||
# Default: all registered plugins visible (no disabled set, no callable).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
PluginEnabledFunc = Callable[[str], bool]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _PluginFilterConfig:
|
||||
disabled_plugins: frozenset[str] = field(default_factory=frozenset)
|
||||
enabled_func: PluginEnabledFunc | None = None
|
||||
|
||||
|
||||
_filter_config: _PluginFilterConfig = _PluginFilterConfig()
|
||||
|
||||
|
||||
def _ensure_plugins_loaded() -> None:
|
||||
"""Lazily import the plugins package to populate _REGISTRY.
|
||||
|
||||
Called before every registry lookup so the registry is always populated,
|
||||
even when callers (tests, chart_utils, validators) import this module
|
||||
directly without first importing app.py.
|
||||
"""
|
||||
global _plugins_loaded
|
||||
if _plugins_loaded:
|
||||
return
|
||||
with _plugins_lock:
|
||||
if not _plugins_loaded:
|
||||
try:
|
||||
import superset.mcp_service.chart.plugins # noqa: F401
|
||||
|
||||
_plugins_loaded = True
|
||||
except Exception:
|
||||
logger.exception("Failed to load built-in chart type plugins")
|
||||
|
||||
|
||||
def configure(
|
||||
disabled: Iterable[str] | None = None,
|
||||
enabled_func: PluginEnabledFunc | None = None,
|
||||
) -> None:
|
||||
"""Set runtime plugin filters. Called once during app initialization.
|
||||
|
||||
Replaces the filter config atomically with a single assignment so concurrent
|
||||
readers always observe a consistent (disabled_plugins, enabled_func) pair.
|
||||
|
||||
Args:
|
||||
disabled: chart_type strings to suppress. Accepts any iterable (set,
|
||||
frozenset, list, tuple). Ignored when enabled_func is provided.
|
||||
enabled_func: callable(chart_type) -> bool. When set, overrides
|
||||
``disabled``. Must be cheap and in-process — no network I/O per
|
||||
call. On exception the registry fails *closed* (plugin hidden).
|
||||
"""
|
||||
global _filter_config
|
||||
|
||||
if enabled_func is not None and not callable(enabled_func):
|
||||
raise TypeError("enabled_func must be callable or None")
|
||||
|
||||
new_config = _PluginFilterConfig(
|
||||
disabled_plugins=frozenset(disabled or ()),
|
||||
enabled_func=enabled_func,
|
||||
)
|
||||
_filter_config = new_config
|
||||
|
||||
if new_config.disabled_plugins:
|
||||
logger.info(
|
||||
"MCP chart plugins disabled: %s", sorted(new_config.disabled_plugins)
|
||||
)
|
||||
if new_config.enabled_func is not None:
|
||||
logger.info(
|
||||
"MCP chart plugin dynamic filter configured: %r", new_config.enabled_func
|
||||
)
|
||||
|
||||
|
||||
def _is_plugin_enabled(chart_type: str) -> bool:
|
||||
"""Return True if the plugin is currently enabled (not filtered out)."""
|
||||
config = _filter_config # read once — atomic reference in CPython
|
||||
if config.enabled_func is not None:
|
||||
try:
|
||||
return bool(config.enabled_func(chart_type))
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"MCP_CHART_PLUGIN_ENABLED_FUNC raised for chart_type=%r; "
|
||||
"failing closed (plugin hidden)",
|
||||
chart_type,
|
||||
exc_info=True,
|
||||
)
|
||||
return False
|
||||
return chart_type not in config.disabled_plugins
|
||||
|
||||
|
||||
def register(plugin: "ChartTypePlugin") -> None:
|
||||
"""Register a chart type plugin in the global registry."""
|
||||
if not plugin.chart_type:
|
||||
raise ValueError(f"{type(plugin).__name__} must define a non-empty chart_type")
|
||||
if plugin.chart_type in _REGISTRY:
|
||||
logger.warning(
|
||||
"Overwriting existing plugin for chart_type=%r", plugin.chart_type
|
||||
)
|
||||
_REGISTRY[plugin.chart_type] = plugin
|
||||
logger.debug("Registered chart plugin: %r", plugin.chart_type)
|
||||
|
||||
|
||||
def get(chart_type: str) -> "ChartTypePlugin | None":
|
||||
"""Return the plugin for chart_type, or None if unknown or disabled."""
|
||||
_ensure_plugins_loaded()
|
||||
if chart_type not in _REGISTRY or not _is_plugin_enabled(chart_type):
|
||||
return None
|
||||
return _REGISTRY[chart_type]
|
||||
|
||||
|
||||
def all_types() -> list[str]:
|
||||
"""Return enabled registered chart type strings in insertion order."""
|
||||
_ensure_plugins_loaded()
|
||||
return [ct for ct in _REGISTRY if _is_plugin_enabled(ct)]
|
||||
|
||||
|
||||
def is_registered(chart_type: str) -> bool:
|
||||
"""Return True if chart_type has a registered plugin, regardless of enabled state.
|
||||
|
||||
Use this to distinguish an unknown chart type from a disabled one.
|
||||
Use is_enabled() to check whether the plugin is currently available.
|
||||
"""
|
||||
_ensure_plugins_loaded()
|
||||
return chart_type in _REGISTRY
|
||||
|
||||
|
||||
def is_enabled(chart_type: str) -> bool:
|
||||
"""Return True if chart_type is registered AND currently enabled."""
|
||||
_ensure_plugins_loaded()
|
||||
return chart_type in _REGISTRY and _is_plugin_enabled(chart_type)
|
||||
|
||||
|
||||
def display_name_for_viz_type(viz_type: str) -> str | None:
|
||||
"""Return the user-facing display name for a Superset-internal viz_type.
|
||||
|
||||
Searches every registered plugin's ``native_viz_types`` mapping.
|
||||
Returns None if no plugin recognises the viz_type.
|
||||
|
||||
Example::
|
||||
|
||||
display_name_for_viz_type("echarts_timeseries_line") # "Line Chart"
|
||||
display_name_for_viz_type("pivot_table_v2") # "Pivot Table"
|
||||
display_name_for_viz_type("unknown_type") # None
|
||||
"""
|
||||
_ensure_plugins_loaded()
|
||||
for plugin in _REGISTRY.values():
|
||||
name = plugin.native_viz_types.get(viz_type)
|
||||
if name is not None:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def get_registry() -> "_RegistryProxy":
|
||||
"""Return a proxy object for registry access (convenience wrapper)."""
|
||||
return _RegistryProxy()
|
||||
|
||||
|
||||
class _RegistryProxy:
|
||||
"""Thin proxy exposing registry functions as instance methods."""
|
||||
|
||||
def get(self, chart_type: str) -> "ChartTypePlugin | None":
|
||||
return get(chart_type)
|
||||
|
||||
def all_types(self) -> list[str]:
|
||||
return all_types()
|
||||
|
||||
def is_registered(self, chart_type: str) -> bool:
|
||||
return is_registered(chart_type)
|
||||
|
||||
def is_enabled(self, chart_type: str) -> bool:
|
||||
return is_enabled(chart_type)
|
||||
|
||||
def display_name_for_viz_type(self, viz_type: str) -> str | None:
|
||||
return display_name_for_viz_type(viz_type)
|
||||
23
superset/mcp_service/chart/schemas.py
Normal file → Executable file
23
superset/mcp_service/chart/schemas.py
Normal file → Executable file
@@ -101,7 +101,14 @@ class ChartInfo(BaseModel):
|
||||
|
||||
id: int | None = Field(None, description="Chart ID")
|
||||
slice_name: str | None = Field(None, description="Chart name")
|
||||
viz_type: str | None = Field(None, description="Visualization type")
|
||||
viz_type: str | None = Field(None, description="Visualization type (internal ID)")
|
||||
chart_type_display_name: str | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"User-friendly chart type name (e.g. 'Line Chart', 'Pivot Table'). "
|
||||
"Use this field when referring to chart types — never expose viz_type."
|
||||
),
|
||||
)
|
||||
datasource_name: str | None = Field(None, description="Datasource name")
|
||||
datasource_type: str | None = Field(None, description="Datasource type")
|
||||
url: str | None = Field(None, description="Chart explore page URL")
|
||||
@@ -488,11 +495,20 @@ def serialize_chart_object(chart: ChartLike | None) -> ChartInfo | None:
|
||||
# Extract structured filter information
|
||||
filters_info = extract_filters_from_form_data(chart_form_data)
|
||||
|
||||
_viz_type = getattr(chart, "viz_type", None)
|
||||
try:
|
||||
from superset.mcp_service.chart.registry import display_name_for_viz_type
|
||||
|
||||
_display_name = display_name_for_viz_type(_viz_type) if _viz_type else None
|
||||
except Exception:
|
||||
_display_name = None
|
||||
|
||||
return sanitize_chart_info_for_llm_context(
|
||||
ChartInfo(
|
||||
id=chart_id,
|
||||
slice_name=getattr(chart, "slice_name", None),
|
||||
viz_type=getattr(chart, "viz_type", None),
|
||||
viz_type=_viz_type,
|
||||
chart_type_display_name=_display_name,
|
||||
datasource_name=getattr(chart, "datasource_name", None),
|
||||
datasource_type=getattr(chart, "datasource_type", None),
|
||||
url=chart_url,
|
||||
@@ -669,7 +685,6 @@ class ColumnRef(BaseModel):
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=255,
|
||||
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
|
||||
validation_alias=AliasChoices("name", "column_name"),
|
||||
)
|
||||
label: str | None = Field(None, max_length=500)
|
||||
@@ -743,7 +758,6 @@ class FilterConfig(BaseModel):
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=255,
|
||||
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
|
||||
validation_alias=AliasChoices("column", "col"),
|
||||
)
|
||||
op: Literal[
|
||||
@@ -1082,7 +1096,6 @@ class BigNumberChartConfig(UnknownFieldCheckMixin):
|
||||
),
|
||||
min_length=1,
|
||||
max_length=255,
|
||||
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
|
||||
)
|
||||
time_grain: TimeGrain | None = Field(
|
||||
None,
|
||||
|
||||
@@ -100,18 +100,34 @@ async def generate_chart( # noqa: C901
|
||||
- Set save_chart=True to permanently save the chart
|
||||
- LLM clients MUST display returned chart URL to users
|
||||
- Use numeric dataset ID or UUID (NOT schema.table_name format)
|
||||
- MUST include chart_type in config (either 'xy' or 'table')
|
||||
- MUST include chart_type in config (one of: 'xy', 'table', 'pie',
|
||||
'pivot_table', 'mixed_timeseries', 'handlebars', 'big_number')
|
||||
|
||||
IMPORTANT: The 'chart_type' field in the config is a DISCRIMINATOR that determines
|
||||
which chart configuration schema to use. It MUST be included and MUST match the
|
||||
other fields in your configuration:
|
||||
|
||||
- Use chart_type='xy' for charts with x and y axes (line, bar, area, scatter)
|
||||
Required fields: x, y
|
||||
Required fields: y (x is optional — defaults to dataset's primary datetime column)
|
||||
|
||||
- Use chart_type='table' for tabular visualizations
|
||||
Required fields: columns
|
||||
|
||||
- Use chart_type='pie' for pie/donut charts
|
||||
Required fields: dimension, metric
|
||||
|
||||
- Use chart_type='pivot_table' for pivot table visualizations
|
||||
Required fields: rows, metrics
|
||||
|
||||
- Use chart_type='mixed_timeseries' for dual-axis time-series charts
|
||||
Required fields: x, y, y_secondary
|
||||
|
||||
- Use chart_type='handlebars' for custom template-based visualizations
|
||||
Required fields: handlebars_template
|
||||
|
||||
- Use chart_type='big_number' for single KPI metric displays
|
||||
Required fields: metric
|
||||
|
||||
Example usage for XY chart:
|
||||
```json
|
||||
{
|
||||
|
||||
@@ -28,7 +28,7 @@ from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.commands.exceptions import CommandException
|
||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException
|
||||
from superset.extensions import event_logger
|
||||
from superset.extensions import db, event_logger
|
||||
from superset.mcp_service.chart.ascii_charts import (
|
||||
generate_ascii_chart,
|
||||
generate_ascii_table,
|
||||
@@ -1140,6 +1140,15 @@ async def _get_chart_preview_internal( # noqa: C901
|
||||
)
|
||||
chart = find_chart_by_identifier(request.identifier)
|
||||
|
||||
# Eagerly refresh all attributes while the session is still
|
||||
# active. SQLAlchemy expires object attributes after any
|
||||
# commit; if a downstream operation commits before the strategy
|
||||
# classes access chart attributes, a DetachedInstanceError will
|
||||
# be raised. Calling refresh() here ensures all column values
|
||||
# are loaded into the object's __dict__ upfront.
|
||||
if chart is not None:
|
||||
db.session.refresh(chart)
|
||||
|
||||
# If not found and looks like a form_data_key, try transient
|
||||
if (
|
||||
not chart
|
||||
@@ -1371,6 +1380,20 @@ async def _get_chart_preview_internal( # noqa: C901
|
||||
|
||||
return _sanitize_chart_preview_for_llm_context(result)
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
# Catch DetachedInstanceError and other SQLAlchemy errors that can
|
||||
# surface when the ORM session expires or commits mid-request.
|
||||
await ctx.error(
|
||||
"Chart preview failed due to database session error: "
|
||||
"identifier=%s, error_type=%s, error=%s"
|
||||
% (request.identifier, type(e).__name__, str(e))
|
||||
)
|
||||
logger.exception("SQLAlchemy error in get_chart_preview: %s", e)
|
||||
return ChartError(
|
||||
error="Database session error while generating chart preview. "
|
||||
"Please retry the request.",
|
||||
error_type="InternalError",
|
||||
)
|
||||
except (
|
||||
CommandException,
|
||||
SupersetException,
|
||||
|
||||
41
superset/mcp_service/chart/tool/update_chart.py
Normal file → Executable file
41
superset/mcp_service/chart/tool/update_chart.py
Normal file → Executable file
@@ -195,6 +195,29 @@ def _validate_update_against_dataset(
|
||||
}
|
||||
)
|
||||
|
||||
# Column existence + fuzzy-match validation
|
||||
# (mirrors generate_chart pipeline layer 2)
|
||||
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
|
||||
|
||||
is_col_valid, col_error = DatasetValidator.validate_against_dataset(
|
||||
parsed_config, dataset.id
|
||||
)
|
||||
if not is_col_valid and col_error is not None:
|
||||
logger.warning(
|
||||
"update_chart column validation failed for chart %s: %s",
|
||||
getattr(chart, "id", None),
|
||||
col_error,
|
||||
)
|
||||
return GenerateChartResponse.model_validate(
|
||||
{
|
||||
"chart": None,
|
||||
"error": col_error.model_dump(),
|
||||
"success": False,
|
||||
"schema_version": "2.0",
|
||||
"api_version": "v1",
|
||||
}
|
||||
)
|
||||
|
||||
compile_result = validate_and_compile(
|
||||
parsed_config, form_data, dataset, run_compile_check=True
|
||||
)
|
||||
@@ -388,6 +411,24 @@ async def update_chart( # noqa: C901
|
||||
# config is already a typed ChartConfig | None (validated by Pydantic)
|
||||
parsed_config = request.config
|
||||
|
||||
# Normalize column case to match dataset canonical names
|
||||
# (mirrors generate_chart pipeline layer 4)
|
||||
chart_datasource_id = getattr(chart, "datasource_id", None)
|
||||
if parsed_config is not None and chart_datasource_id is not None:
|
||||
from superset.mcp_service.chart.validation.dataset_validator import (
|
||||
DatasetValidator,
|
||||
NORMALIZATION_EXCEPTIONS,
|
||||
)
|
||||
|
||||
try:
|
||||
parsed_config = DatasetValidator.normalize_column_names(
|
||||
parsed_config, chart.datasource_id
|
||||
)
|
||||
except NORMALIZATION_EXCEPTIONS as e:
|
||||
logger.warning(
|
||||
"Column normalization failed for chart %s: %s", chart.id, e
|
||||
)
|
||||
|
||||
if not request.generate_preview:
|
||||
from superset.commands.chart.update import UpdateChartCommand
|
||||
|
||||
|
||||
@@ -22,17 +22,11 @@ Validates that referenced columns exist in the dataset schema.
|
||||
|
||||
import difflib
|
||||
import logging
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Tuple, TypeVar
|
||||
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
BigNumberChartConfig,
|
||||
ChartConfig,
|
||||
ColumnRef,
|
||||
HandlebarsChartConfig,
|
||||
MixedTimeseriesChartConfig,
|
||||
PieChartConfig,
|
||||
PivotTableChartConfig,
|
||||
TableChartConfig,
|
||||
XYChartConfig,
|
||||
)
|
||||
from superset.mcp_service.common.error_schemas import (
|
||||
ChartGenerationError,
|
||||
@@ -40,6 +34,8 @@ from superset.mcp_service.common.error_schemas import (
|
||||
DatasetContext,
|
||||
)
|
||||
|
||||
_C = TypeVar("_C", bound=ChartConfig)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Exceptions that can occur during column name normalization.
|
||||
@@ -58,7 +54,7 @@ class DatasetValidator:
|
||||
|
||||
@staticmethod
|
||||
def validate_against_dataset(
|
||||
config: Any,
|
||||
config: ChartConfig,
|
||||
dataset_id: int | str,
|
||||
dataset_context: DatasetContext | None = None,
|
||||
) -> Tuple[bool, ChartGenerationError | None]:
|
||||
@@ -260,59 +256,31 @@ class DatasetValidator:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_column_references(config: Any) -> List[ColumnRef]: # noqa: C901
|
||||
"""Extract all column references from a chart configuration.
|
||||
def _extract_column_references(
|
||||
config: ChartConfig,
|
||||
) -> List[ColumnRef]:
|
||||
"""Extract all column references from configuration via the plugin registry.
|
||||
|
||||
Covers every supported ``ChartConfig`` variant so fast-path tools
|
||||
(``generate_explore_link``, ``update_chart_preview``) that only run
|
||||
Tier-1 validation still catch bad column refs in pie / pivot table /
|
||||
mixed timeseries / handlebars / big number charts — not just XY and
|
||||
table.
|
||||
Previously only handled TableChartConfig and XYChartConfig, causing
|
||||
5 of 7 chart types to silently skip column validation. Now delegates
|
||||
to the plugin for each chart type so all types are covered.
|
||||
"""
|
||||
refs: List[ColumnRef] = []
|
||||
# Local import: plugins call DatasetValidator helpers from
|
||||
# normalize_column_refs().
|
||||
# A top-level import of registry in dataset_validator would make loading this
|
||||
# module implicitly trigger plugin registration, creating a circular dependency.
|
||||
from superset.mcp_service.chart.registry import get_registry
|
||||
|
||||
if isinstance(config, TableChartConfig):
|
||||
refs.extend(config.columns)
|
||||
elif isinstance(config, XYChartConfig):
|
||||
if config.x is not None:
|
||||
refs.append(config.x)
|
||||
refs.extend(config.y)
|
||||
if config.group_by:
|
||||
refs.extend(config.group_by)
|
||||
elif isinstance(config, PieChartConfig):
|
||||
refs.append(config.dimension)
|
||||
refs.append(config.metric)
|
||||
elif isinstance(config, PivotTableChartConfig):
|
||||
refs.extend(config.rows)
|
||||
if config.columns:
|
||||
refs.extend(config.columns)
|
||||
refs.extend(config.metrics)
|
||||
elif isinstance(config, MixedTimeseriesChartConfig):
|
||||
refs.append(config.x)
|
||||
refs.extend(config.y)
|
||||
if config.group_by:
|
||||
refs.extend(config.group_by)
|
||||
refs.extend(config.y_secondary)
|
||||
if config.group_by_secondary:
|
||||
refs.extend(config.group_by_secondary)
|
||||
elif isinstance(config, HandlebarsChartConfig):
|
||||
if config.columns:
|
||||
refs.extend(config.columns)
|
||||
if config.groupby:
|
||||
refs.extend(config.groupby)
|
||||
if config.metrics:
|
||||
refs.extend(config.metrics)
|
||||
elif isinstance(config, BigNumberChartConfig):
|
||||
refs.append(config.metric)
|
||||
if config.temporal_column:
|
||||
refs.append(ColumnRef(name=config.temporal_column))
|
||||
chart_type = getattr(config, "chart_type", None)
|
||||
if chart_type is None:
|
||||
return []
|
||||
|
||||
# Filter columns (shared by every config type that defines ``filters``).
|
||||
if filters := getattr(config, "filters", None):
|
||||
for filter_config in filters:
|
||||
refs.append(ColumnRef(name=filter_config.column))
|
||||
plugin = get_registry().get(chart_type)
|
||||
if plugin is None:
|
||||
logger.warning("No plugin registered for chart_type=%r", chart_type)
|
||||
return []
|
||||
|
||||
return refs
|
||||
return plugin.extract_column_refs(config)
|
||||
|
||||
@staticmethod
|
||||
def _column_exists(column_name: str, dataset_context: DatasetContext) -> bool:
|
||||
@@ -365,42 +333,6 @@ class DatasetValidator:
|
||||
# Return original if not found (validation should catch this case)
|
||||
return column_name
|
||||
|
||||
@staticmethod
|
||||
def _normalize_xy_config(
|
||||
config_dict: Dict[str, Any], dataset_context: DatasetContext
|
||||
) -> None:
|
||||
"""Normalize column names in an XY chart config dict in place."""
|
||||
# Normalize x-axis column
|
||||
if "x" in config_dict and config_dict["x"]:
|
||||
config_dict["x"]["name"] = DatasetValidator._get_canonical_column_name(
|
||||
config_dict["x"]["name"], dataset_context
|
||||
)
|
||||
|
||||
# Normalize y-axis columns
|
||||
if "y" in config_dict and config_dict["y"]:
|
||||
for y_col in config_dict["y"]:
|
||||
y_col["name"] = DatasetValidator._get_canonical_column_name(
|
||||
y_col["name"], dataset_context
|
||||
)
|
||||
|
||||
# Normalize group_by columns
|
||||
if "group_by" in config_dict and config_dict["group_by"]:
|
||||
for gb_col in config_dict["group_by"]:
|
||||
gb_col["name"] = DatasetValidator._get_canonical_column_name(
|
||||
gb_col["name"], dataset_context
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_table_config(
|
||||
config_dict: Dict[str, Any], dataset_context: DatasetContext
|
||||
) -> None:
|
||||
"""Normalize column names in a table chart config dict in place."""
|
||||
if "columns" in config_dict and config_dict["columns"]:
|
||||
for col in config_dict["columns"]:
|
||||
col["name"] = DatasetValidator._get_canonical_column_name(
|
||||
col["name"], dataset_context
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_filters(
|
||||
config_dict: Dict[str, Any], dataset_context: DatasetContext
|
||||
@@ -417,10 +349,10 @@ class DatasetValidator:
|
||||
|
||||
@staticmethod
|
||||
def normalize_column_names(
|
||||
config: TableChartConfig | XYChartConfig,
|
||||
config: _C,
|
||||
dataset_id: int | str,
|
||||
dataset_context: DatasetContext | None = None,
|
||||
) -> TableChartConfig | XYChartConfig:
|
||||
) -> _C:
|
||||
"""
|
||||
Normalize column names in config to match the canonical dataset column names.
|
||||
|
||||
@@ -429,6 +361,9 @@ class DatasetValidator:
|
||||
(e.g., 'OrderDate'). The frontend performs case-sensitive comparisons,
|
||||
so we need to ensure column names match exactly.
|
||||
|
||||
Previously only XYChartConfig and TableChartConfig were normalized; now
|
||||
all 7 chart types are handled via the plugin registry.
|
||||
|
||||
Args:
|
||||
config: Chart configuration with column references
|
||||
dataset_id: Dataset ID to get canonical column names from
|
||||
@@ -443,22 +378,24 @@ class DatasetValidator:
|
||||
if not dataset_context:
|
||||
return config
|
||||
|
||||
# Create a mutable copy of the config
|
||||
config_dict = config.model_dump()
|
||||
# Local import: plugins call DatasetValidator helpers from
|
||||
# normalize_column_refs().
|
||||
# A top-level import of registry in dataset_validator would make loading this
|
||||
# module implicitly trigger plugin registration, creating a circular dependency.
|
||||
from superset.mcp_service.chart.registry import get_registry
|
||||
|
||||
# Normalize based on config type
|
||||
if isinstance(config, XYChartConfig):
|
||||
DatasetValidator._normalize_xy_config(config_dict, dataset_context)
|
||||
elif isinstance(config, TableChartConfig):
|
||||
DatasetValidator._normalize_table_config(config_dict, dataset_context)
|
||||
chart_type = getattr(config, "chart_type", None)
|
||||
if chart_type is None:
|
||||
return config
|
||||
|
||||
# Normalize filter columns (common to both config types)
|
||||
DatasetValidator._normalize_filters(config_dict, dataset_context)
|
||||
plugin = get_registry().get(chart_type)
|
||||
if plugin is None:
|
||||
logger.warning(
|
||||
"No plugin for chart_type=%r; skipping column normalization", chart_type
|
||||
)
|
||||
return config
|
||||
|
||||
# Reconstruct the config with normalized names
|
||||
if isinstance(config, XYChartConfig):
|
||||
return XYChartConfig.model_validate(config_dict)
|
||||
return TableChartConfig.model_validate(config_dict)
|
||||
return plugin.normalize_column_refs(config, dataset_context)
|
||||
|
||||
@staticmethod
|
||||
def _get_column_suggestions(
|
||||
|
||||
@@ -23,10 +23,7 @@ Validates performance, compatibility, and user experience issues.
|
||||
import logging
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
ChartConfig,
|
||||
XYChartConfig,
|
||||
)
|
||||
from superset.mcp_service.chart.schemas import ChartConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -56,20 +53,10 @@ class RuntimeValidator:
|
||||
warnings: List[str] = []
|
||||
suggestions: List[str] = []
|
||||
|
||||
# Only check XY charts for format and cardinality issues
|
||||
if isinstance(config, XYChartConfig):
|
||||
# Format-type compatibility validation
|
||||
format_warnings = RuntimeValidator._validate_format_compatibility(config)
|
||||
if format_warnings:
|
||||
warnings.extend(format_warnings)
|
||||
|
||||
# Cardinality validation
|
||||
cardinality_warnings, cardinality_suggestions = (
|
||||
RuntimeValidator._validate_cardinality(config, dataset_id)
|
||||
)
|
||||
if cardinality_warnings:
|
||||
warnings.extend(cardinality_warnings)
|
||||
suggestions.extend(cardinality_suggestions)
|
||||
# Per-plugin runtime warnings (format, cardinality, etc.)
|
||||
plugin_warnings = RuntimeValidator._validate_plugin_runtime(config, dataset_id)
|
||||
if plugin_warnings:
|
||||
warnings.extend(plugin_warnings)
|
||||
|
||||
# Chart type appropriateness validation (for all chart types)
|
||||
type_warnings, type_suggestions = RuntimeValidator._validate_chart_type(
|
||||
@@ -98,61 +85,28 @@ class RuntimeValidator:
|
||||
return True, None
|
||||
|
||||
@staticmethod
|
||||
def _validate_format_compatibility(config: XYChartConfig) -> List[str]:
|
||||
"""Validate format-type compatibility."""
|
||||
warnings: List[str] = []
|
||||
def _validate_plugin_runtime(
|
||||
config: ChartConfig, dataset_id: int | str
|
||||
) -> List[str]:
|
||||
"""Delegate per-chart-type runtime warnings to the plugin registry.
|
||||
|
||||
Each plugin's get_runtime_warnings() method returns chart-type-specific
|
||||
warnings (e.g. format/cardinality for XY). The registry dispatch removes
|
||||
the previous isinstance(config, XYChartConfig) hardcoding.
|
||||
"""
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from .format_validator import FormatTypeValidator
|
||||
from superset.mcp_service.chart.registry import get_registry
|
||||
|
||||
is_valid, format_warnings = (
|
||||
FormatTypeValidator.validate_format_compatibility(config)
|
||||
)
|
||||
if format_warnings:
|
||||
warnings.extend(format_warnings)
|
||||
except ImportError:
|
||||
logger.warning("Format validator not available")
|
||||
except Exception as e:
|
||||
logger.warning("Format validation failed: %s", e)
|
||||
|
||||
return warnings
|
||||
|
||||
@staticmethod
|
||||
def _validate_cardinality(
|
||||
config: XYChartConfig, dataset_id: int | str
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
"""Validate cardinality issues."""
|
||||
warnings: List[str] = []
|
||||
suggestions: List[str] = []
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from .cardinality_validator import CardinalityValidator
|
||||
|
||||
# Determine chart type for cardinality thresholds
|
||||
chart_type = config.kind if hasattr(config, "kind") else "default"
|
||||
|
||||
# Check X-axis cardinality
|
||||
if config.x is None:
|
||||
return warnings, suggestions
|
||||
is_ok, cardinality_info = CardinalityValidator.check_cardinality(
|
||||
dataset_id=dataset_id,
|
||||
x_column=config.x.name,
|
||||
chart_type=chart_type,
|
||||
group_by_column=config.group_by[0].name if config.group_by else None,
|
||||
)
|
||||
|
||||
if not is_ok and cardinality_info:
|
||||
warnings.extend(cardinality_info.get("warnings", []))
|
||||
suggestions.extend(cardinality_info.get("suggestions", []))
|
||||
|
||||
except ImportError:
|
||||
logger.warning("Cardinality validator not available")
|
||||
except Exception as e:
|
||||
logger.warning("Cardinality validation failed: %s", e)
|
||||
|
||||
return warnings, suggestions
|
||||
chart_type = getattr(config, "chart_type", None)
|
||||
if chart_type is None:
|
||||
return []
|
||||
plugin = get_registry().get(chart_type)
|
||||
if plugin is None:
|
||||
return []
|
||||
return plugin.get_runtime_warnings(config, dataset_id)
|
||||
except Exception as exc:
|
||||
logger.warning("Plugin runtime validation failed: %s", exc)
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _validate_chart_type(
|
||||
|
||||
455
superset/mcp_service/chart/validation/schema_validator.py
Normal file → Executable file
455
superset/mcp_service/chart/validation/schema_validator.py
Normal file → Executable file
@@ -147,19 +147,13 @@ class SchemaValidator:
|
||||
chart_type: str,
|
||||
config: Dict[str, Any],
|
||||
) -> Tuple[bool, ChartGenerationError | None]:
|
||||
"""Validate chart type and dispatch to type-specific pre-validation."""
|
||||
chart_type_validators = {
|
||||
"xy": SchemaValidator._pre_validate_xy_config,
|
||||
"table": SchemaValidator._pre_validate_table_config,
|
||||
"pie": SchemaValidator._pre_validate_pie_config,
|
||||
"pivot_table": SchemaValidator._pre_validate_pivot_table_config,
|
||||
"mixed_timeseries": SchemaValidator._pre_validate_mixed_timeseries_config,
|
||||
"handlebars": SchemaValidator._pre_validate_handlebars_config,
|
||||
"big_number": SchemaValidator._pre_validate_big_number_config,
|
||||
}
|
||||
"""Validate chart type and dispatch to plugin pre-validation."""
|
||||
from superset.mcp_service.chart.registry import get_registry
|
||||
|
||||
if not isinstance(chart_type, str) or chart_type not in chart_type_validators:
|
||||
valid_types = ", ".join(chart_type_validators.keys())
|
||||
registry = get_registry()
|
||||
|
||||
if not isinstance(chart_type, str) or not registry.is_registered(chart_type):
|
||||
valid_types = ", ".join(registry.all_types())
|
||||
return False, ChartGenerationError(
|
||||
error_type="invalid_chart_type",
|
||||
message=f"Invalid chart_type: '{chart_type}'",
|
||||
@@ -178,351 +172,33 @@ class SchemaValidator:
|
||||
error_code="INVALID_CHART_TYPE",
|
||||
)
|
||||
|
||||
return chart_type_validators[chart_type](config)
|
||||
|
||||
@staticmethod
|
||||
def _pre_validate_xy_config(
|
||||
config: Dict[str, Any],
|
||||
) -> Tuple[bool, ChartGenerationError | None]:
|
||||
"""Pre-validate XY chart configuration."""
|
||||
# x is optional — defaults to dataset's main_dttm_col in map_xy_config
|
||||
if "y" not in config:
|
||||
if not registry.is_enabled(chart_type):
|
||||
valid_types = ", ".join(registry.all_types())
|
||||
return False, ChartGenerationError(
|
||||
error_type="missing_xy_fields",
|
||||
message="XY chart missing required field: 'y' (Y-axis metrics)",
|
||||
details="XY charts require Y-axis (metrics) specifications. "
|
||||
"X-axis is optional and defaults to the dataset's primary "
|
||||
"datetime column when omitted.",
|
||||
error_type="disabled_chart_type",
|
||||
message=f"Chart type '{chart_type}' is not enabled on this instance",
|
||||
details=f"Chart type '{chart_type}' is registered but has been "
|
||||
f"disabled by the operator. "
|
||||
f"Enabled chart types: {valid_types}",
|
||||
suggestions=[
|
||||
"Add 'y' field: [{'name': 'metric_column', 'aggregate': 'SUM'}] "
|
||||
"for Y-axis",
|
||||
"Example: {'chart_type': 'xy', 'x': {'name': 'date'}, "
|
||||
"'y': [{'name': 'sales', 'aggregate': 'SUM'}]}",
|
||||
f"Use one of the enabled chart types: {valid_types}",
|
||||
"Contact your administrator if you believe this is an error",
|
||||
],
|
||||
error_code="MISSING_XY_FIELDS",
|
||||
error_code="DISABLED_CHART_TYPE",
|
||||
)
|
||||
|
||||
# Validate Y is a list
|
||||
if not isinstance(config.get("y", []), list):
|
||||
plugin = registry.get(chart_type)
|
||||
if plugin is None:
|
||||
return False, ChartGenerationError(
|
||||
error_type="invalid_y_format",
|
||||
message="Y-axis must be a list of metrics",
|
||||
details="The 'y' field must be an array of metric specifications",
|
||||
suggestions=[
|
||||
"Wrap Y-axis metric in array: 'y': [{'name': 'column', "
|
||||
"'aggregate': 'SUM'}]",
|
||||
"Multiple metrics supported: 'y': [metric1, metric2, ...]",
|
||||
],
|
||||
error_code="INVALID_Y_FORMAT",
|
||||
error_type="invalid_chart_type",
|
||||
message=f"Chart type '{chart_type}' has no registered plugin",
|
||||
details="Internal error: chart type is listed but has no plugin",
|
||||
suggestions=["Use a supported chart_type"],
|
||||
error_code="INVALID_CHART_TYPE",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
@staticmethod
|
||||
def _pre_validate_table_config(
|
||||
config: Dict[str, Any],
|
||||
) -> Tuple[bool, ChartGenerationError | None]:
|
||||
"""Pre-validate table chart configuration."""
|
||||
if "columns" not in config:
|
||||
return False, ChartGenerationError(
|
||||
error_type="missing_columns",
|
||||
message="Table chart missing required field: columns",
|
||||
details="Table charts require a 'columns' array to specify which "
|
||||
"columns to display",
|
||||
suggestions=[
|
||||
"Add 'columns' field with array of column specifications",
|
||||
"Example: 'columns': [{'name': 'product'}, {'name': 'sales', "
|
||||
"'aggregate': 'SUM'}]",
|
||||
"Each column can have optional 'aggregate' for metrics",
|
||||
],
|
||||
error_code="MISSING_COLUMNS",
|
||||
)
|
||||
|
||||
if not isinstance(config.get("columns", []), list):
|
||||
return False, ChartGenerationError(
|
||||
error_type="invalid_columns_format",
|
||||
message="Columns must be a list",
|
||||
details="The 'columns' field must be an array of column specifications",
|
||||
suggestions=[
|
||||
"Ensure columns is an array: 'columns': [...]",
|
||||
"Each column should be an object with 'name' field",
|
||||
],
|
||||
error_code="INVALID_COLUMNS_FORMAT",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
@staticmethod
|
||||
def _pre_validate_pie_config(
|
||||
config: Dict[str, Any],
|
||||
) -> Tuple[bool, ChartGenerationError | None]:
|
||||
"""Pre-validate pie chart configuration."""
|
||||
missing_fields = []
|
||||
|
||||
if "dimension" not in config:
|
||||
missing_fields.append("'dimension' (category column for slices)")
|
||||
if "metric" not in config:
|
||||
missing_fields.append("'metric' (value metric for slice sizes)")
|
||||
|
||||
if missing_fields:
|
||||
return False, ChartGenerationError(
|
||||
error_type="missing_pie_fields",
|
||||
message=f"Pie chart missing required "
|
||||
f"fields: {', '.join(missing_fields)}",
|
||||
details="Pie charts require a dimension (categories) and a metric "
|
||||
"(values)",
|
||||
suggestions=[
|
||||
"Add 'dimension' field: {'name': 'category_column'}",
|
||||
"Add 'metric' field: {'name': 'value_column', 'aggregate': 'SUM'}",
|
||||
"Example: {'chart_type': 'pie', 'dimension': {'name': "
|
||||
"'product'}, 'metric': {'name': 'revenue', 'aggregate': 'SUM'}}",
|
||||
],
|
||||
error_code="MISSING_PIE_FIELDS",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
@staticmethod
|
||||
def _pre_validate_handlebars_config(
|
||||
config: Dict[str, Any],
|
||||
) -> Tuple[bool, ChartGenerationError | None]:
|
||||
"""Pre-validate handlebars chart configuration."""
|
||||
if "handlebars_template" not in config:
|
||||
return False, ChartGenerationError(
|
||||
error_type="missing_handlebars_template",
|
||||
message="Handlebars chart missing required field: handlebars_template",
|
||||
details="Handlebars charts require a 'handlebars_template' string "
|
||||
"containing Handlebars HTML template markup",
|
||||
suggestions=[
|
||||
"Add 'handlebars_template' with a Handlebars HTML template",
|
||||
"Data is available as {{data}} array in the template",
|
||||
"Example: '<ul>{{#each data}}<li>{{this.name}}: "
|
||||
"{{this.value}}</li>{{/each}}</ul>'",
|
||||
],
|
||||
error_code="MISSING_HANDLEBARS_TEMPLATE",
|
||||
)
|
||||
|
||||
template = config.get("handlebars_template")
|
||||
if not isinstance(template, str) or not template.strip():
|
||||
return False, ChartGenerationError(
|
||||
error_type="invalid_handlebars_template",
|
||||
message="Handlebars template must be a non-empty string",
|
||||
details="The 'handlebars_template' field must be a non-empty string "
|
||||
"containing valid Handlebars HTML template markup",
|
||||
suggestions=[
|
||||
"Ensure handlebars_template is a non-empty string",
|
||||
"Example: '<ul>{{#each data}}<li>{{this.name}}</li>{{/each}}</ul>'",
|
||||
],
|
||||
error_code="INVALID_HANDLEBARS_TEMPLATE",
|
||||
)
|
||||
|
||||
query_mode = config.get("query_mode", "aggregate")
|
||||
if query_mode not in ("aggregate", "raw"):
|
||||
return False, ChartGenerationError(
|
||||
error_type="invalid_query_mode",
|
||||
message="Invalid query_mode for handlebars chart",
|
||||
details="query_mode must be either 'aggregate' or 'raw'",
|
||||
suggestions=[
|
||||
"Use 'aggregate' for aggregated data (default)",
|
||||
"Use 'raw' for individual rows",
|
||||
],
|
||||
error_code="INVALID_QUERY_MODE",
|
||||
)
|
||||
|
||||
if query_mode == "raw" and not config.get("columns"):
|
||||
return False, ChartGenerationError(
|
||||
error_type="missing_raw_columns",
|
||||
message="Handlebars chart in 'raw' mode requires 'columns'",
|
||||
details="When query_mode is 'raw', you must specify which columns "
|
||||
"to include in the query results",
|
||||
suggestions=[
|
||||
"Add 'columns': [{'name': 'column_name'}] for raw mode",
|
||||
"Or use query_mode='aggregate' with 'metrics' "
|
||||
"and optional 'groupby'",
|
||||
],
|
||||
error_code="MISSING_RAW_COLUMNS",
|
||||
)
|
||||
|
||||
if query_mode == "aggregate" and not config.get("metrics"):
|
||||
return False, ChartGenerationError(
|
||||
error_type="missing_aggregate_metrics",
|
||||
message="Handlebars chart in 'aggregate' mode requires 'metrics'",
|
||||
details="When query_mode is 'aggregate' (default), you must specify "
|
||||
"at least one metric with an aggregate function",
|
||||
suggestions=[
|
||||
"Add 'metrics': [{'name': 'column', 'aggregate': 'SUM'}]",
|
||||
"Or use query_mode='raw' with 'columns' for individual rows",
|
||||
],
|
||||
error_code="MISSING_AGGREGATE_METRICS",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
@staticmethod
|
||||
def _pre_validate_big_number_config(
|
||||
config: Dict[str, Any],
|
||||
) -> Tuple[bool, ChartGenerationError | None]:
|
||||
"""Pre-validate big number chart configuration."""
|
||||
if "metric" not in config:
|
||||
return False, ChartGenerationError(
|
||||
error_type="missing_metric",
|
||||
message="Big Number chart missing required field: metric",
|
||||
details="Big Number charts require a 'metric' field "
|
||||
"specifying the value to display",
|
||||
suggestions=[
|
||||
"Add 'metric' with name and aggregate: "
|
||||
"{'name': 'revenue', 'aggregate': 'SUM'}",
|
||||
"The aggregate function is required (SUM, COUNT, AVG, MIN, MAX)",
|
||||
"Example: {'chart_type': 'big_number', "
|
||||
"'metric': {'name': 'sales', 'aggregate': 'SUM'}}",
|
||||
],
|
||||
error_code="MISSING_BIG_NUMBER_METRIC",
|
||||
)
|
||||
|
||||
metric = config.get("metric", {})
|
||||
if not isinstance(metric, dict):
|
||||
return False, ChartGenerationError(
|
||||
error_type="invalid_metric_type",
|
||||
message="Big Number metric must be a dict with 'name' and 'aggregate'",
|
||||
details="The 'metric' field must be an object, "
|
||||
f"got {type(metric).__name__}",
|
||||
suggestions=[
|
||||
"Use a dict: {'name': 'col', 'aggregate': 'SUM'}",
|
||||
"Valid aggregates: SUM, COUNT, AVG, MIN, MAX",
|
||||
],
|
||||
error_code="INVALID_BIG_NUMBER_METRIC_TYPE",
|
||||
)
|
||||
if not metric.get("aggregate") and not metric.get("saved_metric"):
|
||||
return False, ChartGenerationError(
|
||||
error_type="missing_metric_aggregate",
|
||||
message="Big Number metric must include an aggregate function "
|
||||
"or reference a saved metric",
|
||||
details="The metric must have an 'aggregate' field "
|
||||
"or 'saved_metric': true",
|
||||
suggestions=[
|
||||
"Add 'aggregate' to your metric: "
|
||||
"{'name': 'col', 'aggregate': 'SUM'}",
|
||||
"Or use a saved metric: "
|
||||
"{'name': 'total_sales', 'saved_metric': true}",
|
||||
"Valid aggregates: SUM, COUNT, AVG, MIN, MAX",
|
||||
],
|
||||
error_code="MISSING_BIG_NUMBER_AGGREGATE",
|
||||
)
|
||||
|
||||
show_trendline = config.get("show_trendline", False)
|
||||
temporal_column = config.get("temporal_column")
|
||||
if show_trendline and not temporal_column:
|
||||
return False, ChartGenerationError(
|
||||
error_type="missing_temporal_column",
|
||||
message="Trendline requires a temporal column",
|
||||
details="When 'show_trendline' is True, a "
|
||||
"'temporal_column' must be specified",
|
||||
suggestions=[
|
||||
"Add 'temporal_column': 'date_column_name'",
|
||||
"Or set 'show_trendline': false for number only",
|
||||
"Use get_dataset_info to find temporal columns",
|
||||
],
|
||||
error_code="MISSING_TEMPORAL_COLUMN",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
@staticmethod
|
||||
def _pre_validate_pivot_table_config(
|
||||
config: Dict[str, Any],
|
||||
) -> Tuple[bool, ChartGenerationError | None]:
|
||||
"""Pre-validate pivot table configuration."""
|
||||
missing_fields = []
|
||||
|
||||
if "rows" not in config:
|
||||
missing_fields.append("'rows' (row grouping columns)")
|
||||
if "metrics" not in config:
|
||||
missing_fields.append("'metrics' (aggregation metrics)")
|
||||
|
||||
if missing_fields:
|
||||
return False, ChartGenerationError(
|
||||
error_type="missing_pivot_fields",
|
||||
message=f"Pivot table missing required "
|
||||
f"fields: {', '.join(missing_fields)}",
|
||||
details="Pivot tables require row groupings and metrics",
|
||||
suggestions=[
|
||||
"Add 'rows' field: [{'name': 'category'}]",
|
||||
"Add 'metrics' field: [{'name': 'sales', 'aggregate': 'SUM'}]",
|
||||
"Optional 'columns' for cross-tabulation: [{'name': 'region'}]",
|
||||
],
|
||||
error_code="MISSING_PIVOT_FIELDS",
|
||||
)
|
||||
|
||||
if not isinstance(config.get("rows", []), list):
|
||||
return False, ChartGenerationError(
|
||||
error_type="invalid_rows_format",
|
||||
message="Rows must be a list of columns",
|
||||
details="The 'rows' field must be an array of column specifications",
|
||||
suggestions=[
|
||||
"Wrap row columns in array: 'rows': [{'name': 'category'}]",
|
||||
],
|
||||
error_code="INVALID_ROWS_FORMAT",
|
||||
)
|
||||
|
||||
if not isinstance(config.get("metrics", []), list):
|
||||
return False, ChartGenerationError(
|
||||
error_type="invalid_metrics_format",
|
||||
message="Metrics must be a list",
|
||||
details="The 'metrics' field must be an array of metric specifications",
|
||||
suggestions=[
|
||||
"Wrap metrics in array: 'metrics': [{'name': 'sales', "
|
||||
"'aggregate': 'SUM'}]",
|
||||
],
|
||||
error_code="INVALID_METRICS_FORMAT",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
@staticmethod
|
||||
def _pre_validate_mixed_timeseries_config(
|
||||
config: Dict[str, Any],
|
||||
) -> Tuple[bool, ChartGenerationError | None]:
|
||||
"""Pre-validate mixed timeseries configuration."""
|
||||
missing_fields = []
|
||||
|
||||
if "x" not in config:
|
||||
missing_fields.append("'x' (X-axis temporal column)")
|
||||
if "y" not in config:
|
||||
missing_fields.append("'y' (primary Y-axis metrics)")
|
||||
if "y_secondary" not in config:
|
||||
missing_fields.append("'y_secondary' (secondary Y-axis metrics)")
|
||||
|
||||
if missing_fields:
|
||||
return False, ChartGenerationError(
|
||||
error_type="missing_mixed_timeseries_fields",
|
||||
message=f"Mixed timeseries chart missing required "
|
||||
f"fields: {', '.join(missing_fields)}",
|
||||
details="Mixed timeseries charts require an x-axis, primary metrics, "
|
||||
"and secondary metrics",
|
||||
suggestions=[
|
||||
"Add 'x' field: {'name': 'date_column'}",
|
||||
"Add 'y' field: [{'name': 'revenue', 'aggregate': 'SUM'}]",
|
||||
"Add 'y_secondary' field: [{'name': 'orders', "
|
||||
"'aggregate': 'COUNT'}]",
|
||||
"Optional: 'primary_kind' and 'secondary_kind' for chart types",
|
||||
],
|
||||
error_code="MISSING_MIXED_TIMESERIES_FIELDS",
|
||||
)
|
||||
|
||||
for field_name in ["y", "y_secondary"]:
|
||||
if not isinstance(config.get(field_name, []), list):
|
||||
return False, ChartGenerationError(
|
||||
error_type=f"invalid_{field_name}_format",
|
||||
message=f"'{field_name}' must be a list of metrics",
|
||||
details=f"The '{field_name}' field must be an array of metric "
|
||||
"specifications",
|
||||
suggestions=[
|
||||
f"Wrap in array: '{field_name}': "
|
||||
"[{'name': 'col', 'aggregate': 'SUM'}]",
|
||||
],
|
||||
error_code=f"INVALID_{field_name.upper()}_FORMAT",
|
||||
)
|
||||
|
||||
if (error := plugin.pre_validate(config)) is not None:
|
||||
return False, error
|
||||
return True, None
|
||||
|
||||
@staticmethod
|
||||
@@ -537,89 +213,26 @@ class SchemaValidator:
|
||||
if err.get("type") == "union_tag_invalid" or "discriminator" in str(
|
||||
err.get("ctx", {})
|
||||
):
|
||||
# This is the generic union error - provide better message
|
||||
config = request_data.get("config", {})
|
||||
chart_type = config.get("chart_type", "unknown")
|
||||
from superset.mcp_service.chart.registry import get_registry
|
||||
|
||||
if chart_type == "xy":
|
||||
return ChartGenerationError(
|
||||
error_type="xy_validation_error",
|
||||
message="XY chart configuration validation failed",
|
||||
details="The XY chart configuration is missing required "
|
||||
"fields or has invalid structure",
|
||||
suggestions=[
|
||||
"Ensure 'x' field exists with {'name': 'column_name'}",
|
||||
"Ensure 'y' field is an array: [{'name': 'metric', "
|
||||
"'aggregate': 'SUM'}]",
|
||||
"Check that all column names are strings",
|
||||
"Verify aggregate functions are valid: SUM, COUNT, AVG, "
|
||||
"MIN, MAX",
|
||||
],
|
||||
error_code="XY_VALIDATION_ERROR",
|
||||
)
|
||||
elif chart_type == "table":
|
||||
return ChartGenerationError(
|
||||
error_type="table_validation_error",
|
||||
message="Table chart configuration validation failed",
|
||||
details="The table chart configuration is missing required "
|
||||
"fields or has invalid structure",
|
||||
suggestions=[
|
||||
"Ensure 'columns' field is an array of column "
|
||||
"specifications",
|
||||
"Each column needs {'name': 'column_name'}",
|
||||
"Optional: add 'aggregate' for metrics",
|
||||
"Example: 'columns': [{'name': 'product'}, {'name': "
|
||||
"'sales', 'aggregate': 'SUM'}]",
|
||||
],
|
||||
error_code="TABLE_VALIDATION_ERROR",
|
||||
)
|
||||
elif chart_type == "handlebars":
|
||||
return ChartGenerationError(
|
||||
error_type="handlebars_validation_error",
|
||||
message="Handlebars chart configuration validation failed",
|
||||
details="The handlebars chart configuration is missing "
|
||||
"required fields or has invalid structure",
|
||||
suggestions=[
|
||||
"Ensure 'handlebars_template' is a non-empty string",
|
||||
"For aggregate mode: add 'metrics' with aggregate "
|
||||
"functions",
|
||||
"For raw mode: set 'query_mode': 'raw' and add 'columns'",
|
||||
"Example: {'chart_type': 'handlebars', "
|
||||
"'handlebars_template': '<ul>{{#each data}}<li>"
|
||||
"{{this.name}}</li>{{/each}}</ul>', "
|
||||
"'metrics': [{'name': 'sales', 'aggregate': 'SUM'}]}",
|
||||
],
|
||||
error_code="HANDLEBARS_VALIDATION_ERROR",
|
||||
)
|
||||
elif chart_type == "big_number":
|
||||
return ChartGenerationError(
|
||||
error_type="big_number_validation_error",
|
||||
message="Big Number chart configuration validation failed",
|
||||
details="The Big Number chart configuration is "
|
||||
"missing required fields or has invalid "
|
||||
"structure",
|
||||
suggestions=[
|
||||
"Ensure 'metric' field has 'name' and 'aggregate'",
|
||||
"Example: 'metric': {'name': 'revenue', "
|
||||
"'aggregate': 'SUM'}",
|
||||
"For trendline: add 'show_trendline': true "
|
||||
"and 'temporal_column': 'date_col'",
|
||||
"Without trendline: just provide the metric",
|
||||
],
|
||||
error_code="BIG_NUMBER_VALIDATION_ERROR",
|
||||
)
|
||||
chart_type = request_data.get("config", {}).get("chart_type", "")
|
||||
plugin = get_registry().get(chart_type)
|
||||
if plugin is not None:
|
||||
hint = plugin.schema_error_hint()
|
||||
if hint is not None:
|
||||
return hint
|
||||
|
||||
# Default enhanced error
|
||||
error_details = []
|
||||
for err in errors[:3]: # Show first 3 errors
|
||||
loc = " -> ".join(str(location) for location in err.get("loc", []))
|
||||
msg = err.get("msg", "Validation failed")
|
||||
error_details.append(f"{loc}: {msg}")
|
||||
error_details.append(f"{loc}: {msg}" if loc else msg)
|
||||
|
||||
return ChartGenerationError(
|
||||
error_type="validation_error",
|
||||
message="Chart configuration validation failed",
|
||||
details="; ".join(error_details),
|
||||
details="; ".join(error_details) or "Invalid chart configuration structure",
|
||||
suggestions=[
|
||||
"Check that all required fields are present",
|
||||
"Ensure field types match the schema",
|
||||
|
||||
@@ -81,6 +81,17 @@ try:
|
||||
mcp_config = get_mcp_config(_mcp_app.config)
|
||||
_mcp_app.config.update(mcp_config)
|
||||
|
||||
# Re-configure chart registry so MCP-specific overrides (e.g.
|
||||
# MCP_DISABLED_CHART_PLUGINS set by the operator) take effect after
|
||||
# the MCP config overlay. SupersetAppInitializer.configure_mcp_chart_registry()
|
||||
# ran earlier with pre-overlay values; this call corrects them.
|
||||
from superset.mcp_service.chart import registry as _chart_registry
|
||||
|
||||
_chart_registry.configure(
|
||||
disabled=_mcp_app.config.get("MCP_DISABLED_CHART_PLUGINS"),
|
||||
enabled_func=_mcp_app.config.get("MCP_CHART_PLUGIN_ENABLED_FUNC"),
|
||||
)
|
||||
|
||||
with _mcp_app.app_context():
|
||||
from superset.core.mcp.core_mcp_injection import (
|
||||
initialize_core_mcp_dependencies,
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from flask import Flask
|
||||
@@ -56,6 +57,46 @@ MCP_DEBUG = False
|
||||
# against the FAB security_manager before execution.
|
||||
MCP_RBAC_ENABLED = True
|
||||
|
||||
# =============================================================================
|
||||
# MCP Chart Plugin Filtering
|
||||
# =============================================================================
|
||||
#
|
||||
# Overview:
|
||||
# ---------
|
||||
# These two settings let operators enable/disable individual chart type plugins
|
||||
# at runtime without a code deploy.
|
||||
#
|
||||
# Use cases:
|
||||
# - Emergency kill switch: add "handlebars" to MCP_DISABLED_CHART_PLUGINS and
|
||||
# restart to immediately hide it from all callers.
|
||||
# - Dynamic per-request control (A/B test, gradual rollout): set
|
||||
# MCP_CHART_PLUGIN_ENABLED_FUNC to an in-process predicate that can vary
|
||||
# by user, request header, or any other context available at call time.
|
||||
#
|
||||
# Priority:
|
||||
# MCP_CHART_PLUGIN_ENABLED_FUNC takes precedence over MCP_DISABLED_CHART_PLUGINS.
|
||||
# When the callable is set, the deny-list is ignored entirely.
|
||||
#
|
||||
# MCP_CHART_PLUGIN_ENABLED_FUNC contract:
|
||||
# - Called as enabled_func(chart_type: str) -> bool for every registry lookup.
|
||||
# - Must be cheap and in-process: consult already-loaded feature flags or
|
||||
# request-local context (e.g. Flask g). Do NOT perform network I/O per call.
|
||||
# - On exception, the registry fails CLOSED (plugin hidden) and logs a warning.
|
||||
# - Example (Harness / Split via pre-fetched flags in g):
|
||||
# from flask import g
|
||||
# def MCP_CHART_PLUGIN_ENABLED_FUNC(chart_type: str) -> bool:
|
||||
# flags = getattr(g, "feature_flags", {})
|
||||
# return flags.get(f"mcp_chart_{chart_type}", True)
|
||||
# =============================================================================
|
||||
|
||||
# Chart types in this set are hidden from all registry lookups.
|
||||
# Use frozenset to avoid accidental mutation.
|
||||
MCP_DISABLED_CHART_PLUGINS: frozenset[str] = frozenset()
|
||||
|
||||
# Dynamic per-call predicate. When set, overrides MCP_DISABLED_CHART_PLUGINS.
|
||||
# Signature: (chart_type: str) -> bool
|
||||
MCP_CHART_PLUGIN_ENABLED_FUNC: Callable[[str], bool] | None = None
|
||||
|
||||
# MCP JWT Debug Errors - controls server-side JWT debug logging.
|
||||
# When False (default), uses the default JWTVerifier with minimal logging.
|
||||
# When True, uses DetailedJWTVerifier with tiered logging:
|
||||
@@ -402,6 +443,8 @@ def get_mcp_config(app_config: Dict[str, Any] | None = None) -> Dict[str, Any]:
|
||||
"MCP_SERVICE_PORT": MCP_SERVICE_PORT,
|
||||
"MCP_DEBUG": MCP_DEBUG,
|
||||
"MCP_RBAC_ENABLED": MCP_RBAC_ENABLED,
|
||||
"MCP_DISABLED_CHART_PLUGINS": MCP_DISABLED_CHART_PLUGINS,
|
||||
"MCP_CHART_PLUGIN_ENABLED_FUNC": MCP_CHART_PLUGIN_ENABLED_FUNC,
|
||||
**MCP_SESSION_CONFIG,
|
||||
**MCP_CSRF_CONFIG,
|
||||
}
|
||||
|
||||
@@ -41,6 +41,12 @@ from superset.mcp_service.constants import (
|
||||
DEFAULT_TOKEN_LIMIT,
|
||||
DEFAULT_WARN_THRESHOLD_PCT,
|
||||
)
|
||||
from superset.mcp_service.utils.token_utils import (
|
||||
estimate_response_tokens,
|
||||
format_size_limit_error,
|
||||
INFO_TOOLS,
|
||||
truncate_oversized_response,
|
||||
)
|
||||
from superset.utils.core import get_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -1104,11 +1110,6 @@ class ResponseSizeGuardMiddleware(Middleware):
|
||||
``content[0].text`` as a JSON string. We parse that string, run the
|
||||
truncation phases on the resulting dict, then re-wrap the result.
|
||||
"""
|
||||
from superset.mcp_service.utils.token_utils import (
|
||||
estimate_response_tokens,
|
||||
truncate_oversized_response,
|
||||
)
|
||||
|
||||
# Unwrap ToolResult so truncation operates on the real payload
|
||||
extracted = self._extract_payload_from_tool_result(response)
|
||||
if extracted is not None:
|
||||
@@ -1191,12 +1192,6 @@ class ResponseSizeGuardMiddleware(Middleware):
|
||||
# Execute the tool
|
||||
response = await call_next(context)
|
||||
|
||||
# Estimate response token count (guard against huge responses causing OOM)
|
||||
from superset.mcp_service.utils.token_utils import (
|
||||
estimate_response_tokens,
|
||||
format_size_limit_error,
|
||||
)
|
||||
|
||||
# When the response is a ToolResult, estimate tokens on the actual
|
||||
# payload inside content[0].text rather than on the ToolResult
|
||||
# wrapper (which would double-serialize the JSON string).
|
||||
@@ -1233,8 +1228,6 @@ class ResponseSizeGuardMiddleware(Middleware):
|
||||
params = getattr(context.message, "params", {}) or {}
|
||||
|
||||
# For info tools, try dynamic truncation before blocking
|
||||
from superset.mcp_service.utils.token_utils import INFO_TOOLS
|
||||
|
||||
if tool_name in INFO_TOOLS:
|
||||
truncated = self._try_truncate_info_response(
|
||||
tool_name, response, estimated_tokens
|
||||
|
||||
@@ -21,6 +21,26 @@ Token counting and response size utilities for MCP service.
|
||||
This module provides utilities to estimate token counts and generate smart
|
||||
suggestions when responses exceed configured limits. This prevents large
|
||||
responses from overwhelming LLM clients like Claude Desktop.
|
||||
|
||||
Token counting strategy:
|
||||
|
||||
1. ``tiktoken`` with the ``cl100k_base`` encoding when the package is
|
||||
installed (it is shipped as part of the ``fastmcp`` extra). This is a
|
||||
real BPE tokenizer trained on a similar vocabulary to Claude's; for
|
||||
English and JSON-heavy MCP payloads it tracks Claude's tokenizer
|
||||
within roughly ±10%, which is far more accurate than the legacy
|
||||
character heuristic.
|
||||
2. A character-based fallback (``CHARS_PER_TOKEN``) when tiktoken is not
|
||||
importable. The fallback uses a slightly more conservative ratio than
|
||||
before (3.0 chars/token instead of 3.5) so that JSON-heavy responses
|
||||
are not under-counted, which previously let oversized payloads slip
|
||||
past the response-size guard.
|
||||
|
||||
The exact-Claude tokenizer is only available via Anthropic's network
|
||||
``count_tokens`` API; calling it from a synchronous middleware on every
|
||||
tool result is too slow and adds an external dependency on every
|
||||
response. ``tiktoken`` is the closest approximation we can ship without
|
||||
that risk.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -36,18 +56,63 @@ logger = logging.getLogger(__name__)
|
||||
# Type alias for MCP tool responses (Pydantic models, dicts, lists, strings, bytes)
|
||||
ToolResponse: TypeAlias = Union[BaseModel, Dict[str, Any], List[Any], str, bytes]
|
||||
|
||||
# Approximate characters per token for estimation
|
||||
# Claude tokenizer averages ~4 chars per token for English text
|
||||
# JSON tends to be more verbose, so we use a slightly lower ratio
|
||||
CHARS_PER_TOKEN = 3.5
|
||||
# Fallback character-to-token ratio used when tiktoken is unavailable.
|
||||
# 3.0 is conservative for JSON content (the previous 3.5 under-counted
|
||||
# JSON-heavy payloads relative to Claude's actual tokenizer, which let
|
||||
# oversized responses slip past the response-size guard).
|
||||
CHARS_PER_TOKEN = 3.0
|
||||
|
||||
# Encoding used when tiktoken is available. cl100k_base is OpenAI's
|
||||
# tokenizer for GPT-3.5/4; it is BPE-based with a vocabulary similar to
|
||||
# Claude's and tracks Claude's token counts within roughly ±10% for
|
||||
# English and JSON-heavy MCP responses.
|
||||
_TIKTOKEN_ENCODING_NAME = "cl100k_base"
|
||||
|
||||
|
||||
def _load_tiktoken_encoding() -> Any:
|
||||
"""Return a tiktoken encoding instance, or None if tiktoken is unavailable.
|
||||
|
||||
Imported lazily so the module can be used in environments without
|
||||
tiktoken installed. The encoding is small (~1 MB) so we cache it on
|
||||
first use.
|
||||
"""
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
logger.info(
|
||||
"tiktoken not installed; falling back to char-based token "
|
||||
"estimation (CHARS_PER_TOKEN=%s). Install the 'fastmcp' extra "
|
||||
"for accurate counts.",
|
||||
CHARS_PER_TOKEN,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
return tiktoken.get_encoding(_TIKTOKEN_ENCODING_NAME)
|
||||
except (KeyError, ValueError) as exc:
|
||||
# tiktoken installed but the requested encoding is missing — this
|
||||
# only happens on partial installs. Treat as no tokenizer rather
|
||||
# than crashing on every tool call.
|
||||
logger.warning(
|
||||
"tiktoken encoding '%s' unavailable: %s; falling back to "
|
||||
"char-based token estimation",
|
||||
_TIKTOKEN_ENCODING_NAME,
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# Cached encoding instance (None if tiktoken not importable).
|
||||
_ENCODING = _load_tiktoken_encoding()
|
||||
|
||||
|
||||
def estimate_token_count(text: str | bytes) -> int:
|
||||
"""
|
||||
Estimate the token count for a given text.
|
||||
|
||||
Uses a character-based heuristic since we don't have direct access to
|
||||
the actual tokenizer. This is conservative to avoid underestimating.
|
||||
Uses tiktoken's ``cl100k_base`` encoding when available for
|
||||
Claude-aligned accuracy (within ~10%), falling back to a
|
||||
character-based heuristic otherwise.
|
||||
|
||||
Args:
|
||||
text: The text to estimate tokens for (string or bytes)
|
||||
@@ -58,11 +123,19 @@ def estimate_token_count(text: str | bytes) -> int:
|
||||
if isinstance(text, bytes):
|
||||
text = text.decode("utf-8", errors="replace")
|
||||
|
||||
# Simple heuristic: ~3.5 characters per token for JSON/code
|
||||
text_length = len(text)
|
||||
if text_length == 0:
|
||||
if not text:
|
||||
return 0
|
||||
return max(1, int(text_length / CHARS_PER_TOKEN))
|
||||
|
||||
if _ENCODING is not None:
|
||||
try:
|
||||
return len(_ENCODING.encode(text))
|
||||
except (ValueError, UnicodeError) as exc:
|
||||
# Defensive: if tiktoken chokes on a specific input, fall
|
||||
# back to the char heuristic for this call rather than
|
||||
# raising — the response size guard must never fail-open.
|
||||
logger.warning("tiktoken encode failed (%s); using fallback", exc)
|
||||
|
||||
return max(1, int(len(text) / CHARS_PER_TOKEN))
|
||||
|
||||
|
||||
def estimate_response_tokens(response: ToolResponse) -> int:
|
||||
|
||||
@@ -19,9 +19,7 @@
|
||||
OpenSearch SQL dialect.
|
||||
|
||||
OpenSearch SQL is syntactically close to MySQL but accepts both backticks and
|
||||
double-quotes as identifier delimiters. Treating ``"`` as an identifier (rather
|
||||
than a string delimiter, as MySQL does) is what keeps mixed-case column names
|
||||
from being emitted as string literals after a SQLGlot round-trip.
|
||||
double-quotes as identifier delimiters.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -31,4 +29,4 @@ from sqlglot.dialects.mysql import MySQL
|
||||
|
||||
class OpenSearch(MySQL):
|
||||
class Tokenizer(MySQL.Tokenizer):
|
||||
IDENTIFIERS = ['"', "`"]
|
||||
IDENTIFIERS = ["`", '"']
|
||||
|
||||
@@ -45,6 +45,13 @@
|
||||
color: #000;
|
||||
}
|
||||
{% endif %}
|
||||
{% if standalone_mode %}
|
||||
/* Keep body sized so screenshot waits don't see it as hidden before React mounts. */
|
||||
html, body.standalone {
|
||||
min-height: 100vh;
|
||||
margin: 0;
|
||||
}
|
||||
{% endif %}
|
||||
</style>
|
||||
|
||||
{% if dark_theme_bg and entry != 'embedded' %}
|
||||
|
||||
@@ -68,3 +68,50 @@ def test_spa_template_includes_css_bundles():
|
||||
"spa.html must call css_bundle for the page entry to load "
|
||||
"entry-specific extracted CSS in production builds"
|
||||
)
|
||||
|
||||
|
||||
def test_spa_template_standalone_body_has_min_height():
|
||||
"""Standalone body must be measurable so screenshot waits don't time out."""
|
||||
from jinja2 import DictLoader, Environment
|
||||
|
||||
template_path = join(SUPERSET_DIR, "templates", "superset", "spa.html")
|
||||
with open(template_path) as f:
|
||||
template_content = f.read()
|
||||
|
||||
env = Environment( # noqa: S701
|
||||
loader=DictLoader(
|
||||
{
|
||||
"spa.html": template_content,
|
||||
# Stub out includes/imports that are not relevant for this test.
|
||||
"appbuilder/general/lib.html": "",
|
||||
"superset/partials/asset_bundle.html": (
|
||||
"{% macro css_bundle(prefix, entry) %}{% endmacro %}"
|
||||
"{% macro js_bundle(prefix, entry) %}{% endmacro %}"
|
||||
),
|
||||
"superset/macros.html": ("{% macro get_nonce() %}{% endmacro %}"),
|
||||
"tail_js_custom_extra.html": "",
|
||||
"head_custom_extra.html": "",
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
appbuilder = Mock()
|
||||
appbuilder.app.config = {"FAVICONS": []}
|
||||
|
||||
def render(standalone_mode: bool) -> str:
|
||||
return env.get_template("spa.html").render(
|
||||
appbuilder=appbuilder,
|
||||
assets_prefix="",
|
||||
bootstrap_data="{}",
|
||||
entry="spa",
|
||||
standalone_mode=standalone_mode,
|
||||
theme_tokens={},
|
||||
spinner_svg=None,
|
||||
)
|
||||
|
||||
standalone_html = render(standalone_mode=True)
|
||||
assert "body.standalone" in standalone_html
|
||||
assert "min-height: 100vh" in standalone_html
|
||||
|
||||
non_standalone_html = render(standalone_mode=False)
|
||||
assert "body.standalone" not in non_standalone_html
|
||||
|
||||
@@ -90,7 +90,7 @@ class TestBigNumberChartConfig:
|
||||
"chart_type": "big_number",
|
||||
"metric": {"name": "total_sales", "saved_metric": True},
|
||||
}
|
||||
is_valid, error = SchemaValidator._pre_validate_big_number_config(data)
|
||||
is_valid, error = SchemaValidator._pre_validate_chart_type("big_number", data)
|
||||
assert is_valid is True
|
||||
assert error is None
|
||||
|
||||
|
||||
143
tests/unit_tests/mcp_service/chart/test_registry.py
Normal file
143
tests/unit_tests/mcp_service/chart/test_registry.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Tests for the chart type plugin registry."""
|
||||
|
||||
import pytest
|
||||
|
||||
import superset.mcp_service.chart.registry as registry_module
|
||||
from superset.mcp_service.chart.plugin import BaseChartPlugin
|
||||
from superset.mcp_service.chart.registry import (
|
||||
_RegistryProxy,
|
||||
all_types,
|
||||
display_name_for_viz_type,
|
||||
get,
|
||||
get_registry,
|
||||
is_registered,
|
||||
register,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolated_registry(monkeypatch):
|
||||
"""Run each test against a clean registry without touching the real one."""
|
||||
monkeypatch.setattr(registry_module, "_REGISTRY", {})
|
||||
monkeypatch.setattr(registry_module, "_plugins_loaded", True)
|
||||
|
||||
|
||||
class _FakePlugin(BaseChartPlugin):
|
||||
chart_type = "fake"
|
||||
display_name = "Fake Chart"
|
||||
native_viz_types = {"fake_viz": "Fake Viz"}
|
||||
|
||||
|
||||
class _AnotherPlugin(BaseChartPlugin):
|
||||
chart_type = "another"
|
||||
display_name = "Another Chart"
|
||||
native_viz_types = {"another_viz": "Another Viz"}
|
||||
|
||||
|
||||
def test_register_adds_plugin():
|
||||
plugin = _FakePlugin()
|
||||
register(plugin)
|
||||
assert get("fake") is plugin
|
||||
|
||||
|
||||
def test_get_returns_none_for_unknown():
|
||||
assert get("nonexistent") is None
|
||||
|
||||
|
||||
def test_all_types_returns_registered_keys():
|
||||
register(_FakePlugin())
|
||||
register(_AnotherPlugin())
|
||||
types = all_types()
|
||||
assert "fake" in types
|
||||
assert "another" in types
|
||||
|
||||
|
||||
def test_all_types_insertion_order():
|
||||
register(_FakePlugin())
|
||||
register(_AnotherPlugin())
|
||||
types = all_types()
|
||||
assert types.index("fake") < types.index("another")
|
||||
|
||||
|
||||
def test_is_registered_true_for_known():
|
||||
register(_FakePlugin())
|
||||
assert is_registered("fake") is True
|
||||
|
||||
|
||||
def test_is_registered_false_for_unknown():
|
||||
assert is_registered("nonexistent") is False
|
||||
|
||||
|
||||
def test_register_warns_on_duplicate(caplog):
|
||||
register(_FakePlugin())
|
||||
with caplog.at_level("WARNING"):
|
||||
register(_FakePlugin())
|
||||
assert "Overwriting" in caplog.text
|
||||
|
||||
|
||||
def test_register_raises_for_empty_chart_type():
|
||||
class _BadPlugin(BaseChartPlugin):
|
||||
chart_type = ""
|
||||
|
||||
with pytest.raises(ValueError, match="non-empty chart_type"):
|
||||
register(_BadPlugin())
|
||||
|
||||
|
||||
def test_display_name_for_viz_type_found():
|
||||
register(_FakePlugin())
|
||||
assert display_name_for_viz_type("fake_viz") == "Fake Viz"
|
||||
|
||||
|
||||
def test_display_name_for_viz_type_not_found():
|
||||
register(_FakePlugin())
|
||||
assert display_name_for_viz_type("unknown_viz") is None
|
||||
|
||||
|
||||
def test_display_name_searches_all_plugins():
|
||||
register(_FakePlugin())
|
||||
register(_AnotherPlugin())
|
||||
assert display_name_for_viz_type("another_viz") == "Another Viz"
|
||||
|
||||
|
||||
def test_get_registry_returns_proxy():
|
||||
assert isinstance(get_registry(), _RegistryProxy)
|
||||
|
||||
|
||||
def test_registry_proxy_get():
|
||||
plugin = _FakePlugin()
|
||||
register(plugin)
|
||||
assert get_registry().get("fake") is plugin
|
||||
|
||||
|
||||
def test_registry_proxy_all_types():
|
||||
register(_FakePlugin())
|
||||
assert "fake" in get_registry().all_types()
|
||||
|
||||
|
||||
def test_registry_proxy_is_registered():
|
||||
register(_FakePlugin())
|
||||
assert get_registry().is_registered("fake") is True
|
||||
assert get_registry().is_registered("missing") is False
|
||||
|
||||
|
||||
def test_registry_proxy_display_name_for_viz_type():
|
||||
register(_FakePlugin())
|
||||
assert get_registry().display_name_for_viz_type("fake_viz") == "Fake Viz"
|
||||
assert get_registry().display_name_for_viz_type("unknown") is None
|
||||
222
tests/unit_tests/mcp_service/chart/test_registry_filters.py
Normal file
222
tests/unit_tests/mcp_service/chart/test_registry_filters.py
Normal file
@@ -0,0 +1,222 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Tests for registry plugin filtering (configure / is_enabled / get / all_types)."""
|
||||
|
||||
import pytest
|
||||
|
||||
import superset.mcp_service.chart.registry as registry_module
|
||||
from superset.mcp_service.chart.plugin import BaseChartPlugin
|
||||
from superset.mcp_service.chart.registry import (
|
||||
_PluginFilterConfig,
|
||||
all_types,
|
||||
configure,
|
||||
display_name_for_viz_type,
|
||||
get,
|
||||
is_enabled,
|
||||
is_registered,
|
||||
register,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolated_registry(monkeypatch):
|
||||
"""Isolated registry with two known plugins and a clean filter for each test."""
|
||||
monkeypatch.setattr(registry_module, "_REGISTRY", {})
|
||||
monkeypatch.setattr(registry_module, "_plugins_loaded", True)
|
||||
monkeypatch.setattr(registry_module, "_filter_config", _PluginFilterConfig())
|
||||
register(_AlphaPlugin())
|
||||
register(_BetaPlugin())
|
||||
|
||||
|
||||
class _AlphaPlugin(BaseChartPlugin):
|
||||
chart_type = "alpha"
|
||||
display_name = "Alpha Chart"
|
||||
native_viz_types = {"alpha_viz": "Alpha Viz"}
|
||||
|
||||
|
||||
class _BetaPlugin(BaseChartPlugin):
|
||||
chart_type = "beta"
|
||||
display_name = "Beta Chart"
|
||||
native_viz_types = {"beta_viz": "Beta Viz"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Static deny-list tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_returns_plugin_when_enabled():
|
||||
assert get("alpha") is not None
|
||||
assert get("beta") is not None
|
||||
|
||||
|
||||
def test_get_returns_none_for_disabled_plugin():
|
||||
configure(disabled={"alpha"})
|
||||
assert get("alpha") is None
|
||||
|
||||
|
||||
def test_get_still_returns_other_plugins_when_one_is_disabled():
|
||||
configure(disabled={"alpha"})
|
||||
assert get("beta") is not None
|
||||
|
||||
|
||||
def test_all_types_excludes_disabled():
|
||||
configure(disabled={"alpha"})
|
||||
types = all_types()
|
||||
assert "alpha" not in types
|
||||
assert "beta" in types
|
||||
|
||||
|
||||
def test_all_types_empty_when_all_disabled():
|
||||
configure(disabled={"alpha", "beta"})
|
||||
assert all_types() == []
|
||||
|
||||
|
||||
def test_is_registered_ignores_deny_list():
|
||||
configure(disabled={"alpha"})
|
||||
assert is_registered("alpha") is True
|
||||
|
||||
|
||||
def test_is_enabled_returns_false_for_disabled():
|
||||
configure(disabled={"alpha"})
|
||||
assert is_enabled("alpha") is False
|
||||
|
||||
|
||||
def test_is_enabled_returns_true_when_not_disabled():
|
||||
configure(disabled={"alpha"})
|
||||
assert is_enabled("beta") is True
|
||||
|
||||
|
||||
def test_is_enabled_returns_false_for_unknown():
|
||||
assert is_enabled("nonexistent") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# configure() accepts different iterable shapes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_configure_accepts_list():
|
||||
configure(disabled=["alpha"])
|
||||
assert get("alpha") is None
|
||||
|
||||
|
||||
def test_configure_accepts_tuple():
|
||||
configure(disabled=("alpha",))
|
||||
assert get("alpha") is None
|
||||
|
||||
|
||||
def test_configure_accepts_frozenset():
|
||||
configure(disabled=frozenset({"alpha"}))
|
||||
assert get("alpha") is None
|
||||
|
||||
|
||||
def test_configure_accepts_none_disabled():
|
||||
configure(disabled=None)
|
||||
assert get("alpha") is not None
|
||||
|
||||
|
||||
def test_configure_rejects_noncallable_enabled_func():
|
||||
with pytest.raises(TypeError):
|
||||
configure(enabled_func="not_a_callable")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dynamic callable hook tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_enabled_func_overrides_deny_list():
|
||||
# alpha is in deny-list but callable says True → should be visible
|
||||
configure(disabled={"alpha"}, enabled_func=lambda ct: ct == "alpha")
|
||||
assert get("alpha") is not None
|
||||
|
||||
|
||||
def test_enabled_func_can_disable_plugin():
|
||||
configure(enabled_func=lambda ct: ct != "beta")
|
||||
assert get("beta") is None
|
||||
assert get("alpha") is not None
|
||||
|
||||
|
||||
def test_enabled_func_called_per_lookup():
|
||||
calls = []
|
||||
|
||||
def hook(ct: str) -> bool:
|
||||
calls.append(ct)
|
||||
return True
|
||||
|
||||
configure(enabled_func=hook)
|
||||
get("alpha")
|
||||
get("alpha")
|
||||
assert calls.count("alpha") == 2
|
||||
|
||||
|
||||
def test_enabled_func_exception_fails_closed(caplog):
|
||||
import logging
|
||||
|
||||
def bad_hook(ct: str) -> bool:
|
||||
raise RuntimeError("Harness down")
|
||||
|
||||
configure(enabled_func=bad_hook)
|
||||
with caplog.at_level(logging.WARNING, logger="superset.mcp_service.chart.registry"):
|
||||
result = get("alpha")
|
||||
|
||||
assert result is None # fail closed
|
||||
assert "failing closed" in caplog.text.lower() or "alpha" in caplog.text
|
||||
|
||||
|
||||
def test_enabled_func_all_types_respects_hook():
|
||||
configure(enabled_func=lambda ct: ct == "alpha")
|
||||
assert all_types() == ["alpha"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# display_name_for_viz_type is NOT filtered
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_display_name_unaffected_by_deny_list():
|
||||
configure(disabled={"alpha"})
|
||||
# Even though alpha is disabled, its viz_type should still resolve
|
||||
assert display_name_for_viz_type("alpha_viz") == "Alpha Viz"
|
||||
|
||||
|
||||
def test_display_name_unaffected_by_callable():
|
||||
configure(enabled_func=lambda ct: False)
|
||||
assert display_name_for_viz_type("beta_viz") == "Beta Viz"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# configure() atomicity: replacing config is visible to next lookup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_reconfigure_replaces_previous_filter():
|
||||
configure(disabled={"alpha"})
|
||||
assert get("alpha") is None
|
||||
configure(disabled=set())
|
||||
assert get("alpha") is not None
|
||||
|
||||
|
||||
def test_reconfigure_with_func_then_none_falls_back_to_deny_list():
|
||||
configure(enabled_func=lambda ct: False)
|
||||
assert get("alpha") is None
|
||||
|
||||
configure(disabled={"beta"}, enabled_func=None)
|
||||
assert get("alpha") is not None
|
||||
assert get("beta") is None
|
||||
@@ -595,3 +595,191 @@ Market Share
|
||||
"""
|
||||
|
||||
# These demonstrate the expected ASCII formats for different chart types
|
||||
|
||||
|
||||
class TestDetachedInstanceError:
|
||||
"""Tests that DetachedInstanceError is handled gracefully.
|
||||
|
||||
When the SQLAlchemy session commits mid-request, ORM objects expire and
|
||||
become detached. Accessing lazy attributes on a detached Slice raises
|
||||
DetachedInstanceError. The tool must:
|
||||
1. Call db.session.refresh() immediately after loading the chart so all
|
||||
column values are loaded upfront before any downstream operation.
|
||||
2. Catch SQLAlchemyError (the base class) and return a ChartError
|
||||
instead of propagating the exception.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_refresh_called_after_chart_load(self):
|
||||
"""db.session.refresh() is invoked right after find_chart_by_identifier."""
|
||||
import importlib
|
||||
from contextlib import nullcontext
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from superset.mcp_service.chart.schemas import URLPreview
|
||||
from superset.utils import json
|
||||
|
||||
get_chart_preview_module = importlib.import_module(
|
||||
"superset.mcp_service.chart.tool.get_chart_preview"
|
||||
)
|
||||
|
||||
mock_chart = MagicMock()
|
||||
mock_chart.id = 42
|
||||
mock_chart.slice_name = "Sales Chart"
|
||||
mock_chart.viz_type = "table"
|
||||
mock_chart.datasource_id = 1
|
||||
mock_chart.datasource_type = "table"
|
||||
mock_chart.params = "{}"
|
||||
|
||||
refresh_calls: list[object] = []
|
||||
|
||||
def _fake_refresh(obj: object) -> None:
|
||||
refresh_calls.append(obj)
|
||||
|
||||
url_preview = URLPreview(
|
||||
preview_url="http://localhost/explore/?slice_id=42",
|
||||
width=800,
|
||||
height=600,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
get_chart_preview_module,
|
||||
"find_chart_by_identifier",
|
||||
return_value=mock_chart,
|
||||
),
|
||||
patch.object(
|
||||
get_chart_preview_module.db,
|
||||
"session",
|
||||
**{"refresh.side_effect": _fake_refresh},
|
||||
),
|
||||
patch.object(
|
||||
get_chart_preview_module,
|
||||
"validate_chart_dataset",
|
||||
return_value=MagicMock(is_valid=True, warnings=[]),
|
||||
),
|
||||
patch.object(
|
||||
get_chart_preview_module.event_logger,
|
||||
"log_context",
|
||||
return_value=nullcontext(),
|
||||
),
|
||||
# Return a real URLPreview so Pydantic model validation succeeds
|
||||
patch.object(
|
||||
get_chart_preview_module.PreviewFormatGenerator,
|
||||
"generate",
|
||||
return_value=url_preview,
|
||||
),
|
||||
patch(
|
||||
"superset.mcp_service.utils.url_utils.get_superset_base_url",
|
||||
return_value="http://localhost",
|
||||
),
|
||||
):
|
||||
from fastmcp import Client
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.mcp_service.chart.schemas import GetChartPreviewRequest
|
||||
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mu:
|
||||
mu.return_value = MagicMock(id=1, username="admin")
|
||||
with patch(
|
||||
"superset.mcp_service.auth.check_tool_permission", return_value=True
|
||||
):
|
||||
async with Client(mcp) as client:
|
||||
response = await client.call_tool(
|
||||
"get_chart_preview",
|
||||
{
|
||||
"request": GetChartPreviewRequest(
|
||||
identifier=42, format="url"
|
||||
).model_dump()
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(response.content[0].text)
|
||||
# The tool should succeed — not return a ChartError
|
||||
assert "error_type" not in data, (
|
||||
f"Expected ChartPreview but got ChartError: {data.get('error')}"
|
||||
)
|
||||
assert data.get("chart_id") == 42
|
||||
|
||||
assert len(refresh_calls) == 1, (
|
||||
"db.session.refresh() should be called once after loading the chart"
|
||||
)
|
||||
assert refresh_calls[0] is mock_chart
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detached_instance_error_returns_chart_error(self):
|
||||
"""DetachedInstanceError during preview generation returns ChartError."""
|
||||
import importlib
|
||||
from contextlib import nullcontext
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from sqlalchemy.orm.exc import DetachedInstanceError
|
||||
|
||||
get_chart_preview_module = importlib.import_module(
|
||||
"superset.mcp_service.chart.tool.get_chart_preview"
|
||||
)
|
||||
|
||||
mock_chart = MagicMock()
|
||||
mock_chart.id = 7
|
||||
mock_chart.slice_name = "Broken Chart"
|
||||
mock_chart.viz_type = "bar"
|
||||
mock_chart.datasource_id = 3
|
||||
mock_chart.datasource_type = "table"
|
||||
mock_chart.params = "{}"
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
get_chart_preview_module,
|
||||
"find_chart_by_identifier",
|
||||
return_value=mock_chart,
|
||||
),
|
||||
patch.object(
|
||||
get_chart_preview_module.db,
|
||||
"session",
|
||||
**{"refresh.return_value": None},
|
||||
),
|
||||
patch.object(
|
||||
get_chart_preview_module,
|
||||
"validate_chart_dataset",
|
||||
return_value=MagicMock(is_valid=True, warnings=[]),
|
||||
),
|
||||
patch.object(
|
||||
get_chart_preview_module.event_logger,
|
||||
"log_context",
|
||||
return_value=nullcontext(),
|
||||
),
|
||||
# Simulate the session expiring inside the strategy
|
||||
patch.object(
|
||||
get_chart_preview_module.PreviewFormatGenerator,
|
||||
"generate",
|
||||
side_effect=DetachedInstanceError(),
|
||||
),
|
||||
patch(
|
||||
"superset.mcp_service.utils.url_utils.get_superset_base_url",
|
||||
return_value="http://localhost",
|
||||
),
|
||||
):
|
||||
from fastmcp import Client
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.mcp_service.chart.schemas import GetChartPreviewRequest
|
||||
from superset.utils import json
|
||||
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mu:
|
||||
mu.return_value = MagicMock(id=1, username="admin")
|
||||
with patch(
|
||||
"superset.mcp_service.auth.check_tool_permission", return_value=True
|
||||
):
|
||||
async with Client(mcp) as client:
|
||||
response = await client.call_tool(
|
||||
"get_chart_preview",
|
||||
{
|
||||
"request": GetChartPreviewRequest(
|
||||
identifier=7, format="ascii"
|
||||
).model_dump()
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(response.content[0].text)
|
||||
assert data["error_type"] == "InternalError"
|
||||
assert "session" in data["error"].lower() or "retry" in data["error"].lower()
|
||||
|
||||
10
tests/unit_tests/mcp_service/chart/tool/test_update_chart.py
Normal file → Executable file
10
tests/unit_tests/mcp_service/chart/tool/test_update_chart.py
Normal file → Executable file
@@ -1175,6 +1175,11 @@ class TestUpdateChartValidationGate:
|
||||
)
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id", new_callable=Mock)
|
||||
@patch("superset.db.session")
|
||||
@patch(
|
||||
"superset.mcp_service.chart.validation.dataset_validator"
|
||||
".DatasetValidator.validate_against_dataset",
|
||||
new=Mock(return_value=(True, None)),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_preview_path_validation_failure_skips_cache(
|
||||
self,
|
||||
@@ -1238,6 +1243,11 @@ class TestUpdateChartValidationGate:
|
||||
)
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id", new_callable=Mock)
|
||||
@patch("superset.db.session")
|
||||
@patch(
|
||||
"superset.mcp_service.chart.validation.dataset_validator"
|
||||
".DatasetValidator.validate_against_dataset",
|
||||
new=Mock(return_value=(True, None)),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_path_validation_failure_skips_db_write(
|
||||
self,
|
||||
|
||||
@@ -117,83 +117,6 @@ class TestGetCanonicalColumnName:
|
||||
assert result == "unknown_column"
|
||||
|
||||
|
||||
class TestNormalizeXYConfig:
|
||||
"""Test _normalize_xy_config static method."""
|
||||
|
||||
def test_normalize_x_axis_column(
|
||||
self, mock_dataset_context: DatasetContext
|
||||
) -> None:
|
||||
"""Test that x-axis column name is normalized."""
|
||||
config_dict: Dict[str, Any] = {
|
||||
"chart_type": "xy",
|
||||
"x": {"name": "orderdate"},
|
||||
"y": [{"name": "Sales", "aggregate": "SUM"}],
|
||||
"kind": "line",
|
||||
}
|
||||
|
||||
DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context)
|
||||
|
||||
assert config_dict["x"]["name"] == "OrderDate"
|
||||
|
||||
def test_normalize_y_axis_columns(
|
||||
self, mock_dataset_context: DatasetContext
|
||||
) -> None:
|
||||
"""Test that y-axis column names are normalized."""
|
||||
config_dict: Dict[str, Any] = {
|
||||
"chart_type": "xy",
|
||||
"x": {"name": "OrderDate"},
|
||||
"y": [
|
||||
{"name": "sales", "aggregate": "SUM"},
|
||||
{"name": "QUANTITY_ORDERED", "aggregate": "COUNT"},
|
||||
],
|
||||
"kind": "bar",
|
||||
}
|
||||
|
||||
DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context)
|
||||
|
||||
assert config_dict["y"][0]["name"] == "Sales"
|
||||
assert config_dict["y"][1]["name"] == "quantity_ordered"
|
||||
|
||||
def test_normalize_group_by_column(
|
||||
self, mock_dataset_context: DatasetContext
|
||||
) -> None:
|
||||
"""Test that group_by column name is normalized."""
|
||||
config_dict: Dict[str, Any] = {
|
||||
"chart_type": "xy",
|
||||
"x": {"name": "OrderDate"},
|
||||
"y": [{"name": "Sales", "aggregate": "SUM"}],
|
||||
"kind": "line",
|
||||
"group_by": [{"name": "productline"}],
|
||||
}
|
||||
|
||||
DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context)
|
||||
|
||||
assert config_dict["group_by"][0]["name"] == "ProductLine"
|
||||
|
||||
|
||||
class TestNormalizeTableConfig:
|
||||
"""Test _normalize_table_config static method."""
|
||||
|
||||
def test_normalize_table_columns(
|
||||
self, mock_dataset_context: DatasetContext
|
||||
) -> None:
|
||||
"""Test that table column names are normalized."""
|
||||
config_dict: Dict[str, Any] = {
|
||||
"chart_type": "table",
|
||||
"columns": [
|
||||
{"name": "orderdate"},
|
||||
{"name": "PRODUCTLINE"},
|
||||
{"name": "sales", "aggregate": "SUM"},
|
||||
],
|
||||
}
|
||||
|
||||
DatasetValidator._normalize_table_config(config_dict, mock_dataset_context)
|
||||
|
||||
assert config_dict["columns"][0]["name"] == "OrderDate"
|
||||
assert config_dict["columns"][1]["name"] == "ProductLine"
|
||||
assert config_dict["columns"][2]["name"] == "Sales"
|
||||
|
||||
|
||||
class TestNormalizeFilters:
|
||||
"""Test _normalize_filters static method."""
|
||||
|
||||
|
||||
@@ -58,12 +58,12 @@ class TestRuntimeValidatorNonBlocking:
|
||||
x_axis=AxisConfig(format="$,.2f"), # Currency format for date - mismatch
|
||||
)
|
||||
|
||||
# Mock the format validator to return warnings
|
||||
# Mock the plugin runtime dispatcher to return format warnings
|
||||
with patch(
|
||||
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
|
||||
"_validate_format_compatibility"
|
||||
) as mock_format:
|
||||
mock_format.return_value = [
|
||||
"_validate_plugin_runtime"
|
||||
) as mock_plugin:
|
||||
mock_plugin.return_value = [
|
||||
"Currency format '$,.2f' may not display dates correctly"
|
||||
]
|
||||
|
||||
@@ -87,15 +87,14 @@ class TestRuntimeValidatorNonBlocking:
|
||||
kind="bar",
|
||||
)
|
||||
|
||||
# Mock the cardinality validator to return warnings
|
||||
# Mock the plugin runtime dispatcher to return cardinality warnings
|
||||
with patch(
|
||||
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
|
||||
"_validate_cardinality"
|
||||
) as mock_cardinality:
|
||||
mock_cardinality.return_value = (
|
||||
["High cardinality detected: 10000+ unique values"],
|
||||
["Consider using aggregation or filtering"],
|
||||
)
|
||||
"_validate_plugin_runtime"
|
||||
) as mock_plugin:
|
||||
mock_plugin.return_value = [
|
||||
"High cardinality detected: 10000+ unique values"
|
||||
]
|
||||
|
||||
is_valid, warnings_metadata = RuntimeValidator.validate_runtime_issues(
|
||||
config, 1
|
||||
@@ -148,26 +147,21 @@ class TestRuntimeValidatorNonBlocking:
|
||||
x_axis=AxisConfig(format="smart_date"), # Wrong format for user_id
|
||||
)
|
||||
|
||||
# Mock all validators to return warnings
|
||||
# Mock plugin runtime and chart type validators to return warnings
|
||||
with (
|
||||
patch(
|
||||
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
|
||||
"_validate_format_compatibility"
|
||||
) as mock_format,
|
||||
patch(
|
||||
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
|
||||
"_validate_cardinality"
|
||||
) as mock_cardinality,
|
||||
"_validate_plugin_runtime"
|
||||
) as mock_plugin,
|
||||
patch(
|
||||
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
|
||||
"_validate_chart_type"
|
||||
) as mock_type,
|
||||
):
|
||||
mock_format.return_value = ["Format mismatch warning"]
|
||||
mock_cardinality.return_value = (
|
||||
["High cardinality warning"],
|
||||
["Cardinality suggestion"],
|
||||
)
|
||||
mock_plugin.return_value = [
|
||||
"Format mismatch warning",
|
||||
"High cardinality warning",
|
||||
]
|
||||
mock_type.return_value = (
|
||||
["Chart type warning"],
|
||||
["Chart type suggestion"],
|
||||
@@ -197,13 +191,13 @@ class TestRuntimeValidatorNonBlocking:
|
||||
with (
|
||||
patch(
|
||||
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
|
||||
"_validate_format_compatibility"
|
||||
) as mock_format,
|
||||
"_validate_plugin_runtime"
|
||||
) as mock_plugin,
|
||||
patch(
|
||||
"superset.mcp_service.chart.validation.runtime.logger"
|
||||
) as mock_logger,
|
||||
):
|
||||
mock_format.return_value = ["Test warning message"]
|
||||
mock_plugin.return_value = ["Test warning message"]
|
||||
|
||||
is_valid, warnings_metadata = RuntimeValidator.validate_runtime_issues(
|
||||
config, 1
|
||||
@@ -217,7 +211,7 @@ class TestRuntimeValidatorNonBlocking:
|
||||
assert "warnings" in warnings_metadata
|
||||
|
||||
def test_validate_table_chart_skips_xy_validations(self):
|
||||
"""Test that table charts skip XY-specific validations."""
|
||||
"""Test that table charts produce no XY-specific runtime warnings."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[
|
||||
@@ -226,28 +220,15 @@ class TestRuntimeValidatorNonBlocking:
|
||||
],
|
||||
)
|
||||
|
||||
# These should not be called for table charts
|
||||
with (
|
||||
patch(
|
||||
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
|
||||
"_validate_format_compatibility"
|
||||
) as mock_format,
|
||||
patch(
|
||||
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
|
||||
"_validate_cardinality"
|
||||
) as mock_cardinality,
|
||||
patch(
|
||||
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
|
||||
"_validate_chart_type"
|
||||
) as mock_chart_type,
|
||||
):
|
||||
# Mock chart type validator to return no warnings
|
||||
# Plugin runtime dispatches to TableChartPlugin which returns no warnings.
|
||||
# Chart type suggester is also stubbed to return no warnings.
|
||||
with patch(
|
||||
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
|
||||
"_validate_chart_type"
|
||||
) as mock_chart_type:
|
||||
mock_chart_type.return_value = ([], [])
|
||||
|
||||
is_valid, error = RuntimeValidator.validate_runtime_issues(config, 1)
|
||||
|
||||
# Format and cardinality validation should not be called for table charts
|
||||
mock_format.assert_not_called()
|
||||
mock_cardinality.assert_not_called()
|
||||
assert is_valid is True
|
||||
assert error is None
|
||||
|
||||
@@ -146,7 +146,13 @@ class TestResponseSizeGuardMiddleware:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_warning_at_threshold(self) -> None:
|
||||
"""Should log warning when approaching limit."""
|
||||
"""Should log warning when approaching limit.
|
||||
|
||||
Mocks the token estimator to return a specific value above the
|
||||
warn threshold but below the hard limit, decoupling the test
|
||||
from whichever tokenizer (tiktoken or char heuristic) happens
|
||||
to be loaded.
|
||||
"""
|
||||
middleware = ResponseSizeGuardMiddleware(
|
||||
token_limit=1000, warn_threshold_pct=80
|
||||
)
|
||||
@@ -155,18 +161,21 @@ class TestResponseSizeGuardMiddleware:
|
||||
context.message.name = "list_charts"
|
||||
context.message.params = {}
|
||||
|
||||
# Response at ~85% of limit (should trigger warning but not block)
|
||||
response = {"data": "x" * 2900} # ~828 tokens at 3.5 chars/token
|
||||
response = {"data": "approaching the limit"}
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with (
|
||||
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
||||
patch("superset.mcp_service.middleware.event_logger"),
|
||||
patch(
|
||||
"superset.mcp_service.middleware.estimate_response_tokens",
|
||||
return_value=850,
|
||||
),
|
||||
patch("superset.mcp_service.middleware.logger") as mock_logger,
|
||||
):
|
||||
result = await middleware.on_call_tool(context, call_next)
|
||||
|
||||
# Should return response (not blocked)
|
||||
# Should return response (not blocked at 85% of limit)
|
||||
assert result == response
|
||||
# Should log warning
|
||||
mock_logger.warning.assert_called()
|
||||
|
||||
@@ -20,9 +20,11 @@ Unit tests for MCP service token utilities.
|
||||
"""
|
||||
|
||||
from typing import Any, List
|
||||
from unittest.mock import patch
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from superset.mcp_service.utils import token_utils
|
||||
from superset.mcp_service.utils.token_utils import (
|
||||
_replace_collections_with_summaries,
|
||||
_summarize_large_dicts,
|
||||
@@ -45,29 +47,65 @@ class TestEstimateTokenCount:
|
||||
"""Test estimate_token_count function."""
|
||||
|
||||
def test_estimate_string(self) -> None:
|
||||
"""Should estimate tokens for a string."""
|
||||
"""Should produce a positive non-zero estimate for a normal string.
|
||||
|
||||
We don't assert on a specific number because the result depends on
|
||||
which tokenizer is loaded (tiktoken when available, char heuristic
|
||||
otherwise).
|
||||
"""
|
||||
text = "Hello world"
|
||||
result = estimate_token_count(text)
|
||||
expected = int(len(text) / CHARS_PER_TOKEN)
|
||||
assert result == expected
|
||||
assert result > 0
|
||||
|
||||
def test_estimate_bytes(self) -> None:
|
||||
"""Should estimate tokens for bytes."""
|
||||
text = b"Hello world"
|
||||
result = estimate_token_count(text)
|
||||
expected = int(len(text) / CHARS_PER_TOKEN)
|
||||
assert result == expected
|
||||
"""Bytes input should be decoded and produce the same count as the
|
||||
equivalent string."""
|
||||
text = "Hello world"
|
||||
assert estimate_token_count(text.encode("utf-8")) == estimate_token_count(text)
|
||||
|
||||
def test_empty_string(self) -> None:
|
||||
"""Should return 0 for empty string."""
|
||||
"""Should return 0 for empty string and empty bytes."""
|
||||
assert estimate_token_count("") == 0
|
||||
assert estimate_token_count(b"") == 0
|
||||
|
||||
def test_json_like_content(self) -> None:
|
||||
"""Should estimate tokens for JSON-like content."""
|
||||
"""JSON content should produce a positive estimate."""
|
||||
json_str = '{"name": "test", "value": 123, "items": [1, 2, 3]}'
|
||||
result = estimate_token_count(json_str)
|
||||
assert result > 0
|
||||
assert result == int(len(json_str) / CHARS_PER_TOKEN)
|
||||
assert estimate_token_count(json_str) > 0
|
||||
|
||||
def test_long_text_roughly_scales_with_length(self) -> None:
|
||||
"""A doubled string should produce roughly double the token count
|
||||
(within ±10%)."""
|
||||
small = "the quick brown fox jumps over the lazy dog. " * 20
|
||||
large = small * 2
|
||||
small_n = estimate_token_count(small)
|
||||
large_n = estimate_token_count(large)
|
||||
# Within 10% of 2x — both tokenizers (tiktoken and the char
|
||||
# fallback) preserve length monotonicity.
|
||||
assert 1.8 * small_n <= large_n <= 2.2 * small_n
|
||||
|
||||
def test_fallback_uses_chars_per_token_when_tiktoken_unavailable(
|
||||
self,
|
||||
) -> None:
|
||||
"""When the tiktoken encoding is None (not installed), the
|
||||
function falls back to len/CHARS_PER_TOKEN math."""
|
||||
text = "x" * 100
|
||||
with patch.object(token_utils, "_ENCODING", None):
|
||||
result = estimate_token_count(text)
|
||||
assert result == int(100 / CHARS_PER_TOKEN)
|
||||
|
||||
def test_fallback_when_tiktoken_encode_raises(self) -> None:
|
||||
"""A misbehaving encoding should fall back to the char heuristic
|
||||
rather than raise — the size guard must never fail-open."""
|
||||
|
||||
class BoomEncoding:
|
||||
def encode(self, text: str) -> list[int]:
|
||||
raise ValueError("simulated tiktoken failure")
|
||||
|
||||
text = "abc" * 50
|
||||
with patch.object(token_utils, "_ENCODING", BoomEncoding()):
|
||||
result = estimate_token_count(text)
|
||||
assert result == int(len(text) / CHARS_PER_TOKEN)
|
||||
|
||||
|
||||
class TestEstimateResponseTokens:
|
||||
|
||||
@@ -33,7 +33,8 @@ def test_opensearch_dialect_registered() -> None:
|
||||
|
||||
def test_double_quotes_as_identifiers() -> None:
|
||||
"""
|
||||
Test that double quotes are treated as identifiers, not string literals.
|
||||
Test that double quotes are treated as identifiers, not string literals,
|
||||
and normalized to backticks in output.
|
||||
"""
|
||||
sql = 'SELECT "AvgTicketPrice" FROM "flights"'
|
||||
ast = sqlglot.parse_one(sql, OpenSearch)
|
||||
@@ -42,8 +43,8 @@ def test_double_quotes_as_identifiers() -> None:
|
||||
OpenSearch().generate(expression=ast, pretty=True)
|
||||
== """
|
||||
SELECT
|
||||
"AvgTicketPrice"
|
||||
FROM "flights"
|
||||
`AvgTicketPrice`
|
||||
FROM `flights`
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@@ -69,8 +70,7 @@ WHERE
|
||||
|
||||
def test_backticks_as_identifiers() -> None:
|
||||
"""
|
||||
Test that backticks work as identifiers (MySQL-style).
|
||||
Backticks are normalized to double quotes in output.
|
||||
Test that backticks are accepted as identifiers and preserved on output.
|
||||
"""
|
||||
sql = "SELECT `AvgTicketPrice` FROM `flights`"
|
||||
ast = sqlglot.parse_one(sql, OpenSearch)
|
||||
@@ -79,15 +79,16 @@ def test_backticks_as_identifiers() -> None:
|
||||
OpenSearch().generate(expression=ast, pretty=True)
|
||||
== """
|
||||
SELECT
|
||||
"AvgTicketPrice"
|
||||
FROM "flights"
|
||||
`AvgTicketPrice`
|
||||
FROM `flights`
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def test_mixed_identifier_quotes() -> None:
|
||||
"""
|
||||
Test mixing double quotes and backticks for identifiers.
|
||||
Test mixing double quotes and backticks for identifiers are all normalized to
|
||||
backticks on output.
|
||||
"""
|
||||
sql = 'SELECT "AvgTicketPrice" AS `AvgTicketPrice` FROM `default`.`flights`'
|
||||
ast = sqlglot.parse_one(sql, OpenSearch)
|
||||
@@ -96,12 +97,26 @@ def test_mixed_identifier_quotes() -> None:
|
||||
OpenSearch().generate(expression=ast, pretty=True)
|
||||
== """
|
||||
SELECT
|
||||
"AvgTicketPrice" AS "AvgTicketPrice"
|
||||
FROM "default"."flights"
|
||||
`AvgTicketPrice` AS `AvgTicketPrice`
|
||||
FROM `default`.`flights`
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def test_alias_with_space() -> None:
|
||||
"""
|
||||
Test that an alias containing a space (e.g. a metric key like ``my test``)
|
||||
is preserved as a backtick-quoted identifier through the round-trip.
|
||||
"""
|
||||
sql = 'SELECT COUNT(*) AS "my test" FROM `flights`'
|
||||
ast = sqlglot.parse_one(sql, OpenSearch)
|
||||
|
||||
assert (
|
||||
OpenSearch().generate(expression=ast, pretty=False)
|
||||
== "SELECT COUNT(*) AS `my test` FROM `flights`"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql, expected",
|
||||
[
|
||||
@@ -110,20 +125,20 @@ FROM "default"."flights"
|
||||
"""
|
||||
SELECT
|
||||
COUNT(*)
|
||||
FROM "flights"
|
||||
FROM `flights`
|
||||
WHERE
|
||||
"Cancelled" = TRUE
|
||||
`Cancelled` = TRUE
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
'SELECT "Carrier", SUM("AvgTicketPrice") FROM "flights" GROUP BY "Carrier"',
|
||||
"""
|
||||
SELECT
|
||||
"Carrier",
|
||||
SUM("AvgTicketPrice")
|
||||
FROM "flights"
|
||||
`Carrier`,
|
||||
SUM(`AvgTicketPrice`)
|
||||
FROM `flights`
|
||||
GROUP BY
|
||||
"Carrier"
|
||||
`Carrier`
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
@@ -131,9 +146,9 @@ GROUP BY
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM "flights"
|
||||
FROM `flights`
|
||||
WHERE
|
||||
"DestCountry" IN ('US', 'CA', 'MX')
|
||||
`DestCountry` IN ('US', 'CA', 'MX')
|
||||
""".strip(),
|
||||
),
|
||||
],
|
||||
@@ -165,13 +180,13 @@ GROUP BY "Carrier"
|
||||
OpenSearch().generate(expression=ast, pretty=True)
|
||||
== """
|
||||
SELECT
|
||||
"Carrier",
|
||||
`Carrier`,
|
||||
COUNT(*),
|
||||
AVG("AvgTicketPrice"),
|
||||
MAX("FlightDelayMin")
|
||||
FROM "flights"
|
||||
AVG(`AvgTicketPrice`),
|
||||
MAX(`FlightDelayMin`)
|
||||
FROM `flights`
|
||||
GROUP BY
|
||||
"Carrier"
|
||||
`Carrier`
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@@ -190,10 +205,10 @@ SELECT
|
||||
*
|
||||
FROM (
|
||||
SELECT
|
||||
"Carrier",
|
||||
"AvgTicketPrice"
|
||||
FROM "flights"
|
||||
) AS "sub"
|
||||
`Carrier`,
|
||||
`AvgTicketPrice`
|
||||
FROM `flights`
|
||||
) AS `sub`
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@@ -212,12 +227,12 @@ def test_order_by_with_quoted_identifiers() -> None:
|
||||
OpenSearch().generate(expression=ast, pretty=True)
|
||||
== """
|
||||
SELECT
|
||||
"Carrier",
|
||||
"AvgTicketPrice"
|
||||
FROM "flights"
|
||||
`Carrier`,
|
||||
`AvgTicketPrice`
|
||||
FROM `flights`
|
||||
ORDER BY
|
||||
"AvgTicketPrice" DESC,
|
||||
"Carrier" ASC
|
||||
`AvgTicketPrice` DESC,
|
||||
`Carrier` ASC
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@@ -234,7 +249,7 @@ def test_limit_clause() -> None:
|
||||
== """
|
||||
SELECT
|
||||
*
|
||||
FROM "flights"
|
||||
FROM `flights`
|
||||
LIMIT 10
|
||||
""".strip()
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user