Compare commits

...

22 Commits

Author SHA1 Message Date
Amin Ghadersohi
f5ba09f7af feat(mcp): add runtime chart plugin enable/disable via _PluginFilterConfig
Introduces a dynamic filter layer in the chart type registry so operators can
disable individual plugins (e.g. `handlebars`) without a code deploy:

- `MCP_DISABLED_CHART_PLUGINS: frozenset[str]` — static deny-list in mcp_config.py
- `MCP_CHART_PLUGIN_ENABLED_FUNC: Callable[[str], bool] | None` — dynamic hook
  for Harness/Split/per-user targeting; takes precedence over the deny-list
- Both keys are propagated through `get_mcp_config()` defaults

registry.py changes:
- `_PluginFilterConfig` frozen dataclass replaces two bare globals so
  configure() replaces them atomically (no torn reads under concurrency)
- `configure(disabled, enabled_func)` — called at app init; accepts any
  iterable for `disabled`; validates `enabled_func` is callable
- `_is_plugin_enabled()` — reads config once, fails closed on callable exception
- `get()` / `all_types()` / `is_enabled()` apply the filter at lookup time;
  `is_registered()` and `display_name_for_viz_type()` intentionally bypass it
  so callers can distinguish "unknown" vs "disabled" and existing charts still
  resolve display names for disabled viz types

schema_validator.py: two-step pre-check — `is_registered()` for unknown types,
`is_enabled()` for disabled ones, with distinct `DISABLED_CHART_TYPE` error code.

Wiring:
- `SupersetAppInitializer.configure_mcp_chart_registry()` called after
  `configure_feature_flags()` in `init_app()`
- `flask_singleton.py` re-calls `registry.configure()` after the MCP config
  overlay so MCP-specific overrides in `superset_config.py` take effect in
  standalone MCP mode

Tests: 28 cases in test_registry_filters.py covering deny-list, callable hook,
fail-closed on exception, all_types() filtering, display_name bypass, atomic
reconfigure, and configure() with list/tuple/frozenset inputs.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-15 00:59:54 +00:00
Amin Ghadersohi
9d298f86f0 fix(mcp): fix E501 in update_chart.py and update_chart test mocks for column validation
Split an 89-char comment line and an over-limit condition in update_chart.py
to satisfy the ruff E501 rule. Also applied ruff format.

Two TestUpdateChartValidationGate tests expected CHART_VALIDATION_FAILED but
received CHART_DATASET_NOT_FOUND because _validate_update_against_dataset calls
DatasetValidator.validate_against_dataset before validate_and_compile, and the
existing mocks provided a Mock() object for chart.datasource whose .id attribute
is an auto-generated MagicMock (not a real int). Added a patch for
DatasetValidator.validate_against_dataset returning (True, None) so the
column-validation tier is bypassed and the test reaches the mocked
validate_and_compile response as intended.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-13 22:12:08 +00:00
Amin Ghadersohi
b34d346e0d refactor(mcp): address Codex review — fix registry bug, DRY schema hints, remove column regex
P1.1 registry.py: move _plugins_loaded=True to after successful import so a
failed load doesn't permanently poison the registry.

P1.3 schemas.py: remove overly restrictive ColumnRef.name / FilterClause.column
/ BigNumberChartConfig.temporal_column regex that blocked valid column names
containing parentheses, slashes, and other SQL-common characters.

P2.3 (DRY): eliminate _CHART_TYPE_ERROR_HINTS second-registry in
schema_validator.py by adding schema_error_hint() to ChartTypePlugin protocol,
BaseChartPlugin default, and all 7 plugin classes. SchemaValidator now delegates
to the plugin registry instead of maintaining a parallel dict.

P3.3 test_registry.py: add full registry unit-test coverage (register, get,
all_types, is_registered, display_name_for_viz_type, proxy methods, duplicate
warning, empty chart_type validation, insertion-order guarantee).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-13 21:27:39 +00:00
Amin Ghadersohi
18a9eff641 fix(mcp): add full column validation to update_chart
update_chart was only running SchemaValidator + Tier-2 compile check,
silently skipping DatasetValidator's column-existence + fuzzy-match
and column-name normalisation layers that generate_chart runs.

A typo like {name: "reveneu"} would save the broken chart and only
surface as a render-time failure in the browser.

Now matches generate_chart pipeline:
- Layer 2: DatasetValidator.validate_against_dataset() — column
  existence check with fuzzy-match "did you mean?" suggestions returned
  to the LLM before any DB write occurs
- Layer 4: DatasetValidator.normalize_column_names() — case
  normalisation so "order_date" resolves to "OrderDate" if that is the
  canonical dataset name

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-13 20:14:32 +00:00
Amin Ghadersohi
eface3bf54 fix(mcp): add threading lock to registry plugin loader
_ensure_plugins_loaded() used an unprotected boolean flag, making it
unsafe under concurrent first-call scenarios (e.g. gunicorn multi-thread
workers). Double-checked locking with threading.Lock eliminates the race.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-13 17:40:10 +00:00
Amin Ghadersohi
b41a53bc8f fix(mcp): resolve E402 and E501 in dataset_validator.py
- Move error_schemas import above _C TypeVar definition (E402)
- Split two over-length comment lines to ≤88 chars (E501, lines 268 and 380)
2026-05-10 00:08:43 +00:00
Amin Ghadersohi
f1b95d6ae3 fix(mcp): resolve ruff E501 and formatting issues to pass pre-commit
- Split long string literal in schema_validator.py line 202 (E501, 94 > 88 chars)
- Apply ruff format auto-fixes to big_number.py, handlebars.py, and test_get_chart_data.py
2026-05-09 00:11:52 +00:00
Amin Ghadersohi
1e2b541600 refactor(mcp): move all local imports to top level in chart type plugins
All per-method local imports in the 7 chart plugins were moved to module-level.
None of them create circular imports: schemas.py, chart_utils.py, and
dataset_validator.py are safe to import at plugin load time because those
modules guard their own registry lookups with local imports.

- big_number: add map_big_number_config, _big_number_chart_what,
  _summarize_filters, DatasetValidator to top-level imports
- pie: add map_pie_config, _pie_chart_what, _summarize_filters, PieChartConfig,
  DatasetValidator to top-level imports
- xy: add map_xy_config, _xy_chart_what/context, XYChartConfig, DatasetValidator,
  FormatTypeValidator, CardinalityValidator to top-level imports
- table: add map_table_config, _table_chart_what, _summarize_filters,
  TableChartConfig, DatasetValidator to top-level imports
- pivot_table: add map_pivot_table_config, _pivot_table_what, _summarize_filters,
  PivotTableChartConfig, DatasetValidator to top-level imports
- mixed_timeseries: add map_mixed_timeseries_config, _mixed_timeseries_what,
  _summarize_filters, MixedTimeseriesChartConfig, DatasetValidator to top-level
- handlebars: add map_handlebars_config, _handlebars_chart_what, _summarize_filters,
  HandlebarsChartConfig, DatasetValidator to top-level imports

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-07 22:19:03 +00:00
Amin Ghadersohi
b16de3622f fix(mcp): address reviewer comments — local import rationale, x-optional corrections, cardinality suggestions
- Remove redundant local imports from BigNumberChartPlugin.post_map_validate()
  now that BigNumberChartConfig and is_column_truly_temporal are at top level
- Add explanatory comments on the two remaining local get_registry imports in
  chart_utils.py and dataset_validator.py (circular import prevention)
- Fix schema_validator.py and generate_chart.py docstring: XY 'x' field is
  optional (defaults to dataset primary datetime column), not required
- Propagate cardinality suggestions alongside warnings in XYChartPlugin
- Clarify app.py instructions: chart_type_display_name is null for viz_types
  outside the 7 generate_chart-supported types

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-07 22:19:03 +00:00
Amin Ghadersohi
5e02d0ec65 refactor(mcp): complete plugin protocol — registry bootstrap, mypy fixes, test repairs
On top of the dead-code elimination in the previous commit:
- Add lazy _ensure_plugins_loaded() bootstrap to ChartTypeRegistry so the
  registry is populated even without importing app.py (fixes isolated test runs)
- Delegate _RegistryProxy methods to module-level functions so bootstrap runs
- Guard register() against empty chart_type strings
- Add generate_name + resolve_viz_type to ChartTypePlugin Protocol and
  BaseChartPlugin; delegate generate_chart_name/_resolve_viz_type in
  chart_utils to the plugin registry
- Add _with_context static helper to BaseChartPlugin (shared by all plugins)
- Fix stale 'five methods' → 'eight methods' docstring in plugin.py
- Add TypeVar _C to normalize_column_names so mypy infers correct return type
- Fix broken tests: update _pre_validate_big_number_config → _pre_validate_chart_type,
  remove deleted TestNormalizeXYConfig/TestNormalizeTableConfig classes,
  update runtime validator tests for removed _validate_format_compatibility /
  _validate_cardinality methods, add x is not None narrowing guards

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-07 22:19:03 +00:00
Amin Ghadersohi
139eea92f6 refactor(mcp): eliminate dead code and complete plugin registry dispatch
H1: Delete 7 dead _pre_validate_* static methods from SchemaValidator
    — exact duplicates of plugin pre_validate() methods, never called
    after _pre_validate_chart_type() was updated to delegate to plugin.

H2: Inline DatasetValidator._normalize_xy_config/_normalize_table_config
    into XYChartPlugin/TableChartPlugin.normalize_column_refs() and delete
    both DatasetValidator helper methods. The 5 other plugins already
    called _get_canonical_column_name directly; XY and Table now match.

H3: Add generate_name()/resolve_viz_type() to ChartTypePlugin protocol
    and BaseChartPlugin, implement in all 7 plugins. Replace the 7-arm
    isinstance chain in generate_chart_name() and the 7-arm elif chain
    in _resolve_viz_type() with single-line registry dispatch.

H4: Add a sync comment above _CHART_TYPE_ERROR_HINTS to document that
    it must stay in sync with the plugin registry.

M4: Move logger=logging.getLogger(__name__) from inside
    XYChartPlugin.get_runtime_warnings() to module scope.
2026-05-07 22:19:03 +00:00
Amin Ghadersohi
b09cbc80aa feat(mcp): add display_name and native_viz_types to chart type plugins
Each ChartTypePlugin now declares:
- display_name: human-readable label for the chart_type discriminator
  (e.g. "Line / Bar / Area / Scatter Chart", "Pivot Table")
- native_viz_types: dict mapping every Superset-internal viz_type the
  plugin produces to a user-friendly name
  (e.g. {"echarts_timeseries_line": "Line Chart", "echarts_area": "Area Chart"})

The registry gains display_name_for_viz_type(viz_type) which searches
all plugins' native_viz_types maps, replacing the need for a separate
viz_type_display_names.json or viz_type_names.py module.

ChartInfo gains a chart_type_display_name field populated via the registry,
so list_charts / get_chart_info return human-readable chart type names.
The MCP system instructions now reference display names rather than
internal viz_type identifiers.
2026-05-07 22:19:03 +00:00
Amin Ghadersohi
e7adf0c670 feat(mcp): introduce chart type plugin registry for extensible chart generation
Replaces four scattered dispatch locations (schema_validator, dataset_validator,
chart_utils, runtime validator) with a central ChartTypePlugin registry. Each of
the 7 supported chart types (xy, table, pie, pivot_table, mixed_timeseries,
handlebars, big_number) now owns its pre-validation, column extraction, form_data
mapping, post-map validation, column normalization, and runtime warnings in a
single plugin class.

Key changes:
- Add ChartTypePlugin protocol and BaseChartPlugin base class (plugin.py)
- Add ChartTypeRegistry with register/get/all_types helpers (registry.py)
- Add 7 chart type plugins under chart/plugins/ with full coverage
- Fix 5-type column validation gap: pie, pivot_table, mixed_timeseries, handlebars,
  and big_number now participate in dataset column validation (previously silently skipped)
- Move BigNumber trendline temporal check to BigNumberChartPlugin.post_map_validate()
- Add get_runtime_warnings() to plugin protocol; XYChartPlugin implements
  format/cardinality checks, removing isinstance(config, XYChartConfig) from RuntimeValidator
- Fix stale generate_chart.py docstring listing only 'xy' and 'table' chart types
- Add missing pie, pivot_table, mixed_timeseries handlers to _enhance_validation_error;
  refactor into a data-driven lookup table to stay within complexity limits
- Fix empty details fallback in Pydantic error handler
2026-05-07 22:19:03 +00:00
Vitor Avila
ad5e3170dd fix: OpenSearch dialect identifier delimiters (#39953) 2026-05-07 16:19:27 -03:00
Maxime Beauchemin
aa710672ed fix(ui): remove makeUrl() double-prefix bugs under subdirectory deployment (#39503)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Vitor Avila <96086495+Vitor-Avila@users.noreply.github.com>
2026-05-07 15:39:38 -03:00
Richard Fogaca Nienkotter
8c80caefa3 fix(explore): preserve preview chart name on save (#39908) 2026-05-07 13:08:28 -03:00
Richard Fogaca Nienkotter
8088c5d1de fix(dashboard): match auto-refresh paused-dot outline to icon color (#39909) 2026-05-07 13:07:52 -03:00
Amin Ghadersohi
9b520312a1 fix(mcp): use tiktoken for response-size-guard token estimation (#39912) 2026-05-07 11:51:31 -04:00
Amin Ghadersohi
9ac4711ac8 fix(mcp): prevent DetachedInstanceError in get_chart_preview (#39921) 2026-05-07 11:44:11 -04:00
dependabot[bot]
7593d2a164 chore(deps): bump caniuse-lite from 1.0.30001791 to 1.0.30001792 in /docs (#39933)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-07 21:57:29 +07:00
dependabot[bot]
d3c44e311e chore(deps): bump aws-actions/amazon-ecr-login from 2.1.4 to 2.1.5 (#39931)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-07 21:54:59 +07:00
Enzo Martellucci
b5186d1c65 fix(reports): keep body sized so standalone screenshots don't time out (#39944) 2026-05-07 12:26:50 +02:00
54 changed files with 3116 additions and 933 deletions

View File

@@ -58,7 +58,7 @@ jobs:
- name: Login to Amazon ECR
if: steps.describe-services.outputs.active == 'true'
id: login-ecr
uses: aws-actions/amazon-ecr-login@19d944daaa35f0fa1d3f7f8af1d3f2e5de25c5b7 # v2
uses: aws-actions/amazon-ecr-login@fa648b43de3d4d023bcb3f89ed6940096949c419 # v2
- name: Delete ECR image tag
if: steps.describe-services.outputs.active == 'true'

View File

@@ -199,7 +199,7 @@ jobs:
- name: Login to Amazon ECR
id: login-ecr
uses: aws-actions/amazon-ecr-login@19d944daaa35f0fa1d3f7f8af1d3f2e5de25c5b7 # v2
uses: aws-actions/amazon-ecr-login@fa648b43de3d4d023bcb3f89ed6940096949c419 # v2
- name: Load, tag and push image to ECR
id: push-image
@@ -235,7 +235,7 @@ jobs:
- name: Login to Amazon ECR
id: login-ecr
uses: aws-actions/amazon-ecr-login@19d944daaa35f0fa1d3f7f8af1d3f2e5de25c5b7 # v2
uses: aws-actions/amazon-ecr-login@fa648b43de3d4d023bcb3f89ed6940096949c419 # v2
- name: Check target image exists in ECR
id: check-image

View File

@@ -70,7 +70,7 @@
"@swc/core": "^1.15.33",
"antd": "^6.3.7",
"baseline-browser-mapping": "^2.10.27",
"caniuse-lite": "^1.0.30001791",
"caniuse-lite": "^1.0.30001792",
"docusaurus-plugin-openapi-docs": "^5.0.2",
"docusaurus-theme-openapi-docs": "^5.0.2",
"js-yaml": "^4.1.1",

View File

@@ -6035,10 +6035,10 @@ caniuse-api@^3.0.0:
lodash.memoize "^4.1.2"
lodash.uniq "^4.5.0"
caniuse-lite@^1.0.0, caniuse-lite@^1.0.30001702, caniuse-lite@^1.0.30001759, caniuse-lite@^1.0.30001791:
version "1.0.30001791"
resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001791.tgz#dfb93d85c40ad380c57123e72e10f3c575786b51"
integrity sha512-yk0l/YSrOnFZk3UROpDLQD9+kC1l4meK/wed583AXrzoarMGJcbRi2Q4RaUYbKxYAsZ8sWmaSa/DsLmdBeI1vQ==
caniuse-lite@^1.0.0, caniuse-lite@^1.0.30001702, caniuse-lite@^1.0.30001759, caniuse-lite@^1.0.30001792:
version "1.0.30001792"
resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001792.tgz#ca8bb9be244835a335e2018272ce7223691873c5"
integrity sha512-hVLMUZFgR4JJ6ACt1uEESvQN1/dBVqPAKY0hgrV70eN3391K6juAfTjKZLKvOMsx8PxA7gsY1/tLMMTcfFLLpw==
ccount@^2.0.0:
version "2.0.1"

View File

@@ -142,10 +142,16 @@ druid = ["pydruid>=0.6.5,<0.7"]
duckdb = ["duckdb>=1.4.2,<2", "duckdb-engine>=0.17.0"]
dynamodb = ["pydynamodb>=0.4.2"]
solr = ["sqlalchemy-solr >= 0.2.0"]
elasticsearch = ["elasticsearch-dbapi>=0.2.12, <0.3.0"]
elasticsearch = ["elasticsearch-dbapi>=0.2.13, <0.3.0"]
exasol = ["sqlalchemy-exasol >= 2.4.0, <3.0"]
excel = ["xlrd>=1.2.0, <1.3"]
fastmcp = ["fastmcp>=3.2.4,<4.0"]
fastmcp = [
"fastmcp>=3.2.4,<4.0",
# tiktoken backs the response-size-guard token estimator. Without
# it, the middleware falls back to a coarser character-based
# heuristic that under-counts JSON-heavy MCP responses.
"tiktoken>=0.7.0,<1.0",
]
firebird = ["sqlalchemy-firebird>=0.7.0, <0.8"]
firebolt = ["firebolt-sqlalchemy>=1.0.0, <2"]
gevent = ["gevent>=23.9.1"]

View File

@@ -183,7 +183,9 @@ idna==3.10
# trio
# url-normalize
isodate==0.7.2
# via apache-superset (pyproject.toml)
# via
# apache-superset (pyproject.toml)
# apache-superset-core
itsdangerous==2.2.0
# via
# flask
@@ -296,6 +298,7 @@ pyarrow==20.0.0
# via
# -r requirements/base.in
# apache-superset (pyproject.toml)
# apache-superset-core
pyasn1==0.6.3
# via
# pyasn1-modules

View File

@@ -442,6 +442,7 @@ isodate==0.7.2
# via
# -c requirements/base-constraint.txt
# apache-superset
# apache-superset-core
isort==6.0.1
# via pylint
itsdangerous==2.2.0
@@ -715,6 +716,7 @@ pyarrow==20.0.0
# via
# -c requirements/base-constraint.txt
# apache-superset
# apache-superset-core
# db-dtypes
# pandas-gbq
pyasn1==0.6.3
@@ -866,6 +868,8 @@ referencing==0.36.2
# jsonschema
# jsonschema-path
# jsonschema-specifications
regex==2026.4.4
# via tiktoken
requests==2.33.0
# via
# -c requirements/base-constraint.txt
@@ -878,6 +882,7 @@ requests==2.33.0
# requests-cache
# requests-oauthlib
# shillelagh
# tiktoken
# trino
requests-cache==1.2.1
# via
@@ -1003,6 +1008,8 @@ tabulate==0.9.0
# via
# -c requirements/base-constraint.txt
# apache-superset
tiktoken==0.12.0
# via apache-superset
tomli-w==1.2.0
# via apache-superset-extensions-cli
tomlkit==0.13.3

View File

@@ -17,7 +17,8 @@
* under the License.
*/
import { render, screen, act } from 'spec/helpers/testing-library';
import { StatusIndicatorDot } from './StatusIndicatorDot';
import { supersetTheme } from '@apache-superset/core/theme';
import { getStatusConfig, StatusIndicatorDot } from './StatusIndicatorDot';
import { AutoRefreshStatus } from '../../types/autoRefresh';
afterEach(() => {
@@ -62,6 +63,15 @@ test('renders with paused status', () => {
expect(dot).toHaveAttribute('data-status', AutoRefreshStatus.Paused);
});
test('uses the icon color for the paused status outline', () => {
expect(
getStatusConfig(supersetTheme, AutoRefreshStatus.Paused),
).toMatchObject({
needsBorder: true,
outlineColor: 'currentColor',
});
});
test('has correct accessibility attributes', () => {
render(<StatusIndicatorDot status={AutoRefreshStatus.Success} />);
const dot = screen.getByTestId('status-indicator-dot');

View File

@@ -39,9 +39,10 @@ export interface StatusIndicatorDotProps {
interface StatusConfig {
color: string;
needsBorder: boolean;
outlineColor?: string;
}
const getStatusConfig = (
export const getStatusConfig = (
theme: ReturnType<typeof useTheme>,
status: AutoRefreshStatus,
): StatusConfig => {
@@ -75,6 +76,7 @@ const getStatusConfig = (
return {
color: theme.colorBgContainer,
needsBorder: true,
outlineColor: 'currentColor',
};
default:
return {
@@ -136,13 +138,15 @@ export const StatusIndicatorDot: FC<StatusIndicatorDotProps> = ({
width: ${size}px;
height: ${size}px;
border-radius: 50%;
color: ${theme.colorTextSecondary};
background-color: ${statusConfig.color};
transition:
background-color ${theme.motionDurationMid} ease-in-out,
border-color ${theme.motionDurationMid} ease-in-out;
border: ${statusConfig.needsBorder
? `1px solid ${theme.colorBorder}`
: 'none'};
border: ${statusConfig.needsBorder ? '1px solid' : 'none'};
border-color: ${statusConfig.needsBorder
? statusConfig.outlineColor
: 'transparent'};
box-shadow: ${statusConfig.needsBorder
? 'none'
: `0 0 0 2px ${theme.colorBgContainer}`};

View File

@@ -21,6 +21,10 @@ import { VizType } from '@superset-ui/core';
import { hydrateExplore, HYDRATE_EXPLORE } from './hydrateExplore';
import { exploreInitialData } from '../fixtures';
afterEach(() => {
window.history.pushState({}, '', '/');
});
test('creates hydrate action from initial data', () => {
const dispatch = jest.fn();
const getState = jest.fn(() => ({
@@ -168,6 +172,84 @@ test('creates hydrate action with existing state', () => {
);
});
test('hydrates sliceName from preview form data before saved slice name', () => {
window.history.pushState({}, '', '/explore/?form_data_key=preview-key');
const dispatch = jest.fn();
const getState = jest.fn(() => ({
user: {},
charts: {},
datasources: {},
common: {},
explore: {},
}));
const previewSliceName = 'RENAMED - Bug Evidence';
const savedSliceName = 'Most Populated Countries';
const previewInitialData = {
...exploreInitialData,
form_data: {
...exploreInitialData.form_data,
slice_name: previewSliceName,
},
slice: {
...exploreInitialData.slice!,
slice_name: savedSliceName,
},
};
// @ts-expect-error we only need the fields consumed by hydrateExplore
hydrateExplore(previewInitialData)(dispatch, getState);
expect(dispatch).toHaveBeenCalledWith(
expect.objectContaining({
type: HYDRATE_EXPLORE,
data: expect.objectContaining({
explore: expect.objectContaining({
sliceName: previewSliceName,
}),
}),
}),
);
});
test('hydrates sliceName from saved slice when regular form data has stale name', () => {
const dispatch = jest.fn();
const getState = jest.fn(() => ({
user: {},
charts: {},
datasources: {},
common: {},
explore: {},
}));
const staleFormDataSliceName = 'Stale Params Name';
const savedSliceName = 'Current Saved Name';
const savedChartInitialData = {
...exploreInitialData,
form_data: {
...exploreInitialData.form_data,
slice_name: staleFormDataSliceName,
},
slice: {
...exploreInitialData.slice!,
slice_name: savedSliceName,
},
};
// @ts-expect-error we only need the fields consumed by hydrateExplore
hydrateExplore(savedChartInitialData)(dispatch, getState);
expect(dispatch).toHaveBeenCalledWith(
expect.objectContaining({
type: HYDRATE_EXPLORE,
data: expect.objectContaining({
explore: expect.objectContaining({
sliceName: savedSliceName,
}),
}),
}),
);
});
test('uses configured default time range if not set', () => {
const dispatch = jest.fn();
const getState = jest.fn(() => ({

View File

@@ -77,6 +77,12 @@ export const hydrateExplore =
const fallbackSlice = sliceId ? sliceEntities?.slices?.[sliceId] : null;
const initialSlice = slice ?? fallbackSlice;
const initialFormData = form_data ?? initialSlice?.form_data;
const isCachedFormData = getUrlParam(URL_PARAMS.formDataKey) !== null;
const [primarySliceNameSource, fallbackSliceNameSource] = isCachedFormData
? [initialFormData, initialSlice]
: [initialSlice, initialFormData];
const initialSliceName =
primarySliceNameSource?.slice_name ?? fallbackSliceNameSource?.slice_name;
if (!initialFormData.viz_type) {
const defaultVizType = common?.conf.DEFAULT_VIZ_TYPE || VizType.Table;
initialFormData.viz_type =
@@ -183,6 +189,7 @@ export const hydrateExplore =
// because `bootstrapData.controls` is undefined.
controls: initialControls,
form_data: initialFormData,
sliceName: initialSliceName,
slice: initialSlice,
controlsTransferred: explore.controlsTransferred,
standalone: getUrlParam(URL_PARAMS.standalone),

View File

@@ -179,6 +179,33 @@ test('renders the right footer buttons', () => {
).toBeInTheDocument();
});
test('initializes chart name from current Explore slice name', () => {
const previewSliceName = 'RENAMED - Bug Evidence';
const savedSliceName = 'Most Populated Countries';
const { getByTestId } = setup(
{
...defaultProps,
form_data: {
...defaultProps.form_data,
slice_name: previewSliceName,
},
sliceName: previewSliceName,
},
mockStore({
...initialState,
explore: {
...initialState.explore,
slice: {
...initialState.explore.slice,
slice_name: savedSliceName,
},
},
}),
);
expect(getByTestId('new-chart-name')).toHaveValue(previewSliceName);
});
test('does not render a message when overriding', () => {
const { getByRole, queryByRole } = setup();

View File

@@ -35,7 +35,6 @@ import { CheckboxChangeEvent } from '@superset-ui/core/components/Checkbox/types
import { useHistory } from 'react-router-dom';
import { setItem, LocalStorageKeys } from 'src/utils/localStorageHelpers';
import { makeUrl } from 'src/utils/pathUtils';
import Tabs from '@superset-ui/core/components/Tabs';
import {
Button,
@@ -1824,7 +1823,9 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
onClick={() => {
setLoading(true);
fetchAndSetDB();
redirectURL(makeUrl(`/sqllab?db=true`));
// redirectURL() delegates to history.push; React Router's basename
// already prefixes the application root, so pass a relative path.
redirectURL('/sqllab?db=true');
}}
>
{t('Query data in SQL Lab')}

View File

@@ -24,7 +24,6 @@ import { TableTab } from 'src/views/CRUD/types';
import { t } from '@apache-superset/core/translation';
import { styled } from '@apache-superset/core/theme';
import { navigateTo } from 'src/utils/navigationUtils';
import { makeUrl } from 'src/utils/pathUtils';
import { WelcomeTable } from './types';
const EmptyContainer = styled.div`
@@ -59,7 +58,9 @@ const REDIRECTS = {
create: {
[WelcomeTable.Charts]: '/chart/add',
[WelcomeTable.Dashboards]: '/dashboard/new',
[WelcomeTable.SavedQueries]: makeUrl('/sqllab?new=true'),
// navigateTo() applies the application root internally; keep this
// relative so the prefix isn't added twice.
[WelcomeTable.SavedQueries]: '/sqllab?new=true',
},
viewAll: {
[WelcomeTable.Charts]: '/chart/list',

View File

@@ -44,7 +44,7 @@ import {
TelemetryPixel,
} from '@superset-ui/core/components';
import type { ItemType, MenuItem } from '@superset-ui/core/components/Menu';
import { ensureAppRoot, makeUrl } from 'src/utils/pathUtils';
import { ensureAppRoot } from 'src/utils/pathUtils';
import { isEmbedded } from 'src/dashboard/util/isEmbedded';
import { findPermission } from 'src/utils/findPermission';
import { isUserAdmin } from 'src/dashboard/util/permissionUtils';
@@ -213,7 +213,10 @@ const RightMenu = ({
},
{
label: t('SQL query'),
url: makeUrl('/sqllab?new=true'),
// Keep the URL relative so isFrontendRoute() matches and Link navigates
// via React Router; the <Typography.Link> fallback applies ensureAppRoot
// exactly once for non-frontend routes.
url: '/sqllab?new=true',
icon: <Icons.SearchOutlined data-test={`menu-item-${t('SQL query')}`} />,
perm: 'can_sqllab',
view: 'Superset',

View File

@@ -25,11 +25,20 @@ import {
fireEvent,
waitFor,
} from 'spec/helpers/testing-library';
import { MemoryRouter } from 'react-router-dom';
import { MemoryRouter, useLocation } from 'react-router-dom';
import { QueryParamProvider } from 'use-query-params';
import { ReactRouter5Adapter } from 'use-query-params/adapters/react-router-5';
import * as getBootstrapData from 'src/utils/getBootstrapData';
import SavedQueryList from '.';
// Renders the current router pathname+search so tests can assert navigation.
function LocationDisplay() {
const location = useLocation();
return (
<div data-test="location-display">{`${location.pathname}${location.search}`}</div>
);
}
// Increase default timeout
jest.setTimeout(30000);
@@ -88,6 +97,7 @@ const renderList = (props = {}, storeOverrides = {}) =>
<MemoryRouter>
<QueryParamProvider adapter={ReactRouter5Adapter}>
<SavedQueryList user={mockUser} {...props} />
<LocationDisplay />
</QueryParamProvider>
</MemoryRouter>,
{
@@ -242,4 +252,39 @@ describe('SavedQueryList', () => {
// Verify delete buttons are not shown
expect(screen.queryByTestId('delete-action')).not.toBeInTheDocument();
});
test('"+ Query" button pushes a router-relative path (subdirectory deployment)', async () => {
// Simulate SUPERSET_APP_ROOT=/superset. ensureAppRoot/makeUrl read
// applicationRoot() dynamically, so mocking it here makes the buggy code
// path (makeUrl() around history.push) produce '/superset/sqllab?new=true'
// instead of being a no-op. React Router's <Router basename> prefixes the
// app root on its own, so history.push MUST receive a path without the
// app-root prefix — otherwise navigation lands at /superset/superset/sqllab
// and shows a blank page (sc-103661).
const applicationRootSpy = jest
.spyOn(getBootstrapData, 'applicationRoot')
.mockReturnValue('/superset');
try {
renderList();
await screen.findByTestId('saved_query-list-view');
const queryButton = await screen.findByRole('button', {
name: /query/i,
});
fireEvent.click(queryButton);
await waitFor(() => {
// The MemoryRouter in renderList uses the default ('/') basename, so
// useLocation reflects exactly what history.push received. A correct
// router-relative push produces '/sqllab?new=true'; a buggy push that
// re-applied the app root would produce '/superset/sqllab?new=true'.
const location = screen.getByTestId('location-display').textContent;
expect(location).toBe('/sqllab?new=true');
});
} finally {
applicationRootSpy.mockRestore();
}
});
});

View File

@@ -223,7 +223,9 @@ function SavedQueryList({
name: t('Query'),
buttonStyle: 'primary',
onClick: () => {
history.push(makeUrl('/sqllab?new=true'));
// React Router's basename already includes the application root; passing
// a relative path ensures correct navigation under subdirectory deployments.
history.push('/sqllab?new=true');
},
});
@@ -245,7 +247,9 @@ function SavedQueryList({
if (openInNewWindow) {
window.open(makeUrl(`/sqllab?savedQueryId=${id}`));
} else {
history.push(makeUrl(`/sqllab?savedQueryId=${id}`));
// React Router's basename already includes the application root; passing
// a relative path ensures correct navigation under subdirectory deployments.
history.push(`/sqllab?savedQueryId=${id}`);
}
};
@@ -338,9 +342,7 @@ function SavedQueryList({
row: {
original: { id, label },
},
}: any) => (
<Link to={makeUrl(`/sqllab?savedQueryId=${id}`)}>{label}</Link>
),
}: any) => <Link to={`/sqllab?savedQueryId=${id}`}>{label}</Link>,
id: 'label',
},
{

17
superset/initialization/__init__.py Normal file → Executable file
View File

@@ -747,6 +747,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
# Configuration of feature_flags must be done first to allow init features
# conditionally
self.configure_feature_flags()
self.configure_mcp_chart_registry()
self.configure_db_encrypt()
self.setup_db()
@@ -821,6 +822,22 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
def configure_feature_flags(self) -> None:
feature_flag_manager.init_app(self.superset_app)
def configure_mcp_chart_registry(self) -> None:
from superset.mcp_service.chart import registry
from superset.mcp_service.mcp_config import (
MCP_CHART_PLUGIN_ENABLED_FUNC,
MCP_DISABLED_CHART_PLUGINS,
)
registry.configure(
disabled=self.config.get(
"MCP_DISABLED_CHART_PLUGINS", MCP_DISABLED_CHART_PLUGINS
),
enabled_func=self.config.get(
"MCP_CHART_PLUGIN_ENABLED_FUNC", MCP_CHART_PLUGIN_ENABLED_FUNC
),
)
def configure_sqlglot_dialects(self) -> None:
extensions = self.config["SQLGLOT_DIALECTS_EXTENSIONS"]

View File

@@ -222,10 +222,12 @@ Time grain for temporal x-axis (time_grain parameter):
- PT1H (hourly), P1D (daily), P1W (weekly), P1M (monthly), P1Y (yearly)
Chart Types in Existing Charts (viewable via list_charts/get_chart_info):
- pie, big_number, big_number_total, funnel, gauge_chart
- echarts_timeseries_line, echarts_timeseries_bar, echarts_timeseries_area
- pivot_table_v2, heatmap_v2, sankey_v2, sunburst_v2, treemap_v2
- word_cloud, world_map, box_plot, bubble, mixed_timeseries
Each chart returned by list_charts / get_chart_info includes a
chart_type_display_name field with a human-readable name when available.
This field is populated only for the 7 chart types supported by generate_chart
(xy, pie, table, pivot_table, big_number, mixed_timeseries, handlebars).
For all other viz_types (Funnel, Gauge, Heatmap, etc.) it will be null —
use the raw viz_type field instead when referring to those chart types.
Query Examples:
- List all tables:
@@ -503,6 +505,7 @@ warnings.filterwarnings(
# NOTE: Always add new prompt/resource imports here when creating new prompts/resources.
# Prompts use @mcp.prompt decorators and resources use @mcp.resource decorators.
# They register automatically on import, similar to tools.
import superset.mcp_service.chart.plugins # noqa: F401, E402 — registers all chart type plugins
from superset.mcp_service.chart import ( # noqa: F401, E402
prompts as chart_prompts,
resources as chart_resources,

View File

@@ -318,29 +318,35 @@ def map_config_to_form_data(
| BigNumberChartConfig,
dataset_id: int | str | None = None,
) -> Dict[str, Any]:
"""Map chart config to Superset form_data."""
if isinstance(config, TableChartConfig):
return map_table_config(config)
elif isinstance(config, XYChartConfig):
return map_xy_config(config, dataset_id=dataset_id)
elif isinstance(config, PieChartConfig):
return map_pie_config(config)
elif isinstance(config, PivotTableChartConfig):
return map_pivot_table_config(config)
elif isinstance(config, MixedTimeseriesChartConfig):
return map_mixed_timeseries_config(config, dataset_id=dataset_id)
elif isinstance(config, HandlebarsChartConfig):
return map_handlebars_config(config)
elif isinstance(config, BigNumberChartConfig):
if config.show_trendline and config.temporal_column:
if not is_column_truly_temporal(config.temporal_column, dataset_id):
raise ValueError(
f"Big Number trendline requires a temporal SQL column; "
f"'{config.temporal_column}' is not temporal."
)
return map_big_number_config(config)
else:
raise ValueError(f"Unsupported config type: {type(config)}")
"""Map chart config to Superset form_data via the plugin registry.
The previous if/elif chain across all 7 chart types has been replaced by a
single registry lookup. Cross-field constraints (e.g. BigNumber trendline
temporal check) are now owned by each plugin's post_map_validate() method
rather than being baked into this dispatcher.
"""
# Local import: plugins call map_*_config from their to_form_data() methods,
# so chart_utils is loaded before plugins finish registering. A top-level
# import of registry here would trigger plugin loading mid-import = cycle.
from superset.mcp_service.chart.registry import get_registry
chart_type = getattr(config, "chart_type", None)
plugin = get_registry().get(chart_type) if chart_type else None
if plugin is None:
raise ValueError(
f"Unsupported config type: {type(config)} (chart_type={chart_type!r})"
)
form_data = plugin.to_form_data(config, dataset_id=dataset_id)
# Run post-map validation (e.g. BigNumber trendline temporal type check).
# Raise ValueError to preserve backward-compatible error handling in callers.
error = plugin.post_map_validate(config, form_data, dataset_id=dataset_id)
if error is not None:
raise ValueError(error.message)
return form_data
def _add_adhoc_filters(
@@ -1129,87 +1135,32 @@ def _big_number_chart_what(config: BigNumberChartConfig) -> str:
def generate_chart_name(
config: TableChartConfig
| XYChartConfig
| PieChartConfig
| PivotTableChartConfig
| MixedTimeseriesChartConfig
| HandlebarsChartConfig
| BigNumberChartConfig,
config: Any,
dataset_name: str | None = None,
) -> str:
"""Generate a descriptive chart name following a standard format.
Format conventions (by chart type):
Aggregated (bar/scatter with group_by): [Metric] by [Dimension]
Time-series (line/area, no group_by): [Metric] Over Time
Table (no aggregates): [Dataset] Records
Table (with aggregates): [Metric] Summary
Pie: [Dimension] by [Metric]
Pivot Table: Pivot Table [Row1, Row2]
Mixed Timeseries: [Primary] + [Secondary]
An en-dash followed by context (filters / time grain) is appended
Delegates to each plugin's ``generate_name()`` method.
See each plugin's ``generate_name`` for chart-type-specific format conventions.
An en-dash followed by context (filters / time grain) is appended by the plugin
when such information is available.
"""
if isinstance(config, TableChartConfig):
what = _table_chart_what(config, dataset_name)
context = _summarize_filters(config.filters)
elif isinstance(config, XYChartConfig):
what = _xy_chart_what(config)
context = _xy_chart_context(config)
elif isinstance(config, PieChartConfig):
what = _pie_chart_what(config)
context = _summarize_filters(config.filters)
elif isinstance(config, PivotTableChartConfig):
what = _pivot_table_what(config)
context = _summarize_filters(config.filters)
elif isinstance(config, MixedTimeseriesChartConfig):
what = _mixed_timeseries_what(config)
context = _summarize_filters(config.filters)
elif isinstance(config, HandlebarsChartConfig):
what = _handlebars_chart_what(config)
context = _summarize_filters(getattr(config, "filters", None))
elif isinstance(config, BigNumberChartConfig):
what = _big_number_chart_what(config)
context = _summarize_filters(getattr(config, "filters", None))
else:
return "Chart"
from superset.mcp_service.chart.registry import get_registry
name = what
if context:
name = f"{what} \u2013 {context}"
return _truncate(name)
plugin = get_registry().get(getattr(config, "chart_type", ""))
if plugin is None:
return "Chart"
return _truncate(plugin.generate_name(config, dataset_name))
def _resolve_viz_type(config: Any) -> str:
"""Resolve the Superset viz_type from a chart config object."""
chart_type = getattr(config, "chart_type", "unknown")
if chart_type == "xy":
kind = getattr(config, "kind", "line")
viz_type_map = {
"line": "echarts_timeseries_line",
"bar": "echarts_timeseries_bar",
"area": "echarts_area",
"scatter": "echarts_timeseries_scatter",
}
return viz_type_map.get(kind, "echarts_timeseries_line")
elif chart_type == "table":
return getattr(config, "viz_type", "table")
elif chart_type == "pie":
return "pie"
elif chart_type == "pivot_table":
return "pivot_table_v2"
elif chart_type == "mixed_timeseries":
return "mixed_timeseries"
elif chart_type == "handlebars":
return "handlebars"
elif chart_type == "big_number":
show_trendline = getattr(config, "show_trendline", False)
temporal_column = getattr(config, "temporal_column", None)
return (
"big_number" if show_trendline and temporal_column else "big_number_total"
)
return "unknown"
from superset.mcp_service.chart.registry import get_registry
plugin = get_registry().get(getattr(config, "chart_type", ""))
if plugin is None:
return "unknown"
return plugin.resolve_viz_type(config)
def analyze_chart_capabilities(chart: Any | None, config: Any) -> ChartCapabilities:

View File

@@ -0,0 +1,255 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
ChartTypePlugin protocol and BaseChartPlugin base class.
Each chart type owns its pre-validation, column extraction, form_data mapping,
and post-map validation in a single plugin class. This eliminates the previous
pattern of 4 separate dispatch points (schema_validator.py, dataset_validator.py,
chart_utils.py, pipeline.py) that had to be updated in sync whenever a new chart
type was added.
"""
from __future__ import annotations
from typing import Any, Protocol, runtime_checkable
from superset.mcp_service.chart.schemas import ColumnRef
from superset.mcp_service.common.error_schemas import ChartGenerationError
@runtime_checkable
class ChartTypePlugin(Protocol):
"""
Protocol that every chart-type plugin must satisfy.
Implementing all eight methods in a single class guarantees that adding a
new chart type requires only one new file — the plugin — rather than edits
across multiple separate files.
"""
#: Discriminator value matching ChartConfig's chart_type field.
chart_type: str
#: Human-readable name shown to users (e.g. "Line / Bar / Area / Scatter").
display_name: str
#: Maps every Superset-internal viz_type this plugin can produce to a
#: user-facing display name, e.g. {"echarts_timeseries_line": "Line Chart"}.
#: Used by the registry to resolve display names for existing charts without
#: needing a separate JSON mapping file.
native_viz_types: dict[str, str]
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
"""
Early validation of the raw config dict before Pydantic parsing.
Called by SchemaValidator before attempting to parse the request.
Should check that required top-level keys are present and well-typed.
Returns None if valid, ChartGenerationError if invalid.
"""
...
def extract_column_refs(
self,
config: Any,
) -> list[ColumnRef]:
"""
Extract all column references from a parsed chart config.
Called by DatasetValidator to validate that all referenced columns exist
in the dataset. Must cover every field that holds a column name,
including filters.
Returns a list of ColumnRef objects (may be empty).
"""
...
def to_form_data(
self,
config: Any,
dataset_id: int | str | None = None,
) -> dict[str, Any]:
"""
Map a parsed chart config to Superset's internal form_data dict.
Replaces the if/elif chain in chart_utils.map_config_to_form_data().
Returns a Superset form_data dict ready for caching and rendering.
"""
...
def post_map_validate(
self,
config: Any,
form_data: dict[str, Any],
dataset_id: int | str | None = None,
) -> ChartGenerationError | None:
"""
Validate the mapped form_data after to_form_data() runs.
Use this for cross-field constraints that can only be checked once
form_data is assembled (e.g. BigNumber trendline requires a temporal
column whose type must be verified against the dataset).
Returns None if valid, ChartGenerationError if invalid.
"""
...
def normalize_column_refs(
self,
config: Any,
dataset_context: Any,
) -> Any:
"""
Return a new config with column names normalized to canonical dataset casing.
Called by DatasetValidator.normalize_column_names(). The default
implementation (in BaseChartPlugin) returns the config unchanged; plugins
with column fields override this to fix case sensitivity mismatches.
Returns a new config object (or the original if no normalization needed).
"""
...
def get_runtime_warnings(
self,
config: Any,
dataset_id: int | str,
) -> list[str]:
"""
Return chart-type-specific runtime warnings (performance, compatibility).
Called by RuntimeValidator to collect per-type warnings. Warnings are
informational only — they never block chart generation. The default
implementation returns an empty list; plugins override this to emit
chart-type-specific warnings (e.g. XY cardinality checks).
Returns a list of warning message strings (may be empty).
"""
...
def generate_name(
self,
config: Any,
dataset_name: str | None = None,
) -> str:
"""
Return a descriptive chart name for the given config.
Called by chart_utils.generate_chart_name(). The name should follow
the standard format conventions documented in that function. Plugins
that do not override this return the generic fallback "Chart".
"""
...
def resolve_viz_type(self, config: Any) -> str:
"""
Return the Superset-internal viz_type string for this config.
Called by chart_utils._resolve_viz_type(). The returned string must
match a registered Superset viz plugin (e.g. "echarts_timeseries_line").
Plugins that do not override this return "unknown".
"""
...
def schema_error_hint(self) -> "ChartGenerationError | None":
"""
Return a user-friendly error for Pydantic discriminated-union parse failures.
Called by SchemaValidator when Pydantic cannot parse the config union and
the chart_type is known. Returning None falls back to the generic error.
"""
...
class BaseChartPlugin:
"""
Base class providing sensible defaults for all ChartTypePlugin methods.
Concrete plugins extend this and override only what they need.
"""
chart_type: str = ""
display_name: str = ""
native_viz_types: dict[str, str] = {}
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
return None
def extract_column_refs(
self,
config: Any,
) -> list[ColumnRef]:
return []
def to_form_data(
self,
config: Any,
dataset_id: int | str | None = None,
) -> dict[str, Any]:
raise NotImplementedError(
f"{self.__class__.__name__}.to_form_data() is not implemented"
)
def post_map_validate(
self,
config: Any,
form_data: dict[str, Any],
dataset_id: int | str | None = None,
) -> ChartGenerationError | None:
return None
def normalize_column_refs(
self,
config: Any,
dataset_context: Any,
) -> Any:
return config
def get_runtime_warnings(
self,
config: Any,
dataset_id: int | str,
) -> list[str]:
return []
def generate_name(
self,
config: Any,
dataset_name: str | None = None,
) -> str:
return "Chart"
def resolve_viz_type(self, config: Any) -> str:
return "unknown"
def schema_error_hint(self) -> ChartGenerationError | None:
return None
@staticmethod
def _with_context(what: str, context: str | None) -> str:
"""Combine a 'what' label and optional context with an en-dash."""
return f"{what} {context}" if context else what

View File

@@ -0,0 +1,58 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Chart type plugins package.
Importing this module registers all built-in chart type plugins in the global
registry. This module is imported by app.py at startup.
To add a new chart type:
1. Create ``superset/mcp_service/chart/plugins/{chart_type}.py``
2. Implement a class extending ``BaseChartPlugin``
3. Import and register it here
"""
from superset.mcp_service.chart.plugins.big_number import BigNumberChartPlugin
from superset.mcp_service.chart.plugins.handlebars import HandlebarsChartPlugin
from superset.mcp_service.chart.plugins.mixed_timeseries import (
MixedTimeseriesChartPlugin,
)
from superset.mcp_service.chart.plugins.pie import PieChartPlugin
from superset.mcp_service.chart.plugins.pivot_table import PivotTableChartPlugin
from superset.mcp_service.chart.plugins.table import TableChartPlugin
from superset.mcp_service.chart.plugins.xy import XYChartPlugin
from superset.mcp_service.chart.registry import register
# Register all built-in chart type plugins
register(XYChartPlugin())
register(TableChartPlugin())
register(PieChartPlugin())
register(PivotTableChartPlugin())
register(MixedTimeseriesChartPlugin())
register(HandlebarsChartPlugin())
register(BigNumberChartPlugin())
__all__ = [
"BigNumberChartPlugin",
"HandlebarsChartPlugin",
"MixedTimeseriesChartPlugin",
"PieChartPlugin",
"PivotTableChartPlugin",
"TableChartPlugin",
"XYChartPlugin",
]

View File

@@ -0,0 +1,220 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Big number chart type plugin."""
from __future__ import annotations
from typing import Any
from superset.mcp_service.chart.chart_utils import (
_big_number_chart_what,
_summarize_filters,
is_column_truly_temporal,
map_big_number_config,
)
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.schemas import BigNumberChartConfig, ColumnRef
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
from superset.mcp_service.common.error_schemas import ChartGenerationError
class BigNumberChartPlugin(BaseChartPlugin):
"""Plugin for big_number chart type."""
chart_type = "big_number"
display_name = "Big Number"
native_viz_types = {
"big_number": "Big Number with Trendline",
"big_number_total": "Big Number",
}
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
if "metric" not in config:
return ChartGenerationError(
error_type="missing_metric",
message="Big Number chart missing required field: metric",
details=(
"Big Number charts require a 'metric' field "
"specifying the value to display"
),
suggestions=[
"Add 'metric' with name and aggregate: "
"{'name': 'revenue', 'aggregate': 'SUM'}",
"The aggregate function is required (SUM, COUNT, AVG, MIN, MAX)",
"Example: {'chart_type': 'big_number', "
"'metric': {'name': 'sales', 'aggregate': 'SUM'}}",
],
error_code="MISSING_BIG_NUMBER_METRIC",
)
metric = config.get("metric", {})
if not isinstance(metric, dict):
return ChartGenerationError(
error_type="invalid_metric_type",
message="Big Number metric must be a dict with 'name' and 'aggregate'",
details=(
f"The 'metric' field must be an object, got {type(metric).__name__}"
),
suggestions=[
"Use a dict: {'name': 'col', 'aggregate': 'SUM'}",
"Valid aggregates: SUM, COUNT, AVG, MIN, MAX",
],
error_code="INVALID_BIG_NUMBER_METRIC_TYPE",
)
if not metric.get("aggregate") and not metric.get("saved_metric"):
return ChartGenerationError(
error_type="missing_metric_aggregate",
message=(
"Big Number metric must include an aggregate function "
"or reference a saved metric"
),
details=(
"The metric must have an 'aggregate' field or 'saved_metric': true"
),
suggestions=[
"Add 'aggregate': {'name': 'col', 'aggregate': 'SUM'}",
"Or use a saved metric: {'name': 'metric', 'saved_metric': true}",
"Valid aggregates: SUM, COUNT, AVG, MIN, MAX",
],
error_code="MISSING_BIG_NUMBER_AGGREGATE",
)
show_trendline = config.get("show_trendline", False)
temporal_column = config.get("temporal_column")
if show_trendline and not temporal_column:
return ChartGenerationError(
error_type="missing_temporal_column",
message="Trendline requires a temporal column",
details=(
"When 'show_trendline' is True, "
"a 'temporal_column' must be specified"
),
suggestions=[
"Add 'temporal_column': 'date_column_name'",
"Or set 'show_trendline': false for number only",
"Use get_dataset_info to find temporal columns",
],
error_code="MISSING_TEMPORAL_COLUMN",
)
return None
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
if not isinstance(config, BigNumberChartConfig):
return []
refs: list[ColumnRef] = [config.metric]
# temporal_column is a str field, not a ColumnRef — validate it exists
if config.temporal_column:
refs.append(ColumnRef(name=config.temporal_column))
if config.filters:
for f in config.filters:
refs.append(ColumnRef(name=f.column))
return refs
def to_form_data(
self, config: Any, dataset_id: int | str | None = None
) -> dict[str, Any]:
return map_big_number_config(config)
def post_map_validate(
self,
config: Any,
form_data: dict[str, Any],
dataset_id: int | str | None = None,
) -> ChartGenerationError | None:
"""Verify the trendline temporal column is a real temporal SQL type.
This check was previously baked into map_config_to_form_data() in
chart_utils.py as a special case. Moving it here keeps the dispatcher
clean and makes the constraint explicit and discoverable.
"""
if not isinstance(config, BigNumberChartConfig):
return None
if not (config.show_trendline and config.temporal_column):
return None
if not is_column_truly_temporal(config.temporal_column, dataset_id):
return ChartGenerationError(
error_type="non_temporal_trendline_column",
message=(
f"Big Number trendline requires a temporal SQL column; "
f"'{config.temporal_column}' is not temporal."
),
details=(
f"Column '{config.temporal_column}' does not have a temporal "
f"SQL type (DATE, DATETIME, TIMESTAMP). The trendline requires "
f"a true temporal column for DATE_TRUNC to work."
),
suggestions=[
"Use get_dataset_info to find columns with temporal SQL types",
"Set 'show_trendline': false to use any column as the metric",
"If the column contains dates stored as integers, "
"consider casting it in a virtual dataset",
],
error_code="NON_TEMPORAL_TRENDLINE_COLUMN",
)
return None
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
what = _big_number_chart_what(config)
context = _summarize_filters(getattr(config, "filters", None))
return self._with_context(what, context)
def resolve_viz_type(self, config: Any) -> str:
show_trendline = getattr(config, "show_trendline", False)
temporal_column = getattr(config, "temporal_column", None)
if show_trendline and temporal_column:
return "big_number"
return "big_number_total"
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
config_dict = config.model_dump()
if config_dict.get("metric") and not config_dict["metric"].get("saved_metric"):
config_dict["metric"]["name"] = DatasetValidator._get_canonical_column_name(
config_dict["metric"]["name"], dataset_context
)
if config_dict.get("temporal_column"):
config_dict["temporal_column"] = (
DatasetValidator._get_canonical_column_name(
config_dict["temporal_column"], dataset_context
)
)
DatasetValidator._normalize_filters(config_dict, dataset_context)
return BigNumberChartConfig.model_validate(config_dict)
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="big_number_validation_error",
message="Big Number chart configuration validation failed",
details=(
"The Big Number chart configuration is missing required "
"fields or has invalid structure"
),
suggestions=[
"Ensure 'metric' field has 'name' and 'aggregate'",
"Example: 'metric': {'name': 'revenue', 'aggregate': 'SUM'}",
"For trendline: add show_trendline=true and temporal_column='col'",
"Without trendline: just provide the metric",
],
error_code="BIG_NUMBER_VALIDATION_ERROR",
)

View File

@@ -0,0 +1,189 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Handlebars chart type plugin."""
from __future__ import annotations
from typing import Any
from superset.mcp_service.chart.chart_utils import (
_handlebars_chart_what,
_summarize_filters,
map_handlebars_config,
)
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.schemas import ColumnRef, HandlebarsChartConfig
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
from superset.mcp_service.common.error_schemas import ChartGenerationError
class HandlebarsChartPlugin(BaseChartPlugin):
"""Plugin for handlebars chart type (custom HTML template charts)."""
chart_type = "handlebars"
display_name = "Handlebars (Custom Template)"
native_viz_types = {
"handlebars": "Custom Template Chart",
}
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
if "handlebars_template" not in config:
return ChartGenerationError(
error_type="missing_handlebars_template",
message="Handlebars chart missing required field: handlebars_template",
details=(
"Handlebars charts require a 'handlebars_template' string "
"containing Handlebars HTML template markup"
),
suggestions=[
"Add 'handlebars_template' with a Handlebars HTML template",
"Data is available as {{data}} array in the template",
"Example: '<ul>{{#each data}}<li>{{this.name}}: "
"{{this.value}}</li>{{/each}}</ul>'",
],
error_code="MISSING_HANDLEBARS_TEMPLATE",
)
template = config.get("handlebars_template")
if not isinstance(template, str) or not template.strip():
return ChartGenerationError(
error_type="invalid_handlebars_template",
message="Handlebars template must be a non-empty string",
details=(
"The 'handlebars_template' field must be a non-empty string "
"containing valid Handlebars HTML template markup"
),
suggestions=[
"Ensure handlebars_template is a non-empty string",
"Example: '<ul>{{#each data}}<li>{{this.name}}</li>{{/each}}</ul>'",
],
error_code="INVALID_HANDLEBARS_TEMPLATE",
)
query_mode = config.get("query_mode", "aggregate")
if query_mode not in ("aggregate", "raw"):
return ChartGenerationError(
error_type="invalid_query_mode",
message="Invalid query_mode for handlebars chart",
details="query_mode must be either 'aggregate' or 'raw'",
suggestions=[
"Use 'aggregate' for aggregated data (default)",
"Use 'raw' for individual rows",
],
error_code="INVALID_QUERY_MODE",
)
if query_mode == "raw" and not config.get("columns"):
return ChartGenerationError(
error_type="missing_raw_columns",
message="Handlebars chart in 'raw' mode requires 'columns'",
details=(
"When query_mode is 'raw', you must specify which columns "
"to include in the query results"
),
suggestions=[
"Add 'columns': [{'name': 'column_name'}] for raw mode",
"Or use query_mode='aggregate' with 'metrics' and optional 'groupby'", # noqa: E501
],
error_code="MISSING_RAW_COLUMNS",
)
if query_mode == "aggregate" and not config.get("metrics"):
return ChartGenerationError(
error_type="missing_aggregate_metrics",
message="Handlebars chart in 'aggregate' mode requires 'metrics'",
details=(
"When query_mode is 'aggregate' (default), you must specify "
"at least one metric with an aggregate function"
),
suggestions=[
"Add 'metrics': [{'name': 'column', 'aggregate': 'SUM'}]",
"Or use query_mode='raw' with 'columns' for individual rows",
],
error_code="MISSING_AGGREGATE_METRICS",
)
return None
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
if not isinstance(config, HandlebarsChartConfig):
return []
refs: list[ColumnRef] = []
if config.columns:
refs.extend(config.columns)
if config.metrics:
refs.extend(config.metrics)
if config.groupby:
refs.extend(config.groupby)
if config.filters:
for f in config.filters:
refs.append(ColumnRef(name=f.column))
return refs
def to_form_data(
self, config: Any, dataset_id: int | str | None = None
) -> dict[str, Any]:
return map_handlebars_config(config)
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
what = _handlebars_chart_what(config)
context = _summarize_filters(getattr(config, "filters", None))
return self._with_context(what, context)
def resolve_viz_type(self, config: Any) -> str:
return "handlebars"
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
config_dict = config.model_dump()
def _norm_list(key: str) -> None:
if config_dict.get(key):
for col in config_dict[key]:
if not col.get("saved_metric"):
col["name"] = DatasetValidator._get_canonical_column_name(
col["name"], dataset_context
)
_norm_list("columns")
_norm_list("metrics")
_norm_list("groupby")
DatasetValidator._normalize_filters(config_dict, dataset_context)
return HandlebarsChartConfig.model_validate(config_dict)
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="handlebars_validation_error",
message="Handlebars chart configuration validation failed",
details=(
"The handlebars chart configuration is missing "
"required fields or has invalid structure"
),
suggestions=[
"Ensure 'handlebars_template' is a non-empty string",
"For aggregate mode: add 'metrics' with aggregate functions",
"For raw mode: set 'query_mode': 'raw' and add 'columns'",
"Example: {'chart_type': 'handlebars', "
"'handlebars_template': "
"'<ul>{{#each data}}<li>{{this.name}}</li>{{/each}}</ul>', "
"'metrics': [{'name': 'sales', 'aggregate': 'SUM'}]}",
],
error_code="HANDLEBARS_VALIDATION_ERROR",
)

View File

@@ -0,0 +1,165 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Mixed timeseries chart type plugin."""
from __future__ import annotations
from typing import Any
from superset.mcp_service.chart.chart_utils import (
_mixed_timeseries_what,
_summarize_filters,
map_mixed_timeseries_config,
)
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.schemas import ColumnRef, MixedTimeseriesChartConfig
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
from superset.mcp_service.common.error_schemas import ChartGenerationError
class MixedTimeseriesChartPlugin(BaseChartPlugin):
"""Plugin for mixed_timeseries chart type."""
chart_type = "mixed_timeseries"
display_name = "Mixed Timeseries"
native_viz_types = {
"mixed_timeseries": "Mixed Timeseries Chart",
}
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
missing_fields = []
if "x" not in config:
missing_fields.append("'x' (X-axis temporal column)")
if "y" not in config:
missing_fields.append("'y' (primary Y-axis metrics)")
if "y_secondary" not in config:
missing_fields.append("'y_secondary' (secondary Y-axis metrics)")
if missing_fields:
return ChartGenerationError(
error_type="missing_mixed_timeseries_fields",
message=(
f"Mixed timeseries chart missing required fields: "
f"{', '.join(missing_fields)}"
),
details=(
"Mixed timeseries charts require an x-axis, primary metrics, "
"and secondary metrics"
),
suggestions=[
"Add 'x' field: {'name': 'date_column'}",
"Add 'y' field: [{'name': 'revenue', 'aggregate': 'SUM'}]",
"Add 'y_secondary': [{'name': 'orders', 'aggregate': 'COUNT'}]",
"Optional: 'primary_kind' and 'secondary_kind' for chart types",
],
error_code="MISSING_MIXED_TIMESERIES_FIELDS",
)
for field_name in ["y", "y_secondary"]:
if not isinstance(config.get(field_name, []), list):
return ChartGenerationError(
error_type=f"invalid_{field_name}_format",
message=f"'{field_name}' must be a list of metrics",
details=(
f"The '{field_name}' field must be an array of metric "
"specifications"
),
suggestions=[
f"Wrap in array: '{field_name}': "
"[{'name': 'col', 'aggregate': 'SUM'}]",
],
error_code=f"INVALID_{field_name.upper()}_FORMAT",
)
return None
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
if not isinstance(config, MixedTimeseriesChartConfig):
return []
refs: list[ColumnRef] = [config.x]
refs.extend(config.y)
refs.extend(config.y_secondary)
if config.group_by:
refs.extend(config.group_by)
if config.group_by_secondary:
refs.extend(config.group_by_secondary)
if config.filters:
for f in config.filters:
refs.append(ColumnRef(name=f.column))
return refs
def to_form_data(
self, config: Any, dataset_id: int | str | None = None
) -> dict[str, Any]:
return map_mixed_timeseries_config(config, dataset_id=dataset_id)
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
what = _mixed_timeseries_what(config)
context = _summarize_filters(config.filters)
return self._with_context(what, context)
def resolve_viz_type(self, config: Any) -> str:
return "mixed_timeseries"
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
config_dict = config.model_dump()
def _norm_single(key: str) -> None:
if config_dict.get(key):
config_dict[key]["name"] = DatasetValidator._get_canonical_column_name(
config_dict[key]["name"], dataset_context
)
def _norm_list(key: str) -> None:
if config_dict.get(key):
for col in config_dict[key]:
col["name"] = DatasetValidator._get_canonical_column_name(
col["name"], dataset_context
)
_norm_single("x")
_norm_list("y")
_norm_list("y_secondary")
_norm_list("group_by")
_norm_list("group_by_secondary")
DatasetValidator._normalize_filters(config_dict, dataset_context)
return MixedTimeseriesChartConfig.model_validate(config_dict)
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="mixed_timeseries_validation_error",
message="Mixed timeseries chart configuration validation failed",
details=(
"The mixed timeseries configuration is missing "
"required fields or has invalid structure"
),
suggestions=[
"Ensure 'x' field has 'name' for the time axis column",
"Ensure 'y' is an array of primary-axis metrics",
"Ensure 'y_secondary' is an array of secondary-axis metrics",
"Example: {'chart_type': 'mixed_timeseries', "
"'x': {'name': 'order_date'}, "
"'y': [{'name': 'revenue', 'aggregate': 'SUM'}], "
"'y_secondary': [{'name': 'orders', 'aggregate': 'COUNT'}]}",
],
error_code="MIXED_TIMESERIES_VALIDATION_ERROR",
)

View File

@@ -0,0 +1,128 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Pie chart type plugin."""
from __future__ import annotations
from typing import Any
from superset.mcp_service.chart.chart_utils import (
_pie_chart_what,
_summarize_filters,
map_pie_config,
)
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.schemas import ColumnRef, PieChartConfig
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
from superset.mcp_service.common.error_schemas import ChartGenerationError
class PieChartPlugin(BaseChartPlugin):
"""Plugin for pie chart type."""
chart_type = "pie"
display_name = "Pie / Donut Chart"
native_viz_types = {
"pie": "Pie Chart",
}
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
missing_fields = []
if "dimension" not in config:
missing_fields.append("'dimension' (category column for slices)")
if "metric" not in config:
missing_fields.append("'metric' (value metric for slice sizes)")
if missing_fields:
return ChartGenerationError(
error_type="missing_pie_fields",
message=(
f"Pie chart missing required fields: {', '.join(missing_fields)}"
),
details=(
"Pie charts require a dimension (categories) and a metric (values)"
),
suggestions=[
"Add 'dimension' field: {'name': 'category_column'}",
"Add 'metric' field: {'name': 'value_column', 'aggregate': 'SUM'}",
"Example: {'chart_type': 'pie', 'dimension': {'name': 'product'}, "
"'metric': {'name': 'revenue', 'aggregate': 'SUM'}}",
],
error_code="MISSING_PIE_FIELDS",
)
return None
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
if not isinstance(config, PieChartConfig):
return []
refs: list[ColumnRef] = [config.dimension, config.metric]
if config.filters:
for f in config.filters:
refs.append(ColumnRef(name=f.column))
return refs
def to_form_data(
self, config: Any, dataset_id: int | str | None = None
) -> dict[str, Any]:
return map_pie_config(config)
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
what = _pie_chart_what(config)
context = _summarize_filters(config.filters)
return self._with_context(what, context)
def resolve_viz_type(self, config: Any) -> str:
return "pie"
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
config_dict = config.model_dump()
if config_dict.get("dimension"):
config_dict["dimension"]["name"] = (
DatasetValidator._get_canonical_column_name(
config_dict["dimension"]["name"], dataset_context
)
)
if config_dict.get("metric") and not config_dict["metric"].get("saved_metric"):
config_dict["metric"]["name"] = DatasetValidator._get_canonical_column_name(
config_dict["metric"]["name"], dataset_context
)
DatasetValidator._normalize_filters(config_dict, dataset_context)
return PieChartConfig.model_validate(config_dict)
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="pie_validation_error",
message="Pie chart configuration validation failed",
details=(
"The pie chart configuration is missing required "
"fields or has invalid structure"
),
suggestions=[
"Ensure 'dimension' field has 'name' for the slice label",
"Ensure 'metric' field has 'name' and 'aggregate'",
"Example: {'chart_type': 'pie', 'dimension': {'name': 'category'}, "
"'metric': {'name': 'revenue', 'aggregate': 'SUM'}}",
],
error_code="PIE_VALIDATION_ERROR",
)

View File

@@ -0,0 +1,153 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Pivot table chart type plugin."""
from __future__ import annotations
from typing import Any
from superset.mcp_service.chart.chart_utils import (
_pivot_table_what,
_summarize_filters,
map_pivot_table_config,
)
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.schemas import ColumnRef, PivotTableChartConfig
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
from superset.mcp_service.common.error_schemas import ChartGenerationError
class PivotTableChartPlugin(BaseChartPlugin):
"""Plugin for pivot_table chart type."""
chart_type = "pivot_table"
display_name = "Pivot Table"
native_viz_types = {
"pivot_table_v2": "Pivot Table",
}
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
missing_fields = []
if "rows" not in config:
missing_fields.append("'rows' (row grouping columns)")
if "metrics" not in config:
missing_fields.append("'metrics' (aggregation metrics)")
if missing_fields:
return ChartGenerationError(
error_type="missing_pivot_fields",
message=(
f"Pivot table missing required fields: {', '.join(missing_fields)}"
),
details="Pivot tables require row groupings and metrics",
suggestions=[
"Add 'rows' field: [{'name': 'category'}]",
"Add 'metrics' field: [{'name': 'sales', 'aggregate': 'SUM'}]",
"Optional 'columns' for cross-tabulation: [{'name': 'region'}]",
],
error_code="MISSING_PIVOT_FIELDS",
)
if not isinstance(config.get("rows", []), list):
return ChartGenerationError(
error_type="invalid_rows_format",
message="Rows must be a list of columns",
details="The 'rows' field must be an array of column specifications",
suggestions=[
"Wrap row columns in array: 'rows': [{'name': 'category'}]",
],
error_code="INVALID_ROWS_FORMAT",
)
if not isinstance(config.get("metrics", []), list):
return ChartGenerationError(
error_type="invalid_metrics_format",
message="Metrics must be a list",
details="The 'metrics' field must be an array of metric specifications",
suggestions=[
"Wrap metrics in array: 'metrics': [{'name': 'sales', "
"'aggregate': 'SUM'}]",
],
error_code="INVALID_METRICS_FORMAT",
)
return None
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
if not isinstance(config, PivotTableChartConfig):
return []
refs: list[ColumnRef] = list(config.rows)
refs.extend(config.metrics)
if config.columns:
refs.extend(config.columns)
if config.filters:
for f in config.filters:
refs.append(ColumnRef(name=f.column))
return refs
def to_form_data(
self, config: Any, dataset_id: int | str | None = None
) -> dict[str, Any]:
return map_pivot_table_config(config)
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
what = _pivot_table_what(config)
context = _summarize_filters(config.filters)
return self._with_context(what, context)
def resolve_viz_type(self, config: Any) -> str:
return "pivot_table_v2"
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
config_dict = config.model_dump()
def _norm_col_list(key: str) -> None:
if config_dict.get(key):
for col in config_dict[key]:
col["name"] = DatasetValidator._get_canonical_column_name(
col["name"], dataset_context
)
_norm_col_list("rows")
_norm_col_list("metrics")
_norm_col_list("columns")
DatasetValidator._normalize_filters(config_dict, dataset_context)
return PivotTableChartConfig.model_validate(config_dict)
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="pivot_table_validation_error",
message="Pivot table configuration validation failed",
details=(
"The pivot table configuration is missing required "
"fields or has invalid structure"
),
suggestions=[
"Ensure 'rows' field is an array of column specs",
"Ensure 'metrics' field is an array with aggregate funcs",
"Optional: add 'columns' for column grouping",
"Example: {'chart_type': 'pivot_table', "
"'rows': [{'name': 'region'}], "
"'metrics': [{'name': 'revenue', 'aggregate': 'SUM'}]}",
],
error_code="PIVOT_TABLE_VALIDATION_ERROR",
)

View File

@@ -0,0 +1,128 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Table chart type plugin."""
from __future__ import annotations
from typing import Any
from superset.mcp_service.chart.chart_utils import (
_summarize_filters,
_table_chart_what,
map_table_config,
)
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.schemas import ColumnRef, TableChartConfig
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
from superset.mcp_service.common.error_schemas import ChartGenerationError
class TableChartPlugin(BaseChartPlugin):
"""Plugin for table chart type."""
chart_type = "table"
display_name = "Table"
native_viz_types = {
"table": "Table",
"ag-grid-table": "Interactive Table",
}
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
if "columns" not in config:
return ChartGenerationError(
error_type="missing_columns",
message="Table chart missing required field: columns",
details=(
"Table charts require a 'columns' array to specify which "
"columns to display"
),
suggestions=[
"Add 'columns' field with array of column specifications",
"Example: 'columns': [{'name': 'product'}, {'name': 'sales', "
"'aggregate': 'SUM'}]",
"Each column can have optional 'aggregate' for metrics",
],
error_code="MISSING_COLUMNS",
)
if not isinstance(config.get("columns", []), list):
return ChartGenerationError(
error_type="invalid_columns_format",
message="Columns must be a list",
details="The 'columns' field must be an array of column specifications",
suggestions=[
"Ensure columns is an array: 'columns': [...]",
"Each column should be an object with 'name' field",
],
error_code="INVALID_COLUMNS_FORMAT",
)
return None
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
if not isinstance(config, TableChartConfig):
return []
refs: list[ColumnRef] = list(config.columns)
if config.filters:
for f in config.filters:
refs.append(ColumnRef(name=f.column))
return refs
def to_form_data(
self, config: Any, dataset_id: int | str | None = None
) -> dict[str, Any]:
return map_table_config(config)
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
what = _table_chart_what(config, dataset_name)
context = _summarize_filters(config.filters)
return self._with_context(what, context)
def resolve_viz_type(self, config: Any) -> str:
return getattr(config, "viz_type", "table")
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
config_dict = config.model_dump()
get_canonical = DatasetValidator._get_canonical_column_name
for col in config_dict.get("columns") or []:
col["name"] = get_canonical(col["name"], dataset_context)
DatasetValidator._normalize_filters(config_dict, dataset_context)
return TableChartConfig.model_validate(config_dict)
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="table_validation_error",
message="Table chart configuration validation failed",
details=(
"The table chart configuration is missing required "
"fields or has invalid structure"
),
suggestions=[
"Ensure 'columns' field is an array of column specifications",
"Each column needs {'name': 'column_name'}",
"Optional: add 'aggregate' for metrics",
"Example: 'columns': [{'name': 'product'}, "
"{'name': 'sales', 'aggregate': 'SUM'}]",
],
error_code="TABLE_VALIDATION_ERROR",
)

View File

@@ -0,0 +1,192 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""XY chart type plugin (line, bar, area, scatter)."""
from __future__ import annotations
import logging
from typing import Any
from superset.mcp_service.chart.chart_utils import (
_xy_chart_context,
_xy_chart_what,
map_xy_config,
)
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.schemas import ColumnRef, XYChartConfig
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
from superset.mcp_service.chart.validation.runtime.cardinality_validator import (
CardinalityValidator,
)
from superset.mcp_service.chart.validation.runtime.format_validator import (
FormatTypeValidator,
)
from superset.mcp_service.common.error_schemas import ChartGenerationError
logger = logging.getLogger(__name__)
class XYChartPlugin(BaseChartPlugin):
"""Plugin for xy chart type (line, bar, area, scatter)."""
chart_type = "xy"
display_name = "Line / Bar / Area / Scatter Chart"
native_viz_types = {
"echarts_timeseries_line": "Line Chart",
"echarts_timeseries_bar": "Bar Chart",
"echarts_area": "Area Chart",
"echarts_timeseries_scatter": "Scatter Plot",
}
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
# x is optional — defaults to dataset's main_dttm_col in map_xy_config
if "y" not in config:
return ChartGenerationError(
error_type="missing_xy_fields",
message="XY chart missing required field: 'y' (Y-axis metrics)",
details=(
"XY charts require Y-axis (metrics) specifications. "
"X-axis is optional and defaults to the dataset's primary "
"datetime column when omitted."
),
suggestions=[
"Add 'y' field: [{'name': 'metric_column', 'aggregate': 'SUM'}]",
"Example: {'chart_type': 'xy', 'x': {'name': 'date'}, "
"'y': [{'name': 'sales', 'aggregate': 'SUM'}]}",
],
error_code="MISSING_XY_FIELDS",
)
if not isinstance(config.get("y", []), list):
return ChartGenerationError(
error_type="invalid_y_format",
message="Y-axis must be a list of metrics",
details="The 'y' field must be an array of metric specifications",
suggestions=[
"Wrap Y-axis metric in array: 'y': [{'name': 'column', "
"'aggregate': 'SUM'}]",
"Multiple metrics supported: 'y': [metric1, metric2, ...]",
],
error_code="INVALID_Y_FORMAT",
)
return None
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
if not isinstance(config, XYChartConfig):
return []
refs: list[ColumnRef] = []
if config.x is not None:
refs.append(config.x)
refs.extend(config.y)
if config.group_by:
refs.extend(config.group_by)
if config.filters:
for f in config.filters:
refs.append(ColumnRef(name=f.column))
return refs
def to_form_data(
self, config: Any, dataset_id: int | str | None = None
) -> dict[str, Any]:
return map_xy_config(config, dataset_id=dataset_id)
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
config_dict = config.model_dump()
get_canonical = DatasetValidator._get_canonical_column_name
if config_dict.get("x"):
config_dict["x"]["name"] = get_canonical(
config_dict["x"]["name"], dataset_context
)
for y_col in config_dict.get("y") or []:
y_col["name"] = get_canonical(y_col["name"], dataset_context)
for gb_col in config_dict.get("group_by") or []:
gb_col["name"] = get_canonical(gb_col["name"], dataset_context)
DatasetValidator._normalize_filters(config_dict, dataset_context)
return XYChartConfig.model_validate(config_dict)
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
what = _xy_chart_what(config)
context = _xy_chart_context(config)
return self._with_context(what, context)
def resolve_viz_type(self, config: Any) -> str:
kind = getattr(config, "kind", "line")
return {
"line": "echarts_timeseries_line",
"bar": "echarts_timeseries_bar",
"area": "echarts_area",
"scatter": "echarts_timeseries_scatter",
}.get(kind, "echarts_timeseries_line")
def get_runtime_warnings(self, config: Any, dataset_id: int | str) -> list[str]:
"""Return format-compatibility and cardinality warnings for XY charts."""
if not isinstance(config, XYChartConfig):
return []
warnings: list[str] = []
try:
_valid, format_warnings = FormatTypeValidator.validate_format_compatibility(
config
)
if format_warnings:
warnings.extend(format_warnings)
except Exception as exc:
logger.warning("XY format validation failed: %s", exc)
try:
chart_kind = config.kind
group_by_col = config.group_by[0].name if config.group_by else None
if config.x is not None:
_ok, card_info = CardinalityValidator.check_cardinality(
dataset_id=dataset_id,
x_column=config.x.name,
chart_type=chart_kind,
group_by_column=group_by_col,
)
if not _ok and card_info:
warnings.extend(card_info.get("warnings", []))
warnings.extend(card_info.get("suggestions", []))
except Exception as exc:
logger.warning("XY cardinality validation failed: %s", exc)
return warnings
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="xy_validation_error",
message="XY chart configuration validation failed",
details=(
"The XY chart configuration is missing required "
"fields or has invalid structure"
),
suggestions=[
"Note: 'x' is optional and defaults to the dataset's primary "
"datetime column",
"Ensure 'y' is an array: [{'name': 'metric', 'aggregate': 'SUM'}]",
"Check that all column names are strings",
"Verify aggregate functions are valid: SUM, COUNT, AVG, MIN, MAX",
],
error_code="XY_VALIDATION_ERROR",
)

View File

@@ -0,0 +1,228 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
ChartTypeRegistry — central registry mapping chart_type strings to plugins.
Replaces the four previously-scattered dispatch locations:
- schema_validator.py: chart_type_validators dict
- dataset_validator.py: isinstance branches in _extract_column_references()
- chart_utils.py: if/elif chain in map_config_to_form_data()
- dataset_validator.py: isinstance branches in normalize_column_names()
Usage::
from superset.mcp_service.chart.registry import get_registry
plugin = get_registry().get("xy")
if plugin is None:
raise ValueError("Unknown chart type: xy")
form_data = plugin.to_form_data(config, dataset_id)
"""
from __future__ import annotations
import logging
import threading
from collections.abc import Callable, Iterable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from superset.mcp_service.chart.plugin import ChartTypePlugin
logger = logging.getLogger(__name__)
_REGISTRY: dict[str, "ChartTypePlugin"] = {}
_plugins_loaded = False
_plugins_lock = threading.Lock()
# ---------------------------------------------------------------------------
# Plugin filter — replaced atomically by configure() at app startup.
# Default: all registered plugins visible (no disabled set, no callable).
# ---------------------------------------------------------------------------
PluginEnabledFunc = Callable[[str], bool]
@dataclass(frozen=True)
class _PluginFilterConfig:
disabled_plugins: frozenset[str] = field(default_factory=frozenset)
enabled_func: PluginEnabledFunc | None = None
_filter_config: _PluginFilterConfig = _PluginFilterConfig()
def _ensure_plugins_loaded() -> None:
"""Lazily import the plugins package to populate _REGISTRY.
Called before every registry lookup so the registry is always populated,
even when callers (tests, chart_utils, validators) import this module
directly without first importing app.py.
"""
global _plugins_loaded
if _plugins_loaded:
return
with _plugins_lock:
if not _plugins_loaded:
try:
import superset.mcp_service.chart.plugins # noqa: F401
_plugins_loaded = True
except Exception:
logger.exception("Failed to load built-in chart type plugins")
def configure(
disabled: Iterable[str] | None = None,
enabled_func: PluginEnabledFunc | None = None,
) -> None:
"""Set runtime plugin filters. Called once during app initialization.
Replaces the filter config atomically with a single assignment so concurrent
readers always observe a consistent (disabled_plugins, enabled_func) pair.
Args:
disabled: chart_type strings to suppress. Accepts any iterable (set,
frozenset, list, tuple). Ignored when enabled_func is provided.
enabled_func: callable(chart_type) -> bool. When set, overrides
``disabled``. Must be cheap and in-process — no network I/O per
call. On exception the registry fails *closed* (plugin hidden).
"""
global _filter_config
if enabled_func is not None and not callable(enabled_func):
raise TypeError("enabled_func must be callable or None")
new_config = _PluginFilterConfig(
disabled_plugins=frozenset(disabled or ()),
enabled_func=enabled_func,
)
_filter_config = new_config
if new_config.disabled_plugins:
logger.info(
"MCP chart plugins disabled: %s", sorted(new_config.disabled_plugins)
)
if new_config.enabled_func is not None:
logger.info(
"MCP chart plugin dynamic filter configured: %r", new_config.enabled_func
)
def _is_plugin_enabled(chart_type: str) -> bool:
"""Return True if the plugin is currently enabled (not filtered out)."""
config = _filter_config # read once — atomic reference in CPython
if config.enabled_func is not None:
try:
return bool(config.enabled_func(chart_type))
except Exception:
logger.warning(
"MCP_CHART_PLUGIN_ENABLED_FUNC raised for chart_type=%r; "
"failing closed (plugin hidden)",
chart_type,
exc_info=True,
)
return False
return chart_type not in config.disabled_plugins
def register(plugin: "ChartTypePlugin") -> None:
"""Register a chart type plugin in the global registry."""
if not plugin.chart_type:
raise ValueError(f"{type(plugin).__name__} must define a non-empty chart_type")
if plugin.chart_type in _REGISTRY:
logger.warning(
"Overwriting existing plugin for chart_type=%r", plugin.chart_type
)
_REGISTRY[plugin.chart_type] = plugin
logger.debug("Registered chart plugin: %r", plugin.chart_type)
def get(chart_type: str) -> "ChartTypePlugin | None":
"""Return the plugin for chart_type, or None if unknown or disabled."""
_ensure_plugins_loaded()
if chart_type not in _REGISTRY or not _is_plugin_enabled(chart_type):
return None
return _REGISTRY[chart_type]
def all_types() -> list[str]:
"""Return enabled registered chart type strings in insertion order."""
_ensure_plugins_loaded()
return [ct for ct in _REGISTRY if _is_plugin_enabled(ct)]
def is_registered(chart_type: str) -> bool:
"""Return True if chart_type has a registered plugin, regardless of enabled state.
Use this to distinguish an unknown chart type from a disabled one.
Use is_enabled() to check whether the plugin is currently available.
"""
_ensure_plugins_loaded()
return chart_type in _REGISTRY
def is_enabled(chart_type: str) -> bool:
"""Return True if chart_type is registered AND currently enabled."""
_ensure_plugins_loaded()
return chart_type in _REGISTRY and _is_plugin_enabled(chart_type)
def display_name_for_viz_type(viz_type: str) -> str | None:
"""Return the user-facing display name for a Superset-internal viz_type.
Searches every registered plugin's ``native_viz_types`` mapping.
Returns None if no plugin recognises the viz_type.
Example::
display_name_for_viz_type("echarts_timeseries_line") # "Line Chart"
display_name_for_viz_type("pivot_table_v2") # "Pivot Table"
display_name_for_viz_type("unknown_type") # None
"""
_ensure_plugins_loaded()
for plugin in _REGISTRY.values():
name = plugin.native_viz_types.get(viz_type)
if name is not None:
return name
return None
def get_registry() -> "_RegistryProxy":
"""Return a proxy object for registry access (convenience wrapper)."""
return _RegistryProxy()
class _RegistryProxy:
"""Thin proxy exposing registry functions as instance methods."""
def get(self, chart_type: str) -> "ChartTypePlugin | None":
return get(chart_type)
def all_types(self) -> list[str]:
return all_types()
def is_registered(self, chart_type: str) -> bool:
return is_registered(chart_type)
def is_enabled(self, chart_type: str) -> bool:
return is_enabled(chart_type)
def display_name_for_viz_type(self, viz_type: str) -> str | None:
return display_name_for_viz_type(viz_type)

23
superset/mcp_service/chart/schemas.py Normal file → Executable file
View File

@@ -101,7 +101,14 @@ class ChartInfo(BaseModel):
id: int | None = Field(None, description="Chart ID")
slice_name: str | None = Field(None, description="Chart name")
viz_type: str | None = Field(None, description="Visualization type")
viz_type: str | None = Field(None, description="Visualization type (internal ID)")
chart_type_display_name: str | None = Field(
None,
description=(
"User-friendly chart type name (e.g. 'Line Chart', 'Pivot Table'). "
"Use this field when referring to chart types — never expose viz_type."
),
)
datasource_name: str | None = Field(None, description="Datasource name")
datasource_type: str | None = Field(None, description="Datasource type")
url: str | None = Field(None, description="Chart explore page URL")
@@ -488,11 +495,20 @@ def serialize_chart_object(chart: ChartLike | None) -> ChartInfo | None:
# Extract structured filter information
filters_info = extract_filters_from_form_data(chart_form_data)
_viz_type = getattr(chart, "viz_type", None)
try:
from superset.mcp_service.chart.registry import display_name_for_viz_type
_display_name = display_name_for_viz_type(_viz_type) if _viz_type else None
except Exception:
_display_name = None
return sanitize_chart_info_for_llm_context(
ChartInfo(
id=chart_id,
slice_name=getattr(chart, "slice_name", None),
viz_type=getattr(chart, "viz_type", None),
viz_type=_viz_type,
chart_type_display_name=_display_name,
datasource_name=getattr(chart, "datasource_name", None),
datasource_type=getattr(chart, "datasource_type", None),
url=chart_url,
@@ -669,7 +685,6 @@ class ColumnRef(BaseModel):
...,
min_length=1,
max_length=255,
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
validation_alias=AliasChoices("name", "column_name"),
)
label: str | None = Field(None, max_length=500)
@@ -743,7 +758,6 @@ class FilterConfig(BaseModel):
...,
min_length=1,
max_length=255,
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
validation_alias=AliasChoices("column", "col"),
)
op: Literal[
@@ -1082,7 +1096,6 @@ class BigNumberChartConfig(UnknownFieldCheckMixin):
),
min_length=1,
max_length=255,
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
)
time_grain: TimeGrain | None = Field(
None,

View File

@@ -100,18 +100,34 @@ async def generate_chart( # noqa: C901
- Set save_chart=True to permanently save the chart
- LLM clients MUST display returned chart URL to users
- Use numeric dataset ID or UUID (NOT schema.table_name format)
- MUST include chart_type in config (either 'xy' or 'table')
- MUST include chart_type in config (one of: 'xy', 'table', 'pie',
'pivot_table', 'mixed_timeseries', 'handlebars', 'big_number')
IMPORTANT: The 'chart_type' field in the config is a DISCRIMINATOR that determines
which chart configuration schema to use. It MUST be included and MUST match the
other fields in your configuration:
- Use chart_type='xy' for charts with x and y axes (line, bar, area, scatter)
Required fields: x, y
Required fields: y (x is optional — defaults to dataset's primary datetime column)
- Use chart_type='table' for tabular visualizations
Required fields: columns
- Use chart_type='pie' for pie/donut charts
Required fields: dimension, metric
- Use chart_type='pivot_table' for pivot table visualizations
Required fields: rows, metrics
- Use chart_type='mixed_timeseries' for dual-axis time-series charts
Required fields: x, y, y_secondary
- Use chart_type='handlebars' for custom template-based visualizations
Required fields: handlebars_template
- Use chart_type='big_number' for single KPI metric displays
Required fields: metric
Example usage for XY chart:
```json
{

View File

@@ -28,7 +28,7 @@ from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.commands.exceptions import CommandException
from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException
from superset.extensions import event_logger
from superset.extensions import db, event_logger
from superset.mcp_service.chart.ascii_charts import (
generate_ascii_chart,
generate_ascii_table,
@@ -1140,6 +1140,15 @@ async def _get_chart_preview_internal( # noqa: C901
)
chart = find_chart_by_identifier(request.identifier)
# Eagerly refresh all attributes while the session is still
# active. SQLAlchemy expires object attributes after any
# commit; if a downstream operation commits before the strategy
# classes access chart attributes, a DetachedInstanceError will
# be raised. Calling refresh() here ensures all column values
# are loaded into the object's __dict__ upfront.
if chart is not None:
db.session.refresh(chart)
# If not found and looks like a form_data_key, try transient
if (
not chart
@@ -1371,6 +1380,20 @@ async def _get_chart_preview_internal( # noqa: C901
return _sanitize_chart_preview_for_llm_context(result)
except SQLAlchemyError as e:
# Catch DetachedInstanceError and other SQLAlchemy errors that can
# surface when the ORM session expires or commits mid-request.
await ctx.error(
"Chart preview failed due to database session error: "
"identifier=%s, error_type=%s, error=%s"
% (request.identifier, type(e).__name__, str(e))
)
logger.exception("SQLAlchemy error in get_chart_preview: %s", e)
return ChartError(
error="Database session error while generating chart preview. "
"Please retry the request.",
error_type="InternalError",
)
except (
CommandException,
SupersetException,

41
superset/mcp_service/chart/tool/update_chart.py Normal file → Executable file
View File

@@ -195,6 +195,29 @@ def _validate_update_against_dataset(
}
)
# Column existence + fuzzy-match validation
# (mirrors generate_chart pipeline layer 2)
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
is_col_valid, col_error = DatasetValidator.validate_against_dataset(
parsed_config, dataset.id
)
if not is_col_valid and col_error is not None:
logger.warning(
"update_chart column validation failed for chart %s: %s",
getattr(chart, "id", None),
col_error,
)
return GenerateChartResponse.model_validate(
{
"chart": None,
"error": col_error.model_dump(),
"success": False,
"schema_version": "2.0",
"api_version": "v1",
}
)
compile_result = validate_and_compile(
parsed_config, form_data, dataset, run_compile_check=True
)
@@ -388,6 +411,24 @@ async def update_chart( # noqa: C901
# config is already a typed ChartConfig | None (validated by Pydantic)
parsed_config = request.config
# Normalize column case to match dataset canonical names
# (mirrors generate_chart pipeline layer 4)
chart_datasource_id = getattr(chart, "datasource_id", None)
if parsed_config is not None and chart_datasource_id is not None:
from superset.mcp_service.chart.validation.dataset_validator import (
DatasetValidator,
NORMALIZATION_EXCEPTIONS,
)
try:
parsed_config = DatasetValidator.normalize_column_names(
parsed_config, chart.datasource_id
)
except NORMALIZATION_EXCEPTIONS as e:
logger.warning(
"Column normalization failed for chart %s: %s", chart.id, e
)
if not request.generate_preview:
from superset.commands.chart.update import UpdateChartCommand

View File

@@ -22,17 +22,11 @@ Validates that referenced columns exist in the dataset schema.
import difflib
import logging
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, TypeVar
from superset.mcp_service.chart.schemas import (
BigNumberChartConfig,
ChartConfig,
ColumnRef,
HandlebarsChartConfig,
MixedTimeseriesChartConfig,
PieChartConfig,
PivotTableChartConfig,
TableChartConfig,
XYChartConfig,
)
from superset.mcp_service.common.error_schemas import (
ChartGenerationError,
@@ -40,6 +34,8 @@ from superset.mcp_service.common.error_schemas import (
DatasetContext,
)
_C = TypeVar("_C", bound=ChartConfig)
logger = logging.getLogger(__name__)
# Exceptions that can occur during column name normalization.
@@ -58,7 +54,7 @@ class DatasetValidator:
@staticmethod
def validate_against_dataset(
config: Any,
config: ChartConfig,
dataset_id: int | str,
dataset_context: DatasetContext | None = None,
) -> Tuple[bool, ChartGenerationError | None]:
@@ -260,59 +256,31 @@ class DatasetValidator:
return None
@staticmethod
def _extract_column_references(config: Any) -> List[ColumnRef]: # noqa: C901
"""Extract all column references from a chart configuration.
def _extract_column_references(
config: ChartConfig,
) -> List[ColumnRef]:
"""Extract all column references from configuration via the plugin registry.
Covers every supported ``ChartConfig`` variant so fast-path tools
(``generate_explore_link``, ``update_chart_preview``) that only run
Tier-1 validation still catch bad column refs in pie / pivot table /
mixed timeseries / handlebars / big number charts — not just XY and
table.
Previously only handled TableChartConfig and XYChartConfig, causing
5 of 7 chart types to silently skip column validation. Now delegates
to the plugin for each chart type so all types are covered.
"""
refs: List[ColumnRef] = []
# Local import: plugins call DatasetValidator helpers from
# normalize_column_refs().
# A top-level import of registry in dataset_validator would make loading this
# module implicitly trigger plugin registration, creating a circular dependency.
from superset.mcp_service.chart.registry import get_registry
if isinstance(config, TableChartConfig):
refs.extend(config.columns)
elif isinstance(config, XYChartConfig):
if config.x is not None:
refs.append(config.x)
refs.extend(config.y)
if config.group_by:
refs.extend(config.group_by)
elif isinstance(config, PieChartConfig):
refs.append(config.dimension)
refs.append(config.metric)
elif isinstance(config, PivotTableChartConfig):
refs.extend(config.rows)
if config.columns:
refs.extend(config.columns)
refs.extend(config.metrics)
elif isinstance(config, MixedTimeseriesChartConfig):
refs.append(config.x)
refs.extend(config.y)
if config.group_by:
refs.extend(config.group_by)
refs.extend(config.y_secondary)
if config.group_by_secondary:
refs.extend(config.group_by_secondary)
elif isinstance(config, HandlebarsChartConfig):
if config.columns:
refs.extend(config.columns)
if config.groupby:
refs.extend(config.groupby)
if config.metrics:
refs.extend(config.metrics)
elif isinstance(config, BigNumberChartConfig):
refs.append(config.metric)
if config.temporal_column:
refs.append(ColumnRef(name=config.temporal_column))
chart_type = getattr(config, "chart_type", None)
if chart_type is None:
return []
# Filter columns (shared by every config type that defines ``filters``).
if filters := getattr(config, "filters", None):
for filter_config in filters:
refs.append(ColumnRef(name=filter_config.column))
plugin = get_registry().get(chart_type)
if plugin is None:
logger.warning("No plugin registered for chart_type=%r", chart_type)
return []
return refs
return plugin.extract_column_refs(config)
@staticmethod
def _column_exists(column_name: str, dataset_context: DatasetContext) -> bool:
@@ -365,42 +333,6 @@ class DatasetValidator:
# Return original if not found (validation should catch this case)
return column_name
@staticmethod
def _normalize_xy_config(
config_dict: Dict[str, Any], dataset_context: DatasetContext
) -> None:
"""Normalize column names in an XY chart config dict in place."""
# Normalize x-axis column
if "x" in config_dict and config_dict["x"]:
config_dict["x"]["name"] = DatasetValidator._get_canonical_column_name(
config_dict["x"]["name"], dataset_context
)
# Normalize y-axis columns
if "y" in config_dict and config_dict["y"]:
for y_col in config_dict["y"]:
y_col["name"] = DatasetValidator._get_canonical_column_name(
y_col["name"], dataset_context
)
# Normalize group_by columns
if "group_by" in config_dict and config_dict["group_by"]:
for gb_col in config_dict["group_by"]:
gb_col["name"] = DatasetValidator._get_canonical_column_name(
gb_col["name"], dataset_context
)
@staticmethod
def _normalize_table_config(
config_dict: Dict[str, Any], dataset_context: DatasetContext
) -> None:
"""Normalize column names in a table chart config dict in place."""
if "columns" in config_dict and config_dict["columns"]:
for col in config_dict["columns"]:
col["name"] = DatasetValidator._get_canonical_column_name(
col["name"], dataset_context
)
@staticmethod
def _normalize_filters(
config_dict: Dict[str, Any], dataset_context: DatasetContext
@@ -417,10 +349,10 @@ class DatasetValidator:
@staticmethod
def normalize_column_names(
config: TableChartConfig | XYChartConfig,
config: _C,
dataset_id: int | str,
dataset_context: DatasetContext | None = None,
) -> TableChartConfig | XYChartConfig:
) -> _C:
"""
Normalize column names in config to match the canonical dataset column names.
@@ -429,6 +361,9 @@ class DatasetValidator:
(e.g., 'OrderDate'). The frontend performs case-sensitive comparisons,
so we need to ensure column names match exactly.
Previously only XYChartConfig and TableChartConfig were normalized; now
all 7 chart types are handled via the plugin registry.
Args:
config: Chart configuration with column references
dataset_id: Dataset ID to get canonical column names from
@@ -443,22 +378,24 @@ class DatasetValidator:
if not dataset_context:
return config
# Create a mutable copy of the config
config_dict = config.model_dump()
# Local import: plugins call DatasetValidator helpers from
# normalize_column_refs().
# A top-level import of registry in dataset_validator would make loading this
# module implicitly trigger plugin registration, creating a circular dependency.
from superset.mcp_service.chart.registry import get_registry
# Normalize based on config type
if isinstance(config, XYChartConfig):
DatasetValidator._normalize_xy_config(config_dict, dataset_context)
elif isinstance(config, TableChartConfig):
DatasetValidator._normalize_table_config(config_dict, dataset_context)
chart_type = getattr(config, "chart_type", None)
if chart_type is None:
return config
# Normalize filter columns (common to both config types)
DatasetValidator._normalize_filters(config_dict, dataset_context)
plugin = get_registry().get(chart_type)
if plugin is None:
logger.warning(
"No plugin for chart_type=%r; skipping column normalization", chart_type
)
return config
# Reconstruct the config with normalized names
if isinstance(config, XYChartConfig):
return XYChartConfig.model_validate(config_dict)
return TableChartConfig.model_validate(config_dict)
return plugin.normalize_column_refs(config, dataset_context)
@staticmethod
def _get_column_suggestions(

View File

@@ -23,10 +23,7 @@ Validates performance, compatibility, and user experience issues.
import logging
from typing import Any, Dict, List, Tuple
from superset.mcp_service.chart.schemas import (
ChartConfig,
XYChartConfig,
)
from superset.mcp_service.chart.schemas import ChartConfig
logger = logging.getLogger(__name__)
@@ -56,20 +53,10 @@ class RuntimeValidator:
warnings: List[str] = []
suggestions: List[str] = []
# Only check XY charts for format and cardinality issues
if isinstance(config, XYChartConfig):
# Format-type compatibility validation
format_warnings = RuntimeValidator._validate_format_compatibility(config)
if format_warnings:
warnings.extend(format_warnings)
# Cardinality validation
cardinality_warnings, cardinality_suggestions = (
RuntimeValidator._validate_cardinality(config, dataset_id)
)
if cardinality_warnings:
warnings.extend(cardinality_warnings)
suggestions.extend(cardinality_suggestions)
# Per-plugin runtime warnings (format, cardinality, etc.)
plugin_warnings = RuntimeValidator._validate_plugin_runtime(config, dataset_id)
if plugin_warnings:
warnings.extend(plugin_warnings)
# Chart type appropriateness validation (for all chart types)
type_warnings, type_suggestions = RuntimeValidator._validate_chart_type(
@@ -98,61 +85,28 @@ class RuntimeValidator:
return True, None
@staticmethod
def _validate_format_compatibility(config: XYChartConfig) -> List[str]:
"""Validate format-type compatibility."""
warnings: List[str] = []
def _validate_plugin_runtime(
config: ChartConfig, dataset_id: int | str
) -> List[str]:
"""Delegate per-chart-type runtime warnings to the plugin registry.
Each plugin's get_runtime_warnings() method returns chart-type-specific
warnings (e.g. format/cardinality for XY). The registry dispatch removes
the previous isinstance(config, XYChartConfig) hardcoding.
"""
try:
# Import here to avoid circular imports
from .format_validator import FormatTypeValidator
from superset.mcp_service.chart.registry import get_registry
is_valid, format_warnings = (
FormatTypeValidator.validate_format_compatibility(config)
)
if format_warnings:
warnings.extend(format_warnings)
except ImportError:
logger.warning("Format validator not available")
except Exception as e:
logger.warning("Format validation failed: %s", e)
return warnings
@staticmethod
def _validate_cardinality(
config: XYChartConfig, dataset_id: int | str
) -> Tuple[List[str], List[str]]:
"""Validate cardinality issues."""
warnings: List[str] = []
suggestions: List[str] = []
try:
# Import here to avoid circular imports
from .cardinality_validator import CardinalityValidator
# Determine chart type for cardinality thresholds
chart_type = config.kind if hasattr(config, "kind") else "default"
# Check X-axis cardinality
if config.x is None:
return warnings, suggestions
is_ok, cardinality_info = CardinalityValidator.check_cardinality(
dataset_id=dataset_id,
x_column=config.x.name,
chart_type=chart_type,
group_by_column=config.group_by[0].name if config.group_by else None,
)
if not is_ok and cardinality_info:
warnings.extend(cardinality_info.get("warnings", []))
suggestions.extend(cardinality_info.get("suggestions", []))
except ImportError:
logger.warning("Cardinality validator not available")
except Exception as e:
logger.warning("Cardinality validation failed: %s", e)
return warnings, suggestions
chart_type = getattr(config, "chart_type", None)
if chart_type is None:
return []
plugin = get_registry().get(chart_type)
if plugin is None:
return []
return plugin.get_runtime_warnings(config, dataset_id)
except Exception as exc:
logger.warning("Plugin runtime validation failed: %s", exc)
return []
@staticmethod
def _validate_chart_type(

455
superset/mcp_service/chart/validation/schema_validator.py Normal file → Executable file
View File

@@ -147,19 +147,13 @@ class SchemaValidator:
chart_type: str,
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Validate chart type and dispatch to type-specific pre-validation."""
chart_type_validators = {
"xy": SchemaValidator._pre_validate_xy_config,
"table": SchemaValidator._pre_validate_table_config,
"pie": SchemaValidator._pre_validate_pie_config,
"pivot_table": SchemaValidator._pre_validate_pivot_table_config,
"mixed_timeseries": SchemaValidator._pre_validate_mixed_timeseries_config,
"handlebars": SchemaValidator._pre_validate_handlebars_config,
"big_number": SchemaValidator._pre_validate_big_number_config,
}
"""Validate chart type and dispatch to plugin pre-validation."""
from superset.mcp_service.chart.registry import get_registry
if not isinstance(chart_type, str) or chart_type not in chart_type_validators:
valid_types = ", ".join(chart_type_validators.keys())
registry = get_registry()
if not isinstance(chart_type, str) or not registry.is_registered(chart_type):
valid_types = ", ".join(registry.all_types())
return False, ChartGenerationError(
error_type="invalid_chart_type",
message=f"Invalid chart_type: '{chart_type}'",
@@ -178,351 +172,33 @@ class SchemaValidator:
error_code="INVALID_CHART_TYPE",
)
return chart_type_validators[chart_type](config)
@staticmethod
def _pre_validate_xy_config(
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Pre-validate XY chart configuration."""
# x is optional — defaults to dataset's main_dttm_col in map_xy_config
if "y" not in config:
if not registry.is_enabled(chart_type):
valid_types = ", ".join(registry.all_types())
return False, ChartGenerationError(
error_type="missing_xy_fields",
message="XY chart missing required field: 'y' (Y-axis metrics)",
details="XY charts require Y-axis (metrics) specifications. "
"X-axis is optional and defaults to the dataset's primary "
"datetime column when omitted.",
error_type="disabled_chart_type",
message=f"Chart type '{chart_type}' is not enabled on this instance",
details=f"Chart type '{chart_type}' is registered but has been "
f"disabled by the operator. "
f"Enabled chart types: {valid_types}",
suggestions=[
"Add 'y' field: [{'name': 'metric_column', 'aggregate': 'SUM'}] "
"for Y-axis",
"Example: {'chart_type': 'xy', 'x': {'name': 'date'}, "
"'y': [{'name': 'sales', 'aggregate': 'SUM'}]}",
f"Use one of the enabled chart types: {valid_types}",
"Contact your administrator if you believe this is an error",
],
error_code="MISSING_XY_FIELDS",
error_code="DISABLED_CHART_TYPE",
)
# Validate Y is a list
if not isinstance(config.get("y", []), list):
plugin = registry.get(chart_type)
if plugin is None:
return False, ChartGenerationError(
error_type="invalid_y_format",
message="Y-axis must be a list of metrics",
details="The 'y' field must be an array of metric specifications",
suggestions=[
"Wrap Y-axis metric in array: 'y': [{'name': 'column', "
"'aggregate': 'SUM'}]",
"Multiple metrics supported: 'y': [metric1, metric2, ...]",
],
error_code="INVALID_Y_FORMAT",
error_type="invalid_chart_type",
message=f"Chart type '{chart_type}' has no registered plugin",
details="Internal error: chart type is listed but has no plugin",
suggestions=["Use a supported chart_type"],
error_code="INVALID_CHART_TYPE",
)
return True, None
@staticmethod
def _pre_validate_table_config(
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Pre-validate table chart configuration."""
if "columns" not in config:
return False, ChartGenerationError(
error_type="missing_columns",
message="Table chart missing required field: columns",
details="Table charts require a 'columns' array to specify which "
"columns to display",
suggestions=[
"Add 'columns' field with array of column specifications",
"Example: 'columns': [{'name': 'product'}, {'name': 'sales', "
"'aggregate': 'SUM'}]",
"Each column can have optional 'aggregate' for metrics",
],
error_code="MISSING_COLUMNS",
)
if not isinstance(config.get("columns", []), list):
return False, ChartGenerationError(
error_type="invalid_columns_format",
message="Columns must be a list",
details="The 'columns' field must be an array of column specifications",
suggestions=[
"Ensure columns is an array: 'columns': [...]",
"Each column should be an object with 'name' field",
],
error_code="INVALID_COLUMNS_FORMAT",
)
return True, None
@staticmethod
def _pre_validate_pie_config(
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Pre-validate pie chart configuration."""
missing_fields = []
if "dimension" not in config:
missing_fields.append("'dimension' (category column for slices)")
if "metric" not in config:
missing_fields.append("'metric' (value metric for slice sizes)")
if missing_fields:
return False, ChartGenerationError(
error_type="missing_pie_fields",
message=f"Pie chart missing required "
f"fields: {', '.join(missing_fields)}",
details="Pie charts require a dimension (categories) and a metric "
"(values)",
suggestions=[
"Add 'dimension' field: {'name': 'category_column'}",
"Add 'metric' field: {'name': 'value_column', 'aggregate': 'SUM'}",
"Example: {'chart_type': 'pie', 'dimension': {'name': "
"'product'}, 'metric': {'name': 'revenue', 'aggregate': 'SUM'}}",
],
error_code="MISSING_PIE_FIELDS",
)
return True, None
@staticmethod
def _pre_validate_handlebars_config(
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Pre-validate handlebars chart configuration."""
if "handlebars_template" not in config:
return False, ChartGenerationError(
error_type="missing_handlebars_template",
message="Handlebars chart missing required field: handlebars_template",
details="Handlebars charts require a 'handlebars_template' string "
"containing Handlebars HTML template markup",
suggestions=[
"Add 'handlebars_template' with a Handlebars HTML template",
"Data is available as {{data}} array in the template",
"Example: '<ul>{{#each data}}<li>{{this.name}}: "
"{{this.value}}</li>{{/each}}</ul>'",
],
error_code="MISSING_HANDLEBARS_TEMPLATE",
)
template = config.get("handlebars_template")
if not isinstance(template, str) or not template.strip():
return False, ChartGenerationError(
error_type="invalid_handlebars_template",
message="Handlebars template must be a non-empty string",
details="The 'handlebars_template' field must be a non-empty string "
"containing valid Handlebars HTML template markup",
suggestions=[
"Ensure handlebars_template is a non-empty string",
"Example: '<ul>{{#each data}}<li>{{this.name}}</li>{{/each}}</ul>'",
],
error_code="INVALID_HANDLEBARS_TEMPLATE",
)
query_mode = config.get("query_mode", "aggregate")
if query_mode not in ("aggregate", "raw"):
return False, ChartGenerationError(
error_type="invalid_query_mode",
message="Invalid query_mode for handlebars chart",
details="query_mode must be either 'aggregate' or 'raw'",
suggestions=[
"Use 'aggregate' for aggregated data (default)",
"Use 'raw' for individual rows",
],
error_code="INVALID_QUERY_MODE",
)
if query_mode == "raw" and not config.get("columns"):
return False, ChartGenerationError(
error_type="missing_raw_columns",
message="Handlebars chart in 'raw' mode requires 'columns'",
details="When query_mode is 'raw', you must specify which columns "
"to include in the query results",
suggestions=[
"Add 'columns': [{'name': 'column_name'}] for raw mode",
"Or use query_mode='aggregate' with 'metrics' "
"and optional 'groupby'",
],
error_code="MISSING_RAW_COLUMNS",
)
if query_mode == "aggregate" and not config.get("metrics"):
return False, ChartGenerationError(
error_type="missing_aggregate_metrics",
message="Handlebars chart in 'aggregate' mode requires 'metrics'",
details="When query_mode is 'aggregate' (default), you must specify "
"at least one metric with an aggregate function",
suggestions=[
"Add 'metrics': [{'name': 'column', 'aggregate': 'SUM'}]",
"Or use query_mode='raw' with 'columns' for individual rows",
],
error_code="MISSING_AGGREGATE_METRICS",
)
return True, None
@staticmethod
def _pre_validate_big_number_config(
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Pre-validate big number chart configuration."""
if "metric" not in config:
return False, ChartGenerationError(
error_type="missing_metric",
message="Big Number chart missing required field: metric",
details="Big Number charts require a 'metric' field "
"specifying the value to display",
suggestions=[
"Add 'metric' with name and aggregate: "
"{'name': 'revenue', 'aggregate': 'SUM'}",
"The aggregate function is required (SUM, COUNT, AVG, MIN, MAX)",
"Example: {'chart_type': 'big_number', "
"'metric': {'name': 'sales', 'aggregate': 'SUM'}}",
],
error_code="MISSING_BIG_NUMBER_METRIC",
)
metric = config.get("metric", {})
if not isinstance(metric, dict):
return False, ChartGenerationError(
error_type="invalid_metric_type",
message="Big Number metric must be a dict with 'name' and 'aggregate'",
details="The 'metric' field must be an object, "
f"got {type(metric).__name__}",
suggestions=[
"Use a dict: {'name': 'col', 'aggregate': 'SUM'}",
"Valid aggregates: SUM, COUNT, AVG, MIN, MAX",
],
error_code="INVALID_BIG_NUMBER_METRIC_TYPE",
)
if not metric.get("aggregate") and not metric.get("saved_metric"):
return False, ChartGenerationError(
error_type="missing_metric_aggregate",
message="Big Number metric must include an aggregate function "
"or reference a saved metric",
details="The metric must have an 'aggregate' field "
"or 'saved_metric': true",
suggestions=[
"Add 'aggregate' to your metric: "
"{'name': 'col', 'aggregate': 'SUM'}",
"Or use a saved metric: "
"{'name': 'total_sales', 'saved_metric': true}",
"Valid aggregates: SUM, COUNT, AVG, MIN, MAX",
],
error_code="MISSING_BIG_NUMBER_AGGREGATE",
)
show_trendline = config.get("show_trendline", False)
temporal_column = config.get("temporal_column")
if show_trendline and not temporal_column:
return False, ChartGenerationError(
error_type="missing_temporal_column",
message="Trendline requires a temporal column",
details="When 'show_trendline' is True, a "
"'temporal_column' must be specified",
suggestions=[
"Add 'temporal_column': 'date_column_name'",
"Or set 'show_trendline': false for number only",
"Use get_dataset_info to find temporal columns",
],
error_code="MISSING_TEMPORAL_COLUMN",
)
return True, None
@staticmethod
def _pre_validate_pivot_table_config(
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Pre-validate pivot table configuration."""
missing_fields = []
if "rows" not in config:
missing_fields.append("'rows' (row grouping columns)")
if "metrics" not in config:
missing_fields.append("'metrics' (aggregation metrics)")
if missing_fields:
return False, ChartGenerationError(
error_type="missing_pivot_fields",
message=f"Pivot table missing required "
f"fields: {', '.join(missing_fields)}",
details="Pivot tables require row groupings and metrics",
suggestions=[
"Add 'rows' field: [{'name': 'category'}]",
"Add 'metrics' field: [{'name': 'sales', 'aggregate': 'SUM'}]",
"Optional 'columns' for cross-tabulation: [{'name': 'region'}]",
],
error_code="MISSING_PIVOT_FIELDS",
)
if not isinstance(config.get("rows", []), list):
return False, ChartGenerationError(
error_type="invalid_rows_format",
message="Rows must be a list of columns",
details="The 'rows' field must be an array of column specifications",
suggestions=[
"Wrap row columns in array: 'rows': [{'name': 'category'}]",
],
error_code="INVALID_ROWS_FORMAT",
)
if not isinstance(config.get("metrics", []), list):
return False, ChartGenerationError(
error_type="invalid_metrics_format",
message="Metrics must be a list",
details="The 'metrics' field must be an array of metric specifications",
suggestions=[
"Wrap metrics in array: 'metrics': [{'name': 'sales', "
"'aggregate': 'SUM'}]",
],
error_code="INVALID_METRICS_FORMAT",
)
return True, None
@staticmethod
def _pre_validate_mixed_timeseries_config(
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Pre-validate mixed timeseries configuration."""
missing_fields = []
if "x" not in config:
missing_fields.append("'x' (X-axis temporal column)")
if "y" not in config:
missing_fields.append("'y' (primary Y-axis metrics)")
if "y_secondary" not in config:
missing_fields.append("'y_secondary' (secondary Y-axis metrics)")
if missing_fields:
return False, ChartGenerationError(
error_type="missing_mixed_timeseries_fields",
message=f"Mixed timeseries chart missing required "
f"fields: {', '.join(missing_fields)}",
details="Mixed timeseries charts require an x-axis, primary metrics, "
"and secondary metrics",
suggestions=[
"Add 'x' field: {'name': 'date_column'}",
"Add 'y' field: [{'name': 'revenue', 'aggregate': 'SUM'}]",
"Add 'y_secondary' field: [{'name': 'orders', "
"'aggregate': 'COUNT'}]",
"Optional: 'primary_kind' and 'secondary_kind' for chart types",
],
error_code="MISSING_MIXED_TIMESERIES_FIELDS",
)
for field_name in ["y", "y_secondary"]:
if not isinstance(config.get(field_name, []), list):
return False, ChartGenerationError(
error_type=f"invalid_{field_name}_format",
message=f"'{field_name}' must be a list of metrics",
details=f"The '{field_name}' field must be an array of metric "
"specifications",
suggestions=[
f"Wrap in array: '{field_name}': "
"[{'name': 'col', 'aggregate': 'SUM'}]",
],
error_code=f"INVALID_{field_name.upper()}_FORMAT",
)
if (error := plugin.pre_validate(config)) is not None:
return False, error
return True, None
@staticmethod
@@ -537,89 +213,26 @@ class SchemaValidator:
if err.get("type") == "union_tag_invalid" or "discriminator" in str(
err.get("ctx", {})
):
# This is the generic union error - provide better message
config = request_data.get("config", {})
chart_type = config.get("chart_type", "unknown")
from superset.mcp_service.chart.registry import get_registry
if chart_type == "xy":
return ChartGenerationError(
error_type="xy_validation_error",
message="XY chart configuration validation failed",
details="The XY chart configuration is missing required "
"fields or has invalid structure",
suggestions=[
"Ensure 'x' field exists with {'name': 'column_name'}",
"Ensure 'y' field is an array: [{'name': 'metric', "
"'aggregate': 'SUM'}]",
"Check that all column names are strings",
"Verify aggregate functions are valid: SUM, COUNT, AVG, "
"MIN, MAX",
],
error_code="XY_VALIDATION_ERROR",
)
elif chart_type == "table":
return ChartGenerationError(
error_type="table_validation_error",
message="Table chart configuration validation failed",
details="The table chart configuration is missing required "
"fields or has invalid structure",
suggestions=[
"Ensure 'columns' field is an array of column "
"specifications",
"Each column needs {'name': 'column_name'}",
"Optional: add 'aggregate' for metrics",
"Example: 'columns': [{'name': 'product'}, {'name': "
"'sales', 'aggregate': 'SUM'}]",
],
error_code="TABLE_VALIDATION_ERROR",
)
elif chart_type == "handlebars":
return ChartGenerationError(
error_type="handlebars_validation_error",
message="Handlebars chart configuration validation failed",
details="The handlebars chart configuration is missing "
"required fields or has invalid structure",
suggestions=[
"Ensure 'handlebars_template' is a non-empty string",
"For aggregate mode: add 'metrics' with aggregate "
"functions",
"For raw mode: set 'query_mode': 'raw' and add 'columns'",
"Example: {'chart_type': 'handlebars', "
"'handlebars_template': '<ul>{{#each data}}<li>"
"{{this.name}}</li>{{/each}}</ul>', "
"'metrics': [{'name': 'sales', 'aggregate': 'SUM'}]}",
],
error_code="HANDLEBARS_VALIDATION_ERROR",
)
elif chart_type == "big_number":
return ChartGenerationError(
error_type="big_number_validation_error",
message="Big Number chart configuration validation failed",
details="The Big Number chart configuration is "
"missing required fields or has invalid "
"structure",
suggestions=[
"Ensure 'metric' field has 'name' and 'aggregate'",
"Example: 'metric': {'name': 'revenue', "
"'aggregate': 'SUM'}",
"For trendline: add 'show_trendline': true "
"and 'temporal_column': 'date_col'",
"Without trendline: just provide the metric",
],
error_code="BIG_NUMBER_VALIDATION_ERROR",
)
chart_type = request_data.get("config", {}).get("chart_type", "")
plugin = get_registry().get(chart_type)
if plugin is not None:
hint = plugin.schema_error_hint()
if hint is not None:
return hint
# Default enhanced error
error_details = []
for err in errors[:3]: # Show first 3 errors
loc = " -> ".join(str(location) for location in err.get("loc", []))
msg = err.get("msg", "Validation failed")
error_details.append(f"{loc}: {msg}")
error_details.append(f"{loc}: {msg}" if loc else msg)
return ChartGenerationError(
error_type="validation_error",
message="Chart configuration validation failed",
details="; ".join(error_details),
details="; ".join(error_details) or "Invalid chart configuration structure",
suggestions=[
"Check that all required fields are present",
"Ensure field types match the schema",

View File

@@ -81,6 +81,17 @@ try:
mcp_config = get_mcp_config(_mcp_app.config)
_mcp_app.config.update(mcp_config)
# Re-configure chart registry so MCP-specific overrides (e.g.
# MCP_DISABLED_CHART_PLUGINS set by the operator) take effect after
# the MCP config overlay. SupersetAppInitializer.configure_mcp_chart_registry()
# ran earlier with pre-overlay values; this call corrects them.
from superset.mcp_service.chart import registry as _chart_registry
_chart_registry.configure(
disabled=_mcp_app.config.get("MCP_DISABLED_CHART_PLUGINS"),
enabled_func=_mcp_app.config.get("MCP_CHART_PLUGIN_ENABLED_FUNC"),
)
with _mcp_app.app_context():
from superset.core.mcp.core_mcp_injection import (
initialize_core_mcp_dependencies,

View File

@@ -18,6 +18,7 @@
import logging
import secrets
from collections.abc import Callable
from typing import Any, Dict, Optional
from flask import Flask
@@ -56,6 +57,46 @@ MCP_DEBUG = False
# against the FAB security_manager before execution.
MCP_RBAC_ENABLED = True
# =============================================================================
# MCP Chart Plugin Filtering
# =============================================================================
#
# Overview:
# ---------
# These two settings let operators enable/disable individual chart type plugins
# at runtime without a code deploy.
#
# Use cases:
# - Emergency kill switch: add "handlebars" to MCP_DISABLED_CHART_PLUGINS and
# restart to immediately hide it from all callers.
# - Dynamic per-request control (A/B test, gradual rollout): set
# MCP_CHART_PLUGIN_ENABLED_FUNC to an in-process predicate that can vary
# by user, request header, or any other context available at call time.
#
# Priority:
# MCP_CHART_PLUGIN_ENABLED_FUNC takes precedence over MCP_DISABLED_CHART_PLUGINS.
# When the callable is set, the deny-list is ignored entirely.
#
# MCP_CHART_PLUGIN_ENABLED_FUNC contract:
# - Called as enabled_func(chart_type: str) -> bool for every registry lookup.
# - Must be cheap and in-process: consult already-loaded feature flags or
# request-local context (e.g. Flask g). Do NOT perform network I/O per call.
# - On exception, the registry fails CLOSED (plugin hidden) and logs a warning.
# - Example (Harness / Split via pre-fetched flags in g):
# from flask import g
# def MCP_CHART_PLUGIN_ENABLED_FUNC(chart_type: str) -> bool:
# flags = getattr(g, "feature_flags", {})
# return flags.get(f"mcp_chart_{chart_type}", True)
# =============================================================================
# Chart types in this set are hidden from all registry lookups.
# Use frozenset to avoid accidental mutation.
MCP_DISABLED_CHART_PLUGINS: frozenset[str] = frozenset()
# Dynamic per-call predicate. When set, overrides MCP_DISABLED_CHART_PLUGINS.
# Signature: (chart_type: str) -> bool
MCP_CHART_PLUGIN_ENABLED_FUNC: Callable[[str], bool] | None = None
# MCP JWT Debug Errors - controls server-side JWT debug logging.
# When False (default), uses the default JWTVerifier with minimal logging.
# When True, uses DetailedJWTVerifier with tiered logging:
@@ -402,6 +443,8 @@ def get_mcp_config(app_config: Dict[str, Any] | None = None) -> Dict[str, Any]:
"MCP_SERVICE_PORT": MCP_SERVICE_PORT,
"MCP_DEBUG": MCP_DEBUG,
"MCP_RBAC_ENABLED": MCP_RBAC_ENABLED,
"MCP_DISABLED_CHART_PLUGINS": MCP_DISABLED_CHART_PLUGINS,
"MCP_CHART_PLUGIN_ENABLED_FUNC": MCP_CHART_PLUGIN_ENABLED_FUNC,
**MCP_SESSION_CONFIG,
**MCP_CSRF_CONFIG,
}

View File

@@ -41,6 +41,12 @@ from superset.mcp_service.constants import (
DEFAULT_TOKEN_LIMIT,
DEFAULT_WARN_THRESHOLD_PCT,
)
from superset.mcp_service.utils.token_utils import (
estimate_response_tokens,
format_size_limit_error,
INFO_TOOLS,
truncate_oversized_response,
)
from superset.utils.core import get_user_id
logger = logging.getLogger(__name__)
@@ -1104,11 +1110,6 @@ class ResponseSizeGuardMiddleware(Middleware):
``content[0].text`` as a JSON string. We parse that string, run the
truncation phases on the resulting dict, then re-wrap the result.
"""
from superset.mcp_service.utils.token_utils import (
estimate_response_tokens,
truncate_oversized_response,
)
# Unwrap ToolResult so truncation operates on the real payload
extracted = self._extract_payload_from_tool_result(response)
if extracted is not None:
@@ -1191,12 +1192,6 @@ class ResponseSizeGuardMiddleware(Middleware):
# Execute the tool
response = await call_next(context)
# Estimate response token count (guard against huge responses causing OOM)
from superset.mcp_service.utils.token_utils import (
estimate_response_tokens,
format_size_limit_error,
)
# When the response is a ToolResult, estimate tokens on the actual
# payload inside content[0].text rather than on the ToolResult
# wrapper (which would double-serialize the JSON string).
@@ -1233,8 +1228,6 @@ class ResponseSizeGuardMiddleware(Middleware):
params = getattr(context.message, "params", {}) or {}
# For info tools, try dynamic truncation before blocking
from superset.mcp_service.utils.token_utils import INFO_TOOLS
if tool_name in INFO_TOOLS:
truncated = self._try_truncate_info_response(
tool_name, response, estimated_tokens

View File

@@ -21,6 +21,26 @@ Token counting and response size utilities for MCP service.
This module provides utilities to estimate token counts and generate smart
suggestions when responses exceed configured limits. This prevents large
responses from overwhelming LLM clients like Claude Desktop.
Token counting strategy:
1. ``tiktoken`` with the ``cl100k_base`` encoding when the package is
installed (it is shipped as part of the ``fastmcp`` extra). This is a
real BPE tokenizer trained on a similar vocabulary to Claude's; for
English and JSON-heavy MCP payloads it tracks Claude's tokenizer
within roughly ±10%, which is far more accurate than the legacy
character heuristic.
2. A character-based fallback (``CHARS_PER_TOKEN``) when tiktoken is not
importable. The fallback uses a slightly more conservative ratio than
before (3.0 chars/token instead of 3.5) so that JSON-heavy responses
are not under-counted, which previously let oversized payloads slip
past the response-size guard.
The exact-Claude tokenizer is only available via Anthropic's network
``count_tokens`` API; calling it from a synchronous middleware on every
tool result is too slow and adds an external dependency on every
response. ``tiktoken`` is the closest approximation we can ship without
that risk.
"""
from __future__ import annotations
@@ -36,18 +56,63 @@ logger = logging.getLogger(__name__)
# Type alias for MCP tool responses (Pydantic models, dicts, lists, strings, bytes)
ToolResponse: TypeAlias = Union[BaseModel, Dict[str, Any], List[Any], str, bytes]
# Approximate characters per token for estimation
# Claude tokenizer averages ~4 chars per token for English text
# JSON tends to be more verbose, so we use a slightly lower ratio
CHARS_PER_TOKEN = 3.5
# Fallback character-to-token ratio used when tiktoken is unavailable.
# 3.0 is conservative for JSON content (the previous 3.5 under-counted
# JSON-heavy payloads relative to Claude's actual tokenizer, which let
# oversized responses slip past the response-size guard).
CHARS_PER_TOKEN = 3.0
# Encoding used when tiktoken is available. cl100k_base is OpenAI's
# tokenizer for GPT-3.5/4; it is BPE-based with a vocabulary similar to
# Claude's and tracks Claude's token counts within roughly ±10% for
# English and JSON-heavy MCP responses.
_TIKTOKEN_ENCODING_NAME = "cl100k_base"
def _load_tiktoken_encoding() -> Any:
"""Return a tiktoken encoding instance, or None if tiktoken is unavailable.
Imported lazily so the module can be used in environments without
tiktoken installed. The encoding is small (~1 MB) so we cache it on
first use.
"""
try:
import tiktoken
except ImportError:
logger.info(
"tiktoken not installed; falling back to char-based token "
"estimation (CHARS_PER_TOKEN=%s). Install the 'fastmcp' extra "
"for accurate counts.",
CHARS_PER_TOKEN,
)
return None
try:
return tiktoken.get_encoding(_TIKTOKEN_ENCODING_NAME)
except (KeyError, ValueError) as exc:
# tiktoken installed but the requested encoding is missing — this
# only happens on partial installs. Treat as no tokenizer rather
# than crashing on every tool call.
logger.warning(
"tiktoken encoding '%s' unavailable: %s; falling back to "
"char-based token estimation",
_TIKTOKEN_ENCODING_NAME,
exc,
)
return None
# Cached encoding instance (None if tiktoken not importable).
_ENCODING = _load_tiktoken_encoding()
def estimate_token_count(text: str | bytes) -> int:
"""
Estimate the token count for a given text.
Uses a character-based heuristic since we don't have direct access to
the actual tokenizer. This is conservative to avoid underestimating.
Uses tiktoken's ``cl100k_base`` encoding when available for
Claude-aligned accuracy (within ~10%), falling back to a
character-based heuristic otherwise.
Args:
text: The text to estimate tokens for (string or bytes)
@@ -58,11 +123,19 @@ def estimate_token_count(text: str | bytes) -> int:
if isinstance(text, bytes):
text = text.decode("utf-8", errors="replace")
# Simple heuristic: ~3.5 characters per token for JSON/code
text_length = len(text)
if text_length == 0:
if not text:
return 0
return max(1, int(text_length / CHARS_PER_TOKEN))
if _ENCODING is not None:
try:
return len(_ENCODING.encode(text))
except (ValueError, UnicodeError) as exc:
# Defensive: if tiktoken chokes on a specific input, fall
# back to the char heuristic for this call rather than
# raising — the response size guard must never fail-open.
logger.warning("tiktoken encode failed (%s); using fallback", exc)
return max(1, int(len(text) / CHARS_PER_TOKEN))
def estimate_response_tokens(response: ToolResponse) -> int:

View File

@@ -19,9 +19,7 @@
OpenSearch SQL dialect.
OpenSearch SQL is syntactically close to MySQL but accepts both backticks and
double-quotes as identifier delimiters. Treating ``"`` as an identifier (rather
than a string delimiter, as MySQL does) is what keeps mixed-case column names
from being emitted as string literals after a SQLGlot round-trip.
double-quotes as identifier delimiters.
"""
from __future__ import annotations
@@ -31,4 +29,4 @@ from sqlglot.dialects.mysql import MySQL
class OpenSearch(MySQL):
class Tokenizer(MySQL.Tokenizer):
IDENTIFIERS = ['"', "`"]
IDENTIFIERS = ["`", '"']

View File

@@ -45,6 +45,13 @@
color: #000;
}
{% endif %}
{% if standalone_mode %}
/* Keep body sized so screenshot waits don't see it as hidden before React mounts. */
html, body.standalone {
min-height: 100vh;
margin: 0;
}
{% endif %}
</style>
{% if dark_theme_bg and entry != 'embedded' %}

View File

@@ -68,3 +68,50 @@ def test_spa_template_includes_css_bundles():
"spa.html must call css_bundle for the page entry to load "
"entry-specific extracted CSS in production builds"
)
def test_spa_template_standalone_body_has_min_height():
"""Standalone body must be measurable so screenshot waits don't time out."""
from jinja2 import DictLoader, Environment
template_path = join(SUPERSET_DIR, "templates", "superset", "spa.html")
with open(template_path) as f:
template_content = f.read()
env = Environment( # noqa: S701
loader=DictLoader(
{
"spa.html": template_content,
# Stub out includes/imports that are not relevant for this test.
"appbuilder/general/lib.html": "",
"superset/partials/asset_bundle.html": (
"{% macro css_bundle(prefix, entry) %}{% endmacro %}"
"{% macro js_bundle(prefix, entry) %}{% endmacro %}"
),
"superset/macros.html": ("{% macro get_nonce() %}{% endmacro %}"),
"tail_js_custom_extra.html": "",
"head_custom_extra.html": "",
}
)
)
appbuilder = Mock()
appbuilder.app.config = {"FAVICONS": []}
def render(standalone_mode: bool) -> str:
return env.get_template("spa.html").render(
appbuilder=appbuilder,
assets_prefix="",
bootstrap_data="{}",
entry="spa",
standalone_mode=standalone_mode,
theme_tokens={},
spinner_svg=None,
)
standalone_html = render(standalone_mode=True)
assert "body.standalone" in standalone_html
assert "min-height: 100vh" in standalone_html
non_standalone_html = render(standalone_mode=False)
assert "body.standalone" not in non_standalone_html

View File

@@ -90,7 +90,7 @@ class TestBigNumberChartConfig:
"chart_type": "big_number",
"metric": {"name": "total_sales", "saved_metric": True},
}
is_valid, error = SchemaValidator._pre_validate_big_number_config(data)
is_valid, error = SchemaValidator._pre_validate_chart_type("big_number", data)
assert is_valid is True
assert error is None

View File

@@ -0,0 +1,143 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for the chart type plugin registry."""
import pytest
import superset.mcp_service.chart.registry as registry_module
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.registry import (
_RegistryProxy,
all_types,
display_name_for_viz_type,
get,
get_registry,
is_registered,
register,
)
@pytest.fixture(autouse=True)
def _isolated_registry(monkeypatch):
"""Run each test against a clean registry without touching the real one."""
monkeypatch.setattr(registry_module, "_REGISTRY", {})
monkeypatch.setattr(registry_module, "_plugins_loaded", True)
class _FakePlugin(BaseChartPlugin):
chart_type = "fake"
display_name = "Fake Chart"
native_viz_types = {"fake_viz": "Fake Viz"}
class _AnotherPlugin(BaseChartPlugin):
chart_type = "another"
display_name = "Another Chart"
native_viz_types = {"another_viz": "Another Viz"}
def test_register_adds_plugin():
plugin = _FakePlugin()
register(plugin)
assert get("fake") is plugin
def test_get_returns_none_for_unknown():
assert get("nonexistent") is None
def test_all_types_returns_registered_keys():
register(_FakePlugin())
register(_AnotherPlugin())
types = all_types()
assert "fake" in types
assert "another" in types
def test_all_types_insertion_order():
register(_FakePlugin())
register(_AnotherPlugin())
types = all_types()
assert types.index("fake") < types.index("another")
def test_is_registered_true_for_known():
register(_FakePlugin())
assert is_registered("fake") is True
def test_is_registered_false_for_unknown():
assert is_registered("nonexistent") is False
def test_register_warns_on_duplicate(caplog):
register(_FakePlugin())
with caplog.at_level("WARNING"):
register(_FakePlugin())
assert "Overwriting" in caplog.text
def test_register_raises_for_empty_chart_type():
class _BadPlugin(BaseChartPlugin):
chart_type = ""
with pytest.raises(ValueError, match="non-empty chart_type"):
register(_BadPlugin())
def test_display_name_for_viz_type_found():
register(_FakePlugin())
assert display_name_for_viz_type("fake_viz") == "Fake Viz"
def test_display_name_for_viz_type_not_found():
register(_FakePlugin())
assert display_name_for_viz_type("unknown_viz") is None
def test_display_name_searches_all_plugins():
register(_FakePlugin())
register(_AnotherPlugin())
assert display_name_for_viz_type("another_viz") == "Another Viz"
def test_get_registry_returns_proxy():
assert isinstance(get_registry(), _RegistryProxy)
def test_registry_proxy_get():
plugin = _FakePlugin()
register(plugin)
assert get_registry().get("fake") is plugin
def test_registry_proxy_all_types():
register(_FakePlugin())
assert "fake" in get_registry().all_types()
def test_registry_proxy_is_registered():
register(_FakePlugin())
assert get_registry().is_registered("fake") is True
assert get_registry().is_registered("missing") is False
def test_registry_proxy_display_name_for_viz_type():
register(_FakePlugin())
assert get_registry().display_name_for_viz_type("fake_viz") == "Fake Viz"
assert get_registry().display_name_for_viz_type("unknown") is None

View File

@@ -0,0 +1,222 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for registry plugin filtering (configure / is_enabled / get / all_types)."""
import pytest
import superset.mcp_service.chart.registry as registry_module
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.registry import (
_PluginFilterConfig,
all_types,
configure,
display_name_for_viz_type,
get,
is_enabled,
is_registered,
register,
)
@pytest.fixture(autouse=True)
def _isolated_registry(monkeypatch):
"""Isolated registry with two known plugins and a clean filter for each test."""
monkeypatch.setattr(registry_module, "_REGISTRY", {})
monkeypatch.setattr(registry_module, "_plugins_loaded", True)
monkeypatch.setattr(registry_module, "_filter_config", _PluginFilterConfig())
register(_AlphaPlugin())
register(_BetaPlugin())
class _AlphaPlugin(BaseChartPlugin):
chart_type = "alpha"
display_name = "Alpha Chart"
native_viz_types = {"alpha_viz": "Alpha Viz"}
class _BetaPlugin(BaseChartPlugin):
chart_type = "beta"
display_name = "Beta Chart"
native_viz_types = {"beta_viz": "Beta Viz"}
# ---------------------------------------------------------------------------
# Static deny-list tests
# ---------------------------------------------------------------------------
def test_get_returns_plugin_when_enabled():
assert get("alpha") is not None
assert get("beta") is not None
def test_get_returns_none_for_disabled_plugin():
configure(disabled={"alpha"})
assert get("alpha") is None
def test_get_still_returns_other_plugins_when_one_is_disabled():
configure(disabled={"alpha"})
assert get("beta") is not None
def test_all_types_excludes_disabled():
configure(disabled={"alpha"})
types = all_types()
assert "alpha" not in types
assert "beta" in types
def test_all_types_empty_when_all_disabled():
configure(disabled={"alpha", "beta"})
assert all_types() == []
def test_is_registered_ignores_deny_list():
configure(disabled={"alpha"})
assert is_registered("alpha") is True
def test_is_enabled_returns_false_for_disabled():
configure(disabled={"alpha"})
assert is_enabled("alpha") is False
def test_is_enabled_returns_true_when_not_disabled():
configure(disabled={"alpha"})
assert is_enabled("beta") is True
def test_is_enabled_returns_false_for_unknown():
assert is_enabled("nonexistent") is False
# ---------------------------------------------------------------------------
# configure() accepts different iterable shapes
# ---------------------------------------------------------------------------
def test_configure_accepts_list():
configure(disabled=["alpha"])
assert get("alpha") is None
def test_configure_accepts_tuple():
configure(disabled=("alpha",))
assert get("alpha") is None
def test_configure_accepts_frozenset():
configure(disabled=frozenset({"alpha"}))
assert get("alpha") is None
def test_configure_accepts_none_disabled():
configure(disabled=None)
assert get("alpha") is not None
def test_configure_rejects_noncallable_enabled_func():
with pytest.raises(TypeError):
configure(enabled_func="not_a_callable")
# ---------------------------------------------------------------------------
# Dynamic callable hook tests
# ---------------------------------------------------------------------------
def test_enabled_func_overrides_deny_list():
# alpha is in deny-list but callable says True → should be visible
configure(disabled={"alpha"}, enabled_func=lambda ct: ct == "alpha")
assert get("alpha") is not None
def test_enabled_func_can_disable_plugin():
configure(enabled_func=lambda ct: ct != "beta")
assert get("beta") is None
assert get("alpha") is not None
def test_enabled_func_called_per_lookup():
calls = []
def hook(ct: str) -> bool:
calls.append(ct)
return True
configure(enabled_func=hook)
get("alpha")
get("alpha")
assert calls.count("alpha") == 2
def test_enabled_func_exception_fails_closed(caplog):
import logging
def bad_hook(ct: str) -> bool:
raise RuntimeError("Harness down")
configure(enabled_func=bad_hook)
with caplog.at_level(logging.WARNING, logger="superset.mcp_service.chart.registry"):
result = get("alpha")
assert result is None # fail closed
assert "failing closed" in caplog.text.lower() or "alpha" in caplog.text
def test_enabled_func_all_types_respects_hook():
configure(enabled_func=lambda ct: ct == "alpha")
assert all_types() == ["alpha"]
# ---------------------------------------------------------------------------
# display_name_for_viz_type is NOT filtered
# ---------------------------------------------------------------------------
def test_display_name_unaffected_by_deny_list():
configure(disabled={"alpha"})
# Even though alpha is disabled, its viz_type should still resolve
assert display_name_for_viz_type("alpha_viz") == "Alpha Viz"
def test_display_name_unaffected_by_callable():
configure(enabled_func=lambda ct: False)
assert display_name_for_viz_type("beta_viz") == "Beta Viz"
# ---------------------------------------------------------------------------
# configure() atomicity: replacing config is visible to next lookup
# ---------------------------------------------------------------------------
def test_reconfigure_replaces_previous_filter():
configure(disabled={"alpha"})
assert get("alpha") is None
configure(disabled=set())
assert get("alpha") is not None
def test_reconfigure_with_func_then_none_falls_back_to_deny_list():
configure(enabled_func=lambda ct: False)
assert get("alpha") is None
configure(disabled={"beta"}, enabled_func=None)
assert get("alpha") is not None
assert get("beta") is None

View File

@@ -595,3 +595,191 @@ Market Share
"""
# These demonstrate the expected ASCII formats for different chart types
class TestDetachedInstanceError:
"""Tests that DetachedInstanceError is handled gracefully.
When the SQLAlchemy session commits mid-request, ORM objects expire and
become detached. Accessing lazy attributes on a detached Slice raises
DetachedInstanceError. The tool must:
1. Call db.session.refresh() immediately after loading the chart so all
column values are loaded upfront before any downstream operation.
2. Catch SQLAlchemyError (the base class) and return a ChartError
instead of propagating the exception.
"""
@pytest.mark.asyncio
async def test_session_refresh_called_after_chart_load(self):
"""db.session.refresh() is invoked right after find_chart_by_identifier."""
import importlib
from contextlib import nullcontext
from unittest.mock import MagicMock, patch
from superset.mcp_service.chart.schemas import URLPreview
from superset.utils import json
get_chart_preview_module = importlib.import_module(
"superset.mcp_service.chart.tool.get_chart_preview"
)
mock_chart = MagicMock()
mock_chart.id = 42
mock_chart.slice_name = "Sales Chart"
mock_chart.viz_type = "table"
mock_chart.datasource_id = 1
mock_chart.datasource_type = "table"
mock_chart.params = "{}"
refresh_calls: list[object] = []
def _fake_refresh(obj: object) -> None:
refresh_calls.append(obj)
url_preview = URLPreview(
preview_url="http://localhost/explore/?slice_id=42",
width=800,
height=600,
)
with (
patch.object(
get_chart_preview_module,
"find_chart_by_identifier",
return_value=mock_chart,
),
patch.object(
get_chart_preview_module.db,
"session",
**{"refresh.side_effect": _fake_refresh},
),
patch.object(
get_chart_preview_module,
"validate_chart_dataset",
return_value=MagicMock(is_valid=True, warnings=[]),
),
patch.object(
get_chart_preview_module.event_logger,
"log_context",
return_value=nullcontext(),
),
# Return a real URLPreview so Pydantic model validation succeeds
patch.object(
get_chart_preview_module.PreviewFormatGenerator,
"generate",
return_value=url_preview,
),
patch(
"superset.mcp_service.utils.url_utils.get_superset_base_url",
return_value="http://localhost",
),
):
from fastmcp import Client
from superset.mcp_service.app import mcp
from superset.mcp_service.chart.schemas import GetChartPreviewRequest
with patch("superset.mcp_service.auth.get_user_from_request") as mu:
mu.return_value = MagicMock(id=1, username="admin")
with patch(
"superset.mcp_service.auth.check_tool_permission", return_value=True
):
async with Client(mcp) as client:
response = await client.call_tool(
"get_chart_preview",
{
"request": GetChartPreviewRequest(
identifier=42, format="url"
).model_dump()
},
)
data = json.loads(response.content[0].text)
# The tool should succeed — not return a ChartError
assert "error_type" not in data, (
f"Expected ChartPreview but got ChartError: {data.get('error')}"
)
assert data.get("chart_id") == 42
assert len(refresh_calls) == 1, (
"db.session.refresh() should be called once after loading the chart"
)
assert refresh_calls[0] is mock_chart
@pytest.mark.asyncio
async def test_detached_instance_error_returns_chart_error(self):
"""DetachedInstanceError during preview generation returns ChartError."""
import importlib
from contextlib import nullcontext
from unittest.mock import MagicMock, patch
from sqlalchemy.orm.exc import DetachedInstanceError
get_chart_preview_module = importlib.import_module(
"superset.mcp_service.chart.tool.get_chart_preview"
)
mock_chart = MagicMock()
mock_chart.id = 7
mock_chart.slice_name = "Broken Chart"
mock_chart.viz_type = "bar"
mock_chart.datasource_id = 3
mock_chart.datasource_type = "table"
mock_chart.params = "{}"
with (
patch.object(
get_chart_preview_module,
"find_chart_by_identifier",
return_value=mock_chart,
),
patch.object(
get_chart_preview_module.db,
"session",
**{"refresh.return_value": None},
),
patch.object(
get_chart_preview_module,
"validate_chart_dataset",
return_value=MagicMock(is_valid=True, warnings=[]),
),
patch.object(
get_chart_preview_module.event_logger,
"log_context",
return_value=nullcontext(),
),
# Simulate the session expiring inside the strategy
patch.object(
get_chart_preview_module.PreviewFormatGenerator,
"generate",
side_effect=DetachedInstanceError(),
),
patch(
"superset.mcp_service.utils.url_utils.get_superset_base_url",
return_value="http://localhost",
),
):
from fastmcp import Client
from superset.mcp_service.app import mcp
from superset.mcp_service.chart.schemas import GetChartPreviewRequest
from superset.utils import json
with patch("superset.mcp_service.auth.get_user_from_request") as mu:
mu.return_value = MagicMock(id=1, username="admin")
with patch(
"superset.mcp_service.auth.check_tool_permission", return_value=True
):
async with Client(mcp) as client:
response = await client.call_tool(
"get_chart_preview",
{
"request": GetChartPreviewRequest(
identifier=7, format="ascii"
).model_dump()
},
)
data = json.loads(response.content[0].text)
assert data["error_type"] == "InternalError"
assert "session" in data["error"].lower() or "retry" in data["error"].lower()

View File

@@ -1175,6 +1175,11 @@ class TestUpdateChartValidationGate:
)
@patch("superset.daos.chart.ChartDAO.find_by_id", new_callable=Mock)
@patch("superset.db.session")
@patch(
"superset.mcp_service.chart.validation.dataset_validator"
".DatasetValidator.validate_against_dataset",
new=Mock(return_value=(True, None)),
)
@pytest.mark.asyncio
async def test_preview_path_validation_failure_skips_cache(
self,
@@ -1238,6 +1243,11 @@ class TestUpdateChartValidationGate:
)
@patch("superset.daos.chart.ChartDAO.find_by_id", new_callable=Mock)
@patch("superset.db.session")
@patch(
"superset.mcp_service.chart.validation.dataset_validator"
".DatasetValidator.validate_against_dataset",
new=Mock(return_value=(True, None)),
)
@pytest.mark.asyncio
async def test_persist_path_validation_failure_skips_db_write(
self,

View File

@@ -117,83 +117,6 @@ class TestGetCanonicalColumnName:
assert result == "unknown_column"
class TestNormalizeXYConfig:
"""Test _normalize_xy_config static method."""
def test_normalize_x_axis_column(
self, mock_dataset_context: DatasetContext
) -> None:
"""Test that x-axis column name is normalized."""
config_dict: Dict[str, Any] = {
"chart_type": "xy",
"x": {"name": "orderdate"},
"y": [{"name": "Sales", "aggregate": "SUM"}],
"kind": "line",
}
DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context)
assert config_dict["x"]["name"] == "OrderDate"
def test_normalize_y_axis_columns(
self, mock_dataset_context: DatasetContext
) -> None:
"""Test that y-axis column names are normalized."""
config_dict: Dict[str, Any] = {
"chart_type": "xy",
"x": {"name": "OrderDate"},
"y": [
{"name": "sales", "aggregate": "SUM"},
{"name": "QUANTITY_ORDERED", "aggregate": "COUNT"},
],
"kind": "bar",
}
DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context)
assert config_dict["y"][0]["name"] == "Sales"
assert config_dict["y"][1]["name"] == "quantity_ordered"
def test_normalize_group_by_column(
self, mock_dataset_context: DatasetContext
) -> None:
"""Test that group_by column name is normalized."""
config_dict: Dict[str, Any] = {
"chart_type": "xy",
"x": {"name": "OrderDate"},
"y": [{"name": "Sales", "aggregate": "SUM"}],
"kind": "line",
"group_by": [{"name": "productline"}],
}
DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context)
assert config_dict["group_by"][0]["name"] == "ProductLine"
class TestNormalizeTableConfig:
"""Test _normalize_table_config static method."""
def test_normalize_table_columns(
self, mock_dataset_context: DatasetContext
) -> None:
"""Test that table column names are normalized."""
config_dict: Dict[str, Any] = {
"chart_type": "table",
"columns": [
{"name": "orderdate"},
{"name": "PRODUCTLINE"},
{"name": "sales", "aggregate": "SUM"},
],
}
DatasetValidator._normalize_table_config(config_dict, mock_dataset_context)
assert config_dict["columns"][0]["name"] == "OrderDate"
assert config_dict["columns"][1]["name"] == "ProductLine"
assert config_dict["columns"][2]["name"] == "Sales"
class TestNormalizeFilters:
"""Test _normalize_filters static method."""

View File

@@ -58,12 +58,12 @@ class TestRuntimeValidatorNonBlocking:
x_axis=AxisConfig(format="$,.2f"), # Currency format for date - mismatch
)
# Mock the format validator to return warnings
# Mock the plugin runtime dispatcher to return format warnings
with patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_format_compatibility"
) as mock_format:
mock_format.return_value = [
"_validate_plugin_runtime"
) as mock_plugin:
mock_plugin.return_value = [
"Currency format '$,.2f' may not display dates correctly"
]
@@ -87,15 +87,14 @@ class TestRuntimeValidatorNonBlocking:
kind="bar",
)
# Mock the cardinality validator to return warnings
# Mock the plugin runtime dispatcher to return cardinality warnings
with patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_cardinality"
) as mock_cardinality:
mock_cardinality.return_value = (
["High cardinality detected: 10000+ unique values"],
["Consider using aggregation or filtering"],
)
"_validate_plugin_runtime"
) as mock_plugin:
mock_plugin.return_value = [
"High cardinality detected: 10000+ unique values"
]
is_valid, warnings_metadata = RuntimeValidator.validate_runtime_issues(
config, 1
@@ -148,26 +147,21 @@ class TestRuntimeValidatorNonBlocking:
x_axis=AxisConfig(format="smart_date"), # Wrong format for user_id
)
# Mock all validators to return warnings
# Mock plugin runtime and chart type validators to return warnings
with (
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_format_compatibility"
) as mock_format,
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_cardinality"
) as mock_cardinality,
"_validate_plugin_runtime"
) as mock_plugin,
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_chart_type"
) as mock_type,
):
mock_format.return_value = ["Format mismatch warning"]
mock_cardinality.return_value = (
["High cardinality warning"],
["Cardinality suggestion"],
)
mock_plugin.return_value = [
"Format mismatch warning",
"High cardinality warning",
]
mock_type.return_value = (
["Chart type warning"],
["Chart type suggestion"],
@@ -197,13 +191,13 @@ class TestRuntimeValidatorNonBlocking:
with (
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_format_compatibility"
) as mock_format,
"_validate_plugin_runtime"
) as mock_plugin,
patch(
"superset.mcp_service.chart.validation.runtime.logger"
) as mock_logger,
):
mock_format.return_value = ["Test warning message"]
mock_plugin.return_value = ["Test warning message"]
is_valid, warnings_metadata = RuntimeValidator.validate_runtime_issues(
config, 1
@@ -217,7 +211,7 @@ class TestRuntimeValidatorNonBlocking:
assert "warnings" in warnings_metadata
def test_validate_table_chart_skips_xy_validations(self):
"""Test that table charts skip XY-specific validations."""
"""Test that table charts produce no XY-specific runtime warnings."""
config = TableChartConfig(
chart_type="table",
columns=[
@@ -226,28 +220,15 @@ class TestRuntimeValidatorNonBlocking:
],
)
# These should not be called for table charts
with (
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_format_compatibility"
) as mock_format,
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_cardinality"
) as mock_cardinality,
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_chart_type"
) as mock_chart_type,
):
# Mock chart type validator to return no warnings
# Plugin runtime dispatches to TableChartPlugin which returns no warnings.
# Chart type suggester is also stubbed to return no warnings.
with patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_chart_type"
) as mock_chart_type:
mock_chart_type.return_value = ([], [])
is_valid, error = RuntimeValidator.validate_runtime_issues(config, 1)
# Format and cardinality validation should not be called for table charts
mock_format.assert_not_called()
mock_cardinality.assert_not_called()
assert is_valid is True
assert error is None

View File

@@ -146,7 +146,13 @@ class TestResponseSizeGuardMiddleware:
@pytest.mark.asyncio
async def test_logs_warning_at_threshold(self) -> None:
"""Should log warning when approaching limit."""
"""Should log warning when approaching limit.
Mocks the token estimator to return a specific value above the
warn threshold but below the hard limit, decoupling the test
from whichever tokenizer (tiktoken or char heuristic) happens
to be loaded.
"""
middleware = ResponseSizeGuardMiddleware(
token_limit=1000, warn_threshold_pct=80
)
@@ -155,18 +161,21 @@ class TestResponseSizeGuardMiddleware:
context.message.name = "list_charts"
context.message.params = {}
# Response at ~85% of limit (should trigger warning but not block)
response = {"data": "x" * 2900} # ~828 tokens at 3.5 chars/token
response = {"data": "approaching the limit"}
call_next = AsyncMock(return_value=response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
patch(
"superset.mcp_service.middleware.estimate_response_tokens",
return_value=850,
),
patch("superset.mcp_service.middleware.logger") as mock_logger,
):
result = await middleware.on_call_tool(context, call_next)
# Should return response (not blocked)
# Should return response (not blocked at 85% of limit)
assert result == response
# Should log warning
mock_logger.warning.assert_called()

View File

@@ -20,9 +20,11 @@ Unit tests for MCP service token utilities.
"""
from typing import Any, List
from unittest.mock import patch
from pydantic import BaseModel
from superset.mcp_service.utils import token_utils
from superset.mcp_service.utils.token_utils import (
_replace_collections_with_summaries,
_summarize_large_dicts,
@@ -45,29 +47,65 @@ class TestEstimateTokenCount:
"""Test estimate_token_count function."""
def test_estimate_string(self) -> None:
"""Should estimate tokens for a string."""
"""Should produce a positive non-zero estimate for a normal string.
We don't assert on a specific number because the result depends on
which tokenizer is loaded (tiktoken when available, char heuristic
otherwise).
"""
text = "Hello world"
result = estimate_token_count(text)
expected = int(len(text) / CHARS_PER_TOKEN)
assert result == expected
assert result > 0
def test_estimate_bytes(self) -> None:
"""Should estimate tokens for bytes."""
text = b"Hello world"
result = estimate_token_count(text)
expected = int(len(text) / CHARS_PER_TOKEN)
assert result == expected
"""Bytes input should be decoded and produce the same count as the
equivalent string."""
text = "Hello world"
assert estimate_token_count(text.encode("utf-8")) == estimate_token_count(text)
def test_empty_string(self) -> None:
"""Should return 0 for empty string."""
"""Should return 0 for empty string and empty bytes."""
assert estimate_token_count("") == 0
assert estimate_token_count(b"") == 0
def test_json_like_content(self) -> None:
"""Should estimate tokens for JSON-like content."""
"""JSON content should produce a positive estimate."""
json_str = '{"name": "test", "value": 123, "items": [1, 2, 3]}'
result = estimate_token_count(json_str)
assert result > 0
assert result == int(len(json_str) / CHARS_PER_TOKEN)
assert estimate_token_count(json_str) > 0
def test_long_text_roughly_scales_with_length(self) -> None:
"""A doubled string should produce roughly double the token count
(within ±10%)."""
small = "the quick brown fox jumps over the lazy dog. " * 20
large = small * 2
small_n = estimate_token_count(small)
large_n = estimate_token_count(large)
# Within 10% of 2x — both tokenizers (tiktoken and the char
# fallback) preserve length monotonicity.
assert 1.8 * small_n <= large_n <= 2.2 * small_n
def test_fallback_uses_chars_per_token_when_tiktoken_unavailable(
self,
) -> None:
"""When the tiktoken encoding is None (not installed), the
function falls back to len/CHARS_PER_TOKEN math."""
text = "x" * 100
with patch.object(token_utils, "_ENCODING", None):
result = estimate_token_count(text)
assert result == int(100 / CHARS_PER_TOKEN)
def test_fallback_when_tiktoken_encode_raises(self) -> None:
"""A misbehaving encoding should fall back to the char heuristic
rather than raise — the size guard must never fail-open."""
class BoomEncoding:
def encode(self, text: str) -> list[int]:
raise ValueError("simulated tiktoken failure")
text = "abc" * 50
with patch.object(token_utils, "_ENCODING", BoomEncoding()):
result = estimate_token_count(text)
assert result == int(len(text) / CHARS_PER_TOKEN)
class TestEstimateResponseTokens:

View File

@@ -33,7 +33,8 @@ def test_opensearch_dialect_registered() -> None:
def test_double_quotes_as_identifiers() -> None:
"""
Test that double quotes are treated as identifiers, not string literals.
Test that double quotes are treated as identifiers, not string literals,
and normalized to backticks in output.
"""
sql = 'SELECT "AvgTicketPrice" FROM "flights"'
ast = sqlglot.parse_one(sql, OpenSearch)
@@ -42,8 +43,8 @@ def test_double_quotes_as_identifiers() -> None:
OpenSearch().generate(expression=ast, pretty=True)
== """
SELECT
"AvgTicketPrice"
FROM "flights"
`AvgTicketPrice`
FROM `flights`
""".strip()
)
@@ -69,8 +70,7 @@ WHERE
def test_backticks_as_identifiers() -> None:
"""
Test that backticks work as identifiers (MySQL-style).
Backticks are normalized to double quotes in output.
Test that backticks are accepted as identifiers and preserved on output.
"""
sql = "SELECT `AvgTicketPrice` FROM `flights`"
ast = sqlglot.parse_one(sql, OpenSearch)
@@ -79,15 +79,16 @@ def test_backticks_as_identifiers() -> None:
OpenSearch().generate(expression=ast, pretty=True)
== """
SELECT
"AvgTicketPrice"
FROM "flights"
`AvgTicketPrice`
FROM `flights`
""".strip()
)
def test_mixed_identifier_quotes() -> None:
"""
Test mixing double quotes and backticks for identifiers.
Test mixing double quotes and backticks for identifiers are all normalized to
backticks on output.
"""
sql = 'SELECT "AvgTicketPrice" AS `AvgTicketPrice` FROM `default`.`flights`'
ast = sqlglot.parse_one(sql, OpenSearch)
@@ -96,12 +97,26 @@ def test_mixed_identifier_quotes() -> None:
OpenSearch().generate(expression=ast, pretty=True)
== """
SELECT
"AvgTicketPrice" AS "AvgTicketPrice"
FROM "default"."flights"
`AvgTicketPrice` AS `AvgTicketPrice`
FROM `default`.`flights`
""".strip()
)
def test_alias_with_space() -> None:
"""
Test that an alias containing a space (e.g. a metric key like ``my test``)
is preserved as a backtick-quoted identifier through the round-trip.
"""
sql = 'SELECT COUNT(*) AS "my test" FROM `flights`'
ast = sqlglot.parse_one(sql, OpenSearch)
assert (
OpenSearch().generate(expression=ast, pretty=False)
== "SELECT COUNT(*) AS `my test` FROM `flights`"
)
@pytest.mark.parametrize(
"sql, expected",
[
@@ -110,20 +125,20 @@ FROM "default"."flights"
"""
SELECT
COUNT(*)
FROM "flights"
FROM `flights`
WHERE
"Cancelled" = TRUE
`Cancelled` = TRUE
""".strip(),
),
(
'SELECT "Carrier", SUM("AvgTicketPrice") FROM "flights" GROUP BY "Carrier"',
"""
SELECT
"Carrier",
SUM("AvgTicketPrice")
FROM "flights"
`Carrier`,
SUM(`AvgTicketPrice`)
FROM `flights`
GROUP BY
"Carrier"
`Carrier`
""".strip(),
),
(
@@ -131,9 +146,9 @@ GROUP BY
"""
SELECT
*
FROM "flights"
FROM `flights`
WHERE
"DestCountry" IN ('US', 'CA', 'MX')
`DestCountry` IN ('US', 'CA', 'MX')
""".strip(),
),
],
@@ -165,13 +180,13 @@ GROUP BY "Carrier"
OpenSearch().generate(expression=ast, pretty=True)
== """
SELECT
"Carrier",
`Carrier`,
COUNT(*),
AVG("AvgTicketPrice"),
MAX("FlightDelayMin")
FROM "flights"
AVG(`AvgTicketPrice`),
MAX(`FlightDelayMin`)
FROM `flights`
GROUP BY
"Carrier"
`Carrier`
""".strip()
)
@@ -190,10 +205,10 @@ SELECT
*
FROM (
SELECT
"Carrier",
"AvgTicketPrice"
FROM "flights"
) AS "sub"
`Carrier`,
`AvgTicketPrice`
FROM `flights`
) AS `sub`
""".strip()
)
@@ -212,12 +227,12 @@ def test_order_by_with_quoted_identifiers() -> None:
OpenSearch().generate(expression=ast, pretty=True)
== """
SELECT
"Carrier",
"AvgTicketPrice"
FROM "flights"
`Carrier`,
`AvgTicketPrice`
FROM `flights`
ORDER BY
"AvgTicketPrice" DESC,
"Carrier" ASC
`AvgTicketPrice` DESC,
`Carrier` ASC
""".strip()
)
@@ -234,7 +249,7 @@ def test_limit_clause() -> None:
== """
SELECT
*
FROM "flights"
FROM `flights`
LIMIT 10
""".strip()
)