Compare commits

...

17 Commits

Author SHA1 Message Date
Amin Ghadersohi
95b9efe1a3 fix(mcp): prevent ValueError when select_columns contains only USER_DIRECTORY_FIELDS in list_rls_filters
When the caller passes select_columns that consists entirely of USER_DIRECTORY_FIELDS
columns (e.g. ["roles"]), ModelListCore raises ValueError because its privacy filter
strips all columns, leaving an empty list.

Strip USER_DIRECTORY_FIELDS from select_columns before passing to run_tool (falling
back to None/defaults when the filtered list is empty). The existing bypass mechanism
already restores these fields in the final serialized output using ALL_RLS_COLUMNS.

Adds a regression test for the ["roles"]-only select_columns edge case.
2026-05-27 18:33:25 +00:00
Amin Ghadersohi
336d30fbdc fix(mcp): correct serializer signatures and naming per review feedback
- Fix `_serialize` in list_plugins and list_rls_filters: change
  `cols: list[str] | None` to `cols: list[str]` to match the
  `Callable[[T, List[str]], S | None]` signature expected by ModelListCore
- Rename `_serialize` to `_serialize_rls_filter` in list_rls_filters for
  clarity and consistency with other list tool conventions
2026-05-27 15:20:59 +00:00
Amin Ghadersohi
ecd073b3d2 ci: trigger CI for fix 2026-05-27 15:20:59 +00:00
Evan Rusackas
768a3df21b fix(datasets): isolate filter state to fix concurrent /dataset race (#39685)
Co-authored-by: Claude Code <noreply@anthropic.com>
2026-05-27 15:20:59 +00:00
Elizabeth Thompson
e33d0d4f44 fix(reports): guard null dashboard height in Playwright screenshots (#40179) 2026-05-27 15:20:59 +00:00
Mehmet Salih Yavuz
776fcf2f8b feat(mcp): make config optional in generate_explore_link (#39559) 2026-05-27 15:20:59 +00:00
Mehmet Salih Yavuz
2ebb3bcd9f feat(mcp): include applied dashboard filters in get_chart_info (#39620)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-27 15:20:59 +00:00
Mehmet Salih Yavuz
ff2c37e52b fix(mcp): eager-load dataset.metrics to prevent Excel export DetachedInstanceError (#39483)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-27 15:20:59 +00:00
Beto Dealmeida
e14a92f881 fix(semantic layers): coerce filter types (#40222) 2026-05-27 15:20:59 +00:00
Mehmet Salih Yavuz
3d528921ba feat(mcp): add find_users tool and owner filter columns for listings (#39679)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-27 15:20:59 +00:00
Alexandru Soare
81aad1c540 fix(recommandation): Fix chart recommandation (#39886) 2026-05-27 15:20:59 +00:00
Mehmet Salih Yavuz
c8c805c7e8 feat(mcp): chart formatting options across all supported chart types (#39887) 2026-05-27 15:20:59 +00:00
Amin Ghadersohi
ab69e3dbcc docs(mcp): document that list_rls_filters and list_plugins have inline column docs 2026-05-21 08:14:03 +00:00
Amin Ghadersohi
c470b7a8d4 fix(mcp): restore 'roles' to USER_DIRECTORY_FIELDS and bypass filter in RLS list tool
'roles' on a dashboard/chart exposes who has access to the resource and
should be stripped by the USER_DIRECTORY_FIELDS privacy filter.

'roles' in an RLS filter is which roles the filter applies to — it is
core filter data, not user-directory metadata. The RLS list tool now
derives its column selection directly from ALL_RLS_COLUMNS (bypassing
ModelListCore's USER_DIRECTORY_FIELDS filtering) so that RLS roles are
selectable while dashboard roles remain hidden.

Fixes three failing unit tests:
- test_list_dashboards_omits_requested_user_directory_fields
- test_get_allowed_fields_always_denies_user_directory_fields
- test_filter_sensitive_data_strips_user_directory_fields_even_if_allowed

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 08:14:03 +00:00
Amin Ghadersohi
7da1bca909 ci: trigger CI for fix 2026-05-21 08:14:03 +00:00
Amin Ghadersohi
06d72bcc72 fix(mcp): remove 'roles' from USER_DIRECTORY_FIELDS to allow RLS filter roles to be returned
RLS filter `roles` (which roles a filter applies to) are core RLS data,
not user-directory metadata. Including 'roles' in USER_DIRECTORY_FIELDS
caused filter_user_directory_columns() to strip it from any requested
select_columns list, making it impossible to retrieve via list_rls_filters.

No dashboard/chart/dataset schema defines a 'roles' field, so removing it
from the block set has no privacy impact on other tools.

Fixes test_list_rls_filters_returns_tables_and_roles.
2026-05-21 08:14:03 +00:00
Amin Ghadersohi
473456b6ea feat(mcp): add list and get tools for row level security and plugins
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 08:14:03 +00:00
51 changed files with 4466 additions and 162 deletions

View File

@@ -123,6 +123,14 @@ Database Connections:
- list_databases: List database connections with advanced filters (1-based pagination)
- get_database_info: Get detailed database connection info by ID (backend, capabilities)
Row Level Security (Admin only):
- list_rls_filters: List RLS filters with filtering and search (1-based pagination)
- get_rls_filter_info: Get detailed RLS filter info by ID (tables, roles, clause)
Plugins (Admin only):
- list_plugins: List dynamic plugins with filtering and search (1-based pagination)
- get_plugin_info: Get detailed plugin info by ID (name, key, bundle URL)
Dataset Management:
- list_datasets: List datasets with advanced filters (1-based pagination)
- get_dataset_info: Get detailed dataset information by ID (includes columns/metrics)
@@ -150,6 +158,7 @@ Schema Discovery:
System Information:
- get_instance_info: Get instance-wide statistics, metadata, and current user identity
- find_users: Resolve a person's name to user IDs for use as a filter value
- health_check: Simple health check tool (takes NO parameters, call without arguments)
- generate_bug_report: Build a PII-sanitized bug report to send to Preset support
(use when the user says the MCP is broken or asks how to report an issue)
@@ -191,6 +200,16 @@ Some tools do not use a request wrapper, so follow each tool's schema
Recommended Workflows:
To filter dashboards/charts/datasets by a person ("show me what <name> is working on"):
1. find_users(request={{"query": "<name>"}}) -> resolve to user IDs
2. Pick the matching user.id from the response
3. list_dashboards(request={{"filters": [
{{"col": "created_by_fk", "opr": "eq", "value": <id>}}
]}}) — same shape for list_charts / list_datasets.
(use changed_by_fk for "last modified by", or "in" with a list of IDs for
multiple matches). Do NOT pass the person's name as the search parameter —
search matches titles, not people.
To add a chart to an existing dashboard:
1. add_chart_to_existing_dashboard(dashboard_id, chart_id) -> updates dashboard directly
- If permission_denied=True is returned: inform the user they lack edit rights,
@@ -342,7 +361,10 @@ IMPORTANT - Tool-Only Interaction:
General usage tips:
- All listing tools use 1-based pagination (first page is 1)
- Use get_schema to discover filterable columns, sortable columns, and default columns
- Use get_schema (chart/dataset/dashboard/database) to discover filterable columns,
sortable columns, and default columns for those resource types
- For list_rls_filters and list_plugins, filterable/sortable columns are listed
inline in each tool's docstring — get_schema does not cover these tools
- Use 'filters' parameter for advanced queries with filter columns from get_schema
- IDs can be integer or UUID format where supported
- All tools return structured, Pydantic-typed responses
@@ -360,6 +382,11 @@ Input format:
contact details, roles, admin status, ownership, or access-list information.
- Do NOT infer access-list answers from dashboard metadata such as published status,
role restrictions, empty owner lists, or schema fields.
- find_users is sanctioned ONLY for resolving a name the user supplied into a
user ID for filtering (e.g., "what is <name> working on" -> filter
list_dashboards by created_by_fk). Do NOT use find_users to answer "who owns
X", "who can access X", "is <name> an admin", or to enumerate the directory.
Never return find_users output to the user verbatim.
- Do NOT use execute_sql to query user, role, owner, or access-list tables for this
information.
- You may reference the current user's own identity details when appropriate, such
@@ -620,6 +647,14 @@ from superset.mcp_service.dataset.tool import ( # noqa: F401, E402
from superset.mcp_service.explore.tool import ( # noqa: F401, E402
generate_explore_link,
)
from superset.mcp_service.plugin.tool import ( # noqa: F401, E402
get_plugin_info,
list_plugins,
)
from superset.mcp_service.rls.tool import ( # noqa: F401, E402
get_rls_filter_info,
list_rls_filters,
)
from superset.mcp_service.sql_lab.tool import ( # noqa: F401, E402
execute_sql,
open_sql_lab_with_context,
@@ -630,6 +665,7 @@ from superset.mcp_service.system import ( # noqa: F401, E402
resources as system_resources,
)
from superset.mcp_service.system.tool import ( # noqa: F401, E402
find_users,
generate_bug_report,
get_instance_info,
get_schema,

View File

@@ -32,6 +32,7 @@ from urllib.parse import parse_qs, urlparse
from superset.constants import EXTRA_FORM_DATA_OVERRIDE_REGULAR_MAPPINGS
if TYPE_CHECKING:
from superset.mcp_service.chart.schemas import AppliedDashboardFilter
from superset.models.slice import Slice
logger = logging.getLogger(__name__)
@@ -44,20 +45,33 @@ QUERY_CONTEXT_EXTRA_FORM_DATA_OVERRIDE_KEYS = {
}
def find_chart_by_identifier(identifier: int | str) -> Slice | None:
class ChartNotOnDashboardError(ValueError):
"""Raised when a chart is not part of the given dashboard's slices."""
def find_chart_by_identifier(
identifier: int | str,
query_options: list[Any] | None = None,
) -> Slice | None:
"""Find a chart by numeric ID or UUID string.
Accepts an integer ID, a string that looks like a digit (e.g. "123"),
or a UUID string. Returns the Slice model instance or None.
``query_options`` is forwarded to the DAO so callers can eager-load
relationships needed after the request-scoped session is detached.
"""
from superset.daos.chart import ChartDAO # avoid circular import
extra: dict[str, Any] = (
{"query_options": query_options} if query_options is not None else {}
)
if isinstance(identifier, int) or (
isinstance(identifier, str) and identifier.isdigit()
):
chart_id = int(identifier) if isinstance(identifier, str) else identifier
return ChartDAO.find_by_id(chart_id)
return ChartDAO.find_by_id(identifier, id_column="uuid")
return ChartDAO.find_by_id(chart_id, **extra)
return ChartDAO.find_by_id(identifier, id_column="uuid", **extra)
def get_cached_form_data(form_data_key: str) -> str | None:
@@ -528,3 +542,118 @@ def extract_form_data_key_from_url(url: str | None) -> str | None:
parsed = urlparse(url)
values = parse_qs(parsed.query).get("form_data_key", [])
return values[0] if values else None
def _match_adhoc_by_subject(
adhoc_filters: Any, column: str | None
) -> tuple[str | None, Any] | None:
if not column or not isinstance(adhoc_filters, list):
return None
for af in adhoc_filters:
if isinstance(af, dict) and af.get("subject") == column:
return af.get("operator"), af.get("comparator")
return None
def _match_legacy_by_col(
legacy_filters: Any, column: str | None
) -> tuple[str | None, Any] | None:
if not column or not isinstance(legacy_filters, list):
return None
for f in legacy_filters:
if isinstance(f, dict) and f.get("col") == column:
return f.get("op"), f.get("val")
return None
def _resolve_filter_operator_and_value(
extra_form_data: dict[str, Any] | None,
column: str | None,
) -> tuple[str | None, Any]:
"""Pull operator and value for a dashboard filter from its
default extra_form_data, matching on target column where applicable."""
if not extra_form_data:
return None, None
if match := _match_adhoc_by_subject(extra_form_data.get("adhoc_filters"), column):
return match
if match := _match_legacy_by_col(extra_form_data.get("filters"), column):
return match
# Temporal filters contribute time_range with no target column
if time_range := extra_form_data.get("time_range"):
return "TIME_RANGE", time_range
return None, None
def build_applied_dashboard_filters(
dashboard_id: int, chart_id: int
) -> list[AppliedDashboardFilter]:
"""Resolve dashboard-level native filters in scope for a chart.
Validates that the dashboard exists, the caller has access, and the chart
is on the dashboard. Returns one AppliedDashboardFilter per non-DIVIDER
native filter whose scope includes the chart, populated with the filter's
default operator and value.
Raises DashboardNotFoundError if the dashboard is missing,
ChartNotOnDashboardError if the chart is not on it, and
SupersetSecurityException if the caller cannot access the dashboard.
"""
# Local imports avoid circular deps at module load
from superset import db, security_manager
from superset.charts.data.dashboard_filter_context import (
_extract_filter_extra_form_data,
_get_filter_target_column,
_is_filter_in_scope_for_chart,
)
from superset.commands.dashboard.exceptions import DashboardNotFoundError
from superset.mcp_service.chart.schemas import AppliedDashboardFilter
from superset.models.dashboard import Dashboard
from superset.utils import json
dashboard = db.session.query(Dashboard).filter_by(id=dashboard_id).one_or_none()
if not dashboard:
raise DashboardNotFoundError(dashboard_id=str(dashboard_id))
security_manager.raise_for_access(dashboard=dashboard)
slice_ids = {slc.id for slc in dashboard.slices}
if chart_id not in slice_ids:
raise ChartNotOnDashboardError(
f"Chart {chart_id} is not on dashboard {dashboard_id}"
)
metadata = json.loads(dashboard.json_metadata or "{}")
native_filter_config = metadata.get("native_filter_configuration", [])
if not isinstance(native_filter_config, list):
return []
position_json = json.loads(dashboard.position_json or "{}")
if not isinstance(position_json, dict):
position_json = {}
applied: list[AppliedDashboardFilter] = []
for flt in native_filter_config:
if not isinstance(flt, dict):
continue
if flt.get("type", "") == "DIVIDER":
continue
if not _is_filter_in_scope_for_chart(flt, chart_id, position_json):
continue
extra_form_data, status = _extract_filter_extra_form_data(flt)
column = _get_filter_target_column(flt)
operator, value = _resolve_filter_operator_and_value(extra_form_data, column)
applied.append(
AppliedDashboardFilter(
id=flt.get("id"),
name=flt.get("name"),
filter_type=flt.get("filterType"),
column=column,
operator=operator,
value=value,
status=status.value,
)
)
return applied

View File

@@ -32,6 +32,7 @@ from superset.mcp_service.chart.schemas import (
ChartCapabilities,
ChartSemantics,
ColumnRef,
CurrencyFormat,
FilterConfig,
HandlebarsChartConfig,
MixedTimeseriesChartConfig,
@@ -477,6 +478,7 @@ def map_table_config(config: TableChartConfig) -> Dict[str, Any]:
]
form_data["row_limit"] = config.row_limit
add_color_scheme(form_data, config.color_scheme)
return form_data
@@ -547,7 +549,35 @@ def add_legend_config(form_data: Dict[str, Any], config: XYChartConfig) -> None:
if not config.legend.show:
form_data["show_legend"] = False
if config.legend.position:
form_data["legend_orientation"] = config.legend.position
# Canonical form_data key is camelCase; the echarts plugins read
# `legendOrientation` directly off form_data.
form_data["legendOrientation"] = config.legend.position
def add_color_scheme(form_data: Dict[str, Any], color_scheme: str | None) -> None:
"""Add color scheme to form_data when set."""
if color_scheme:
form_data["color_scheme"] = color_scheme
def add_currency_format(
form_data: Dict[str, Any],
currency_format: CurrencyFormat | None,
key: str = "currency_format",
) -> None:
"""Add currency format to form_data under the given key when set."""
if currency_format:
form_data[key] = currency_format.to_form_data()
def add_xy_data_label_options(
form_data: Dict[str, Any], config: XYChartConfig, x_is_temporal: bool
) -> None:
"""Apply XY-specific data-label and time-format options when set."""
if config.x_axis_time_format and x_is_temporal:
form_data["x_axis_time_format"] = config.x_axis_time_format
if config.show_value:
form_data["show_value"] = True
def add_orientation_config(form_data: Dict[str, Any], config: XYChartConfig) -> None:
@@ -722,6 +752,9 @@ def map_xy_config(
add_axis_config(form_data, config)
add_legend_config(form_data, config)
add_orientation_config(form_data, config)
add_color_scheme(form_data, config.color_scheme)
add_currency_format(form_data, config.currency_format)
add_xy_data_label_options(form_data, config, x_is_temporal)
return form_data
@@ -734,11 +767,13 @@ def map_pie_config(config: PieChartConfig) -> Dict[str, Any]:
"viz_type": "pie",
"groupby": [config.dimension.name],
"metric": metric,
"color_scheme": "supersetColors",
"color_scheme": config.color_scheme or "supersetColors",
"show_labels": config.show_labels,
"show_legend": config.show_legend,
"legendOrientation": config.legend_orientation,
"label_type": config.label_type,
"number_format": config.number_format,
"date_format": config.date_format,
"sort_by_metric": config.sort_by_metric,
"row_limit": config.row_limit,
"donut": config.donut,
@@ -746,9 +781,9 @@ def map_pie_config(config: PieChartConfig) -> Dict[str, Any]:
"labels_outside": config.labels_outside,
"outerRadius": config.outer_radius,
"innerRadius": config.inner_radius,
"date_format": "smart_date",
}
add_currency_format(form_data, config.currency_format)
_add_adhoc_filters(form_data, config.filters)
return form_data
@@ -774,6 +809,9 @@ def map_big_number_config(config: BigNumberChartConfig) -> Dict[str, Any]:
if config.y_axis_format:
form_data["y_axis_format"] = config.y_axis_format
add_color_scheme(form_data, config.color_scheme)
add_currency_format(form_data, config.currency_format)
# Trendline-specific fields
if viz_type == "big_number":
# Big Number with trendline uses granularity_sqla for the temporal column
@@ -789,6 +827,9 @@ def map_big_number_config(config: BigNumberChartConfig) -> Dict[str, Any]:
if config.compare_lag is not None:
form_data["compare_lag"] = config.compare_lag
if config.time_format:
form_data["time_format"] = config.time_format
_add_adhoc_filters(form_data, config.filters)
return form_data
@@ -860,6 +901,10 @@ def map_pivot_table_config(config: PivotTableChartConfig) -> Dict[str, Any]:
"row_limit": config.row_limit,
}
if config.date_format:
form_data["date_format"] = config.date_format
add_currency_format(form_data, config.currency_format)
_add_adhoc_filters(form_data, config.filters)
return form_data
@@ -939,10 +984,20 @@ def map_mixed_timeseries_config(
"yAxisIndexB": 1,
# Display
"show_legend": config.show_legend,
"legendOrientation": config.legend_orientation,
"zoomable": True,
"rich_tooltip": True,
}
if config.show_value:
form_data["show_value"] = True
add_color_scheme(form_data, config.color_scheme)
add_currency_format(form_data, config.currency_format)
add_currency_format(
form_data, config.currency_format_secondary, key="currency_format_secondary"
)
# Configure temporal handling
configure_temporal_handling(form_data, x_is_temporal, config.time_grain)

View File

@@ -23,7 +23,7 @@ from __future__ import annotations
import difflib
from datetime import datetime
from typing import Annotated, Any, Dict, List, Literal, Protocol
from typing import Annotated, Any, cast, Dict, List, Literal, Protocol
import humanize
from pydantic import (
@@ -146,7 +146,7 @@ class ChartInfo(BaseModel):
),
)
form_data_key: str | None = Field(
None,
default=None,
description=(
"Cache key used to retrieve unsaved form_data. When present, indicates "
"the form_data came from cache (unsaved edits) rather than the saved chart."
@@ -279,6 +279,15 @@ class GetChartInfoRequest(BaseModel):
"Can be used alone (without identifier) for unsaved charts."
),
)
dashboard_id: int | None = Field(
default=None,
description=(
"When provided, resolves dashboard-level native filters that are in "
"scope for this chart on the given dashboard and returns them under "
"filters.dashboard_filters. Requires the chart to be on the dashboard "
"and the caller to have dashboard access."
),
)
@model_validator(mode="after")
def validate_identifier_or_form_data_key(self) -> "GetChartInfoRequest":
@@ -523,17 +532,18 @@ class ChartFilter(ColumnOperator):
value: The value to filter by (type depends on col and opr).
"""
col: Literal[
col: Literal[ # pyright: ignore[reportIncompatibleVariableOverride]
"slice_name",
"viz_type",
"datasource_name",
"created_by_fk",
"changed_by_fk",
] = Field(
...,
description=(
"Column to filter on. Valid values: 'slice_name', 'viz_type', "
"'datasource_name'. Other column names are not valid filter columns "
"and will cause a validation error."
),
description="Column to filter on. Use get_schema(model_type='chart') for "
"available filter columns. To filter by a person, first call find_users "
"to resolve a name to a user ID, then filter by created_by_fk or "
"changed_by_fk with that integer ID.",
)
opr: ColumnOperatorEnum = Field(
...,
@@ -731,6 +741,29 @@ class LegendConfig(BaseModel):
position: Literal["top", "bottom", "left", "right"] | None = "right"
class CurrencyFormat(BaseModel):
"""Currency symbol and placement applied to numeric values."""
model_config = ConfigDict(populate_by_name=True)
symbol: str = Field(
...,
description="Currency code or symbol (e.g. 'USD', 'EUR', '$', '')",
max_length=20,
)
symbol_position: Literal["prefix", "suffix"] = Field(
"prefix",
description="Whether to render the symbol before or after the value",
validation_alias=AliasChoices("symbol_position", "symbolPosition"),
)
def to_form_data(self) -> Dict[str, str]:
return {"symbol": self.symbol, "symbolPosition": self.symbol_position}
LEGEND_POSITION_LITERAL = Literal["top", "bottom", "left", "right"]
class FilterConfig(BaseModel):
model_config = ConfigDict(populate_by_name=True)
@@ -857,6 +890,27 @@ class PieChartConfig(UnknownFieldCheckMixin):
)
row_limit: int = Field(100, description="Max slices", ge=1, le=10000)
number_format: str = Field("SMART_NUMBER", max_length=50)
date_format: str = Field(
"smart_date",
description="Date format for date dimension labels (e.g. 'smart_date', "
"'%Y-%m-%d')",
max_length=50,
)
currency_format: CurrencyFormat | None = Field(
None,
description="Currency symbol applied to the metric value",
)
color_scheme: str | None = Field(
None,
description=(
"Superset color scheme ID (e.g. 'supersetColors', 'lyftColors', "
"'googleCategory10c', 'd3Category10'). Defaults to 'supersetColors'."
),
max_length=100,
)
legend_orientation: LEGEND_POSITION_LITERAL = Field(
"top", description="Legend placement around the chart"
)
show_total: bool = Field(False, description="Show total in center")
labels_outside: bool = True
outer_radius: int = Field(70, description="Outer radius % (1-100)", ge=1, le=100)
@@ -908,6 +962,15 @@ class PivotTableChartConfig(UnknownFieldCheckMixin):
)
row_limit: int = Field(10000, description="Max cells", ge=1, le=50000)
value_format: str = Field("SMART_NUMBER", max_length=50)
date_format: str | None = Field(
None,
description="Date format for date columns (e.g. 'smart_date', '%Y-%m-%d')",
max_length=50,
)
currency_format: CurrencyFormat | None = Field(
None,
description="Currency symbol applied to numeric metric values",
)
class MixedTimeseriesChartConfig(UnknownFieldCheckMixin):
@@ -954,9 +1017,29 @@ class MixedTimeseriesChartConfig(UnknownFieldCheckMixin):
)
# Display options
show_legend: bool = True
legend_orientation: LEGEND_POSITION_LITERAL = Field(
"top", description="Legend placement around the chart"
)
show_value: bool = Field(False, description="Show data labels on each data point")
x_axis: AxisConfig | None = None
y_axis: AxisConfig | None = None
y_axis_secondary: AxisConfig | None = None
color_scheme: str | None = Field(
None,
description=(
"Superset color scheme ID (e.g. 'supersetColors', 'lyftColors'). "
"When omitted, Superset's default scheme is used."
),
max_length=100,
)
currency_format: CurrencyFormat | None = Field(
None,
description="Currency symbol applied to primary metric values",
)
currency_format_secondary: CurrencyFormat | None = Field(
None,
description="Currency symbol applied to secondary metric values",
)
filters: List[FilterConfig] | None = Field(
None,
description="Structured filters (column/op/value). "
@@ -1132,6 +1215,27 @@ class BigNumberChartConfig(UnknownFieldCheckMixin):
),
max_length=50,
)
time_format: str | None = Field(
None,
description=(
"Date format string for trendline x-axis labels "
"(e.g. 'smart_date', '%Y-%m-%d'). Only applies when "
"show_trendline=True."
),
max_length=50,
)
currency_format: CurrencyFormat | None = Field(
None,
description="Currency symbol applied to the metric value",
)
color_scheme: str | None = Field(
None,
description=(
"Superset color scheme ID for the trendline (e.g. 'supersetColors'). "
"When omitted, Superset's default scheme is used."
),
max_length=100,
)
start_y_axis_at_zero: bool = Field(
True,
description="Anchor trendline y-axis at zero",
@@ -1217,6 +1321,14 @@ class TableChartConfig(UnknownFieldCheckMixin):
validation_alias=AliasChoices("sort_by", "order_by_cols", "order_by"),
)
row_limit: int = Field(1000, description="Max rows returned", ge=1, le=50000)
color_scheme: str | None = Field(
None,
description=(
"Superset color scheme ID applied to conditional/cell formatting "
"(e.g. 'supersetColors')."
),
max_length=100,
)
@model_validator(mode="after")
def validate_unique_column_labels(self) -> "TableChartConfig":
@@ -1298,6 +1410,28 @@ class XYChartConfig(UnknownFieldCheckMixin):
x_axis: AxisConfig | None = None
y_axis: AxisConfig | None = None
legend: LegendConfig | None = None
x_axis_time_format: str | None = Field(
None,
description=(
"Date format for temporal x-axis labels (e.g. 'smart_date', "
"'%Y-%m-%d'). Only applies when the x-axis column is temporal."
),
max_length=50,
)
show_value: bool = Field(False, description="Show data labels on each data point")
currency_format: CurrencyFormat | None = Field(
None,
description="Currency symbol applied to metric values",
)
color_scheme: str | None = Field(
None,
description=(
"Superset color scheme ID (e.g. 'supersetColors', 'lyftColors', "
"'googleCategory10c', 'd3Category10'). When omitted, Superset's "
"default scheme is used."
),
max_length=100,
)
filters: List[FilterConfig] | None = Field(
None,
description="Structured filters (column/op/value). "
@@ -1414,7 +1548,10 @@ class ListChartsRequest(OwnedByMeMixin, CreatedByMeMixin, MetadataCacheControl):
"""
from superset.mcp_service.utils.schema_utils import parse_json_or_model_list
return parse_json_or_model_list(v, ChartFilter, "filters")
return cast(
List[ChartFilter],
parse_json_or_model_list(v, ChartFilter, "filters"),
)
@field_validator("select_columns", mode="before")
@classmethod
@@ -1575,7 +1712,14 @@ class GenerateChartRequest(QueryCacheControl):
class GenerateExploreLinkRequest(FormDataCacheControl):
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
config: ChartConfig = Field(..., description="Chart configuration")
config: ChartConfig | None = Field(
None,
description=(
"Chart configuration. Optional; omit to get a default "
"explore URL that opens the dataset in Superset without a "
"preconfigured chart."
),
)
class UpdateChartRequest(QueryCacheControl):
@@ -2064,6 +2208,38 @@ class AdhocFilter(BaseModel):
model_config = ConfigDict(extra="ignore")
class AppliedDashboardFilter(BaseModel):
"""A dashboard-level native filter resolved against a specific chart.
Returned when get_chart_info is called with a dashboard_id. Values come
from the filter's default state on the saved dashboard (not a permalink).
"""
id: str | None = Field(None, description="Native filter ID")
name: str | None = Field(None, description="Filter display name")
filter_type: str | None = Field(
None, description="Native filter type (e.g. filter_select, filter_range)"
)
column: str | None = Field(None, description="Target column the filter applies to")
operator: str | None = Field(
None,
description=(
"Filter operator as stored in extra_form_data (e.g. 'IN', '==', 'LIKE', "
"or 'TIME_RANGE' for temporal filters with no target column)"
),
)
value: Any | None = Field(
None, description="Filter value(s) from the default data mask"
)
status: str = Field(
...,
description=(
"Whether the filter contributes to the chart query: 'applied', "
"'not_applied', or 'not_applied_uses_default_to_first_item_prequery'"
),
)
class ChartFiltersInfo(BaseModel):
"""Structured representation of all filters applied to a chart."""
@@ -2105,6 +2281,15 @@ class ChartFiltersInfo(BaseModel):
None,
description="Custom HAVING clause applied to the chart query",
)
dashboard_filters: List[AppliedDashboardFilter] = Field(
default_factory=list,
description=(
"Dashboard-level native filters in scope for this chart on the "
"dashboard passed via get_chart_info's dashboard_id argument. Empty "
"when no dashboard_id was provided or no native filter targets this "
"chart."
),
)
# Rebuild ChartInfo so Pydantic can resolve the ChartFiltersInfo forward reference.

View File

@@ -26,6 +26,7 @@ from typing import Any, Dict, List, TYPE_CHECKING
from fastmcp import Context
from flask import current_app
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import subqueryload
from superset_core.mcp.decorators import tool, ToolAnnotations
if TYPE_CHECKING:
@@ -58,9 +59,177 @@ from superset.mcp_service.utils.oauth2_utils import (
build_oauth2_redirect_message,
OAUTH2_CONFIG_ERROR_MESSAGE,
)
from superset.utils.core import GenericDataType
logger = logging.getLogger(__name__)
_GENERIC_TYPE_MAP: dict[int, str] = {
GenericDataType.NUMERIC: "numeric",
GenericDataType.STRING: "string",
GenericDataType.TEMPORAL: "temporal",
GenericDataType.BOOLEAN: "boolean",
}
# Maps Superset viz_type strings to canonical categories so we can
# avoid recommending a chart type the user already has.
_VIZ_CATEGORY: dict[str, str] = {
"echarts_timeseries_line": "line",
"echarts_timeseries_smooth": "line",
"echarts_timeseries_step": "line",
"echarts_timeseries": "line",
"echarts_timeseries_bar": "bar",
"echarts_area": "area",
"echarts_timeseries_scatter": "scatter",
"mixed_timeseries": "line",
"table": "table",
"pie": "pie",
"big_number": "kpi",
"big_number_total": "kpi",
"pop_kpi": "kpi",
"dist_bar": "bar",
"line": "line",
"area": "area",
"scatter": "scatter",
"bubble": "bubble",
"treemap_v2": "treemap",
"sunburst_v2": "treemap",
"heatmap_v2": "heatmap",
"gauge_chart": "gauge",
"funnel": "funnel",
"histogram": "histogram",
"histogram_v2": "histogram",
"box_plot": "box_plot",
"world_map": "map",
"pivot_table_v2": "table",
}
_MAX_RECOMMENDATIONS = 4
def _recommend_visualizations(
viz_type: str,
columns: list[DataColumn],
row_count: int,
) -> list[str]:
"""Suggest visualization types based on column types,
cardinality, and the chart's current viz_type.
"""
if not columns:
return ["table"]
current_category = _VIZ_CATEGORY.get(viz_type, viz_type)
candidates = _build_candidates(columns, row_count)
if not candidates:
candidates = ["table", "bar chart"]
return _filter_candidates(candidates, current_category)
def _build_candidates(
columns: list[DataColumn],
row_count: int,
) -> list[str]:
"""Build candidate visualization list from column metadata."""
temporal = [c for c in columns if c.data_type == "temporal"]
numeric = [c for c in columns if c.data_type == "numeric"]
categorical = [c for c in columns if c.data_type in ("string", "boolean")]
if temporal and numeric:
return _candidates_temporal_numeric(numeric, row_count)
if categorical and numeric:
return _candidates_categorical_numeric(numeric, categorical)
if len(numeric) >= 2:
return _candidates_multi_numeric(numeric, categorical)
if len(numeric) == 1 and not temporal and not categorical:
return _candidates_single_numeric(numeric[0], row_count)
return []
def _candidates_temporal_numeric(
numeric: list[DataColumn], row_count: int
) -> list[str]:
# Few data points are better as a bar chart than a line
if row_count < 5:
candidates = ["bar chart", "table"]
else:
candidates = ["line chart", "area chart", "bar chart"]
if len(numeric) > 1:
candidates.append("multi-line chart")
return candidates
def _candidates_categorical_numeric(
numeric: list[DataColumn],
categorical: list[DataColumn],
) -> list[str]:
candidates = ["bar chart"]
if len(numeric) == 1 and categorical[0].unique_count <= 10:
candidates.append("pie chart")
if len(numeric) >= 2:
candidates.append("scatter plot")
candidates.append("heatmap")
if any(c.unique_count > 5 for c in categorical):
candidates.append("treemap")
return candidates
def _candidates_single_numeric(col: DataColumn, row_count: int) -> list[str]:
candidates = ["big number / KPI", "gauge chart"]
if row_count > 20 and col.unique_count > 10:
candidates.insert(0, "histogram")
return candidates
def _candidates_multi_numeric(
numeric: list[DataColumn],
categorical: list[DataColumn],
) -> list[str]:
candidates = ["scatter plot"]
if len(numeric) >= 3:
candidates.append("bubble chart")
if categorical:
candidates.append("heatmap")
return candidates
# Maps each candidate string to a canonical category for dedup
# against the current viz_type.
_CANDIDATE_CATEGORY: dict[str, str] = {
"line chart": "line",
"multi-line chart": "line",
"area chart": "area",
"bar chart": "bar",
"scatter plot": "scatter",
"bubble chart": "bubble",
"pie chart": "pie",
"treemap": "treemap",
"heatmap": "heatmap",
"big number / KPI": "kpi",
"gauge chart": "gauge",
"histogram": "histogram",
"table": "table",
}
def _filter_candidates(
candidates: list[str],
current_category: str,
) -> list[str]:
"""Deduplicate, exclude the current viz category, and cap."""
seen: set[str] = set()
result: list[str] = []
for c in candidates:
if c in seen:
continue
if _CANDIDATE_CATEGORY.get(c) == current_category:
continue
seen.add(c)
result.append(c)
if len(result) >= _MAX_RECOMMENDATIONS:
break
return result
def _sanitize_chart_data_for_llm_context(chart_data: ChartData) -> ChartData:
"""Wrap chart data read-path descriptive fields before LLM exposure."""
@@ -182,7 +351,18 @@ async def get_chart_data( # noqa: C901
# Build query context entirely from cached form_data
return await _query_from_form_data(cached_form_data_dict, request, ctx)
# Find the chart by identifier
# Find the chart by identifier.
# Eagerly load the dataset's metrics relationship so Excel export
# (which may run after the request-scoped session is detached) can
# access dataset.metrics without triggering a lazy load. See
# apache/superset#39206 for the analogous database eager-load fix.
from superset.connectors.sqla.models import SqlaTable
from superset.models.slice import Slice
chart_query_options = [
subqueryload(Slice.table).subqueryload(SqlaTable.metrics),
]
with event_logger.log_context(action="mcp.get_chart_data.chart_lookup"):
await ctx.debug("Looking up chart: identifier=%s" % (request.identifier,))
if request.identifier is None:
@@ -190,7 +370,9 @@ async def get_chart_data( # noqa: C901
error="Chart identifier is required",
error_type="ValidationError",
)
chart = find_chart_by_identifier(request.identifier)
chart = find_chart_by_identifier(
request.identifier, query_options=chart_query_options
)
if not chart:
await ctx.warning("Chart not found: identifier=%s" % (request.identifier,))
@@ -484,8 +666,9 @@ async def get_chart_data( # noqa: C901
)
# Create rich column metadata
coltypes = query_result.get("coltypes", [])
columns = []
for col_name in raw_columns:
for idx, col_name in enumerate(raw_columns):
# Sample some values for metadata
sample_values = [
row.get(col_name)
@@ -493,13 +676,16 @@ async def get_chart_data( # noqa: C901
if row.get(col_name) is not None
]
# Infer data type
# Use SQL-derived GenericDataType when available,
# fall back to Python isinstance heuristic
data_type = "string"
if sample_values:
if all(isinstance(v, (int, float)) for v in sample_values):
data_type = "numeric"
elif all(isinstance(v, bool) for v in sample_values):
if coltypes:
data_type = _GENERIC_TYPE_MAP.get(coltypes[idx], "string")
elif sample_values:
if all(isinstance(v, bool) for v in sample_values):
data_type = "boolean"
elif all(isinstance(v, (int, float)) for v in sample_values):
data_type = "numeric"
columns.append(
DataColumn(
@@ -542,13 +728,11 @@ async def get_chart_data( # noqa: C901
else:
insights.append("Fresh data retrieved from database")
recommended_visualizations = []
if any(
"time" in col.lower() or "date" in col.lower() for col in raw_columns
):
recommended_visualizations.extend(["line chart", "time series"])
if len(raw_columns) <= 3:
recommended_visualizations.extend(["bar chart", "scatter plot"])
recommended_visualizations = _recommend_visualizations(
viz_type=chart.viz_type or "unknown",
columns=columns,
row_count=len(data),
)
# Performance metadata with cache awareness
execution_time = int((time.time() - start_time) * 1000)

View File

@@ -25,12 +25,19 @@ from fastmcp import Context
from sqlalchemy.orm import subqueryload
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.commands.dashboard.exceptions import DashboardNotFoundError
from superset.exceptions import SupersetSecurityException
from superset.extensions import event_logger
from superset.mcp_service.chart.chart_helpers import get_cached_form_data
from superset.mcp_service.chart.chart_helpers import (
build_applied_dashboard_filters,
ChartNotOnDashboardError,
get_cached_form_data,
)
from superset.mcp_service.chart.chart_utils import validate_chart_dataset
from superset.mcp_service.chart.schemas import (
CHART_FORM_DATA_EXCLUDED_FIELD_NAMES,
ChartError,
ChartFiltersInfo,
ChartInfo,
extract_filters_from_form_data,
GetChartInfoRequest,
@@ -89,6 +96,66 @@ FORM_DATA_OVERRIDE_EXCLUDED_FIELD_NAMES = (
)
async def _validate_chart_dataset_access(
result: ChartInfo, ctx: Context
) -> ChartError | None:
"""Validate that the chart's dataset is accessible to the current user.
Returns a ChartError if the dataset is not accessible, otherwise None.
Logs any non-fatal warnings (e.g., virtual dataset warnings) via ctx.
"""
from superset.daos.chart import ChartDAO
if not result.id:
return None
chart = ChartDAO.find_by_id(result.id)
if not chart:
return None
validation_result = validate_chart_dataset(chart, check_access=True)
if not validation_result.is_valid:
await ctx.warning(
"Chart found but dataset is not accessible: %s" % (validation_result.error,)
)
return ChartError(
error=validation_result.error or "Chart's dataset is not accessible",
error_type="DatasetNotAccessible",
)
for warning in validation_result.warnings:
await ctx.warning("Dataset warning: %s" % (warning,))
return None
async def _attach_dashboard_filters(
result: ChartInfo, dashboard_id: int, ctx: Context
) -> ChartError | None:
"""Resolve dashboard-scoped native filters and attach them to result.filters.
Returns a ChartError to surface to the caller on validation / access
failures, or None on success (including the no-filters case).
"""
if not result.id:
return None
with event_logger.log_context(action="mcp.get_chart_info.dashboard_filters"):
try:
dashboard_filters = build_applied_dashboard_filters(dashboard_id, result.id)
except DashboardNotFoundError as exc:
await ctx.warning("Dashboard not found: %s" % (str(exc),))
return ChartError(error=str(exc), error_type="DashboardNotFound")
except ChartNotOnDashboardError as exc:
await ctx.warning("Chart not on dashboard: %s" % (str(exc),))
return ChartError(error=str(exc), error_type="ChartNotOnDashboard")
except SupersetSecurityException as exc:
await ctx.warning("Dashboard not accessible: %s" % (str(exc),))
return ChartError(error=str(exc), error_type="DashboardNotAccessible")
if dashboard_filters:
if result.filters is None:
result.filters = ChartFiltersInfo(dashboard_filters=dashboard_filters)
else:
result.filters.dashboard_filters = dashboard_filters
return None
def _apply_unsaved_state_override(result: ChartInfo, form_data_key: str) -> None:
"""Override a ChartInfo's form_data with cached unsaved state."""
from superset.utils import json as utils_json
@@ -178,6 +245,17 @@ async def get_chart_info(
}
```
With dashboard context to resolve applied dashboard-level filters:
```json
{
"identifier": 123,
"dashboard_id": 45
}
```
When dashboard_id is provided, the response's filters.dashboard_filters
lists native filters (with column, operator, and value) that are in scope
for this chart on that dashboard.
Returns chart details including name, type, and URL.
"""
from superset.daos.chart import ChartDAO
@@ -247,23 +325,14 @@ async def get_chart_info(
)
# Validate the chart's dataset is accessible
if result.id:
chart = ChartDAO.find_by_id(result.id)
if chart:
validation_result = validate_chart_dataset(chart, check_access=True)
if not validation_result.is_valid:
await ctx.warning(
"Chart found but dataset is not accessible: %s"
% (validation_result.error,)
)
return ChartError(
error=validation_result.error
or "Chart's dataset is not accessible",
error_type="DatasetNotAccessible",
)
# Log any warnings (e.g., virtual dataset warnings)
for warning in validation_result.warnings:
await ctx.warning("Dataset warning: %s" % (warning,))
dataset_error = await _validate_chart_dataset_access(result, ctx)
if dataset_error is not None:
return dataset_error
if request.dashboard_id:
error = await _attach_dashboard_filters(result, request.dashboard_id, ctx)
if error is not None:
return error
else:
await ctx.warning("Chart retrieval failed: error=%s" % (str(result),))

View File

@@ -104,11 +104,16 @@ async def list_charts(
list_charts(search="revenue", page=1) # DO NOT DO THIS
Valid filter columns for ``filters[].col``:
``slice_name``, ``viz_type``, ``datasource_name``
``slice_name``, ``viz_type``, ``datasource_name``,
``created_by_fk``, ``changed_by_fk``
Sortable columns for ``order_column``:
``id``, ``slice_name``, ``viz_type``, ``description``,
``changed_on``, ``created_on``
To filter by a person, call find_users to resolve the name to a user ID,
then pass it as a filter: filters=[{"col": "created_by_fk", "opr": "eq",
"value": <id>}] (or "changed_by_fk"). Do not pass the name as search.
"""
request = request or _DEFAULT_LIST_CHARTS_REQUEST.model_copy(deep=True)
await ctx.info(

View File

@@ -37,10 +37,12 @@ class ColumnMetadata(BaseModel):
"""Metadata for a selectable column."""
name: str = Field(..., description="Column name to use in select_columns")
description: str | None = Field(None, description="Column description")
type: str | None = Field(None, description="Data type (str, int, datetime, etc.)")
description: str | None = Field(default=None, description="Column description")
type: str | None = Field(
default=None, description="Data type (str, int, datetime, etc.)"
)
is_default: bool = Field(
False, description="Whether this column is included by default"
default=False, description="Whether this column is included by default"
)

View File

@@ -67,7 +67,7 @@ from __future__ import annotations
import logging
from datetime import datetime
from typing import Annotated, Any, Dict, List, Literal, TYPE_CHECKING
from typing import Annotated, Any, cast, Dict, List, Literal, TYPE_CHECKING
import humanize
from pydantic import (
@@ -169,16 +169,20 @@ class DashboardFilter(ColumnOperator):
value: The value to filter by (type depends on col and opr).
"""
col: Literal[
col: Literal[ # pyright: ignore[reportIncompatibleVariableOverride]
"dashboard_title",
"published",
"favorite",
"created_by_fk",
"changed_by_fk",
] = Field(
...,
description=(
"Column to filter on. Valid values: 'dashboard_title', 'published', "
"'favorite'. Other column names are not valid filter columns and will "
"cause a validation error."
"Column to filter on. Use "
"get_schema(model_type='dashboard') for available "
"filter columns. To filter by a person, first call find_users to "
"resolve a name to a user ID, then filter by created_by_fk or "
"changed_by_fk with that integer ID."
),
)
opr: ColumnOperatorEnum = Field(
@@ -223,7 +227,10 @@ class ListDashboardsRequest(OwnedByMeMixin, CreatedByMeMixin, MetadataCacheContr
"""
from superset.mcp_service.utils.schema_utils import parse_json_or_model_list
return parse_json_or_model_list(v, DashboardFilter, "filters")
return cast(
List[DashboardFilter],
parse_json_or_model_list(v, DashboardFilter, "filters"),
)
@field_validator("select_columns", mode="before")
@classmethod
@@ -392,14 +399,14 @@ class DashboardInfo(BaseModel):
# Fields for permalink/filter state support
permalink_key: str | None = Field(
None,
default=None,
description=(
"Permalink key used to retrieve filter state. When present, indicates "
"the filter_state came from a permalink rather than the default dashboard."
),
)
filter_state: Dict[str, Any] | None = Field(
None,
default=None,
description=(
"Filter state from permalink. Contains dataMask (native filter values), "
"activeTabs, anchor, and urlParams. When present, represents the actual "

View File

@@ -98,11 +98,18 @@ async def list_dashboards(
list_dashboards(search="sales", page=1) # DO NOT DO THIS
Valid filter columns for ``filters[].col``:
``dashboard_title``, ``published``, ``favorite``
``dashboard_title``, ``published``, ``favorite``,
``created_by_fk``, ``changed_by_fk``
Sortable columns for ``order_column``:
``id``, ``dashboard_title``, ``slug``, ``published``,
``changed_on``, ``created_on``
To filter by a person (e.g. "dashboards Maxime is working on"), do NOT pass
the name as the search parameter — search matches titles and slugs only.
Instead, call find_users to resolve the name to a user ID, then pass it as
a filter: filters=[{"col": "created_by_fk", "opr": "eq", "value": <id>}]
(or "changed_by_fk" for "last modified by").
"""
request = request or _DEFAULT_LIST_DASHBOARDS_REQUEST.model_copy(deep=True)
await ctx.info(

View File

@@ -22,7 +22,7 @@ Pydantic schemas for database-related responses
from __future__ import annotations
from datetime import datetime
from typing import Annotated, Any, Dict, List, Literal
from typing import Annotated, Any, cast, Dict, List, Literal
import humanize
from pydantic import (
@@ -58,7 +58,7 @@ class DatabaseFilter(ColumnOperator):
value: The value to filter by (type depends on col and opr).
"""
col: Literal[
col: Literal[ # pyright: ignore[reportIncompatibleVariableOverride]
"database_name",
"expose_in_sqllab",
"allow_file_upload",
@@ -242,7 +242,10 @@ class ListDatabasesRequest(CreatedByMeMixin, MetadataCacheControl):
@classmethod
def parse_filters(cls, v: Any) -> List[DatabaseFilter]:
"""Accept both JSON string and list of objects."""
return parse_json_or_model_list(v, DatabaseFilter, "filters")
return cast(
List[DatabaseFilter],
parse_json_or_model_list(v, DatabaseFilter, "filters"),
)
@field_validator("select_columns", mode="before")
@classmethod

View File

@@ -65,17 +65,18 @@ class DatasetFilter(ColumnOperator):
value: The value to filter by (type depends on col and opr).
"""
col: Literal[
col: Literal[ # pyright: ignore[reportIncompatibleVariableOverride]
"table_name",
"schema",
"database_name",
"created_by_fk",
"changed_by_fk",
] = Field(
...,
description=(
"Column to filter on. Valid values: 'table_name', 'schema', "
"'database_name'. Other column names (e.g. 'created_by_fk', 'id') "
"are not valid filter columns and will cause a validation error."
),
description="Column to filter on. Use get_schema(model_type='dataset') for "
"available filter columns. To filter by a person, first call find_users "
"to resolve a name to a user ID, then filter by created_by_fk or "
"changed_by_fk with that integer ID.",
)
opr: ColumnOperatorEnum = Field(
...,
@@ -658,7 +659,7 @@ def serialize_dataset_object(dataset: Any) -> DatasetInfo | None:
params = None
columns = [
TableColumnInfo(
column_name=getattr(col, "column_name", None),
column_name=getattr(col, "column_name", None) or "",
verbose_name=getattr(col, "verbose_name", None),
type=getattr(col, "type", None),
is_dttm=getattr(col, "is_dttm", None),
@@ -670,7 +671,7 @@ def serialize_dataset_object(dataset: Any) -> DatasetInfo | None:
]
metrics = [
SqlMetricInfo(
metric_name=getattr(metric, "metric_name", None),
metric_name=getattr(metric, "metric_name", None) or "",
verbose_name=getattr(metric, "verbose_name", None),
expression=getattr(metric, "expression", None),
description=getattr(metric, "description", None),

View File

@@ -109,10 +109,15 @@ async def list_datasets(
list_datasets(search="sales", page=1) # DO NOT DO THIS
Valid filter columns for ``filters[].col``:
``table_name``, ``schema``, ``database_name``
``table_name``, ``schema``, ``database_name``,
``created_by_fk``, ``changed_by_fk``
Sortable columns for ``order_column``:
``id``, ``table_name``, ``schema``, ``changed_on``, ``created_on``
To filter by a person, call find_users to resolve the name to a user ID,
then pass it as a filter: filters=[{"col": "created_by_fk", "opr": "eq",
"value": <id>}] (or "changed_by_fk"). Do not pass the name as search.
"""
if ctx is None:
raise RuntimeError("FastMCP context is required for list_datasets")

View File

@@ -65,10 +65,12 @@ async def generate_explore_link(
- "Visualize [data]"
- General data exploration
- When user wants to SEE data visually
- Opening a dataset in Explore without a preconfigured chart (omit config)
IMPORTANT:
- Use numeric dataset ID or UUID (NOT schema.table_name format)
- MUST include chart_type in config (either 'xy' or 'table')
- When config is provided, MUST include chart_type (e.g. 'xy' or 'table')
- Omit config entirely to return a default explore URL for the dataset
Example usage:
```json
@@ -83,6 +85,11 @@ async def generate_explore_link(
}
```
Or with no config to simply open the dataset in Explore:
```json
{"dataset_id": 123}
```
Better UX because:
- Users can interact with chart before saving
- Easy to modify parameters instantly
@@ -93,9 +100,10 @@ async def generate_explore_link(
Returns explore URL for immediate use.
"""
chart_type = request.config.chart_type if request.config else "none"
await ctx.info(
"Generating explore link for dataset_id=%s, chart_type=%s"
% (request.dataset_id, request.config.chart_type)
% (request.dataset_id, chart_type)
)
await ctx.debug(
"Configuration details: use_cache=%s, force_refresh=%s, cache_form_data=%s"
@@ -103,9 +111,6 @@ async def generate_explore_link(
)
try:
# config is already a typed ChartConfig (validated by Pydantic)
config = request.config
await ctx.report_progress(1, 4, "Validating dataset exists")
with event_logger.log_context(action="mcp.generate_explore_link.dataset_check"):
from superset.daos.dataset import DatasetDAO
@@ -157,8 +162,32 @@ async def generate_explore_link(
),
}
# When no config is provided, return a default explore URL that opens
# the dataset in Superset without a preconfigured chart.
if request.config is None:
await ctx.report_progress(4, 4, "URL generation complete")
from superset.mcp_service.utils.url_utils import get_superset_base_url
base_url = get_superset_base_url()
default_url = (
f"{base_url}/explore/?datasource_type=table&datasource_id={dataset.id}"
)
await ctx.info(
"Default explore link generated: dataset_id=%s" % (request.dataset_id,)
)
return {
"url": default_url,
"form_data": {},
"form_data_key": None,
"chart_type_label": None,
"error": None,
}
await ctx.report_progress(2, 4, "Converting configuration to form data")
with event_logger.log_context(action="mcp.generate_explore_link.form_data"):
# config is already a typed ChartConfig (validated by Pydantic)
config = request.config
# Normalize column names to match canonical dataset column names
# This fixes case sensitivity issues (e.g., 'order_date' vs 'OrderDate')
try:
@@ -256,7 +285,7 @@ async def generate_explore_link(
"Explore link generation failed for dataset_id=%s, chart_type=%s: %s: %s"
% (
request.dataset_id,
request.config.chart_type,
chart_type,
type(e).__name__,
str(e),
)

View File

@@ -34,6 +34,7 @@ from superset.mcp_service.privacy import (
filter_user_directory_columns,
SELF_REFERENCING_FILTER_COLUMNS,
USER_DIRECTORY_FIELDS,
USER_FILTER_FIELDS,
)
from superset.mcp_service.system.schemas import PaginationInfo
from superset.mcp_service.utils import _is_uuid
@@ -314,14 +315,6 @@ class ModelListCore(BaseCore, Generic[L]):
has_previous=page > 0,
)
# Build response
def get_keys(obj: BaseModel | dict[str, Any] | Any) -> List[str]:
if hasattr(obj, "model_dump"):
return list(obj.model_dump().keys())
elif isinstance(obj, dict):
return list(obj.keys())
return []
response_kwargs = {
self.list_field_name: item_objs,
"count": len(item_objs),
@@ -595,7 +588,7 @@ class InstanceInfoCore(BaseCore):
return counts
def _calculate_time_based_metrics(
self, base_counts: Dict[str, int]
self, _base_counts: Dict[str, int]
) -> Dict[str, Dict[str, int]]:
"""Calculate time-based metrics for recent activity."""
now = datetime.now(timezone.utc)
@@ -774,7 +767,9 @@ class ModelGetSchemaCore(BaseCore, Generic[S]):
self.default_sort = default_sort
self.default_sort_direction = default_sort_direction
self.exclude_filter_columns = set(exclude_filter_columns or set())
self.exclude_filter_columns.update(USER_DIRECTORY_FIELDS)
# Hide user-directory columns from filter discovery, except the small
# set callers may legitimately filter by ID (resolved via find_users).
self.exclude_filter_columns.update(USER_DIRECTORY_FIELDS - USER_FILTER_FIELDS)
def _get_filter_columns(self) -> Dict[str, List[str]]:
"""Get filterable columns and operators from the DAO."""

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,23 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.daos.base import BaseDAO
from superset.models.dynamic_plugins import DynamicPlugin
class DynamicPluginDAO(BaseDAO[DynamicPlugin]):
pass

View File

@@ -0,0 +1,213 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Pydantic schemas for dynamic plugin responses.
"""
from __future__ import annotations
from datetime import datetime
from typing import Annotated, Any, Dict, List, Literal
from pydantic import (
BaseModel,
ConfigDict,
Field,
field_validator,
model_serializer,
model_validator,
PositiveInt,
)
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
from superset.mcp_service.system.schemas import PaginationInfo
from superset.mcp_service.utils.schema_utils import (
parse_json_or_list,
parse_json_or_model_list,
)
DEFAULT_PLUGIN_COLUMNS = ["id", "name", "key", "bundle_url"]
ALL_PLUGIN_COLUMNS = [
"id",
"name",
"key",
"bundle_url",
"changed_on",
"created_on",
]
SORTABLE_PLUGIN_COLUMNS = ["id", "name", "key", "changed_on", "created_on"]
class PluginColumnFilter(ColumnOperator):
"""Filter object for plugin listing."""
col: Literal["name", "key"] = Field(..., description="Column to filter on.")
opr: ColumnOperatorEnum = Field(..., description="Operator to use.")
value: str | int | float | bool | List[str | int | float | bool] = Field(
..., description="Value to filter by"
)
class PluginInfo(BaseModel):
id: int | None = Field(None, description="Plugin ID")
name: str | None = Field(None, description="Plugin display name")
key: str | None = Field(None, description="Plugin key (corresponds to viz_type)")
bundle_url: str | None = Field(None, description="URL to the plugin bundle")
changed_on: str | datetime | None = Field(
None, description="Last modification timestamp"
)
created_on: str | datetime | None = Field(None, description="Creation timestamp")
model_config = ConfigDict(
from_attributes=True,
ser_json_timedelta="iso8601",
populate_by_name=True,
)
@model_serializer(mode="wrap")
def _filter_fields_by_context(self, serializer: Any, info: Any) -> Dict[str, Any]:
data = serializer(self)
if info.context and isinstance(info.context, dict):
select_columns = info.context.get("select_columns")
if select_columns:
requested_fields = set(select_columns)
return {k: v for k, v in data.items() if k in requested_fields}
return data
class PluginList(BaseModel):
plugins: List[PluginInfo]
count: int
total_count: int
page: int
page_size: int
total_pages: int
has_previous: bool
has_next: bool
columns_requested: List[str] = Field(default_factory=list)
columns_loaded: List[str] = Field(default_factory=list)
columns_available: List[str] = Field(default_factory=list)
sortable_columns: List[str] = Field(default_factory=list)
filters_applied: List[PluginColumnFilter] = Field(default_factory=list)
pagination: PaginationInfo | None = None
timestamp: datetime | None = None
model_config = ConfigDict(ser_json_timedelta="iso8601")
class ListPluginsRequest(BaseModel):
"""Request schema for list_plugins."""
filters: Annotated[
List[PluginColumnFilter],
Field(
default_factory=list,
description="List of filter objects (col, opr, value). "
"Cannot be used with search.",
),
]
select_columns: Annotated[
List[str],
Field(
default_factory=list,
description="Columns to include in response. Defaults to common columns.",
),
]
search: Annotated[
str | None,
Field(
default=None,
description="Text search on plugin name or key. "
"Cannot be used with filters.",
),
]
order_column: Annotated[
str | None, Field(default=None, description="Column to order results by")
]
order_direction: Annotated[
Literal["asc", "desc"],
Field(default="desc", description="Sort direction"),
]
page: Annotated[
PositiveInt,
Field(default=1, description="Page number (1-based)"),
]
page_size: Annotated[
int,
Field(
default=DEFAULT_PAGE_SIZE,
gt=0,
le=MAX_PAGE_SIZE,
description=f"Items per page (max {MAX_PAGE_SIZE})",
),
]
@field_validator("filters", mode="before")
@classmethod
def parse_filters(cls, v: Any) -> List[PluginColumnFilter]:
return parse_json_or_model_list(v, PluginColumnFilter, "filters")
@field_validator("select_columns", mode="before")
@classmethod
def parse_columns(cls, v: Any) -> List[str]:
return parse_json_or_list(v, "select_columns")
@model_validator(mode="after")
def validate_search_and_filters(self) -> "ListPluginsRequest":
if self.search and self.filters:
raise ValueError("Cannot use both 'search' and 'filters' simultaneously.")
return self
class PluginError(BaseModel):
error: str = Field(..., description="Error message")
error_type: str = Field(..., description="Type of error")
timestamp: str | datetime | None = Field(None, description="Error timestamp")
model_config = ConfigDict(ser_json_timedelta="iso8601")
@classmethod
def create(cls, error: str, error_type: str) -> "PluginError":
from datetime import timezone
return cls(
error=error, error_type=error_type, timestamp=datetime.now(timezone.utc)
)
class GetPluginInfoRequest(BaseModel):
"""Request schema for get_plugin_info."""
identifier: Annotated[
int,
Field(description="Plugin ID"),
]
def serialize_plugin_object(plugin: Any) -> PluginInfo | None:
if not plugin:
return None
return PluginInfo(
id=getattr(plugin, "id", None),
name=getattr(plugin, "name", None),
key=getattr(plugin, "key", None),
bundle_url=getattr(plugin, "bundle_url", None),
changed_on=getattr(plugin, "changed_on", None),
created_on=getattr(plugin, "created_on", None),
)

View File

@@ -0,0 +1,24 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from .get_plugin_info import get_plugin_info
from .list_plugins import list_plugins
__all__ = [
"list_plugins",
"get_plugin_info",
]

View File

@@ -0,0 +1,101 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Get plugin info FastMCP tool.
"""
import logging
from datetime import datetime, timezone
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.extensions import event_logger
from superset.mcp_service.mcp_core import ModelGetInfoCore
from superset.mcp_service.plugin.schemas import (
GetPluginInfoRequest,
PluginError,
PluginInfo,
serialize_plugin_object,
)
logger = logging.getLogger(__name__)
@tool(
tags=["discovery"],
class_permission_name="DynamicPlugin",
annotations=ToolAnnotations(
title="Get plugin info",
readOnlyHint=True,
destructiveHint=False,
),
)
async def get_plugin_info(
request: GetPluginInfoRequest, ctx: Context
) -> PluginInfo | PluginError:
"""Get dynamic plugin details by ID. Requires admin access.
Returns full plugin configuration including name, key, and bundle URL.
Example usage:
```json
{"identifier": 1}
```
"""
await ctx.info(
"Retrieving plugin information: identifier=%s" % (request.identifier,)
)
try:
from superset.mcp_service.plugin.dao import DynamicPluginDAO
with event_logger.log_context(action="mcp.get_plugin_info.lookup"):
get_tool = ModelGetInfoCore(
dao_class=DynamicPluginDAO,
output_schema=PluginInfo,
error_schema=PluginError,
serializer=serialize_plugin_object,
supports_slug=False,
logger=logger,
)
result = get_tool.run_tool(request.identifier)
if isinstance(result, PluginInfo):
await ctx.info(
"Plugin retrieved: id=%s, name=%s, key=%s"
% (result.id, result.name, result.key)
)
else:
await ctx.warning(
"Plugin retrieval failed: error_type=%s, error=%s"
% (result.error_type, result.error)
)
return result
except Exception as e:
await ctx.error(
"Plugin info retrieval failed: identifier=%s, error=%s"
% (request.identifier, str(e))
)
return PluginError(
error=f"Failed to get plugin info: {str(e)}",
error_type="InternalError",
timestamp=datetime.now(timezone.utc),
)

View File

@@ -0,0 +1,123 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
List plugins FastMCP tool.
"""
import logging
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.extensions import event_logger
from superset.mcp_service.mcp_core import ModelListCore
from superset.mcp_service.plugin.schemas import (
ALL_PLUGIN_COLUMNS,
DEFAULT_PLUGIN_COLUMNS,
ListPluginsRequest,
PluginColumnFilter,
PluginError,
PluginInfo,
PluginList,
serialize_plugin_object,
SORTABLE_PLUGIN_COLUMNS,
)
logger = logging.getLogger(__name__)
_DEFAULT_LIST_PLUGINS_REQUEST = ListPluginsRequest()
@tool(
tags=["core"],
class_permission_name="DynamicPlugin",
annotations=ToolAnnotations(
title="List plugins",
readOnlyHint=True,
destructiveHint=False,
),
)
async def list_plugins(
request: ListPluginsRequest | None = None,
ctx: Context | None = None,
) -> PluginList | PluginError:
"""List dynamic plugins registered in this Superset instance. Requires admin access.
Returns plugin metadata including name, key, and bundle URL.
Sortable columns for order_column: id, name, key, changed_on, created_on
"""
if ctx is None:
raise RuntimeError("FastMCP context is required for list_plugins")
request = request or _DEFAULT_LIST_PLUGINS_REQUEST.model_copy(deep=True)
await ctx.info(
"Listing plugins: page=%s, page_size=%s, search=%s"
% (request.page, request.page_size, request.search)
)
try:
from superset.mcp_service.plugin.dao import DynamicPluginDAO
def _serialize(obj: object, cols: list[str]) -> PluginInfo | None:
return serialize_plugin_object(obj)
list_tool = ModelListCore(
dao_class=DynamicPluginDAO,
output_schema=PluginInfo,
item_serializer=_serialize,
filter_type=PluginColumnFilter,
default_columns=DEFAULT_PLUGIN_COLUMNS,
search_columns=["name", "key"],
list_field_name="plugins",
output_list_schema=PluginList,
all_columns=ALL_PLUGIN_COLUMNS,
sortable_columns=SORTABLE_PLUGIN_COLUMNS,
logger=logger,
)
with event_logger.log_context(action="mcp.list_plugins.query"):
result = list_tool.run_tool(
filters=request.filters,
search=request.search,
select_columns=request.select_columns,
order_column=request.order_column,
order_direction=request.order_direction,
page=max(request.page - 1, 0),
page_size=request.page_size,
)
await ctx.info(
"Plugins listed: count=%s, total_count=%s"
% (len(result.plugins), result.total_count)
)
columns_to_filter = result.columns_requested
with event_logger.log_context(action="mcp.list_plugins.serialization"):
return result.model_dump(
mode="json",
context={"select_columns": columns_to_filter},
)
except Exception as e:
await ctx.error(
"Plugin listing failed: error=%s, error_type=%s"
% (str(e), type(e).__name__)
)
raise

View File

@@ -44,13 +44,20 @@ USER_DIRECTORY_FIELDS = frozenset(
}
)
# User-directory columns that may be used as filter values (an integer user ID).
# These remain stripped from select_columns, sort, search, and tool responses
# (so the directory itself is never exposed), but list tools may filter rows by
# them when the caller already has an ID — typically resolved via find_users.
USER_FILTER_FIELDS = frozenset({"created_by_fk", "changed_by_fk"})
# Internal DAO filter column names generated server-side when translating the
# created_by_me / owned_by_me boolean flags (see mcp_core._prepend_self_lookup_filters).
# These columns are never exposed to LLM callers; they are excluded from the
# filters_applied response field to avoid leaking internal implementation details.
SELF_REFERENCING_FILTER_COLUMNS = frozenset(
{"created_by_fk", "owner", "created_by_fk_or_owner"}
)
# Note: ``created_by_fk`` is intentionally excluded — it is also a publicly
# advertised filter column (see USER_FILTER_FIELDS) so callers can filter by a
# user ID resolved via find_users.
SELF_REFERENCING_FILTER_COLUMNS = frozenset({"owner", "created_by_fk_or_owner"})
DATA_MODEL_METADATA_ACCESS_ATTR = "_requires_data_model_metadata_access"
DATA_MODEL_METADATA_ERROR_TYPE = "DataModelMetadataRestricted"
@@ -133,7 +140,7 @@ def user_can_view_data_model_metadata() -> bool:
def filter_user_directory_fields(data: dict[str, Any]) -> dict[str, Any]:
"""Remove fields that expose users, roles, owners, or access metadata."""
"""Remove fields that expose users, owners, or access metadata."""
return {
key: value for key, value in data.items() if key not in USER_DIRECTORY_FIELDS
}

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,255 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Pydantic schemas for row level security filter responses.
"""
from __future__ import annotations
from datetime import datetime
from typing import Annotated, Any, Dict, List, Literal
from pydantic import (
BaseModel,
ConfigDict,
Field,
field_validator,
model_serializer,
model_validator,
PositiveInt,
)
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
from superset.mcp_service.system.schemas import PaginationInfo
from superset.mcp_service.utils.schema_utils import (
parse_json_or_list,
parse_json_or_model_list,
)
DEFAULT_RLS_COLUMNS = ["id", "name", "filter_type", "clause"]
ALL_RLS_COLUMNS = [
"id",
"name",
"filter_type",
"tables",
"roles",
"clause",
"group_key",
"changed_on",
]
SORTABLE_RLS_COLUMNS = ["id", "name", "filter_type", "changed_on"]
class RlsColumnFilter(ColumnOperator):
"""Filter object for RLS filter listing."""
col: Literal["name", "filter_type"] = Field(
...,
description="Column to filter on.",
)
opr: ColumnOperatorEnum = Field(..., description="Operator to use.")
value: str | int | float | bool | List[str | int | float | bool] = Field(
..., description="Value to filter by"
)
class RlsTableRef(BaseModel):
id: int | None = Field(None, description="Table ID")
table_name: str | None = Field(None, description="Table name")
model_config = ConfigDict(from_attributes=True)
class RlsRoleRef(BaseModel):
id: int | None = Field(None, description="Role ID")
name: str | None = Field(None, description="Role name")
model_config = ConfigDict(from_attributes=True)
class RlsFilterInfo(BaseModel):
id: int | None = Field(None, description="RLS filter ID")
name: str | None = Field(None, description="RLS filter name")
filter_type: str | None = Field(None, description="Filter type: Regular or Base")
tables: List[RlsTableRef] | None = Field(
None, description="Tables this filter applies to"
)
roles: List[RlsRoleRef] | None = Field(
None, description="Roles this filter applies to"
)
clause: str | None = Field(None, description="SQL WHERE clause")
group_key: str | None = Field(
None, description="Group key for Base filter grouping"
)
changed_on: str | datetime | None = Field(
None, description="Last modification timestamp"
)
model_config = ConfigDict(
from_attributes=True,
ser_json_timedelta="iso8601",
populate_by_name=True,
)
@model_serializer(mode="wrap")
def _filter_fields_by_context(self, serializer: Any, info: Any) -> Dict[str, Any]:
data = serializer(self)
if info.context and isinstance(info.context, dict):
select_columns = info.context.get("select_columns")
if select_columns:
requested_fields = set(select_columns)
return {k: v for k, v in data.items() if k in requested_fields}
return data
class RlsFilterList(BaseModel):
rls_filters: List[RlsFilterInfo]
count: int
total_count: int
page: int
page_size: int
total_pages: int
has_previous: bool
has_next: bool
columns_requested: List[str] = Field(default_factory=list)
columns_loaded: List[str] = Field(default_factory=list)
columns_available: List[str] = Field(default_factory=list)
sortable_columns: List[str] = Field(default_factory=list)
filters_applied: List[RlsColumnFilter] = Field(default_factory=list)
pagination: PaginationInfo | None = None
timestamp: datetime | None = None
model_config = ConfigDict(ser_json_timedelta="iso8601")
class ListRlsFiltersRequest(BaseModel):
"""Request schema for list_rls_filters."""
filters: Annotated[
List[RlsColumnFilter],
Field(
default_factory=list,
description="List of filter objects (col, opr, value). "
"Cannot be used with search.",
),
]
select_columns: Annotated[
List[str],
Field(
default_factory=list,
description="Columns to include in response. Defaults to common columns.",
),
]
search: Annotated[
str | None,
Field(
default=None,
description="Text search on filter name. Cannot be used with filters.",
),
]
order_column: Annotated[
str | None, Field(default=None, description="Column to order results by")
]
order_direction: Annotated[
Literal["asc", "desc"],
Field(default="desc", description="Sort direction"),
]
page: Annotated[
PositiveInt,
Field(default=1, description="Page number (1-based)"),
]
page_size: Annotated[
int,
Field(
default=DEFAULT_PAGE_SIZE,
gt=0,
le=MAX_PAGE_SIZE,
description=f"Items per page (max {MAX_PAGE_SIZE})",
),
]
@field_validator("filters", mode="before")
@classmethod
def parse_filters(cls, v: Any) -> List[RlsColumnFilter]:
return parse_json_or_model_list(v, RlsColumnFilter, "filters")
@field_validator("select_columns", mode="before")
@classmethod
def parse_columns(cls, v: Any) -> List[str]:
return parse_json_or_list(v, "select_columns")
@model_validator(mode="after")
def validate_search_and_filters(self) -> "ListRlsFiltersRequest":
if self.search and self.filters:
raise ValueError("Cannot use both 'search' and 'filters' simultaneously.")
return self
class RlsFilterError(BaseModel):
error: str = Field(..., description="Error message")
error_type: str = Field(..., description="Type of error")
timestamp: str | datetime | None = Field(None, description="Error timestamp")
model_config = ConfigDict(ser_json_timedelta="iso8601")
@classmethod
def create(cls, error: str, error_type: str) -> "RlsFilterError":
from datetime import timezone
return cls(
error=error, error_type=error_type, timestamp=datetime.now(timezone.utc)
)
class GetRlsFilterInfoRequest(BaseModel):
"""Request schema for get_rls_filter_info."""
identifier: Annotated[
int,
Field(description="RLS filter ID"),
]
def serialize_rls_filter_object(rls_filter: Any) -> RlsFilterInfo | None:
if not rls_filter:
return None
tables = [
RlsTableRef(
id=getattr(t, "id", None),
table_name=getattr(t, "table_name", None),
)
for t in (getattr(rls_filter, "tables", None) or [])
]
roles = [
RlsRoleRef(
id=getattr(r, "id", None),
name=getattr(r, "name", None),
)
for r in (getattr(rls_filter, "roles", None) or [])
]
return RlsFilterInfo(
id=getattr(rls_filter, "id", None),
name=getattr(rls_filter, "name", None),
filter_type=getattr(rls_filter, "filter_type", None),
tables=tables,
roles=roles,
clause=getattr(rls_filter, "clause", None),
group_key=getattr(rls_filter, "group_key", None),
changed_on=getattr(rls_filter, "changed_on", None),
)

View File

@@ -0,0 +1,24 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from .get_rls_filter_info import get_rls_filter_info
from .list_rls_filters import list_rls_filters
__all__ = [
"list_rls_filters",
"get_rls_filter_info",
]

View File

@@ -0,0 +1,101 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Get RLS filter info FastMCP tool.
"""
import logging
from datetime import datetime, timezone
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.extensions import event_logger
from superset.mcp_service.mcp_core import ModelGetInfoCore
from superset.mcp_service.rls.schemas import (
GetRlsFilterInfoRequest,
RlsFilterError,
RlsFilterInfo,
serialize_rls_filter_object,
)
logger = logging.getLogger(__name__)
@tool(
tags=["discovery"],
class_permission_name="Row Level Security",
annotations=ToolAnnotations(
title="Get RLS filter info",
readOnlyHint=True,
destructiveHint=False,
),
)
async def get_rls_filter_info(
request: GetRlsFilterInfoRequest, ctx: Context
) -> RlsFilterInfo | RlsFilterError:
"""Get row level security filter details by ID. Requires admin access.
Returns full RLS filter configuration including name, type, tables, roles,
and clause.
Example usage:
```json
{"identifier": 1}
```
"""
await ctx.info(
"Retrieving RLS filter information: identifier=%s" % (request.identifier,)
)
try:
from superset.daos.security import RLSDAO
with event_logger.log_context(action="mcp.get_rls_filter_info.lookup"):
get_tool = ModelGetInfoCore(
dao_class=RLSDAO,
output_schema=RlsFilterInfo,
error_schema=RlsFilterError,
serializer=serialize_rls_filter_object,
supports_slug=False,
logger=logger,
)
result = get_tool.run_tool(request.identifier)
if isinstance(result, RlsFilterInfo):
await ctx.info(
"RLS filter retrieved: id=%s, name=%s" % (result.id, result.name)
)
else:
await ctx.warning(
"RLS filter retrieval failed: error_type=%s, error=%s"
% (result.error_type, result.error)
)
return result
except Exception as e:
await ctx.error(
"RLS filter info retrieval failed: identifier=%s, error=%s"
% (request.identifier, str(e))
)
return RlsFilterError(
error=f"Failed to get RLS filter info: {str(e)}",
error_type="InternalError",
timestamp=datetime.now(timezone.utc),
)

View File

@@ -0,0 +1,149 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
List RLS filters FastMCP tool.
"""
import logging
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.extensions import event_logger
from superset.mcp_service.mcp_core import ModelListCore
from superset.mcp_service.privacy import USER_DIRECTORY_FIELDS
from superset.mcp_service.rls.schemas import (
ALL_RLS_COLUMNS,
DEFAULT_RLS_COLUMNS,
ListRlsFiltersRequest,
RlsColumnFilter,
RlsFilterError,
RlsFilterInfo,
RlsFilterList,
serialize_rls_filter_object,
SORTABLE_RLS_COLUMNS,
)
logger = logging.getLogger(__name__)
_DEFAULT_LIST_RLS_FILTERS_REQUEST = ListRlsFiltersRequest()
@tool(
tags=["core"],
class_permission_name="Row Level Security",
annotations=ToolAnnotations(
title="List RLS filters",
readOnlyHint=True,
destructiveHint=False,
),
)
async def list_rls_filters(
request: ListRlsFiltersRequest | None = None,
ctx: Context | None = None,
) -> RlsFilterList | RlsFilterError:
"""List row level security filters. Requires admin access.
Returns RLS filter metadata including name, filter type, tables, roles, and clause.
Sortable columns for order_column: id, name, filter_type, changed_on
"""
if ctx is None:
raise RuntimeError("FastMCP context is required for list_rls_filters")
request = request or _DEFAULT_LIST_RLS_FILTERS_REQUEST.model_copy(deep=True)
await ctx.info(
"Listing RLS filters: page=%s, page_size=%s, search=%s"
% (request.page, request.page_size, request.search)
)
try:
from superset.daos.security import RLSDAO
def _serialize_rls_filter(obj: object, cols: list[str]) -> RlsFilterInfo | None:
return serialize_rls_filter_object(obj)
list_tool = ModelListCore(
dao_class=RLSDAO,
output_schema=RlsFilterInfo,
item_serializer=_serialize_rls_filter,
filter_type=RlsColumnFilter,
default_columns=DEFAULT_RLS_COLUMNS,
search_columns=["name"],
list_field_name="rls_filters",
output_list_schema=RlsFilterList,
all_columns=ALL_RLS_COLUMNS,
sortable_columns=SORTABLE_RLS_COLUMNS,
logger=logger,
)
# RLS 'roles' is valid filter data but lives in USER_DIRECTORY_FIELDS,
# so ModelListCore would raise ValueError for a column list that reduces
# to empty after privacy filtering (e.g. select_columns=["roles"]).
# Strip directory-field columns here; the bypass below restores them in
# the final serialized output from ALL_RLS_COLUMNS.
run_tool_columns = None
if request.select_columns:
non_directory = [
c for c in request.select_columns if c not in USER_DIRECTORY_FIELDS
]
run_tool_columns = non_directory if non_directory else None
with event_logger.log_context(action="mcp.list_rls_filters.query"):
result = list_tool.run_tool(
filters=request.filters,
search=request.search,
select_columns=run_tool_columns,
order_column=request.order_column,
order_direction=request.order_direction,
page=max(request.page - 1, 0),
page_size=request.page_size,
)
await ctx.info(
"RLS filters listed: count=%s, total_count=%s"
% (len(result.rls_filters), result.total_count)
)
# Build column selection using ALL_RLS_COLUMNS as the source of truth,
# bypassing the USER_DIRECTORY_FIELDS privacy filter applied by
# ModelListCore. 'roles' in an RLS filter is which roles the filter
# applies to — core filter data — not user-directory metadata (like
# dashboard.roles, which exposes who has access to the resource).
if request.select_columns:
columns_to_filter = [
c for c in request.select_columns if c in ALL_RLS_COLUMNS
]
if not columns_to_filter:
columns_to_filter = list(DEFAULT_RLS_COLUMNS)
else:
columns_to_filter = list(DEFAULT_RLS_COLUMNS)
with event_logger.log_context(action="mcp.list_rls_filters.serialization"):
return result.model_dump(
mode="json",
context={"select_columns": columns_to_filter},
)
except Exception as e:
await ctx.error(
"RLS filter listing failed: error=%s, error_type=%s"
% (str(e), type(e).__name__)
)
raise

View File

@@ -25,9 +25,11 @@ system-level info.
from __future__ import annotations
from datetime import datetime
from typing import Any
from typing import Annotated, Any, List
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, field_validator
from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
class HealthCheckResponse(BaseModel):
@@ -170,6 +172,84 @@ def serialize_user_object(user: Any) -> UserInfo | None:
)
class FindUsersRequest(BaseModel):
"""Request schema for find_users tool.
Resolves a person's name (or partial name, username, or email) to user IDs
so they can be passed to listing tools as filter values for created_by_fk
or changed_by_fk. This is the only sanctioned path for "show me what
<person> is working on" queries.
"""
model_config = ConfigDict(extra="forbid")
query: Annotated[
str,
Field(
min_length=1,
max_length=200,
description=(
"Substring to match (case-insensitive) against username, "
"first_name, last_name, and email. Required and non-empty: "
"this tool does not enumerate the full user directory."
),
),
]
page_size: Annotated[
int,
Field(
default=DEFAULT_PAGE_SIZE,
gt=0,
le=MAX_PAGE_SIZE,
description=f"Maximum number of matches to return (max {MAX_PAGE_SIZE}).",
),
]
@field_validator("query")
@classmethod
def _reject_blank_query(cls, value: str) -> str:
# min_length=1 alone admits whitespace-only strings, which strip to "" and
# produce a "%%" LIKE pattern that matches every user. Strip and require
# at least one non-space character.
stripped = value.strip()
if not stripped:
raise ValueError("query must contain at least one non-whitespace character")
return stripped
class UserMatch(BaseModel):
"""Minimal user projection returned by find_users.
Intentionally narrower than UserInfo: only the fields needed to disambiguate
matches and pass an id to created_by_fk / changed_by_fk filters. Email,
active flag, and roles are deliberately excluded to limit identity
exposure through this directory-resolution path.
"""
id: int | None = None
username: str | None = None
first_name: str | None = None
last_name: str | None = None
class FindUsersResponse(BaseModel):
"""Response schema for find_users tool."""
users: List[UserMatch] = Field(
default_factory=list,
description=(
"Matching users. Pass user.id as the value for created_by_fk or "
"changed_by_fk filters on list_dashboards, list_charts, and "
"list_datasets."
),
)
count: int = Field(..., description="Number of users returned in this response.")
truncated: bool = Field(
default=False,
description="True when the query matched more rows than page_size allows.",
)
class TagInfo(BaseModel):
id: int | None = None
name: str | None = None

View File

@@ -17,12 +17,14 @@
"""System tools for MCP service."""
from .find_users import find_users
from .generate_bug_report import generate_bug_report
from .get_instance_info import get_instance_info
from .get_schema import get_schema
from .health_check import health_check
__all__ = [
"find_users",
"generate_bug_report",
"health_check",
"get_instance_info",

View File

@@ -0,0 +1,101 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""find_users MCP tool: resolve a person's name to user IDs for filtering."""
import logging
from fastmcp import Context
from sqlalchemy import or_
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.extensions import db, event_logger, security_manager
from superset.mcp_service.system.schemas import (
FindUsersRequest,
FindUsersResponse,
UserMatch,
)
logger = logging.getLogger(__name__)
@tool(
tags=["core"],
annotations=ToolAnnotations(
title="Find users",
readOnlyHint=True,
destructiveHint=False,
),
)
async def find_users(request: FindUsersRequest, ctx: Context) -> FindUsersResponse:
"""Resolve a person's name to user IDs so they can be used as filter values.
Use this when the caller asks "show me <person>'s dashboards/charts/datasets"
or "what is <person> working on". Take the matching user.id and pass it as
the value for a created_by_fk or changed_by_fk filter on list_dashboards,
list_charts, or list_datasets.
Matches case-insensitively against username, first_name, last_name, and
email. The query is required and non-empty; this tool does not enumerate
the full user directory.
Privacy: returning a user's identity here is sanctioned only for resolving
filter values. Do not use the response to answer "who owns X", "who can
access X", or any access-list question — those remain off-limits per the
server instructions.
"""
await ctx.info(
"Resolving user query: query=%s, page_size=%s"
% (request.query, request.page_size)
)
user_model = security_manager.user_model
needle = f"%{request.query.strip()}%"
with event_logger.log_context(action="mcp.find_users.query"):
query = (
db.session.query(user_model)
.filter(
or_(
user_model.username.ilike(needle),
user_model.first_name.ilike(needle),
user_model.last_name.ilike(needle),
user_model.email.ilike(needle),
)
)
.order_by(user_model.username.asc())
)
# Fetch one extra row to detect truncation without a separate count query.
rows = query.limit(request.page_size + 1).all()
truncated = len(rows) > request.page_size
rows = rows[: request.page_size]
users: list[UserMatch] = [
UserMatch(
id=getattr(row, "id", None),
username=getattr(row, "username", None),
first_name=getattr(row, "first_name", None),
last_name=getattr(row, "last_name", None),
)
for row in rows
]
await ctx.info(
"Resolved user query: matches=%s, truncated=%s" % (len(users), truncated)
)
return FindUsersResponse(users=users, count=len(users), truncated=truncated)

View File

@@ -24,9 +24,10 @@ single dataframe.
"""
from datetime import datetime, timedelta
from time import time
from datetime import date, datetime, time, timedelta, tzinfo
from time import time as current_time
from typing import Any, cast, Sequence, TypeGuard
from zoneinfo import ZoneInfo
import isodate
import numpy as np
@@ -103,7 +104,7 @@ def get_results(query_object: QueryObject) -> QueryResult:
raise ValueError("QueryObject must have a datasource defined.")
# Track execution time
start_time = time()
start_time = current_time()
semantic_view = query_object.datasource.implementation
dispatcher = (
@@ -127,7 +128,7 @@ def get_results(query_object: QueryObject) -> QueryResult:
# If no time offsets, return the main result as-is
if not query_object.time_offsets or len(queries) <= 1:
duration = timedelta(seconds=time() - start_time)
duration = timedelta(seconds=current_time() - start_time)
return map_semantic_result_to_query_result(
main_result,
query_object,
@@ -197,7 +198,7 @@ def get_results(query_object: QueryObject) -> QueryResult:
requests=all_requests,
results=pa.Table.from_pandas(main_df),
)
duration = timedelta(seconds=time() - start_time)
duration = timedelta(seconds=current_time() - start_time)
return map_semantic_result_to_query_result(
semantic_result,
query_object,
@@ -541,21 +542,29 @@ def _convert_query_object_filter(
if operator_str == FilterOperator.TEMPORAL_RANGE.value:
if not isinstance(value, str) or value == NO_TIME_RANGE:
return None
start, end = value.split(" : ")
return {
Filter(
type=PredicateType.WHERE,
column=dimension,
operator=Operator.GREATER_THAN_OR_EQUAL,
value=start,
),
Filter(
type=PredicateType.WHERE,
column=dimension,
operator=Operator.LESS_THAN,
value=end,
),
}
start, end = (side.strip() for side in value.split(" : "))
filters: set[Filter] = set()
if start:
filters.add(
Filter(
type=PredicateType.WHERE,
column=dimension,
operator=Operator.GREATER_THAN_OR_EQUAL,
value=_coerce_scalar_filter_value(start, dimension),
)
)
if end:
filters.add(
Filter(
type=PredicateType.WHERE,
column=dimension,
operator=Operator.LESS_THAN,
value=_coerce_scalar_filter_value(end, dimension),
)
)
return filters or None
value = _coerce_filter_value(value, dimension)
# Map QueryObject operators to semantic layer operators
operator_mapping = {
@@ -588,6 +597,149 @@ def _convert_query_object_filter(
}
def _coerce_filter_value(
value: FilterValues | frozenset[FilterValues],
dimension: Dimension,
) -> FilterValues | frozenset[FilterValues]:
if isinstance(value, frozenset):
return frozenset(_coerce_scalar_filter_value(v, dimension) for v in value)
return _coerce_scalar_filter_value(value, dimension)
def _timestamp_target_tz(dtype: pa.DataType) -> tzinfo | None:
tz_name = getattr(dtype, "tz", None)
return ZoneInfo(tz_name) if tz_name else None
def _align_tz(dt: datetime, target_tz: tzinfo | None) -> datetime:
if target_tz is None:
return dt
if dt.tzinfo is None:
return dt.replace(tzinfo=target_tz)
return dt.astimezone(target_tz)
def _coerce_scalar_filter_value( # noqa: C901 — type dispatch, complexity is inherent
value: FilterValues, dimension: Dimension
) -> FilterValues:
if value is None:
return None
dtype = dimension.type
if pa.types.is_boolean(dtype):
if isinstance(value, bool):
return value
if isinstance(value, (int, float)) and value in (0, 1):
return bool(value)
if isinstance(value, str):
parsed = value.strip().lower()
if parsed in {"true", "t", "1", "yes", "y", "on"}:
return True
if parsed in {"false", "f", "0", "no", "n", "off"}:
return False
raise ValueError(
f"Invalid boolean value {value!r} for filter column {dimension.name}"
)
if pa.types.is_integer(dtype):
if isinstance(value, bool):
raise ValueError(
f"Invalid integer value {value!r} for filter column {dimension.name}"
)
if isinstance(value, int):
return value
if isinstance(value, float) and value.is_integer():
return int(value)
if isinstance(value, str):
try:
return int(value.strip())
except ValueError as ex:
raise ValueError(
f"Invalid integer value {value!r} for filter column "
f"{dimension.name}"
) from ex
raise ValueError(
f"Invalid integer value {value!r} for filter column {dimension.name}"
)
if pa.types.is_floating(dtype) or pa.types.is_decimal(dtype):
# Decimal dimensions are coerced through ``float`` because ``FilterValues``
# does not include ``Decimal``. That is lossless for the common case
# (≤ ~15 significant digits) and matches how downstream semantic-view
# implementations consume numeric filters; high-precision decimals would
# need a wider ``FilterValues`` union and propagation through the cache's
# comparability checks.
if isinstance(value, bool):
raise ValueError(
f"Invalid numeric value {value!r} for filter column {dimension.name}"
)
if isinstance(value, (int, float)):
return float(value)
if isinstance(value, str):
try:
return float(value.strip())
except ValueError as ex:
raise ValueError(
f"Invalid numeric value {value!r} for filter column "
f"{dimension.name}"
) from ex
raise ValueError(
f"Invalid numeric value {value!r} for filter column {dimension.name}"
)
if pa.types.is_date(dtype):
if isinstance(value, datetime):
return value.date()
if isinstance(value, date):
return value
if isinstance(value, str):
try:
return datetime.fromisoformat(value.strip()).date()
except ValueError as ex:
raise ValueError(
f"Invalid date value {value!r} for filter column {dimension.name}"
) from ex
raise ValueError(
f"Invalid date value {value!r} for filter column {dimension.name}"
)
if pa.types.is_timestamp(dtype):
target_tz = _timestamp_target_tz(dtype)
if isinstance(value, datetime):
return _align_tz(value, target_tz)
if isinstance(value, date):
return _align_tz(datetime.combine(value, time.min), target_tz)
if isinstance(value, str):
normalized = value.strip().replace("Z", "+00:00")
try:
return _align_tz(datetime.fromisoformat(normalized), target_tz)
except ValueError as ex:
raise ValueError(
f"Invalid timestamp value {value!r} for filter column "
f"{dimension.name}"
) from ex
raise ValueError(
f"Invalid timestamp value {value!r} for filter column {dimension.name}"
)
if pa.types.is_time(dtype):
if isinstance(value, time):
return value
if isinstance(value, str):
try:
return time.fromisoformat(value.strip())
except ValueError as ex:
raise ValueError(
f"Invalid time value {value!r} for filter column {dimension.name}"
) from ex
raise ValueError(
f"Invalid time value {value!r} for filter column {dimension.name}"
)
return value
def _get_order_from_query_object(
query_object: ValidatedQueryObject,
all_metrics: dict[str, Metric],

View File

@@ -324,7 +324,10 @@ class WebDriverPlaywright(WebDriverProxy):
'document.querySelectorAll(".chart-container").length'
)
dashboard_height = page.evaluate(
f'document.querySelector(".{element_name}").scrollHeight || 0'
f"""() => {{
const target = document.querySelector(\".{element_name}\");
return target ? target.scrollHeight : 0;
}}"""
)
chart_threshold = app.config.get(
"SCREENSHOT_TILED_CHART_THRESHOLD", 20
@@ -336,6 +339,14 @@ class WebDriverPlaywright(WebDriverProxy):
"SCREENSHOT_TILED_VIEWPORT_HEIGHT", viewport_height
)
if dashboard_height == 0:
logger.warning(
"Could not determine dashboard height for element %s "
"at url %s; falling back to standard screenshot behavior",
element_name,
url,
)
# Use tiled screenshots for large dashboards
use_tiled = (
chart_count >= chart_threshold

View File

@@ -29,6 +29,7 @@ from flask_appbuilder.api import (
rison as parse_rison,
safe,
)
from flask_appbuilder.const import API_FILTERS_RIS_KEY
from flask_appbuilder.models.filters import BaseFilter, Filters
from flask_appbuilder.models.sqla.filters import FilterStartsWith
from flask_appbuilder.models.sqla.interface import SQLAInterface
@@ -377,6 +378,35 @@ class BaseSupersetModelRestApi(BaseSupersetApiMixin, ModelRestApi):
self.add_columns = [model_id]
super()._init_properties()
def _handle_filters_args(self, rison_args: dict[str, Any]) -> Filters:
"""
Build a request-scoped ``Filters`` instance from Rison-encoded args.
Overrides :meth:`flask_appbuilder.api.ModelRestApi._handle_filters_args`,
which mutates ``self._filters`` (a single instance shared across
requests on the same API view). Under concurrent traffic that shared
state can leak filters from one request into another — e.g. two
parallel ``GET /api/v1/<resource>/`` calls filtering by different
values can return mixed results.
Returning a fresh ``Filters`` per call keeps each request isolated.
Applies to every subclass of ``BaseSupersetModelRestApi``
(datasets, charts, dashboards, saved queries, queries, databases,
etc.) — see issue #33828 for the original report on the dataset
endpoint.
:param rison_args: Arguments parsed from the API request's
Rison-encoded ``q`` parameter.
:returns: A request-scoped ``Filters`` instance joined with the
API's base filters.
"""
filters = self.datamodel.get_filters(
search_columns=self.search_columns,
search_filters=self.search_filters,
)
filters.rest_add_filters(rison_args.get(API_FILTERS_RIS_KEY, []))
return filters.get_joined_filters(self._base_filters)
def _get_related_filter(
self, datamodel: SQLAInterface, column_name: str, value: str
) -> Filters:

View File

@@ -120,3 +120,46 @@ def test_get_dataset_include_rendered_sql_passes_table_to_template_processor(
assert response.status_code == 200
mock_get_processor.assert_called_once_with(database=database, table=dataset)
def test_handle_filters_args_returns_request_scoped_filters(
session: Session,
client: Any,
full_api_access: None,
) -> None:
"""
``_handle_filters_args`` must return a fresh ``Filters`` instance per
call so concurrent requests don't share filter state.
Regression test for #33828: under concurrent traffic the FAB default
implementation mutates ``self._filters`` (a single shared instance),
causing filters from one request to leak into another.
The fix lives on ``BaseSupersetModelRestApi`` so every superset REST
API subclass (datasets, charts, dashboards, saved queries, etc.)
inherits the request-scoped behavior. This test exercises it via
``DatasetRestApi`` as a concrete subclass.
"""
from flask_appbuilder.const import API_FILTERS_RIS_KEY
from superset.datasets.api import DatasetRestApi
api = DatasetRestApi()
api.datamodel = MagicMock()
api.search_columns = ["table_name"]
api.search_filters = {}
api._base_filters = MagicMock() # noqa: SLF001
# Each call should construct a fresh Filters instance via datamodel.get_filters
rison_args = {
API_FILTERS_RIS_KEY: [{"col": "table_name", "opr": "eq", "value": "a"}],
}
api._handle_filters_args(rison_args) # noqa: SLF001
api._handle_filters_args(rison_args) # noqa: SLF001
assert api.datamodel.get_filters.call_count == 2
# Returned object must be the joined-filters result of the *fresh* Filters,
# not the shared self._filters attribute.
fresh_filters = api.datamodel.get_filters.return_value
assert fresh_filters.rest_add_filters.call_count == 2
assert fresh_filters.get_joined_filters.call_count == 2

View File

@@ -594,7 +594,34 @@ class TestMapXYConfig:
assert result["viz_type"] == "echarts_timeseries_scatter"
assert result["show_legend"] is False
assert result["legend_orientation"] == "top"
assert result["legendOrientation"] == "top"
def test_map_xy_config_with_color_scheme(self) -> None:
"""color_scheme propagates to form_data when set."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue")],
kind="line",
color_scheme="lyftColors",
)
result = map_xy_config(config)
assert result["color_scheme"] == "lyftColors"
def test_map_xy_config_without_color_scheme(self) -> None:
"""color_scheme key omitted when not set, leaving Superset default."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue")],
kind="line",
)
result = map_xy_config(config)
assert "color_scheme" not in result
def test_map_xy_config_with_time_grain_month(self) -> None:
"""Test XY config mapping with monthly time grain"""

View File

@@ -29,18 +29,23 @@ from pydantic import ValidationError
from superset.mcp_service.chart.chart_utils import (
generate_chart_name,
map_big_number_config,
map_config_to_form_data,
map_mixed_timeseries_config,
map_pie_config,
map_pivot_table_config,
map_table_config,
)
from superset.mcp_service.chart.schemas import (
AxisConfig,
BigNumberChartConfig,
ColumnRef,
CurrencyFormat,
FilterConfig,
MixedTimeseriesChartConfig,
PieChartConfig,
PivotTableChartConfig,
TableChartConfig,
)
from superset.mcp_service.chart.validation.schema_validator import SchemaValidator
@@ -212,6 +217,18 @@ class TestMapPieConfig:
assert result["adhoc_filters"][0]["operator"] == "=="
assert result["adhoc_filters"][0]["comparator"] == "US"
def test_pie_form_data_color_scheme_override(self) -> None:
"""Explicit color_scheme overrides the supersetColors default."""
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="product"),
metric=ColumnRef(name="revenue", aggregate="SUM"),
color_scheme="googleCategory10c",
)
result = map_pie_config(config)
assert result["color_scheme"] == "googleCategory10c"
def test_pie_form_data_custom_options(self) -> None:
config = PieChartConfig(
chart_type="pie",
@@ -975,3 +992,272 @@ class TestSchemaValidatorNewTypes:
assert is_valid is False
assert error is not None
assert error.error_code == "INVALID_CHART_TYPE"
# ============================================================
# Chart Formatting Options Tests (sc-102806 follow-up)
# ============================================================
class TestPieFormattingOptions:
"""number/date/currency format, color scheme, legend orientation on Pie."""
def test_currency_format_in_form_data(self) -> None:
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="product"),
metric=ColumnRef(name="revenue", aggregate="SUM"),
currency_format=CurrencyFormat(symbol="USD", symbol_position="prefix"),
)
result = map_pie_config(config)
assert result["currency_format"] == {
"symbol": "USD",
"symbolPosition": "prefix",
}
def test_currency_format_omitted_when_unset(self) -> None:
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="product"),
metric=ColumnRef(name="revenue", aggregate="SUM"),
)
result = map_pie_config(config)
assert "currency_format" not in result
def test_legend_orientation_in_form_data(self) -> None:
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="product"),
metric=ColumnRef(name="revenue", aggregate="SUM"),
legend_orientation="bottom",
)
result = map_pie_config(config)
assert result["legendOrientation"] == "bottom"
def test_default_legend_orientation_is_top(self) -> None:
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="product"),
metric=ColumnRef(name="revenue", aggregate="SUM"),
)
result = map_pie_config(config)
assert result["legendOrientation"] == "top"
def test_date_format_overridable(self) -> None:
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="ds"),
metric=ColumnRef(name="revenue", aggregate="SUM"),
date_format="%Y-%m-%d",
)
result = map_pie_config(config)
assert result["date_format"] == "%Y-%m-%d"
class TestPivotTableFormattingOptions:
"""date/currency format on PivotTable."""
def test_currency_format_in_form_data(self) -> None:
config = PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="region")],
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
currency_format=CurrencyFormat(symbol="EUR", symbol_position="suffix"),
)
result = map_pivot_table_config(config)
assert result["currency_format"] == {
"symbol": "EUR",
"symbolPosition": "suffix",
}
def test_date_format_in_form_data(self) -> None:
config = PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="ds")],
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
date_format="%Y-%m",
)
result = map_pivot_table_config(config)
assert result["date_format"] == "%Y-%m"
def test_formatting_omitted_when_unset(self) -> None:
config = PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="region")],
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
)
result = map_pivot_table_config(config)
assert "currency_format" not in result
assert "date_format" not in result
class TestMixedTimeseriesFormattingOptions:
"""color scheme, currency format, legend orientation, data labels on Mixed."""
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
def test_color_scheme_in_form_data(self, mock_is_temporal) -> None:
mock_is_temporal.return_value = True
config = MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="ds"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
color_scheme="lyftColors",
)
result = map_mixed_timeseries_config(config)
assert result["color_scheme"] == "lyftColors"
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
def test_currency_format_primary_and_secondary(self, mock_is_temporal) -> None:
mock_is_temporal.return_value = True
config = MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="ds"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
currency_format=CurrencyFormat(symbol="USD"),
currency_format_secondary=CurrencyFormat(symbol="GBP"),
)
result = map_mixed_timeseries_config(config)
assert result["currency_format"] == {
"symbol": "USD",
"symbolPosition": "prefix",
}
assert result["currency_format_secondary"] == {
"symbol": "GBP",
"symbolPosition": "prefix",
}
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
def test_legend_orientation_in_form_data(self, mock_is_temporal) -> None:
mock_is_temporal.return_value = True
config = MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="ds"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
legend_orientation="left",
)
result = map_mixed_timeseries_config(config)
assert result["legendOrientation"] == "left"
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
def test_show_value_data_labels(self, mock_is_temporal) -> None:
mock_is_temporal.return_value = True
config = MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="ds"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
show_value=True,
)
result = map_mixed_timeseries_config(config)
assert result["show_value"] is True
class TestBigNumberFormattingOptions:
"""color scheme, currency format, time format on BigNumber."""
def test_currency_format_in_form_data(self) -> None:
config = BigNumberChartConfig(
chart_type="big_number",
metric=ColumnRef(name="revenue", aggregate="SUM"),
currency_format=CurrencyFormat(symbol="JPY", symbol_position="prefix"),
)
result = map_big_number_config(config)
assert result["currency_format"] == {
"symbol": "JPY",
"symbolPosition": "prefix",
}
def test_color_scheme_in_form_data(self) -> None:
config = BigNumberChartConfig(
chart_type="big_number",
metric=ColumnRef(name="revenue", aggregate="SUM"),
color_scheme="d3Category10",
)
result = map_big_number_config(config)
assert result["color_scheme"] == "d3Category10"
def test_time_format_only_for_trendline(self) -> None:
# Without trendline, time_format is dropped because the trendline
# x-axis doesn't render.
config = BigNumberChartConfig(
chart_type="big_number",
metric=ColumnRef(name="revenue", aggregate="SUM"),
time_format="%Y-%m-%d",
)
result = map_big_number_config(config)
assert "time_format" not in result
def test_time_format_with_trendline(self) -> None:
config = BigNumberChartConfig(
chart_type="big_number",
metric=ColumnRef(name="revenue", aggregate="SUM"),
temporal_column="ds",
show_trendline=True,
time_format="%Y-%m-%d",
)
result = map_big_number_config(config)
assert result["time_format"] == "%Y-%m-%d"
class TestTableFormattingOptions:
"""color scheme on Table."""
def test_color_scheme_in_form_data(self) -> None:
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="product"), ColumnRef(name="revenue")],
color_scheme="lyftColors",
)
result = map_table_config(config)
assert result["color_scheme"] == "lyftColors"
def test_color_scheme_omitted_when_unset(self) -> None:
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="product"), ColumnRef(name="revenue")],
)
result = map_table_config(config)
assert "color_scheme" not in result
class TestCurrencyFormatModel:
"""CurrencyFormat schema validation."""
def test_default_symbol_position_is_prefix(self) -> None:
cf = CurrencyFormat(symbol="USD")
assert cf.symbol_position == "prefix"
def test_camel_case_alias_accepted(self) -> None:
cf = CurrencyFormat.model_validate(
{"symbol": "USD", "symbolPosition": "suffix"}
)
assert cf.symbol_position == "suffix"
def test_invalid_position_rejected(self) -> None:
with pytest.raises(ValidationError):
CurrencyFormat(symbol="USD", symbol_position="middle")
def test_to_form_data_shape(self) -> None:
cf = CurrencyFormat(symbol="EUR", symbol_position="suffix")
assert cf.to_form_data() == {"symbol": "EUR", "symbolPosition": "suffix"}

View File

@@ -33,11 +33,15 @@ from superset.mcp_service.chart.schemas import (
PerformanceMetadata,
)
from superset.mcp_service.chart.tool.get_chart_data import (
_GENERIC_TYPE_MAP,
_MAX_RECOMMENDATIONS,
_query_from_form_data,
_recommend_visualizations,
_sanitize_chart_data_for_llm_context,
)
from superset.mcp_service.utils import sanitize_for_llm_context
from superset.mcp_service.utils.sanitization import LLM_CONTEXT_ESCAPED_CLOSE_DELIMITER
from superset.utils.core import GenericDataType
def _collect_groupby_extras(
@@ -1167,3 +1171,284 @@ class TestChartDataCommandValidation:
)
mock_command.run.assert_not_called()
@pytest.fixture
def mcp_server():
from superset.mcp_service.app import mcp
return mcp
@pytest.fixture
def mock_auth():
"""Mock MCP auth so Client.call_tool() doesn't need a real admin user."""
import importlib
from contextlib import contextmanager
from unittest.mock import Mock, patch
_gcd_module = importlib.import_module(
"superset.mcp_service.chart.tool.get_chart_data"
)
@contextmanager
def _noop_log_context(*_args: Any, **_kwargs: Any) -> Any:
yield lambda **_kw: None
# Neutralize event_logger.log_context: the default DBEventLogger would
# otherwise insert a log row referencing our mock user_id and fail a
# FK constraint against the real users table. Patch via the module
# object directly — the `tool` package's __init__.py re-exports the
# get_chart_data function under the same name, which shadows the
# submodule binding in the package namespace, so a dotted-string patch
# target resolves to the function and mock.patch cannot find
# event_logger on it.
mock_event_logger = Mock()
mock_event_logger.log_context.side_effect = _noop_log_context
with (
patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user,
patch.object(_gcd_module, "event_logger", mock_event_logger),
):
user = Mock()
user.id = 1
user.username = "admin"
mock_get_user.return_value = user
yield mock_get_user
def _extract_metrics_load_path(load_opt: Any) -> list[str]:
"""Walk a SQLAlchemy Load option and return the attr chain.
e.g. subqueryload(Slice.table).subqueryload(SqlaTable.metrics)
-> ["table", "metrics"]
"""
path = getattr(load_opt, "path", ())
return [elem.key for elem in path if hasattr(elem, "key")]
class TestChartLookupEagerLoading:
"""Tests that get_chart_data eager-loads dataset.metrics on chart lookup.
Regression tests for the Excel export DetachedInstanceError on
dataset.metrics. The chart's dataset metrics relationship must be
eager-loaded at fetch time so it remains accessible during Excel export
after the request-scoped session is detached.
"""
@pytest.mark.asyncio
async def test_numeric_id_lookup_passes_metrics_eager_load(
self, mcp_server, mock_auth
):
"""Integer identifier lookup must eager-load Slice.table.metrics."""
from unittest.mock import patch
from fastmcp import Client
with patch(
"superset.daos.chart.ChartDAO.find_by_id", return_value=None
) as mock_find:
async with Client(mcp_server) as client:
await client.call_tool(
"get_chart_data",
{"request": {"identifier": 42, "format": "excel"}},
)
mock_find.assert_called_once()
call = mock_find.call_args
assert call.args == (42,)
query_options = call.kwargs.get("query_options")
assert query_options is not None, (
"Chart lookup must pass query_options for eager-loading."
)
assert len(query_options) == 1
load_path = _extract_metrics_load_path(query_options[0])
assert load_path == ["table", "metrics"], (
f"Expected subqueryload chain 'table' -> 'metrics', got {load_path}"
)
@pytest.mark.asyncio
async def test_uuid_lookup_passes_metrics_eager_load(self, mcp_server, mock_auth):
"""UUID identifier lookup must also eager-load Slice.table.metrics."""
from unittest.mock import patch
from fastmcp import Client
uuid = "a1b2c3d4-5678-90ab-cdef-1234567890ab"
with patch(
"superset.daos.chart.ChartDAO.find_by_id", return_value=None
) as mock_find:
async with Client(mcp_server) as client:
await client.call_tool(
"get_chart_data",
{"request": {"identifier": uuid, "format": "excel"}},
)
mock_find.assert_called_once()
call = mock_find.call_args
assert call.args == (uuid,)
assert call.kwargs.get("id_column") == "uuid"
query_options = call.kwargs.get("query_options")
assert query_options is not None, (
"UUID chart lookup must pass query_options for eager-loading."
)
load_path = _extract_metrics_load_path(query_options[0])
assert load_path == ["table", "metrics"], (
f"Expected subqueryload chain 'table' -> 'metrics', got {load_path}"
)
@pytest.mark.asyncio
async def test_json_format_also_eager_loads_metrics(self, mcp_server, mock_auth):
"""Eager-load is applied for every format, not just Excel.
Applying unconditionally keeps the fix robust if additional code paths
start touching dataset.metrics, and avoids branching behavior that
would be easy to regress on.
"""
from unittest.mock import patch
from fastmcp import Client
with patch(
"superset.daos.chart.ChartDAO.find_by_id", return_value=None
) as mock_find:
async with Client(mcp_server) as client:
await client.call_tool(
"get_chart_data",
{"request": {"identifier": 7, "format": "json"}},
)
call = mock_find.call_args
query_options = call.kwargs.get("query_options")
assert query_options is not None
assert _extract_metrics_load_path(query_options[0]) == ["table", "metrics"]
# ---------------------------------------------------------------------------
# Tests for _recommend_visualizations
# ---------------------------------------------------------------------------
def _col(
name: str,
data_type: str = "string",
unique_count: int = 5,
null_count: int = 0,
) -> DataColumn:
"""Shortcut to build a DataColumn for tests."""
return DataColumn(
name=name,
display_name=name,
data_type=data_type,
sample_values=[],
null_count=null_count,
unique_count=unique_count,
)
def test_recommend_temporal_and_numeric_suggests_line_chart():
cols = [_col("created_at", "temporal"), _col("revenue", "numeric")]
result = _recommend_visualizations("table", cols, row_count=50)
assert "line chart" in result
assert "area chart" in result
def test_recommend_categorical_and_numeric_suggests_bar_chart():
cols = [_col("region", "string", unique_count=5), _col("sales", "numeric")]
result = _recommend_visualizations("echarts_timeseries_line", cols, row_count=50)
assert "bar chart" in result
def test_recommend_excludes_current_viz_type():
cols = [_col("created_at", "temporal"), _col("revenue", "numeric")]
result = _recommend_visualizations("echarts_timeseries_line", cols, row_count=50)
assert "line chart" not in result
def test_recommend_multiple_numeric_suggests_scatter():
cols = [
_col("height", "numeric"),
_col("weight", "numeric"),
_col("age", "numeric"),
]
result = _recommend_visualizations("table", cols, row_count=100)
assert "scatter plot" in result
def test_recommend_single_numeric_suggests_kpi():
cols = [_col("total_revenue", "numeric")]
result = _recommend_visualizations("table", cols, row_count=1)
assert "big number / KPI" in result
def test_recommend_all_strings_falls_back():
cols = [_col("name", "string"), _col("address", "string")]
result = _recommend_visualizations("pie", cols, row_count=100)
assert "table" in result or "bar chart" in result
def test_recommend_high_cardinality_no_pie():
cols = [
_col("user_id", "string", unique_count=900),
_col("score", "numeric"),
]
result = _recommend_visualizations("table", cols, row_count=1000)
assert "pie chart" not in result
def test_recommend_caps_at_max():
cols = [_col("ts", "temporal"), _col("a", "numeric"), _col("b", "numeric")]
result = _recommend_visualizations("table", cols, row_count=100)
assert len(result) <= _MAX_RECOMMENDATIONS
def test_recommend_empty_columns_returns_table():
result = _recommend_visualizations("table", [], row_count=0)
assert result == ["table"]
def test_recommend_pie_only_for_low_cardinality():
cols = [
_col("department", "string", unique_count=25),
_col("headcount", "numeric"),
]
result = _recommend_visualizations("table", cols, row_count=100)
assert "pie chart" not in result
def test_recommend_temporal_few_rows_prefers_bar():
cols = [_col("date", "temporal"), _col("revenue", "numeric")]
result = _recommend_visualizations("table", cols, row_count=3)
assert "bar chart" in result
assert "line chart" not in result
def test_recommend_single_numeric_high_cardinality_suggests_histogram():
cols = [_col("salary", "numeric", unique_count=500)]
result = _recommend_visualizations("table", cols, row_count=1000)
assert "histogram" in result
def test_coltypes_populates_data_type():
"""Verify that GenericDataType values from coltypes are mapped correctly."""
assert _GENERIC_TYPE_MAP[GenericDataType.NUMERIC] == "numeric"
assert _GENERIC_TYPE_MAP[GenericDataType.STRING] == "string"
assert _GENERIC_TYPE_MAP[GenericDataType.TEMPORAL] == "temporal"
assert _GENERIC_TYPE_MAP[GenericDataType.BOOLEAN] == "boolean"
def test_bool_isinstance_check_before_int():
"""bool is a subclass of int; verify bool check takes priority in fallback."""
# When coltypes is unavailable, the fallback isinstance heuristic
# must check bool before int/float since isinstance(True, int) is True.
# We verify this indirectly: if _GENERIC_TYPE_MAP handles bool correctly,
# and the fallback code checks bool first, booleans won't be "numeric".
# Direct test: simulate what the fallback does
sample_values = [True, False, True]
data_type = "string"
if all(isinstance(v, bool) for v in sample_values):
data_type = "boolean"
elif all(isinstance(v, (int, float)) for v in sample_values):
data_type = "numeric"
assert data_type == "boolean"

View File

@@ -16,18 +16,25 @@
# under the License.
"""
Unit tests for get_chart_info MCP tool privacy behavior.
Unit tests for get_chart_info MCP tool: dashboard-filter resolution and
privacy behavior.
"""
import importlib
from contextlib import nullcontext
from types import SimpleNamespace
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch
import pytest
from fastmcp import Client
from superset.commands.dashboard.exceptions import DashboardNotFoundError
from superset.mcp_service.app import mcp
from superset.mcp_service.chart.chart_helpers import (
_resolve_filter_operator_and_value,
build_applied_dashboard_filters,
ChartNotOnDashboardError,
)
from superset.mcp_service.chart.schemas import (
ChartInfo,
extract_filters_from_form_data,
@@ -84,6 +91,238 @@ def _make_chart_info() -> ChartInfo:
)
class TestGetChartInfoRequestSchema:
def test_dashboard_id_optional(self):
request = GetChartInfoRequest(identifier=1)
assert request.dashboard_id is None
def test_dashboard_id_accepted(self):
request = GetChartInfoRequest(identifier=1, dashboard_id=42)
assert request.dashboard_id == 42
class TestResolveFilterOperatorAndValue:
def test_matches_adhoc_filter_by_subject(self):
efd = {
"adhoc_filters": [
{
"subject": "country",
"operator": "IN",
"comparator": ["US", "CA"],
}
]
}
assert _resolve_filter_operator_and_value(efd, "country") == (
"IN",
["US", "CA"],
)
def test_matches_legacy_filter_by_col(self):
efd = {"filters": [{"col": "state", "op": "==", "val": "NY"}]}
assert _resolve_filter_operator_and_value(efd, "state") == ("==", "NY")
def test_time_range_when_no_column(self):
efd = {"time_range": "Last 7 days"}
assert _resolve_filter_operator_and_value(efd, None) == (
"TIME_RANGE",
"Last 7 days",
)
def test_column_not_in_extra_form_data(self):
efd = {
"adhoc_filters": [{"subject": "other", "operator": "==", "comparator": 1}]
}
assert _resolve_filter_operator_and_value(efd, "country") == (None, None)
def test_none_extra_form_data(self):
assert _resolve_filter_operator_and_value(None, "country") == (None, None)
def test_ignores_non_dict_entries(self):
efd = {
"adhoc_filters": ["not-a-dict", None],
"filters": [42, "foo"],
}
assert _resolve_filter_operator_and_value(efd, "country") == (None, None)
class TestBuildAppliedDashboardFilters:
"""The helper validates access, checks chart-on-dashboard, iterates
native filters, resolves scope, and maps each to AppliedDashboardFilter."""
def _make_dashboard(self, json_metadata=None, position_json=None, slice_ids=None):
dashboard = MagicMock()
dashboard.json_metadata = json_metadata or "{}"
dashboard.position_json = position_json or "{}"
dashboard.slices = [MagicMock(id=sid) for sid in (slice_ids or [])]
return dashboard
def test_chart_not_on_dashboard_raises(self):
dashboard = self._make_dashboard(slice_ids=[2, 3])
with (
patch("superset.db") as mock_db,
patch("superset.security_manager"),
):
mock_db.session.query.return_value.filter_by.return_value.one_or_none.return_value = dashboard # noqa: E501
with pytest.raises(ChartNotOnDashboardError, match="not on dashboard"):
build_applied_dashboard_filters(dashboard_id=10, chart_id=1)
def test_dashboard_not_found_raises(self):
with (
patch("superset.db") as mock_db,
patch("superset.security_manager"),
):
mock_db.session.query.return_value.filter_by.return_value.one_or_none.return_value = None # noqa: E501
with pytest.raises(DashboardNotFoundError):
build_applied_dashboard_filters(dashboard_id=10, chart_id=1)
def test_in_scope_filter_with_static_default(self):
native_filter = {
"id": "NATIVE_FILTER-1",
"name": "Country",
"type": "NATIVE_FILTER",
"filterType": "filter_select",
"chartsInScope": [1],
"targets": [{"column": {"name": "country"}, "datasetId": 7}],
"defaultDataMask": {
"filterState": {"value": ["US"]},
"extraFormData": {
"adhoc_filters": [
{
"subject": "country",
"operator": "IN",
"comparator": ["US"],
}
]
},
},
}
dashboard = self._make_dashboard(
json_metadata='{"native_filter_configuration": %s}'
% _json(native_filter_list=[native_filter]),
slice_ids=[1],
)
with (
patch("superset.db") as mock_db,
patch("superset.security_manager"),
):
mock_db.session.query.return_value.filter_by.return_value.one_or_none.return_value = dashboard # noqa: E501
result = build_applied_dashboard_filters(dashboard_id=10, chart_id=1)
assert len(result) == 1
flt = result[0]
assert flt.id == "NATIVE_FILTER-1"
assert flt.name == "Country"
assert flt.filter_type == "filter_select"
assert flt.column == "country"
assert flt.operator == "IN"
assert flt.value == ["US"]
assert flt.status == "applied"
def test_excluded_chart_filter_skipped(self):
native_filter = {
"id": "NATIVE_FILTER-1",
"name": "Region",
"type": "NATIVE_FILTER",
"filterType": "filter_select",
"chartsInScope": [2, 3], # chart 1 excluded
"targets": [{"column": {"name": "region"}, "datasetId": 7}],
"defaultDataMask": {
"filterState": {"value": ["NA"]},
"extraFormData": {
"filters": [{"col": "region", "op": "==", "val": "NA"}]
},
},
}
dashboard = self._make_dashboard(
json_metadata='{"native_filter_configuration": %s}'
% _json(native_filter_list=[native_filter]),
slice_ids=[1],
)
with (
patch("superset.db") as mock_db,
patch("superset.security_manager"),
):
mock_db.session.query.return_value.filter_by.return_value.one_or_none.return_value = dashboard # noqa: E501
result = build_applied_dashboard_filters(dashboard_id=10, chart_id=1)
assert result == []
def test_default_to_first_item_marks_prequery(self):
native_filter = {
"id": "NATIVE_FILTER-1",
"name": "Region",
"type": "NATIVE_FILTER",
"filterType": "filter_select",
"chartsInScope": [1],
"targets": [{"column": {"name": "region"}, "datasetId": 7}],
"controlValues": {"defaultToFirstItem": True},
"defaultDataMask": {},
}
dashboard = self._make_dashboard(
json_metadata='{"native_filter_configuration": %s}'
% _json(native_filter_list=[native_filter]),
slice_ids=[1],
)
with (
patch("superset.db") as mock_db,
patch("superset.security_manager"),
):
mock_db.session.query.return_value.filter_by.return_value.one_or_none.return_value = dashboard # noqa: E501
result = build_applied_dashboard_filters(dashboard_id=10, chart_id=1)
assert len(result) == 1
assert result[0].status == "not_applied_uses_default_to_first_item_prequery"
assert result[0].operator is None
assert result[0].value is None
def test_divider_entry_skipped(self):
divider = {
"id": "DIVIDER-1",
"name": "Section header",
"type": "DIVIDER",
}
dashboard = self._make_dashboard(
json_metadata='{"native_filter_configuration": %s}'
% _json(native_filter_list=[divider]),
slice_ids=[1],
)
with (
patch("superset.db") as mock_db,
patch("superset.security_manager"),
):
mock_db.session.query.return_value.filter_by.return_value.one_or_none.return_value = dashboard # noqa: E501
result = build_applied_dashboard_filters(dashboard_id=10, chart_id=1)
assert result == []
def test_no_native_filters_returns_empty_list(self):
dashboard = self._make_dashboard(
json_metadata="{}",
slice_ids=[1],
)
with (
patch("superset.db") as mock_db,
patch("superset.security_manager"),
):
mock_db.session.query.return_value.filter_by.return_value.one_or_none.return_value = dashboard # noqa: E501
result = build_applied_dashboard_filters(dashboard_id=10, chart_id=1)
assert result == []
def _json(native_filter_list):
"""Serialize a native_filter list as JSON string for embedding in
json_metadata fixtures without escaping issues."""
from superset.utils import json
return json.dumps(native_filter_list)
class TestGetChartInfoPrivacy:
@pytest.mark.asyncio
async def test_restricted_user_redacts_saved_chart_data_model_fields(

View File

@@ -2043,12 +2043,10 @@ class TestListDatasetsCreatedByMe:
with pytest.raises(ValidationError, match="created_by_me"):
ListDatasetsRequest(created_by_me=True, search="My tables")
def test_dataset_filter_rejects_created_by_fk(self):
"""created_by_fk is not a public filter column; use created_by_me instead."""
from pydantic import ValidationError
with pytest.raises(ValidationError):
DatasetFilter(col="created_by_fk", opr="eq", value=1)
def test_dataset_filter_accepts_created_by_fk(self):
"""created_by_fk is exposed for person-filtering via find_users."""
f = DatasetFilter(col="created_by_fk", opr="eq", value=1)
assert f.col == "created_by_fk"
class TestListDatasetsOwnedByMe:
@@ -2115,14 +2113,10 @@ class TestListDatasetsRequestWrapper:
assert f.col == col
def test_dataset_filter_invalid_col_raises(self) -> None:
"""Column names not in the Literal are rejected with a validation error.
This guards against LLMs passing ``created_by_fk`` or similar
internal column names that are not exposed as filter fields.
"""
"""Column names not in the Literal are rejected with a validation error."""
from pydantic import ValidationError
for bad_col in ("created_by_fk", "id", "database_id", "owner"):
for bad_col in ("id", "database_id", "owner"):
with pytest.raises(ValidationError):
DatasetFilter(col=bad_col, opr="eq", value="1")

View File

@@ -810,6 +810,52 @@ class TestGenerateExploreLink:
assert "Dataset not found: 99999" in result.data["error"]
assert "list_datasets" in result.data["error"]
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@pytest.mark.asyncio
async def test_generate_explore_link_without_config(
self, mock_find_dataset, mcp_server
):
"""Omitting config returns a default dataset explore URL."""
mock_find_dataset.return_value = _mock_dataset(id=42)
request = GenerateExploreLinkRequest(dataset_id="42")
async with Client(mcp_server) as client:
result = await client.call_tool(
"generate_explore_link", {"request": request.model_dump()}
)
assert result.data["error"] is None
assert (
result.data["url"]
== "http://localhost:9001/explore/?datasource_type=table"
"&datasource_id=42"
)
assert result.data["form_data"] == {}
assert result.data["form_data_key"] is None
assert result.data["chart_type_label"] is None
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@pytest.mark.asyncio
async def test_generate_explore_link_without_config_missing_dataset(
self, mock_find_dataset, mcp_server
):
"""Omitting config still surfaces a dataset-not-found error."""
mock_find_dataset.return_value = None
request = GenerateExploreLinkRequest(dataset_id="99999")
async with Client(mcp_server) as client:
result = await client.call_tool(
"generate_explore_link", {"request": request.model_dump()}
)
assert result.data["url"] == ""
assert result.data["form_data"] == {}
assert result.data["form_data_key"] is None
assert result.data["chart_type_label"] is None
assert "Dataset not found: 99999" in result.data["error"]
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@pytest.mark.asyncio
async def test_generate_explore_link_nonexistent_uuid_dataset(

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,172 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from unittest.mock import MagicMock, Mock, patch
import pytest
from fastmcp import Client
from pydantic import ValidationError
from superset.mcp_service.app import mcp
from superset.mcp_service.plugin.schemas import ListPluginsRequest, PluginColumnFilter
from superset.utils import json
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
def create_mock_plugin(
plugin_id: int = 1,
name: str = "My Plugin",
key: str = "my_plugin",
bundle_url: str = "https://example.com/plugin.js",
) -> MagicMock:
plugin = MagicMock()
plugin.id = plugin_id
plugin.name = name
plugin.key = key
plugin.bundle_url = bundle_url
plugin.changed_on = None
plugin.created_on = None
return plugin
@pytest.fixture
def mcp_server():
return mcp
@pytest.fixture(autouse=True)
def mock_auth():
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
mock_user = Mock()
mock_user.id = 1
mock_user.username = "admin"
mock_get_user.return_value = mock_user
yield mock_get_user
class TestPluginColumnFilterSchema:
def test_invalid_filter_column_rejected(self):
with pytest.raises(ValidationError):
PluginColumnFilter(col="bundle_url", opr="eq", value="test")
def test_valid_name_filter(self):
f = PluginColumnFilter(col="name", opr="eq", value="test")
assert f.col == "name"
def test_valid_key_filter(self):
f = PluginColumnFilter(col="key", opr="eq", value="my_plugin")
assert f.col == "key"
@patch("superset.mcp_service.plugin.dao.DynamicPluginDAO.list")
@pytest.mark.asyncio
async def test_list_plugins_basic(mock_list, mcp_server):
plugin = create_mock_plugin()
mock_list.return_value = ([plugin], 1)
async with Client(mcp_server) as client:
result = await client.call_tool("list_plugins", {})
data = json.loads(result.content[0].text)
assert "plugins" in data
assert len(data["plugins"]) == 1
assert data["plugins"][0]["id"] == 1
assert data["plugins"][0]["name"] == "My Plugin"
assert data["plugins"][0]["key"] == "my_plugin"
@patch("superset.mcp_service.plugin.dao.DynamicPluginDAO.list")
@pytest.mark.asyncio
async def test_list_plugins_with_request(mock_list, mcp_server):
plugin = create_mock_plugin()
mock_list.return_value = ([plugin], 1)
async with Client(mcp_server) as client:
request = ListPluginsRequest(page=1, page_size=10)
result = await client.call_tool(
"list_plugins", {"request": request.model_dump()}
)
data = json.loads(result.content[0].text)
assert data["count"] == 1
assert data["total_count"] == 1
@patch("superset.mcp_service.plugin.dao.DynamicPluginDAO.list")
@pytest.mark.asyncio
async def test_list_plugins_with_search(mock_list, mcp_server):
plugin = create_mock_plugin(name="Custom Chart")
mock_list.return_value = ([plugin], 1)
async with Client(mcp_server) as client:
request = ListPluginsRequest(page=1, page_size=10, search="custom")
result = await client.call_tool(
"list_plugins", {"request": request.model_dump()}
)
data = json.loads(result.content[0].text)
assert data["plugins"][0]["name"] == "Custom Chart"
@patch("superset.mcp_service.plugin.dao.DynamicPluginDAO.list")
@pytest.mark.asyncio
async def test_list_plugins_empty(mock_list, mcp_server):
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client:
result = await client.call_tool("list_plugins", {})
data = json.loads(result.content[0].text)
assert data["count"] == 0
assert data["plugins"] == []
@patch("superset.mcp_service.plugin.dao.DynamicPluginDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_plugin_info_basic(mock_find, mcp_server):
plugin = create_mock_plugin()
mock_find.return_value = plugin
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_plugin_info", {"request": {"identifier": 1}}
)
data = json.loads(result.content[0].text)
assert data["id"] == 1
assert data["name"] == "My Plugin"
assert data["key"] == "my_plugin"
assert data["bundle_url"] == "https://example.com/plugin.js"
@patch("superset.mcp_service.plugin.dao.DynamicPluginDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_plugin_info_not_found(mock_find, mcp_server):
mock_find.return_value = None
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_plugin_info", {"request": {"identifier": 999}}
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "not_found"
def test_list_plugins_request_rejects_search_and_filters():
with pytest.raises(ValidationError):
ListPluginsRequest(
search="test",
filters=[{"col": "name", "opr": "eq", "value": "x"}],
)

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,244 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from unittest.mock import MagicMock, Mock, patch
import pytest
from fastmcp import Client
from pydantic import ValidationError
from superset.mcp_service.app import mcp
from superset.mcp_service.rls.schemas import ListRlsFiltersRequest, RlsColumnFilter
from superset.utils import json
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
def create_mock_rls_filter(
filter_id: int = 1,
name: str = "test_filter",
filter_type: str = "Regular",
clause: str = "user_id = {{current_user_id()}}",
group_key: str | None = None,
) -> MagicMock:
rls_filter = MagicMock()
rls_filter.id = filter_id
rls_filter.name = name
rls_filter.filter_type = filter_type
rls_filter.clause = clause
rls_filter.group_key = group_key
rls_filter.changed_on = None
table = MagicMock()
table.id = 1
table.table_name = "sales"
rls_filter.tables = [table]
role = MagicMock()
role.id = 1
role.name = "Alpha"
rls_filter.roles = [role]
return rls_filter
@pytest.fixture
def mcp_server():
return mcp
@pytest.fixture(autouse=True)
def mock_auth():
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
mock_user = Mock()
mock_user.id = 1
mock_user.username = "admin"
mock_get_user.return_value = mock_user
yield mock_get_user
class TestRlsColumnFilterSchema:
def test_invalid_filter_column_rejected(self):
with pytest.raises(ValidationError):
RlsColumnFilter(col="clause", opr="eq", value="test")
def test_valid_name_filter(self):
f = RlsColumnFilter(col="name", opr="eq", value="test")
assert f.col == "name"
def test_valid_filter_type_filter(self):
f = RlsColumnFilter(col="filter_type", opr="eq", value="Regular")
assert f.col == "filter_type"
@patch("superset.daos.security.RLSDAO.list")
@pytest.mark.asyncio
async def test_list_rls_filters_basic(mock_list, mcp_server):
rls_filter = create_mock_rls_filter()
mock_list.return_value = ([rls_filter], 1)
async with Client(mcp_server) as client:
result = await client.call_tool("list_rls_filters", {})
assert result.content is not None
data = json.loads(result.content[0].text)
assert "rls_filters" in data
assert len(data["rls_filters"]) == 1
assert data["rls_filters"][0]["id"] == 1
assert data["rls_filters"][0]["name"] == "test_filter"
@patch("superset.daos.security.RLSDAO.list")
@pytest.mark.asyncio
async def test_list_rls_filters_with_request(mock_list, mcp_server):
rls_filter = create_mock_rls_filter()
mock_list.return_value = ([rls_filter], 1)
async with Client(mcp_server) as client:
request = ListRlsFiltersRequest(page=1, page_size=10)
result = await client.call_tool(
"list_rls_filters", {"request": request.model_dump()}
)
data = json.loads(result.content[0].text)
assert data["count"] == 1
assert data["total_count"] == 1
@patch("superset.daos.security.RLSDAO.list")
@pytest.mark.asyncio
async def test_list_rls_filters_with_search(mock_list, mcp_server):
rls_filter = create_mock_rls_filter(name="user_filter")
mock_list.return_value = ([rls_filter], 1)
async with Client(mcp_server) as client:
request = ListRlsFiltersRequest(page=1, page_size=10, search="user")
result = await client.call_tool(
"list_rls_filters", {"request": request.model_dump()}
)
data = json.loads(result.content[0].text)
assert data["rls_filters"][0]["name"] == "user_filter"
@patch("superset.daos.security.RLSDAO.list")
@pytest.mark.asyncio
async def test_list_rls_filters_returns_tables_and_roles(mock_list, mcp_server):
rls_filter = create_mock_rls_filter()
mock_list.return_value = ([rls_filter], 1)
async with Client(mcp_server) as client:
request = ListRlsFiltersRequest(
page=1,
page_size=10,
select_columns=["id", "name", "tables", "roles"],
)
result = await client.call_tool(
"list_rls_filters", {"request": request.model_dump()}
)
data = json.loads(result.content[0].text)
item = data["rls_filters"][0]
assert "tables" in item
assert item["tables"][0]["table_name"] == "sales"
assert "roles" in item
assert item["roles"][0]["name"] == "Alpha"
@patch("superset.daos.security.RLSDAO.list")
@pytest.mark.asyncio
async def test_list_rls_filters_roles_only_select_columns(mock_list, mcp_server):
"""Requesting only 'roles' must not raise ValueError from the privacy filter."""
rls_filter = create_mock_rls_filter()
mock_list.return_value = ([rls_filter], 1)
async with Client(mcp_server) as client:
request = ListRlsFiltersRequest(
page=1,
page_size=10,
select_columns=["roles"],
)
result = await client.call_tool(
"list_rls_filters", {"request": request.model_dump()}
)
data = json.loads(result.content[0].text)
item = data["rls_filters"][0]
assert "roles" in item
assert item["roles"][0]["name"] == "Alpha"
@patch("superset.daos.security.RLSDAO.list")
@pytest.mark.asyncio
async def test_list_rls_filters_empty(mock_list, mcp_server):
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client:
result = await client.call_tool("list_rls_filters", {})
data = json.loads(result.content[0].text)
assert data["count"] == 0
assert data["rls_filters"] == []
@patch("superset.daos.security.RLSDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_rls_filter_info_basic(mock_find, mcp_server):
rls_filter = create_mock_rls_filter()
mock_find.return_value = rls_filter
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_rls_filter_info", {"request": {"identifier": 1}}
)
data = json.loads(result.content[0].text)
assert data["id"] == 1
assert data["name"] == "test_filter"
assert data["filter_type"] == "Regular"
assert data["clause"] == "user_id = {{current_user_id()}}"
@patch("superset.daos.security.RLSDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_rls_filter_info_not_found(mock_find, mcp_server):
mock_find.return_value = None
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_rls_filter_info", {"request": {"identifier": 999}}
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "not_found"
@patch("superset.daos.security.RLSDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_rls_filter_info_includes_tables_and_roles(mock_find, mcp_server):
rls_filter = create_mock_rls_filter()
mock_find.return_value = rls_filter
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_rls_filter_info", {"request": {"identifier": 1}}
)
data = json.loads(result.content[0].text)
assert data["tables"][0]["table_name"] == "sales"
assert data["roles"][0]["name"] == "Alpha"
def test_list_rls_filters_request_rejects_search_and_filters():
with pytest.raises(ValidationError):
ListRlsFiltersRequest(
search="test",
filters=[{"col": "name", "opr": "eq", "value": "x"}],
)

View File

@@ -0,0 +1,257 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for find_users MCP tool and its filter contract."""
import importlib
from unittest.mock import MagicMock, Mock, patch
import pytest
from fastmcp import Client
from fastmcp.exceptions import ToolError
from pydantic import ValidationError
from superset.mcp_service.app import mcp
from superset.mcp_service.system.schemas import FindUsersRequest, FindUsersResponse
from superset.utils import json
# Import the submodule directly so ``patch.object`` targets the module (not the
# ``find_users`` function that ``tool/__init__.py`` re-exports onto the
# package). The package attribute is the function, so dotted-string patches
# like ``superset.mcp_service.system.tool.find_users.db`` can resolve to the
# function in some import orderings and fail with AttributeError.
find_users_module = importlib.import_module(
"superset.mcp_service.system.tool.find_users"
)
@pytest.fixture
def mcp_server():
return mcp
@pytest.fixture(autouse=True)
def mock_auth():
"""Mock authentication for all tests."""
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
mock_user = Mock()
mock_user.id = 1
mock_user.username = "admin"
mock_get_user.return_value = mock_user
yield mock_get_user
def _make_user(id_, username, first=None, last=None, email=None, active=True):
"""Build a Mock user with the attributes serialize_user_object reads."""
user = Mock(
spec=["id", "username", "first_name", "last_name", "email", "active", "roles"]
)
user.id = id_
user.username = username
user.first_name = first
user.last_name = last
user.email = email
user.active = active
user.roles = []
return user
def _patch_user_query(rows):
"""Patch the SQLAlchemy chain used by find_users to return a fixed result set."""
chain = MagicMock()
chain.filter.return_value = chain
chain.order_by.return_value = chain
chain.limit.return_value = chain
chain.all.return_value = rows
session = MagicMock()
session.query.return_value = chain
return session, chain
# ---------------------------------------------------------------------------
# Schema tests
# ---------------------------------------------------------------------------
def test_find_users_request_rejects_empty_query():
with pytest.raises(ValidationError):
FindUsersRequest(query="")
def test_find_users_request_rejects_extra_fields():
with pytest.raises(ValidationError):
FindUsersRequest(query="maxime", random_field="x")
def test_find_users_response_default_truncated_false():
resp = FindUsersResponse(users=[], count=0)
assert resp.truncated is False
# ---------------------------------------------------------------------------
# Tool-level tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_find_users_returns_matches(mcp_server):
rows = [
_make_user(
7, "maxime", first="Maxime", last="Beauchemin", email="m@example.com"
)
]
session, _ = _patch_user_query(rows)
with (
patch.object(find_users_module, "db") as mock_db,
patch.object(find_users_module, "security_manager") as mock_sm,
patch.object(find_users_module, "or_") as mock_or,
):
mock_db.session = session
mock_sm.user_model = MagicMock()
mock_or.return_value = MagicMock()
async with Client(mcp_server) as client:
result = await client.call_tool(
"find_users", {"request": {"query": "maxime"}}
)
data = json.loads(result.content[0].text)
assert data["count"] == 1
assert data["truncated"] is False
assert data["users"][0]["id"] == 7
assert data["users"][0]["username"] == "maxime"
assert data["users"][0]["first_name"] == "Maxime"
assert data["users"][0]["last_name"] == "Beauchemin"
# Privacy: minimal projection excludes identity attributes that aren't
# required for filter resolution. Catch regressions on the response shape.
for forbidden in ("email", "active", "roles"):
assert forbidden not in data["users"][0]
# or_ should have been built across the four matched columns
assert mock_or.called
assert len(mock_or.call_args.args) == 4
@pytest.mark.asyncio
async def test_find_users_truncates_when_more_rows_than_page_size(mcp_server):
# page_size=2 with 3 returned rows -> truncated, response trimmed to 2
rows = [
_make_user(1, "a"),
_make_user(2, "b"),
_make_user(3, "c"),
]
session, chain = _patch_user_query(rows)
with (
patch.object(find_users_module, "db") as mock_db,
patch.object(find_users_module, "security_manager") as mock_sm,
patch.object(find_users_module, "or_") as mock_or,
):
mock_db.session = session
mock_sm.user_model = MagicMock()
mock_or.return_value = MagicMock()
async with Client(mcp_server) as client:
result = await client.call_tool(
"find_users", {"request": {"query": "a", "page_size": 2}}
)
# Tool requested page_size+1 rows for truncation detection
chain.limit.assert_called_with(3)
data = json.loads(result.content[0].text)
assert data["count"] == 2
assert data["truncated"] is True
assert [u["id"] for u in data["users"]] == [1, 2]
@pytest.mark.asyncio
async def test_find_users_rejects_empty_query_via_client(mcp_server):
async with Client(mcp_server) as client:
with pytest.raises(ToolError):
await client.call_tool("find_users", {"request": {"query": ""}})
@pytest.mark.parametrize("blank", [" ", " ", "\t", "\n \t"])
def test_find_users_request_rejects_whitespace_only_query(blank):
# Whitespace-only queries would strip to "" and produce a LIKE "%%" pattern
# that enumerates the entire user directory. The validator must reject them.
with pytest.raises(ValidationError):
FindUsersRequest(query=blank)
def test_find_users_request_strips_query_whitespace():
# Validator should normalize the stored query so downstream LIKE patterns
# don't carry leading/trailing whitespace.
request = FindUsersRequest(query=" maxime ")
assert request.query == "maxime"
# ---------------------------------------------------------------------------
# Filter contract: created_by_fk / changed_by_fk filtering on list tools
# ---------------------------------------------------------------------------
@patch("superset.daos.dashboard.DashboardDAO.list")
@pytest.mark.asyncio
async def test_list_dashboards_passes_created_by_fk_filter_to_dao(
mock_list, mcp_server
):
"""list_dashboards should accept created_by_fk filter and forward it."""
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client:
await client.call_tool(
"list_dashboards",
{
"request": {
"filters": [{"col": "created_by_fk", "opr": "eq", "value": 7}],
"page": 1,
"page_size": 10,
}
},
)
assert mock_list.called
forwarded_filters = mock_list.call_args.kwargs.get("column_operators")
assert forwarded_filters is not None
assert any(
getattr(f, "col", None) == "created_by_fk" and getattr(f, "value", None) == 7
for f in forwarded_filters
)
@patch("superset.daos.chart.ChartDAO.list")
@pytest.mark.asyncio
async def test_list_charts_passes_changed_by_fk_filter_to_dao(mock_list, mcp_server):
"""list_charts should accept changed_by_fk filter and forward it."""
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client:
await client.call_tool(
"list_charts",
{
"request": {
"filters": [{"col": "changed_by_fk", "opr": "eq", "value": 7}],
"page": 1,
"page_size": 10,
}
},
)
assert mock_list.called
forwarded_filters = mock_list.call_args.kwargs.get("column_operators")
assert forwarded_filters is not None
assert any(getattr(f, "col", None) == "changed_by_fk" for f in forwarded_filters)

View File

@@ -22,6 +22,8 @@ from unittest.mock import Mock, patch
import pytest
from fastmcp import Client
from fastmcp.client.client import CallToolResult
from mcp.types import TextContent
from pydantic import ValidationError
from superset.mcp_service.app import mcp
@@ -45,6 +47,14 @@ get_schema_module = importlib.import_module(
"superset.mcp_service.system.tool.get_schema"
)
def _result_text(result: CallToolResult) -> str:
"""Return the text payload from the first content block of a tool result."""
block = result.content[0]
assert isinstance(block, TextContent)
return block.text
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@@ -198,7 +208,7 @@ async def test_get_schema_returns_structured_privacy_error_for_dataset(mcp_serve
{"request": {"model_type": "dataset"}},
)
data = json.loads(result.content[0].text)
data = json.loads(_result_text(result))
assert data["error_type"] == DATA_MODEL_METADATA_ERROR_TYPE
assert data["privacy_scope"] == "data_model"
@@ -241,7 +251,7 @@ async def test_get_schema_redacts_chart_data_model_fields(mcp_server):
{"request": {"model_type": "chart"}},
)
data = json.loads(result.content[0].text)
data = json.loads(_result_text(result))
schema_info = data["schema_info"]
assert all(
column["name"] not in CHART_DATA_MODEL_COLUMNS
@@ -389,7 +399,7 @@ class TestGetInstanceInfoCurrentUserViaMCP:
async with Client(mcp_server) as client:
result = await client.call_tool("get_instance_info", {"request": {}})
data = json.loads(result.content[0].text)
data = json.loads(_result_text(result))
assert "current_user" in data
cu = data["current_user"]
assert cu["id"] == 5
@@ -418,7 +428,7 @@ class TestGetInstanceInfoCurrentUserViaMCP:
async with Client(mcp_server) as client:
result = await client.call_tool("get_instance_info", {"request": {}})
data = json.loads(result.content[0].text)
data = json.loads(_result_text(result))
assert data["current_user"] is None
@pytest.mark.asyncio
@@ -444,7 +454,7 @@ class TestGetInstanceInfoCurrentUserViaMCP:
async with Client(mcp_server) as client:
result = await client.call_tool("get_instance_info", {"request": {}})
data = json.loads(result.content[0].text)
data = json.loads(_result_text(result))
cu = data["current_user"]
assert cu["id"] == 99
assert cu["username"] == "bot"
@@ -460,28 +470,50 @@ class TestGetInstanceInfoCurrentUserViaMCP:
# ---------------------------------------------------------------------------
def test_chart_filter_rejects_created_by_fk() -> None:
"""created_by_fk is not a valid ChartFilter column; use created_by_me instead."""
with pytest.raises(ValidationError):
ChartFilter(col="created_by_fk", opr="eq", value=42)
def test_chart_filter_rejects_user_directory_columns_other_than_fk() -> None:
"""ChartFilter still rejects user-directory columns that expose names."""
for col in ("created_by_name", "owners", "changed_by"):
with pytest.raises(ValidationError):
ChartFilter.model_validate({"col": col, "opr": "eq", "value": "anything"})
def test_chart_filter_accepts_created_and_changed_by_fk() -> None:
"""ChartFilter allows filtering by created_by_fk / changed_by_fk (user IDs)."""
for col in ("created_by_fk", "changed_by_fk"):
f = ChartFilter.model_validate({"col": col, "opr": "eq", "value": 42})
assert f.col == col
def test_chart_filter_rejects_invalid_column():
"""Test that ChartFilter rejects invalid column names."""
with pytest.raises(ValidationError):
ChartFilter(col="nonexistent_column", opr="eq", value=42)
ChartFilter.model_validate(
{"col": "nonexistent_column", "opr": "eq", "value": 42}
)
def test_dashboard_filter_rejects_created_by_fk():
"""created_by_fk is not a valid DashboardFilter column; use created_by_me."""
with pytest.raises(ValidationError):
DashboardFilter(col="created_by_fk", opr="eq", value=42)
def test_dashboard_filter_rejects_user_directory_columns_other_than_fk() -> None:
"""DashboardFilter still rejects user-directory columns that expose names."""
for col in ("created_by_name", "owners", "changed_by"):
with pytest.raises(ValidationError):
DashboardFilter.model_validate(
{"col": col, "opr": "eq", "value": "anything"}
)
def test_dashboard_filter_accepts_created_and_changed_by_fk() -> None:
"""DashboardFilter allows filtering by created_by_fk / changed_by_fk."""
for col in ("created_by_fk", "changed_by_fk"):
f = DashboardFilter.model_validate({"col": col, "opr": "eq", "value": 42})
assert f.col == col
def test_dashboard_filter_rejects_invalid_column():
"""Test that DashboardFilter rejects invalid column names."""
with pytest.raises(ValidationError):
DashboardFilter(col="nonexistent_column", opr="eq", value=42)
DashboardFilter.model_validate(
{"col": "nonexistent_column", "opr": "eq", "value": 42}
)
# ---------------------------------------------------------------------------
@@ -492,12 +524,12 @@ def test_dashboard_filter_rejects_invalid_column():
def test_chart_filter_existing_columns_still_work():
"""Test that pre-existing chart filter columns are not broken."""
for col in ("slice_name", "viz_type", "datasource_name"):
f = ChartFilter(col=col, opr="eq", value="test")
f = ChartFilter.model_validate({"col": col, "opr": "eq", "value": "test"})
assert f.col == col
def test_dashboard_filter_existing_columns_still_work():
"""Test that pre-existing dashboard filter columns are not broken."""
for col in ("dashboard_title", "published", "favorite"):
f = DashboardFilter(col=col, opr="eq", value="test")
f = DashboardFilter.model_validate({"col": col, "opr": "eq", "value": "test"})
assert f.col == col

View File

@@ -326,11 +326,19 @@ class TestGetSchemaToolViaClient:
async def test_get_schema_omits_user_directory_columns(
self, mock_filters, mcp_server
):
"""Test that schema discovery does not advertise user/access fields."""
"""Test that schema discovery does not advertise user/access fields.
created_by_fk and changed_by_fk are intentionally allowed in
filter_columns so callers can filter by user ID resolved via find_users,
but they remain hidden from select_columns and sortable_columns so the
directory itself is never exposed.
"""
mock_filters.return_value = {
"dashboard_title": ["eq", "ilike"],
"owner": ["rel_m_m"],
"published": ["eq"],
"created_by_fk": ["eq", "in"],
"changed_by_fk": ["eq", "in"],
}
async with Client(mcp_server) as client:
@@ -352,9 +360,16 @@ class TestGetSchemaToolViaClient:
"owner",
):
assert field not in select_column_names
assert field not in info["filter_columns"]
assert field not in info["sortable_columns"]
# User-name and relationship fields stay out of filter_columns
for field in ("owners", "roles", "created_by", "changed_by", "owner"):
assert field not in info["filter_columns"]
# ID-only filter columns are advertised so callers can filter via find_users
assert "created_by_fk" in info["filter_columns"]
assert "changed_by_fk" in info["filter_columns"]
@patch("superset.daos.chart.ChartDAO.get_filterable_columns_and_operators")
@pytest.mark.asyncio
async def test_get_schema_chart_omits_self_referencing_filter_columns(
@@ -362,8 +377,9 @@ class TestGetSchemaToolViaClient:
):
"""Test that chart schema does not advertise self-referencing filter columns.
Even if the DAO returns created_by_fk or owner, they must be excluded so
LLMs cannot discover and use them to enumerate user IDs.
Even if the DAO returns owner or created_by_fk_or_owner, they must be
excluded — these synthetic columns are generated server-side from the
owned_by_me flag and are not directly usable by LLM callers.
"""
mock_filters.return_value = {
"slice_name": ["eq", "ilike"],
@@ -381,7 +397,7 @@ class TestGetSchemaToolViaClient:
info = data["schema_info"]
assert "slice_name" in info["filter_columns"]
for field in ("created_by_fk", "owner", "created_by_fk_or_owner"):
for field in ("owner", "created_by_fk_or_owner"):
assert field not in info["filter_columns"]
@patch("superset.daos.dataset.DatasetDAO.get_filterable_columns_and_operators")
@@ -391,8 +407,9 @@ class TestGetSchemaToolViaClient:
):
"""Test that dataset schema does not advertise self-referencing filter columns.
Even if the DAO returns created_by_fk or owner, they must be excluded so
LLMs cannot discover and use them to enumerate user IDs.
Even if the DAO returns owner or created_by_fk_or_owner, they must be
excluded — these synthetic columns are generated server-side from the
owned_by_me flag and are not directly usable by LLM callers.
"""
mock_filters.return_value = {
"table_name": ["eq", "ilike"],
@@ -410,7 +427,7 @@ class TestGetSchemaToolViaClient:
info = data["schema_info"]
assert "table_name" in info["filter_columns"]
for field in ("created_by_fk", "owner", "created_by_fk_or_owner"):
for field in ("owner", "created_by_fk_or_owner"):
assert field not in info["filter_columns"]
@patch("superset.daos.dashboard.DashboardDAO.get_filterable_columns_and_operators")
@@ -420,8 +437,9 @@ class TestGetSchemaToolViaClient:
):
"""Test dashboard schema omits self-referencing filter columns.
Even if the DAO returns created_by_fk or owner, they must be excluded
so LLMs cannot discover and use them to enumerate user IDs.
Even if the DAO returns owner or created_by_fk_or_owner, they must be
excluded — these synthetic columns are generated server-side from the
owned_by_me flag and are not directly usable by LLM callers.
"""
mock_filters.return_value = {
"dashboard_title": ["eq", "ilike"],
@@ -439,7 +457,7 @@ class TestGetSchemaToolViaClient:
info = data["schema_info"]
assert "dashboard_title" in info["filter_columns"]
for field in ("created_by_fk", "owner", "created_by_fk_or_owner"):
for field in ("owner", "created_by_fk_or_owner"):
assert field not in info["filter_columns"]

View File

@@ -15,8 +15,10 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from datetime import date, datetime, time, timezone
from typing import Any
from unittest.mock import MagicMock
from zoneinfo import ZoneInfo
import pandas as pd
import pyarrow as pa
@@ -40,6 +42,7 @@ from superset_core.semantic_layers.types import (
from superset_core.semantic_layers.view import SemanticViewFeature
from superset.semantic_layers.mapper import (
_coerce_scalar_filter_value,
_convert_query_object_filter,
_convert_time_grain,
_get_filters_from_extras,
@@ -1204,6 +1207,91 @@ def test_convert_query_object_filter_like() -> None:
}
def test_convert_query_object_filter_coerces_integer_string_value() -> None:
"""Test scalar filter values are coerced to dimension type."""
all_dimensions = {
"birthyear": Dimension(
"birthyear",
"birthyear",
pa.int64(),
"birthyear",
"Birthyear",
)
}
filter_: ValidatedQueryObjectFilterClause = {
"op": FilterOperator.GREATER_THAN_OR_EQUALS.value,
"col": "birthyear",
"val": "1982",
}
result = _convert_query_object_filter(filter_, all_dimensions)
assert result == {
Filter(
type=PredicateType.WHERE,
column=all_dimensions["birthyear"],
operator=Operator.GREATER_THAN_OR_EQUAL,
value=1982,
)
}
def test_convert_query_object_filter_coerces_in_integer_values() -> None:
"""Test IN filter list values are coerced element-wise."""
all_dimensions = {
"order_id__amount": Dimension(
"order_id__amount",
"order_id__amount",
pa.int64(),
"order_id__amount",
"Order amount",
)
}
filter_: ValidatedQueryObjectFilterClause = {
"op": FilterOperator.IN.value,
"col": "order_id__amount",
"val": ["58", "61"],
}
result = _convert_query_object_filter(filter_, all_dimensions)
assert result == {
Filter(
type=PredicateType.WHERE,
column=all_dimensions["order_id__amount"],
operator=Operator.IN,
value=frozenset({58, 61}),
)
}
def test_convert_query_object_filter_invalid_integer_value_raises() -> None:
"""Test invalid integer value raises a clear error."""
all_dimensions = {
"birthyear": Dimension(
"birthyear",
"birthyear",
pa.int64(),
"birthyear",
"Birthyear",
)
}
filter_: ValidatedQueryObjectFilterClause = {
"op": FilterOperator.GREATER_THAN_OR_EQUALS.value,
"col": "birthyear",
"val": "nineteen-eighty-two",
}
with pytest.raises(
ValueError,
match="Invalid integer value 'nineteen-eighty-two' for filter column birthyear",
):
_convert_query_object_filter(filter_, all_dimensions)
def test_get_results_without_time_offsets(
mock_datasource: MagicMock,
mocker: MockerFixture,
@@ -1923,6 +2011,86 @@ def test_convert_query_object_filter_temporal_range_with_value() -> None:
}
def test_convert_query_object_filter_temporal_range_coerces_date_bounds() -> None:
"""
TEMPORAL_RANGE bounds should be coerced against the dimension's dtype so
date/timestamp columns are not compared against raw strings.
"""
all_dimensions = {
"order_date": Dimension(
"order_date", "order_date", pa.date32(), "order_date", "Order date"
)
}
filter_: ValidatedQueryObjectFilterClause = {
"op": FilterOperator.TEMPORAL_RANGE.value,
"col": "order_date",
"val": "2025-01-01 : 2025-12-31",
}
result = _convert_query_object_filter(filter_, all_dimensions)
assert result == {
Filter(
type=PredicateType.WHERE,
column=all_dimensions["order_date"],
operator=Operator.GREATER_THAN_OR_EQUAL,
value=date(2025, 1, 1),
),
Filter(
type=PredicateType.WHERE,
column=all_dimensions["order_date"],
operator=Operator.LESS_THAN,
value=date(2025, 12, 31),
),
}
def test_convert_query_object_filter_temporal_range_open_ended() -> None:
"""
Open-ended TEMPORAL_RANGE bounds should emit only the bounded predicate.
"""
all_dimensions = {
"order_date": Dimension(
"order_date", "order_date", pa.date32(), "order_date", "Order date"
)
}
only_start: ValidatedQueryObjectFilterClause = {
"op": FilterOperator.TEMPORAL_RANGE.value,
"col": "order_date",
"val": "2025-01-01 : ",
}
assert _convert_query_object_filter(only_start, all_dimensions) == {
Filter(
type=PredicateType.WHERE,
column=all_dimensions["order_date"],
operator=Operator.GREATER_THAN_OR_EQUAL,
value=date(2025, 1, 1),
),
}
only_end: ValidatedQueryObjectFilterClause = {
"op": FilterOperator.TEMPORAL_RANGE.value,
"col": "order_date",
"val": " : 2025-12-31",
}
assert _convert_query_object_filter(only_end, all_dimensions) == {
Filter(
type=PredicateType.WHERE,
column=all_dimensions["order_date"],
operator=Operator.LESS_THAN,
value=date(2025, 12, 31),
),
}
empty: ValidatedQueryObjectFilterClause = {
"op": FilterOperator.TEMPORAL_RANGE.value,
"col": "order_date",
"val": " : ",
}
assert _convert_query_object_filter(empty, all_dimensions) is None
def test_get_order_adhoc_with_none_sql_expression(mock_datasource: MagicMock) -> None:
"""
Test order extraction skips adhoc expression with None sqlExpression.
@@ -2741,3 +2909,205 @@ def test_get_group_limit_filters_no_granularity(
# Should return None - no granularity means no time filters added
assert result is None
# ---------------------------------------------------------------------------
# _coerce_scalar_filter_value: per-dtype branches
# ---------------------------------------------------------------------------
def _dim(dtype: pa.DataType, name: str = "d") -> Dimension:
return Dimension(name, name, dtype, name, name.capitalize())
def test_coerce_none_returns_none() -> None:
assert _coerce_scalar_filter_value(None, _dim(pa.int64())) is None
def test_coerce_unsupported_dtype_passes_through() -> None:
# utf8 (and any dtype not branched in the function) returns the value as-is.
assert _coerce_scalar_filter_value("abc", _dim(pa.utf8())) == "abc"
@pytest.mark.parametrize(
"raw,expected",
[
(True, True),
(False, False),
(1, True),
(0, False),
(1.0, True),
(0.0, False),
("true", True),
("T", True),
(" 1 ", True),
("yes", True),
("Y", True),
("on", True),
("false", False),
("F", False),
("0", False),
("no", False),
("N", False),
("off", False),
],
)
def test_coerce_boolean(raw: Any, expected: bool) -> None:
assert _coerce_scalar_filter_value(raw, _dim(pa.bool_())) is expected
@pytest.mark.parametrize("raw", ["maybe", 2, 0.5, -1])
def test_coerce_boolean_invalid_raises(raw: Any) -> None:
with pytest.raises(ValueError, match="Invalid boolean value"):
_coerce_scalar_filter_value(raw, _dim(pa.bool_()))
def test_coerce_integer_passthrough() -> None:
assert _coerce_scalar_filter_value(42, _dim(pa.int64())) == 42
def test_coerce_integer_accepts_integer_valued_float() -> None:
# JSON round-trips can turn an int into ``42.0``; accept losslessly.
assert _coerce_scalar_filter_value(42.0, _dim(pa.int64())) == 42
def test_coerce_integer_rejects_bool() -> None:
# bool is a subclass of int; we explicitly reject it.
with pytest.raises(ValueError, match="Invalid integer value"):
_coerce_scalar_filter_value(True, _dim(pa.int64()))
def test_coerce_integer_rejects_non_integer_float() -> None:
with pytest.raises(ValueError, match="Invalid integer value"):
_coerce_scalar_filter_value(1.5, _dim(pa.int64()))
def test_coerce_integer_rejects_other_types() -> None:
with pytest.raises(ValueError, match="Invalid integer value"):
_coerce_scalar_filter_value([1], _dim(pa.int64()))
@pytest.mark.parametrize(
"dtype",
[pa.float64(), pa.decimal128(10, 2)],
)
def test_coerce_floating_or_decimal(dtype: pa.DataType) -> None:
assert _coerce_scalar_filter_value(1, _dim(dtype)) == 1.0
assert _coerce_scalar_filter_value(1.5, _dim(dtype)) == 1.5
assert _coerce_scalar_filter_value(" 2.5 ", _dim(dtype)) == 2.5
def test_coerce_floating_rejects_bool() -> None:
with pytest.raises(ValueError, match="Invalid numeric value"):
_coerce_scalar_filter_value(True, _dim(pa.float64()))
def test_coerce_floating_invalid_string_raises() -> None:
with pytest.raises(ValueError, match="Invalid numeric value"):
_coerce_scalar_filter_value("not-a-number", _dim(pa.float64()))
def test_coerce_floating_rejects_other_types() -> None:
with pytest.raises(ValueError, match="Invalid numeric value"):
_coerce_scalar_filter_value([1.0], _dim(pa.float64()))
def test_coerce_date_from_datetime() -> None:
out = _coerce_scalar_filter_value(datetime(2025, 1, 2, 12, 0), _dim(pa.date32()))
assert out == date(2025, 1, 2)
def test_coerce_date_passthrough() -> None:
out = _coerce_scalar_filter_value(date(2025, 1, 2), _dim(pa.date32()))
assert out == date(2025, 1, 2)
def test_coerce_date_from_iso_string() -> None:
out = _coerce_scalar_filter_value(" 2025-01-02 ", _dim(pa.date32()))
assert out == date(2025, 1, 2)
def test_coerce_date_invalid_string_raises() -> None:
with pytest.raises(ValueError, match="Invalid date value"):
_coerce_scalar_filter_value("not-a-date", _dim(pa.date32()))
def test_coerce_date_rejects_other_types() -> None:
with pytest.raises(ValueError, match="Invalid date value"):
_coerce_scalar_filter_value(20250102, _dim(pa.date32()))
def test_coerce_timestamp_from_datetime_passthrough() -> None:
dt = datetime(2025, 1, 2, 3, 4, 5)
# Naive dtype: returned as-is, still naive.
assert _coerce_scalar_filter_value(dt, _dim(pa.timestamp("us"))) == dt
def test_coerce_timestamp_from_date() -> None:
out = _coerce_scalar_filter_value(date(2025, 1, 2), _dim(pa.timestamp("us")))
assert out == datetime(2025, 1, 2, 0, 0)
def test_coerce_timestamp_from_iso_string_with_z() -> None:
out = _coerce_scalar_filter_value("2025-01-02T03:04:05Z", _dim(pa.timestamp("us")))
assert out == datetime.fromisoformat("2025-01-02T03:04:05+00:00")
def test_coerce_timestamp_invalid_string_raises() -> None:
with pytest.raises(ValueError, match="Invalid timestamp value"):
_coerce_scalar_filter_value("not-a-ts", _dim(pa.timestamp("us")))
def test_coerce_timestamp_rejects_other_types() -> None:
with pytest.raises(ValueError, match="Invalid timestamp value"):
_coerce_scalar_filter_value(1234567890, _dim(pa.timestamp("us")))
def test_coerce_timestamp_tz_aware_dtype_attaches_tz_to_naive_datetime() -> None:
dt = datetime(2025, 1, 2, 3, 4, 5)
out = _coerce_scalar_filter_value(dt, _dim(pa.timestamp("us", tz="UTC")))
assert out == datetime(2025, 1, 2, 3, 4, 5, tzinfo=ZoneInfo("UTC"))
def test_coerce_timestamp_tz_aware_dtype_converts_aware_datetime() -> None:
dt = datetime(2025, 1, 2, 12, 0, tzinfo=timezone.utc)
out = _coerce_scalar_filter_value(
dt, _dim(pa.timestamp("us", tz="America/New_York"))
)
# 12:00 UTC == 07:00 in New York
assert out == datetime(2025, 1, 2, 7, 0, tzinfo=ZoneInfo("America/New_York"))
def test_coerce_timestamp_tz_aware_dtype_attaches_tz_to_date() -> None:
out = _coerce_scalar_filter_value(
date(2025, 1, 2), _dim(pa.timestamp("us", tz="UTC"))
)
assert out == datetime(2025, 1, 2, 0, 0, tzinfo=ZoneInfo("UTC"))
def test_coerce_timestamp_tz_aware_dtype_parses_string_with_tz() -> None:
out = _coerce_scalar_filter_value(
"2025-01-02T03:04:05", _dim(pa.timestamp("us", tz="UTC"))
)
# Naive string gets UTC attached.
assert out == datetime(2025, 1, 2, 3, 4, 5, tzinfo=ZoneInfo("UTC"))
def test_coerce_time_passthrough() -> None:
out = _coerce_scalar_filter_value(time(3, 4, 5), _dim(pa.time64("us")))
assert out == time(3, 4, 5)
def test_coerce_time_from_iso_string() -> None:
out = _coerce_scalar_filter_value(" 03:04:05 ", _dim(pa.time64("us")))
assert out == time(3, 4, 5)
def test_coerce_time_invalid_string_raises() -> None:
with pytest.raises(ValueError, match="Invalid time value"):
_coerce_scalar_filter_value("not-a-time", _dim(pa.time64("us")))
def test_coerce_time_rejects_other_types() -> None:
with pytest.raises(ValueError, match="Invalid time value"):
_coerce_scalar_filter_value(123, _dim(pa.time64("us")))

View File

@@ -811,3 +811,90 @@ class TestWebDriverPlaywrightErrorHandling:
mock_logger.exception.assert_any_call(
"Timed out requesting url %s", "http://example.com"
)
@patch("superset.utils.webdriver.PLAYWRIGHT_AVAILABLE", True)
@patch("superset.utils.webdriver.sync_playwright")
@patch("superset.utils.webdriver.logger")
def test_missing_element_for_dashboard_height_falls_back_without_crashing(
self, mock_logger, mock_sync_playwright
):
"""Missing dashboard element should not crash height evaluation."""
mock_user = MagicMock()
mock_user.username = "test_user"
mock_playwright_instance = MagicMock()
mock_browser = MagicMock()
mock_context = MagicMock()
mock_page = MagicMock()
mock_element = MagicMock()
mock_chart_container = MagicMock()
mock_sync_playwright.return_value.__enter__.return_value = (
mock_playwright_instance
)
mock_playwright_instance.chromium.launch.return_value = mock_browser
mock_browser.new_context.return_value = mock_context
mock_context.new_page.return_value = mock_page
def locator_side_effect(selector):
if selector == ".dashboard":
return mock_element
if selector == ".chart-container":
locator = MagicMock()
locator.all.return_value = [mock_chart_container]
return locator
if selector == ".loading":
locator = MagicMock()
locator.all.return_value = []
return locator
return MagicMock()
mock_page.locator.side_effect = locator_side_effect
mock_element.wait_for.return_value = None
mock_element.screenshot.return_value = b"fake_screenshot"
mock_chart_container.wait_for.return_value = None
mock_page.wait_for_timeout.return_value = None
def evaluate_side_effect(script):
if script == 'document.querySelectorAll(".chart-container").length':
return 1
if "const target = document.querySelector" in script:
return 0
return None
mock_page.evaluate.side_effect = evaluate_side_effect
with patch("superset.utils.webdriver.app") as mock_app:
mock_app.config = {
"WEBDRIVER_OPTION_ARGS": [],
"WEBDRIVER_WINDOW": {"pixel_density": 1},
"SCREENSHOT_PLAYWRIGHT_DEFAULT_TIMEOUT": 30000,
"SCREENSHOT_PLAYWRIGHT_WAIT_EVENT": "networkidle",
"SCREENSHOT_SELENIUM_HEADSTART": 5,
"SCREENSHOT_SELENIUM_ANIMATION_WAIT": 1,
"SCREENSHOT_LOCATE_WAIT": 10,
"SCREENSHOT_LOAD_WAIT": 10,
"SCREENSHOT_WAIT_FOR_ERROR_MODAL_VISIBLE": 10,
"SCREENSHOT_WAIT_FOR_ERROR_MODAL_INVISIBLE": 10,
"SCREENSHOT_REPLACE_UNEXPECTED_ERRORS": False,
"SCREENSHOT_TILED_ENABLED": True,
"SCREENSHOT_TILED_CHART_THRESHOLD": 20,
"SCREENSHOT_TILED_HEIGHT_THRESHOLD": 5000,
"SCREENSHOT_TILED_VIEWPORT_HEIGHT": 600,
}
with patch.object(WebDriverPlaywright, "auth") as mock_auth:
mock_auth.return_value = mock_context
driver = WebDriverPlaywright("chrome")
result = driver.get_screenshot(
"http://example.com", "dashboard", mock_user
)
assert result == b"fake_screenshot"
mock_logger.warning.assert_any_call(
"Could not determine dashboard height for element %s at url %s; "
"falling back to standard screenshot behavior",
"dashboard",
"http://example.com",
)