Compare commits

..

35 Commits

Author SHA1 Message Date
Amin Ghadersohi
44aaa2cde5 chore(mcp): drop spurious executable bits on three Python files 2026-06-11 00:19:12 +00:00
Amin Ghadersohi
df5c242248 chore(mcp): restore out-of-scope files to master state
Earlier style commits reformatted 24 unrelated files with a newer local
ruff than the pinned 0.9.7, removed the PT004 ignore from
pyproject.toml, and the initial branch commit accidentally deleted
tests/unit_tests/utils/test_split.py. None of these belong to the chart
plugin registry change; restore them to master to keep the PR scoped.
2026-06-11 00:18:25 +00:00
Amin Ghadersohi
be3395614b feat(mcp): warn on native_viz_types collision at plugin registration
If two plugins claim the same viz_type, display_name_for_viz_type()
silently resolves to the iteration-order winner. Surface a warning at
register() time so plugin authors catch the shadowing immediately.
2026-06-10 23:16:36 +00:00
Amin Ghadersohi
71d696d1dd fix(mcp): restore master's API-key auth factory lost in rebase conflict
The rebase of the self-review commit left conflict markers in
mcp_config.py; resolve in favor of master's get_mcp_api_key_enabled
based factory.
2026-06-10 23:16:18 +00:00
Amin Ghadersohi
62f4df3ec0 fix(mcp): address self-review findings — comments, dedup, modern types
- schema_validator.py: add circular-import comment to both local registry
  imports (H1); extract valid_types before the conditional so all_types()
  is called once instead of in each error branch (N1)
- plugin.py: expand BaseChartPlugin docstring to list all default method
  behaviours including schema_error_hint (N3); add comment warning that
  native_viz_types is a class-level shared dict — subclasses must override
  as a class attribute, not mutate in place (M1)
- registry.py: expand _reset_for_testing() docstring with explicit warning
  that direct global assignment is not reverted by pytest monkeypatch —
  callers must restore state in teardown (M2)
- mcp_config.py: replace Dict/Optional from typing with dict/X|None modern
  syntax; remove now-unused Optional and Dict imports (N2)
- initialization/__init__.py: add docstring to configure_mcp_chart_registry()
  explaining the known two-call pattern in MCP-standalone startup and why
  the stale-config window between the two calls is benign in practice (H2)
2026-06-10 23:06:04 +00:00
Amin Ghadersohi
e7078d42b3 fix(mcp): fix import sort order in app.py (ruff I001)
Move bare `import superset.mcp_service.chart.plugins` before the `from`
imports per isort conventions; CI was failing with I001 (unsorted-imports).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-10 23:05:42 +00:00
Amin Ghadersohi
84496e89cb fix(mcp): extend pre_validate empty-list guard to xy, mixed_timeseries, table plugins
Presence checks like `"y" not in config` pass silently when `y=[]` is submitted,
deferring to Pydantic's min_length error instead of the friendlier ChartGenerationError.
Switch to falsy checks (`not config.get("y")`) to catch both missing keys and empty
lists in the same early guard — matching the pattern already applied to pivot_table.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-10 23:04:53 +00:00
Amin Ghadersohi
d78f438e1a fix(mcp): fix CI failures — sql_expression handling, cardinality guard, test mock fixes
- big_number.py pre_validate: add sql_expression branch; return
  MISSING_SQL_METRIC_LABEL when label is absent/non-string, so the
  existing unit tests (and LLM callers) get a clear actionable error.
- xy.py normalize_column_refs: skip entries with sql_expression set
  (name is None for these metrics); previously crashed with
  AttributeError: 'NoneType'.lower() in _get_canonical_column_name.
- test_big_number_chart.py: replace three calls to deleted
  SchemaValidator._pre_validate_big_number_config with
  plugin.pre_validate() via get_registry().
- test_runtime_validator.py: replace call to deleted
  RuntimeValidator._validate_cardinality with XYChartPlugin.get_runtime_warnings;
  patch FormatTypeValidator to isolate cardinality guard.
- test_update_chart.py: set mock_create_preview.return_value to a
  3-tuple so the update_chart unpack doesn't crash; change RuntimeError
  to ValueError which is in NORMALIZATION_EXCEPTIONS.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-10 23:04:53 +00:00
Amin Ghadersohi
2556b91002 fix(mcp): address copilot + bito review comments
- schema_validator: collapse is_enabled()+get() double lookup into a
  single get() call so operator-supplied enabled_func is invoked once
- update_chart: use guarded chart_datasource_id local var instead of
  re-accessing chart.datasource_id after the None check
- chart_utils: propagate post_map_validate() details+suggestions into
  the raised ValueError so callers log actionable context
- schemas: clarify chart_type_display_name description — prefer over
  viz_type when present, fall back to viz_type when null
- schemas: add or-empty-string fallback to dedup key labels to satisfy
  mypy (dict is typed dict[tuple[bool, str], str])
- plugins/xy: guard config.x.name is not None before cardinality check
- runtime/__init__, plugins/xy, registry: add noqa BLE001 to intentional
  broad exception catches with inline rationale comments
- tests: add TestUpdateChartColumnNormalization covering normalization
  called with guarded ID, graceful exception handling, and skip-when-null

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-10 23:04:53 +00:00
Amin Ghadersohi
ece185bbdc fix(mcp): address bito review — pre_validate empty-list and chart_type=None error
- pivot_table.pre_validate: `not config.get("rows/metrics")` catches both
  missing keys and empty lists, matching PivotTableChartConfig min_length=1
- chart_utils.map_config_to_form_data: omit `(chart_type=None)` suffix from
  ValueError when chart_type is None to avoid misleading error messages
2026-06-10 23:04:53 +00:00
Amin Ghadersohi
8b9a0ea0c0 fix(mcp): harden chart registry review fixes 2026-06-10 23:04:53 +00:00
Amin Ghadersohi
3cdd00fb0a fix(mcp): address fitzee review — sanitize temporal_column, narrow BLE001, add _reset_for_testing(), remove exec bits
schemas.py:
- Add @field_validator('temporal_column') on BigNumberChartConfig — the regex removal
  in d883b622 left this field with only min/max_length guards; ColumnRef.name and
  FilterConfig.column already used sanitize_user_input (check_sql_keywords=True) and
  sanitize_user_input respectively, but BigNumberChartConfig.temporal_column was missed.
  PR #39915 (relaxing the original regex) was closed; this PR covers the same intent
  by relying on sanitize_user_input validators instead.
- serialize_chart_object: split the broad except Exception (BLE001) into ImportError
  (for import-failure path) + Exception (for third-party plugin errors) so the scope
  of each catch is explicit.

registry.py:
- Add _reset_for_testing() that resets _REGISTRY, _plugins_loaded, _plugins_load_failed,
  and _filter_config — gives tests a single clean-slate function instead of four
  separate monkeypatches.
- Move _RegistryProxy instantiation to module level (_PROXY); get_registry() returns
  the singleton instead of allocating a new object on every call.

file modes:
- Remove executable bits (100755 → 100644) from 9 files: plugin.py, all 7 plugin
  files, registry.py, and initialization/__init__.py.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-10 23:04:53 +00:00
Amin Ghadersohi
a376b38ac9 style: fix E501 in test_registry docstring (93 > 88 chars)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-10 23:04:53 +00:00
Amin Ghadersohi
74ff96902c fix(mcp): lock register() writes and add circuit breaker to _ensure_plugins_loaded
- register() now holds _plugins_lock when writing to _REGISTRY, preventing
  concurrent write races if plugins are registered outside the bootstrap path
- _ensure_plugins_loaded() now sets _plugins_load_failed=True on ImportError
  so subsequent lookups return None immediately instead of retrying the import
  on every call
- _isolated_registry fixture in test_registry.py resets _plugins_load_failed
- Two new tests cover the circuit-breaker skip path and the failure-flag path

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-10 23:04:53 +00:00
Amin Ghadersohi
3dbfbbdefa style: fix E501 noqa placement and PT001 in export_test.py
noqa: E501 comments were on the closing-paren line instead of on the
actual long string lines, so ruff did not suppress the violations.
Add # noqa: PT001 on the @pytest.fixture decorator to pin the
no-parentheses style (ruff 0.9.7 default) and prevent ruff 0.5.x
from auto-converting it in either direction.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-10 23:04:53 +00:00
Amin Ghadersohi
e0149f38ee style: fix E501 noqa placement and PT001 in export_test.py
noqa: E501 comments were on the closing-paren line instead of on the
actual long string lines, so ruff did not suppress the violations.
Also applied ruff auto-fix for PT001 (@pytest.fixture -> @pytest.fixture()).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-10 23:04:53 +00:00
Amin Ghadersohi
f86abee451 perf(mcp): remove redundant DatasetValidator call in update_chart
validate_and_compile already runs Tier 1 (validate_against_dataset)
with the same fuzzy-match suggestions in CompileResult.error_obj.
The explicit pre-call fetched dataset context a second time via a
separate DB query, producing duplicate work on every update request.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-10 23:04:53 +00:00
Amin Ghadersohi
e63d309dca style: remove obsolete PT004 ruff rule (dropped in ruff 0.9.7)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-10 23:04:53 +00:00
Amin Ghadersohi
7459b5ab40 style: ruff-format auto-format fix 2026-06-10 23:04:53 +00:00
Amin Ghadersohi
eb78d4a405 fix(mcp): address Bito review — log bare exception in schemas, remove redundant annotation quotes
- schemas.py: CWE-390 bare except → add `as exc` + debug log so display-name
  lookup failures are observable (BLE001 suppressed: intentional fallback)
- plugin.py: remove redundant forward-reference quotes from schema_error_hint
  return type (from __future__ import annotations already makes all annotations
  lazy strings)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-10 23:04:53 +00:00
Amin Ghadersohi
80c320b971 fix(mcp): use (saved_metric, label) dedup key in XYChartConfig
A saved metric and a regular column with the same input name resolve
to different display labels after normalization (saved metrics use the
dataset's actual casing).  Using a plain string key incorrectly flags
them as duplicates; keying on (saved_metric, label) avoids the false
collision.

Fixes test_xy_saved_metric_uses_metric_casing.
2026-06-10 23:04:53 +00:00
Amin Ghadersohi
ec178b862c fix(mcp): fix saved-metric name normalization across all chart plugins
Add _get_canonical_metric_name() to DatasetValidator that searches only
available_metrics, preventing a column with matching case-insensitive name
from shadowing a saved metric's canonical casing.

Update all 7 chart plugins (xy, table, pie, big_number, handlebars,
mixed_timeseries, pivot_table) to branch on saved_metric flag: saved
metrics now go through _get_canonical_metric_name while regular column
refs continue to use _get_canonical_column_name.

Fix pre_validate alias handling in xy and mixed_timeseries plugins to
accept Pydantic AliasChoices keys (metrics/x_axis/metrics_b) so payloads
using canonical Superset field names are not incorrectly rejected.

Add TestGetCanonicalMetricName, TestSavedMetricNormalizationCorrectness,
and TestPreValidateAliasHandling test classes covering the collision case
and alias acceptance.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-10 23:04:53 +00:00
Amin Ghadersohi
4e0eb3a395 feat(mcp): add runtime chart plugin enable/disable via _PluginFilterConfig
Introduces a dynamic filter layer in the chart type registry so operators can
disable individual plugins (e.g. `handlebars`) without a code deploy:

- `MCP_DISABLED_CHART_PLUGINS: frozenset[str]` — static deny-list in mcp_config.py
- `MCP_CHART_PLUGIN_ENABLED_FUNC: Callable[[str], bool] | None` — dynamic hook
  for Harness/Split/per-user targeting; takes precedence over the deny-list
- Both keys are propagated through `get_mcp_config()` defaults

registry.py changes:
- `_PluginFilterConfig` frozen dataclass replaces two bare globals so
  configure() replaces them atomically (no torn reads under concurrency)
- `configure(disabled, enabled_func)` — called at app init; accepts any
  iterable for `disabled`; validates `enabled_func` is callable
- `_is_plugin_enabled()` — reads config once, fails closed on callable exception
- `get()` / `all_types()` / `is_enabled()` apply the filter at lookup time;
  `is_registered()` and `display_name_for_viz_type()` intentionally bypass it
  so callers can distinguish "unknown" vs "disabled" and existing charts still
  resolve display names for disabled viz types

schema_validator.py: two-step pre-check — `is_registered()` for unknown types,
`is_enabled()` for disabled ones, with distinct `DISABLED_CHART_TYPE` error code.

Wiring:
- `SupersetAppInitializer.configure_mcp_chart_registry()` called after
  `configure_feature_flags()` in `init_app()`
- `flask_singleton.py` re-calls `registry.configure()` after the MCP config
  overlay so MCP-specific overrides in `superset_config.py` take effect in
  standalone MCP mode

Tests: 28 cases in test_registry_filters.py covering deny-list, callable hook,
fail-closed on exception, all_types() filtering, display_name bypass, atomic
reconfigure, and configure() with list/tuple/frozenset inputs.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-10 23:04:53 +00:00
Amin Ghadersohi
c1d4b454e4 fix(mcp): fix E501 in update_chart.py and update_chart test mocks for column validation
Split an 89-char comment line and an over-limit condition in update_chart.py
to satisfy the ruff E501 rule. Also applied ruff format.

Two TestUpdateChartValidationGate tests expected CHART_VALIDATION_FAILED but
received CHART_DATASET_NOT_FOUND because _validate_update_against_dataset calls
DatasetValidator.validate_against_dataset before validate_and_compile, and the
existing mocks provided a Mock() object for chart.datasource whose .id attribute
is an auto-generated MagicMock (not a real int). Added a patch for
DatasetValidator.validate_against_dataset returning (True, None) so the
column-validation tier is bypassed and the test reaches the mocked
validate_and_compile response as intended.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-10 23:04:09 +00:00
Amin Ghadersohi
760a973c76 refactor(mcp): address Codex review — fix registry bug, DRY schema hints, remove column regex
P1.1 registry.py: move _plugins_loaded=True to after successful import so a
failed load doesn't permanently poison the registry.

P1.3 schemas.py: remove overly restrictive ColumnRef.name / FilterClause.column
/ BigNumberChartConfig.temporal_column regex that blocked valid column names
containing parentheses, slashes, and other SQL-common characters.

P2.3 (DRY): eliminate _CHART_TYPE_ERROR_HINTS second-registry in
schema_validator.py by adding schema_error_hint() to ChartTypePlugin protocol,
BaseChartPlugin default, and all 7 plugin classes. SchemaValidator now delegates
to the plugin registry instead of maintaining a parallel dict.

P3.3 test_registry.py: add full registry unit-test coverage (register, get,
all_types, is_registered, display_name_for_viz_type, proxy methods, duplicate
warning, empty chart_type validation, insertion-order guarantee).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-10 23:04:09 +00:00
Amin Ghadersohi
5d41fe1d53 fix(mcp): add full column validation to update_chart
update_chart was only running SchemaValidator + Tier-2 compile check,
silently skipping DatasetValidator's column-existence + fuzzy-match
and column-name normalisation layers that generate_chart runs.

A typo like {name: "reveneu"} would save the broken chart and only
surface as a render-time failure in the browser.

Now matches generate_chart pipeline:
- Layer 2: DatasetValidator.validate_against_dataset() — column
  existence check with fuzzy-match "did you mean?" suggestions returned
  to the LLM before any DB write occurs
- Layer 4: DatasetValidator.normalize_column_names() — case
  normalisation so "order_date" resolves to "OrderDate" if that is the
  canonical dataset name

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-10 23:04:09 +00:00
Amin Ghadersohi
fe24d8cdcd fix(mcp): add threading lock to registry plugin loader
_ensure_plugins_loaded() used an unprotected boolean flag, making it
unsafe under concurrent first-call scenarios (e.g. gunicorn multi-thread
workers). Double-checked locking with threading.Lock eliminates the race.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-10 23:04:09 +00:00
Amin Ghadersohi
20c179390c fix(mcp): resolve E402 and E501 in dataset_validator.py
- Move error_schemas import above _C TypeVar definition (E402)
- Split two over-length comment lines to ≤88 chars (E501, lines 268 and 380)
2026-06-10 23:04:09 +00:00
Amin Ghadersohi
77dc099da7 fix(mcp): resolve ruff E501 and formatting issues to pass pre-commit
- Split long string literal in schema_validator.py line 202 (E501, 94 > 88 chars)
- Apply ruff format auto-fixes to big_number.py, handlebars.py, and test_get_chart_data.py
2026-06-10 23:04:09 +00:00
Amin Ghadersohi
3fa84d8838 refactor(mcp): move all local imports to top level in chart type plugins
All per-method local imports in the 7 chart plugins were moved to module-level.
None of them create circular imports: schemas.py, chart_utils.py, and
dataset_validator.py are safe to import at plugin load time because those
modules guard their own registry lookups with local imports.

- big_number: add map_big_number_config, _big_number_chart_what,
  _summarize_filters, DatasetValidator to top-level imports
- pie: add map_pie_config, _pie_chart_what, _summarize_filters, PieChartConfig,
  DatasetValidator to top-level imports
- xy: add map_xy_config, _xy_chart_what/context, XYChartConfig, DatasetValidator,
  FormatTypeValidator, CardinalityValidator to top-level imports
- table: add map_table_config, _table_chart_what, _summarize_filters,
  TableChartConfig, DatasetValidator to top-level imports
- pivot_table: add map_pivot_table_config, _pivot_table_what, _summarize_filters,
  PivotTableChartConfig, DatasetValidator to top-level imports
- mixed_timeseries: add map_mixed_timeseries_config, _mixed_timeseries_what,
  _summarize_filters, MixedTimeseriesChartConfig, DatasetValidator to top-level
- handlebars: add map_handlebars_config, _handlebars_chart_what, _summarize_filters,
  HandlebarsChartConfig, DatasetValidator to top-level imports

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-10 23:04:09 +00:00
Amin Ghadersohi
6f43a56935 fix(mcp): address reviewer comments — local import rationale, x-optional corrections, cardinality suggestions
- Remove redundant local imports from BigNumberChartPlugin.post_map_validate()
  now that BigNumberChartConfig and is_column_truly_temporal are at top level
- Add explanatory comments on the two remaining local get_registry imports in
  chart_utils.py and dataset_validator.py (circular import prevention)
- Fix schema_validator.py and generate_chart.py docstring: XY 'x' field is
  optional (defaults to dataset primary datetime column), not required
- Propagate cardinality suggestions alongside warnings in XYChartPlugin
- Clarify app.py instructions: chart_type_display_name is null for viz_types
  outside the 7 generate_chart-supported types

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-10 23:04:09 +00:00
Amin Ghadersohi
487f8afc72 refactor(mcp): complete plugin protocol — registry bootstrap, mypy fixes, test repairs
On top of the dead-code elimination in the previous commit:
- Add lazy _ensure_plugins_loaded() bootstrap to ChartTypeRegistry so the
  registry is populated even without importing app.py (fixes isolated test runs)
- Delegate _RegistryProxy methods to module-level functions so bootstrap runs
- Guard register() against empty chart_type strings
- Add generate_name + resolve_viz_type to ChartTypePlugin Protocol and
  BaseChartPlugin; delegate generate_chart_name/_resolve_viz_type in
  chart_utils to the plugin registry
- Add _with_context static helper to BaseChartPlugin (shared by all plugins)
- Fix stale 'five methods' → 'eight methods' docstring in plugin.py
- Add TypeVar _C to normalize_column_names so mypy infers correct return type
- Fix broken tests: update _pre_validate_big_number_config → _pre_validate_chart_type,
  remove deleted TestNormalizeXYConfig/TestNormalizeTableConfig classes,
  update runtime validator tests for removed _validate_format_compatibility /
  _validate_cardinality methods, add x is not None narrowing guards

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-10 23:04:09 +00:00
Amin Ghadersohi
65afe7d577 refactor(mcp): eliminate dead code and complete plugin registry dispatch
H1: Delete 7 dead _pre_validate_* static methods from SchemaValidator
    — exact duplicates of plugin pre_validate() methods, never called
    after _pre_validate_chart_type() was updated to delegate to plugin.

H2: Inline DatasetValidator._normalize_xy_config/_normalize_table_config
    into XYChartPlugin/TableChartPlugin.normalize_column_refs() and delete
    both DatasetValidator helper methods. The 5 other plugins already
    called _get_canonical_column_name directly; XY and Table now match.

H3: Add generate_name()/resolve_viz_type() to ChartTypePlugin protocol
    and BaseChartPlugin, implement in all 7 plugins. Replace the 7-arm
    isinstance chain in generate_chart_name() and the 7-arm elif chain
    in _resolve_viz_type() with single-line registry dispatch.

H4: Add a sync comment above _CHART_TYPE_ERROR_HINTS to document that
    it must stay in sync with the plugin registry.

M4: Move logger=logging.getLogger(__name__) from inside
    XYChartPlugin.get_runtime_warnings() to module scope.
2026-06-10 23:04:09 +00:00
Amin Ghadersohi
0110b523a0 feat(mcp): add display_name and native_viz_types to chart type plugins
Each ChartTypePlugin now declares:
- display_name: human-readable label for the chart_type discriminator
  (e.g. "Line / Bar / Area / Scatter Chart", "Pivot Table")
- native_viz_types: dict mapping every Superset-internal viz_type the
  plugin produces to a user-friendly name
  (e.g. {"echarts_timeseries_line": "Line Chart", "echarts_area": "Area Chart"})

The registry gains display_name_for_viz_type(viz_type) which searches
all plugins' native_viz_types maps, replacing the need for a separate
viz_type_display_names.json or viz_type_names.py module.

ChartInfo gains a chart_type_display_name field populated via the registry,
so list_charts / get_chart_info return human-readable chart type names.
The MCP system instructions now reference display names rather than
internal viz_type identifiers.
2026-06-10 23:04:09 +00:00
Amin Ghadersohi
689c0fb5b7 feat(mcp): introduce chart type plugin registry for extensible chart generation
Replaces four scattered dispatch locations (schema_validator, dataset_validator,
chart_utils, runtime validator) with a central ChartTypePlugin registry. Each of
the 7 supported chart types (xy, table, pie, pivot_table, mixed_timeseries,
handlebars, big_number) now owns its pre-validation, column extraction, form_data
mapping, post-map validation, column normalization, and runtime warnings in a
single plugin class.

Key changes:
- Add ChartTypePlugin protocol and BaseChartPlugin base class (plugin.py)
- Add ChartTypeRegistry with register/get/all_types helpers (registry.py)
- Add 7 chart type plugins under chart/plugins/ with full coverage
- Fix 5-type column validation gap: pie, pivot_table, mixed_timeseries, handlebars,
  and big_number now participate in dataset column validation (previously silently skipped)
- Move BigNumber trendline temporal check to BigNumberChartPlugin.post_map_validate()
- Add get_runtime_warnings() to plugin protocol; XYChartPlugin implements
  format/cardinality checks, removing isinstance(config, XYChartConfig) from RuntimeValidator
- Fix stale generate_chart.py docstring listing only 'xy' and 'table' chart types
- Add missing pie, pivot_table, mixed_timeseries handlers to _enhance_validation_error;
  refactor into a data-driven lookup table to stay within complexity limits
- Fix empty details fallback in Pydantic error handler
2026-06-10 23:04:09 +00:00
31 changed files with 3205 additions and 1037 deletions

View File

@@ -103,19 +103,6 @@ class DatasourceTypeUpdateRequiredValidationError(ValidationError):
)
class ChartQueryContextDatasourceMismatchValidationError(ValidationError):
"""
Raised when a query-context-only update carries a datasource that does not
match the chart's own datasource.
"""
def __init__(self) -> None:
super().__init__(
_("The query context datasource does not match the chart datasource"),
field_name="query_context",
)
class ChartNotFoundError(CommandException):
message = "Chart not found."

View File

@@ -29,7 +29,6 @@ from superset.commands.chart.exceptions import (
ChartForbiddenError,
ChartInvalidError,
ChartNotFoundError,
ChartQueryContextDatasourceMismatchValidationError,
ChartUpdateFailedError,
DashboardsForbiddenError,
DashboardsNotFoundValidationError,
@@ -42,7 +41,6 @@ from superset.exceptions import SupersetSecurityException
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.tags.models import ObjectType
from superset.utils import json
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@@ -103,51 +101,6 @@ class UpdateChartCommand(UpdateMixin, BaseCommand):
if not security_manager.is_owner(dash):
raise DashboardsForbiddenError()
def _validate_query_context_datasource(
self, exceptions: list[ValidationError]
) -> None:
"""
Ensure a query-context-only update keeps the chart's own datasource.
The submitted query context is only verified when it carries a parseable
``datasource`` object; a payload that references a different datasource than
the chart's persisted one is rejected. Payloads without a datasource fall
back to the chart's datasource at execution time and need no check.
"""
if not self._model:
return
raw_query_context = self._properties.get("query_context")
if not raw_query_context:
return
try:
query_context = json.loads(raw_query_context)
except (TypeError, ValueError):
# An unparseable payload cannot be verified or replayed; leave it for
# downstream handling rather than guessing at its intent.
return
datasource = (
query_context.get("datasource") if isinstance(query_context, dict) else None
)
if not isinstance(datasource, dict):
return
try:
ids_match = int(datasource["id"]) == self._model.datasource_id
except (KeyError, TypeError, ValueError):
ids_match = False
datasource_type = datasource.get("type")
types_match = (
datasource_type is None
or str(datasource_type) == self._model.datasource_type
)
if not ids_match or not types_match:
exceptions.append(ChartQueryContextDatasourceMismatchValidationError())
def validate(self) -> None: # noqa: C901
exceptions: list[ValidationError] = []
dashboard_ids = self._properties.get("dashboards")
@@ -181,12 +134,6 @@ class UpdateChartCommand(UpdateMixin, BaseCommand):
raise ChartForbiddenError() from ex
except ValidationError as ex:
exceptions.append(ex)
else:
# The query-context-only path skips the ownership check so report and
# alert workers can refresh a chart's cached payload. Keep that payload
# bound to the chart's own datasource so it cannot be repointed at an
# unrelated one.
self._validate_query_context_datasource(exceptions)
# validate tags
try:

View File

@@ -807,6 +807,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
self.configure_feature_flags()
self.check_guest_token_secret()
self.check_async_query_secret()
self.configure_mcp_chart_registry()
self.configure_db_encrypt()
self.setup_db()
@@ -888,6 +889,37 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
def configure_feature_flags(self) -> None:
feature_flag_manager.init_app(self.superset_app)
def configure_mcp_chart_registry(self) -> None:
"""Configure the MCP chart plugin registry with operator overrides.
Called from ``post_init()`` during ``create_app()``. In normal
(non-MCP-standalone) Superset deployments this call is the only one
and picks up ``MCP_DISABLED_CHART_PLUGINS`` / ``MCP_CHART_PLUGIN_ENABLED_FUNC``
from the fully resolved config.
In the MCP-standalone deployment (``flask_singleton.py``), ``create_app()``
calls this method first — before the MCP-specific config overlay is applied —
and then ``flask_singleton.configure()`` calls ``registry.configure()`` a second
time with the correct post-overlay values. Any registry lookup that occurs
between these two calls (during ``initialize_core_mcp_dependencies()``) sees
the pre-overlay config. In practice no lookups occur at that point because
tools are invoked only after startup completes, so the window is benign.
"""
from superset.mcp_service.chart import registry
from superset.mcp_service.mcp_config import (
MCP_CHART_PLUGIN_ENABLED_FUNC,
MCP_DISABLED_CHART_PLUGINS,
)
registry.configure(
disabled=self.config.get(
"MCP_DISABLED_CHART_PLUGINS", MCP_DISABLED_CHART_PLUGINS
),
enabled_func=self.config.get(
"MCP_CHART_PLUGIN_ENABLED_FUNC", MCP_CHART_PLUGIN_ENABLED_FUNC
),
)
def configure_sqlglot_dialects(self) -> None:
extensions = self.config["SQLGLOT_DIALECTS_EXTENSIONS"]

View File

@@ -342,10 +342,12 @@ Time grain for temporal x-axis (time_grain parameter):
- PT1H (hourly), P1D (daily), P1W (weekly), P1M (monthly), P1Y (yearly)
Chart Types in Existing Charts (viewable via list_charts/get_chart_info):
- pie, big_number, big_number_total, funnel, gauge_chart
- echarts_timeseries_line, echarts_timeseries_bar, echarts_timeseries_area
- pivot_table_v2, heatmap_v2, sankey_v2, sunburst_v2, treemap_v2
- word_cloud, world_map, box_plot, bubble, mixed_timeseries
Each chart returned by list_charts / get_chart_info includes a
chart_type_display_name field with a human-readable name when available.
This field is populated only for the 7 chart types supported by generate_chart
(xy, pie, table, pivot_table, big_number, mixed_timeseries, handlebars).
For all other viz_types (Funnel, Gauge, Heatmap, etc.) it will be null —
use the raw viz_type field instead when referring to those chart types.
Query Examples:
- List all tables:
@@ -656,6 +658,7 @@ warnings.filterwarnings(
# NOTE: Always add new prompt/resource imports here when creating new prompts/resources.
# Prompts use @mcp.prompt decorators and resources use @mcp.resource decorators.
# They register automatically on import, similar to tools.
import superset.mcp_service.chart.plugins # noqa: F401, E402 — registers all chart type plugins
from superset.mcp_service.annotation_layer.tool import ( # noqa: F401, E402
get_annotation_layer_info,
get_layer_annotation_info,

View File

@@ -321,29 +321,44 @@ def map_config_to_form_data(
| BigNumberChartConfig,
dataset_id: int | str | None = None,
) -> Dict[str, Any]:
"""Map chart config to Superset form_data."""
if isinstance(config, TableChartConfig):
return map_table_config(config)
elif isinstance(config, XYChartConfig):
return map_xy_config(config, dataset_id=dataset_id)
elif isinstance(config, PieChartConfig):
return map_pie_config(config)
elif isinstance(config, PivotTableChartConfig):
return map_pivot_table_config(config)
elif isinstance(config, MixedTimeseriesChartConfig):
return map_mixed_timeseries_config(config, dataset_id=dataset_id)
elif isinstance(config, HandlebarsChartConfig):
return map_handlebars_config(config)
elif isinstance(config, BigNumberChartConfig):
if config.show_trendline and config.temporal_column:
if not is_column_truly_temporal(config.temporal_column, dataset_id):
raise ValueError(
f"Big Number trendline requires a temporal SQL column; "
f"'{config.temporal_column}' is not temporal."
)
return map_big_number_config(config)
else:
raise ValueError(f"Unsupported config type: {type(config)}")
"""Map chart config to Superset form_data via the plugin registry.
The previous if/elif chain across all 7 chart types has been replaced by a
single registry lookup. Cross-field constraints (e.g. BigNumber trendline
temporal check) are now owned by each plugin's post_map_validate() method
rather than being baked into this dispatcher.
"""
# Local import: plugins call map_*_config from their to_form_data() methods,
# so chart_utils is loaded before plugins finish registering. A top-level
# import of registry here would trigger plugin loading mid-import = cycle.
from superset.mcp_service.chart.registry import get_registry
chart_type = getattr(config, "chart_type", None)
plugin = get_registry().get(chart_type) if chart_type else None
if plugin is None:
if chart_type is None:
raise ValueError(f"Unsupported config type: {type(config)}")
raise ValueError(
f"Unsupported config type: {type(config)} (chart_type={chart_type!r})"
)
form_data = plugin.to_form_data(config, dataset_id=dataset_id)
# Run post-map validation (e.g. BigNumber trendline temporal type check).
# Raise ValueError to preserve backward-compatible error handling in callers.
# Include details and suggestions so callers logging str(e) surface actionable
# context (e.g. BigNumber trendline guidance) rather than just the headline.
error = plugin.post_map_validate(config, form_data, dataset_id=dataset_id)
if error is not None:
parts = [error.message]
if error.details:
parts.append(error.details)
if error.suggestions:
parts.append("Suggestions: " + "; ".join(error.suggestions))
raise ValueError(" ".join(parts))
return form_data
def _add_adhoc_filters(
@@ -1244,87 +1259,32 @@ def _big_number_chart_what(config: BigNumberChartConfig) -> str:
def generate_chart_name(
config: TableChartConfig
| XYChartConfig
| PieChartConfig
| PivotTableChartConfig
| MixedTimeseriesChartConfig
| HandlebarsChartConfig
| BigNumberChartConfig,
config: Any,
dataset_name: str | None = None,
) -> str:
"""Generate a descriptive chart name following a standard format.
Format conventions (by chart type):
Aggregated (bar/scatter with group_by): [Metric] by [Dimension]
Time-series (line/area, no group_by): [Metric] Over Time
Table (no aggregates): [Dataset] Records
Table (with aggregates): [Metric] Summary
Pie: [Dimension] by [Metric]
Pivot Table: Pivot Table [Row1, Row2]
Mixed Timeseries: [Primary] + [Secondary]
An en-dash followed by context (filters / time grain) is appended
Delegates to each plugin's ``generate_name()`` method.
See each plugin's ``generate_name`` for chart-type-specific format conventions.
An en-dash followed by context (filters / time grain) is appended by the plugin
when such information is available.
"""
if isinstance(config, TableChartConfig):
what = _table_chart_what(config, dataset_name)
context = _summarize_filters(config.filters)
elif isinstance(config, XYChartConfig):
what = _xy_chart_what(config)
context = _xy_chart_context(config)
elif isinstance(config, PieChartConfig):
what = _pie_chart_what(config)
context = _summarize_filters(config.filters)
elif isinstance(config, PivotTableChartConfig):
what = _pivot_table_what(config)
context = _summarize_filters(config.filters)
elif isinstance(config, MixedTimeseriesChartConfig):
what = _mixed_timeseries_what(config)
context = _summarize_filters(config.filters)
elif isinstance(config, HandlebarsChartConfig):
what = _handlebars_chart_what(config)
context = _summarize_filters(getattr(config, "filters", None))
elif isinstance(config, BigNumberChartConfig):
what = _big_number_chart_what(config)
context = _summarize_filters(getattr(config, "filters", None))
else:
return "Chart"
from superset.mcp_service.chart.registry import get_registry
name = what
if context:
name = f"{what} \u2013 {context}"
return _truncate(name)
plugin = get_registry().get(getattr(config, "chart_type", ""))
if plugin is None:
return "Chart"
return _truncate(plugin.generate_name(config, dataset_name))
def _resolve_viz_type(config: Any) -> str:
"""Resolve the Superset viz_type from a chart config object."""
chart_type = getattr(config, "chart_type", "unknown")
if chart_type == "xy":
kind = getattr(config, "kind", "line")
viz_type_map = {
"line": "echarts_timeseries_line",
"bar": "echarts_timeseries_bar",
"area": "echarts_area",
"scatter": "echarts_timeseries_scatter",
}
return viz_type_map.get(kind, "echarts_timeseries_line")
elif chart_type == "table":
return getattr(config, "viz_type", "table")
elif chart_type == "pie":
return "pie"
elif chart_type == "pivot_table":
return "pivot_table_v2"
elif chart_type == "mixed_timeseries":
return "mixed_timeseries"
elif chart_type == "handlebars":
return "handlebars"
elif chart_type == "big_number":
show_trendline = getattr(config, "show_trendline", False)
temporal_column = getattr(config, "temporal_column", None)
return (
"big_number" if show_trendline and temporal_column else "big_number_total"
)
return "unknown"
from superset.mcp_service.chart.registry import get_registry
plugin = get_registry().get(getattr(config, "chart_type", ""))
if plugin is None:
return "unknown"
return plugin.resolve_viz_type(config)
TABLE_VIZ_TYPE_LABELS = {

View File

@@ -0,0 +1,263 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
ChartTypePlugin protocol and BaseChartPlugin base class.
Each chart type owns its pre-validation, column extraction, form_data mapping,
and post-map validation in a single plugin class. This eliminates the previous
pattern of 4 separate dispatch points (schema_validator.py, dataset_validator.py,
chart_utils.py, pipeline.py) that had to be updated in sync whenever a new chart
type was added.
"""
from __future__ import annotations
from typing import Any, Protocol, runtime_checkable
from superset.mcp_service.chart.schemas import ColumnRef
from superset.mcp_service.common.error_schemas import ChartGenerationError
@runtime_checkable
class ChartTypePlugin(Protocol):
"""
Protocol that every chart-type plugin must satisfy.
Implementing all eight methods in a single class guarantees that adding a
new chart type requires only one new file — the plugin — rather than edits
across multiple separate files.
"""
#: Discriminator value matching ChartConfig's chart_type field.
chart_type: str
#: Human-readable name shown to users (e.g. "Line / Bar / Area / Scatter").
display_name: str
#: Maps every Superset-internal viz_type this plugin can produce to a
#: user-facing display name, e.g. {"echarts_timeseries_line": "Line Chart"}.
#: Used by the registry to resolve display names for existing charts without
#: needing a separate JSON mapping file.
native_viz_types: dict[str, str]
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
"""
Early validation of the raw config dict before Pydantic parsing.
Called by SchemaValidator before attempting to parse the request.
Should check that required top-level keys are present and well-typed.
Returns None if valid, ChartGenerationError if invalid.
"""
...
def extract_column_refs(
self,
config: Any,
) -> list[ColumnRef]:
"""
Extract all column references from a parsed chart config.
Called by DatasetValidator to validate that all referenced columns exist
in the dataset. Must cover every field that holds a column name,
including filters.
Returns a list of ColumnRef objects (may be empty).
"""
...
def to_form_data(
self,
config: Any,
dataset_id: int | str | None = None,
) -> dict[str, Any]:
"""
Map a parsed chart config to Superset's internal form_data dict.
Replaces the if/elif chain in chart_utils.map_config_to_form_data().
Returns a Superset form_data dict ready for caching and rendering.
"""
...
def post_map_validate(
self,
config: Any,
form_data: dict[str, Any],
dataset_id: int | str | None = None,
) -> ChartGenerationError | None:
"""
Validate the mapped form_data after to_form_data() runs.
Use this for cross-field constraints that can only be checked once
form_data is assembled (e.g. BigNumber trendline requires a temporal
column whose type must be verified against the dataset).
Returns None if valid, ChartGenerationError if invalid.
"""
...
def normalize_column_refs(
self,
config: Any,
dataset_context: Any,
) -> Any:
"""
Return a new config with column names normalized to canonical dataset casing.
Called by DatasetValidator.normalize_column_names(). The default
implementation (in BaseChartPlugin) returns the config unchanged; plugins
with column fields override this to fix case sensitivity mismatches.
Returns a new config object (or the original if no normalization needed).
"""
...
def get_runtime_warnings(
self,
config: Any,
dataset_id: int | str,
) -> list[str]:
"""
Return chart-type-specific runtime warnings (performance, compatibility).
Called by RuntimeValidator to collect per-type warnings. Warnings are
informational only — they never block chart generation. The default
implementation returns an empty list; plugins override this to emit
chart-type-specific warnings (e.g. XY cardinality checks).
Returns a list of warning message strings (may be empty).
"""
...
def generate_name(
self,
config: Any,
dataset_name: str | None = None,
) -> str:
"""
Return a descriptive chart name for the given config.
Called by chart_utils.generate_chart_name(). The name should follow
the standard format conventions documented in that function. Plugins
that do not override this return the generic fallback "Chart".
"""
...
def resolve_viz_type(self, config: Any) -> str:
"""
Return the Superset-internal viz_type string for this config.
Called by chart_utils._resolve_viz_type(). The returned string must
match a registered Superset viz plugin (e.g. "echarts_timeseries_line").
Plugins that do not override this return "unknown".
"""
...
def schema_error_hint(self) -> ChartGenerationError | None:
"""
Return a user-friendly error for Pydantic discriminated-union parse failures.
Called by SchemaValidator when Pydantic cannot parse the config union and
the chart_type is known. Returning None falls back to the generic error.
"""
...
class BaseChartPlugin:
"""
Base class providing sensible defaults for all ChartTypePlugin methods.
Concrete plugins extend this and override only what they need. Default
implementations: ``pre_validate`` → None (valid), ``extract_column_refs`` → [],
``post_map_validate`` → None, ``normalize_column_refs`` → config unchanged,
``get_runtime_warnings`` → [], ``generate_name`` → "Chart",
``resolve_viz_type`` → "unknown", ``schema_error_hint`` → None.
``to_form_data`` raises ``NotImplementedError`` and must be overridden.
"""
chart_type: str = ""
display_name: str = ""
# Class-level dict shared across all subclasses that don't override it.
# Subclasses MUST override this as a class attribute (not mutate in place)
# to avoid corrupting the shared empty-dict default for other plugins.
native_viz_types: dict[str, str] = {}
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
return None
def extract_column_refs(
self,
config: Any,
) -> list[ColumnRef]:
return []
def to_form_data(
self,
config: Any,
dataset_id: int | str | None = None,
) -> dict[str, Any]:
raise NotImplementedError(
f"{self.__class__.__name__}.to_form_data() is not implemented"
)
def post_map_validate(
self,
config: Any,
form_data: dict[str, Any],
dataset_id: int | str | None = None,
) -> ChartGenerationError | None:
return None
def normalize_column_refs(
self,
config: Any,
dataset_context: Any,
) -> Any:
return config
def get_runtime_warnings(
self,
config: Any,
dataset_id: int | str,
) -> list[str]:
return []
def generate_name(
self,
config: Any,
dataset_name: str | None = None,
) -> str:
return "Chart"
def resolve_viz_type(self, config: Any) -> str:
return "unknown"
def schema_error_hint(self) -> ChartGenerationError | None:
return None
@staticmethod
def _with_context(what: str, context: str | None) -> str:
"""Combine a 'what' label and optional context with an en-dash."""
return f"{what} {context}" if context else what

View File

@@ -0,0 +1,58 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Chart type plugins package.
Importing this module registers all built-in chart type plugins in the global
registry. This module is imported by app.py at startup.
To add a new chart type:
1. Create ``superset/mcp_service/chart/plugins/{chart_type}.py``
2. Implement a class extending ``BaseChartPlugin``
3. Import and register it here
"""
from superset.mcp_service.chart.plugins.big_number import BigNumberChartPlugin
from superset.mcp_service.chart.plugins.handlebars import HandlebarsChartPlugin
from superset.mcp_service.chart.plugins.mixed_timeseries import (
MixedTimeseriesChartPlugin,
)
from superset.mcp_service.chart.plugins.pie import PieChartPlugin
from superset.mcp_service.chart.plugins.pivot_table import PivotTableChartPlugin
from superset.mcp_service.chart.plugins.table import TableChartPlugin
from superset.mcp_service.chart.plugins.xy import XYChartPlugin
from superset.mcp_service.chart.registry import register
# Register all built-in chart type plugins
register(XYChartPlugin())
register(TableChartPlugin())
register(PieChartPlugin())
register(PivotTableChartPlugin())
register(MixedTimeseriesChartPlugin())
register(HandlebarsChartPlugin())
register(BigNumberChartPlugin())
__all__ = [
"BigNumberChartPlugin",
"HandlebarsChartPlugin",
"MixedTimeseriesChartPlugin",
"PieChartPlugin",
"PivotTableChartPlugin",
"TableChartPlugin",
"XYChartPlugin",
]

View File

@@ -0,0 +1,247 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Big number chart type plugin."""
from __future__ import annotations
from typing import Any
from superset.mcp_service.chart.chart_utils import (
_big_number_chart_what,
_summarize_filters,
is_column_truly_temporal,
map_big_number_config,
)
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.schemas import BigNumberChartConfig, ColumnRef
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
from superset.mcp_service.common.error_schemas import ChartGenerationError
class BigNumberChartPlugin(BaseChartPlugin):
"""Plugin for big_number chart type."""
chart_type = "big_number"
display_name = "Big Number"
native_viz_types = {
"big_number": "Big Number with Trendline",
"big_number_total": "Big Number",
}
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
if "metric" not in config:
return ChartGenerationError(
error_type="missing_metric",
message="Big Number chart missing required field: metric",
details=(
"Big Number charts require a 'metric' field "
"specifying the value to display"
),
suggestions=[
"Add 'metric' with name and aggregate: "
"{'name': 'revenue', 'aggregate': 'SUM'}",
"The aggregate function is required (SUM, COUNT, AVG, MIN, MAX)",
"Example: {'chart_type': 'big_number', "
"'metric': {'name': 'sales', 'aggregate': 'SUM'}}",
],
error_code="MISSING_BIG_NUMBER_METRIC",
)
metric = config.get("metric", {})
if not isinstance(metric, dict):
return ChartGenerationError(
error_type="invalid_metric_type",
message="Big Number metric must be a dict with 'name' and 'aggregate'",
details=(
f"The 'metric' field must be an object, got {type(metric).__name__}"
),
suggestions=[
"Use a dict: {'name': 'col', 'aggregate': 'SUM'}",
"Valid aggregates: SUM, COUNT, AVG, MIN, MAX",
],
error_code="INVALID_BIG_NUMBER_METRIC_TYPE",
)
if metric.get("sql_expression"):
label = metric.get("label")
if not isinstance(label, str) or not label.strip():
return ChartGenerationError(
error_type="missing_sql_metric_label",
message="SQL expression metrics require a non-empty 'label'",
details=(
"When using a custom SQL expression as the Big Number metric, "
"a human-readable 'label' string is required so Superset can "
"display the metric name."
),
suggestions=[
"Add 'label': e.g. {'sql_expression': 'SUM(a)/SUM(b)', "
"'label': 'Conversion Rate'}",
"The label must be a non-empty string",
],
error_code="MISSING_SQL_METRIC_LABEL",
)
elif not metric.get("aggregate") and not metric.get("saved_metric"):
return ChartGenerationError(
error_type="missing_metric_aggregate",
message=(
"Big Number metric must include an aggregate function "
"or reference a saved metric"
),
details=(
"The metric must have an 'aggregate' field or 'saved_metric': true"
),
suggestions=[
"Add 'aggregate': {'name': 'col', 'aggregate': 'SUM'}",
"Or use a saved metric: {'name': 'metric', 'saved_metric': true}",
"Valid aggregates: SUM, COUNT, AVG, MIN, MAX",
],
error_code="MISSING_BIG_NUMBER_AGGREGATE",
)
show_trendline = config.get("show_trendline", False)
temporal_column = config.get("temporal_column")
if show_trendline and not temporal_column:
return ChartGenerationError(
error_type="missing_temporal_column",
message="Trendline requires a temporal column",
details=(
"When 'show_trendline' is True, "
"a 'temporal_column' must be specified"
),
suggestions=[
"Add 'temporal_column': 'date_column_name'",
"Or set 'show_trendline': false for number only",
"Use get_dataset_info to find temporal columns",
],
error_code="MISSING_TEMPORAL_COLUMN",
)
return None
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
if not isinstance(config, BigNumberChartConfig):
return []
refs: list[ColumnRef] = [config.metric]
# temporal_column is a str field, not a ColumnRef — validate it exists
if config.temporal_column:
refs.append(ColumnRef(name=config.temporal_column))
if config.filters:
for f in config.filters:
refs.append(ColumnRef(name=f.column))
return refs
def to_form_data(
self, config: Any, dataset_id: int | str | None = None
) -> dict[str, Any]:
return map_big_number_config(config)
def post_map_validate(
self,
config: Any,
form_data: dict[str, Any],
dataset_id: int | str | None = None,
) -> ChartGenerationError | None:
"""Verify the trendline temporal column is a real temporal SQL type.
This check was previously baked into map_config_to_form_data() in
chart_utils.py as a special case. Moving it here keeps the dispatcher
clean and makes the constraint explicit and discoverable.
"""
if not isinstance(config, BigNumberChartConfig):
return None
if not (config.show_trendline and config.temporal_column):
return None
if not is_column_truly_temporal(config.temporal_column, dataset_id):
return ChartGenerationError(
error_type="non_temporal_trendline_column",
message=(
f"Big Number trendline requires a temporal SQL column; "
f"'{config.temporal_column}' is not temporal."
),
details=(
f"Column '{config.temporal_column}' does not have a temporal "
f"SQL type (DATE, DATETIME, TIMESTAMP). The trendline requires "
f"a true temporal column for DATE_TRUNC to work."
),
suggestions=[
"Use get_dataset_info to find columns with temporal SQL types",
"Set 'show_trendline': false to use any column as the metric",
"If the column contains dates stored as integers, "
"consider casting it in a virtual dataset",
],
error_code="NON_TEMPORAL_TRENDLINE_COLUMN",
)
return None
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
what = _big_number_chart_what(config)
context = _summarize_filters(getattr(config, "filters", None))
return self._with_context(what, context)
def resolve_viz_type(self, config: Any) -> str:
show_trendline = getattr(config, "show_trendline", False)
temporal_column = getattr(config, "temporal_column", None)
if show_trendline and temporal_column:
return "big_number"
return "big_number_total"
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
config_dict = config.model_dump()
if config_dict.get("metric"):
if config_dict["metric"].get("saved_metric"):
config_dict["metric"]["name"] = (
DatasetValidator._get_canonical_metric_name(
config_dict["metric"]["name"], dataset_context
)
)
else:
config_dict["metric"]["name"] = (
DatasetValidator._get_canonical_column_name(
config_dict["metric"]["name"], dataset_context
)
)
if config_dict.get("temporal_column"):
config_dict["temporal_column"] = (
DatasetValidator._get_canonical_column_name(
config_dict["temporal_column"], dataset_context
)
)
DatasetValidator._normalize_filters(config_dict, dataset_context)
return BigNumberChartConfig.model_validate(config_dict)
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="big_number_validation_error",
message="Big Number chart configuration validation failed",
details=(
"The Big Number chart configuration is missing required "
"fields or has invalid structure"
),
suggestions=[
"Ensure 'metric' field has 'name' and 'aggregate'",
"Example: 'metric': {'name': 'revenue', 'aggregate': 'SUM'}",
"For trendline: add show_trendline=true and temporal_column='col'",
"Without trendline: just provide the metric",
],
error_code="BIG_NUMBER_VALIDATION_ERROR",
)

View File

@@ -0,0 +1,193 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Handlebars chart type plugin."""
from __future__ import annotations
from typing import Any
from superset.mcp_service.chart.chart_utils import (
_handlebars_chart_what,
_summarize_filters,
map_handlebars_config,
)
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.schemas import ColumnRef, HandlebarsChartConfig
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
from superset.mcp_service.common.error_schemas import ChartGenerationError
class HandlebarsChartPlugin(BaseChartPlugin):
"""Plugin for handlebars chart type (custom HTML template charts)."""
chart_type = "handlebars"
display_name = "Handlebars (Custom Template)"
native_viz_types = {
"handlebars": "Custom Template Chart",
}
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
if "handlebars_template" not in config:
return ChartGenerationError(
error_type="missing_handlebars_template",
message="Handlebars chart missing required field: handlebars_template",
details=(
"Handlebars charts require a 'handlebars_template' string "
"containing Handlebars HTML template markup"
),
suggestions=[
"Add 'handlebars_template' with a Handlebars HTML template",
"Data is available as {{data}} array in the template",
"Example: '<ul>{{#each data}}<li>{{this.name}}: "
"{{this.value}}</li>{{/each}}</ul>'",
],
error_code="MISSING_HANDLEBARS_TEMPLATE",
)
template = config.get("handlebars_template")
if not isinstance(template, str) or not template.strip():
return ChartGenerationError(
error_type="invalid_handlebars_template",
message="Handlebars template must be a non-empty string",
details=(
"The 'handlebars_template' field must be a non-empty string "
"containing valid Handlebars HTML template markup"
),
suggestions=[
"Ensure handlebars_template is a non-empty string",
"Example: '<ul>{{#each data}}<li>{{this.name}}</li>{{/each}}</ul>'",
],
error_code="INVALID_HANDLEBARS_TEMPLATE",
)
query_mode = config.get("query_mode", "aggregate")
if query_mode not in ("aggregate", "raw"):
return ChartGenerationError(
error_type="invalid_query_mode",
message="Invalid query_mode for handlebars chart",
details="query_mode must be either 'aggregate' or 'raw'",
suggestions=[
"Use 'aggregate' for aggregated data (default)",
"Use 'raw' for individual rows",
],
error_code="INVALID_QUERY_MODE",
)
if query_mode == "raw" and not config.get("columns"):
return ChartGenerationError(
error_type="missing_raw_columns",
message="Handlebars chart in 'raw' mode requires 'columns'",
details=(
"When query_mode is 'raw', you must specify which columns "
"to include in the query results"
),
suggestions=[
"Add 'columns': [{'name': 'column_name'}] for raw mode",
"Or use query_mode='aggregate' with 'metrics' and optional 'groupby'", # noqa: E501
],
error_code="MISSING_RAW_COLUMNS",
)
if query_mode == "aggregate" and not config.get("metrics"):
return ChartGenerationError(
error_type="missing_aggregate_metrics",
message="Handlebars chart in 'aggregate' mode requires 'metrics'",
details=(
"When query_mode is 'aggregate' (default), you must specify "
"at least one metric with an aggregate function"
),
suggestions=[
"Add 'metrics': [{'name': 'column', 'aggregate': 'SUM'}]",
"Or use query_mode='raw' with 'columns' for individual rows",
],
error_code="MISSING_AGGREGATE_METRICS",
)
return None
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
if not isinstance(config, HandlebarsChartConfig):
return []
refs: list[ColumnRef] = []
if config.columns:
refs.extend(config.columns)
if config.metrics:
refs.extend(config.metrics)
if config.groupby:
refs.extend(config.groupby)
if config.filters:
for f in config.filters:
refs.append(ColumnRef(name=f.column))
return refs
def to_form_data(
self, config: Any, dataset_id: int | str | None = None
) -> dict[str, Any]:
return map_handlebars_config(config)
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
what = _handlebars_chart_what(config)
context = _summarize_filters(getattr(config, "filters", None))
return self._with_context(what, context)
def resolve_viz_type(self, config: Any) -> str:
return "handlebars"
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
config_dict = config.model_dump()
def _norm_list(key: str) -> None:
if config_dict.get(key):
for col in config_dict[key]:
if col.get("saved_metric"):
col["name"] = DatasetValidator._get_canonical_metric_name(
col["name"], dataset_context
)
else:
col["name"] = DatasetValidator._get_canonical_column_name(
col["name"], dataset_context
)
_norm_list("columns")
_norm_list("metrics")
_norm_list("groupby")
DatasetValidator._normalize_filters(config_dict, dataset_context)
return HandlebarsChartConfig.model_validate(config_dict)
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="handlebars_validation_error",
message="Handlebars chart configuration validation failed",
details=(
"The handlebars chart configuration is missing "
"required fields or has invalid structure"
),
suggestions=[
"Ensure 'handlebars_template' is a non-empty string",
"For aggregate mode: add 'metrics' with aggregate functions",
"For raw mode: set 'query_mode': 'raw' and add 'columns'",
"Example: {'chart_type': 'handlebars', "
"'handlebars_template': "
"'<ul>{{#each data}}<li>{{this.name}}</li>{{/each}}</ul>', "
"'metrics': [{'name': 'sales', 'aggregate': 'SUM'}]}",
],
error_code="HANDLEBARS_VALIDATION_ERROR",
)

View File

@@ -0,0 +1,170 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Mixed timeseries chart type plugin."""
from __future__ import annotations
from typing import Any
from superset.mcp_service.chart.chart_utils import (
_mixed_timeseries_what,
_summarize_filters,
map_mixed_timeseries_config,
)
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.schemas import ColumnRef, MixedTimeseriesChartConfig
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
from superset.mcp_service.common.error_schemas import ChartGenerationError
class MixedTimeseriesChartPlugin(BaseChartPlugin):
"""Plugin for mixed_timeseries chart type."""
chart_type = "mixed_timeseries"
display_name = "Mixed Timeseries"
native_viz_types = {
"mixed_timeseries": "Mixed Timeseries Chart",
}
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
missing_fields = []
if "x" not in config and "x_axis" not in config:
missing_fields.append("'x' (X-axis temporal column)")
if not config.get("y") and not config.get("metrics"):
missing_fields.append("'y' (primary Y-axis metrics)")
if not config.get("y_secondary") and not config.get("metrics_b"):
missing_fields.append("'y_secondary' (secondary Y-axis metrics)")
if missing_fields:
return ChartGenerationError(
error_type="missing_mixed_timeseries_fields",
message=(
f"Mixed timeseries chart missing required fields: "
f"{', '.join(missing_fields)}"
),
details=(
"Mixed timeseries charts require an x-axis, primary metrics, "
"and secondary metrics"
),
suggestions=[
"Add 'x' field: {'name': 'date_column'}",
"Add 'y' field: [{'name': 'revenue', 'aggregate': 'SUM'}]",
"Add 'y_secondary': [{'name': 'orders', 'aggregate': 'COUNT'}]",
"Optional: 'primary_kind' and 'secondary_kind' for chart types",
],
error_code="MISSING_MIXED_TIMESERIES_FIELDS",
)
for field_name in ["y", "y_secondary"]:
if not isinstance(config.get(field_name, []), list):
return ChartGenerationError(
error_type=f"invalid_{field_name}_format",
message=f"'{field_name}' must be a list of metrics",
details=(
f"The '{field_name}' field must be an array of metric "
"specifications"
),
suggestions=[
f"Wrap in array: '{field_name}': "
"[{'name': 'col', 'aggregate': 'SUM'}]",
],
error_code=f"INVALID_{field_name.upper()}_FORMAT",
)
return None
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
if not isinstance(config, MixedTimeseriesChartConfig):
return []
refs: list[ColumnRef] = [config.x]
refs.extend(config.y)
refs.extend(config.y_secondary)
if config.group_by:
refs.extend(config.group_by)
if config.group_by_secondary:
refs.extend(config.group_by_secondary)
if config.filters:
for f in config.filters:
refs.append(ColumnRef(name=f.column))
return refs
def to_form_data(
self, config: Any, dataset_id: int | str | None = None
) -> dict[str, Any]:
return map_mixed_timeseries_config(config, dataset_id=dataset_id)
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
what = _mixed_timeseries_what(config)
context = _summarize_filters(config.filters)
return self._with_context(what, context)
def resolve_viz_type(self, config: Any) -> str:
return "mixed_timeseries"
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
config_dict = config.model_dump()
def _norm_single(key: str) -> None:
if config_dict.get(key):
config_dict[key]["name"] = DatasetValidator._get_canonical_column_name(
config_dict[key]["name"], dataset_context
)
def _norm_list(key: str) -> None:
if config_dict.get(key):
for col in config_dict[key]:
if col.get("saved_metric"):
col["name"] = DatasetValidator._get_canonical_metric_name(
col["name"], dataset_context
)
else:
col["name"] = DatasetValidator._get_canonical_column_name(
col["name"], dataset_context
)
_norm_single("x")
_norm_list("y")
_norm_list("y_secondary")
_norm_list("group_by")
_norm_list("group_by_secondary")
DatasetValidator._normalize_filters(config_dict, dataset_context)
return MixedTimeseriesChartConfig.model_validate(config_dict)
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="mixed_timeseries_validation_error",
message="Mixed timeseries chart configuration validation failed",
details=(
"The mixed timeseries configuration is missing "
"required fields or has invalid structure"
),
suggestions=[
"Ensure 'x' field has 'name' for the time axis column",
"Ensure 'y' is an array of primary-axis metrics",
"Ensure 'y_secondary' is an array of secondary-axis metrics",
"Example: {'chart_type': 'mixed_timeseries', "
"'x': {'name': 'order_date'}, "
"'y': [{'name': 'revenue', 'aggregate': 'SUM'}], "
"'y_secondary': [{'name': 'orders', 'aggregate': 'COUNT'}]}",
],
error_code="MIXED_TIMESERIES_VALIDATION_ERROR",
)

View File

@@ -0,0 +1,137 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Pie chart type plugin."""
from __future__ import annotations
from typing import Any
from superset.mcp_service.chart.chart_utils import (
_pie_chart_what,
_summarize_filters,
map_pie_config,
)
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.schemas import ColumnRef, PieChartConfig
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
from superset.mcp_service.common.error_schemas import ChartGenerationError
class PieChartPlugin(BaseChartPlugin):
"""Plugin for pie chart type."""
chart_type = "pie"
display_name = "Pie / Donut Chart"
native_viz_types = {
"pie": "Pie Chart",
}
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
missing_fields = []
if "dimension" not in config:
missing_fields.append("'dimension' (category column for slices)")
if "metric" not in config:
missing_fields.append("'metric' (value metric for slice sizes)")
if missing_fields:
return ChartGenerationError(
error_type="missing_pie_fields",
message=(
f"Pie chart missing required fields: {', '.join(missing_fields)}"
),
details=(
"Pie charts require a dimension (categories) and a metric (values)"
),
suggestions=[
"Add 'dimension' field: {'name': 'category_column'}",
"Add 'metric' field: {'name': 'value_column', 'aggregate': 'SUM'}",
"Example: {'chart_type': 'pie', 'dimension': {'name': 'product'}, "
"'metric': {'name': 'revenue', 'aggregate': 'SUM'}}",
],
error_code="MISSING_PIE_FIELDS",
)
return None
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
if not isinstance(config, PieChartConfig):
return []
refs: list[ColumnRef] = [config.dimension, config.metric]
if config.filters:
for f in config.filters:
refs.append(ColumnRef(name=f.column))
return refs
def to_form_data(
self, config: Any, dataset_id: int | str | None = None
) -> dict[str, Any]:
return map_pie_config(config)
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
what = _pie_chart_what(config)
context = _summarize_filters(config.filters)
return self._with_context(what, context)
def resolve_viz_type(self, config: Any) -> str:
return "pie"
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
config_dict = config.model_dump()
if config_dict.get("dimension"):
config_dict["dimension"]["name"] = (
DatasetValidator._get_canonical_column_name(
config_dict["dimension"]["name"], dataset_context
)
)
if config_dict.get("metric"):
if config_dict["metric"].get("saved_metric"):
config_dict["metric"]["name"] = (
DatasetValidator._get_canonical_metric_name(
config_dict["metric"]["name"], dataset_context
)
)
else:
config_dict["metric"]["name"] = (
DatasetValidator._get_canonical_column_name(
config_dict["metric"]["name"], dataset_context
)
)
DatasetValidator._normalize_filters(config_dict, dataset_context)
return PieChartConfig.model_validate(config_dict)
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="pie_validation_error",
message="Pie chart configuration validation failed",
details=(
"The pie chart configuration is missing required "
"fields or has invalid structure"
),
suggestions=[
"Ensure 'dimension' field has 'name' for the slice label",
"Ensure 'metric' field has 'name' and 'aggregate'",
"Example: {'chart_type': 'pie', 'dimension': {'name': 'category'}, "
"'metric': {'name': 'revenue', 'aggregate': 'SUM'}}",
],
error_code="PIE_VALIDATION_ERROR",
)

View File

@@ -0,0 +1,158 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Pivot table chart type plugin."""
from __future__ import annotations
from typing import Any
from superset.mcp_service.chart.chart_utils import (
_pivot_table_what,
_summarize_filters,
map_pivot_table_config,
)
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.schemas import ColumnRef, PivotTableChartConfig
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
from superset.mcp_service.common.error_schemas import ChartGenerationError
class PivotTableChartPlugin(BaseChartPlugin):
"""Plugin for pivot_table chart type."""
chart_type = "pivot_table"
display_name = "Pivot Table"
native_viz_types = {
"pivot_table_v2": "Pivot Table",
}
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
missing_fields = []
if not config.get("rows"):
missing_fields.append("'rows' (row grouping columns)")
if not config.get("metrics"):
missing_fields.append("'metrics' (aggregation metrics)")
if missing_fields:
return ChartGenerationError(
error_type="missing_pivot_fields",
message=(
f"Pivot table missing required fields: {', '.join(missing_fields)}"
),
details="Pivot tables require row groupings and metrics",
suggestions=[
"Add 'rows' field: [{'name': 'category'}]",
"Add 'metrics' field: [{'name': 'sales', 'aggregate': 'SUM'}]",
"Optional 'columns' for cross-tabulation: [{'name': 'region'}]",
],
error_code="MISSING_PIVOT_FIELDS",
)
if not isinstance(config.get("rows", []), list):
return ChartGenerationError(
error_type="invalid_rows_format",
message="Rows must be a list of columns",
details="The 'rows' field must be an array of column specifications",
suggestions=[
"Wrap row columns in array: 'rows': [{'name': 'category'}]",
],
error_code="INVALID_ROWS_FORMAT",
)
if not isinstance(config.get("metrics", []), list):
return ChartGenerationError(
error_type="invalid_metrics_format",
message="Metrics must be a list",
details="The 'metrics' field must be an array of metric specifications",
suggestions=[
"Wrap metrics in array: 'metrics': [{'name': 'sales', "
"'aggregate': 'SUM'}]",
],
error_code="INVALID_METRICS_FORMAT",
)
return None
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
if not isinstance(config, PivotTableChartConfig):
return []
refs: list[ColumnRef] = list(config.rows)
refs.extend(config.metrics)
if config.columns:
refs.extend(config.columns)
if config.filters:
for f in config.filters:
refs.append(ColumnRef(name=f.column))
return refs
def to_form_data(
self, config: Any, dataset_id: int | str | None = None
) -> dict[str, Any]:
return map_pivot_table_config(config)
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
what = _pivot_table_what(config)
context = _summarize_filters(config.filters)
return self._with_context(what, context)
def resolve_viz_type(self, config: Any) -> str:
return "pivot_table_v2"
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
config_dict = config.model_dump()
def _norm_col_list(key: str) -> None:
if config_dict.get(key):
for col in config_dict[key]:
if col.get("saved_metric"):
col["name"] = DatasetValidator._get_canonical_metric_name(
col["name"], dataset_context
)
else:
col["name"] = DatasetValidator._get_canonical_column_name(
col["name"], dataset_context
)
_norm_col_list("rows")
_norm_col_list("metrics")
_norm_col_list("columns")
DatasetValidator._normalize_filters(config_dict, dataset_context)
return PivotTableChartConfig.model_validate(config_dict)
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="pivot_table_validation_error",
message="Pivot table configuration validation failed",
details=(
"The pivot table configuration is missing required "
"fields or has invalid structure"
),
suggestions=[
"Ensure 'rows' field is an array of column specs",
"Ensure 'metrics' field is an array with aggregate funcs",
"Optional: add 'columns' for column grouping",
"Example: {'chart_type': 'pivot_table', "
"'rows': [{'name': 'region'}], "
"'metrics': [{'name': 'revenue', 'aggregate': 'SUM'}]}",
],
error_code="PIVOT_TABLE_VALIDATION_ERROR",
)

View File

@@ -0,0 +1,132 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Table chart type plugin."""
from __future__ import annotations
from typing import Any
from superset.mcp_service.chart.chart_utils import (
_summarize_filters,
_table_chart_what,
map_table_config,
)
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.schemas import ColumnRef, TableChartConfig
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
from superset.mcp_service.common.error_schemas import ChartGenerationError
class TableChartPlugin(BaseChartPlugin):
"""Plugin for table chart type."""
chart_type = "table"
display_name = "Table"
native_viz_types = {
"table": "Table",
"ag-grid-table": "Interactive Table",
}
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
if not config.get("columns"):
return ChartGenerationError(
error_type="missing_columns",
message="Table chart missing required field: columns",
details=(
"Table charts require a 'columns' array to specify which "
"columns to display"
),
suggestions=[
"Add 'columns' field with array of column specifications",
"Example: 'columns': [{'name': 'product'}, {'name': 'sales', "
"'aggregate': 'SUM'}]",
"Each column can have optional 'aggregate' for metrics",
],
error_code="MISSING_COLUMNS",
)
if not isinstance(config.get("columns", []), list):
return ChartGenerationError(
error_type="invalid_columns_format",
message="Columns must be a list",
details="The 'columns' field must be an array of column specifications",
suggestions=[
"Ensure columns is an array: 'columns': [...]",
"Each column should be an object with 'name' field",
],
error_code="INVALID_COLUMNS_FORMAT",
)
return None
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
if not isinstance(config, TableChartConfig):
return []
refs: list[ColumnRef] = list(config.columns)
if config.filters:
for f in config.filters:
refs.append(ColumnRef(name=f.column))
return refs
def to_form_data(
self, config: Any, dataset_id: int | str | None = None
) -> dict[str, Any]:
return map_table_config(config)
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
what = _table_chart_what(config, dataset_name)
context = _summarize_filters(config.filters)
return self._with_context(what, context)
def resolve_viz_type(self, config: Any) -> str:
return getattr(config, "viz_type", "table")
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
config_dict = config.model_dump()
get_canonical = DatasetValidator._get_canonical_column_name
get_canonical_metric = DatasetValidator._get_canonical_metric_name
for col in config_dict.get("columns") or []:
if col.get("saved_metric"):
col["name"] = get_canonical_metric(col["name"], dataset_context)
else:
col["name"] = get_canonical(col["name"], dataset_context)
DatasetValidator._normalize_filters(config_dict, dataset_context)
return TableChartConfig.model_validate(config_dict)
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="table_validation_error",
message="Table chart configuration validation failed",
details=(
"The table chart configuration is missing required "
"fields or has invalid structure"
),
suggestions=[
"Ensure 'columns' field is an array of column specifications",
"Each column needs {'name': 'column_name'}",
"Optional: add 'aggregate' for metrics",
"Example: 'columns': [{'name': 'product'}, "
"{'name': 'sales', 'aggregate': 'SUM'}]",
],
error_code="TABLE_VALIDATION_ERROR",
)

View File

@@ -0,0 +1,198 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""XY chart type plugin (line, bar, area, scatter)."""
from __future__ import annotations
import logging
from typing import Any
from superset.mcp_service.chart.chart_utils import (
_xy_chart_context,
_xy_chart_what,
map_xy_config,
)
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.schemas import ColumnRef, XYChartConfig
from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator
from superset.mcp_service.chart.validation.runtime.cardinality_validator import (
CardinalityValidator,
)
from superset.mcp_service.chart.validation.runtime.format_validator import (
FormatTypeValidator,
)
from superset.mcp_service.common.error_schemas import ChartGenerationError
logger = logging.getLogger(__name__)
class XYChartPlugin(BaseChartPlugin):
"""Plugin for xy chart type (line, bar, area, scatter)."""
chart_type = "xy"
display_name = "Line / Bar / Area / Scatter Chart"
native_viz_types = {
"echarts_timeseries_line": "Line Chart",
"echarts_timeseries_bar": "Bar Chart",
"echarts_area": "Area Chart",
"echarts_timeseries_scatter": "Scatter Plot",
}
def pre_validate(
self,
config: dict[str, Any],
) -> ChartGenerationError | None:
# x is optional — defaults to dataset's main_dttm_col in map_xy_config
if not config.get("y") and not config.get("metrics"):
return ChartGenerationError(
error_type="missing_xy_fields",
message="XY chart missing required field: 'y' (Y-axis metrics)",
details=(
"XY charts require Y-axis (metrics) specifications. "
"X-axis is optional and defaults to the dataset's primary "
"datetime column when omitted."
),
suggestions=[
"Add 'y' field: [{'name': 'metric_column', 'aggregate': 'SUM'}]",
"Example: {'chart_type': 'xy', 'x': {'name': 'date'}, "
"'y': [{'name': 'sales', 'aggregate': 'SUM'}]}",
],
error_code="MISSING_XY_FIELDS",
)
if not isinstance(config.get("y", []), list):
return ChartGenerationError(
error_type="invalid_y_format",
message="Y-axis must be a list of metrics",
details="The 'y' field must be an array of metric specifications",
suggestions=[
"Wrap Y-axis metric in array: 'y': [{'name': 'column', "
"'aggregate': 'SUM'}]",
"Multiple metrics supported: 'y': [metric1, metric2, ...]",
],
error_code="INVALID_Y_FORMAT",
)
return None
def extract_column_refs(self, config: Any) -> list[ColumnRef]:
if not isinstance(config, XYChartConfig):
return []
refs: list[ColumnRef] = []
if config.x is not None:
refs.append(config.x)
refs.extend(config.y)
if config.group_by:
refs.extend(config.group_by)
if config.filters:
for f in config.filters:
refs.append(ColumnRef(name=f.column))
return refs
def to_form_data(
self, config: Any, dataset_id: int | str | None = None
) -> dict[str, Any]:
return map_xy_config(config, dataset_id=dataset_id)
def normalize_column_refs(self, config: Any, dataset_context: Any) -> Any:
config_dict = config.model_dump()
get_canonical = DatasetValidator._get_canonical_column_name
get_canonical_metric = DatasetValidator._get_canonical_metric_name
if config_dict.get("x"):
config_dict["x"]["name"] = get_canonical(
config_dict["x"]["name"], dataset_context
)
for y_col in config_dict.get("y") or []:
if y_col.get("sql_expression"):
continue # sql_expression metrics have no underlying column
if y_col.get("saved_metric"):
y_col["name"] = get_canonical_metric(y_col["name"], dataset_context)
else:
y_col["name"] = get_canonical(y_col["name"], dataset_context)
for gb_col in config_dict.get("group_by") or []:
gb_col["name"] = get_canonical(gb_col["name"], dataset_context)
DatasetValidator._normalize_filters(config_dict, dataset_context)
return XYChartConfig.model_validate(config_dict)
def generate_name(self, config: Any, dataset_name: str | None = None) -> str:
what = _xy_chart_what(config)
context = _xy_chart_context(config)
return self._with_context(what, context)
def resolve_viz_type(self, config: Any) -> str:
kind = getattr(config, "kind", "line")
return {
"line": "echarts_timeseries_line",
"bar": "echarts_timeseries_bar",
"area": "echarts_area",
"scatter": "echarts_timeseries_scatter",
}.get(kind, "echarts_timeseries_line")
def get_runtime_warnings(self, config: Any, dataset_id: int | str) -> list[str]:
"""Return format-compatibility and cardinality warnings for XY charts."""
if not isinstance(config, XYChartConfig):
return []
warnings: list[str] = []
try:
_valid, format_warnings = FormatTypeValidator.validate_format_compatibility(
config
)
if format_warnings:
warnings.extend(format_warnings)
except Exception as exc: # noqa: BLE001 — non-blocking warning path
logger.warning("XY format validation failed: %s", exc)
try:
chart_kind = config.kind
group_by_col = config.group_by[0].name if config.group_by else None
if config.x is not None and config.x.name is not None:
_ok, card_info = CardinalityValidator.check_cardinality(
dataset_id=dataset_id,
x_column=config.x.name,
chart_type=chart_kind,
group_by_column=group_by_col,
)
if not _ok and card_info:
warnings.extend(card_info.get("warnings", []))
warnings.extend(card_info.get("suggestions", []))
except Exception as exc: # noqa: BLE001 — DB queries may raise infra errors
logger.warning("XY cardinality validation failed: %s", exc)
return warnings
def schema_error_hint(self) -> ChartGenerationError | None:
return ChartGenerationError(
error_type="xy_validation_error",
message="XY chart configuration validation failed",
details=(
"The XY chart configuration is missing required "
"fields or has invalid structure"
),
suggestions=[
"Note: 'x' is optional and defaults to the dataset's primary "
"datetime column",
"Ensure 'y' is an array: [{'name': 'metric', 'aggregate': 'SUM'}]",
"Check that all column names are strings",
"Verify aggregate functions are valid: SUM, COUNT, AVG, MIN, MAX",
],
error_code="XY_VALIDATION_ERROR",
)

View File

@@ -0,0 +1,279 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
ChartTypeRegistry — central registry mapping chart_type strings to plugins.
Replaces the four previously-scattered dispatch locations:
- schema_validator.py: chart_type_validators dict
- dataset_validator.py: isinstance branches in _extract_column_references()
- chart_utils.py: if/elif chain in map_config_to_form_data()
- dataset_validator.py: isinstance branches in normalize_column_names()
Usage::
from superset.mcp_service.chart.registry import get_registry
plugin = get_registry().get("xy")
if plugin is None:
raise ValueError("Unknown chart type: xy")
form_data = plugin.to_form_data(config, dataset_id)
"""
from __future__ import annotations
import logging
import sys
import threading
from collections.abc import Callable, Iterable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from superset.mcp_service.chart.plugin import ChartTypePlugin
logger = logging.getLogger(__name__)
_REGISTRY: dict[str, "ChartTypePlugin"] = {}
_plugins_loaded = False
_plugins_load_failed = False
_plugins_lock = threading.RLock()
# ---------------------------------------------------------------------------
# Plugin filter — replaced atomically by configure() at app startup.
# Default: all registered plugins visible (no disabled set, no callable).
# ---------------------------------------------------------------------------
PluginEnabledFunc = Callable[[str], bool]
@dataclass(frozen=True)
class _PluginFilterConfig:
disabled_plugins: frozenset[str] = field(default_factory=frozenset)
enabled_func: PluginEnabledFunc | None = None
_filter_config: _PluginFilterConfig = _PluginFilterConfig()
def _ensure_plugins_loaded() -> None:
"""Lazily import the plugins package to populate _REGISTRY.
Called before every registry lookup so the registry is always populated,
even when callers (tests, chart_utils, validators) import this module
directly without first importing app.py.
"""
global _plugins_loaded, _plugins_load_failed
if _plugins_loaded or _plugins_load_failed:
return
with _plugins_lock:
if not _plugins_loaded and not _plugins_load_failed:
registry_before_import = dict(_REGISTRY)
try:
import superset.mcp_service.chart.plugins # noqa: F401
_plugins_loaded = True
except Exception: # noqa: BLE001 — plugin import may raise anything
_REGISTRY.clear()
_REGISTRY.update(registry_before_import)
_plugins_load_failed = True
logger.exception(
"Failed to load built-in chart type plugins; "
"further lookups will return None"
)
def configure(
disabled: Iterable[str] | None = None,
enabled_func: PluginEnabledFunc | None = None,
) -> None:
"""Set runtime plugin filters. Called once during app initialization.
Replaces the filter config atomically with a single assignment so concurrent
readers always observe a consistent (disabled_plugins, enabled_func) pair.
Args:
disabled: chart_type strings to suppress. Accepts any iterable (set,
frozenset, list, tuple). Ignored when enabled_func is provided.
enabled_func: callable(chart_type) -> bool. When set, overrides
``disabled``. Must be cheap and in-process — no network I/O per
call. On exception the registry fails *closed* (plugin hidden).
"""
global _filter_config
if enabled_func is not None and not callable(enabled_func):
raise TypeError("enabled_func must be callable or None")
new_config = _PluginFilterConfig(
disabled_plugins=frozenset(disabled or ()),
enabled_func=enabled_func,
)
_filter_config = new_config
if new_config.disabled_plugins:
logger.info(
"MCP chart plugins disabled: %s", sorted(new_config.disabled_plugins)
)
if new_config.enabled_func is not None:
logger.info(
"MCP chart plugin dynamic filter configured: %r", new_config.enabled_func
)
def _is_plugin_enabled(chart_type: str) -> bool:
"""Return True if the plugin is currently enabled (not filtered out)."""
config = _filter_config # read once — atomic reference in CPython
if config.enabled_func is not None:
try:
return bool(config.enabled_func(chart_type))
except Exception: # noqa: BLE001 — operator-supplied callable may raise anything
logger.warning(
"MCP_CHART_PLUGIN_ENABLED_FUNC raised for chart_type=%r; "
"failing closed (plugin hidden)",
chart_type,
exc_info=True,
)
return False
return chart_type not in config.disabled_plugins
def register(plugin: "ChartTypePlugin") -> None:
"""Register a chart type plugin in the global registry."""
if not plugin.chart_type:
raise ValueError(f"{type(plugin).__name__} must define a non-empty chart_type")
with _plugins_lock:
if plugin.chart_type in _REGISTRY:
logger.warning(
"Overwriting existing plugin for chart_type=%r", plugin.chart_type
)
for existing in _REGISTRY.values():
if existing.chart_type == plugin.chart_type:
continue
colliding = plugin.native_viz_types.keys() & existing.native_viz_types
if colliding:
# display_name_for_viz_type() resolves to the first plugin in
# iteration order, making the later registration unreachable.
logger.warning(
"Plugin %r declares native_viz_types %s already claimed by "
"plugin %r; viz_type display-name lookups will resolve to "
"the earlier registration",
plugin.chart_type,
sorted(colliding),
existing.chart_type,
)
_REGISTRY[plugin.chart_type] = plugin
logger.debug("Registered chart plugin: %r", plugin.chart_type)
def get(chart_type: str) -> "ChartTypePlugin | None":
"""Return the plugin for chart_type, or None if unknown or disabled."""
_ensure_plugins_loaded()
if chart_type not in _REGISTRY or not _is_plugin_enabled(chart_type):
return None
return _REGISTRY[chart_type]
def all_types() -> list[str]:
"""Return enabled registered chart type strings in insertion order."""
_ensure_plugins_loaded()
return [ct for ct in _REGISTRY if _is_plugin_enabled(ct)]
def is_registered(chart_type: str) -> bool:
"""Return True if chart_type has a registered plugin, regardless of enabled state.
Use this to distinguish an unknown chart type from a disabled one.
Use is_enabled() to check whether the plugin is currently available.
"""
_ensure_plugins_loaded()
return chart_type in _REGISTRY
def is_enabled(chart_type: str) -> bool:
"""Return True if chart_type is registered AND currently enabled."""
_ensure_plugins_loaded()
return chart_type in _REGISTRY and _is_plugin_enabled(chart_type)
def display_name_for_viz_type(viz_type: str) -> str | None:
"""Return the user-facing display name for a Superset-internal viz_type.
Searches every registered plugin's ``native_viz_types`` mapping.
Returns None if no plugin recognises the viz_type.
Example::
display_name_for_viz_type("echarts_timeseries_line") # "Line Chart"
display_name_for_viz_type("pivot_table_v2") # "Pivot Table"
display_name_for_viz_type("unknown_type") # None
"""
_ensure_plugins_loaded()
for plugin in _REGISTRY.values():
name = plugin.native_viz_types.get(viz_type)
if name is not None:
return name
return None
def _reset_for_testing() -> None:
"""Reset all registry state to defaults.
Only for use in tests that need a clean slate. Calling this in production
will discard all registered plugins and any runtime filter configuration.
**Caller responsibility**: This function pops ``superset.mcp_service.chart.plugins``
from ``sys.modules`` and directly assigns module globals (``_REGISTRY``,
``_plugins_loaded``, etc.). Direct global assignment is NOT automatically
reverted by pytest's ``monkeypatch`` fixture. Callers must either use
``monkeypatch.setattr`` for each global, or call ``_reset_for_testing()`` again
in teardown to restore the clean state. See ``test_registry.py`` for the
recommended ``monkeypatch.setattr`` isolation pattern.
"""
global _REGISTRY, _plugins_loaded, _plugins_load_failed, _filter_config
with _plugins_lock:
_REGISTRY = {}
_plugins_loaded = False
_plugins_load_failed = False
_filter_config = _PluginFilterConfig()
sys.modules.pop("superset.mcp_service.chart.plugins", None)
class _RegistryProxy:
"""Thin proxy exposing registry functions as instance methods."""
def get(self, chart_type: str) -> "ChartTypePlugin | None":
return get(chart_type)
def all_types(self) -> list[str]:
return all_types()
def is_registered(self, chart_type: str) -> bool:
return is_registered(chart_type)
def is_enabled(self, chart_type: str) -> bool:
return is_enabled(chart_type)
def display_name_for_viz_type(self, viz_type: str) -> str | None:
return display_name_for_viz_type(viz_type)
_PROXY = _RegistryProxy()
def get_registry() -> "_RegistryProxy":
"""Return the module-level registry proxy (convenience wrapper)."""
return _PROXY

View File

@@ -22,6 +22,7 @@ Pydantic schemas for chart-related responses
from __future__ import annotations
import difflib
import logging
from datetime import datetime
from typing import Annotated, Any, cast, Dict, List, Literal, Protocol
@@ -68,6 +69,8 @@ from superset.mcp_service.utils.sanitization import (
sanitize_user_input_with_changes,
)
logger = logging.getLogger(__name__)
class ChartLike(Protocol):
"""Protocol for chart-like objects with expected attributes."""
@@ -102,7 +105,15 @@ class ChartInfo(BaseModel):
id: int | None = Field(None, description="Chart ID")
slice_name: str | None = Field(None, description="Chart name")
viz_type: str | None = Field(None, description="Visualization type")
viz_type: str | None = Field(None, description="Visualization type (internal ID)")
chart_type_display_name: str | None = Field(
None,
description=(
"User-friendly chart type name (e.g. 'Line Chart', 'Pivot Table'). "
"Prefer this field when referring to chart types; "
"fall back to viz_type when this field is null."
),
)
datasource_name: str | None = Field(None, description="Datasource name")
datasource_type: str | None = Field(None, description="Datasource type")
url: str | None = Field(None, description="Chart explore page URL")
@@ -561,11 +572,27 @@ def serialize_chart_object(chart: ChartLike | None) -> ChartInfo | None:
# Extract structured filter information
filters_info = extract_filters_from_form_data(chart_form_data)
_viz_type = getattr(chart, "viz_type", None)
_display_name = None
if _viz_type:
try:
from superset.mcp_service.chart.registry import display_name_for_viz_type
except ImportError:
pass
else:
try:
_display_name = display_name_for_viz_type(_viz_type)
except Exception as exc: # noqa: BLE001 — third-party plugins may raise
logger.debug(
"Failed to resolve display name for viz_type=%r: %s", _viz_type, exc
)
return sanitize_chart_info_for_llm_context(
ChartInfo(
id=chart_id,
slice_name=getattr(chart, "slice_name", None),
viz_type=getattr(chart, "viz_type", None),
viz_type=_viz_type,
chart_type_display_name=_display_name,
datasource_name=getattr(chart, "datasource_name", None),
datasource_type=getattr(chart, "datasource_type", None),
url=chart_url,
@@ -742,7 +769,6 @@ class ColumnRef(BaseModel):
None,
min_length=1,
max_length=255,
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
validation_alias=AliasChoices("name", "column_name"),
)
label: str | None = Field(None, max_length=500)
@@ -898,7 +924,6 @@ class FilterConfig(BaseModel):
...,
min_length=1,
max_length=255,
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
validation_alias=AliasChoices("column", "col"),
)
op: Literal[
@@ -930,7 +955,9 @@ class FilterConfig(BaseModel):
"""Sanitize filter column name to prevent injection attacks."""
# sanitize_user_input raises ValueError when allow_empty=False (default)
# so the return value is guaranteed to be a non-None str
return sanitize_user_input(v, "Filter column", max_length=255) # type: ignore[return-value]
return sanitize_user_input(
v, "Filter column", max_length=255, check_sql_keywords=True
) # type: ignore[return-value]
@field_validator("value")
@classmethod
@@ -1353,7 +1380,6 @@ class BigNumberChartConfig(UnknownFieldCheckMixin):
),
min_length=1,
max_length=255,
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
)
time_grain: TimeGrain | None = Field(
None,
@@ -1422,6 +1448,18 @@ class BigNumberChartConfig(UnknownFieldCheckMixin):
description="Filters to apply",
)
@field_validator("temporal_column")
@classmethod
def sanitize_temporal_column(cls, v: str | None) -> str | None:
"""Sanitize temporal column name to prevent SQL injection."""
return sanitize_user_input(
v,
"Temporal column",
max_length=255,
check_sql_keywords=True,
allow_empty=True,
)
@model_validator(mode="after")
def validate_trendline_fields(self) -> Self:
"""Validate trendline requires temporal column."""
@@ -1517,25 +1555,29 @@ class TableChartConfig(UnknownFieldCheckMixin):
@model_validator(mode="after")
def validate_unique_column_labels(self) -> "TableChartConfig":
"""Ensure all column labels are unique."""
labels_seen = set()
# Key is (saved_metric, label) so a saved metric and a regular column
# with the same input name are not flagged as duplicates — saved metrics
# resolve to their actual casing from the dataset during normalization.
labels_seen: dict[tuple[bool, str], str] = {}
duplicates = []
for i, col in enumerate(self.columns):
# Generate the label that will be used (same logic as create_metric_object)
if col.sql_expression:
# SQL metrics carry a required label; use it verbatim.
label = col.label
label = col.label or ""
elif col.saved_metric:
label = col.label or col.name
label = col.label or col.name or ""
elif col.aggregate:
label = col.label or f"{col.aggregate}({col.name})"
else:
label = col.label or col.name
label = col.label or col.name or ""
if label in labels_seen:
key = (col.saved_metric, label)
if key in labels_seen:
duplicates.append(f"columns[{i}]: '{label}'")
else:
labels_seen.add(label)
labels_seen[key] = f"columns[{i}]"
if duplicates:
raise ValueError(
@@ -1665,24 +1707,28 @@ class XYChartConfig(UnknownFieldCheckMixin):
@model_validator(mode="after")
def validate_unique_column_labels(self) -> "XYChartConfig":
"""Ensure all column labels are unique across x, y, and group_by."""
labels_seen: dict[str, str] = {}
# Key is (saved_metric, label) so a saved metric and a regular column
# with the same input name are not flagged as duplicates — saved metrics
# resolve to their actual casing from the dataset during normalization.
labels_seen: dict[tuple[bool, str], str] = {}
duplicates: list[str] = []
# Add x-axis label if present (x may be None, resolved later).
# The dimension validator rejects sql_expression on x, so name is set.
if self.x is not None:
x_label = self.x.label or self.x.name or ""
labels_seen[x_label] = "x"
labels_seen[(self.x.saved_metric, x_label)] = "x"
# Check Y-axis labels
for i, col in enumerate(self.y):
label = _metric_display_label(col)
if label in labels_seen:
key = (col.saved_metric, label)
if key in labels_seen:
duplicates.append(
f"y[{i}]: '{label}' (conflicts with {labels_seen[label]})"
f"y[{i}]: '{label}' (conflicts with {labels_seen[key]})"
)
else:
labels_seen[label] = f"y[{i}]"
labels_seen[key] = f"y[{i}]"
# Check group_by labels if present
if self.group_by:
@@ -1692,15 +1738,15 @@ class XYChartConfig(UnknownFieldCheckMixin):
# to prevent Superset "duplicate label" errors, so
# we allow them through validation.
continue
# group_by rejects sql_expression, so name is set.
group_label = col.label or col.name or ""
if group_label in labels_seen:
group_key = (col.saved_metric, group_label)
if group_key in labels_seen:
duplicates.append(
f"group_by[{i}]: '{group_label}' "
f"(conflicts with {labels_seen[group_label]})"
f"(conflicts with {labels_seen[group_key]})"
)
else:
labels_seen[group_label] = f"group_by[{i}]"
labels_seen[group_key] = f"group_by[{i}]"
if duplicates:
raise ValueError(

View File

@@ -105,18 +105,34 @@ async def generate_chart( # noqa: C901
- Set save_chart=True to permanently save the chart
- LLM clients MUST display returned chart URL to users
- Use numeric dataset ID or UUID (NOT schema.table_name format)
- MUST include chart_type in config (either 'xy' or 'table')
- MUST include chart_type in config (one of: 'xy', 'table', 'pie',
'pivot_table', 'mixed_timeseries', 'handlebars', 'big_number')
IMPORTANT: The 'chart_type' field in the config is a DISCRIMINATOR that determines
which chart configuration schema to use. It MUST be included and MUST match the
other fields in your configuration:
- Use chart_type='xy' for charts with x and y axes (line, bar, area, scatter)
Required fields: x, y
Required fields: y (x is optional — defaults to dataset's primary datetime column)
- Use chart_type='table' for tabular visualizations
Required fields: columns
- Use chart_type='pie' for pie/donut charts
Required fields: dimension, metric
- Use chart_type='pivot_table' for pivot table visualizations
Required fields: rows, metrics
- Use chart_type='mixed_timeseries' for dual-axis time-series charts
Required fields: x, y, y_secondary
- Use chart_type='handlebars' for custom template-based visualizations
Required fields: handlebars_template
- Use chart_type='big_number' for single KPI metric displays
Required fields: metric
Example usage for XY chart:
```json
{

View File

@@ -420,6 +420,24 @@ async def update_chart( # noqa: C901
# config is already a typed ChartConfig | None (validated by Pydantic)
parsed_config = request.config
# Normalize column case to match dataset canonical names
# (mirrors generate_chart pipeline layer 4)
chart_datasource_id = getattr(chart, "datasource_id", None)
if parsed_config is not None and chart_datasource_id is not None:
from superset.mcp_service.chart.validation.dataset_validator import (
DatasetValidator,
NORMALIZATION_EXCEPTIONS,
)
try:
parsed_config = DatasetValidator.normalize_column_names(
parsed_config, chart_datasource_id
)
except NORMALIZATION_EXCEPTIONS as e:
logger.warning(
"Column normalization failed for chart %s: %s", chart.id, e
)
if not request.generate_preview:
from superset.commands.chart.update import UpdateChartCommand

View File

@@ -22,17 +22,11 @@ Validates that referenced columns exist in the dataset schema.
import difflib
import logging
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, TypeVar
from superset.mcp_service.chart.schemas import (
BigNumberChartConfig,
ChartConfig,
ColumnRef,
HandlebarsChartConfig,
MixedTimeseriesChartConfig,
PieChartConfig,
PivotTableChartConfig,
TableChartConfig,
XYChartConfig,
)
from superset.mcp_service.common.error_schemas import (
ChartGenerationError,
@@ -40,6 +34,8 @@ from superset.mcp_service.common.error_schemas import (
DatasetContext,
)
_C = TypeVar("_C", bound=ChartConfig)
logger = logging.getLogger(__name__)
# Exceptions that can occur during column name normalization.
@@ -58,7 +54,7 @@ class DatasetValidator:
@staticmethod
def validate_against_dataset(
config: Any,
config: ChartConfig,
dataset_id: int | str,
dataset_context: DatasetContext | None = None,
) -> Tuple[bool, ChartGenerationError | None]:
@@ -269,59 +265,31 @@ class DatasetValidator:
return None
@staticmethod
def _extract_column_references(config: Any) -> List[ColumnRef]: # noqa: C901
"""Extract all column references from a chart configuration.
def _extract_column_references(
config: ChartConfig,
) -> List[ColumnRef]:
"""Extract all column references from configuration via the plugin registry.
Covers every supported ``ChartConfig`` variant so fast-path tools
(``generate_explore_link``, ``update_chart_preview``) that only run
Tier-1 validation still catch bad column refs in pie / pivot table /
mixed timeseries / handlebars / big number charts — not just XY and
table.
Previously only handled TableChartConfig and XYChartConfig, causing
5 of 7 chart types to silently skip column validation. Now delegates
to the plugin for each chart type so all types are covered.
"""
refs: List[ColumnRef] = []
# Local import: plugins call DatasetValidator helpers from
# normalize_column_refs().
# A top-level import of registry in dataset_validator would make loading this
# module implicitly trigger plugin registration, creating a circular dependency.
from superset.mcp_service.chart.registry import get_registry
if isinstance(config, TableChartConfig):
refs.extend(config.columns)
elif isinstance(config, XYChartConfig):
if config.x is not None:
refs.append(config.x)
refs.extend(config.y)
if config.group_by:
refs.extend(config.group_by)
elif isinstance(config, PieChartConfig):
refs.append(config.dimension)
refs.append(config.metric)
elif isinstance(config, PivotTableChartConfig):
refs.extend(config.rows)
if config.columns:
refs.extend(config.columns)
refs.extend(config.metrics)
elif isinstance(config, MixedTimeseriesChartConfig):
refs.append(config.x)
refs.extend(config.y)
if config.group_by:
refs.extend(config.group_by)
refs.extend(config.y_secondary)
if config.group_by_secondary:
refs.extend(config.group_by_secondary)
elif isinstance(config, HandlebarsChartConfig):
if config.columns:
refs.extend(config.columns)
if config.groupby:
refs.extend(config.groupby)
if config.metrics:
refs.extend(config.metrics)
elif isinstance(config, BigNumberChartConfig):
refs.append(config.metric)
if config.temporal_column:
refs.append(ColumnRef(name=config.temporal_column))
chart_type = getattr(config, "chart_type", None)
if chart_type is None:
return []
# Filter columns (shared by every config type that defines ``filters``).
if filters := getattr(config, "filters", None):
for filter_config in filters:
refs.append(ColumnRef(name=filter_config.column))
plugin = get_registry().get(chart_type)
if plugin is None:
logger.warning("No plugin registered for chart_type=%r", chart_type)
return []
return refs
return plugin.extract_column_refs(config)
@staticmethod
def _column_exists(column_name: str, dataset_context: DatasetContext) -> bool:
@@ -375,47 +343,23 @@ class DatasetValidator:
return column_name
@staticmethod
def _normalize_xy_config(
config_dict: Dict[str, Any], dataset_context: DatasetContext
) -> None:
"""Normalize column names in an XY chart config dict in place."""
# Normalize x-axis column
if "x" in config_dict and config_dict["x"] and config_dict["x"].get("name"):
config_dict["x"]["name"] = DatasetValidator._get_canonical_column_name(
config_dict["x"]["name"], dataset_context
)
def _get_canonical_metric_name(
metric_name: str, dataset_context: DatasetContext
) -> str:
"""Return the canonical saved-metric name from available_metrics.
# Normalize y-axis columns (skip SQL-expression metrics; no name).
if "y" in config_dict and config_dict["y"]:
for y_col in config_dict["y"]:
if not y_col.get("name"):
continue
y_col["name"] = DatasetValidator._get_canonical_column_name(
y_col["name"], dataset_context
)
Unlike _get_canonical_column_name, this only searches available_metrics
so that a same-named column with different casing cannot shadow the
metric's canonical name. Use this whenever saved_metric=True.
# Normalize group_by columns
if "group_by" in config_dict and config_dict["group_by"]:
for gb_col in config_dict["group_by"]:
if not gb_col.get("name"):
continue
gb_col["name"] = DatasetValidator._get_canonical_column_name(
gb_col["name"], dataset_context
)
@staticmethod
def _normalize_table_config(
config_dict: Dict[str, Any], dataset_context: DatasetContext
) -> None:
"""Normalize column names in a table chart config dict in place."""
if "columns" in config_dict and config_dict["columns"]:
for col in config_dict["columns"]:
# Skip SQL-expression metrics: no underlying column name.
if not col.get("name"):
continue
col["name"] = DatasetValidator._get_canonical_column_name(
col["name"], dataset_context
)
Returns the original name when no metric matches (validation catches
the missing-metric case separately).
"""
metric_lower = metric_name.lower()
for metric in dataset_context.available_metrics:
if metric["name"].lower() == metric_lower:
return metric["name"]
return metric_name
@staticmethod
def _normalize_filters(
@@ -433,10 +377,10 @@ class DatasetValidator:
@staticmethod
def normalize_column_names(
config: TableChartConfig | XYChartConfig,
config: _C,
dataset_id: int | str,
dataset_context: DatasetContext | None = None,
) -> TableChartConfig | XYChartConfig:
) -> _C:
"""
Normalize column names in config to match the canonical dataset column names.
@@ -445,6 +389,9 @@ class DatasetValidator:
(e.g., 'OrderDate'). The frontend performs case-sensitive comparisons,
so we need to ensure column names match exactly.
Previously only XYChartConfig and TableChartConfig were normalized; now
all 7 chart types are handled via the plugin registry.
Args:
config: Chart configuration with column references
dataset_id: Dataset ID to get canonical column names from
@@ -459,22 +406,24 @@ class DatasetValidator:
if not dataset_context:
return config
# Create a mutable copy of the config
config_dict = config.model_dump()
# Local import: plugins call DatasetValidator helpers from
# normalize_column_refs().
# A top-level import of registry in dataset_validator would make loading this
# module implicitly trigger plugin registration, creating a circular dependency.
from superset.mcp_service.chart.registry import get_registry
# Normalize based on config type
if isinstance(config, XYChartConfig):
DatasetValidator._normalize_xy_config(config_dict, dataset_context)
elif isinstance(config, TableChartConfig):
DatasetValidator._normalize_table_config(config_dict, dataset_context)
chart_type = getattr(config, "chart_type", None)
if chart_type is None:
return config
# Normalize filter columns (common to both config types)
DatasetValidator._normalize_filters(config_dict, dataset_context)
plugin = get_registry().get(chart_type)
if plugin is None:
logger.warning(
"No plugin for chart_type=%r; skipping column normalization", chart_type
)
return config
# Reconstruct the config with normalized names
if isinstance(config, XYChartConfig):
return XYChartConfig.model_validate(config_dict)
return TableChartConfig.model_validate(config_dict)
return plugin.normalize_column_refs(config, dataset_context)
@staticmethod
def _get_column_suggestions(

View File

@@ -23,10 +23,7 @@ Validates performance, compatibility, and user experience issues.
import logging
from typing import Any, Dict, List, Tuple
from superset.mcp_service.chart.schemas import (
ChartConfig,
XYChartConfig,
)
from superset.mcp_service.chart.schemas import ChartConfig
logger = logging.getLogger(__name__)
@@ -56,20 +53,10 @@ class RuntimeValidator:
warnings: List[str] = []
suggestions: List[str] = []
# Only check XY charts for format and cardinality issues
if isinstance(config, XYChartConfig):
# Format-type compatibility validation
format_warnings = RuntimeValidator._validate_format_compatibility(config)
if format_warnings:
warnings.extend(format_warnings)
# Cardinality validation
cardinality_warnings, cardinality_suggestions = (
RuntimeValidator._validate_cardinality(config, dataset_id)
)
if cardinality_warnings:
warnings.extend(cardinality_warnings)
suggestions.extend(cardinality_suggestions)
# Per-plugin runtime warnings (format, cardinality, etc.)
plugin_warnings = RuntimeValidator._validate_plugin_runtime(config, dataset_id)
if plugin_warnings:
warnings.extend(plugin_warnings)
# Chart type appropriateness validation (for all chart types)
type_warnings, type_suggestions = RuntimeValidator._validate_chart_type(
@@ -98,61 +85,28 @@ class RuntimeValidator:
return True, None
@staticmethod
def _validate_format_compatibility(config: XYChartConfig) -> List[str]:
"""Validate format-type compatibility."""
warnings: List[str] = []
def _validate_plugin_runtime(
config: ChartConfig, dataset_id: int | str
) -> List[str]:
"""Delegate per-chart-type runtime warnings to the plugin registry.
Each plugin's get_runtime_warnings() method returns chart-type-specific
warnings (e.g. format/cardinality for XY). The registry dispatch removes
the previous isinstance(config, XYChartConfig) hardcoding.
"""
try:
# Import here to avoid circular imports
from .format_validator import FormatTypeValidator
from superset.mcp_service.chart.registry import get_registry
is_valid, format_warnings = (
FormatTypeValidator.validate_format_compatibility(config)
)
if format_warnings:
warnings.extend(format_warnings)
except ImportError:
logger.warning("Format validator not available")
except Exception as e:
logger.warning("Format validation failed: %s", e)
return warnings
@staticmethod
def _validate_cardinality(
config: XYChartConfig, dataset_id: int | str
) -> Tuple[List[str], List[str]]:
"""Validate cardinality issues."""
warnings: List[str] = []
suggestions: List[str] = []
try:
# Import here to avoid circular imports
from .cardinality_validator import CardinalityValidator
# Determine chart type for cardinality thresholds
chart_type = config.kind if hasattr(config, "kind") else "default"
# Check X-axis cardinality
if config.x is None or config.x.name is None:
return warnings, suggestions
is_ok, cardinality_info = CardinalityValidator.check_cardinality(
dataset_id=dataset_id,
x_column=config.x.name,
chart_type=chart_type,
group_by_column=(config.group_by[0].name if config.group_by else None),
)
if not is_ok and cardinality_info:
warnings.extend(cardinality_info.get("warnings", []))
suggestions.extend(cardinality_info.get("suggestions", []))
except ImportError:
logger.warning("Cardinality validator not available")
except Exception as e:
logger.warning("Cardinality validation failed: %s", e)
return warnings, suggestions
chart_type = getattr(config, "chart_type", None)
if chart_type is None:
return []
plugin = get_registry().get(chart_type)
if plugin is None:
return []
return plugin.get_runtime_warnings(config, dataset_id)
except Exception as exc: # noqa: BLE001 — plugin code is third-party-extensible
logger.warning("Plugin runtime validation failed: %s", exc)
return []
@staticmethod
def _validate_chart_type(
@@ -184,7 +138,7 @@ class RuntimeValidator:
except ImportError:
logger.warning("Chart type suggester not available")
except Exception as e:
except Exception as e: # noqa: BLE001 — non-blocking warning path
logger.warning("Chart type validation failed: %s", e)
return warnings, suggestions

View File

@@ -147,19 +147,16 @@ class SchemaValidator:
chart_type: str,
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Validate chart type and dispatch to type-specific pre-validation."""
chart_type_validators = {
"xy": SchemaValidator._pre_validate_xy_config,
"table": SchemaValidator._pre_validate_table_config,
"pie": SchemaValidator._pre_validate_pie_config,
"pivot_table": SchemaValidator._pre_validate_pivot_table_config,
"mixed_timeseries": SchemaValidator._pre_validate_mixed_timeseries_config,
"handlebars": SchemaValidator._pre_validate_handlebars_config,
"big_number": SchemaValidator._pre_validate_big_number_config,
}
"""Validate chart type and dispatch to plugin pre-validation."""
# avoid circular import — a top-level import of registry here would pull in
# the plugins package before it finishes registering, creating an import cycle.
from superset.mcp_service.chart.registry import get_registry
if not isinstance(chart_type, str) or chart_type not in chart_type_validators:
valid_types = ", ".join(chart_type_validators.keys())
registry = get_registry()
# Compute once — used in both error branches below.
valid_types = ", ".join(registry.all_types())
if not isinstance(chart_type, str) or not registry.is_registered(chart_type):
return False, ChartGenerationError(
error_type="invalid_chart_type",
message=f"Invalid chart_type: '{chart_type}'",
@@ -178,376 +175,26 @@ class SchemaValidator:
error_code="INVALID_CHART_TYPE",
)
return chart_type_validators[chart_type](config)
@staticmethod
def _pre_validate_xy_config(
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Pre-validate XY chart configuration."""
# x is optional — defaults to dataset's main_dttm_col in map_xy_config
if "y" not in config:
# Single get() call — returns None when the plugin is disabled.
# Avoids calling enabled_func twice (separate is_enabled + get both
# invoke _is_plugin_enabled, which may call operator-supplied callable).
plugin = registry.get(chart_type)
if plugin is None:
return False, ChartGenerationError(
error_type="missing_xy_fields",
message="XY chart missing required field: 'y' (Y-axis metrics)",
details="XY charts require Y-axis (metrics) specifications. "
"X-axis is optional and defaults to the dataset's primary "
"datetime column when omitted.",
error_type="disabled_chart_type",
message=f"Chart type '{chart_type}' is not enabled on this instance",
details=f"Chart type '{chart_type}' is registered but has been "
f"disabled by the operator. "
f"Enabled chart types: {valid_types}",
suggestions=[
"Add 'y' field: [{'name': 'metric_column', 'aggregate': 'SUM'}] "
"for Y-axis",
"Example: {'chart_type': 'xy', 'x': {'name': 'date'}, "
"'y': [{'name': 'sales', 'aggregate': 'SUM'}]}",
f"Use one of the enabled chart types: {valid_types}",
"Contact your administrator if you believe this is an error",
],
error_code="MISSING_XY_FIELDS",
error_code="DISABLED_CHART_TYPE",
)
# Validate Y is a list
if not isinstance(config.get("y", []), list):
return False, ChartGenerationError(
error_type="invalid_y_format",
message="Y-axis must be a list of metrics",
details="The 'y' field must be an array of metric specifications",
suggestions=[
"Wrap Y-axis metric in array: 'y': [{'name': 'column', "
"'aggregate': 'SUM'}]",
"Multiple metrics supported: 'y': [metric1, metric2, ...]",
],
error_code="INVALID_Y_FORMAT",
)
return True, None
@staticmethod
def _pre_validate_table_config(
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Pre-validate table chart configuration."""
if "columns" not in config:
return False, ChartGenerationError(
error_type="missing_columns",
message="Table chart missing required field: columns",
details="Table charts require a 'columns' array to specify which "
"columns to display",
suggestions=[
"Add 'columns' field with array of column specifications",
"Example: 'columns': [{'name': 'product'}, {'name': 'sales', "
"'aggregate': 'SUM'}]",
"Each column can have optional 'aggregate' for metrics",
],
error_code="MISSING_COLUMNS",
)
if not isinstance(config.get("columns", []), list):
return False, ChartGenerationError(
error_type="invalid_columns_format",
message="Columns must be a list",
details="The 'columns' field must be an array of column specifications",
suggestions=[
"Ensure columns is an array: 'columns': [...]",
"Each column should be an object with 'name' field",
],
error_code="INVALID_COLUMNS_FORMAT",
)
return True, None
@staticmethod
def _pre_validate_pie_config(
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Pre-validate pie chart configuration."""
missing_fields = []
if "dimension" not in config:
missing_fields.append("'dimension' (category column for slices)")
if "metric" not in config:
missing_fields.append("'metric' (value metric for slice sizes)")
if missing_fields:
return False, ChartGenerationError(
error_type="missing_pie_fields",
message=f"Pie chart missing required "
f"fields: {', '.join(missing_fields)}",
details="Pie charts require a dimension (categories) and a metric "
"(values)",
suggestions=[
"Add 'dimension' field: {'name': 'category_column'}",
"Add 'metric' field: {'name': 'value_column', 'aggregate': 'SUM'}",
"Example: {'chart_type': 'pie', 'dimension': {'name': "
"'product'}, 'metric': {'name': 'revenue', 'aggregate': 'SUM'}}",
],
error_code="MISSING_PIE_FIELDS",
)
return True, None
@staticmethod
def _pre_validate_handlebars_config(
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Pre-validate handlebars chart configuration."""
if "handlebars_template" not in config:
return False, ChartGenerationError(
error_type="missing_handlebars_template",
message="Handlebars chart missing required field: handlebars_template",
details="Handlebars charts require a 'handlebars_template' string "
"containing Handlebars HTML template markup",
suggestions=[
"Add 'handlebars_template' with a Handlebars HTML template",
"Data is available as {{data}} array in the template",
"Example: '<ul>{{#each data}}<li>{{this.name}}: "
"{{this.value}}</li>{{/each}}</ul>'",
],
error_code="MISSING_HANDLEBARS_TEMPLATE",
)
template = config.get("handlebars_template")
if not isinstance(template, str) or not template.strip():
return False, ChartGenerationError(
error_type="invalid_handlebars_template",
message="Handlebars template must be a non-empty string",
details="The 'handlebars_template' field must be a non-empty string "
"containing valid Handlebars HTML template markup",
suggestions=[
"Ensure handlebars_template is a non-empty string",
"Example: '<ul>{{#each data}}<li>{{this.name}}</li>{{/each}}</ul>'",
],
error_code="INVALID_HANDLEBARS_TEMPLATE",
)
query_mode = config.get("query_mode", "aggregate")
if query_mode not in ("aggregate", "raw"):
return False, ChartGenerationError(
error_type="invalid_query_mode",
message="Invalid query_mode for handlebars chart",
details="query_mode must be either 'aggregate' or 'raw'",
suggestions=[
"Use 'aggregate' for aggregated data (default)",
"Use 'raw' for individual rows",
],
error_code="INVALID_QUERY_MODE",
)
if query_mode == "raw" and not config.get("columns"):
return False, ChartGenerationError(
error_type="missing_raw_columns",
message="Handlebars chart in 'raw' mode requires 'columns'",
details="When query_mode is 'raw', you must specify which columns "
"to include in the query results",
suggestions=[
"Add 'columns': [{'name': 'column_name'}] for raw mode",
"Or use query_mode='aggregate' with 'metrics' "
"and optional 'groupby'",
],
error_code="MISSING_RAW_COLUMNS",
)
if query_mode == "aggregate" and not config.get("metrics"):
return False, ChartGenerationError(
error_type="missing_aggregate_metrics",
message="Handlebars chart in 'aggregate' mode requires 'metrics'",
details="When query_mode is 'aggregate' (default), you must specify "
"at least one metric with an aggregate function",
suggestions=[
"Add 'metrics': [{'name': 'column', 'aggregate': 'SUM'}]",
"Or use query_mode='raw' with 'columns' for individual rows",
],
error_code="MISSING_AGGREGATE_METRICS",
)
return True, None
@staticmethod
def _pre_validate_big_number_config(
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Pre-validate big number chart configuration."""
if "metric" not in config:
return False, ChartGenerationError(
error_type="missing_metric",
message="Big Number chart missing required field: metric",
details="Big Number charts require a 'metric' field "
"specifying the value to display",
suggestions=[
"Add 'metric' with name and aggregate: "
"{'name': 'revenue', 'aggregate': 'SUM'}",
"The aggregate function is required (SUM, COUNT, AVG, MIN, MAX)",
"Example: {'chart_type': 'big_number', "
"'metric': {'name': 'sales', 'aggregate': 'SUM'}}",
],
error_code="MISSING_BIG_NUMBER_METRIC",
)
metric = config.get("metric", {})
if not isinstance(metric, dict):
return False, ChartGenerationError(
error_type="invalid_metric_type",
message="Big Number metric must be a dict with 'name' and 'aggregate'",
details="The 'metric' field must be an object, "
f"got {type(metric).__name__}",
suggestions=[
"Use a dict: {'name': 'col', 'aggregate': 'SUM'}",
"Valid aggregates: SUM, COUNT, AVG, MIN, MAX",
],
error_code="INVALID_BIG_NUMBER_METRIC_TYPE",
)
if (
not metric.get("aggregate")
and not metric.get("saved_metric")
and not metric.get("sql_expression")
):
return False, ChartGenerationError(
error_type="missing_metric_aggregate",
message="Big Number metric must include an aggregate function, "
"a saved metric reference, or a SQL expression",
details="The metric must have an 'aggregate' field, "
"'saved_metric': true, or 'sql_expression'",
suggestions=[
"Add 'aggregate' to your metric: "
"{'name': 'col', 'aggregate': 'SUM'}",
"Or use a saved metric: "
"{'name': 'total_sales', 'saved_metric': true}",
"Or a custom SQL metric: "
"{'sql_expression': 'SUM(a)/SUM(b)', 'label': 'Ratio'}",
"Valid aggregates: SUM, COUNT, AVG, MIN, MAX",
],
error_code="MISSING_BIG_NUMBER_AGGREGATE",
)
# ``label`` may be any JSON type here (pre-Pydantic), so test the
# string-ness explicitly before calling ``.strip()``.
label = metric.get("label")
if metric.get("sql_expression") and not (
isinstance(label, str) and label.strip()
):
return False, ChartGenerationError(
error_type="missing_sql_metric_label",
message="Big Number metric with sql_expression requires a label",
details=(
"Custom SQL metrics have no column name to derive a label "
"from, so 'label' is required for display."
),
suggestions=[
"Add a 'label': "
"{'sql_expression': 'SUM(a)/SUM(b)', 'label': 'Ratio'}",
],
error_code="MISSING_SQL_METRIC_LABEL",
)
show_trendline = config.get("show_trendline", False)
temporal_column = config.get("temporal_column")
if show_trendline and not temporal_column:
return False, ChartGenerationError(
error_type="missing_temporal_column",
message="Trendline requires a temporal column",
details="When 'show_trendline' is True, a "
"'temporal_column' must be specified",
suggestions=[
"Add 'temporal_column': 'date_column_name'",
"Or set 'show_trendline': false for number only",
"Use get_dataset_info to find temporal columns",
],
error_code="MISSING_TEMPORAL_COLUMN",
)
return True, None
@staticmethod
def _pre_validate_pivot_table_config(
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Pre-validate pivot table configuration."""
missing_fields = []
if "rows" not in config:
missing_fields.append("'rows' (row grouping columns)")
if "metrics" not in config:
missing_fields.append("'metrics' (aggregation metrics)")
if missing_fields:
return False, ChartGenerationError(
error_type="missing_pivot_fields",
message=f"Pivot table missing required "
f"fields: {', '.join(missing_fields)}",
details="Pivot tables require row groupings and metrics",
suggestions=[
"Add 'rows' field: [{'name': 'category'}]",
"Add 'metrics' field: [{'name': 'sales', 'aggregate': 'SUM'}]",
"Optional 'columns' for cross-tabulation: [{'name': 'region'}]",
],
error_code="MISSING_PIVOT_FIELDS",
)
if not isinstance(config.get("rows", []), list):
return False, ChartGenerationError(
error_type="invalid_rows_format",
message="Rows must be a list of columns",
details="The 'rows' field must be an array of column specifications",
suggestions=[
"Wrap row columns in array: 'rows': [{'name': 'category'}]",
],
error_code="INVALID_ROWS_FORMAT",
)
if not isinstance(config.get("metrics", []), list):
return False, ChartGenerationError(
error_type="invalid_metrics_format",
message="Metrics must be a list",
details="The 'metrics' field must be an array of metric specifications",
suggestions=[
"Wrap metrics in array: 'metrics': [{'name': 'sales', "
"'aggregate': 'SUM'}]",
],
error_code="INVALID_METRICS_FORMAT",
)
return True, None
@staticmethod
def _pre_validate_mixed_timeseries_config(
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Pre-validate mixed timeseries configuration."""
missing_fields = []
if "x" not in config:
missing_fields.append("'x' (X-axis temporal column)")
if "y" not in config:
missing_fields.append("'y' (primary Y-axis metrics)")
if "y_secondary" not in config:
missing_fields.append("'y_secondary' (secondary Y-axis metrics)")
if missing_fields:
return False, ChartGenerationError(
error_type="missing_mixed_timeseries_fields",
message=f"Mixed timeseries chart missing required "
f"fields: {', '.join(missing_fields)}",
details="Mixed timeseries charts require an x-axis, primary metrics, "
"and secondary metrics",
suggestions=[
"Add 'x' field: {'name': 'date_column'}",
"Add 'y' field: [{'name': 'revenue', 'aggregate': 'SUM'}]",
"Add 'y_secondary' field: [{'name': 'orders', "
"'aggregate': 'COUNT'}]",
"Optional: 'primary_kind' and 'secondary_kind' for chart types",
],
error_code="MISSING_MIXED_TIMESERIES_FIELDS",
)
for field_name in ["y", "y_secondary"]:
if not isinstance(config.get(field_name, []), list):
return False, ChartGenerationError(
error_type=f"invalid_{field_name}_format",
message=f"'{field_name}' must be a list of metrics",
details=f"The '{field_name}' field must be an array of metric "
"specifications",
suggestions=[
f"Wrap in array: '{field_name}': "
"[{'name': 'col', 'aggregate': 'SUM'}]",
],
error_code=f"INVALID_{field_name.upper()}_FORMAT",
)
if (error := plugin.pre_validate(config)) is not None:
return False, error
return True, None
@staticmethod
@@ -562,89 +209,27 @@ class SchemaValidator:
if err.get("type") == "union_tag_invalid" or "discriminator" in str(
err.get("ctx", {})
):
# This is the generic union error - provide better message
config = request_data.get("config", {})
chart_type = config.get("chart_type", "unknown")
# avoid circular import
from superset.mcp_service.chart.registry import get_registry
if chart_type == "xy":
return ChartGenerationError(
error_type="xy_validation_error",
message="XY chart configuration validation failed",
details="The XY chart configuration is missing required "
"fields or has invalid structure",
suggestions=[
"Ensure 'x' field exists with {'name': 'column_name'}",
"Ensure 'y' field is an array: [{'name': 'metric', "
"'aggregate': 'SUM'}]",
"Check that all column names are strings",
"Verify aggregate functions are valid: SUM, COUNT, AVG, "
"MIN, MAX",
],
error_code="XY_VALIDATION_ERROR",
)
elif chart_type == "table":
return ChartGenerationError(
error_type="table_validation_error",
message="Table chart configuration validation failed",
details="The table chart configuration is missing required "
"fields or has invalid structure",
suggestions=[
"Ensure 'columns' field is an array of column "
"specifications",
"Each column needs {'name': 'column_name'}",
"Optional: add 'aggregate' for metrics",
"Example: 'columns': [{'name': 'product'}, {'name': "
"'sales', 'aggregate': 'SUM'}]",
],
error_code="TABLE_VALIDATION_ERROR",
)
elif chart_type == "handlebars":
return ChartGenerationError(
error_type="handlebars_validation_error",
message="Handlebars chart configuration validation failed",
details="The handlebars chart configuration is missing "
"required fields or has invalid structure",
suggestions=[
"Ensure 'handlebars_template' is a non-empty string",
"For aggregate mode: add 'metrics' with aggregate "
"functions",
"For raw mode: set 'query_mode': 'raw' and add 'columns'",
"Example: {'chart_type': 'handlebars', "
"'handlebars_template': '<ul>{{#each data}}<li>"
"{{this.name}}</li>{{/each}}</ul>', "
"'metrics': [{'name': 'sales', 'aggregate': 'SUM'}]}",
],
error_code="HANDLEBARS_VALIDATION_ERROR",
)
elif chart_type == "big_number":
return ChartGenerationError(
error_type="big_number_validation_error",
message="Big Number chart configuration validation failed",
details="The Big Number chart configuration is "
"missing required fields or has invalid "
"structure",
suggestions=[
"Ensure 'metric' field has 'name' and 'aggregate'",
"Example: 'metric': {'name': 'revenue', "
"'aggregate': 'SUM'}",
"For trendline: add 'show_trendline': true "
"and 'temporal_column': 'date_col'",
"Without trendline: just provide the metric",
],
error_code="BIG_NUMBER_VALIDATION_ERROR",
)
chart_type = request_data.get("config", {}).get("chart_type", "")
plugin = get_registry().get(chart_type)
if plugin is not None:
hint = plugin.schema_error_hint()
if hint is not None:
return hint
# Default enhanced error
error_details = []
for err in errors[:3]: # Show first 3 errors
loc = " -> ".join(str(location) for location in err.get("loc", []))
msg = err.get("msg", "Validation failed")
error_details.append(f"{loc}: {msg}")
error_details.append(f"{loc}: {msg}" if loc else msg)
return ChartGenerationError(
error_type="validation_error",
message="Chart configuration validation failed",
details="; ".join(error_details),
details="; ".join(error_details) or "Invalid chart configuration structure",
suggestions=[
"Check that all required fields are present",
"Ensure field types match the schema",

View File

@@ -81,6 +81,17 @@ try:
mcp_config = get_mcp_config(_mcp_app.config)
_mcp_app.config.update(mcp_config)
# Re-configure chart registry so MCP-specific overrides (e.g.
# MCP_DISABLED_CHART_PLUGINS set by the operator) take effect after
# the MCP config overlay. SupersetAppInitializer.configure_mcp_chart_registry()
# ran earlier with pre-overlay values; this call corrects them.
from superset.mcp_service.chart import registry as _chart_registry
_chart_registry.configure(
disabled=_mcp_app.config.get("MCP_DISABLED_CHART_PLUGINS"),
enabled_func=_mcp_app.config.get("MCP_CHART_PLUGIN_ENABLED_FUNC"),
)
with _mcp_app.app_context():
from superset.core.mcp.core_mcp_injection import (
initialize_core_mcp_dependencies,

View File

@@ -18,6 +18,7 @@
import logging
import secrets
from collections.abc import Callable
from typing import Any, Dict, Optional, Sequence
from authlib.jose.errors import JoseError
@@ -73,6 +74,46 @@ MCP_RBAC_ENABLED = True
# MCP_DISABLED_TOOLS = {"extensions.myorg.myext.some_tool"}
MCP_DISABLED_TOOLS: set[str] = set()
# =============================================================================
# MCP Chart Plugin Filtering
# =============================================================================
#
# Overview:
# ---------
# These two settings let operators enable/disable individual chart type plugins
# at runtime without a code deploy.
#
# Use cases:
# - Emergency kill switch: add "handlebars" to MCP_DISABLED_CHART_PLUGINS and
# restart to immediately hide it from all callers.
# - Dynamic per-request control (A/B test, gradual rollout): set
# MCP_CHART_PLUGIN_ENABLED_FUNC to an in-process predicate that can vary
# by user, request header, or any other context available at call time.
#
# Priority:
# MCP_CHART_PLUGIN_ENABLED_FUNC takes precedence over MCP_DISABLED_CHART_PLUGINS.
# When the callable is set, the deny-list is ignored entirely.
#
# MCP_CHART_PLUGIN_ENABLED_FUNC contract:
# - Called as enabled_func(chart_type: str) -> bool for every registry lookup.
# - Must be cheap and in-process: consult already-loaded feature flags or
# request-local context (e.g. Flask g). Do NOT perform network I/O per call.
# - On exception, the registry fails CLOSED (plugin hidden) and logs a warning.
# - Example (Harness / Split via pre-fetched flags in g):
# from flask import g
# def MCP_CHART_PLUGIN_ENABLED_FUNC(chart_type: str) -> bool:
# flags = getattr(g, "feature_flags", {})
# return flags.get(f"mcp_chart_{chart_type}", True)
# =============================================================================
# Chart types in this set are hidden from all registry lookups.
# Use frozenset to avoid accidental mutation.
MCP_DISABLED_CHART_PLUGINS: frozenset[str] = frozenset()
# Dynamic per-call predicate. When set, overrides MCP_DISABLED_CHART_PLUGINS.
# Signature: (chart_type: str) -> bool
MCP_CHART_PLUGIN_ENABLED_FUNC: Callable[[str], bool] | None = None
# MCP JWT Debug Errors - controls server-side JWT debug logging.
# When False (default), uses the default JWTVerifier with minimal logging.
# When True, uses DetailedJWTVerifier with tiered logging:
@@ -186,7 +227,7 @@ MCP_FACTORY_CONFIG = {
#
# For multi-pod/Kubernetes deployments, setting CACHE_REDIS_URL automatically
# enables Redis-backed EventStore to share session state across pods.
MCP_STORE_CONFIG: Dict[str, Any] = {
MCP_STORE_CONFIG: dict[str, Any] = {
"enabled": False, # Disabled by default - caching uses in-memory store
"CACHE_REDIS_URL": None, # Redis URL, e.g., "redis://localhost:6379/0"
# Wrapper class that prefixes all keys. Each consumer provides their own prefix.
@@ -199,7 +240,7 @@ MCP_STORE_CONFIG: Dict[str, Any] = {
# MCP Response Caching Configuration - controls caching behavior and TTLs
# When enabled without MCP_STORE_CONFIG, uses in-memory store.
# When enabled with MCP_STORE_CONFIG, uses Redis store.
MCP_CACHE_CONFIG: Dict[str, Any] = {
MCP_CACHE_CONFIG: dict[str, Any] = {
"enabled": False, # Disabled by default
"CACHE_KEY_PREFIX": None, # Only needed when using the store
"list_tools_ttl": 60 * 5, # 5 minutes
@@ -249,7 +290,7 @@ MCP_CACHE_CONFIG: Dict[str, Any] = {
# Uses character-based heuristic (~3.5 chars per token for JSON).
# This is intentionally conservative to avoid underestimating.
# =============================================================================
MCP_RESPONSE_SIZE_CONFIG: Dict[str, Any] = {
MCP_RESPONSE_SIZE_CONFIG: dict[str, Any] = {
"enabled": True, # Enabled by default to protect LLM clients
"token_limit": DEFAULT_TOKEN_LIMIT,
"warn_threshold_pct": DEFAULT_WARN_THRESHOLD_PCT,
@@ -304,7 +345,7 @@ MCP_RESPONSE_SIZE_CONFIG: Dict[str, Any] = {
# - compact_schemas is ignored when include_schemas=False (no schema to
# compact); max_description_length still applies in summary mode.
# =============================================================================
MCP_TOOL_SEARCH_CONFIG: Dict[str, Any] = {
MCP_TOOL_SEARCH_CONFIG: dict[str, Any] = {
"enabled": True, # Enabled by default — reduces initial context by ~70%
"strategy": "bm25", # "bm25" (natural language) or "regex" (pattern matching)
"max_results": 5, # Max tools returned per search
@@ -498,7 +539,7 @@ def generate_secret_key() -> str:
return secrets.token_urlsafe(42)
def get_mcp_config(app_config: Dict[str, Any] | None = None) -> Dict[str, Any]:
def get_mcp_config(app_config: dict[str, Any] | None = None) -> dict[str, Any]:
"""
Get complete MCP configuration dictionary.
@@ -519,6 +560,8 @@ def get_mcp_config(app_config: Dict[str, Any] | None = None) -> Dict[str, Any]:
"MCP_DEBUG": MCP_DEBUG,
"MCP_RBAC_ENABLED": MCP_RBAC_ENABLED,
"MCP_DISABLED_TOOLS": set(MCP_DISABLED_TOOLS),
"MCP_DISABLED_CHART_PLUGINS": MCP_DISABLED_CHART_PLUGINS,
"MCP_CHART_PLUGIN_ENABLED_FUNC": MCP_CHART_PLUGIN_ENABLED_FUNC,
**MCP_SESSION_CONFIG,
**MCP_CSRF_CONFIG,
}
@@ -528,8 +571,8 @@ def get_mcp_config(app_config: Dict[str, Any] | None = None) -> Dict[str, Any]:
def get_mcp_config_with_overrides(
app_config: Dict[str, Any] | None = None,
) -> Dict[str, Any]:
app_config: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Alternative approach: Allow any app_config keys, not just predefined ones.
@@ -543,7 +586,7 @@ def get_mcp_config_with_overrides(
return {**defaults, **app_config}
def get_mcp_factory_config() -> Dict[str, Any]:
def get_mcp_factory_config() -> dict[str, Any]:
"""
Get FastMCP factory configuration.

View File

@@ -17,11 +17,10 @@
import pytest
from pytest_mock import MockerFixture
from superset.commands.chart.exceptions import ChartForbiddenError, ChartInvalidError
from superset.commands.chart.exceptions import ChartForbiddenError
from superset.commands.chart.update import UpdateChartCommand
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetSecurityException
from superset.utils import json
def _ownership_exc() -> SupersetSecurityException:
@@ -92,73 +91,3 @@ def test_update_chart_owner_can_perform_regular_update(
find_by_id.assert_called_once_with(1)
raise_for_ownership.assert_called_once()
def _query_context_payload(datasource: object) -> dict[str, object]:
return {
"query_context": json.dumps({"datasource": datasource, "queries": []}),
"query_context_generation": True,
}
def test_update_chart_query_context_matching_datasource_is_allowed(
mocker: MockerFixture,
) -> None:
"""A query context that targets the chart's own datasource is accepted."""
find_by_id = mocker.patch("superset.commands.chart.update.ChartDAO.find_by_id")
find_by_id.return_value = mocker.MagicMock(
id=1, tags=[], dashboards=[], datasource_id=42, datasource_type="table"
)
mocker.patch("superset.commands.chart.update.security_manager.raise_for_ownership")
UpdateChartCommand(
1, _query_context_payload({"id": 42, "type": "table"})
).validate()
@pytest.mark.parametrize(
"datasource",
[
{"id": 99, "type": "table"}, # different id
{"id": 42, "type": "query"}, # different type
{"id": "99", "type": "table"}, # different id as string
],
)
def test_update_chart_query_context_mismatched_datasource_is_rejected(
mocker: MockerFixture,
datasource: dict[str, object],
) -> None:
"""A query context pointing at a different datasource is rejected with a 4xx."""
find_by_id = mocker.patch("superset.commands.chart.update.ChartDAO.find_by_id")
find_by_id.return_value = mocker.MagicMock(
id=1, tags=[], dashboards=[], datasource_id=42, datasource_type="table"
)
mocker.patch("superset.commands.chart.update.security_manager.raise_for_ownership")
with pytest.raises(ChartInvalidError):
UpdateChartCommand(1, _query_context_payload(datasource)).validate()
@pytest.mark.parametrize(
"query_context",
[
"{}", # no datasource key
'{"datasource": null}', # null datasource
"not-json", # unparseable payload
],
)
def test_update_chart_query_context_without_datasource_is_allowed(
mocker: MockerFixture,
query_context: str,
) -> None:
"""Payloads with no verifiable datasource fall back to the chart's own."""
find_by_id = mocker.patch("superset.commands.chart.update.ChartDAO.find_by_id")
find_by_id.return_value = mocker.MagicMock(
id=1, tags=[], dashboards=[], datasource_id=42, datasource_type="table"
)
mocker.patch("superset.commands.chart.update.security_manager.raise_for_ownership")
UpdateChartCommand(
1,
{"query_context": query_context, "query_context_generation": True},
).validate()

View File

@@ -93,30 +93,36 @@ class TestBigNumberChartConfig:
"chart_type": "big_number",
"metric": {"name": "total_sales", "saved_metric": True},
}
is_valid, error = SchemaValidator._pre_validate_big_number_config(data)
is_valid, error = SchemaValidator._pre_validate_chart_type("big_number", data)
assert is_valid is True
assert error is None
def test_sql_expression_with_label_passes_pre_validation(self) -> None:
"""A custom SQL metric is a valid third option alongside aggregate and
saved_metric in Tier-1 validation."""
from superset.mcp_service.chart.registry import get_registry
data = {
"chart_type": "big_number",
"metric": {"sql_expression": "SUM(a)/SUM(b)", "label": "Ratio"},
}
is_valid, error = SchemaValidator._pre_validate_big_number_config(data)
assert is_valid is True
plugin = get_registry().get("big_number")
assert plugin is not None
error = plugin.pre_validate(data)
assert error is None
def test_sql_expression_without_label_fails_pre_validation(self) -> None:
"""Tier-1 surfaces the label-required error with an LLM-actionable
suggestion before the request reaches Pydantic's stricter error."""
from superset.mcp_service.chart.registry import get_registry
data = {
"chart_type": "big_number",
"metric": {"sql_expression": "SUM(a)/SUM(b)"},
}
is_valid, error = SchemaValidator._pre_validate_big_number_config(data)
assert is_valid is False
plugin = get_registry().get("big_number")
assert plugin is not None
error = plugin.pre_validate(data)
assert error is not None
assert error.error_code == "MISSING_SQL_METRIC_LABEL"
@@ -124,12 +130,15 @@ class TestBigNumberChartConfig:
"""Pre-validation runs on raw dict input before Pydantic coercion, so
a non-string ``label`` (e.g. an int from a buggy client) must surface
as a validation error, not an AttributeError from ``.strip()``."""
from superset.mcp_service.chart.registry import get_registry
data = {
"chart_type": "big_number",
"metric": {"sql_expression": "SUM(a)/SUM(b)", "label": 123},
}
is_valid, error = SchemaValidator._pre_validate_big_number_config(data)
assert is_valid is False
plugin = get_registry().get("big_number")
assert plugin is not None
error = plugin.pre_validate(data)
assert error is not None
assert error.error_code == "MISSING_SQL_METRIC_LABEL"

View File

@@ -23,7 +23,9 @@ import pytest
from pydantic import ValidationError
from superset.mcp_service.chart.schemas import (
BigNumberChartConfig,
ColumnRef,
FilterConfig,
GenerateChartRequest,
GenerateChartResponse,
MixedTimeseriesChartConfig,
@@ -49,6 +51,44 @@ class TestGenerateChartResponse:
assert response.chart_type_label == "table chart"
class TestColumnNameSanitization:
"""Test relaxed column names retain SQL-injection protection."""
def test_column_ref_rejects_sql_injection(self) -> None:
"""ColumnRef rejects SQL injection patterns."""
with pytest.raises(ValidationError, match="potentially unsafe"):
ColumnRef(name="revenue; DROP TABLE users")
def test_filter_column_rejects_sql_injection(self) -> None:
"""FilterConfig.column rejects SQL injection patterns."""
with pytest.raises(ValidationError, match="potentially unsafe"):
FilterConfig(column="status; DROP TABLE users", op="=", value="active")
def test_temporal_column_rejects_sql_injection(self) -> None:
"""BigNumberChartConfig.temporal_column rejects SQL injection patterns."""
with pytest.raises(ValidationError, match="potentially unsafe"):
BigNumberChartConfig(
chart_type="big_number",
metric={"name": "revenue", "aggregate": "SUM"},
show_trendline=True,
temporal_column="created_at; DROP TABLE users",
)
def test_relaxed_column_names_still_pass(self) -> None:
"""Digit-prefixed, dotted, and hyphenated column names are accepted."""
assert ColumnRef(name="1Q_revenue").name == "1Q_revenue"
assert FilterConfig(column="order-date", op="=", value="active").column == (
"order-date"
)
config = BigNumberChartConfig(
chart_type="big_number",
metric={"name": "revenue", "aggregate": "SUM"},
show_trendline=True,
temporal_column="events.created-at",
)
assert config.temporal_column == "events.created-at"
class TestTableChartConfig:
"""Test TableChartConfig validation."""

View File

@@ -0,0 +1,252 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for the chart type plugin registry."""
import sys
import threading
from types import ModuleType
import pytest
import superset.mcp_service.chart.registry as registry_module
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.registry import (
_RegistryProxy,
_reset_for_testing,
all_types,
display_name_for_viz_type,
get,
get_registry,
is_registered,
register,
)
@pytest.fixture(autouse=True)
def _isolated_registry(monkeypatch):
"""Run each test against a clean registry without touching the real one."""
monkeypatch.setattr(registry_module, "_REGISTRY", {})
monkeypatch.setattr(registry_module, "_plugins_loaded", True)
monkeypatch.setattr(registry_module, "_plugins_load_failed", False)
class _FakePlugin(BaseChartPlugin):
chart_type = "fake"
display_name = "Fake Chart"
native_viz_types = {"fake_viz": "Fake Viz"}
class _AnotherPlugin(BaseChartPlugin):
chart_type = "another"
display_name = "Another Chart"
native_viz_types = {"another_viz": "Another Viz"}
def test_register_adds_plugin():
plugin = _FakePlugin()
register(plugin)
assert get("fake") is plugin
def test_get_returns_none_for_unknown():
assert get("nonexistent") is None
def test_all_types_returns_registered_keys():
register(_FakePlugin())
register(_AnotherPlugin())
types = all_types()
assert "fake" in types
assert "another" in types
def test_all_types_insertion_order():
register(_FakePlugin())
register(_AnotherPlugin())
types = all_types()
assert types.index("fake") < types.index("another")
def test_is_registered_true_for_known():
register(_FakePlugin())
assert is_registered("fake") is True
def test_is_registered_false_for_unknown():
assert is_registered("nonexistent") is False
def test_register_warns_on_duplicate(caplog):
register(_FakePlugin())
with caplog.at_level("WARNING"):
register(_FakePlugin())
assert "Overwriting" in caplog.text
def test_register_warns_on_viz_type_collision(caplog):
register(_FakePlugin())
class _CollidingPlugin(BaseChartPlugin):
chart_type = "colliding"
display_name = "Colliding Chart"
native_viz_types = {"fake_viz": "Shadowed Viz", "own_viz": "Own Viz"}
with caplog.at_level("WARNING"):
register(_CollidingPlugin())
assert "already claimed by" in caplog.text
assert "fake_viz" in caplog.text
# Earlier registration wins in display-name lookups
assert display_name_for_viz_type("fake_viz") == "Fake Viz"
def test_register_same_plugin_no_collision_warning(caplog):
register(_FakePlugin())
with caplog.at_level("WARNING"):
register(_FakePlugin())
assert "already claimed by" not in caplog.text
def test_register_raises_for_empty_chart_type():
class _BadPlugin(BaseChartPlugin):
chart_type = ""
with pytest.raises(ValueError, match="non-empty chart_type"):
register(_BadPlugin())
def test_display_name_for_viz_type_found():
register(_FakePlugin())
assert display_name_for_viz_type("fake_viz") == "Fake Viz"
def test_display_name_for_viz_type_not_found():
register(_FakePlugin())
assert display_name_for_viz_type("unknown_viz") is None
def test_display_name_searches_all_plugins():
register(_FakePlugin())
register(_AnotherPlugin())
assert display_name_for_viz_type("another_viz") == "Another Viz"
def test_get_registry_returns_proxy():
assert isinstance(get_registry(), _RegistryProxy)
def test_plugins_lock_allows_register_during_lazy_import():
"""The registry lock is re-entrant for plugin registration during import."""
assert isinstance(registry_module._plugins_lock, type(threading.RLock()))
def test_registry_proxy_get():
plugin = _FakePlugin()
register(plugin)
assert get_registry().get("fake") is plugin
def test_registry_proxy_all_types():
register(_FakePlugin())
assert "fake" in get_registry().all_types()
def test_registry_proxy_is_registered():
register(_FakePlugin())
assert get_registry().is_registered("fake") is True
assert get_registry().is_registered("missing") is False
def test_registry_proxy_display_name_for_viz_type():
register(_FakePlugin())
assert get_registry().display_name_for_viz_type("fake_viz") == "Fake Viz"
assert get_registry().display_name_for_viz_type("unknown") is None
def test_ensure_plugins_loaded_skips_when_load_failed(monkeypatch):
"""_ensure_plugins_loaded returns immediately when _plugins_load_failed is set."""
from superset.mcp_service.chart.registry import _ensure_plugins_loaded
monkeypatch.setattr(registry_module, "_plugins_loaded", False)
monkeypatch.setattr(registry_module, "_plugins_load_failed", True)
# If the function tried to import, the real plugins module would load and flip
# _plugins_loaded to True. The circuit breaker should prevent that.
_ensure_plugins_loaded()
assert registry_module._plugins_loaded is False
def test_ensure_plugins_loaded_sets_failed_flag_on_error(monkeypatch):
"""A failed import sets _plugins_load_failed so subsequent calls are no-ops."""
from unittest.mock import patch
from superset.mcp_service.chart.registry import _ensure_plugins_loaded
monkeypatch.setattr(registry_module, "_plugins_loaded", False)
monkeypatch.setattr(registry_module, "_plugins_load_failed", False)
monkeypatch.setattr(registry_module, "_plugins_lock", threading.Lock())
# Setting the module to None in sys.modules causes ImportError on import.
with patch.dict("sys.modules", {"superset.mcp_service.chart.plugins": None}):
_ensure_plugins_loaded()
assert registry_module._plugins_load_failed is True
assert registry_module._plugins_loaded is False
def test_ensure_plugins_loaded_rolls_back_partial_registration(monkeypatch):
"""A failed lazy import restores the registry to its previous state."""
from superset.mcp_service.chart.registry import _ensure_plugins_loaded
original_import = __import__
existing_plugin = _FakePlugin()
partial_plugin = _AnotherPlugin()
monkeypatch.setattr(registry_module, "_REGISTRY", {"fake": existing_plugin})
monkeypatch.setattr(registry_module, "_plugins_loaded", False)
monkeypatch.setattr(registry_module, "_plugins_load_failed", False)
def fail_plugin_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == "superset.mcp_service.chart.plugins":
register(partial_plugin)
raise RuntimeError("plugin import failed")
return original_import(name, globals, locals, fromlist, level)
monkeypatch.setattr("builtins.__import__", fail_plugin_import)
_ensure_plugins_loaded()
assert registry_module._plugins_load_failed is True
assert registry_module._REGISTRY == {"fake": existing_plugin}
def test_reset_for_testing_clears_cached_plugins_package(monkeypatch):
"""Reset removes the plugins package so lazy loading can re-run registration."""
module_name = "superset.mcp_service.chart.plugins"
monkeypatch.setitem(sys.modules, module_name, ModuleType(module_name))
monkeypatch.setattr(registry_module, "_REGISTRY", {"fake": _FakePlugin()})
monkeypatch.setattr(registry_module, "_plugins_loaded", True)
monkeypatch.setattr(registry_module, "_plugins_load_failed", True)
_reset_for_testing()
assert registry_module._REGISTRY == {}
assert registry_module._plugins_loaded is False
assert registry_module._plugins_load_failed is False
assert module_name not in sys.modules

View File

@@ -0,0 +1,222 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for registry plugin filtering (configure / is_enabled / get / all_types)."""
import pytest
import superset.mcp_service.chart.registry as registry_module
from superset.mcp_service.chart.plugin import BaseChartPlugin
from superset.mcp_service.chart.registry import (
_PluginFilterConfig,
all_types,
configure,
display_name_for_viz_type,
get,
is_enabled,
is_registered,
register,
)
@pytest.fixture(autouse=True)
def _isolated_registry(monkeypatch):
"""Isolated registry with two known plugins and a clean filter for each test."""
monkeypatch.setattr(registry_module, "_REGISTRY", {})
monkeypatch.setattr(registry_module, "_plugins_loaded", True)
monkeypatch.setattr(registry_module, "_filter_config", _PluginFilterConfig())
register(_AlphaPlugin())
register(_BetaPlugin())
class _AlphaPlugin(BaseChartPlugin):
chart_type = "alpha"
display_name = "Alpha Chart"
native_viz_types = {"alpha_viz": "Alpha Viz"}
class _BetaPlugin(BaseChartPlugin):
chart_type = "beta"
display_name = "Beta Chart"
native_viz_types = {"beta_viz": "Beta Viz"}
# ---------------------------------------------------------------------------
# Static deny-list tests
# ---------------------------------------------------------------------------
def test_get_returns_plugin_when_enabled():
assert get("alpha") is not None
assert get("beta") is not None
def test_get_returns_none_for_disabled_plugin():
configure(disabled={"alpha"})
assert get("alpha") is None
def test_get_still_returns_other_plugins_when_one_is_disabled():
configure(disabled={"alpha"})
assert get("beta") is not None
def test_all_types_excludes_disabled():
configure(disabled={"alpha"})
types = all_types()
assert "alpha" not in types
assert "beta" in types
def test_all_types_empty_when_all_disabled():
configure(disabled={"alpha", "beta"})
assert all_types() == []
def test_is_registered_ignores_deny_list():
configure(disabled={"alpha"})
assert is_registered("alpha") is True
def test_is_enabled_returns_false_for_disabled():
configure(disabled={"alpha"})
assert is_enabled("alpha") is False
def test_is_enabled_returns_true_when_not_disabled():
configure(disabled={"alpha"})
assert is_enabled("beta") is True
def test_is_enabled_returns_false_for_unknown():
assert is_enabled("nonexistent") is False
# ---------------------------------------------------------------------------
# configure() accepts different iterable shapes
# ---------------------------------------------------------------------------
def test_configure_accepts_list():
configure(disabled=["alpha"])
assert get("alpha") is None
def test_configure_accepts_tuple():
configure(disabled=("alpha",))
assert get("alpha") is None
def test_configure_accepts_frozenset():
configure(disabled=frozenset({"alpha"}))
assert get("alpha") is None
def test_configure_accepts_none_disabled():
configure(disabled=None)
assert get("alpha") is not None
def test_configure_rejects_noncallable_enabled_func():
with pytest.raises(TypeError):
configure(enabled_func="not_a_callable")
# ---------------------------------------------------------------------------
# Dynamic callable hook tests
# ---------------------------------------------------------------------------
def test_enabled_func_overrides_deny_list():
# alpha is in deny-list but callable says True → should be visible
configure(disabled={"alpha"}, enabled_func=lambda ct: ct == "alpha")
assert get("alpha") is not None
def test_enabled_func_can_disable_plugin():
configure(enabled_func=lambda ct: ct != "beta")
assert get("beta") is None
assert get("alpha") is not None
def test_enabled_func_called_per_lookup():
calls = []
def hook(ct: str) -> bool:
calls.append(ct)
return True
configure(enabled_func=hook)
get("alpha")
get("alpha")
assert calls.count("alpha") == 2
def test_enabled_func_exception_fails_closed(caplog):
import logging
def bad_hook(ct: str) -> bool:
raise RuntimeError("Harness down")
configure(enabled_func=bad_hook)
with caplog.at_level(logging.WARNING, logger="superset.mcp_service.chart.registry"):
result = get("alpha")
assert result is None # fail closed
assert "failing closed" in caplog.text.lower() or "alpha" in caplog.text
def test_enabled_func_all_types_respects_hook():
configure(enabled_func=lambda ct: ct == "alpha")
assert all_types() == ["alpha"]
# ---------------------------------------------------------------------------
# display_name_for_viz_type is NOT filtered
# ---------------------------------------------------------------------------
def test_display_name_unaffected_by_deny_list():
configure(disabled={"alpha"})
# Even though alpha is disabled, its viz_type should still resolve
assert display_name_for_viz_type("alpha_viz") == "Alpha Viz"
def test_display_name_unaffected_by_callable():
configure(enabled_func=lambda ct: False)
assert display_name_for_viz_type("beta_viz") == "Beta Viz"
# ---------------------------------------------------------------------------
# configure() atomicity: replacing config is visible to next lookup
# ---------------------------------------------------------------------------
def test_reconfigure_replaces_previous_filter():
configure(disabled={"alpha"})
assert get("alpha") is None
configure(disabled=set())
assert get("alpha") is not None
def test_reconfigure_with_func_then_none_falls_back_to_deny_list():
configure(enabled_func=lambda ct: False)
assert get("alpha") is None
configure(disabled={"beta"}, enabled_func=None)
assert get("alpha") is not None
assert get("beta") is None

View File

@@ -1175,6 +1175,11 @@ class TestUpdateChartValidationGate:
)
@patch("superset.daos.chart.ChartDAO.find_by_id", new_callable=Mock)
@patch("superset.db.session")
@patch(
"superset.mcp_service.chart.validation.dataset_validator"
".DatasetValidator.validate_against_dataset",
new=Mock(return_value=(True, None)),
)
@pytest.mark.asyncio
async def test_preview_path_validation_failure_skips_cache(
self,
@@ -1238,6 +1243,11 @@ class TestUpdateChartValidationGate:
)
@patch("superset.daos.chart.ChartDAO.find_by_id", new_callable=Mock)
@patch("superset.db.session")
@patch(
"superset.mcp_service.chart.validation.dataset_validator"
".DatasetValidator.validate_against_dataset",
new=Mock(return_value=(True, None)),
)
@pytest.mark.asyncio
async def test_persist_path_validation_failure_skips_db_write(
self,
@@ -1288,6 +1298,173 @@ class TestUpdateChartValidationGate:
mock_update_cmd_cls.assert_not_called()
# ---------------------------------------------------------------------------
# Column normalization in update_chart
# ---------------------------------------------------------------------------
class TestUpdateChartColumnNormalization:
"""Column names are normalized to dataset canonical case before validation."""
@staticmethod
def _mock_chart(datasource_id: int | None = 10) -> Mock:
chart = Mock()
chart.id = 1
chart.datasource_id = datasource_id
chart.slice_name = "Existing"
chart.viz_type = "table"
chart.uuid = "abc-123"
chart.params = '{"viz_type": "table", "datasource": "10__table"}'
chart.datasource = Mock()
return chart
@patch.object(update_chart_module, "validate_and_compile")
@patch.object(update_chart_module, "_create_preview_url", new_callable=Mock)
@patch("superset.mcp_service.auth.check_chart_data_access", new_callable=Mock)
@patch("superset.daos.chart.ChartDAO.find_by_id", new_callable=Mock)
@patch("superset.db.session")
@patch(
"superset.mcp_service.chart.validation.dataset_validator"
".DatasetValidator.validate_against_dataset",
new=Mock(return_value=(True, None)),
)
@patch(
"superset.mcp_service.chart.validation.dataset_validator"
".DatasetValidator.normalize_column_names",
)
@pytest.mark.asyncio
async def test_normalization_called_with_guarded_datasource_id(
self,
mock_normalize,
mock_db_session,
mock_find_by_id,
mock_check_access,
mock_create_preview,
mock_validate,
mcp_server,
):
"""normalize_column_names receives the locally-guarded datasource_id, not
chart.datasource_id re-accessed from the ORM object (which could raise on
mock/partial objects)."""
from superset.mcp_service.chart.compile import CompileResult
chart = self._mock_chart(datasource_id=10)
mock_find_by_id.return_value = chart
mock_check_access.return_value = DatasetValidationResult(
is_valid=True, dataset_id=10, dataset_name="ds", warnings=[]
)
mock_validate.return_value = CompileResult(
success=True, error=None, error_code=None, tier="validation", error_obj=None
)
mock_create_preview.return_value = ("http://example.com/explore", None, [])
# normalize_column_names returns the config unchanged
def _passthrough(config, dataset_id):
return config
mock_normalize.side_effect = _passthrough
request = {
"identifier": 1,
"config": {
"chart_type": "xy",
"x": {"name": "ds"},
"y": [{"name": "num_boys", "aggregate": "SUM"}],
"kind": "line",
},
}
async with Client(mcp) as client:
await client.call_tool("update_chart", {"request": request})
mock_normalize.assert_called_once()
_, call_dataset_id = mock_normalize.call_args.args
assert call_dataset_id == 10 # guarded local var, not re-read from ORM
@patch.object(update_chart_module, "validate_and_compile")
@patch.object(update_chart_module, "_create_preview_url", new_callable=Mock)
@patch("superset.mcp_service.auth.check_chart_data_access", new_callable=Mock)
@patch("superset.daos.chart.ChartDAO.find_by_id", new_callable=Mock)
@patch("superset.db.session")
@patch(
"superset.mcp_service.chart.validation.dataset_validator"
".DatasetValidator.validate_against_dataset",
new=Mock(return_value=(True, None)),
)
@patch(
"superset.mcp_service.chart.validation.dataset_validator"
".DatasetValidator.normalize_column_names",
)
@pytest.mark.asyncio
async def test_normalization_exception_is_caught_gracefully(
self,
mock_normalize,
mock_db_session,
mock_find_by_id,
mock_check_access,
mock_create_preview,
mock_validate,
mcp_server,
):
"""A normalization failure must not propagate — chart update continues."""
from superset.mcp_service.chart.compile import CompileResult
chart = self._mock_chart(datasource_id=10)
mock_find_by_id.return_value = chart
mock_check_access.return_value = DatasetValidationResult(
is_valid=True, dataset_id=10, dataset_name="ds", warnings=[]
)
mock_validate.return_value = CompileResult(
success=True, error=None, error_code=None, tier="validation", error_obj=None
)
mock_create_preview.return_value = ("http://example.com/explore", None, [])
mock_normalize.side_effect = ValueError("DB connection failed")
request = {
"identifier": 1,
"config": {
"chart_type": "xy",
"x": {"name": "ds"},
"y": [{"name": "num_boys", "aggregate": "SUM"}],
"kind": "line",
},
}
async with Client(mcp) as client:
# Should not raise; normalization failure is a warning only
await client.call_tool("update_chart", {"request": request})
# Normalization failed but tool still attempted the update path
mock_normalize.assert_called_once()
@patch(
"superset.mcp_service.chart.validation.dataset_validator"
".DatasetValidator.normalize_column_names",
)
def test_normalization_skipped_when_no_datasource_id(self, mock_normalize):
"""normalize_column_names is never called when chart has no datasource_id."""
from superset.mcp_service.chart.schemas import XYChartConfig
chart = self._mock_chart(datasource_id=None)
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="ds"),
y=[ColumnRef(name="num_boys", aggregate="SUM")],
kind="line",
)
# Simulate the guard from update_chart
chart_datasource_id = getattr(chart, "datasource_id", None)
if config is not None and chart_datasource_id is not None:
from superset.mcp_service.chart.validation.dataset_validator import (
DatasetValidator,
)
DatasetValidator.normalize_column_names(config, chart_datasource_id)
mock_normalize.assert_not_called()
# ---------------------------------------------------------------------------
# Custom SQL metrics (sql_expression) — Ticket #3, update_chart side.
# ---------------------------------------------------------------------------

View File

@@ -117,83 +117,6 @@ class TestGetCanonicalColumnName:
assert result == "unknown_column"
class TestNormalizeXYConfig:
"""Test _normalize_xy_config static method."""
def test_normalize_x_axis_column(
self, mock_dataset_context: DatasetContext
) -> None:
"""Test that x-axis column name is normalized."""
config_dict: Dict[str, Any] = {
"chart_type": "xy",
"x": {"name": "orderdate"},
"y": [{"name": "Sales", "aggregate": "SUM"}],
"kind": "line",
}
DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context)
assert config_dict["x"]["name"] == "OrderDate"
def test_normalize_y_axis_columns(
self, mock_dataset_context: DatasetContext
) -> None:
"""Test that y-axis column names are normalized."""
config_dict: Dict[str, Any] = {
"chart_type": "xy",
"x": {"name": "OrderDate"},
"y": [
{"name": "sales", "aggregate": "SUM"},
{"name": "QUANTITY_ORDERED", "aggregate": "COUNT"},
],
"kind": "bar",
}
DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context)
assert config_dict["y"][0]["name"] == "Sales"
assert config_dict["y"][1]["name"] == "quantity_ordered"
def test_normalize_group_by_column(
self, mock_dataset_context: DatasetContext
) -> None:
"""Test that group_by column name is normalized."""
config_dict: Dict[str, Any] = {
"chart_type": "xy",
"x": {"name": "OrderDate"},
"y": [{"name": "Sales", "aggregate": "SUM"}],
"kind": "line",
"group_by": [{"name": "productline"}],
}
DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context)
assert config_dict["group_by"][0]["name"] == "ProductLine"
class TestNormalizeTableConfig:
"""Test _normalize_table_config static method."""
def test_normalize_table_columns(
self, mock_dataset_context: DatasetContext
) -> None:
"""Test that table column names are normalized."""
config_dict: Dict[str, Any] = {
"chart_type": "table",
"columns": [
{"name": "orderdate"},
{"name": "PRODUCTLINE"},
{"name": "sales", "aggregate": "SUM"},
],
}
DatasetValidator._normalize_table_config(config_dict, mock_dataset_context)
assert config_dict["columns"][0]["name"] == "OrderDate"
assert config_dict["columns"][1]["name"] == "ProductLine"
assert config_dict["columns"][2]["name"] == "Sales"
class TestNormalizeFilters:
"""Test _normalize_filters static method."""
@@ -742,3 +665,242 @@ class TestValidateSavedMetrics:
assert not is_valid
assert error is not None
assert error.error_code == "INVALID_SAVED_METRIC"
class TestGetCanonicalMetricName:
"""Tests for _get_canonical_metric_name — metrics-only lookup."""
def test_exact_match(self, mock_dataset_context: DatasetContext) -> None:
result = DatasetValidator._get_canonical_metric_name(
"TotalRevenue", mock_dataset_context
)
assert result == "TotalRevenue"
def test_case_insensitive_match(self, mock_dataset_context: DatasetContext) -> None:
result = DatasetValidator._get_canonical_metric_name(
"totalrevenue", mock_dataset_context
)
assert result == "TotalRevenue"
def test_unknown_metric_returns_original(
self, mock_dataset_context: DatasetContext
) -> None:
result = DatasetValidator._get_canonical_metric_name(
"no_such_metric", mock_dataset_context
)
assert result == "no_such_metric"
def test_column_name_not_matched(
self, mock_dataset_context: DatasetContext
) -> None:
"""A name that matches a column but not a metric returns the original."""
result = DatasetValidator._get_canonical_metric_name(
"Sales", mock_dataset_context
)
assert result == "Sales"
@pytest.fixture
def collision_dataset_context() -> DatasetContext:
"""Dataset where a column and a metric share the same case-insensitive name
but have different casing — the scenario that exposed the saved-metric bug."""
return DatasetContext(
id=99,
table_name="sales_data",
schema="public",
database_name="examples",
available_columns=[
{"name": "totalrevenue", "type": "DECIMAL", "is_numeric": True},
],
available_metrics=[
{
"name": "TotalRevenue",
"expression": "SUM(amount)",
"description": None,
},
],
)
class TestSavedMetricNormalizationCorrectness:
"""Saved metrics must resolve against available_metrics, not available_columns.
When a column and a metric share the same case-insensitive name but have
different casing, _get_canonical_column_name (columns-first) returns the
column's casing. For saved_metric=True refs this is wrong — downstream
metric resolution is exact-name based and expects the metric's casing.
"""
@patch.object(DatasetValidator, "_get_dataset_context")
def test_xy_saved_metric_uses_metric_casing(
self,
mock_get_context: Any,
collision_dataset_context: DatasetContext,
) -> None:
mock_get_context.return_value = collision_dataset_context
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="totalrevenue"),
y=[ColumnRef(name="totalrevenue", saved_metric=True)],
)
normalized = DatasetValidator.normalize_column_names(config, dataset_id=99)
# x is a regular column ref — gets column casing
assert normalized.x is not None
assert normalized.x.name == "totalrevenue"
# y is a saved metric — must get metric casing, not column casing
assert normalized.y[0].name == "TotalRevenue"
@patch.object(DatasetValidator, "_get_dataset_context")
def test_table_saved_metric_uses_metric_casing(
self,
mock_get_context: Any,
collision_dataset_context: DatasetContext,
) -> None:
from superset.mcp_service.chart.schemas import TableChartConfig
mock_get_context.return_value = collision_dataset_context
config = TableChartConfig(
chart_type="table",
columns=[
ColumnRef(name="totalrevenue"),
ColumnRef(name="totalrevenue", saved_metric=True),
],
)
normalized = DatasetValidator.normalize_column_names(config, dataset_id=99)
assert normalized.columns[0].name == "totalrevenue"
assert normalized.columns[1].name == "TotalRevenue"
@patch.object(DatasetValidator, "_get_dataset_context")
def test_pie_saved_metric_uses_metric_casing(
self,
mock_get_context: Any,
collision_dataset_context: DatasetContext,
) -> None:
from superset.mcp_service.chart.schemas import PieChartConfig
mock_get_context.return_value = collision_dataset_context
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="totalrevenue"),
metric=ColumnRef(name="totalrevenue", saved_metric=True),
)
normalized = DatasetValidator.normalize_column_names(config, dataset_id=99)
assert normalized.dimension.name == "totalrevenue"
assert normalized.metric.name == "TotalRevenue"
@patch.object(DatasetValidator, "_get_dataset_context")
def test_big_number_saved_metric_uses_metric_casing(
self,
mock_get_context: Any,
collision_dataset_context: DatasetContext,
) -> None:
from superset.mcp_service.chart.schemas import BigNumberChartConfig
mock_get_context.return_value = collision_dataset_context
config = BigNumberChartConfig(
chart_type="big_number",
metric=ColumnRef(name="totalrevenue", saved_metric=True),
)
normalized = DatasetValidator.normalize_column_names(config, dataset_id=99)
assert normalized.metric.name == "TotalRevenue"
@patch.object(DatasetValidator, "_get_dataset_context")
def test_mixed_timeseries_saved_metrics_use_metric_casing(
self,
mock_get_context: Any,
collision_dataset_context: DatasetContext,
) -> None:
from superset.mcp_service.chart.schemas import (
ColumnRef,
MixedTimeseriesChartConfig,
)
context = DatasetContext(
id=99,
table_name="sales_data",
schema="public",
database_name="examples",
available_columns=[
{"name": "ds", "type": "TIMESTAMP", "is_temporal": True},
{"name": "totalrevenue", "type": "DECIMAL", "is_numeric": True},
],
available_metrics=[
{
"name": "TotalRevenue",
"expression": "SUM(amount)",
"description": None,
},
{
"name": "OrderCount",
"expression": "COUNT(*)",
"description": None,
},
],
)
mock_get_context.return_value = context
config = MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="ds"),
y=[ColumnRef(name="totalrevenue", saved_metric=True)],
y_secondary=[ColumnRef(name="ordercount", saved_metric=True)],
)
normalized = DatasetValidator.normalize_column_names(config, dataset_id=99)
assert normalized.y[0].name == "TotalRevenue"
assert normalized.y_secondary[0].name == "OrderCount"
class TestPreValidateAliasHandling:
"""pre_validate must accept schema field aliases, not just canonical names."""
def test_xy_pre_validate_accepts_metrics_alias(self) -> None:
from superset.mcp_service.chart.registry import get_registry
plugin = get_registry().get("xy")
assert plugin is not None
config_with_alias = {
"chart_type": "xy",
"metrics": [{"name": "revenue", "aggregate": "SUM"}],
}
error = plugin.pre_validate(config_with_alias)
assert error is None, f"pre_validate rejected 'metrics' alias: {error}"
def test_mixed_timeseries_pre_validate_accepts_x_axis_alias(self) -> None:
from superset.mcp_service.chart.registry import get_registry
plugin = get_registry().get("mixed_timeseries")
assert plugin is not None
config_with_alias = {
"chart_type": "mixed_timeseries",
"x_axis": {"name": "ds"},
"metrics": [{"name": "revenue", "aggregate": "SUM"}],
"metrics_b": [{"name": "orders", "aggregate": "COUNT"}],
}
error = plugin.pre_validate(config_with_alias)
assert error is None, f"pre_validate rejected aliases: {error}"
def test_mixed_timeseries_pre_validate_still_rejects_truly_missing(self) -> None:
from superset.mcp_service.chart.registry import get_registry
plugin = get_registry().get("mixed_timeseries")
assert plugin is not None
config_missing_secondary = {
"chart_type": "mixed_timeseries",
"x": {"name": "ds"},
"y": [{"name": "revenue", "aggregate": "SUM"}],
}
error = plugin.pre_validate(config_missing_secondary)
assert error is not None
assert "y_secondary" in error.message

View File

@@ -58,12 +58,12 @@ class TestRuntimeValidatorNonBlocking:
x_axis=AxisConfig(format="$,.2f"), # Currency format for date - mismatch
)
# Mock the format validator to return warnings
# Mock the plugin runtime dispatcher to return format warnings
with patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_format_compatibility"
) as mock_format:
mock_format.return_value = [
"_validate_plugin_runtime"
) as mock_plugin:
mock_plugin.return_value = [
"Currency format '$,.2f' may not display dates correctly"
]
@@ -87,15 +87,14 @@ class TestRuntimeValidatorNonBlocking:
kind="bar",
)
# Mock the cardinality validator to return warnings
# Mock the plugin runtime dispatcher to return cardinality warnings
with patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_cardinality"
) as mock_cardinality:
mock_cardinality.return_value = (
["High cardinality detected: 10000+ unique values"],
["Consider using aggregation or filtering"],
)
"_validate_plugin_runtime"
) as mock_plugin:
mock_plugin.return_value = [
"High cardinality detected: 10000+ unique values"
]
is_valid, warnings_metadata = RuntimeValidator.validate_runtime_issues(
config, 1
@@ -148,26 +147,21 @@ class TestRuntimeValidatorNonBlocking:
x_axis=AxisConfig(format="smart_date"), # Wrong format for user_id
)
# Mock all validators to return warnings
# Mock plugin runtime and chart type validators to return warnings
with (
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_format_compatibility"
) as mock_format,
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_cardinality"
) as mock_cardinality,
"_validate_plugin_runtime"
) as mock_plugin,
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_chart_type"
) as mock_type,
):
mock_format.return_value = ["Format mismatch warning"]
mock_cardinality.return_value = (
["High cardinality warning"],
["Cardinality suggestion"],
)
mock_plugin.return_value = [
"Format mismatch warning",
"High cardinality warning",
]
mock_type.return_value = (
["Chart type warning"],
["Chart type suggestion"],
@@ -197,13 +191,13 @@ class TestRuntimeValidatorNonBlocking:
with (
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_format_compatibility"
) as mock_format,
"_validate_plugin_runtime"
) as mock_plugin,
patch(
"superset.mcp_service.chart.validation.runtime.logger"
) as mock_logger,
):
mock_format.return_value = ["Test warning message"]
mock_plugin.return_value = ["Test warning message"]
is_valid, warnings_metadata = RuntimeValidator.validate_runtime_issues(
config, 1
@@ -217,7 +211,7 @@ class TestRuntimeValidatorNonBlocking:
assert "warnings" in warnings_metadata
def test_validate_table_chart_skips_xy_validations(self):
"""Test that table charts skip XY-specific validations."""
"""Test that table charts produce no XY-specific runtime warnings."""
config = TableChartConfig(
chart_type="table",
columns=[
@@ -226,37 +220,29 @@ class TestRuntimeValidatorNonBlocking:
],
)
# These should not be called for table charts
with (
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_format_compatibility"
) as mock_format,
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_cardinality"
) as mock_cardinality,
patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_chart_type"
) as mock_chart_type,
):
# Mock chart type validator to return no warnings
# Plugin runtime dispatches to TableChartPlugin which returns no warnings.
# Chart type suggester is also stubbed to return no warnings.
with patch(
"superset.mcp_service.chart.validation.runtime.RuntimeValidator."
"_validate_chart_type"
) as mock_chart_type:
mock_chart_type.return_value = ([], [])
is_valid, error = RuntimeValidator.validate_runtime_issues(config, 1)
# Format and cardinality validation should not be called for table charts
mock_format.assert_not_called()
mock_cardinality.assert_not_called()
assert is_valid is True
assert error is None
def test_validate_cardinality_returns_cleanly_when_x_name_is_none(self) -> None:
"""The dimension-rejection guard on XYChartConfig normally forbids
x.name=None, but a model_construct bypass (or a future code path)
could land us here. The defensive guard must return cleanly without
calling into CardinalityValidator (which assumes a real column)."""
could land us here. The defensive guard in XYChartPlugin.get_runtime_warnings
must skip cardinality without crashing."""
from superset.mcp_service.chart.plugins.xy import XYChartPlugin
from superset.mcp_service.chart.validation.runtime.format_validator import (
FormatTypeValidator,
)
col = ColumnRef.model_construct(name=None)
config = XYChartConfig.model_construct(
chart_type="xy",
@@ -265,14 +251,19 @@ class TestRuntimeValidatorNonBlocking:
kind="line",
)
with patch(
"superset.mcp_service.chart.validation.runtime."
"cardinality_validator.CardinalityValidator.check_cardinality"
) as mock_check:
warnings, suggestions = RuntimeValidator._validate_cardinality(
config, dataset_id=1
)
plugin = XYChartPlugin()
with (
patch.object(
FormatTypeValidator,
"validate_format_compatibility",
return_value=(True, []),
),
patch(
"superset.mcp_service.chart.validation.runtime."
"cardinality_validator.CardinalityValidator.check_cardinality"
) as mock_check,
):
warnings = plugin.get_runtime_warnings(config, dataset_id=1)
assert warnings == []
assert suggestions == []
mock_check.assert_not_called()