Compare commits

...

4 Commits

Author SHA1 Message Date
dependabot[bot]
f3bb8e0e0a chore(deps): bump docusaurus-plugin-openapi-docs in /docs
Bumps [docusaurus-plugin-openapi-docs](https://github.com/PaloAltoNetworks/docusaurus-openapi-docs/tree/HEAD/packages/docusaurus-plugin-openapi-docs) from 5.0.2 to 5.1.0.
- [Release notes](https://github.com/PaloAltoNetworks/docusaurus-openapi-docs/releases)
- [Changelog](https://github.com/PaloAltoNetworks/docusaurus-openapi-docs/blob/main/CHANGELOG.md)
- [Commits](https://github.com/PaloAltoNetworks/docusaurus-openapi-docs/commits/v5.1.0/packages/docusaurus-plugin-openapi-docs)

---
updated-dependencies:
- dependency-name: docusaurus-plugin-openapi-docs
  dependency-version: 5.1.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-07-02 07:04:50 +00:00
Amin Ghadersohi
c3f5e997a1 feat(mcp): chart type plugin registry for extensible generate_chart (#39922) 2026-07-02 00:31:19 -04:00
dependabot[bot]
d507be2555 chore(deps): bump geostyler-openlayers-parser from 5.7.0 to 5.7.1 in /superset-frontend (#41615)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Joe Li <joe@preset.io>
2026-07-02 10:04:12 +07:00
dependabot[bot]
e3bd6e5c70 chore(deps-dev): bump @playwright/test from 1.61.0 to 1.61.1 in /superset-frontend (#41616)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Đỗ Trọng Hải <41283691+hainenber@users.noreply.github.com>
Co-authored-by: Joe Li <joe@preset.io>
2026-07-02 10:02:20 +07:00
45 changed files with 3636 additions and 1052 deletions

0
.pre-commit-config.yaml Normal file → Executable file
View File

View File

@@ -74,7 +74,7 @@
"antd": "^6.4.5",
"baseline-browser-mapping": "^2.10.38",
"caniuse-lite": "^1.0.30001799",
"docusaurus-plugin-openapi-docs": "^5.0.2",
"docusaurus-plugin-openapi-docs": "^5.1.0",
"docusaurus-theme-openapi-docs": "^5.0.2",
"js-yaml": "^5.1.0",
"js-yaml-loader": "^1.2.2",

View File

@@ -7163,10 +7163,10 @@ doctrine@^2.1.0:
dependencies:
esutils "^2.0.2"
docusaurus-plugin-openapi-docs@^5.0.2:
version "5.0.2"
resolved "https://registry.yarnpkg.com/docusaurus-plugin-openapi-docs/-/docusaurus-plugin-openapi-docs-5.0.2.tgz#f00028621deb9179065fe7d6a541256692ef941b"
integrity sha512-WCC2m6PpylXZfNga+ScelTG0a7jUGtbB9+AmbR9lUj93FPryTs8VHTMJ3fKtO0senJTWgOU3MDvZw0v+mE3ztA==
docusaurus-plugin-openapi-docs@^5.1.0:
version "5.1.0"
resolved "https://registry.yarnpkg.com/docusaurus-plugin-openapi-docs/-/docusaurus-plugin-openapi-docs-5.1.0.tgz#9732f81f45a5bc126bcafcb150332b7623ddece7"
integrity sha512-ocRemE3KmdhhPKaow5hja1m1NLIPfNlfRYFt7pja+nG26Wlp0MEC9ERS99gSWEWaxU+txDFBpUXsxo7nGlk8ZA==
dependencies:
"@apidevtools/json-schema-ref-parser" "^15.3.3"
"@redocly/openapi-core" "^2.25.2"

View File

@@ -98,7 +98,7 @@
"geolib": "^3.3.14",
"geostyler": "^18.6.0",
"geostyler-data": "^1.1.0",
"geostyler-openlayers-parser": "^5.7.0",
"geostyler-openlayers-parser": "^5.7.1",
"geostyler-style": "11.0.2",
"geostyler-wfs-parser": "^3.0.1",
"google-auth-library": "^10.7.0",
@@ -180,7 +180,7 @@
"@emotion/jest": "^11.14.2",
"@formatjs/intl-durationformat": "^0.10.15",
"@istanbuljs/nyc-config-typescript": "^1.0.1",
"@playwright/test": "^1.61.0",
"@playwright/test": "^1.61.1",
"@pmmmwh/react-refresh-webpack-plugin": "^0.6.2",
"@storybook/addon-docs": "10.4.5",
"@storybook/addon-links": "10.4.4",
@@ -8825,13 +8825,13 @@
}
},
"node_modules/@playwright/test": {
"version": "1.61.0",
"resolved": "https://registry.npmjs.org/@playwright/test/-/test-1.61.0.tgz",
"integrity": "sha512-cKA5B6lpFEMyMGjxF54QihfYpB4FkEGH+qZhtArDEG+wezQAJY8Pq6C7T1SjWz+FFzt3TbyoXBQYk/0292TdJA==",
"version": "1.61.1",
"resolved": "https://registry.npmjs.org/@playwright/test/-/test-1.61.1.tgz",
"integrity": "sha512-8nKv6+0RJSL9FE4jYOEGXnPeM/Hg12qZpmqzZjRh3qM0Y7c3z1mrOTfFLids72RDQYVh9WpLEfR5WdpNX4fkig==",
"dev": true,
"license": "Apache-2.0",
"dependencies": {
"playwright": "1.61.0"
"playwright": "1.61.1"
},
"bin": {
"playwright": "cli.js"
@@ -21481,9 +21481,9 @@
}
},
"node_modules/geostyler-openlayers-parser": {
"version": "5.7.0",
"resolved": "https://registry.npmjs.org/geostyler-openlayers-parser/-/geostyler-openlayers-parser-5.7.0.tgz",
"integrity": "sha512-FRNTNPLoKJzKYnWas+E4hb4h38SGaK3KeNPZmLUqO5EcTootJjAJyTbCy/Cuv9afk56HYIBpM2gHh6q/fLwqsg==",
"version": "5.7.1",
"resolved": "https://registry.npmjs.org/geostyler-openlayers-parser/-/geostyler-openlayers-parser-5.7.1.tgz",
"integrity": "sha512-GKkFdki1XbNIWS8onAU2CatGCJ/BB3QzknligxTXtTuLOa6Gqp2RshgExt3BzVyQOlXMmb8zmhtd5Z0CbvrrgA==",
"license": "BSD-2-Clause",
"dependencies": {
"css-font-parser": "^2.0.0",
@@ -33684,13 +33684,13 @@
}
},
"node_modules/playwright": {
"version": "1.61.0",
"resolved": "https://registry.npmjs.org/playwright/-/playwright-1.61.0.tgz",
"integrity": "sha512-Z+7BeeqQPRRzklHsVFP4KTGIyMxKUmfeRA4WisM6G3/XW6nwGeX6fX9qYaDa+CiUqpOkb2f6X3nar05R3kSuJQ==",
"version": "1.61.1",
"resolved": "https://registry.npmjs.org/playwright/-/playwright-1.61.1.tgz",
"integrity": "sha512-DWnY5o3YbLWK4GovuAVwpqL+1VwGNdUGrRr++8j8PtQQzvAVZUIMjKQ90fY689sEJZJBbZVw1rXaOKSTitkzPQ==",
"dev": true,
"license": "Apache-2.0",
"dependencies": {
"playwright-core": "1.61.0"
"playwright-core": "1.61.1"
},
"bin": {
"playwright": "cli.js"
@@ -33703,9 +33703,9 @@
}
},
"node_modules/playwright-core": {
"version": "1.61.0",
"resolved": "https://registry.npmjs.org/playwright-core/-/playwright-core-1.61.0.tgz",
"integrity": "sha512-caX7TrY3Ml6egyDX0WUcTHDxodl/b51y5wJOdCEA36QviK/s2g081hvmGs8eaE3DWb6NYZQ6BjO/QkNRPenoPA==",
"version": "1.61.1",
"resolved": "https://registry.npmjs.org/playwright-core/-/playwright-core-1.61.1.tgz",
"integrity": "sha512-h7Qlt6m4REp25qvIdvbDtVmD4LqVXfpRxhORv9L0jzETM05p4fuPJ3dKyuSXQxDSbXnmS79HAgi9589lGSpLkg==",
"dev": true,
"license": "Apache-2.0",
"bin": {

View File

@@ -181,7 +181,7 @@
"geolib": "^3.3.14",
"geostyler": "^18.6.0",
"geostyler-data": "^1.1.0",
"geostyler-openlayers-parser": "^5.7.0",
"geostyler-openlayers-parser": "^5.7.1",
"geostyler-style": "11.0.2",
"geostyler-wfs-parser": "^3.0.1",
"google-auth-library": "^10.7.0",
@@ -263,7 +263,7 @@
"@emotion/jest": "^11.14.2",
"@formatjs/intl-durationformat": "^0.10.15",
"@istanbuljs/nyc-config-typescript": "^1.0.1",
"@playwright/test": "^1.61.0",
"@playwright/test": "^1.61.1",
"@pmmmwh/react-refresh-webpack-plugin": "^0.6.2",
"@storybook/addon-docs": "10.4.5",
"@storybook/addon-links": "10.4.4",

View File

@@ -357,10 +357,12 @@ Time grain for temporal x-axis (time_grain parameter):
- PT1H (hourly), P1D (daily), P1W (weekly), P1M (monthly), P1Y (yearly)
Chart Types in Existing Charts (viewable via list_charts/get_chart_info):
- pie, big_number, big_number_total, funnel, gauge_chart
- echarts_timeseries_line, echarts_timeseries_bar, echarts_timeseries_area
- pivot_table_v2, heatmap_v2, sankey_v2, sunburst_v2, treemap_v2
- word_cloud, world_map, box_plot, bubble, mixed_timeseries
Each chart returned by list_charts / get_chart_info includes a
chart_type_display_name field with a human-readable name when available.
This field is populated only for the 7 chart types supported by generate_chart
(xy, pie, table, pivot_table, big_number, mixed_timeseries, handlebars).
For all other viz_types (Funnel, Gauge, Heatmap, etc.) it will be null —
use the raw viz_type field instead when referring to those chart types.
Query Examples:
- List all tables:
@@ -674,6 +676,7 @@ warnings.filterwarnings(
# NOTE: Always add new prompt/resource imports here when creating new prompts/resources.
# Prompts use @mcp.prompt decorators and resources use @mcp.resource decorators.
# They register automatically on import, similar to tools.
import superset.mcp_service.chart.plugins # noqa: F401, E402 — registers all chart type plugins
from superset.mcp_service.annotation_layer.tool import ( # noqa: F401, E402
get_annotation_layer_info,
get_layer_annotation_info,

View File

@@ -353,29 +353,44 @@ def map_config_to_form_data(
| BigNumberChartConfig,
dataset_id: int | str | None = None,
) -> Dict[str, Any]:
"""Map chart config to Superset form_data."""
if isinstance(config, TableChartConfig):
return map_table_config(config)
elif isinstance(config, XYChartConfig):
return map_xy_config(config, dataset_id=dataset_id)
elif isinstance(config, PieChartConfig):
return map_pie_config(config)
elif isinstance(config, PivotTableChartConfig):
return map_pivot_table_config(config)
elif isinstance(config, MixedTimeseriesChartConfig):
return map_mixed_timeseries_config(config, dataset_id=dataset_id)
elif isinstance(config, HandlebarsChartConfig):
return map_handlebars_config(config)
elif isinstance(config, BigNumberChartConfig):
if config.show_trendline and config.temporal_column:
if not is_column_truly_temporal(config.temporal_column, dataset_id):
raise ValueError(
f"Big Number trendline requires a temporal SQL column; "
f"'{config.temporal_column}' is not temporal."
)
return map_big_number_config(config)
else:
raise ValueError(f"Unsupported config type: {type(config)}")
"""Map chart config to Superset form_data via the plugin registry.
The previous if/elif chain across all 7 chart types has been replaced by a
single registry lookup. Cross-field constraints (e.g. BigNumber trendline
temporal check) are now owned by each plugin's post_map_validate() method
rather than being baked into this dispatcher.
"""
# Local import: plugins call map_*_config from their to_form_data() methods,
# so chart_utils is loaded before plugins finish registering. A top-level
# import of registry here would trigger plugin loading mid-import = cycle.
from superset.mcp_service.chart.registry import get_registry
chart_type = getattr(config, "chart_type", None)
plugin = get_registry().get(chart_type) if chart_type else None
if plugin is None:
if chart_type is None:
raise ValueError(f"Unsupported config type: {type(config)}")
raise ValueError(
f"Unsupported config type: {type(config)} (chart_type={chart_type!r})"
)
form_data = plugin.to_form_data(config, dataset_id=dataset_id)
# Run post-map validation (e.g. BigNumber trendline temporal type check).
# Raise ValueError to preserve backward-compatible error handling in callers.
# Include details and suggestions so callers logging str(e) surface actionable
# context (e.g. BigNumber trendline guidance) rather than just the headline.
error = plugin.post_map_validate(config, form_data, dataset_id=dataset_id)
if error is not None:
parts = [error.message]
if error.details:
parts.append(error.details)
if error.suggestions:
parts.append("Suggestions: " + "; ".join(error.suggestions))
raise ValueError(" ".join(parts))
return form_data
def _add_adhoc_filters(
@@ -1279,87 +1294,32 @@ def _big_number_chart_what(config: BigNumberChartConfig) -> str:
def generate_chart_name(
config: TableChartConfig
| XYChartConfig
| PieChartConfig
| PivotTableChartConfig
| MixedTimeseriesChartConfig
| HandlebarsChartConfig
| BigNumberChartConfig,
config: Any,
dataset_name: str | None = None,
) -> str:
"""Generate a descriptive chart name following a standard format.
Format conventions (by chart type):
Aggregated (bar/scatter with group_by): [Metric] by [Dimension]
Time-series (line/area, no group_by): [Metric] Over Time
Table (no aggregates): [Dataset] Records
Table (with aggregates): [Metric] Summary
Pie: [Dimension] by [Metric]
Pivot Table: Pivot Table [Row1, Row2]
Mixed Timeseries: [Primary] + [Secondary]
An en-dash followed by context (filters / time grain) is appended
Delegates to each plugin's ``generate_name()`` method.
See each plugin's ``generate_name`` for chart-type-specific format conventions.
An en-dash followed by context (filters / time grain) is appended by the plugin
when such information is available.
"""
if isinstance(config, TableChartConfig):
what = _table_chart_what(config, dataset_name)
context = _summarize_filters(config.filters)
elif isinstance(config, XYChartConfig):
what = _xy_chart_what(config)
context = _xy_chart_context(config)
elif isinstance(config, PieChartConfig):
what = _pie_chart_what(config)
context = _summarize_filters(config.filters)
elif isinstance(config, PivotTableChartConfig):
what = _pivot_table_what(config)
context = _summarize_filters(config.filters)
elif isinstance(config, MixedTimeseriesChartConfig):
what = _mixed_timeseries_what(config)
context = _summarize_filters(config.filters)
elif isinstance(config, HandlebarsChartConfig):
what = _handlebars_chart_what(config)
context = _summarize_filters(getattr(config, "filters", None))
elif isinstance(config, BigNumberChartConfig):
what = _big_number_chart_what(config)
context = _summarize_filters(getattr(config, "filters", None))
else:
return "Chart"
from superset.mcp_service.chart.registry import get_registry
name = what
if context:
name = f"{what} \u2013 {context}"
return _truncate(name)
plugin = get_registry().get(getattr(config, "chart_type", ""))
if plugin is None:
return "Chart"
return _truncate(plugin.generate_name(config, dataset_name))
def _resolve_viz_type(config: Any) -> str:
"""Resolve the Superset viz_type from a chart config object."""
chart_type = getattr(config, "chart_type", "unknown")
if chart_type == "xy":
kind = getattr(config, "kind", "line")
viz_type_map = {
"line": "echarts_timeseries_line",
"bar": "echarts_timeseries_bar",
"area": "echarts_area",
"scatter": "echarts_timeseries_scatter",
}
return viz_type_map.get(kind, "echarts_timeseries_line")
elif chart_type == "table":
return getattr(config, "viz_type", "table")
elif chart_type == "pie":
return "pie"
elif chart_type == "pivot_table":
return "pivot_table_v2"
elif chart_type == "mixed_timeseries":
return "mixed_timeseries"
elif chart_type == "handlebars":
return "handlebars"
elif chart_type == "big_number":
show_trendline = getattr(config, "show_trendline", False)
temporal_column = getattr(config, "temporal_column", None)
return (
"big_number" if show_trendline and temporal_column else "big_number_total"
)
return "unknown"
from superset.mcp_service.chart.registry import get_registry
plugin = get_registry().get(getattr(config, "chart_type", ""))
if plugin is None:
return "unknown"
return plugin.resolve_viz_type(config)
TABLE_VIZ_TYPE_LABELS = {

View File

@@ -40,7 +40,10 @@ from sqlalchemy.exc import SQLAlchemyError
from superset.commands.exceptions import CommandException
from superset.errors import SupersetErrorType
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
from superset.mcp_service.chart.validation.dataset_validator import (
build_dataset_context_from_orm,
DatasetValidator,
)
from superset.mcp_service.common.error_schemas import (
ChartGenerationError,
ColumnSuggestion,
@@ -94,52 +97,6 @@ class CompileResult:
row_count: int | None = None
def build_dataset_context_from_orm(dataset: Any) -> DatasetContext | None:
"""Construct a ``DatasetContext`` from an already-fetched ORM dataset.
Mirrors :py:meth:`DatasetValidator._get_dataset_context` but skips the
``DatasetDAO.find_by_id`` round trip. Callers that have already loaded
the dataset (for permission checks, etc.) should use this instead.
"""
if dataset is None:
return None
columns: List[Dict[str, Any]] = []
for col in getattr(dataset, "columns", []) or []:
columns.append(
{
"name": col.column_name,
"type": str(col.type) if col.type else "UNKNOWN",
"is_temporal": getattr(col, "is_temporal", False),
"is_numeric": getattr(col, "is_numeric", False),
}
)
metrics: List[Dict[str, Any]] = []
for metric in getattr(dataset, "metrics", []) or []:
metrics.append(
{
"name": metric.metric_name,
"expression": metric.expression,
"description": metric.description,
}
)
database = getattr(dataset, "database", None)
# ``DatasetContext.database_name`` is typed as required ``str``; default to
# an empty string when the relationship isn't loaded so we don't blow up
# Pydantic validation. The field is purely informational in error messages.
database_name = getattr(database, "database_name", None) or ""
return DatasetContext(
id=dataset.id,
table_name=dataset.table_name,
schema=dataset.schema,
database_name=database_name,
available_columns=columns,
available_metrics=metrics,
)
def _compile_chart(
form_data: Dict[str, Any],
dataset_id: int,

View File

@@ -0,0 +1,262 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
ChartTypePlugin protocol and BaseChartPlugin base class.
Each chart type owns its pre-validation, column extraction, form_data mapping,
and post-map validation in a single plugin class. This eliminates the previous
pattern of 4 separate dispatch points (schema_validator.py, dataset_validator.py,
chart_utils.py, pipeline.py) that had to be updated in sync whenever a new chart
type was added.
"""
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, ClassVar, Protocol, runtime_checkable
from superset.mcp_service.chart.schemas import ColumnRef
from superset.mcp_service.common.error_schemas import ChartGenerationError
@runtime_checkable
class ChartTypePlugin(Protocol):
"""
Protocol that every chart-type plugin must satisfy.
Implementing all nine methods in a single class guarantees that adding a
new chart type requires only one new file — the plugin — rather than edits
across multiple separate files.
"""
#: Discriminator value matching ChartConfig's chart_type field.
chart_type: str
#: Human-readable name shown to users (e.g. "Line / Bar / Area / Scatter").
display_name: str
#: Maps every Superset-internal viz_type this plugin can produce to a
#: user-facing display name, e.g. {"echarts_timeseries_line": "Line Chart"}.
#: Used by the registry to resolve display names for existing charts without
#: needing a separate JSON mapping file.
native_viz_types: ClassVar[Mapping[str, str]]
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
"""
Early validation of the raw config dict before Pydantic parsing.
Called by SchemaValidator before attempting to parse the request.
Should check that required top-level keys are present and well-typed.
Returns None if valid, ChartGenerationError if invalid.
"""
...
def extract_column_refs(
self,
config: Any,
) -> list[ColumnRef]:
"""
Extract all column references from a parsed chart config.
Called by DatasetValidator to validate that all referenced columns exist
in the dataset. Must cover every field that holds a column name,
including filters.
Returns a list of ColumnRef objects (may be empty).
"""
...
def to_form_data(
self,
config: Any,
dataset_id: int | str | None = None,
) -> dict[str, Any]:
"""
Map a parsed chart config to Superset's internal form_data dict.
Replaces the if/elif chain in chart_utils.map_config_to_form_data().
Returns a Superset form_data dict ready for caching and rendering.
"""
...
def post_map_validate(
self,
config: Any,
form_data: dict[str, Any],
dataset_id: int | str | None = None,
) -> ChartGenerationError | None:
"""
Validate the mapped form_data after to_form_data() runs.
Use this for cross-field constraints that can only be checked once
form_data is assembled (e.g. BigNumber trendline requires a temporal
column whose type must be verified against the dataset).
Returns None if valid, ChartGenerationError if invalid.
"""
...
def normalize_column_refs(
self,
config: Any,
dataset_context: Any,
) -> Any:
"""
Return a new config with column names normalized to canonical dataset casing.
Called by DatasetValidator.normalize_column_names(). The default
implementation (in BaseChartPlugin) returns the config unchanged; plugins
with column fields override this to fix case sensitivity mismatches.
Returns a new config object (or the original if no normalization needed).
"""
...
def get_runtime_warnings(
self,
config: Any,
dataset_id: int | str,
) -> tuple[list[str], list[str]]:
"""
Return chart-type-specific runtime warnings and suggestions.
Called by RuntimeValidator to collect per-type warnings. Warnings are
informational only — they never block chart generation. The default
implementation returns empty lists; plugins override this to emit
chart-type-specific warnings (e.g. XY cardinality checks).
Returns a (warnings, suggestions) tuple — both may be empty.
"""
...
def generate_name(
self,
config: Any,
dataset_name: str | None = None,
) -> str:
"""
Return a descriptive chart name for the given config.
Called by chart_utils.generate_chart_name(). The name should follow
the standard format conventions documented in that function. Plugins
that do not override this return the generic fallback "Chart".
"""
...
def resolve_viz_type(self, config: Any) -> str:
"""
Return the Superset-internal viz_type string for this config.
Called by chart_utils._resolve_viz_type(). The returned string must
match a registered Superset viz plugin (e.g. "echarts_timeseries_line").
Plugins that do not override this return "unknown".
"""
...
def schema_error_hint(self) -> ChartGenerationError | None:
"""
Return a user-friendly error for Pydantic discriminated-union parse failures.
Called by SchemaValidator when Pydantic cannot parse the config union and
the chart_type is known. Returning None falls back to the generic error.
"""
...
class BaseChartPlugin:
"""
Base class providing sensible defaults for all ChartTypePlugin methods.
Concrete plugins extend this and override only what they need. Default
implementations: ``pre_validate`` → None (valid), ``extract_column_refs`` → [],
``post_map_validate`` → None, ``normalize_column_refs`` → config unchanged,
``get_runtime_warnings`` → ([], []), ``generate_name`` → "Chart",
``resolve_viz_type`` → "unknown", ``schema_error_hint`` → None.
``to_form_data`` raises ``NotImplementedError`` and must be overridden.
"""
chart_type: str = ""
display_name: str = ""
# Subclasses must override this with their own class attribute.
native_viz_types: ClassVar[Mapping[str, str]] = {}
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
return None
def extract_column_refs(
self,
config: Any,
) -> list[ColumnRef]:
return []
def to_form_data(
self,
config: Any,
dataset_id: int | str | None = None,
) -> dict[str, Any]:
raise NotImplementedError(
f"{self.__class__.__name__}.to_form_data() is not implemented"
)
def post_map_validate(
self,
config: Any,
form_data: dict[str, Any],
dataset_id: int | str | None = None,
) -> ChartGenerationError | None:
return None
def normalize_column_refs(
self,
config: Any,
dataset_context: Any,
) -> Any:
return config
def get_runtime_warnings(
self,
config: Any,
dataset_id: int | str,
) -> tuple[list[str], list[str]]:
return [], []
def generate_name(
self,
config: Any,
dataset_name: str | None = None,
) -> str:
return "Chart"
def resolve_viz_type(self, config: Any) -> str:
return "unknown"
def schema_error_hint(self) -> ChartGenerationError | None:
return None
@staticmethod
def _with_context(what: str, context: str | None) -> str:
"""Combine a 'what' label and optional context with an en-dash."""
return f"{what} {context}" if context else what

View File

@@ -0,0 +1,58 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Chart type plugins package.
Importing this module registers all built-in chart type plugins in the global
registry. This module is imported by app.py at startup.
To add a new chart type:
1. Create ``superset/mcp_service/chart/plugins/{chart_type}.py``
2. Implement a class extending ``BaseChartPlugin``
3. Import and register it here
"""
from superset.mcp_service.chart.plugins.big_number import BigNumberChartPlugin
from superset.mcp_service.chart.plugins.handlebars import HandlebarsChartPlugin
from superset.mcp_service.chart.plugins.mixed_timeseries import (
MixedTimeseriesChartPlugin,
)
from superset.mcp_service.chart.plugins.pie import PieChartPlugin
from superset.mcp_service.chart.plugins.pivot_table import PivotTableChartPlugin
from superset.mcp_service.chart.plugins.table import TableChartPlugin
from superset.mcp_service.chart.plugins.xy import XYChartPlugin
from superset.mcp_service.chart.registry import register
# Register all built-in chart type plugins
register(XYChartPlugin())
register(TableChartPlugin())
register(PieChartPlugin())
register(PivotTableChartPlugin())
register(MixedTimeseriesChartPlugin())
register(HandlebarsChartPlugin())
register(BigNumberChartPlugin())
__all__ = [
"BigNumberChartPlugin",
"HandlebarsChartPlugin",
"MixedTimeseriesChartPlugin",
"PieChartPlugin",
"PivotTableChartPlugin",
"TableChartPlugin",
"XYChartPlugin",
]

View File

@@ -0,0 +1,248 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Big number chart type plugin."""
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, ClassVar
from superset.mcp_service.chart.chart_utils import (
_big_number_chart_what,
_summarize_filters,
is_column_truly_temporal,
map_big_number_config,
)
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.schemas import BigNumberChartConfig, ColumnRef
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
from superset.mcp_service.common.error_schemas import ChartGenerationError
class BigNumberChartPlugin(BaseChartPlugin):
"""Plugin for big_number chart type."""
chart_type = "big_number"
display_name = "Big Number"
native_viz_types: ClassVar[Mapping[str, str]] = {
"big_number": "Big Number with Trendline",
"big_number_total": "Big Number",
}
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
if "metric" not in config:
return ChartGenerationError(
error_type="missing_metric",
message="Big Number chart missing required field: metric",
details=(
"Big Number charts require a 'metric' field "
"specifying the value to display"
),
suggestions=[
"Add 'metric' with name and aggregate: "
"{'name': 'revenue', 'aggregate': 'SUM'}",
"The aggregate function is required (SUM, COUNT, AVG, MIN, MAX)",
"Example: {'chart_type': 'big_number', "
"'metric': {'name': 'sales', 'aggregate': 'SUM'}}",
],
error_code="MISSING_BIG_NUMBER_METRIC",
)
metric = config.get("metric", {})
if not isinstance(metric, dict):
return ChartGenerationError(
error_type="invalid_metric_type",
message="Big Number metric must be a dict with 'name' and 'aggregate'",
details=(
f"The 'metric' field must be an object, got {type(metric).__name__}"
),
suggestions=[
"Use a dict: {'name': 'col', 'aggregate': 'SUM'}",
"Valid aggregates: SUM, COUNT, AVG, MIN, MAX",
],
error_code="INVALID_BIG_NUMBER_METRIC_TYPE",
)
if metric.get("sql_expression"):
label = metric.get("label")
if not isinstance(label, str) or not label.strip():
return ChartGenerationError(
error_type="missing_sql_metric_label",
message="SQL expression metrics require a non-empty 'label'",
details=(
"When using a custom SQL expression as the Big Number metric, "
"a human-readable 'label' string is required so Superset can "
"display the metric name."
),
suggestions=[
"Add 'label': e.g. {'sql_expression': 'SUM(a)/SUM(b)', "
"'label': 'Conversion Rate'}",
"The label must be a non-empty string",
],
error_code="MISSING_SQL_METRIC_LABEL",
)
elif not metric.get("aggregate") and not metric.get("saved_metric"):
return ChartGenerationError(
error_type="missing_metric_aggregate",
message=(
"Big Number metric must include an aggregate function "
"or reference a saved metric"
),
details=(
"The metric must have an 'aggregate' field or 'saved_metric': true"
),
suggestions=[
"Add 'aggregate': {'name': 'col', 'aggregate': 'SUM'}",
"Or use a saved metric: {'name': 'metric', 'saved_metric': true}",
"Valid aggregates: SUM, COUNT, AVG, MIN, MAX",
],
error_code="MISSING_BIG_NUMBER_AGGREGATE",
)
show_trendline = config.get("show_trendline", False)
temporal_column = config.get("temporal_column")
if show_trendline and not temporal_column:
return ChartGenerationError(
error_type="missing_temporal_column",
message="Trendline requires a temporal column",
details=(
"When 'show_trendline' is True, "
"a 'temporal_column' must be specified"
),
suggestions=[
"Add 'temporal_column': 'date_column_name'",
"Or set 'show_trendline': false for number only",
"Use get_dataset_info to find temporal columns",
],
error_code="MISSING_TEMPORAL_COLUMN",
)
return None
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
if not isinstance(config, BigNumberChartConfig):
return []
refs: list[ColumnRef] = [config.metric]
# temporal_column is a str field, not a ColumnRef — validate it exists
if config.temporal_column:
refs.append(ColumnRef(name=config.temporal_column))
if config.filters:
for f in config.filters:
refs.append(ColumnRef(name=f.column))
return refs
def to_form_data(
self, config: Any, dataset_id: int | str | None = None
) -> dict[str, Any]:
return map_big_number_config(config)
def post_map_validate(
self,
config: Any,
form_data: dict[str, Any],
dataset_id: int | str | None = None,
) -> ChartGenerationError | None:
"""Verify the trendline temporal column is a real temporal SQL type.
This check was previously baked into map_config_to_form_data() in
chart_utils.py as a special case. Moving it here keeps the dispatcher
clean and makes the constraint explicit and discoverable.
"""
if not isinstance(config, BigNumberChartConfig):
return None
if not (config.show_trendline and config.temporal_column):
return None
if not is_column_truly_temporal(config.temporal_column, dataset_id):
return ChartGenerationError(
error_type="non_temporal_trendline_column",
message=(
f"Big Number trendline requires a temporal SQL column; "
f"'{config.temporal_column}' is not temporal."
),
details=(
f"Column '{config.temporal_column}' does not have a temporal "
f"SQL type (DATE, DATETIME, TIMESTAMP). The trendline requires "
f"a true temporal column for DATE_TRUNC to work."
),
suggestions=[
"Use get_dataset_info to find columns with temporal SQL types",
"Set 'show_trendline': false to use any column as the metric",
"If the column contains dates stored as integers, "
"consider casting it in a virtual dataset",
],
error_code="NON_TEMPORAL_TRENDLINE_COLUMN",
)
return None
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
what = _big_number_chart_what(config)
context = _summarize_filters(getattr(config, "filters", None))
return self._with_context(what, context)
def resolve_viz_type(self, config: Any) -> str:
show_trendline = getattr(config, "show_trendline", False)
temporal_column = getattr(config, "temporal_column", None)
if show_trendline and temporal_column:
return "big_number"
return "big_number_total"
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
config_dict = config.model_dump()
if config_dict.get("metric"):
if config_dict["metric"].get("sql_expression"):
pass
elif config_dict["metric"].get("saved_metric"):
config_dict["metric"]["name"] = (
DatasetValidator.get_canonical_metric_name(
config_dict["metric"]["name"], dataset_context
)
)
else:
config_dict["metric"]["name"] = (
DatasetValidator.get_canonical_column_name(
config_dict["metric"]["name"], dataset_context
)
)
if config_dict.get("temporal_column"):
config_dict["temporal_column"] = DatasetValidator.get_canonical_column_name(
config_dict["temporal_column"], dataset_context
)
DatasetValidator.normalize_filters(config_dict, dataset_context)
return BigNumberChartConfig.model_validate(config_dict)
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="big_number_validation_error",
message="Big Number chart configuration validation failed",
details=(
"The Big Number chart configuration is missing required "
"fields or has invalid structure"
),
suggestions=[
"Ensure 'metric' field has 'name' and 'aggregate'",
"Example: 'metric': {'name': 'revenue', 'aggregate': 'SUM'}",
"For trendline: add show_trendline=true and temporal_column='col'",
"Without trendline: just provide the metric",
],
error_code="BIG_NUMBER_VALIDATION_ERROR",
)

View File

@@ -0,0 +1,194 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Handlebars chart type plugin."""
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, ClassVar
from superset.mcp_service.chart.chart_utils import (
_handlebars_chart_what,
_summarize_filters,
map_handlebars_config,
)
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.schemas import ColumnRef, HandlebarsChartConfig
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
from superset.mcp_service.common.error_schemas import ChartGenerationError
class HandlebarsChartPlugin(BaseChartPlugin):
"""Plugin for handlebars chart type (custom HTML template charts)."""
chart_type = "handlebars"
display_name = "Handlebars (Custom Template)"
native_viz_types: ClassVar[Mapping[str, str]] = {
"handlebars": "Custom Template Chart",
}
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
if "handlebars_template" not in config:
return ChartGenerationError(
error_type="missing_handlebars_template",
message="Handlebars chart missing required field: handlebars_template",
details=(
"Handlebars charts require a 'handlebars_template' string "
"containing Handlebars HTML template markup"
),
suggestions=[
"Add 'handlebars_template' with a Handlebars HTML template",
"Data is available as {{data}} array in the template",
"Example: '<ul>{{#each data}}<li>{{this.name}}: "
"{{this.value}}</li>{{/each}}</ul>'",
],
error_code="MISSING_HANDLEBARS_TEMPLATE",
)
template = config.get("handlebars_template")
if not isinstance(template, str) or not template.strip():
return ChartGenerationError(
error_type="invalid_handlebars_template",
message="Handlebars template must be a non-empty string",
details=(
"The 'handlebars_template' field must be a non-empty string "
"containing valid Handlebars HTML template markup"
),
suggestions=[
"Ensure handlebars_template is a non-empty string",
"Example: '<ul>{{#each data}}<li>{{this.name}}</li>{{/each}}</ul>'",
],
error_code="INVALID_HANDLEBARS_TEMPLATE",
)
query_mode = config.get("query_mode", "aggregate")
if query_mode not in ("aggregate", "raw"):
return ChartGenerationError(
error_type="invalid_query_mode",
message="Invalid query_mode for handlebars chart",
details="query_mode must be either 'aggregate' or 'raw'",
suggestions=[
"Use 'aggregate' for aggregated data (default)",
"Use 'raw' for individual rows",
],
error_code="INVALID_QUERY_MODE",
)
if query_mode == "raw" and not config.get("columns"):
return ChartGenerationError(
error_type="missing_raw_columns",
message="Handlebars chart in 'raw' mode requires 'columns'",
details=(
"When query_mode is 'raw', you must specify which columns "
"to include in the query results"
),
suggestions=[
"Add 'columns': [{'name': 'column_name'}] for raw mode",
"Or use query_mode='aggregate' with 'metrics' and optional 'groupby'", # noqa: E501
],
error_code="MISSING_RAW_COLUMNS",
)
if query_mode == "aggregate" and not config.get("metrics"):
return ChartGenerationError(
error_type="missing_aggregate_metrics",
message="Handlebars chart in 'aggregate' mode requires 'metrics'",
details=(
"When query_mode is 'aggregate' (default), you must specify "
"at least one metric with an aggregate function"
),
suggestions=[
"Add 'metrics': [{'name': 'column', 'aggregate': 'SUM'}]",
"Or use query_mode='raw' with 'columns' for individual rows",
],
error_code="MISSING_AGGREGATE_METRICS",
)
return None
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
if not isinstance(config, HandlebarsChartConfig):
return []
refs: list[ColumnRef] = []
if config.columns:
refs.extend(config.columns)
if config.metrics:
refs.extend(config.metrics)
if config.groupby:
refs.extend(config.groupby)
if config.filters:
for f in config.filters:
refs.append(ColumnRef(name=f.column))
return refs
def to_form_data(
self, config: Any, dataset_id: int | str | None = None
) -> dict[str, Any]:
return map_handlebars_config(config)
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
what = _handlebars_chart_what(config)
context = _summarize_filters(getattr(config, "filters", None))
return self._with_context(what, context)
def resolve_viz_type(self, config: Any) -> str:
return "handlebars"
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
config_dict = config.model_dump()
def _norm_list(key: str) -> None:
if config_dict.get(key):
for col in config_dict[key]:
if col.get("saved_metric"):
col["name"] = DatasetValidator.get_canonical_metric_name(
col["name"], dataset_context
)
elif not col.get("sql_expression"):
col["name"] = DatasetValidator.get_canonical_column_name(
col["name"], dataset_context
)
_norm_list("columns")
_norm_list("metrics")
_norm_list("groupby")
DatasetValidator.normalize_filters(config_dict, dataset_context)
return HandlebarsChartConfig.model_validate(config_dict)
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="handlebars_validation_error",
message="Handlebars chart configuration validation failed",
details=(
"The handlebars chart configuration is missing "
"required fields or has invalid structure"
),
suggestions=[
"Ensure 'handlebars_template' is a non-empty string",
"For aggregate mode: add 'metrics' with aggregate functions",
"For raw mode: set 'query_mode': 'raw' and add 'columns'",
"Example: {'chart_type': 'handlebars', "
"'handlebars_template': "
"'<ul>{{#each data}}<li>{{this.name}}</li>{{/each}}</ul>', "
"'metrics': [{'name': 'sales', 'aggregate': 'SUM'}]}",
],
error_code="HANDLEBARS_VALIDATION_ERROR",
)

View File

@@ -0,0 +1,173 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Mixed timeseries chart type plugin."""
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, ClassVar
from superset.mcp_service.chart.chart_utils import (
_mixed_timeseries_what,
_summarize_filters,
map_mixed_timeseries_config,
)
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.schemas import ColumnRef, MixedTimeseriesChartConfig
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
from superset.mcp_service.common.error_schemas import ChartGenerationError
class MixedTimeseriesChartPlugin(BaseChartPlugin):
"""Plugin for mixed_timeseries chart type."""
chart_type = "mixed_timeseries"
display_name = "Mixed Timeseries"
native_viz_types: ClassVar[Mapping[str, str]] = {
"mixed_timeseries": "Mixed Timeseries Chart",
}
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
missing_fields = []
if "x" not in config and "x_axis" not in config:
missing_fields.append("'x' (X-axis temporal column)")
if not config.get("y") and not config.get("metrics"):
missing_fields.append("'y' (primary Y-axis metrics)")
if not config.get("y_secondary") and not config.get("metrics_b"):
missing_fields.append("'y_secondary' (secondary Y-axis metrics)")
if missing_fields:
return ChartGenerationError(
error_type="missing_mixed_timeseries_fields",
message=(
f"Mixed timeseries chart missing required fields: "
f"{', '.join(missing_fields)}"
),
details=(
"Mixed timeseries charts require an x-axis, primary metrics, "
"and secondary metrics"
),
suggestions=[
"Add 'x' field: {'name': 'date_column'}",
"Add 'y' field: [{'name': 'revenue', 'aggregate': 'SUM'}]",
"Add 'y_secondary': [{'name': 'orders', 'aggregate': 'COUNT'}]",
"Optional: 'primary_kind' and 'secondary_kind' for chart types",
],
error_code="MISSING_MIXED_TIMESERIES_FIELDS",
)
for field_name in ["y", "y_secondary"]:
if not isinstance(config.get(field_name, []), list):
return ChartGenerationError(
error_type=f"invalid_{field_name}_format",
message=f"'{field_name}' must be a list of metrics",
details=(
f"The '{field_name}' field must be an array of metric "
"specifications"
),
suggestions=[
f"Wrap in array: '{field_name}': "
"[{'name': 'col', 'aggregate': 'SUM'}]",
],
error_code=f"INVALID_{field_name.upper()}_FORMAT",
)
return None
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
if not isinstance(config, MixedTimeseriesChartConfig):
return []
refs: list[ColumnRef] = [config.x]
refs.extend(config.y)
refs.extend(config.y_secondary)
if config.group_by:
refs.extend(config.group_by)
if config.group_by_secondary:
refs.extend(config.group_by_secondary)
if config.filters:
for f in config.filters:
refs.append(ColumnRef(name=f.column))
return refs
def to_form_data(
self, config: Any, dataset_id: int | str | None = None
) -> dict[str, Any]:
return map_mixed_timeseries_config(config, dataset_id=dataset_id)
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
what = _mixed_timeseries_what(config)
context = _summarize_filters(config.filters)
return self._with_context(what, context)
def resolve_viz_type(self, config: Any) -> str:
return "mixed_timeseries"
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
config_dict = config.model_dump()
def _norm_single(key: str) -> None:
if config_dict.get(key):
config_dict[key]["name"] = DatasetValidator.get_canonical_column_name(
config_dict[key]["name"], dataset_context
)
def _norm_list(key: str) -> None:
if config_dict.get(key):
for col in config_dict[key]:
if col.get("sql_expression"):
continue
if col.get("saved_metric"):
col["name"] = DatasetValidator.get_canonical_metric_name(
col["name"], dataset_context
)
else:
col["name"] = DatasetValidator.get_canonical_column_name(
col["name"], dataset_context
)
_norm_single("x")
_norm_list("y")
_norm_list("y_secondary")
_norm_list("group_by")
_norm_list("group_by_secondary")
DatasetValidator.normalize_filters(config_dict, dataset_context)
return MixedTimeseriesChartConfig.model_validate(config_dict)
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="mixed_timeseries_validation_error",
message="Mixed timeseries chart configuration validation failed",
details=(
"The mixed timeseries configuration is missing "
"required fields or has invalid structure"
),
suggestions=[
"Ensure 'x' field has 'name' for the time axis column",
"Ensure 'y' is an array of primary-axis metrics",
"Ensure 'y_secondary' is an array of secondary-axis metrics",
"Example: {'chart_type': 'mixed_timeseries', "
"'x': {'name': 'order_date'}, "
"'y': [{'name': 'revenue', 'aggregate': 'SUM'}], "
"'y_secondary': [{'name': 'orders', 'aggregate': 'COUNT'}]}",
],
error_code="MIXED_TIMESERIES_VALIDATION_ERROR",
)

View File

@@ -0,0 +1,140 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Pie chart type plugin."""
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, ClassVar
from superset.mcp_service.chart.chart_utils import (
_pie_chart_what,
_summarize_filters,
map_pie_config,
)
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.schemas import ColumnRef, PieChartConfig
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
from superset.mcp_service.common.error_schemas import ChartGenerationError
class PieChartPlugin(BaseChartPlugin):
"""Plugin for pie chart type."""
chart_type = "pie"
display_name = "Pie / Donut Chart"
native_viz_types: ClassVar[Mapping[str, str]] = {
"pie": "Pie Chart",
}
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
missing_fields = []
if "dimension" not in config and "groupby" not in config:
missing_fields.append("'dimension' (category column for slices)")
if "metric" not in config:
missing_fields.append("'metric' (value metric for slice sizes)")
if missing_fields:
return ChartGenerationError(
error_type="missing_pie_fields",
message=(
f"Pie chart missing required fields: {', '.join(missing_fields)}"
),
details=(
"Pie charts require a dimension (categories) and a metric (values)"
),
suggestions=[
"Add 'dimension' field: {'name': 'category_column'}",
"Add 'metric' field: {'name': 'value_column', 'aggregate': 'SUM'}",
"Example: {'chart_type': 'pie', 'dimension': {'name': 'product'}, "
"'metric': {'name': 'revenue', 'aggregate': 'SUM'}}",
],
error_code="MISSING_PIE_FIELDS",
)
return None
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
if not isinstance(config, PieChartConfig):
return []
refs: list[ColumnRef] = [config.dimension, config.metric]
if config.filters:
for f in config.filters:
refs.append(ColumnRef(name=f.column))
return refs
def to_form_data(
self, config: Any, dataset_id: int | str | None = None
) -> dict[str, Any]:
return map_pie_config(config)
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
what = _pie_chart_what(config)
context = _summarize_filters(config.filters)
return self._with_context(what, context)
def resolve_viz_type(self, config: Any) -> str:
return "pie"
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
config_dict = config.model_dump()
if config_dict.get("dimension"):
dim = config_dict["dimension"]
if not dim.get("sql_expression") and not dim.get("saved_metric"):
dim["name"] = DatasetValidator.get_canonical_column_name(
dim["name"], dataset_context
)
if config_dict.get("metric"):
if config_dict["metric"].get("sql_expression"):
pass
elif config_dict["metric"].get("saved_metric"):
config_dict["metric"]["name"] = (
DatasetValidator.get_canonical_metric_name(
config_dict["metric"]["name"], dataset_context
)
)
else:
config_dict["metric"]["name"] = (
DatasetValidator.get_canonical_column_name(
config_dict["metric"]["name"], dataset_context
)
)
DatasetValidator.normalize_filters(config_dict, dataset_context)
return PieChartConfig.model_validate(config_dict)
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="pie_validation_error",
message="Pie chart configuration validation failed",
details=(
"The pie chart configuration is missing required "
"fields or has invalid structure"
),
suggestions=[
"Ensure 'dimension' field has 'name' for the slice label",
"Ensure 'metric' field has 'name' and 'aggregate'",
"Example: {'chart_type': 'pie', 'dimension': {'name': 'category'}, "
"'metric': {'name': 'revenue', 'aggregate': 'SUM'}}",
],
error_code="PIE_VALIDATION_ERROR",
)

View File

@@ -0,0 +1,164 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Pivot table chart type plugin."""
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, ClassVar
from superset.mcp_service.chart.chart_utils import (
_pivot_table_what,
_summarize_filters,
map_pivot_table_config,
)
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.schemas import ColumnRef, PivotTableChartConfig
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
from superset.mcp_service.common.error_schemas import ChartGenerationError
class PivotTableChartPlugin(BaseChartPlugin):
"""Plugin for pivot_table chart type."""
chart_type = "pivot_table"
display_name = "Pivot Table"
native_viz_types: ClassVar[Mapping[str, str]] = {
"pivot_table_v2": "Pivot Table",
}
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
missing_fields = []
if not (config.get("rows") or config.get("groupby") or config.get("dimension")):
missing_fields.append("'rows' (row grouping columns)")
if not config.get("metrics"):
missing_fields.append("'metrics' (aggregation metrics)")
if missing_fields:
return ChartGenerationError(
error_type="missing_pivot_fields",
message=(
f"Pivot table missing required fields: {', '.join(missing_fields)}"
),
details="Pivot tables require row groupings and metrics",
suggestions=[
"Add 'rows' field: [{'name': 'category'}]",
"Add 'metrics' field: [{'name': 'sales', 'aggregate': 'SUM'}]",
"Optional 'columns' for cross-tabulation: [{'name': 'region'}]",
],
error_code="MISSING_PIVOT_FIELDS",
)
rows_val = (
config.get("rows") or config.get("groupby") or config.get("dimension") or []
)
if not isinstance(rows_val, list):
return ChartGenerationError(
error_type="invalid_rows_format",
message="Rows must be a list of columns",
details="The 'rows' field must be an array of column specifications",
suggestions=[
"Wrap row columns in array: 'rows': [{'name': 'category'}]",
],
error_code="INVALID_ROWS_FORMAT",
)
if not isinstance(config.get("metrics", []), list):
return ChartGenerationError(
error_type="invalid_metrics_format",
message="Metrics must be a list",
details="The 'metrics' field must be an array of metric specifications",
suggestions=[
"Wrap metrics in array: 'metrics': [{'name': 'sales', "
"'aggregate': 'SUM'}]",
],
error_code="INVALID_METRICS_FORMAT",
)
return None
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
if not isinstance(config, PivotTableChartConfig):
return []
refs: list[ColumnRef] = list(config.rows)
refs.extend(config.metrics)
if config.columns:
refs.extend(config.columns)
if config.filters:
for f in config.filters:
refs.append(ColumnRef(name=f.column))
return refs
def to_form_data(
self, config: Any, dataset_id: int | str | None = None
) -> dict[str, Any]:
return map_pivot_table_config(config)
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
what = _pivot_table_what(config)
context = _summarize_filters(config.filters)
return self._with_context(what, context)
def resolve_viz_type(self, config: Any) -> str:
return "pivot_table_v2"
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
config_dict = config.model_dump()
def _norm_col_list(key: str) -> None:
if config_dict.get(key):
for col in config_dict[key]:
if col.get("sql_expression"):
continue
if col.get("saved_metric"):
col["name"] = DatasetValidator.get_canonical_metric_name(
col["name"], dataset_context
)
else:
col["name"] = DatasetValidator.get_canonical_column_name(
col["name"], dataset_context
)
_norm_col_list("rows")
_norm_col_list("metrics")
_norm_col_list("columns")
DatasetValidator.normalize_filters(config_dict, dataset_context)
return PivotTableChartConfig.model_validate(config_dict)
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="pivot_table_validation_error",
message="Pivot table configuration validation failed",
details=(
"The pivot table configuration is missing required "
"fields or has invalid structure"
),
suggestions=[
"Ensure 'rows' field is an array of column specs",
"Ensure 'metrics' field is an array with aggregate funcs",
"Optional: add 'columns' for column grouping",
"Example: {'chart_type': 'pivot_table', "
"'rows': [{'name': 'region'}], "
"'metrics': [{'name': 'revenue', 'aggregate': 'SUM'}]}",
],
error_code="PIVOT_TABLE_VALIDATION_ERROR",
)

View File

@@ -0,0 +1,136 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Table chart type plugin."""
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, ClassVar
from superset.mcp_service.chart.chart_utils import (
_summarize_filters,
_table_chart_what,
map_table_config,
)
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.schemas import ColumnRef, TableChartConfig
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
from superset.mcp_service.common.error_schemas import ChartGenerationError
class TableChartPlugin(BaseChartPlugin):
"""Plugin for table chart type."""
chart_type = "table"
display_name = "Table"
native_viz_types: ClassVar[Mapping[str, str]] = {
"table": "Table",
"ag-grid-table": "Interactive Table",
}
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
columns = (
config.get("columns") or config.get("all_columns") or config.get("groupby")
)
if not columns:
return ChartGenerationError(
error_type="missing_columns",
message="Table chart missing required field: columns",
details=(
"Table charts require a 'columns' array to specify which "
"columns to display"
),
suggestions=[
"Add 'columns' field with array of column specifications",
"Example: 'columns': [{'name': 'product'}, {'name': 'sales', "
"'aggregate': 'SUM'}]",
"Each column can have optional 'aggregate' for metrics",
],
error_code="MISSING_COLUMNS",
)
if not isinstance(columns, list):
return ChartGenerationError(
error_type="invalid_columns_format",
message="Columns must be a list",
details="The 'columns' field must be an array of column specifications",
suggestions=[
"Ensure columns is an array: 'columns': [...]",
"Each column should be an object with 'name' field",
],
error_code="INVALID_COLUMNS_FORMAT",
)
return None
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
if not isinstance(config, TableChartConfig):
return []
refs: list[ColumnRef] = list(config.columns)
if config.filters:
for f in config.filters:
refs.append(ColumnRef(name=f.column))
return refs
def to_form_data(
self, config: Any, dataset_id: int | str | None = None
) -> dict[str, Any]:
return map_table_config(config)
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
what = _table_chart_what(config, dataset_name)
context = _summarize_filters(config.filters)
return self._with_context(what, context)
def resolve_viz_type(self, config: Any) -> str:
return getattr(config, "viz_type", "table")
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
config_dict = config.model_dump()
get_canonical = DatasetValidator.get_canonical_column_name
get_canonical_metric = DatasetValidator.get_canonical_metric_name
for col in config_dict.get("columns") or []:
if col.get("saved_metric"):
col["name"] = get_canonical_metric(col["name"], dataset_context)
elif not col.get("sql_expression"):
col["name"] = get_canonical(col["name"], dataset_context)
DatasetValidator.normalize_filters(config_dict, dataset_context)
return TableChartConfig.model_validate(config_dict)
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="table_validation_error",
message="Table chart configuration validation failed",
details=(
"The table chart configuration is missing required "
"fields or has invalid structure"
),
suggestions=[
"Ensure 'columns' field is an array of column specifications",
"Each column needs {'name': 'column_name'}",
"Optional: add 'aggregate' for metrics",
"Example: 'columns': [{'name': 'product'}, "
"{'name': 'sales', 'aggregate': 'SUM'}]",
],
error_code="TABLE_VALIDATION_ERROR",
)

View File

@@ -0,0 +1,202 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""XY chart type plugin (line, bar, area, scatter)."""
from __future__ import annotations
import logging
from collections.abc import Mapping
from typing import Any, ClassVar
from superset.mcp_service.chart.chart_utils import (
_xy_chart_context,
_xy_chart_what,
map_xy_config,
)
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.schemas import ColumnRef, XYChartConfig
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
from superset.mcp_service.chart.validation.runtime.cardinality_validator import (
CardinalityValidator,
)
from superset.mcp_service.chart.validation.runtime.format_validator import (
FormatTypeValidator,
)
from superset.mcp_service.common.error_schemas import ChartGenerationError
logger = logging.getLogger(__name__)
class XYChartPlugin(BaseChartPlugin):
"""Plugin for xy chart type (line, bar, area, scatter)."""
chart_type = "xy"
display_name = "Line / Bar / Area / Scatter Chart"
native_viz_types: ClassVar[Mapping[str, str]] = {
"echarts_timeseries_line": "Line Chart",
"echarts_timeseries_bar": "Bar Chart",
"echarts_area": "Area Chart",
"echarts_timeseries_scatter": "Scatter Plot",
}
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
# x is optional — defaults to dataset's main_dttm_col in map_xy_config
if not config.get("y") and not config.get("metrics"):
return ChartGenerationError(
error_type="missing_xy_fields",
message="XY chart missing required field: 'y' (Y-axis metrics)",
details=(
"XY charts require Y-axis (metrics) specifications. "
"X-axis is optional and defaults to the dataset's primary "
"datetime column when omitted."
),
suggestions=[
"Add 'y' field: [{'name': 'metric_column', 'aggregate': 'SUM'}]",
"Example: {'chart_type': 'xy', 'x': {'name': 'date'}, "
"'y': [{'name': 'sales', 'aggregate': 'SUM'}]}",
],
error_code="MISSING_XY_FIELDS",
)
if not isinstance(config.get("y", []), list):
return ChartGenerationError(
error_type="invalid_y_format",
message="Y-axis must be a list of metrics",
details="The 'y' field must be an array of metric specifications",
suggestions=[
"Wrap Y-axis metric in array: 'y': [{'name': 'column', "
"'aggregate': 'SUM'}]",
"Multiple metrics supported: 'y': [metric1, metric2, ...]",
],
error_code="INVALID_Y_FORMAT",
)
return None
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
if not isinstance(config, XYChartConfig):
return []
refs: list[ColumnRef] = []
if config.x is not None:
refs.append(config.x)
refs.extend(config.y)
if config.group_by:
refs.extend(config.group_by)
if config.filters:
for f in config.filters:
refs.append(ColumnRef(name=f.column))
return refs
def to_form_data(
self, config: Any, dataset_id: int | str | None = None
) -> dict[str, Any]:
return map_xy_config(config, dataset_id=dataset_id)
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
config_dict = config.model_dump()
get_canonical = DatasetValidator.get_canonical_column_name
get_canonical_metric = DatasetValidator.get_canonical_metric_name
if config_dict.get("x"):
config_dict["x"]["name"] = get_canonical(
config_dict["x"]["name"], dataset_context
)
for y_col in config_dict.get("y") or []:
if y_col.get("sql_expression"):
continue # sql_expression metrics have no underlying column
if y_col.get("saved_metric"):
y_col["name"] = get_canonical_metric(y_col["name"], dataset_context)
else:
y_col["name"] = get_canonical(y_col["name"], dataset_context)
for gb_col in config_dict.get("group_by") or []:
gb_col["name"] = get_canonical(gb_col["name"], dataset_context)
DatasetValidator.normalize_filters(config_dict, dataset_context)
return XYChartConfig.model_validate(config_dict)
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
what = _xy_chart_what(config)
context = _xy_chart_context(config)
return self._with_context(what, context)
def resolve_viz_type(self, config: Any) -> str:
kind = getattr(config, "kind", "line")
return {
"line": "echarts_timeseries_line",
"bar": "echarts_timeseries_bar",
"area": "echarts_area",
"scatter": "echarts_timeseries_scatter",
}.get(kind, "echarts_timeseries_line")
def get_runtime_warnings(
self, config: Any, dataset_id: int | str
) -> tuple[list[str], list[str]]:
"""Return format-compatibility and cardinality warnings for XY charts."""
if not isinstance(config, XYChartConfig):
return [], []
warnings: list[str] = []
suggestions: list[str] = []
try:
_valid, format_warnings = FormatTypeValidator.validate_format_compatibility(
config
)
if format_warnings:
warnings.extend(format_warnings)
except Exception as exc: # noqa: BLE001 — non-blocking warning path
logger.warning("XY format validation failed: %s", exc)
try:
chart_kind = config.kind
group_by_col = config.group_by[0].name if config.group_by else None
if config.x is not None and config.x.name is not None:
_ok, card_info = CardinalityValidator.check_cardinality(
dataset_id=dataset_id,
x_column=config.x.name,
chart_type=chart_kind,
group_by_column=group_by_col,
)
if not _ok and card_info:
warnings.extend(card_info.get("warnings", []))
suggestions.extend(card_info.get("suggestions", []))
except Exception as exc: # noqa: BLE001 — DB queries may raise infra errors
logger.warning("XY cardinality validation failed: %s", exc)
return warnings, suggestions
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="xy_validation_error",
message="XY chart configuration validation failed",
details=(
"The XY chart configuration is missing required "
"fields or has invalid structure"
),
suggestions=[
"Note: 'x' is optional and defaults to the dataset's primary "
"datetime column",
"Ensure 'y' is an array: [{'name': 'metric', 'aggregate': 'SUM'}]",
"Check that all column names are strings",
"Verify aggregate functions are valid: SUM, COUNT, AVG, MIN, MAX",
],
error_code="XY_VALIDATION_ERROR",
)

View File

@@ -0,0 +1,279 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
ChartTypeRegistry — central registry mapping chart_type strings to plugins.
Replaces the four previously-scattered dispatch locations:
- schema_validator.py: chart_type_validators dict
- dataset_validator.py: isinstance branches in _extract_column_references()
- chart_utils.py: if/elif chain in map_config_to_form_data()
- dataset_validator.py: isinstance branches in normalize_column_names()
Usage::
from superset.mcp_service.chart.registry import get_registry
plugin = get_registry().get("xy")
if plugin is None:
raise ValueError("Unknown chart type: xy")
form_data = plugin.to_form_data(config, dataset_id)
"""
from __future__ import annotations
import logging
import sys
import threading
from collections.abc import Callable, Iterable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from superset.mcp_service.chart.plugin import ChartTypePlugin
logger = logging.getLogger(__name__)
_REGISTRY: dict[str, "ChartTypePlugin"] = {}
_plugins_loaded = False
_plugins_load_failed = False
_plugins_lock = threading.RLock()
# ---------------------------------------------------------------------------
# Plugin filter — replaced atomically by configure() at app startup.
# Default: all registered plugins visible (no disabled set, no callable).
# ---------------------------------------------------------------------------
PluginEnabledFunc = Callable[[str], bool]
@dataclass(frozen=True)
class _PluginFilterConfig:
disabled_plugins: frozenset[str] = field(default_factory=frozenset)
enabled_func: PluginEnabledFunc | None = None
_filter_config: _PluginFilterConfig = _PluginFilterConfig()
def _ensure_plugins_loaded() -> None:
"""Lazily import the plugins package to populate _REGISTRY.
Called before every registry lookup so the registry is always populated,
even when callers (tests, chart_utils, validators) import this module
directly without first importing app.py.
"""
global _plugins_loaded, _plugins_load_failed
if _plugins_loaded or _plugins_load_failed:
return
with _plugins_lock:
if not _plugins_loaded and not _plugins_load_failed:
registry_before_import = dict(_REGISTRY)
try:
import superset.mcp_service.chart.plugins # noqa: F401
_plugins_loaded = True
except Exception: # noqa: BLE001 — plugin import may raise anything
_REGISTRY.clear()
_REGISTRY.update(registry_before_import)
_plugins_load_failed = True
logger.exception(
"Failed to load built-in chart type plugins; "
"further lookups will return None"
)
def configure(
disabled: Iterable[str] | None = None,
enabled_func: PluginEnabledFunc | None = None,
) -> None:
"""Set runtime plugin filters. Called once during app initialization.
Replaces the filter config atomically with a single assignment so concurrent
readers always observe a consistent (disabled_plugins, enabled_func) pair.
Args:
disabled: chart_type strings to suppress. Accepts any iterable (set,
frozenset, list, tuple). Ignored when enabled_func is provided.
enabled_func: callable(chart_type) -> bool. When set, overrides
``disabled``. Must be cheap and in-process — no network I/O per
call. On exception the registry fails *closed* (plugin hidden).
"""
global _filter_config
if enabled_func is not None and not callable(enabled_func):
raise TypeError("enabled_func must be callable or None")
new_config = _PluginFilterConfig(
disabled_plugins=frozenset(disabled or ()),
enabled_func=enabled_func,
)
_filter_config = new_config
if new_config.disabled_plugins:
logger.info(
"MCP chart plugins disabled: %s", sorted(new_config.disabled_plugins)
)
if new_config.enabled_func is not None:
logger.info(
"MCP chart plugin dynamic filter configured: %r", new_config.enabled_func
)
def _is_plugin_enabled(chart_type: str) -> bool:
"""Return True if the plugin is currently enabled (not filtered out)."""
config = _filter_config # read once — atomic reference in CPython
if config.enabled_func is not None:
try:
return bool(config.enabled_func(chart_type))
except Exception: # noqa: BLE001 — operator-supplied callable may raise anything
logger.warning(
"MCP_CHART_PLUGIN_ENABLED_FUNC raised for chart_type=%r; "
"failing closed (plugin hidden)",
chart_type,
exc_info=True,
)
return False
return chart_type not in config.disabled_plugins
def register(plugin: "ChartTypePlugin") -> None:
"""Register a chart type plugin in the global registry."""
if not plugin.chart_type:
raise ValueError(f"{type(plugin).__name__} must define a non-empty chart_type")
with _plugins_lock:
if plugin.chart_type in _REGISTRY:
logger.warning(
"Overwriting existing plugin for chart_type=%r", plugin.chart_type
)
for existing in _REGISTRY.values():
if existing.chart_type == plugin.chart_type:
continue
colliding = plugin.native_viz_types.keys() & existing.native_viz_types
if colliding:
# display_name_for_viz_type() resolves to the first plugin in
# iteration order, making the later registration unreachable.
logger.warning(
"Plugin %r declares native_viz_types %s already claimed by "
"plugin %r; viz_type display-name lookups will resolve to "
"the earlier registration",
plugin.chart_type,
sorted(colliding),
existing.chart_type,
)
_REGISTRY[plugin.chart_type] = plugin
logger.debug("Registered chart plugin: %r", plugin.chart_type)
def get(chart_type: str) -> "ChartTypePlugin | None":
"""Return the plugin for chart_type, or None if unknown or disabled."""
_ensure_plugins_loaded()
if chart_type not in _REGISTRY or not _is_plugin_enabled(chart_type):
return None
return _REGISTRY[chart_type]
def all_types() -> list[str]:
"""Return enabled registered chart type strings in insertion order."""
_ensure_plugins_loaded()
return [ct for ct in _REGISTRY if _is_plugin_enabled(ct)]
def is_registered(chart_type: str) -> bool:
"""Return True if chart_type has a registered plugin, regardless of enabled state.
Use this to distinguish an unknown chart type from a disabled one.
Use is_enabled() to check whether the plugin is currently available.
"""
_ensure_plugins_loaded()
return chart_type in _REGISTRY
def is_enabled(chart_type: str) -> bool:
"""Return True if chart_type is registered AND currently enabled."""
_ensure_plugins_loaded()
return chart_type in _REGISTRY and _is_plugin_enabled(chart_type)
def display_name_for_viz_type(viz_type: str) -> str | None:
"""Return the user-facing display name for a Superset-internal viz_type.
Searches every registered plugin's ``native_viz_types`` mapping.
Returns None if no plugin recognises the viz_type.
Example::
display_name_for_viz_type("echarts_timeseries_line") # "Line Chart"
display_name_for_viz_type("pivot_table_v2") # "Pivot Table"
display_name_for_viz_type("unknown_type") # None
"""
_ensure_plugins_loaded()
for plugin in _REGISTRY.values():
name = plugin.native_viz_types.get(viz_type)
if name is not None:
return name
return None
def _reset_for_testing() -> None:
"""Reset all registry state to defaults.
Only for use in tests that need a clean slate. Calling this in production
will discard all registered plugins and any runtime filter configuration.
**Caller responsibility**: This function pops ``superset.mcp_service.chart.plugins``
from ``sys.modules`` and directly assigns module globals (``_REGISTRY``,
``_plugins_loaded``, etc.). Direct global assignment is NOT automatically
reverted by pytest's ``monkeypatch`` fixture. Callers must either use
``monkeypatch.setattr`` for each global, or call ``_reset_for_testing()`` again
in teardown to restore the clean state. See ``test_registry.py`` for the
recommended ``monkeypatch.setattr`` isolation pattern.
"""
global _REGISTRY, _plugins_loaded, _plugins_load_failed, _filter_config
with _plugins_lock:
_REGISTRY = {}
_plugins_loaded = False
_plugins_load_failed = False
_filter_config = _PluginFilterConfig()
sys.modules.pop("superset.mcp_service.chart.plugins", None)
class _RegistryProxy:
"""Thin proxy exposing registry functions as instance methods."""
def get(self, chart_type: str) -> "ChartTypePlugin | None":
return get(chart_type)
def all_types(self) -> list[str]:
return all_types()
def is_registered(self, chart_type: str) -> bool:
return is_registered(chart_type)
def is_enabled(self, chart_type: str) -> bool:
return is_enabled(chart_type)
def display_name_for_viz_type(self, viz_type: str) -> str | None:
return display_name_for_viz_type(viz_type)
_PROXY = _RegistryProxy()
def get_registry() -> "_RegistryProxy":
"""Return the module-level registry proxy (convenience wrapper)."""
return _PROXY

View File

@@ -22,6 +22,7 @@ Pydantic schemas for chart-related responses
from __future__ import annotations
import difflib
import logging
from datetime import datetime
from typing import Annotated, Any, cast, Dict, List, Literal, Protocol
@@ -68,6 +69,8 @@ from superset.mcp_service.utils.sanitization import (
sanitize_user_input_with_changes,
)
logger = logging.getLogger(__name__)
class ChartLike(Protocol):
"""Protocol for chart-like objects with expected attributes."""
@@ -102,7 +105,15 @@ class ChartInfo(BaseModel):
id: int | None = Field(None, description="Chart ID")
slice_name: str | None = Field(None, description="Chart name")
viz_type: str | None = Field(None, description="Visualization type")
viz_type: str | None = Field(None, description="Visualization type (internal ID)")
chart_type_display_name: str | None = Field(
None,
description=(
"User-friendly chart type name (e.g. 'Line Chart', 'Pivot Table'). "
"Prefer this field when referring to chart types; "
"fall back to viz_type when this field is null."
),
)
datasource_name: str | None = Field(None, description="Datasource name")
datasource_type: str | None = Field(None, description="Datasource type")
url: str | None = Field(None, description="Chart explore page URL")
@@ -561,11 +572,27 @@ def serialize_chart_object(chart: ChartLike | None) -> ChartInfo | None:
# Extract structured filter information
filters_info = extract_filters_from_form_data(chart_form_data)
_viz_type = getattr(chart, "viz_type", None)
_display_name = None
if _viz_type:
try:
from superset.mcp_service.chart.registry import display_name_for_viz_type
except ImportError:
pass
else:
try:
_display_name = display_name_for_viz_type(_viz_type)
except Exception as exc: # noqa: BLE001 — third-party plugins may raise
logger.debug(
"Failed to resolve display name for viz_type=%r: %s", _viz_type, exc
)
return sanitize_chart_info_for_llm_context(
ChartInfo(
id=chart_id,
slice_name=getattr(chart, "slice_name", None),
viz_type=getattr(chart, "viz_type", None),
viz_type=_viz_type,
chart_type_display_name=_display_name,
datasource_name=getattr(chart, "datasource_name", None),
datasource_type=getattr(chart, "datasource_type", None),
url=chart_url,
@@ -747,7 +774,6 @@ class ColumnRef(BaseModel):
None,
min_length=1,
max_length=255,
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
validation_alias=AliasChoices("name", "column_name"),
)
label: str | None = Field(None, max_length=500)
@@ -903,7 +929,6 @@ class FilterConfig(BaseModel):
...,
min_length=1,
max_length=255,
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
validation_alias=AliasChoices("column", "col"),
)
op: Literal[
@@ -935,7 +960,9 @@ class FilterConfig(BaseModel):
"""Sanitize filter column name to prevent injection attacks."""
# sanitize_user_input raises ValueError when allow_empty=False (default)
# so the return value is guaranteed to be a non-None str
return sanitize_user_input(v, "Filter column", max_length=255) # type: ignore[return-value]
return sanitize_user_input(
v, "Filter column", max_length=255, check_sql_keywords=True
) # type: ignore[return-value]
@field_validator("value")
@classmethod
@@ -1052,8 +1079,13 @@ class PieChartConfig(UnknownFieldCheckMixin):
@model_validator(mode="after")
def reject_sql_expression_on_dimensions(self) -> "PieChartConfig":
"""sql_expression is metric-only; reject it on the dimension."""
"""sql_expression and saved_metric are metric-only; reject on the dimension."""
_reject_sql_expression_on_dimension(self.dimension, "dimension")
if self.dimension and self.dimension.saved_metric:
raise ValueError(
"dimension cannot use saved_metric=True; "
"saved metrics belong in the 'metric' field"
)
return self
@@ -1358,7 +1390,6 @@ class BigNumberChartConfig(UnknownFieldCheckMixin):
),
min_length=1,
max_length=255,
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
)
time_grain: TimeGrain | None = Field(
None,
@@ -1453,6 +1484,18 @@ class BigNumberChartConfig(UnknownFieldCheckMixin):
description="Filters to apply",
)
@field_validator("temporal_column")
@classmethod
def sanitize_temporal_column(cls, v: str | None) -> str | None:
"""Sanitize temporal column name to prevent SQL injection."""
return sanitize_user_input(
v,
"Temporal column",
max_length=255,
check_sql_keywords=True,
allow_empty=True,
)
@model_validator(mode="after")
def validate_trendline_fields(self) -> Self:
"""Validate trendline requires temporal column."""
@@ -1555,25 +1598,29 @@ class TableChartConfig(UnknownFieldCheckMixin):
@model_validator(mode="after")
def validate_unique_column_labels(self) -> "TableChartConfig":
"""Ensure all column labels are unique."""
labels_seen = set()
# Key is (saved_metric, label) so a saved metric and a regular column
# with the same input name are not flagged as duplicates — saved metrics
# resolve to their actual casing from the dataset during normalization.
labels_seen: dict[tuple[bool, str], str] = {}
duplicates = []
for i, col in enumerate(self.columns):
# Generate the label that will be used (same logic as create_metric_object)
if col.sql_expression:
# SQL metrics carry a required label; use it verbatim.
label = col.label
label = col.label or ""
elif col.saved_metric:
label = col.label or col.name
label = col.label or col.name or ""
elif col.aggregate:
label = col.label or f"{col.aggregate}({col.name})"
else:
label = col.label or col.name
label = col.label or col.name or ""
if label in labels_seen:
key = (col.saved_metric, label)
if key in labels_seen:
duplicates.append(f"columns[{i}]: '{label}'")
else:
labels_seen.add(label)
labels_seen[key] = f"columns[{i}]"
if duplicates:
raise ValueError(
@@ -1703,24 +1750,28 @@ class XYChartConfig(UnknownFieldCheckMixin):
@model_validator(mode="after")
def validate_unique_column_labels(self) -> "XYChartConfig":
"""Ensure all column labels are unique across x, y, and group_by."""
labels_seen: dict[str, str] = {}
# Key is (saved_metric, label) so a saved metric and a regular column
# with the same input name are not flagged as duplicates — saved metrics
# resolve to their actual casing from the dataset during normalization.
labels_seen: dict[tuple[bool, str], str] = {}
duplicates: list[str] = []
# Add x-axis label if present (x may be None, resolved later).
# The dimension validator rejects sql_expression on x, so name is set.
if self.x is not None:
x_label = self.x.label or self.x.name or ""
labels_seen[x_label] = "x"
labels_seen[(self.x.saved_metric, x_label)] = "x"
# Check Y-axis labels
for i, col in enumerate(self.y):
label = _metric_display_label(col)
if label in labels_seen:
key = (col.saved_metric, label)
if key in labels_seen:
duplicates.append(
f"y[{i}]: '{label}' (conflicts with {labels_seen[label]})"
f"y[{i}]: '{label}' (conflicts with {labels_seen[key]})"
)
else:
labels_seen[label] = f"y[{i}]"
labels_seen[key] = f"y[{i}]"
# Check group_by labels if present
if self.group_by:
@@ -1730,15 +1781,15 @@ class XYChartConfig(UnknownFieldCheckMixin):
# to prevent Superset "duplicate label" errors, so
# we allow them through validation.
continue
# group_by rejects sql_expression, so name is set.
group_label = col.label or col.name or ""
if group_label in labels_seen:
group_key = (col.saved_metric, group_label)
if group_key in labels_seen:
duplicates.append(
f"group_by[{i}]: '{group_label}' "
f"(conflicts with {labels_seen[group_label]})"
f"(conflicts with {labels_seen[group_key]})"
)
else:
labels_seen[group_label] = f"group_by[{i}]"
labels_seen[group_key] = f"group_by[{i}]"
if duplicates:
raise ValueError(

View File

@@ -105,18 +105,34 @@ async def generate_chart( # noqa: C901
- Set save_chart=True to permanently save the chart
- LLM clients MUST display returned chart URL to users
- Use numeric dataset ID or UUID (NOT schema.table_name format)
- MUST include chart_type in config (either 'xy' or 'table')
- MUST include chart_type in config (one of: 'xy', 'table', 'pie',
'pivot_table', 'mixed_timeseries', 'handlebars', 'big_number')
IMPORTANT: The 'chart_type' field in the config is a DISCRIMINATOR that determines
which chart configuration schema to use. It MUST be included and MUST match the
other fields in your configuration:
- Use chart_type='xy' for charts with x and y axes (line, bar, area, scatter)
Required fields: x, y
Required fields: y (x is optional — defaults to dataset's primary datetime column)
- Use chart_type='table' for tabular visualizations
Required fields: columns
- Use chart_type='pie' for pie/donut charts
Required fields: dimension, metric
- Use chart_type='pivot_table' for pivot table visualizations
Required fields: rows, metrics
- Use chart_type='mixed_timeseries' for dual-axis time-series charts
Required fields: x, y, y_secondary
- Use chart_type='handlebars' for custom template-based visualizations
Required fields: handlebars_template
- Use chart_type='big_number' for single KPI metric displays
Required fields: metric
Example usage for XY chart:
```json
{

View File

@@ -170,6 +170,21 @@ def _apply_unsaved_state_override(result: ChartInfo, form_data_key: str) -> None
# Update viz_type from cached form_data if present
if result.form_data and "viz_type" in result.form_data:
result.viz_type = result.form_data["viz_type"]
if result.viz_type:
try:
from superset.mcp_service.chart.registry import (
display_name_for_viz_type,
)
result.chart_type_display_name = display_name_for_viz_type(
result.viz_type
)
except Exception as exc: # noqa: BLE001
logger.debug(
"Failed to resolve display name for viz_type=%r: %s",
result.viz_type,
exc,
)
# Update filters from cached form_data
result.filters = extract_filters_from_form_data(result.form_data)

View File

@@ -471,6 +471,31 @@ async def update_chart( # noqa: C901
# config is already a typed ChartConfig | None (validated by Pydantic)
parsed_config = request.config
# Normalize column case to match dataset canonical names
# (mirrors generate_chart pipeline layer 4)
# When rebinding to a new dataset, normalize against the target dataset —
# not the chart's current datasource — so canonical names are resolved
# against the schema that will actually be used after the update.
effective_norm_dataset_id = (
request.dataset_id
if request.dataset_id is not None
else getattr(chart, "datasource_id", None)
)
if parsed_config is not None and effective_norm_dataset_id is not None:
from superset.mcp_service.chart.validation.dataset_validator import (
DatasetValidator,
NORMALIZATION_EXCEPTIONS,
)
try:
parsed_config = DatasetValidator.normalize_column_names(
parsed_config, effective_norm_dataset_id
)
except NORMALIZATION_EXCEPTIONS as e:
logger.warning(
"Column normalization failed for chart %s: %s", chart.id, e
)
if not request.generate_preview:
from superset.commands.chart.update import UpdateChartCommand

View File

@@ -22,17 +22,11 @@ Validates that referenced columns exist in the dataset schema.
import difflib
import logging
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, TypeVar
from superset.mcp_service.chart.schemas import (
BigNumberChartConfig,
ChartConfig,
ColumnRef,
HandlebarsChartConfig,
MixedTimeseriesChartConfig,
PieChartConfig,
PivotTableChartConfig,
TableChartConfig,
XYChartConfig,
)
from superset.mcp_service.common.error_schemas import (
ChartGenerationError,
@@ -40,8 +34,53 @@ from superset.mcp_service.common.error_schemas import (
DatasetContext,
)
_C = TypeVar("_C", bound=ChartConfig)
logger = logging.getLogger(__name__)
def build_dataset_context_from_orm(dataset: Any) -> DatasetContext | None:
"""Construct a :class:`DatasetContext` from an already-fetched ORM dataset.
Callers that already have the ORM object (e.g. after permission checks)
should use this to avoid a redundant ``DatasetDAO.find_by_id`` round trip.
"""
if dataset is None:
return None
columns: List[Dict[str, Any]] = []
for col in getattr(dataset, "columns", []) or []:
columns.append(
{
"name": col.column_name,
"type": str(col.type) if col.type else "UNKNOWN",
"is_temporal": getattr(col, "is_temporal", False),
"is_numeric": getattr(col, "is_numeric", False),
}
)
metrics: List[Dict[str, Any]] = []
for metric in getattr(dataset, "metrics", []) or []:
metrics.append(
{
"name": metric.metric_name,
"expression": metric.expression,
"description": metric.description,
}
)
database = getattr(dataset, "database", None)
database_name = getattr(database, "database_name", None) or ""
return DatasetContext(
id=dataset.id,
table_name=dataset.table_name,
schema=dataset.schema,
database_name=database_name,
available_columns=columns,
available_metrics=metrics,
)
# Exceptions that can occur during column name normalization.
# Shared by the validation pipeline and tool-level normalization calls.
NORMALIZATION_EXCEPTIONS = (
@@ -58,7 +97,7 @@ class DatasetValidator:
@staticmethod
def validate_against_dataset(
config: Any,
config: ChartConfig,
dataset_id: int | str,
dataset_context: DatasetContext | None = None,
) -> Tuple[bool, ChartGenerationError | None]:
@@ -208,120 +247,49 @@ class DatasetValidator:
@staticmethod
def _get_dataset_context(dataset_id: int | str) -> DatasetContext | None:
"""Get dataset context with column information."""
"""Fetch the ORM dataset by ID/UUID and build a :class:`DatasetContext`."""
try:
from superset.daos.dataset import DatasetDAO
# Find dataset
if isinstance(dataset_id, int) or (
isinstance(dataset_id, str) and dataset_id.isdigit()
):
dataset = DatasetDAO.find_by_id(int(dataset_id))
else:
# Try UUID lookup
dataset = DatasetDAO.find_by_id(dataset_id, id_column="uuid")
if not dataset:
return None
# Build context
columns = []
metrics = []
# Add table columns
for col in dataset.columns:
columns.append(
{
"name": col.column_name,
"type": str(col.type) if col.type else "UNKNOWN",
"is_temporal": col.is_temporal
if hasattr(col, "is_temporal")
else False,
"is_numeric": col.is_numeric
if hasattr(col, "is_numeric")
else False,
}
)
# Add metrics
for metric in dataset.metrics:
metrics.append(
{
"name": metric.metric_name,
"expression": metric.expression,
"description": metric.description,
}
)
return DatasetContext(
id=dataset.id,
table_name=dataset.table_name,
schema=dataset.schema,
database_name=dataset.database.database_name
if dataset.database
else None,
available_columns=columns,
available_metrics=metrics,
)
return build_dataset_context_from_orm(dataset)
except Exception as e:
logger.error("Error getting dataset context for %s: %s", dataset_id, e)
return None
@staticmethod
def _extract_column_references(config: Any) -> List[ColumnRef]: # noqa: C901
"""Extract all column references from a chart configuration.
def _extract_column_references(
config: ChartConfig,
) -> List[ColumnRef]:
"""Extract all column references from configuration via the plugin registry.
Covers every supported ``ChartConfig`` variant so fast-path tools
(``generate_explore_link``, ``update_chart_preview``) that only run
Tier-1 validation still catch bad column refs in pie / pivot table /
mixed timeseries / handlebars / big number charts — not just XY and
table.
Previously only handled TableChartConfig and XYChartConfig, causing
5 of 7 chart types to silently skip column validation. Now delegates
to the plugin for each chart type so all types are covered.
"""
refs: List[ColumnRef] = []
# Local import: plugins call DatasetValidator helpers from
# normalize_column_refs().
# A top-level import of registry in dataset_validator would make loading this
# module implicitly trigger plugin registration, creating a circular dependency.
from superset.mcp_service.chart.registry import get_registry
if isinstance(config, TableChartConfig):
refs.extend(config.columns)
elif isinstance(config, XYChartConfig):
if config.x is not None:
refs.append(config.x)
refs.extend(config.y)
if config.group_by:
refs.extend(config.group_by)
elif isinstance(config, PieChartConfig):
refs.append(config.dimension)
refs.append(config.metric)
elif isinstance(config, PivotTableChartConfig):
refs.extend(config.rows)
if config.columns:
refs.extend(config.columns)
refs.extend(config.metrics)
elif isinstance(config, MixedTimeseriesChartConfig):
refs.append(config.x)
refs.extend(config.y)
if config.group_by:
refs.extend(config.group_by)
refs.extend(config.y_secondary)
if config.group_by_secondary:
refs.extend(config.group_by_secondary)
elif isinstance(config, HandlebarsChartConfig):
if config.columns:
refs.extend(config.columns)
if config.groupby:
refs.extend(config.groupby)
if config.metrics:
refs.extend(config.metrics)
elif isinstance(config, BigNumberChartConfig):
refs.append(config.metric)
if config.temporal_column:
refs.append(ColumnRef(name=config.temporal_column))
chart_type = getattr(config, "chart_type", None)
if chart_type is None:
return []
# Filter columns (shared by every config type that defines ``filters``).
if filters := getattr(config, "filters", None):
for filter_config in filters:
refs.append(ColumnRef(name=filter_config.column))
plugin = get_registry().get(chart_type)
if plugin is None:
logger.warning("No plugin registered for chart_type=%r", chart_type)
return []
return refs
return plugin.extract_column_refs(config)
@staticmethod
def _column_exists(column_name: str, dataset_context: DatasetContext) -> bool:
@@ -341,7 +309,7 @@ class DatasetValidator:
return False
@staticmethod
def _get_canonical_column_name(
def get_canonical_column_name(
column_name: str, dataset_context: DatasetContext
) -> str:
"""
@@ -375,50 +343,26 @@ class DatasetValidator:
return column_name
@staticmethod
def _normalize_xy_config(
config_dict: Dict[str, Any], dataset_context: DatasetContext
) -> None:
"""Normalize column names in an XY chart config dict in place."""
# Normalize x-axis column
if "x" in config_dict and config_dict["x"] and config_dict["x"].get("name"):
config_dict["x"]["name"] = DatasetValidator._get_canonical_column_name(
config_dict["x"]["name"], dataset_context
)
def get_canonical_metric_name(
metric_name: str, dataset_context: DatasetContext
) -> str:
"""Return the canonical saved-metric name from available_metrics.
# Normalize y-axis columns (skip SQL-expression metrics; no name).
if "y" in config_dict and config_dict["y"]:
for y_col in config_dict["y"]:
if not y_col.get("name"):
continue
y_col["name"] = DatasetValidator._get_canonical_column_name(
y_col["name"], dataset_context
)
Unlike get_canonical_column_name, this only searches available_metrics
so that a same-named column with different casing cannot shadow the
metric's canonical name. Use this whenever saved_metric=True.
# Normalize group_by columns
if "group_by" in config_dict and config_dict["group_by"]:
for gb_col in config_dict["group_by"]:
if not gb_col.get("name"):
continue
gb_col["name"] = DatasetValidator._get_canonical_column_name(
gb_col["name"], dataset_context
)
Returns the original name when no metric matches (validation catches
the missing-metric case separately).
"""
metric_lower = metric_name.lower()
for metric in dataset_context.available_metrics:
if metric["name"].lower() == metric_lower:
return metric["name"]
return metric_name
@staticmethod
def _normalize_table_config(
config_dict: Dict[str, Any], dataset_context: DatasetContext
) -> None:
"""Normalize column names in a table chart config dict in place."""
if "columns" in config_dict and config_dict["columns"]:
for col in config_dict["columns"]:
# Skip SQL-expression metrics: no underlying column name.
if not col.get("name"):
continue
col["name"] = DatasetValidator._get_canonical_column_name(
col["name"], dataset_context
)
@staticmethod
def _normalize_filters(
def normalize_filters(
config_dict: Dict[str, Any], dataset_context: DatasetContext
) -> None:
"""Normalize filter column names in a config dict in place."""
@@ -426,17 +370,17 @@ class DatasetValidator:
for filter_config in config_dict["filters"]:
if filter_config and "column" in filter_config:
filter_config["column"] = (
DatasetValidator._get_canonical_column_name(
DatasetValidator.get_canonical_column_name(
filter_config["column"], dataset_context
)
)
@staticmethod
def normalize_column_names(
config: TableChartConfig | XYChartConfig,
config: _C,
dataset_id: int | str,
dataset_context: DatasetContext | None = None,
) -> TableChartConfig | XYChartConfig:
) -> _C:
"""
Normalize column names in config to match the canonical dataset column names.
@@ -445,6 +389,9 @@ class DatasetValidator:
(e.g., 'OrderDate'). The frontend performs case-sensitive comparisons,
so we need to ensure column names match exactly.
Previously only XYChartConfig and TableChartConfig were normalized; now
all 7 chart types are handled via the plugin registry.
Args:
config: Chart configuration with column references
dataset_id: Dataset ID to get canonical column names from
@@ -459,22 +406,24 @@ class DatasetValidator:
if not dataset_context:
return config
# Create a mutable copy of the config
config_dict = config.model_dump()
# Local import: plugins call DatasetValidator helpers from
# normalize_column_refs().
# A top-level import of registry in dataset_validator would make loading this
# module implicitly trigger plugin registration, creating a circular dependency.
from superset.mcp_service.chart.registry import get_registry
# Normalize based on config type
if isinstance(config, XYChartConfig):
DatasetValidator._normalize_xy_config(config_dict, dataset_context)
elif isinstance(config, TableChartConfig):
DatasetValidator._normalize_table_config(config_dict, dataset_context)
chart_type = getattr(config, "chart_type", None)
if chart_type is None:
return config
# Normalize filter columns (common to both config types)
DatasetValidator._normalize_filters(config_dict, dataset_context)
plugin = get_registry().get(chart_type)
if plugin is None:
logger.warning(
"No plugin for chart_type=%r; skipping column normalization", chart_type
)
return config
# Reconstruct the config with normalized names
if isinstance(config, XYChartConfig):
return XYChartConfig.model_validate(config_dict)
return TableChartConfig.model_validate(config_dict)
return plugin.normalize_column_refs(config, dataset_context)
@staticmethod
def _get_column_suggestions(

View File

@@ -23,10 +23,7 @@ Validates performance, compatibility, and user experience issues.
import logging
from typing import Any, Dict, List, Tuple
from superset.mcp_service.chart.schemas import (
ChartConfig,
XYChartConfig,
)
from superset.mcp_service.chart.schemas import ChartConfig
logger = logging.getLogger(__name__)
@@ -56,20 +53,14 @@ class RuntimeValidator:
warnings: List[str] = []
suggestions: List[str] = []
# Only check XY charts for format and cardinality issues
if isinstance(config, XYChartConfig):
# Format-type compatibility validation
format_warnings = RuntimeValidator._validate_format_compatibility(config)
if format_warnings:
warnings.extend(format_warnings)
# Cardinality validation
cardinality_warnings, cardinality_suggestions = (
RuntimeValidator._validate_cardinality(config, dataset_id)
)
if cardinality_warnings:
warnings.extend(cardinality_warnings)
suggestions.extend(cardinality_suggestions)
# Per-plugin runtime warnings (format, cardinality, etc.)
plugin_warnings, plugin_suggestions = RuntimeValidator._validate_plugin_runtime(
config, dataset_id
)
if plugin_warnings:
warnings.extend(plugin_warnings)
if plugin_suggestions:
suggestions.extend(plugin_suggestions)
# Chart type appropriateness validation (for all chart types)
type_warnings, type_suggestions = RuntimeValidator._validate_chart_type(
@@ -98,61 +89,28 @@ class RuntimeValidator:
return True, None
@staticmethod
def _validate_format_compatibility(config: XYChartConfig) -> List[str]:
"""Validate format-type compatibility."""
warnings: List[str] = []
try:
# Import here to avoid circular imports
from .format_validator import FormatTypeValidator
is_valid, format_warnings = (
FormatTypeValidator.validate_format_compatibility(config)
)
if format_warnings:
warnings.extend(format_warnings)
except ImportError:
logger.warning("Format validator not available")
except Exception as e:
logger.warning("Format validation failed: %s", e)
return warnings
@staticmethod
def _validate_cardinality(
config: XYChartConfig, dataset_id: int | str
def _validate_plugin_runtime(
config: ChartConfig, dataset_id: int | str
) -> Tuple[List[str], List[str]]:
"""Validate cardinality issues."""
warnings: List[str] = []
suggestions: List[str] = []
"""Delegate per-chart-type runtime warnings to the plugin registry.
Each plugin's get_runtime_warnings() method returns chart-type-specific
warnings and suggestions (e.g. format/cardinality for XY). The registry
dispatch removes the previous isinstance(config, XYChartConfig) hardcoding.
"""
try:
# Import here to avoid circular imports
from .cardinality_validator import CardinalityValidator
from superset.mcp_service.chart.registry import get_registry
# Determine chart type for cardinality thresholds
chart_type = config.kind if hasattr(config, "kind") else "default"
# Check X-axis cardinality
if config.x is None or config.x.name is None:
return warnings, suggestions
is_ok, cardinality_info = CardinalityValidator.check_cardinality(
dataset_id=dataset_id,
x_column=config.x.name,
chart_type=chart_type,
group_by_column=(config.group_by[0].name if config.group_by else None),
)
if not is_ok and cardinality_info:
warnings.extend(cardinality_info.get("warnings", []))
suggestions.extend(cardinality_info.get("suggestions", []))
except ImportError:
logger.warning("Cardinality validator not available")
except Exception as e:
logger.warning("Cardinality validation failed: %s", e)
return warnings, suggestions
chart_type = getattr(config, "chart_type", None)
if chart_type is None:
return [], []
plugin = get_registry().get(chart_type)
if plugin is None:
return [], []
return plugin.get_runtime_warnings(config, dataset_id)
except Exception as exc: # noqa: BLE001 — plugin code is third-party-extensible
logger.warning("Plugin runtime validation failed: %s", exc)
return [], []
@staticmethod
def _validate_chart_type(
@@ -184,7 +142,7 @@ class RuntimeValidator:
except ImportError:
logger.warning("Chart type suggester not available")
except Exception as e:
except Exception as e: # noqa: BLE001 — non-blocking warning path
logger.warning("Chart type validation failed: %s", e)
return warnings, suggestions

View File

@@ -147,19 +147,16 @@ class SchemaValidator:
chart_type: str,
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Validate chart type and dispatch to type-specific pre-validation."""
chart_type_validators = {
"xy": SchemaValidator._pre_validate_xy_config,
"table": SchemaValidator._pre_validate_table_config,
"pie": SchemaValidator._pre_validate_pie_config,
"pivot_table": SchemaValidator._pre_validate_pivot_table_config,
"mixed_timeseries": SchemaValidator._pre_validate_mixed_timeseries_config,
"handlebars": SchemaValidator._pre_validate_handlebars_config,
"big_number": SchemaValidator._pre_validate_big_number_config,
}
"""Validate chart type and dispatch to plugin pre-validation."""
# avoid circular import — a top-level import of registry here would pull in
# the plugins package before it finishes registering, creating an import cycle.
from superset.mcp_service.chart.registry import get_registry
if not isinstance(chart_type, str) or chart_type not in chart_type_validators:
valid_types = ", ".join(chart_type_validators.keys())
registry = get_registry()
# Compute once — used in both error branches below.
valid_types = ", ".join(registry.all_types())
if not isinstance(chart_type, str) or not registry.is_registered(chart_type):
return False, ChartGenerationError(
error_type="invalid_chart_type",
message=f"Invalid chart_type: '{chart_type}'",
@@ -178,376 +175,26 @@ class SchemaValidator:
error_code="INVALID_CHART_TYPE",
)
return chart_type_validators[chart_type](config)
@staticmethod
def _pre_validate_xy_config(
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Pre-validate XY chart configuration."""
# x is optional — defaults to dataset's main_dttm_col in map_xy_config
if "y" not in config:
# Single get() call — returns None when the plugin is disabled.
# Avoids calling enabled_func twice (separate is_enabled + get both
# invoke _is_plugin_enabled, which may call operator-supplied callable).
plugin = registry.get(chart_type)
if plugin is None:
return False, ChartGenerationError(
error_type="missing_xy_fields",
message="XY chart missing required field: 'y' (Y-axis metrics)",
details="XY charts require Y-axis (metrics) specifications. "
"X-axis is optional and defaults to the dataset's primary "
"datetime column when omitted.",
error_type="disabled_chart_type",
message=f"Chart type '{chart_type}' is not enabled on this instance",
details=f"Chart type '{chart_type}' is registered but has been "
f"disabled by the operator. "
f"Enabled chart types: {valid_types}",
suggestions=[
"Add 'y' field: [{'name': 'metric_column', 'aggregate': 'SUM'}] "
"for Y-axis",
"Example: {'chart_type': 'xy', 'x': {'name': 'date'}, "
"'y': [{'name': 'sales', 'aggregate': 'SUM'}]}",
f"Use one of the enabled chart types: {valid_types}",
"Contact your administrator if you believe this is an error",
],
error_code="MISSING_XY_FIELDS",
error_code="DISABLED_CHART_TYPE",
)
# Validate Y is a list
if not isinstance(config.get("y", []), list):
return False, ChartGenerationError(
error_type="invalid_y_format",
message="Y-axis must be a list of metrics",
details="The 'y' field must be an array of metric specifications",
suggestions=[
"Wrap Y-axis metric in array: 'y': [{'name': 'column', "
"'aggregate': 'SUM'}]",
"Multiple metrics supported: 'y': [metric1, metric2, ...]",
],
error_code="INVALID_Y_FORMAT",
)
return True, None
@staticmethod
def _pre_validate_table_config(
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Pre-validate table chart configuration."""
if "columns" not in config:
return False, ChartGenerationError(
error_type="missing_columns",
message="Table chart missing required field: columns",
details="Table charts require a 'columns' array to specify which "
"columns to display",
suggestions=[
"Add 'columns' field with array of column specifications",
"Example: 'columns': [{'name': 'product'}, {'name': 'sales', "
"'aggregate': 'SUM'}]",
"Each column can have optional 'aggregate' for metrics",
],
error_code="MISSING_COLUMNS",
)
if not isinstance(config.get("columns", []), list):
return False, ChartGenerationError(
error_type="invalid_columns_format",
message="Columns must be a list",
details="The 'columns' field must be an array of column specifications",
suggestions=[
"Ensure columns is an array: 'columns': [...]",
"Each column should be an object with 'name' field",
],
error_code="INVALID_COLUMNS_FORMAT",
)
return True, None
@staticmethod
def _pre_validate_pie_config(
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Pre-validate pie chart configuration."""
missing_fields = []
if "dimension" not in config:
missing_fields.append("'dimension' (category column for slices)")
if "metric" not in config:
missing_fields.append("'metric' (value metric for slice sizes)")
if missing_fields:
return False, ChartGenerationError(
error_type="missing_pie_fields",
message=f"Pie chart missing required "
f"fields: {', '.join(missing_fields)}",
details="Pie charts require a dimension (categories) and a metric "
"(values)",
suggestions=[
"Add 'dimension' field: {'name': 'category_column'}",
"Add 'metric' field: {'name': 'value_column', 'aggregate': 'SUM'}",
"Example: {'chart_type': 'pie', 'dimension': {'name': "
"'product'}, 'metric': {'name': 'revenue', 'aggregate': 'SUM'}}",
],
error_code="MISSING_PIE_FIELDS",
)
return True, None
@staticmethod
def _pre_validate_handlebars_config(
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Pre-validate handlebars chart configuration."""
if "handlebars_template" not in config:
return False, ChartGenerationError(
error_type="missing_handlebars_template",
message="Handlebars chart missing required field: handlebars_template",
details="Handlebars charts require a 'handlebars_template' string "
"containing Handlebars HTML template markup",
suggestions=[
"Add 'handlebars_template' with a Handlebars HTML template",
"Data is available as {{data}} array in the template",
"Example: '<ul>{{#each data}}<li>{{this.name}}: "
"{{this.value}}</li>{{/each}}</ul>'",
],
error_code="MISSING_HANDLEBARS_TEMPLATE",
)
template = config.get("handlebars_template")
if not isinstance(template, str) or not template.strip():
return False, ChartGenerationError(
error_type="invalid_handlebars_template",
message="Handlebars template must be a non-empty string",
details="The 'handlebars_template' field must be a non-empty string "
"containing valid Handlebars HTML template markup",
suggestions=[
"Ensure handlebars_template is a non-empty string",
"Example: '<ul>{{#each data}}<li>{{this.name}}</li>{{/each}}</ul>'",
],
error_code="INVALID_HANDLEBARS_TEMPLATE",
)
query_mode = config.get("query_mode", "aggregate")
if query_mode not in ("aggregate", "raw"):
return False, ChartGenerationError(
error_type="invalid_query_mode",
message="Invalid query_mode for handlebars chart",
details="query_mode must be either 'aggregate' or 'raw'",
suggestions=[
"Use 'aggregate' for aggregated data (default)",
"Use 'raw' for individual rows",
],
error_code="INVALID_QUERY_MODE",
)
if query_mode == "raw" and not config.get("columns"):
return False, ChartGenerationError(
error_type="missing_raw_columns",
message="Handlebars chart in 'raw' mode requires 'columns'",
details="When query_mode is 'raw', you must specify which columns "
"to include in the query results",
suggestions=[
"Add 'columns': [{'name': 'column_name'}] for raw mode",
"Or use query_mode='aggregate' with 'metrics' "
"and optional 'groupby'",
],
error_code="MISSING_RAW_COLUMNS",
)
if query_mode == "aggregate" and not config.get("metrics"):
return False, ChartGenerationError(
error_type="missing_aggregate_metrics",
message="Handlebars chart in 'aggregate' mode requires 'metrics'",
details="When query_mode is 'aggregate' (default), you must specify "
"at least one metric with an aggregate function",
suggestions=[
"Add 'metrics': [{'name': 'column', 'aggregate': 'SUM'}]",
"Or use query_mode='raw' with 'columns' for individual rows",
],
error_code="MISSING_AGGREGATE_METRICS",
)
return True, None
@staticmethod
def _pre_validate_big_number_config(
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Pre-validate big number chart configuration."""
if "metric" not in config:
return False, ChartGenerationError(
error_type="missing_metric",
message="Big Number chart missing required field: metric",
details="Big Number charts require a 'metric' field "
"specifying the value to display",
suggestions=[
"Add 'metric' with name and aggregate: "
"{'name': 'revenue', 'aggregate': 'SUM'}",
"The aggregate function is required (SUM, COUNT, AVG, MIN, MAX)",
"Example: {'chart_type': 'big_number', "
"'metric': {'name': 'sales', 'aggregate': 'SUM'}}",
],
error_code="MISSING_BIG_NUMBER_METRIC",
)
metric = config.get("metric", {})
if not isinstance(metric, dict):
return False, ChartGenerationError(
error_type="invalid_metric_type",
message="Big Number metric must be a dict with 'name' and 'aggregate'",
details="The 'metric' field must be an object, "
f"got {type(metric).__name__}",
suggestions=[
"Use a dict: {'name': 'col', 'aggregate': 'SUM'}",
"Valid aggregates: SUM, COUNT, AVG, MIN, MAX",
],
error_code="INVALID_BIG_NUMBER_METRIC_TYPE",
)
if (
not metric.get("aggregate")
and not metric.get("saved_metric")
and not metric.get("sql_expression")
):
return False, ChartGenerationError(
error_type="missing_metric_aggregate",
message="Big Number metric must include an aggregate function, "
"a saved metric reference, or a SQL expression",
details="The metric must have an 'aggregate' field, "
"'saved_metric': true, or 'sql_expression'",
suggestions=[
"Add 'aggregate' to your metric: "
"{'name': 'col', 'aggregate': 'SUM'}",
"Or use a saved metric: "
"{'name': 'total_sales', 'saved_metric': true}",
"Or a custom SQL metric: "
"{'sql_expression': 'SUM(a)/SUM(b)', 'label': 'Ratio'}",
"Valid aggregates: SUM, COUNT, AVG, MIN, MAX",
],
error_code="MISSING_BIG_NUMBER_AGGREGATE",
)
# ``label`` may be any JSON type here (pre-Pydantic), so test the
# string-ness explicitly before calling ``.strip()``.
label = metric.get("label")
if metric.get("sql_expression") and not (
isinstance(label, str) and label.strip()
):
return False, ChartGenerationError(
error_type="missing_sql_metric_label",
message="Big Number metric with sql_expression requires a label",
details=(
"Custom SQL metrics have no column name to derive a label "
"from, so 'label' is required for display."
),
suggestions=[
"Add a 'label': "
"{'sql_expression': 'SUM(a)/SUM(b)', 'label': 'Ratio'}",
],
error_code="MISSING_SQL_METRIC_LABEL",
)
show_trendline = config.get("show_trendline", False)
temporal_column = config.get("temporal_column")
if show_trendline and not temporal_column:
return False, ChartGenerationError(
error_type="missing_temporal_column",
message="Trendline requires a temporal column",
details="When 'show_trendline' is True, a "
"'temporal_column' must be specified",
suggestions=[
"Add 'temporal_column': 'date_column_name'",
"Or set 'show_trendline': false for number only",
"Use get_dataset_info to find temporal columns",
],
error_code="MISSING_TEMPORAL_COLUMN",
)
return True, None
@staticmethod
def _pre_validate_pivot_table_config(
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Pre-validate pivot table configuration."""
missing_fields = []
if "rows" not in config:
missing_fields.append("'rows' (row grouping columns)")
if "metrics" not in config:
missing_fields.append("'metrics' (aggregation metrics)")
if missing_fields:
return False, ChartGenerationError(
error_type="missing_pivot_fields",
message=f"Pivot table missing required "
f"fields: {', '.join(missing_fields)}",
details="Pivot tables require row groupings and metrics",
suggestions=[
"Add 'rows' field: [{'name': 'category'}]",
"Add 'metrics' field: [{'name': 'sales', 'aggregate': 'SUM'}]",
"Optional 'columns' for cross-tabulation: [{'name': 'region'}]",
],
error_code="MISSING_PIVOT_FIELDS",
)
if not isinstance(config.get("rows", []), list):
return False, ChartGenerationError(
error_type="invalid_rows_format",
message="Rows must be a list of columns",
details="The 'rows' field must be an array of column specifications",
suggestions=[
"Wrap row columns in array: 'rows': [{'name': 'category'}]",
],
error_code="INVALID_ROWS_FORMAT",
)
if not isinstance(config.get("metrics", []), list):
return False, ChartGenerationError(
error_type="invalid_metrics_format",
message="Metrics must be a list",
details="The 'metrics' field must be an array of metric specifications",
suggestions=[
"Wrap metrics in array: 'metrics': [{'name': 'sales', "
"'aggregate': 'SUM'}]",
],
error_code="INVALID_METRICS_FORMAT",
)
return True, None
@staticmethod
def _pre_validate_mixed_timeseries_config(
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Pre-validate mixed timeseries configuration."""
missing_fields = []
if "x" not in config:
missing_fields.append("'x' (X-axis temporal column)")
if "y" not in config:
missing_fields.append("'y' (primary Y-axis metrics)")
if "y_secondary" not in config:
missing_fields.append("'y_secondary' (secondary Y-axis metrics)")
if missing_fields:
return False, ChartGenerationError(
error_type="missing_mixed_timeseries_fields",
message=f"Mixed timeseries chart missing required "
f"fields: {', '.join(missing_fields)}",
details="Mixed timeseries charts require an x-axis, primary metrics, "
"and secondary metrics",
suggestions=[
"Add 'x' field: {'name': 'date_column'}",
"Add 'y' field: [{'name': 'revenue', 'aggregate': 'SUM'}]",
"Add 'y_secondary' field: [{'name': 'orders', "
"'aggregate': 'COUNT'}]",
"Optional: 'primary_kind' and 'secondary_kind' for chart types",
],
error_code="MISSING_MIXED_TIMESERIES_FIELDS",
)
for field_name in ["y", "y_secondary"]:
if not isinstance(config.get(field_name, []), list):
return False, ChartGenerationError(
error_type=f"invalid_{field_name}_format",
message=f"'{field_name}' must be a list of metrics",
details=f"The '{field_name}' field must be an array of metric "
"specifications",
suggestions=[
f"Wrap in array: '{field_name}': "
"[{'name': 'col', 'aggregate': 'SUM'}]",
],
error_code=f"INVALID_{field_name.upper()}_FORMAT",
)
if (error := plugin.pre_validate(config)) is not None:
return False, error
return True, None
@staticmethod
@@ -562,89 +209,27 @@ class SchemaValidator:
if err.get("type") == "union_tag_invalid" or "discriminator" in str(
err.get("ctx", {})
):
# This is the generic union error - provide better message
config = request_data.get("config", {})
chart_type = config.get("chart_type", "unknown")
# avoid circular import
from superset.mcp_service.chart.registry import get_registry
if chart_type == "xy":
return ChartGenerationError(
error_type="xy_validation_error",
message="XY chart configuration validation failed",
details="The XY chart configuration is missing required "
"fields or has invalid structure",
suggestions=[
"Ensure 'x' field exists with {'name': 'column_name'}",
"Ensure 'y' field is an array: [{'name': 'metric', "
"'aggregate': 'SUM'}]",
"Check that all column names are strings",
"Verify aggregate functions are valid: SUM, COUNT, AVG, "
"MIN, MAX",
],
error_code="XY_VALIDATION_ERROR",
)
elif chart_type == "table":
return ChartGenerationError(
error_type="table_validation_error",
message="Table chart configuration validation failed",
details="The table chart configuration is missing required "
"fields or has invalid structure",
suggestions=[
"Ensure 'columns' field is an array of column "
"specifications",
"Each column needs {'name': 'column_name'}",
"Optional: add 'aggregate' for metrics",
"Example: 'columns': [{'name': 'product'}, {'name': "
"'sales', 'aggregate': 'SUM'}]",
],
error_code="TABLE_VALIDATION_ERROR",
)
elif chart_type == "handlebars":
return ChartGenerationError(
error_type="handlebars_validation_error",
message="Handlebars chart configuration validation failed",
details="The handlebars chart configuration is missing "
"required fields or has invalid structure",
suggestions=[
"Ensure 'handlebars_template' is a non-empty string",
"For aggregate mode: add 'metrics' with aggregate "
"functions",
"For raw mode: set 'query_mode': 'raw' and add 'columns'",
"Example: {'chart_type': 'handlebars', "
"'handlebars_template': '<ul>{{#each data}}<li>"
"{{this.name}}</li>{{/each}}</ul>', "
"'metrics': [{'name': 'sales', 'aggregate': 'SUM'}]}",
],
error_code="HANDLEBARS_VALIDATION_ERROR",
)
elif chart_type == "big_number":
return ChartGenerationError(
error_type="big_number_validation_error",
message="Big Number chart configuration validation failed",
details="The Big Number chart configuration is "
"missing required fields or has invalid "
"structure",
suggestions=[
"Ensure 'metric' field has 'name' and 'aggregate'",
"Example: 'metric': {'name': 'revenue', "
"'aggregate': 'SUM'}",
"For trendline: add 'show_trendline': true "
"and 'temporal_column': 'date_col'",
"Without trendline: just provide the metric",
],
error_code="BIG_NUMBER_VALIDATION_ERROR",
)
chart_type = request_data.get("config", {}).get("chart_type", "")
plugin = get_registry().get(chart_type)
if plugin is not None:
hint = plugin.schema_error_hint()
if hint is not None:
return hint
# Default enhanced error
error_details = []
for err in errors[:3]: # Show first 3 errors
loc = " -> ".join(str(location) for location in err.get("loc", []))
msg = err.get("msg", "Validation failed")
error_details.append(f"{loc}: {msg}")
error_details.append(f"{loc}: {msg}" if loc else msg)
return ChartGenerationError(
error_type="validation_error",
message="Chart configuration validation failed",
details="; ".join(error_details),
details="; ".join(error_details) or "Invalid chart configuration structure",
suggestions=[
"Check that all required fields are present",
"Ensure field types match the schema",

View File

@@ -51,6 +51,16 @@ try:
logger.info("Reusing existing Flask app from app context for MCP service")
# Use _get_current_object() to get the actual Flask app, not the LocalProxy
app = current_app._get_current_object()
# Configure the chart plugin registry from the host app's config.
# This module is the registry's only configure site — core Superset
# startup must not import mcp_service (fastmcp is an optional extra).
from superset.mcp_service.chart import registry as _chart_registry
_chart_registry.configure(
disabled=app.config.get("MCP_DISABLED_CHART_PLUGINS"),
enabled_func=app.config.get("MCP_CHART_PLUGIN_ENABLED_FUNC"),
)
elif appbuilder_initialized:
# appbuilder is initialized but we have no app context. Calling
# create_app() here would invoke appbuilder.init_app() a second
@@ -81,6 +91,18 @@ try:
mcp_config = get_mcp_config(_mcp_app.config)
_mcp_app.config.update(mcp_config)
# Configure the chart plugin registry with post-overlay values so
# MCP-specific overrides (e.g. MCP_DISABLED_CHART_PLUGINS set by the
# operator) take effect. This module is the registry's only configure
# site — core Superset startup must not import mcp_service (fastmcp
# is an optional extra).
from superset.mcp_service.chart import registry as _chart_registry
_chart_registry.configure(
disabled=_mcp_app.config.get("MCP_DISABLED_CHART_PLUGINS"),
enabled_func=_mcp_app.config.get("MCP_CHART_PLUGIN_ENABLED_FUNC"),
)
with _mcp_app.app_context():
from superset.core.mcp.core_mcp_injection import (
initialize_core_mcp_dependencies,

View File

@@ -18,6 +18,7 @@
import logging
import secrets
from collections.abc import Callable
from typing import Any, Dict, Optional, Sequence
from fastmcp.server.auth.providers.jwt import JWTVerifier
@@ -83,6 +84,46 @@ MCP_RBAC_ENABLED = True
# MCP_DISABLED_TOOLS = {"extensions.myorg.myext.some_tool"}
MCP_DISABLED_TOOLS: set[str] = set()
# =============================================================================
# MCP Chart Plugin Filtering
# =============================================================================
#
# Overview:
# ---------
# These two settings let operators enable/disable individual chart type plugins
# at runtime without a code deploy.
#
# Use cases:
# - Emergency kill switch: add "handlebars" to MCP_DISABLED_CHART_PLUGINS and
# restart to immediately hide it from all callers.
# - Dynamic per-request control (A/B test, gradual rollout): set
# MCP_CHART_PLUGIN_ENABLED_FUNC to an in-process predicate that can vary
# by user, request header, or any other context available at call time.
#
# Priority:
# MCP_CHART_PLUGIN_ENABLED_FUNC takes precedence over MCP_DISABLED_CHART_PLUGINS.
# When the callable is set, the deny-list is ignored entirely.
#
# MCP_CHART_PLUGIN_ENABLED_FUNC contract:
# - Called as enabled_func(chart_type: str) -> bool for every registry lookup.
# - Must be cheap and in-process: consult already-loaded feature flags or
# request-local context (e.g. Flask g). Do NOT perform network I/O per call.
# - On exception, the registry fails CLOSED (plugin hidden) and logs a warning.
# - Example (Harness / Split via pre-fetched flags in g):
# from flask import g
# def MCP_CHART_PLUGIN_ENABLED_FUNC(chart_type: str) -> bool:
# flags = getattr(g, "feature_flags", {})
# return flags.get(f"mcp_chart_{chart_type}", True)
# =============================================================================
# Chart types in this set are hidden from all registry lookups.
# Use frozenset to avoid accidental mutation.
MCP_DISABLED_CHART_PLUGINS: frozenset[str] = frozenset()
# Dynamic per-call predicate. When set, overrides MCP_DISABLED_CHART_PLUGINS.
# Signature: (chart_type: str) -> bool
MCP_CHART_PLUGIN_ENABLED_FUNC: Callable[[str], bool] | None = None
# MCP JWT Debug Errors - controls server-side JWT debug logging.
# When False (default), uses the default JWTVerifier with minimal logging.
# When True, uses DetailedJWTVerifier with tiered logging:
@@ -196,7 +237,7 @@ MCP_FACTORY_CONFIG = {
#
# For multi-pod/Kubernetes deployments, setting CACHE_REDIS_URL automatically
# enables Redis-backed EventStore to share session state across pods.
MCP_STORE_CONFIG: Dict[str, Any] = {
MCP_STORE_CONFIG: dict[str, Any] = {
"enabled": False, # Disabled by default - caching uses in-memory store
"CACHE_REDIS_URL": None, # Redis URL, e.g., "redis://localhost:6379/0"
# Wrapper class that prefixes all keys. Each consumer provides their own prefix.
@@ -209,7 +250,7 @@ MCP_STORE_CONFIG: Dict[str, Any] = {
# MCP Response Caching Configuration - controls caching behavior and TTLs
# When enabled without MCP_STORE_CONFIG, uses in-memory store.
# When enabled with MCP_STORE_CONFIG, uses Redis store.
MCP_CACHE_CONFIG: Dict[str, Any] = {
MCP_CACHE_CONFIG: dict[str, Any] = {
"enabled": False, # Disabled by default
"CACHE_KEY_PREFIX": None, # Only needed when using the store
"list_tools_ttl": 60 * 5, # 5 minutes
@@ -260,7 +301,7 @@ MCP_CACHE_CONFIG: Dict[str, Any] = {
# Uses character-based heuristic (~3.5 chars per token for JSON).
# This is intentionally conservative to avoid underestimating.
# =============================================================================
MCP_RESPONSE_SIZE_CONFIG: Dict[str, Any] = {
MCP_RESPONSE_SIZE_CONFIG: dict[str, Any] = {
"enabled": True, # Enabled by default to protect LLM clients
"token_limit": DEFAULT_TOKEN_LIMIT,
"warn_threshold_pct": DEFAULT_WARN_THRESHOLD_PCT,
@@ -315,7 +356,7 @@ MCP_RESPONSE_SIZE_CONFIG: Dict[str, Any] = {
# - compact_schemas is ignored when include_schemas=False (no schema to
# compact); max_description_length still applies in summary mode.
# =============================================================================
MCP_TOOL_SEARCH_CONFIG: Dict[str, Any] = {
MCP_TOOL_SEARCH_CONFIG: dict[str, Any] = {
"enabled": True, # Enabled by default — reduces initial context by ~70%
"strategy": "bm25", # "bm25" (natural language) or "regex" (pattern matching)
"max_results": 5, # Max tools returned per search
@@ -523,7 +564,7 @@ def generate_secret_key() -> str:
return secrets.token_urlsafe(42)
def get_mcp_config(app_config: Dict[str, Any] | None = None) -> Dict[str, Any]:
def get_mcp_config(app_config: dict[str, Any] | None = None) -> dict[str, Any]:
"""
Get complete MCP configuration dictionary.
@@ -544,6 +585,8 @@ def get_mcp_config(app_config: Dict[str, Any] | None = None) -> Dict[str, Any]:
"MCP_DEBUG": MCP_DEBUG,
"MCP_RBAC_ENABLED": MCP_RBAC_ENABLED,
"MCP_DISABLED_TOOLS": set(MCP_DISABLED_TOOLS),
"MCP_DISABLED_CHART_PLUGINS": MCP_DISABLED_CHART_PLUGINS,
"MCP_CHART_PLUGIN_ENABLED_FUNC": MCP_CHART_PLUGIN_ENABLED_FUNC,
**MCP_SESSION_CONFIG,
**MCP_CSRF_CONFIG,
}
@@ -553,8 +596,8 @@ def get_mcp_config(app_config: Dict[str, Any] | None = None) -> Dict[str, Any]:
def get_mcp_config_with_overrides(
app_config: Dict[str, Any] | None = None,
) -> Dict[str, Any]:
app_config: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Alternative approach: Allow any app_config keys, not just predefined ones.
@@ -568,7 +611,7 @@ def get_mcp_config_with_overrides(
return {**defaults, **app_config}
def get_mcp_factory_config() -> Dict[str, Any]:
def get_mcp_factory_config() -> dict[str, Any]:
"""
Get FastMCP factory configuration.

View File

@@ -82,8 +82,8 @@ class TestPrestoDbEngineSpec(SupersetTestCase):
def verify_presto_column(self, column, expected_results):
inspector = mock.Mock()
preparer = inspector.engine.dialect.identifier_preparer
preparer.quote_identifier = preparer.quote = preparer.quote_schema = lambda x: (
f'"{x}"'
preparer.quote_identifier = preparer.quote = preparer.quote_schema = (
lambda x: f'"{x}"'
)
row = mock.Mock()
row.Column, row.Type, row.Null = column
@@ -828,8 +828,8 @@ class TestPrestoDbEngineSpec(SupersetTestCase):
def test_show_columns(self):
inspector = mock.MagicMock()
preparer = inspector.engine.dialect.identifier_preparer
preparer.quote_identifier = preparer.quote = preparer.quote_schema = lambda x: (
f'"{x}"'
preparer.quote_identifier = preparer.quote = preparer.quote_schema = (
lambda x: f'"{x}"'
)
inspector.bind.execute.return_value.fetchall = mock.MagicMock(
return_value=["a", "b"]
@@ -845,8 +845,8 @@ class TestPrestoDbEngineSpec(SupersetTestCase):
def test_show_columns_with_schema(self):
inspector = mock.MagicMock()
preparer = inspector.engine.dialect.identifier_preparer
preparer.quote_identifier = preparer.quote = preparer.quote_schema = lambda x: (
f'"{x}"'
preparer.quote_identifier = preparer.quote = preparer.quote_schema = (
lambda x: f'"{x}"'
)
inspector.bind.execute.return_value.fetchall = mock.MagicMock(
return_value=["a", "b"]

View File

@@ -260,9 +260,8 @@ class TestSecurityGuestTokenApiTokenValidator(SupersetTestCase):
@with_config(
{
"GUEST_TOKEN_VALIDATOR_HOOK": lambda x: (
len(x["rls"]) == 1 and "tenant_id=" in x["rls"][0]["clause"]
)
"GUEST_TOKEN_VALIDATOR_HOOK": lambda x: len(x["rls"]) == 1
and "tenant_id=" in x["rls"][0]["clause"]
}
)
def test_guest_validator_hook_real_world_example_positive(self):
@@ -277,9 +276,8 @@ class TestSecurityGuestTokenApiTokenValidator(SupersetTestCase):
@with_config(
{
"GUEST_TOKEN_VALIDATOR_HOOK": lambda x: (
len(x["rls"]) == 1 and "tenant_id=" in x["rls"][0]["clause"]
)
"GUEST_TOKEN_VALIDATOR_HOOK": lambda x: len(x["rls"]) == 1
and "tenant_id=" in x["rls"][0]["clause"]
}
)
def test_guest_validator_hook_real_world_example_negative(self):

View File

@@ -140,8 +140,8 @@ class TestQueryContextFactory:
existing_columns = {"existing_col"}
with patch.object(self.factory, "_find_column_definition") as mock_find:
mock_find.side_effect = lambda qo, col: (
f"def_{col}" if col != "tooltip_col2" else None
mock_find.side_effect = (
lambda qo, col: f"def_{col}" if col != "tooltip_col2" else None
)
self.factory._append_missing_tooltip_columns(

View File

@@ -44,8 +44,8 @@ def mock_dataset() -> MagicMock:
dataset.database.get_sqla_engine.return_value.__exit__.return_value = None
# Mock apply_limit_to_sql to return SQL with LIMIT
dataset.database.apply_limit_to_sql = lambda sql, limit, force: (
f"{sql} LIMIT {limit}"
dataset.database.apply_limit_to_sql = (
lambda sql, limit, force: f"{sql} LIMIT {limit}"
)
return dataset

View File

@@ -230,8 +230,8 @@ def test_get_catalog_names(mocker: MockerFixture) -> None:
# StarRocks returns rows with keys: ['Catalog', 'Type', 'Comment']
mock_row_1 = mocker.MagicMock()
mock_row_1.keys.return_value = ["Catalog", "Type", "Comment"]
mock_row_1.__getitem__ = lambda self, key: (
"default_catalog" if key == "Catalog" else None
mock_row_1.__getitem__ = (
lambda self, key: "default_catalog" if key == "Catalog" else None
)
mock_row_2 = mocker.MagicMock()

View File

@@ -93,30 +93,36 @@ class TestBigNumberChartConfig:
"chart_type": "big_number",
"metric": {"name": "total_sales", "saved_metric": True},
}
is_valid, error = SchemaValidator._pre_validate_big_number_config(data)
is_valid, error = SchemaValidator._pre_validate_chart_type("big_number", data)
assert is_valid is True
assert error is None
def test_sql_expression_with_label_passes_pre_validation(self) -> None:
"""A custom SQL metric is a valid third option alongside aggregate and
saved_metric in Tier-1 validation."""
from superset.mcp_service.chart.registry import get_registry
data = {
"chart_type": "big_number",
"metric": {"sql_expression": "SUM(a)/SUM(b)", "label": "Ratio"},
}
is_valid, error = SchemaValidator._pre_validate_big_number_config(data)
assert is_valid is True
plugin = get_registry().get("big_number")
assert plugin is not None
error = plugin.pre_validate(data)
assert error is None
def test_sql_expression_without_label_fails_pre_validation(self) -> None:
"""Tier-1 surfaces the label-required error with an LLM-actionable
suggestion before the request reaches Pydantic's stricter error."""
from superset.mcp_service.chart.registry import get_registry
data = {
"chart_type": "big_number",
"metric": {"sql_expression": "SUM(a)/SUM(b)"},
}
is_valid, error = SchemaValidator._pre_validate_big_number_config(data)
assert is_valid is False
plugin = get_registry().get("big_number")
assert plugin is not None
error = plugin.pre_validate(data)
assert error is not None
assert error.error_code == "MISSING_SQL_METRIC_LABEL"
@@ -124,12 +130,15 @@ class TestBigNumberChartConfig:
"""Pre-validation runs on raw dict input before Pydantic coercion, so
a non-string ``label`` (e.g. an int from a buggy client) must surface
as a validation error, not an AttributeError from ``.strip()``."""
from superset.mcp_service.chart.registry import get_registry
data = {
"chart_type": "big_number",
"metric": {"sql_expression": "SUM(a)/SUM(b)", "label": 123},
}
is_valid, error = SchemaValidator._pre_validate_big_number_config(data)
assert is_valid is False
plugin = get_registry().get("big_number")
assert plugin is not None
error = plugin.pre_validate(data)
assert error is not None
assert error.error_code == "MISSING_SQL_METRIC_LABEL"

View File

@@ -23,7 +23,9 @@ import pytest
from pydantic import ValidationError
from superset.mcp_service.chart.schemas import (
BigNumberChartConfig,
ColumnRef,
FilterConfig,
GenerateChartRequest,
GenerateChartResponse,
MixedTimeseriesChartConfig,
@@ -49,6 +51,44 @@ class TestGenerateChartResponse:
assert response.chart_type_label == "table chart"
class TestColumnNameSanitization:
"""Test relaxed column names retain SQL-injection protection."""
def test_column_ref_rejects_sql_injection(self) -> None:
"""ColumnRef rejects SQL injection patterns."""
with pytest.raises(ValidationError, match="potentially unsafe"):
ColumnRef(name="revenue; DROP TABLE users")
def test_filter_column_rejects_sql_injection(self) -> None:
"""FilterConfig.column rejects SQL injection patterns."""
with pytest.raises(ValidationError, match="potentially unsafe"):
FilterConfig(column="status; DROP TABLE users", op="=", value="active")
def test_temporal_column_rejects_sql_injection(self) -> None:
"""BigNumberChartConfig.temporal_column rejects SQL injection patterns."""
with pytest.raises(ValidationError, match="potentially unsafe"):
BigNumberChartConfig(
chart_type="big_number",
metric={"name": "revenue", "aggregate": "SUM"},
show_trendline=True,
temporal_column="created_at; DROP TABLE users",
)
def test_relaxed_column_names_still_pass(self) -> None:
"""Digit-prefixed, dotted, and hyphenated column names are accepted."""
assert ColumnRef(name="1Q_revenue").name == "1Q_revenue"
assert FilterConfig(column="order-date", op="=", value="active").column == (
"order-date"
)
config = BigNumberChartConfig(
chart_type="big_number",
metric={"name": "revenue", "aggregate": "SUM"},
show_trendline=True,
temporal_column="events.created-at",
)
assert config.temporal_column == "events.created-at"
class TestTableChartConfig:
"""Test TableChartConfig validation."""
@@ -926,6 +966,14 @@ class TestSqlExpressionRejectedOnDimensionPositions:
metric=ColumnRef(name="sales", aggregate="SUM"),
)
def test_pie_config_rejects_saved_metric_on_dimension(self) -> None:
with pytest.raises(ValidationError):
PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="Total Revenue", saved_metric=True),
metric=ColumnRef(name="sales", aggregate="SUM"),
)
def test_pivot_config_rejects_sql_expression_on_rows(self) -> None:
with pytest.raises(ValidationError):
PivotTableChartConfig(

View File

@@ -1116,7 +1116,7 @@ class TestGenerateChartName:
def test_unsupported_config_type(self) -> None:
"""Unsupported config type returns generic name."""
result = generate_chart_name("invalid_config") # type: ignore
result = generate_chart_name("invalid_config")
assert result == "Chart"
def test_custom_labels_used(self) -> None:
@@ -2155,7 +2155,7 @@ class TestDatasetValidatorSkipsSqlMetrics:
def test_normalize_column_names_skips_sql_metric_dicts(self) -> None:
"""A SQL-metric ColumnRef dumps to {name: None, sql_expression: ...};
_get_canonical_column_name(None, ...) would crash without the guard."""
get_canonical_column_name(None, ...) would crash without the guard."""
from superset.mcp_service.chart.validation.dataset_validator import (
DatasetValidator,
)

View File

@@ -28,7 +28,6 @@ from unittest.mock import Mock, patch
import pytest
from superset.mcp_service.chart.compile import (
build_dataset_context_from_orm,
CompileResult,
validate_and_compile,
)
@@ -41,6 +40,9 @@ from superset.mcp_service.chart.schemas import (
TableChartConfig,
XYChartConfig,
)
from superset.mcp_service.chart.validation.dataset_validator import (
build_dataset_context_from_orm,
)
def _orm_dataset(

View File

@@ -0,0 +1,252 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for the chart type plugin registry."""
import sys
import threading
from types import ModuleType
import pytest
import superset.mcp_service.chart.registry as registry_module
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.registry import (
_RegistryProxy,
_reset_for_testing,
all_types,
display_name_for_viz_type,
get,
get_registry,
is_registered,
register,
)
@pytest.fixture(autouse=True)
def _isolated_registry(monkeypatch):
"""Run each test against a clean registry without touching the real one."""
monkeypatch.setattr(registry_module, "_REGISTRY", {})
monkeypatch.setattr(registry_module, "_plugins_loaded", True)
monkeypatch.setattr(registry_module, "_plugins_load_failed", False)
class _FakePlugin(BaseChartPlugin):
chart_type = "fake"
display_name = "Fake Chart"
native_viz_types = {"fake_viz": "Fake Viz"}
class _AnotherPlugin(BaseChartPlugin):
chart_type = "another"
display_name = "Another Chart"
native_viz_types = {"another_viz": "Another Viz"}
def test_register_adds_plugin():
plugin = _FakePlugin()
register(plugin)
assert get("fake") is plugin
def test_get_returns_none_for_unknown():
assert get("nonexistent") is None
def test_all_types_returns_registered_keys():
register(_FakePlugin())
register(_AnotherPlugin())
types = all_types()
assert "fake" in types
assert "another" in types
def test_all_types_insertion_order():
register(_FakePlugin())
register(_AnotherPlugin())
types = all_types()
assert types.index("fake") < types.index("another")
def test_is_registered_true_for_known():
register(_FakePlugin())
assert is_registered("fake") is True
def test_is_registered_false_for_unknown():
assert is_registered("nonexistent") is False
def test_register_warns_on_duplicate(caplog):
register(_FakePlugin())
with caplog.at_level("WARNING"):
register(_FakePlugin())
assert "Overwriting" in caplog.text
def test_register_warns_on_viz_type_collision(caplog):
register(_FakePlugin())
class _CollidingPlugin(BaseChartPlugin):
chart_type = "colliding"
display_name = "Colliding Chart"
native_viz_types = {"fake_viz": "Shadowed Viz", "own_viz": "Own Viz"}
with caplog.at_level("WARNING"):
register(_CollidingPlugin())
assert "already claimed by" in caplog.text
assert "fake_viz" in caplog.text
# Earlier registration wins in display-name lookups
assert display_name_for_viz_type("fake_viz") == "Fake Viz"
def test_register_same_plugin_no_collision_warning(caplog):
register(_FakePlugin())
with caplog.at_level("WARNING"):
register(_FakePlugin())
assert "already claimed by" not in caplog.text
def test_register_raises_for_empty_chart_type():
class _BadPlugin(BaseChartPlugin):
chart_type = ""
with pytest.raises(ValueError, match="non-empty chart_type"):
register(_BadPlugin())
def test_display_name_for_viz_type_found():
register(_FakePlugin())
assert display_name_for_viz_type("fake_viz") == "Fake Viz"
def test_display_name_for_viz_type_not_found():
register(_FakePlugin())
assert display_name_for_viz_type("unknown_viz") is None
def test_display_name_searches_all_plugins():
register(_FakePlugin())
register(_AnotherPlugin())
assert display_name_for_viz_type("another_viz") == "Another Viz"
def test_get_registry_returns_proxy():
assert isinstance(get_registry(), _RegistryProxy)
def test_plugins_lock_allows_register_during_lazy_import():
"""The registry lock is re-entrant for plugin registration during import."""
assert isinstance(registry_module._plugins_lock, type(threading.RLock()))
def test_registry_proxy_get():
plugin = _FakePlugin()
register(plugin)
assert get_registry().get("fake") is plugin
def test_registry_proxy_all_types():
register(_FakePlugin())
assert "fake" in get_registry().all_types()
def test_registry_proxy_is_registered():
register(_FakePlugin())
assert get_registry().is_registered("fake") is True
assert get_registry().is_registered("missing") is False
def test_registry_proxy_display_name_for_viz_type():
register(_FakePlugin())
assert get_registry().display_name_for_viz_type("fake_viz") == "Fake Viz"
assert get_registry().display_name_for_viz_type("unknown") is None
def test_ensure_plugins_loaded_skips_when_load_failed(monkeypatch):
"""_ensure_plugins_loaded returns immediately when _plugins_load_failed is set."""
from superset.mcp_service.chart.registry import _ensure_plugins_loaded
monkeypatch.setattr(registry_module, "_plugins_loaded", False)
monkeypatch.setattr(registry_module, "_plugins_load_failed", True)
# If the function tried to import, the real plugins module would load and flip
# _plugins_loaded to True. The circuit breaker should prevent that.
_ensure_plugins_loaded()
assert registry_module._plugins_loaded is False
def test_ensure_plugins_loaded_sets_failed_flag_on_error(monkeypatch):
"""A failed import sets _plugins_load_failed so subsequent calls are no-ops."""
from unittest.mock import patch
from superset.mcp_service.chart.registry import _ensure_plugins_loaded
monkeypatch.setattr(registry_module, "_plugins_loaded", False)
monkeypatch.setattr(registry_module, "_plugins_load_failed", False)
monkeypatch.setattr(registry_module, "_plugins_lock", threading.Lock())
# Setting the module to None in sys.modules causes ImportError on import.
with patch.dict("sys.modules", {"superset.mcp_service.chart.plugins": None}):
_ensure_plugins_loaded()
assert registry_module._plugins_load_failed is True
assert registry_module._plugins_loaded is False
def test_ensure_plugins_loaded_rolls_back_partial_registration(monkeypatch):
"""A failed lazy import restores the registry to its previous state."""
from superset.mcp_service.chart.registry import _ensure_plugins_loaded
original_import = __import__
existing_plugin = _FakePlugin()
partial_plugin = _AnotherPlugin()
monkeypatch.setattr(registry_module, "_REGISTRY", {"fake": existing_plugin})
monkeypatch.setattr(registry_module, "_plugins_loaded", False)
monkeypatch.setattr(registry_module, "_plugins_load_failed", False)
def fail_plugin_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == "superset.mcp_service.chart.plugins":
register(partial_plugin)
raise RuntimeError("plugin import failed")
return original_import(name, globals, locals, fromlist, level)
monkeypatch.setattr("builtins.__import__", fail_plugin_import)
_ensure_plugins_loaded()
assert registry_module._plugins_load_failed is True
assert registry_module._REGISTRY == {"fake": existing_plugin}
def test_reset_for_testing_clears_cached_plugins_package(monkeypatch):
"""Reset removes the plugins package so lazy loading can re-run registration."""
module_name = "superset.mcp_service.chart.plugins"
monkeypatch.setitem(sys.modules, module_name, ModuleType(module_name))
monkeypatch.setattr(registry_module, "_REGISTRY", {"fake": _FakePlugin()})
monkeypatch.setattr(registry_module, "_plugins_loaded", True)
monkeypatch.setattr(registry_module, "_plugins_load_failed", True)
_reset_for_testing()
assert registry_module._REGISTRY == {}
assert registry_module._plugins_loaded is False
assert registry_module._plugins_load_failed is False
assert module_name not in sys.modules

View File

@@ -0,0 +1,222 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for registry plugin filtering (configure / is_enabled / get / all_types)."""
import pytest
import superset.mcp_service.chart.registry as registry_module
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.registry import (
_PluginFilterConfig,
all_types,
configure,
display_name_for_viz_type,
get,
is_enabled,
is_registered,
register,
)
@pytest.fixture(autouse=True)
def _isolated_registry(monkeypatch):
"""Isolated registry with two known plugins and a clean filter for each test."""
monkeypatch.setattr(registry_module, "_REGISTRY", {})
monkeypatch.setattr(registry_module, "_plugins_loaded", True)
monkeypatch.setattr(registry_module, "_filter_config", _PluginFilterConfig())
register(_AlphaPlugin())
register(_BetaPlugin())
class _AlphaPlugin(BaseChartPlugin):
chart_type = "alpha"
display_name = "Alpha Chart"
native_viz_types = {"alpha_viz": "Alpha Viz"}
class _BetaPlugin(BaseChartPlugin):
chart_type = "beta"
display_name = "Beta Chart"
native_viz_types = {"beta_viz": "Beta Viz"}
# ---------------------------------------------------------------------------
# Static deny-list tests
# ---------------------------------------------------------------------------
def test_get_returns_plugin_when_enabled():
assert get("alpha") is not None
assert get("beta") is not None
def test_get_returns_none_for_disabled_plugin():
configure(disabled={"alpha"})
assert get("alpha") is None
def test_get_still_returns_other_plugins_when_one_is_disabled():
configure(disabled={"alpha"})
assert get("beta") is not None
def test_all_types_excludes_disabled():
configure(disabled={"alpha"})
types = all_types()
assert "alpha" not in types
assert "beta" in types
def test_all_types_empty_when_all_disabled():
configure(disabled={"alpha", "beta"})
assert all_types() == []
def test_is_registered_ignores_deny_list():
configure(disabled={"alpha"})
assert is_registered("alpha") is True
def test_is_enabled_returns_false_for_disabled():
configure(disabled={"alpha"})
assert is_enabled("alpha") is False
def test_is_enabled_returns_true_when_not_disabled():
configure(disabled={"alpha"})
assert is_enabled("beta") is True
def test_is_enabled_returns_false_for_unknown():
assert is_enabled("nonexistent") is False
# ---------------------------------------------------------------------------
# configure() accepts different iterable shapes
# ---------------------------------------------------------------------------
def test_configure_accepts_list():
configure(disabled=["alpha"])
assert get("alpha") is None
def test_configure_accepts_tuple():
configure(disabled=("alpha",))
assert get("alpha") is None
def test_configure_accepts_frozenset():
configure(disabled=frozenset({"alpha"}))
assert get("alpha") is None
def test_configure_accepts_none_disabled():
configure(disabled=None)
assert get("alpha") is not None
def test_configure_rejects_noncallable_enabled_func():
with pytest.raises(TypeError):
configure(enabled_func="not_a_callable")
# ---------------------------------------------------------------------------
# Dynamic callable hook tests
# ---------------------------------------------------------------------------
def test_enabled_func_overrides_deny_list():
# alpha is in deny-list but callable says True → should be visible
configure(disabled={"alpha"}, enabled_func=lambda ct: ct == "alpha")
assert get("alpha") is not None
def test_enabled_func_can_disable_plugin():
configure(enabled_func=lambda ct: ct != "beta")
assert get("beta") is None
assert get("alpha") is not None
def test_enabled_func_called_per_lookup():
calls = []
def hook(ct: str) -> bool:
calls.append(ct)
return True
configure(enabled_func=hook)
get("alpha")
get("alpha")
assert calls.count("alpha") == 2
def test_enabled_func_exception_fails_closed(caplog):
import logging
def bad_hook(ct: str) -> bool:
raise RuntimeError("Harness down")
configure(enabled_func=bad_hook)
with caplog.at_level(logging.WARNING, logger="superset.mcp_service.chart.registry"):
result = get("alpha")
assert result is None # fail closed
assert "failing closed" in caplog.text.lower() or "alpha" in caplog.text
def test_enabled_func_all_types_respects_hook():
configure(enabled_func=lambda ct: ct == "alpha")
assert all_types() == ["alpha"]
# ---------------------------------------------------------------------------
# display_name_for_viz_type is NOT filtered
# ---------------------------------------------------------------------------
def test_display_name_unaffected_by_deny_list():
configure(disabled={"alpha"})
# Even though alpha is disabled, its viz_type should still resolve
assert display_name_for_viz_type("alpha_viz") == "Alpha Viz"
def test_display_name_unaffected_by_callable():
configure(enabled_func=lambda ct: False)
assert display_name_for_viz_type("beta_viz") == "Beta Viz"
# ---------------------------------------------------------------------------
# configure() atomicity: replacing config is visible to next lookup
# ---------------------------------------------------------------------------
def test_reconfigure_replaces_previous_filter():
configure(disabled={"alpha"})
assert get("alpha") is None
configure(disabled=set())
assert get("alpha") is not None
def test_reconfigure_with_func_then_none_falls_back_to_deny_list():
configure(enabled_func=lambda ct: False)
assert get("alpha") is None
configure(disabled={"beta"}, enabled_func=None)
assert get("alpha") is not None
assert get("beta") is None

View File

@@ -580,3 +580,31 @@ class TestGetChartInfoPrivacy:
)
assert result is error
def test_apply_unsaved_state_override_updates_display_name_for_new_viz_type() -> None:
"""Stale display name is recomputed when viz_type is overridden from form_data."""
module = get_chart_info_module
result = ChartInfo(
id=1,
slice_name="My Chart",
viz_type="table",
chart_type_display_name="Table",
)
with (
patch.object(
module,
"get_cached_form_data",
return_value='{"viz_type": "pie"}',
),
patch(
"superset.mcp_service.chart.registry.display_name_for_viz_type",
return_value="Pie Chart",
),
):
module._apply_unsaved_state_override(result, "key")
assert result.viz_type == "pie"
assert result.chart_type_display_name == "Pie Chart"

View File

@@ -1177,6 +1177,11 @@ class TestUpdateChartValidationGate:
)
@patch("superset.daos.chart.ChartDAO.find_by_id", new_callable=Mock)
@patch("superset.db.session")
@patch(
"superset.mcp_service.chart.validation.dataset_validator"
".DatasetValidator.validate_against_dataset",
new=Mock(return_value=(True, None)),
)
@pytest.mark.asyncio
async def test_preview_path_validation_failure_skips_cache(
self,
@@ -1240,6 +1245,11 @@ class TestUpdateChartValidationGate:
)
@patch("superset.daos.chart.ChartDAO.find_by_id", new_callable=Mock)
@patch("superset.db.session")
@patch(
"superset.mcp_service.chart.validation.dataset_validator"
".DatasetValidator.validate_against_dataset",
new=Mock(return_value=(True, None)),
)
@pytest.mark.asyncio
async def test_persist_path_validation_failure_skips_db_write(
self,
@@ -1290,6 +1300,237 @@ class TestUpdateChartValidationGate:
mock_update_cmd_cls.assert_not_called()
# ---------------------------------------------------------------------------
# Column normalization in update_chart
# ---------------------------------------------------------------------------
class TestUpdateChartColumnNormalization:
"""Column names are normalized to dataset canonical case before validation."""
@staticmethod
def _mock_chart(datasource_id: int | None = 10) -> Mock:
chart = Mock()
chart.id = 1
chart.datasource_id = datasource_id
chart.slice_name = "Existing"
chart.viz_type = "table"
chart.uuid = "abc-123"
chart.params = '{"viz_type": "table", "datasource": "10__table"}'
chart.datasource = Mock()
return chart
@patch.object(update_chart_module, "validate_and_compile")
@patch.object(update_chart_module, "_create_preview_url", new_callable=Mock)
@patch("superset.mcp_service.auth.check_chart_data_access", new_callable=Mock)
@patch("superset.daos.chart.ChartDAO.find_by_id", new_callable=Mock)
@patch("superset.db.session")
@patch(
"superset.mcp_service.chart.validation.dataset_validator"
".DatasetValidator.validate_against_dataset",
new=Mock(return_value=(True, None)),
)
@patch(
"superset.mcp_service.chart.validation.dataset_validator"
".DatasetValidator.normalize_column_names",
)
@pytest.mark.asyncio
async def test_normalization_called_with_guarded_datasource_id(
self,
mock_normalize,
mock_db_session,
mock_find_by_id,
mock_check_access,
mock_create_preview,
mock_validate,
mcp_server,
):
"""normalize_column_names receives the locally-guarded datasource_id, not
chart.datasource_id re-accessed from the ORM object (which could raise on
mock/partial objects)."""
from superset.mcp_service.chart.compile import CompileResult
chart = self._mock_chart(datasource_id=10)
mock_find_by_id.return_value = chart
mock_check_access.return_value = DatasetValidationResult(
is_valid=True, dataset_id=10, dataset_name="ds", warnings=[]
)
mock_validate.return_value = CompileResult(
success=True, error=None, error_code=None, tier="validation", error_obj=None
)
mock_create_preview.return_value = ("http://example.com/explore", None, [])
# normalize_column_names returns the config unchanged
def _passthrough(config, dataset_id):
return config
mock_normalize.side_effect = _passthrough
request = {
"identifier": 1,
"config": {
"chart_type": "xy",
"x": {"name": "ds"},
"y": [{"name": "num_boys", "aggregate": "SUM"}],
"kind": "line",
},
}
async with Client(mcp) as client:
await client.call_tool("update_chart", {"request": request})
mock_normalize.assert_called_once()
_, call_dataset_id = mock_normalize.call_args.args
assert call_dataset_id == 10 # guarded local var, not re-read from ORM
@patch.object(update_chart_module, "validate_and_compile")
@patch.object(update_chart_module, "_create_preview_url", new_callable=Mock)
@patch("superset.mcp_service.auth.check_chart_data_access", new_callable=Mock)
@patch("superset.daos.chart.ChartDAO.find_by_id", new_callable=Mock)
@patch("superset.db.session")
@patch(
"superset.mcp_service.chart.validation.dataset_validator"
".DatasetValidator.validate_against_dataset",
new=Mock(return_value=(True, None)),
)
@patch(
"superset.mcp_service.chart.validation.dataset_validator"
".DatasetValidator.normalize_column_names",
)
@pytest.mark.asyncio
async def test_normalization_exception_is_caught_gracefully(
self,
mock_normalize,
mock_db_session,
mock_find_by_id,
mock_check_access,
mock_create_preview,
mock_validate,
mcp_server,
):
"""A normalization failure must not propagate — chart update continues."""
from superset.mcp_service.chart.compile import CompileResult
chart = self._mock_chart(datasource_id=10)
mock_find_by_id.return_value = chart
mock_check_access.return_value = DatasetValidationResult(
is_valid=True, dataset_id=10, dataset_name="ds", warnings=[]
)
mock_validate.return_value = CompileResult(
success=True, error=None, error_code=None, tier="validation", error_obj=None
)
mock_create_preview.return_value = ("http://example.com/explore", None, [])
mock_normalize.side_effect = ValueError("DB connection failed")
request = {
"identifier": 1,
"config": {
"chart_type": "xy",
"x": {"name": "ds"},
"y": [{"name": "num_boys", "aggregate": "SUM"}],
"kind": "line",
},
}
async with Client(mcp) as client:
# Should not raise; normalization failure is a warning only
await client.call_tool("update_chart", {"request": request})
# Normalization failed but tool still attempted the update path
mock_normalize.assert_called_once()
@patch(
"superset.mcp_service.chart.validation.dataset_validator"
".DatasetValidator.normalize_column_names",
)
def test_normalization_skipped_when_no_datasource_id(self, mock_normalize):
"""normalize_column_names is never called when chart has no datasource_id."""
from superset.mcp_service.chart.schemas import XYChartConfig
chart = self._mock_chart(datasource_id=None)
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="ds"),
y=[ColumnRef(name="num_boys", aggregate="SUM")],
kind="line",
)
# Simulate the guard from update_chart
chart_datasource_id = getattr(chart, "datasource_id", None)
if config is not None and chart_datasource_id is not None:
from superset.mcp_service.chart.validation.dataset_validator import (
DatasetValidator,
)
DatasetValidator.normalize_column_names(config, chart_datasource_id)
mock_normalize.assert_not_called()
@patch.object(update_chart_module, "validate_and_compile")
@patch.object(update_chart_module, "_create_preview_url", new_callable=Mock)
@patch("superset.mcp_service.auth.check_chart_data_access", new_callable=Mock)
@patch("superset.daos.chart.ChartDAO.find_by_id", new_callable=Mock)
@patch("superset.db.session")
@patch(
"superset.mcp_service.chart.validation.dataset_validator"
".DatasetValidator.validate_against_dataset",
new=Mock(return_value=(True, None)),
)
@patch(
"superset.mcp_service.chart.validation.dataset_validator"
".DatasetValidator.normalize_column_names",
)
@pytest.mark.asyncio
async def test_normalization_uses_request_dataset_id_when_rebinding(
self,
mock_normalize,
mock_db_session,
mock_find_by_id,
mock_check_access,
mock_create_preview,
mock_validate,
mcp_server,
):
"""When dataset_id is in the request, normalization must use it — not the
chart's current datasource — so column names are resolved against the
target schema after rebind."""
from superset.mcp_service.chart.compile import CompileResult
chart = self._mock_chart(datasource_id=10)
mock_find_by_id.return_value = chart
mock_check_access.return_value = DatasetValidationResult(
is_valid=True, dataset_id=99, dataset_name="new_ds", warnings=[]
)
mock_validate.return_value = CompileResult(
success=True, error=None, error_code=None, tier="validation", error_obj=None
)
mock_create_preview.return_value = ("http://example.com/explore", None, [])
def _passthrough(config, dataset_id):
return config
mock_normalize.side_effect = _passthrough
request = {
"identifier": 1,
"dataset_id": 99,
"config": {
"chart_type": "xy",
"x": {"name": "ds"},
"y": [{"name": "num_boys", "aggregate": "SUM"}],
"kind": "line",
},
}
async with Client(mcp) as client:
await client.call_tool("update_chart", {"request": request})
mock_normalize.assert_called_once()
_, call_dataset_id = mock_normalize.call_args.args
# Must use the request's dataset_id (99), not the chart's current one (10)
assert call_dataset_id == 99
# ---------------------------------------------------------------------------
# Custom SQL metrics (sql_expression) — Ticket #3, update_chart side.
# ---------------------------------------------------------------------------

View File

@@ -30,6 +30,7 @@ import pytest
from superset.mcp_service.chart.schemas import (
ColumnRef,
FilterConfig,
PivotTableChartConfig,
TableChartConfig,
XYChartConfig,
)
@@ -58,13 +59,13 @@ def mock_dataset_context() -> DatasetContext:
class TestGetCanonicalColumnName:
"""Test _get_canonical_column_name static method."""
"""Test get_canonical_column_name static method."""
def test_exact_match_returns_same_name(
self, mock_dataset_context: DatasetContext
) -> None:
"""Test that exact match returns the same column name."""
result = DatasetValidator._get_canonical_column_name(
result = DatasetValidator.get_canonical_column_name(
"OrderDate", mock_dataset_context
)
assert result == "OrderDate"
@@ -73,7 +74,7 @@ class TestGetCanonicalColumnName:
self, mock_dataset_context: DatasetContext
) -> None:
"""Test that lowercase input returns the canonical (dataset) column name."""
result = DatasetValidator._get_canonical_column_name(
result = DatasetValidator.get_canonical_column_name(
"orderdate", mock_dataset_context
)
assert result == "OrderDate"
@@ -84,7 +85,7 @@ class TestGetCanonicalColumnName:
"""Test that snake_case input returns the canonical column name."""
# 'order_date' won't match 'OrderDate' directly, but would match if
# the dataset had 'order_date'. This test verifies case-insensitive matching.
result = DatasetValidator._get_canonical_column_name(
result = DatasetValidator.get_canonical_column_name(
"productline", mock_dataset_context
)
assert result == "ProductLine"
@@ -93,7 +94,7 @@ class TestGetCanonicalColumnName:
self, mock_dataset_context: DatasetContext
) -> None:
"""Test that uppercase input returns the canonical column name."""
result = DatasetValidator._get_canonical_column_name(
result = DatasetValidator.get_canonical_column_name(
"SALES", mock_dataset_context
)
assert result == "Sales"
@@ -102,7 +103,7 @@ class TestGetCanonicalColumnName:
self, mock_dataset_context: DatasetContext
) -> None:
"""Test that metric names are also normalized."""
result = DatasetValidator._get_canonical_column_name(
result = DatasetValidator.get_canonical_column_name(
"totalrevenue", mock_dataset_context
)
assert result == "TotalRevenue"
@@ -111,91 +112,14 @@ class TestGetCanonicalColumnName:
self, mock_dataset_context: DatasetContext
) -> None:
"""Test that unknown columns return the original name."""
result = DatasetValidator._get_canonical_column_name(
result = DatasetValidator.get_canonical_column_name(
"unknown_column", mock_dataset_context
)
assert result == "unknown_column"
class TestNormalizeXYConfig:
"""Test _normalize_xy_config static method."""
def test_normalize_x_axis_column(
self, mock_dataset_context: DatasetContext
) -> None:
"""Test that x-axis column name is normalized."""
config_dict: Dict[str, Any] = {
"chart_type": "xy",
"x": {"name": "orderdate"},
"y": [{"name": "Sales", "aggregate": "SUM"}],
"kind": "line",
}
DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context)
assert config_dict["x"]["name"] == "OrderDate"
def test_normalize_y_axis_columns(
self, mock_dataset_context: DatasetContext
) -> None:
"""Test that y-axis column names are normalized."""
config_dict: Dict[str, Any] = {
"chart_type": "xy",
"x": {"name": "OrderDate"},
"y": [
{"name": "sales", "aggregate": "SUM"},
{"name": "QUANTITY_ORDERED", "aggregate": "COUNT"},
],
"kind": "bar",
}
DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context)
assert config_dict["y"][0]["name"] == "Sales"
assert config_dict["y"][1]["name"] == "quantity_ordered"
def test_normalize_group_by_column(
self, mock_dataset_context: DatasetContext
) -> None:
"""Test that group_by column name is normalized."""
config_dict: Dict[str, Any] = {
"chart_type": "xy",
"x": {"name": "OrderDate"},
"y": [{"name": "Sales", "aggregate": "SUM"}],
"kind": "line",
"group_by": [{"name": "productline"}],
}
DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context)
assert config_dict["group_by"][0]["name"] == "ProductLine"
class TestNormalizeTableConfig:
"""Test _normalize_table_config static method."""
def test_normalize_table_columns(
self, mock_dataset_context: DatasetContext
) -> None:
"""Test that table column names are normalized."""
config_dict: Dict[str, Any] = {
"chart_type": "table",
"columns": [
{"name": "orderdate"},
{"name": "PRODUCTLINE"},
{"name": "sales", "aggregate": "SUM"},
],
}
DatasetValidator._normalize_table_config(config_dict, mock_dataset_context)
assert config_dict["columns"][0]["name"] == "OrderDate"
assert config_dict["columns"][1]["name"] == "ProductLine"
assert config_dict["columns"][2]["name"] == "Sales"
class TestNormalizeFilters:
"""Test _normalize_filters static method."""
"""Test normalize_filters static method."""
def test_normalize_filter_columns(
self, mock_dataset_context: DatasetContext
@@ -208,7 +132,7 @@ class TestNormalizeFilters:
],
}
DatasetValidator._normalize_filters(config_dict, mock_dataset_context)
DatasetValidator.normalize_filters(config_dict, mock_dataset_context)
assert config_dict["filters"][0]["column"] == "ProductLine"
assert config_dict["filters"][1]["column"] == "OrderDate"
@@ -264,6 +188,58 @@ class TestNormalizeColumnNames:
assert normalized.columns[1].name == "ProductLine"
assert normalized.columns[2].name == "Sales"
@patch.object(DatasetValidator, "_get_dataset_context")
def test_table_sql_expression_column_skips_name_normalization(
self, mock_get_context, mock_dataset_context: DatasetContext
) -> None:
"""sql_expression columns have name=None; normalization must skip them."""
mock_get_context.return_value = mock_dataset_context
config = TableChartConfig(
chart_type="table",
columns=[
ColumnRef(name="orderdate"),
ColumnRef(sql_expression="SUM(sales)/COUNT(*)", label="Avg Sale"),
],
)
# Must not raise — get_canonical_column_name(None, ...) crashes without guard.
normalized = DatasetValidator.normalize_column_names(config, dataset_id=18)
assert normalized.columns[0].name == "OrderDate"
assert normalized.columns[1].sql_expression == "SUM(sales)/COUNT(*)"
assert normalized.columns[1].name is None
@patch.object(DatasetValidator, "_get_dataset_context")
def test_handlebars_sql_expression_metric_skips_name_normalization(
self, mock_get_context, mock_dataset_context: DatasetContext
) -> None:
"""sql_expression metrics in handlebars charts must not cause a crash.
HandlebarsChartConfig rejects sql_expression on columns/groupby, but
allows it on metrics; that is the live code path where name=None can occur.
"""
from superset.mcp_service.chart.schemas import HandlebarsChartConfig
mock_get_context.return_value = mock_dataset_context
config = HandlebarsChartConfig(
chart_type="handlebars",
handlebars_template="{{col}}",
query_mode="aggregate",
groupby=[ColumnRef(name="orderdate")],
metrics=[
ColumnRef(name="sales", aggregate="SUM"),
ColumnRef(sql_expression="COUNT(DISTINCT id)", label="Unique IDs"),
],
)
normalized = DatasetValidator.normalize_column_names(config, dataset_id=18)
assert normalized.groupby is not None
assert normalized.groupby[0].name == "OrderDate"
assert normalized.metrics is not None
assert normalized.metrics[1].sql_expression == "COUNT(DISTINCT id)"
assert normalized.metrics[1].name is None
@patch.object(DatasetValidator, "_get_dataset_context")
def test_returns_original_when_dataset_not_found(self, mock_get_context) -> None:
"""Test that original config is returned when dataset context is unavailable."""
@@ -742,3 +718,378 @@ class TestValidateSavedMetrics:
assert not is_valid
assert error is not None
assert error.error_code == "INVALID_SAVED_METRIC"
class TestGetCanonicalMetricName:
"""Tests for get_canonical_metric_name — metrics-only lookup."""
def test_exact_match(self, mock_dataset_context: DatasetContext) -> None:
result = DatasetValidator.get_canonical_metric_name(
"TotalRevenue", mock_dataset_context
)
assert result == "TotalRevenue"
def test_case_insensitive_match(self, mock_dataset_context: DatasetContext) -> None:
result = DatasetValidator.get_canonical_metric_name(
"totalrevenue", mock_dataset_context
)
assert result == "TotalRevenue"
def test_unknown_metric_returns_original(
self, mock_dataset_context: DatasetContext
) -> None:
result = DatasetValidator.get_canonical_metric_name(
"no_such_metric", mock_dataset_context
)
assert result == "no_such_metric"
def test_column_name_not_matched(
self, mock_dataset_context: DatasetContext
) -> None:
"""A name that matches a column but not a metric returns the original."""
result = DatasetValidator.get_canonical_metric_name(
"Sales", mock_dataset_context
)
assert result == "Sales"
@pytest.fixture
def collision_dataset_context() -> DatasetContext:
"""Dataset where a column and a metric share the same case-insensitive name
but have different casing — the scenario that exposed the saved-metric bug."""
return DatasetContext(
id=99,
table_name="sales_data",
schema="public",
database_name="examples",
available_columns=[
{"name": "totalrevenue", "type": "DECIMAL", "is_numeric": True},
],
available_metrics=[
{
"name": "TotalRevenue",
"expression": "SUM(amount)",
"description": None,
},
],
)
class TestSavedMetricNormalizationCorrectness:
"""Saved metrics must resolve against available_metrics, not available_columns.
When a column and a metric share the same case-insensitive name but have
different casing, get_canonical_column_name (columns-first) returns the
column's casing. For saved_metric=True refs this is wrong — downstream
metric resolution is exact-name based and expects the metric's casing.
"""
@patch.object(DatasetValidator, "_get_dataset_context")
def test_xy_saved_metric_uses_metric_casing(
self,
mock_get_context: Any,
collision_dataset_context: DatasetContext,
) -> None:
mock_get_context.return_value = collision_dataset_context
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="totalrevenue"),
y=[ColumnRef(name="totalrevenue", saved_metric=True)],
)
normalized = DatasetValidator.normalize_column_names(config, dataset_id=99)
# x is a regular column ref — gets column casing
assert normalized.x is not None
assert normalized.x.name == "totalrevenue"
# y is a saved metric — must get metric casing, not column casing
assert normalized.y[0].name == "TotalRevenue"
@patch.object(DatasetValidator, "_get_dataset_context")
def test_table_saved_metric_uses_metric_casing(
self,
mock_get_context: Any,
collision_dataset_context: DatasetContext,
) -> None:
from superset.mcp_service.chart.schemas import TableChartConfig
mock_get_context.return_value = collision_dataset_context
config = TableChartConfig(
chart_type="table",
columns=[
ColumnRef(name="totalrevenue"),
ColumnRef(name="totalrevenue", saved_metric=True),
],
)
normalized = DatasetValidator.normalize_column_names(config, dataset_id=99)
assert normalized.columns[0].name == "totalrevenue"
assert normalized.columns[1].name == "TotalRevenue"
@patch.object(DatasetValidator, "_get_dataset_context")
def test_pie_saved_metric_uses_metric_casing(
self,
mock_get_context: Any,
collision_dataset_context: DatasetContext,
) -> None:
from superset.mcp_service.chart.schemas import PieChartConfig
mock_get_context.return_value = collision_dataset_context
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="totalrevenue"),
metric=ColumnRef(name="totalrevenue", saved_metric=True),
)
normalized = DatasetValidator.normalize_column_names(config, dataset_id=99)
assert normalized.dimension.name == "totalrevenue"
assert normalized.metric.name == "TotalRevenue"
@patch.object(DatasetValidator, "_get_dataset_context")
def test_big_number_saved_metric_uses_metric_casing(
self,
mock_get_context: Any,
collision_dataset_context: DatasetContext,
) -> None:
from superset.mcp_service.chart.schemas import BigNumberChartConfig
mock_get_context.return_value = collision_dataset_context
config = BigNumberChartConfig(
chart_type="big_number",
metric=ColumnRef(name="totalrevenue", saved_metric=True),
)
normalized = DatasetValidator.normalize_column_names(config, dataset_id=99)
assert normalized.metric.name == "TotalRevenue"
@patch.object(DatasetValidator, "_get_dataset_context")
def test_mixed_timeseries_saved_metrics_use_metric_casing(
self,
mock_get_context: Any,
collision_dataset_context: DatasetContext,
) -> None:
from superset.mcp_service.chart.schemas import (
ColumnRef,
MixedTimeseriesChartConfig,
)
context = DatasetContext(
id=99,
table_name="sales_data",
schema="public",
database_name="examples",
available_columns=[
{"name": "ds", "type": "TIMESTAMP", "is_temporal": True},
{"name": "totalrevenue", "type": "DECIMAL", "is_numeric": True},
],
available_metrics=[
{
"name": "TotalRevenue",
"expression": "SUM(amount)",
"description": None,
},
{
"name": "OrderCount",
"expression": "COUNT(*)",
"description": None,
},
],
)
mock_get_context.return_value = context
config = MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="ds"),
y=[ColumnRef(name="totalrevenue", saved_metric=True)],
y_secondary=[ColumnRef(name="ordercount", saved_metric=True)],
)
normalized = DatasetValidator.normalize_column_names(config, dataset_id=99)
assert normalized.y[0].name == "TotalRevenue"
assert normalized.y_secondary[0].name == "OrderCount"
class TestPreValidateAliasHandling:
"""pre_validate must accept schema field aliases, not just canonical names."""
def test_xy_pre_validate_accepts_metrics_alias(self) -> None:
from superset.mcp_service.chart.registry import get_registry
plugin = get_registry().get("xy")
assert plugin is not None
config_with_alias = {
"chart_type": "xy",
"metrics": [{"name": "revenue", "aggregate": "SUM"}],
}
error = plugin.pre_validate(config_with_alias)
assert error is None, f"pre_validate rejected 'metrics' alias: {error}"
def test_mixed_timeseries_pre_validate_accepts_x_axis_alias(self) -> None:
from superset.mcp_service.chart.registry import get_registry
plugin = get_registry().get("mixed_timeseries")
assert plugin is not None
config_with_alias = {
"chart_type": "mixed_timeseries",
"x_axis": {"name": "ds"},
"metrics": [{"name": "revenue", "aggregate": "SUM"}],
"metrics_b": [{"name": "orders", "aggregate": "COUNT"}],
}
error = plugin.pre_validate(config_with_alias)
assert error is None, f"pre_validate rejected aliases: {error}"
def test_mixed_timeseries_pre_validate_still_rejects_truly_missing(self) -> None:
from superset.mcp_service.chart.registry import get_registry
plugin = get_registry().get("mixed_timeseries")
assert plugin is not None
config_missing_secondary = {
"chart_type": "mixed_timeseries",
"x": {"name": "ds"},
"y": [{"name": "revenue", "aggregate": "SUM"}],
}
error = plugin.pre_validate(config_missing_secondary)
assert error is not None
class TestNormalizePivotTableColumnRefs:
"""Test normalize_column_refs for PivotTableChartPlugin.
Covers rows, metrics, columns, and filters — the four field groups that
PivotTableChartPlugin.normalize_column_refs() processes.
"""
@patch.object(DatasetValidator, "_get_dataset_context")
def test_normalize_rows_case_mismatch(
self, mock_get_context, mock_dataset_context: DatasetContext
) -> None:
"""Rows with wrong case are normalized to the canonical dataset column name."""
mock_get_context.return_value = mock_dataset_context
config = PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="productline")],
metrics=[ColumnRef(name="sales", aggregate="SUM")],
)
normalized = DatasetValidator.normalize_column_names(config, dataset_id=18)
assert normalized.rows[0].name == "ProductLine"
assert normalized.metrics[0].name == "Sales"
@patch.object(DatasetValidator, "_get_dataset_context")
def test_normalize_columns_case_mismatch(
self, mock_get_context, mock_dataset_context: DatasetContext
) -> None:
"""Optional column-grouping field is normalized when present."""
mock_get_context.return_value = mock_dataset_context
config = PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="ProductLine")],
metrics=[ColumnRef(name="Sales", aggregate="SUM")],
columns=[ColumnRef(name="PRODUCTLINE")],
)
normalized = DatasetValidator.normalize_column_names(config, dataset_id=18)
assert normalized.columns is not None
assert normalized.columns[0].name == "ProductLine"
@patch.object(DatasetValidator, "_get_dataset_context")
def test_normalize_filters_alongside_rows(
self, mock_get_context, mock_dataset_context: DatasetContext
) -> None:
"""Filters are normalized together with rows and metrics."""
mock_get_context.return_value = mock_dataset_context
config = PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="PRODUCTLINE")],
metrics=[ColumnRef(name="QUANTITY_ORDERED", aggregate="SUM")],
filters=[FilterConfig(column="orderdate", op=">", value="2023-01-01")],
)
normalized = DatasetValidator.normalize_column_names(config, dataset_id=18)
assert normalized.rows[0].name == "ProductLine"
assert normalized.metrics[0].name == "quantity_ordered"
assert normalized.filters is not None
assert normalized.filters[0].column == "OrderDate"
@patch.object(DatasetValidator, "_get_dataset_context")
def test_normalize_saved_metric_uses_canonical_metric_name(
self, mock_get_context, mock_dataset_context: DatasetContext
) -> None:
"""A saved_metric=True entry in metrics uses get_canonical_metric_name."""
mock_get_context.return_value = mock_dataset_context
config = PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="ProductLine")],
metrics=[ColumnRef(name="totalrevenue", saved_metric=True)],
)
normalized = DatasetValidator.normalize_column_names(config, dataset_id=18)
assert normalized.metrics[0].name == "TotalRevenue"
@patch.object(DatasetValidator, "_get_dataset_context")
def test_normalize_sql_expression_metric_skipped(
self, mock_get_context, mock_dataset_context: DatasetContext
) -> None:
"""sql_expression metrics are skipped — no AttributeError on name=None."""
mock_get_context.return_value = mock_dataset_context
config = PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="ProductLine")],
metrics=[
ColumnRef(
name=None,
sql_expression="SUM(Sales * 1.1)",
label="Adjusted Sales",
),
ColumnRef(name="sales", aggregate="AVG"),
],
)
normalized = DatasetValidator.normalize_column_names(config, dataset_id=18)
# sql_expression metric is passed through unchanged (name stays None)
assert normalized.metrics[0].name is None
assert normalized.metrics[0].sql_expression == "SUM(Sales * 1.1)"
# ad-hoc metric is normalized
assert normalized.metrics[1].name == "Sales"
@patch.object(DatasetValidator, "_get_dataset_context")
def test_normalize_multiple_rows_and_metrics(
self, mock_get_context, uppercase_dataset_context: DatasetContext
) -> None:
"""Multiple rows and metrics entries are all normalized."""
mock_get_context.return_value = uppercase_dataset_context
config = PivotTableChartConfig(
chart_type="pivot_table",
rows=[
ColumnRef(name="airline"),
ColumnRef(name="distance", aggregate="AVG"),
],
metrics=[
ColumnRef(name="departure_delay", aggregate="AVG"),
ColumnRef(name="arrival_delay", aggregate="SUM"),
],
)
normalized = DatasetValidator.normalize_column_names(config, dataset_id=24)
assert normalized.rows[0].name == "AIRLINE"
assert normalized.rows[1].name == "DISTANCE"
assert normalized.metrics[0].name == "DEPARTURE_DELAY"
assert normalized.metrics[1].name == "ARRIVAL_DELAY"

View File

@@ -58,14 +58,15 @@ class TestRuntimeValidatorNonBlocking:
x_axis=AxisConfig(format="$,.2f"), # Currency format for date - mismatch
)
# Mock the format validator to return warnings
# Mock the plugin runtime dispatcher to return format warnings
with patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_format_compatibility"
) as mock_format:
mock_format.return_value = [
"Currency format '$,.2f' may not display dates correctly"
]
"_validate_plugin_runtime"
) as mock_plugin:
mock_plugin.return_value = (
["Currency format '$,.2f' may not display dates correctly"],
[],
)
is_valid, warnings_metadata = RuntimeValidator.validate_runtime_issues(
config, 1
@@ -87,14 +88,14 @@ class TestRuntimeValidatorNonBlocking:
kind="bar",
)
# Mock the cardinality validator to return warnings
# Mock the plugin runtime dispatcher to return cardinality warnings
with patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_cardinality"
) as mock_cardinality:
mock_cardinality.return_value = (
"_validate_plugin_runtime"
) as mock_plugin:
mock_plugin.return_value = (
["High cardinality detected: 10000+ unique values"],
["Consider using aggregation or filtering"],
[],
)
is_valid, warnings_metadata = RuntimeValidator.validate_runtime_issues(
@@ -148,25 +149,20 @@ class TestRuntimeValidatorNonBlocking:
x_axis=AxisConfig(format="smart_date"), # Wrong format for user_id
)
# Mock all validators to return warnings
# Mock plugin runtime and chart type validators to return warnings
with (
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_format_compatibility"
) as mock_format,
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_cardinality"
) as mock_cardinality,
"_validate_plugin_runtime"
) as mock_plugin,
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_chart_type"
) as mock_type,
):
mock_format.return_value = ["Format mismatch warning"]
mock_cardinality.return_value = (
["High cardinality warning"],
["Cardinality suggestion"],
mock_plugin.return_value = (
["Format mismatch warning", "High cardinality warning"],
[],
)
mock_type.return_value = (
["Chart type warning"],
@@ -197,13 +193,13 @@ class TestRuntimeValidatorNonBlocking:
with (
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_format_compatibility"
) as mock_format,
"_validate_plugin_runtime"
) as mock_plugin,
patch(
"superset.mcp_service.chart.validation.runtime.logger"
) as mock_logger,
):
mock_format.return_value = ["Test warning message"]
mock_plugin.return_value = (["Test warning message"], [])
is_valid, warnings_metadata = RuntimeValidator.validate_runtime_issues(
config, 1
@@ -217,7 +213,7 @@ class TestRuntimeValidatorNonBlocking:
assert "warnings" in warnings_metadata
def test_validate_table_chart_skips_xy_validations(self):
"""Test that table charts skip XY-specific validations."""
"""Test that table charts produce no XY-specific runtime warnings."""
config = TableChartConfig(
chart_type="table",
columns=[
@@ -226,37 +222,29 @@ class TestRuntimeValidatorNonBlocking:
],
)
# These should not be called for table charts
with (
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_format_compatibility"
) as mock_format,
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_cardinality"
) as mock_cardinality,
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_chart_type"
) as mock_chart_type,
):
# Mock chart type validator to return no warnings
# Plugin runtime dispatches to TableChartPlugin which returns no warnings.
# Chart type suggester is also stubbed to return no warnings.
with patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_chart_type"
) as mock_chart_type:
mock_chart_type.return_value = ([], [])
is_valid, error = RuntimeValidator.validate_runtime_issues(config, 1)
# Format and cardinality validation should not be called for table charts
mock_format.assert_not_called()
mock_cardinality.assert_not_called()
assert is_valid is True
assert error is None
def test_validate_cardinality_returns_cleanly_when_x_name_is_none(self) -> None:
"""The dimension-rejection guard on XYChartConfig normally forbids
x.name=None, but a model_construct bypass (or a future code path)
could land us here. The defensive guard must return cleanly without
calling into CardinalityValidator (which assumes a real column)."""
could land us here. The defensive guard in XYChartPlugin.get_runtime_warnings
must skip cardinality without crashing."""
from superset.mcp_service.chart.plugins.xy import XYChartPlugin
from superset.mcp_service.chart.validation.runtime.format_validator import (
FormatTypeValidator,
)
col = ColumnRef.model_construct(name=None)
config = XYChartConfig.model_construct(
chart_type="xy",
@@ -265,13 +253,19 @@ class TestRuntimeValidatorNonBlocking:
kind="line",
)
with patch(
"superset.mcp_service.chart.validation.runtime."
"cardinality_validator.CardinalityValidator.check_cardinality"
) as mock_check:
warnings, suggestions = RuntimeValidator._validate_cardinality(
config, dataset_id=1
)
plugin = XYChartPlugin()
with (
patch.object(
FormatTypeValidator,
"validate_format_compatibility",
return_value=(True, []),
),
patch(
"superset.mcp_service.chart.validation.runtime."
"cardinality_validator.CardinalityValidator.check_cardinality"
) as mock_check,
):
warnings, suggestions = plugin.get_runtime_warnings(config, dataset_id=1)
assert warnings == []
assert suggestions == []

View File

@@ -192,8 +192,8 @@ def _make_datasource_with_real_rls(dataset_id: int) -> MagicMock:
datasource = _make_datasource(dataset_id)
# Bind real BaseDatasource method so RLS logic executes against mocked
# security_manager rather than returning MagicMock auto-stub
datasource.get_sqla_row_level_filters = lambda **kwargs: (
BaseDatasource.get_sqla_row_level_filters(datasource, **kwargs)
datasource.get_sqla_row_level_filters = (
lambda **kwargs: BaseDatasource.get_sqla_row_level_filters(datasource, **kwargs)
)
return datasource

View File

@@ -97,8 +97,8 @@ def test_cache_query_by_user_flag_yields_distinct_keys(feature_flag_mock):
"""
Global ``CACHE_QUERY_BY_USER`` flag also reaches the legacy viz path.
"""
feature_flag_mock.is_feature_enabled.side_effect = lambda feature=None: (
feature == "CACHE_QUERY_BY_USER"
feature_flag_mock.is_feature_enabled.side_effect = (
lambda feature=None: feature == "CACHE_QUERY_BY_USER"
)
database = Database(database_name="d", sqlalchemy_uri="sqlite://")
obj = _viz_for(database)

View File

@@ -693,9 +693,8 @@ def test_get_user_agent(mocker: MockerFixture, app_context: None) -> None:
@with_config(
{
"USER_AGENT_FUNC": lambda database, source: (
f"{database.database_name} {source.name}"
)
"USER_AGENT_FUNC": lambda database,
source: f"{database.database_name} {source.name}"
}
)
def test_get_user_agent_custom(mocker: MockerFixture, app_context: None) -> None: