Compare commits

...

15 Commits

Author SHA1 Message Date
Amin Ghadersohi
23056838c9 fix(mcp): remove created_by_fk from ReportFilter public schema and use ColumnOperator for filters_applied
- ReportFilter.col Literal no longer includes created_by_fk; callers
  must use the created_by_me flag instead (internal-only filter)
- ReportList.filters_applied typed as List[ColumnOperator] so that
  internally-injected ColumnOperator instances (e.g. created_by_fk from
  created_by_me=True) pass pydantic validation without coercion errors
2026-05-21 20:09:54 +00:00
Amin Ghadersohi
3bbdd6150b fix(mcp): add created_by_fk to ReportFilter allowed columns
created_by_fk was removed from SELF_REFERENCING_FILTER_COLUMNS so it
now appears in filters_applied, but ReportFilter.col Literal didn't
include it, causing pydantic validation error in list_reports responses.
2026-05-21 19:12:44 +00:00
Amin Ghadersohi
6d3ae5e476 fix(mcp): restore created_by_fk as public filter column, keep owners.id for reports 2026-05-21 19:10:22 +00:00
Amin Ghadersohi
fa2eeace4c fix(mcp): inject owners.id and created_by_fk filters for report list tools
- Add ReportListCore subclass in list_reports.py that overrides filter
  injection to use owners.id (instead of generic owner) and calls the
  DAO with filters= kwarg (instead of column_operators=) so tests can
  assert on the kwarg by name
- Extract _call_dao_list hook in ModelListCore so subclasses can change
  the DAO kwarg name without duplicating run_tool
- Add owners.id to SELF_REFERENCING_FILTER_COLUMNS so it is excluded
  from filters_applied in responses

Fixes: test_list_reports_owned_by_me_passed_to_dao,
       test_list_reports_created_by_me_passed_to_dao
2026-05-21 19:10:22 +00:00
Amin Ghadersohi
d0f49a1875 refactor(mcp): address review findings for list/get report tools
- Register report model type in get_schema (Fix #1): add _get_report_schema_core
  factory + "report" entry in _SCHEMA_CORE_FACTORIES; ModelType now includes "report"
- Add OwnedByMeMixin/CreatedByMeMixin to ListReportsRequest (Fix #2)
- DRY up list_reports.py column constants (Fix #3): import REPORT_* constants and
  get_report_columns from schema_discovery; pass created_by_me/owned_by_me to run_tool
- Extend test coverage (Fix #6): humanized timestamp fields, invalid order_column
  guard, owned_by_me/created_by_me DAO filter injection

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 19:10:22 +00:00
Amin Ghadersohi
2e87c2c48e feat(mcp): add list and get tools for alerts and reports
Adds list_reports and get_report_info MCP tools under a new
superset/mcp_service/report/ domain, following the canonical database
domain pattern. Includes unit tests and app.py registration.
2026-05-21 19:10:22 +00:00
Evan Rusackas
73f66e4c14 fix(datasets): isolate filter state to fix concurrent /dataset race (#39685)
Co-authored-by: Claude Code <noreply@anthropic.com>
2026-05-21 11:12:32 -07:00
Elizabeth Thompson
f187a8e1c4 fix(reports): guard null dashboard height in Playwright screenshots (#40179) 2026-05-21 09:19:29 -07:00
Mehmet Salih Yavuz
4c3f65ef0b feat(mcp): make config optional in generate_explore_link (#39559) 2026-05-21 18:01:59 +03:00
Mehmet Salih Yavuz
53d8e5bdfa feat(mcp): include applied dashboard filters in get_chart_info (#39620)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 17:48:21 +03:00
Mehmet Salih Yavuz
2f95d288dd fix(mcp): eager-load dataset.metrics to prevent Excel export DetachedInstanceError (#39483)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 16:34:38 +03:00
Beto Dealmeida
2f5fcc21f9 fix(semantic layers): coerce filter types (#40222) 2026-05-21 09:25:27 -04:00
Mehmet Salih Yavuz
d1d07112aa feat(mcp): add find_users tool and owner filter columns for listings (#39679)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 15:59:09 +03:00
Alexandru Soare
e3711bec39 fix(recommandation): Fix chart recommandation (#39886) 2026-05-21 15:16:16 +03:00
Mehmet Salih Yavuz
ce9cab098f feat(mcp): chart formatting options across all supported chart types (#39887) 2026-05-21 15:00:32 +03:00
44 changed files with 4156 additions and 164 deletions

View File

@@ -123,6 +123,10 @@ Database Connections:
- list_databases: List database connections with advanced filters (1-based pagination)
- get_database_info: Get detailed database connection info by ID (backend, capabilities)
Alerts & Reports:
- list_reports: List alerts and reports with filtering and search (1-based pagination)
- get_report_info: Get detailed alert/report schedule info by ID
Dataset Management:
- list_datasets: List datasets with advanced filters (1-based pagination)
- get_dataset_info: Get detailed dataset information by ID (includes columns/metrics)
@@ -150,6 +154,7 @@ Schema Discovery:
System Information:
- get_instance_info: Get instance-wide statistics, metadata, and current user identity
- find_users: Resolve a person's name to user IDs for use as a filter value
- health_check: Simple health check tool (takes NO parameters, call without arguments)
- generate_bug_report: Build a PII-sanitized bug report to send to Preset support
(use when the user says the MCP is broken or asks how to report an issue)
@@ -191,6 +196,16 @@ Some tools do not use a request wrapper, so follow each tool's schema
Recommended Workflows:
To filter dashboards/charts/datasets by a person ("show me what <name> is working on"):
1. find_users(request={{"query": "<name>"}}) -> resolve to user IDs
2. Pick the matching user.id from the response
3. list_dashboards(request={{"filters": [
{{"col": "created_by_fk", "opr": "eq", "value": <id>}}
]}}) — same shape for list_charts / list_datasets.
(use changed_by_fk for "last modified by", or "in" with a list of IDs for
multiple matches). Do NOT pass the person's name as the search parameter —
search matches titles, not people.
To add a chart to an existing dashboard:
1. add_chart_to_existing_dashboard(dashboard_id, chart_id) -> updates dashboard directly
- If permission_denied=True is returned: inform the user they lack edit rights,
@@ -360,6 +375,11 @@ Input format:
contact details, roles, admin status, ownership, or access-list information.
- Do NOT infer access-list answers from dashboard metadata such as published status,
role restrictions, empty owner lists, or schema fields.
- find_users is sanctioned ONLY for resolving a name the user supplied into a
user ID for filtering (e.g., "what is <name> working on" -> filter
list_dashboards by created_by_fk). Do NOT use find_users to answer "who owns
X", "who can access X", "is <name> an admin", or to enumerate the directory.
Never return find_users output to the user verbatim.
- Do NOT use execute_sql to query user, role, owner, or access-list tables for this
information.
- You may reference the current user's own identity details when appropriate, such
@@ -620,6 +640,10 @@ from superset.mcp_service.dataset.tool import ( # noqa: F401, E402
from superset.mcp_service.explore.tool import ( # noqa: F401, E402
generate_explore_link,
)
from superset.mcp_service.report.tool import ( # noqa: F401, E402
get_report_info,
list_reports,
)
from superset.mcp_service.sql_lab.tool import ( # noqa: F401, E402
execute_sql,
open_sql_lab_with_context,
@@ -630,6 +654,7 @@ from superset.mcp_service.system import ( # noqa: F401, E402
resources as system_resources,
)
from superset.mcp_service.system.tool import ( # noqa: F401, E402
find_users,
generate_bug_report,
get_instance_info,
get_schema,

View File

@@ -32,6 +32,7 @@ from urllib.parse import parse_qs, urlparse
from superset.constants import EXTRA_FORM_DATA_OVERRIDE_REGULAR_MAPPINGS
if TYPE_CHECKING:
from superset.mcp_service.chart.schemas import AppliedDashboardFilter
from superset.models.slice import Slice
logger = logging.getLogger(__name__)
@@ -44,20 +45,33 @@ QUERY_CONTEXT_EXTRA_FORM_DATA_OVERRIDE_KEYS = {
}
def find_chart_by_identifier(identifier: int | str) -> Slice | None:
class ChartNotOnDashboardError(ValueError):
"""Raised when a chart is not part of the given dashboard's slices."""
def find_chart_by_identifier(
identifier: int | str,
query_options: list[Any] | None = None,
) -> Slice | None:
"""Find a chart by numeric ID or UUID string.
Accepts an integer ID, a string that looks like a digit (e.g. "123"),
or a UUID string. Returns the Slice model instance or None.
``query_options`` is forwarded to the DAO so callers can eager-load
relationships needed after the request-scoped session is detached.
"""
from superset.daos.chart import ChartDAO # avoid circular import
extra: dict[str, Any] = (
{"query_options": query_options} if query_options is not None else {}
)
if isinstance(identifier, int) or (
isinstance(identifier, str) and identifier.isdigit()
):
chart_id = int(identifier) if isinstance(identifier, str) else identifier
return ChartDAO.find_by_id(chart_id)
return ChartDAO.find_by_id(identifier, id_column="uuid")
return ChartDAO.find_by_id(chart_id, **extra)
return ChartDAO.find_by_id(identifier, id_column="uuid", **extra)
def get_cached_form_data(form_data_key: str) -> str | None:
@@ -528,3 +542,118 @@ def extract_form_data_key_from_url(url: str | None) -> str | None:
parsed = urlparse(url)
values = parse_qs(parsed.query).get("form_data_key", [])
return values[0] if values else None
def _match_adhoc_by_subject(
adhoc_filters: Any, column: str | None
) -> tuple[str | None, Any] | None:
if not column or not isinstance(adhoc_filters, list):
return None
for af in adhoc_filters:
if isinstance(af, dict) and af.get("subject") == column:
return af.get("operator"), af.get("comparator")
return None
def _match_legacy_by_col(
legacy_filters: Any, column: str | None
) -> tuple[str | None, Any] | None:
if not column or not isinstance(legacy_filters, list):
return None
for f in legacy_filters:
if isinstance(f, dict) and f.get("col") == column:
return f.get("op"), f.get("val")
return None
def _resolve_filter_operator_and_value(
extra_form_data: dict[str, Any] | None,
column: str | None,
) -> tuple[str | None, Any]:
"""Pull operator and value for a dashboard filter from its
default extra_form_data, matching on target column where applicable."""
if not extra_form_data:
return None, None
if match := _match_adhoc_by_subject(extra_form_data.get("adhoc_filters"), column):
return match
if match := _match_legacy_by_col(extra_form_data.get("filters"), column):
return match
# Temporal filters contribute time_range with no target column
if time_range := extra_form_data.get("time_range"):
return "TIME_RANGE", time_range
return None, None
def build_applied_dashboard_filters(
dashboard_id: int, chart_id: int
) -> list[AppliedDashboardFilter]:
"""Resolve dashboard-level native filters in scope for a chart.
Validates that the dashboard exists, the caller has access, and the chart
is on the dashboard. Returns one AppliedDashboardFilter per non-DIVIDER
native filter whose scope includes the chart, populated with the filter's
default operator and value.
Raises DashboardNotFoundError if the dashboard is missing,
ChartNotOnDashboardError if the chart is not on it, and
SupersetSecurityException if the caller cannot access the dashboard.
"""
# Local imports avoid circular deps at module load
from superset import db, security_manager
from superset.charts.data.dashboard_filter_context import (
_extract_filter_extra_form_data,
_get_filter_target_column,
_is_filter_in_scope_for_chart,
)
from superset.commands.dashboard.exceptions import DashboardNotFoundError
from superset.mcp_service.chart.schemas import AppliedDashboardFilter
from superset.models.dashboard import Dashboard
from superset.utils import json
dashboard = db.session.query(Dashboard).filter_by(id=dashboard_id).one_or_none()
if not dashboard:
raise DashboardNotFoundError(dashboard_id=str(dashboard_id))
security_manager.raise_for_access(dashboard=dashboard)
slice_ids = {slc.id for slc in dashboard.slices}
if chart_id not in slice_ids:
raise ChartNotOnDashboardError(
f"Chart {chart_id} is not on dashboard {dashboard_id}"
)
metadata = json.loads(dashboard.json_metadata or "{}")
native_filter_config = metadata.get("native_filter_configuration", [])
if not isinstance(native_filter_config, list):
return []
position_json = json.loads(dashboard.position_json or "{}")
if not isinstance(position_json, dict):
position_json = {}
applied: list[AppliedDashboardFilter] = []
for flt in native_filter_config:
if not isinstance(flt, dict):
continue
if flt.get("type", "") == "DIVIDER":
continue
if not _is_filter_in_scope_for_chart(flt, chart_id, position_json):
continue
extra_form_data, status = _extract_filter_extra_form_data(flt)
column = _get_filter_target_column(flt)
operator, value = _resolve_filter_operator_and_value(extra_form_data, column)
applied.append(
AppliedDashboardFilter(
id=flt.get("id"),
name=flt.get("name"),
filter_type=flt.get("filterType"),
column=column,
operator=operator,
value=value,
status=status.value,
)
)
return applied

View File

@@ -32,6 +32,7 @@ from superset.mcp_service.chart.schemas import (
ChartCapabilities,
ChartSemantics,
ColumnRef,
CurrencyFormat,
FilterConfig,
HandlebarsChartConfig,
MixedTimeseriesChartConfig,
@@ -477,6 +478,7 @@ def map_table_config(config: TableChartConfig) -> Dict[str, Any]:
]
form_data["row_limit"] = config.row_limit
add_color_scheme(form_data, config.color_scheme)
return form_data
@@ -547,7 +549,35 @@ def add_legend_config(form_data: Dict[str, Any], config: XYChartConfig) -> None:
if not config.legend.show:
form_data["show_legend"] = False
if config.legend.position:
form_data["legend_orientation"] = config.legend.position
# Canonical form_data key is camelCase; the echarts plugins read
# `legendOrientation` directly off form_data.
form_data["legendOrientation"] = config.legend.position
def add_color_scheme(form_data: Dict[str, Any], color_scheme: str | None) -> None:
"""Add color scheme to form_data when set."""
if color_scheme:
form_data["color_scheme"] = color_scheme
def add_currency_format(
form_data: Dict[str, Any],
currency_format: CurrencyFormat | None,
key: str = "currency_format",
) -> None:
"""Add currency format to form_data under the given key when set."""
if currency_format:
form_data[key] = currency_format.to_form_data()
def add_xy_data_label_options(
form_data: Dict[str, Any], config: XYChartConfig, x_is_temporal: bool
) -> None:
"""Apply XY-specific data-label and time-format options when set."""
if config.x_axis_time_format and x_is_temporal:
form_data["x_axis_time_format"] = config.x_axis_time_format
if config.show_value:
form_data["show_value"] = True
def add_orientation_config(form_data: Dict[str, Any], config: XYChartConfig) -> None:
@@ -722,6 +752,9 @@ def map_xy_config(
add_axis_config(form_data, config)
add_legend_config(form_data, config)
add_orientation_config(form_data, config)
add_color_scheme(form_data, config.color_scheme)
add_currency_format(form_data, config.currency_format)
add_xy_data_label_options(form_data, config, x_is_temporal)
return form_data
@@ -734,11 +767,13 @@ def map_pie_config(config: PieChartConfig) -> Dict[str, Any]:
"viz_type": "pie",
"groupby": [config.dimension.name],
"metric": metric,
"color_scheme": "supersetColors",
"color_scheme": config.color_scheme or "supersetColors",
"show_labels": config.show_labels,
"show_legend": config.show_legend,
"legendOrientation": config.legend_orientation,
"label_type": config.label_type,
"number_format": config.number_format,
"date_format": config.date_format,
"sort_by_metric": config.sort_by_metric,
"row_limit": config.row_limit,
"donut": config.donut,
@@ -746,9 +781,9 @@ def map_pie_config(config: PieChartConfig) -> Dict[str, Any]:
"labels_outside": config.labels_outside,
"outerRadius": config.outer_radius,
"innerRadius": config.inner_radius,
"date_format": "smart_date",
}
add_currency_format(form_data, config.currency_format)
_add_adhoc_filters(form_data, config.filters)
return form_data
@@ -774,6 +809,9 @@ def map_big_number_config(config: BigNumberChartConfig) -> Dict[str, Any]:
if config.y_axis_format:
form_data["y_axis_format"] = config.y_axis_format
add_color_scheme(form_data, config.color_scheme)
add_currency_format(form_data, config.currency_format)
# Trendline-specific fields
if viz_type == "big_number":
# Big Number with trendline uses granularity_sqla for the temporal column
@@ -789,6 +827,9 @@ def map_big_number_config(config: BigNumberChartConfig) -> Dict[str, Any]:
if config.compare_lag is not None:
form_data["compare_lag"] = config.compare_lag
if config.time_format:
form_data["time_format"] = config.time_format
_add_adhoc_filters(form_data, config.filters)
return form_data
@@ -860,6 +901,10 @@ def map_pivot_table_config(config: PivotTableChartConfig) -> Dict[str, Any]:
"row_limit": config.row_limit,
}
if config.date_format:
form_data["date_format"] = config.date_format
add_currency_format(form_data, config.currency_format)
_add_adhoc_filters(form_data, config.filters)
return form_data
@@ -939,10 +984,20 @@ def map_mixed_timeseries_config(
"yAxisIndexB": 1,
# Display
"show_legend": config.show_legend,
"legendOrientation": config.legend_orientation,
"zoomable": True,
"rich_tooltip": True,
}
if config.show_value:
form_data["show_value"] = True
add_color_scheme(form_data, config.color_scheme)
add_currency_format(form_data, config.currency_format)
add_currency_format(
form_data, config.currency_format_secondary, key="currency_format_secondary"
)
# Configure temporal handling
configure_temporal_handling(form_data, x_is_temporal, config.time_grain)

View File

@@ -23,7 +23,7 @@ from __future__ import annotations
import difflib
from datetime import datetime
from typing import Annotated, Any, Dict, List, Literal, Protocol
from typing import Annotated, Any, cast, Dict, List, Literal, Protocol
import humanize
from pydantic import (
@@ -146,7 +146,7 @@ class ChartInfo(BaseModel):
),
)
form_data_key: str | None = Field(
None,
default=None,
description=(
"Cache key used to retrieve unsaved form_data. When present, indicates "
"the form_data came from cache (unsaved edits) rather than the saved chart."
@@ -279,6 +279,15 @@ class GetChartInfoRequest(BaseModel):
"Can be used alone (without identifier) for unsaved charts."
),
)
dashboard_id: int | None = Field(
default=None,
description=(
"When provided, resolves dashboard-level native filters that are in "
"scope for this chart on the given dashboard and returns them under "
"filters.dashboard_filters. Requires the chart to be on the dashboard "
"and the caller to have dashboard access."
),
)
@model_validator(mode="after")
def validate_identifier_or_form_data_key(self) -> "GetChartInfoRequest":
@@ -523,17 +532,18 @@ class ChartFilter(ColumnOperator):
value: The value to filter by (type depends on col and opr).
"""
col: Literal[
col: Literal[ # pyright: ignore[reportIncompatibleVariableOverride]
"slice_name",
"viz_type",
"datasource_name",
"created_by_fk",
"changed_by_fk",
] = Field(
...,
description=(
"Column to filter on. Valid values: 'slice_name', 'viz_type', "
"'datasource_name'. Other column names are not valid filter columns "
"and will cause a validation error."
),
description="Column to filter on. Use get_schema(model_type='chart') for "
"available filter columns. To filter by a person, first call find_users "
"to resolve a name to a user ID, then filter by created_by_fk or "
"changed_by_fk with that integer ID.",
)
opr: ColumnOperatorEnum = Field(
...,
@@ -731,6 +741,29 @@ class LegendConfig(BaseModel):
position: Literal["top", "bottom", "left", "right"] | None = "right"
class CurrencyFormat(BaseModel):
"""Currency symbol and placement applied to numeric values."""
model_config = ConfigDict(populate_by_name=True)
symbol: str = Field(
...,
description="Currency code or symbol (e.g. 'USD', 'EUR', '$', '')",
max_length=20,
)
symbol_position: Literal["prefix", "suffix"] = Field(
"prefix",
description="Whether to render the symbol before or after the value",
validation_alias=AliasChoices("symbol_position", "symbolPosition"),
)
def to_form_data(self) -> Dict[str, str]:
return {"symbol": self.symbol, "symbolPosition": self.symbol_position}
LEGEND_POSITION_LITERAL = Literal["top", "bottom", "left", "right"]
class FilterConfig(BaseModel):
model_config = ConfigDict(populate_by_name=True)
@@ -857,6 +890,27 @@ class PieChartConfig(UnknownFieldCheckMixin):
)
row_limit: int = Field(100, description="Max slices", ge=1, le=10000)
number_format: str = Field("SMART_NUMBER", max_length=50)
date_format: str = Field(
"smart_date",
description="Date format for date dimension labels (e.g. 'smart_date', "
"'%Y-%m-%d')",
max_length=50,
)
currency_format: CurrencyFormat | None = Field(
None,
description="Currency symbol applied to the metric value",
)
color_scheme: str | None = Field(
None,
description=(
"Superset color scheme ID (e.g. 'supersetColors', 'lyftColors', "
"'googleCategory10c', 'd3Category10'). Defaults to 'supersetColors'."
),
max_length=100,
)
legend_orientation: LEGEND_POSITION_LITERAL = Field(
"top", description="Legend placement around the chart"
)
show_total: bool = Field(False, description="Show total in center")
labels_outside: bool = True
outer_radius: int = Field(70, description="Outer radius % (1-100)", ge=1, le=100)
@@ -908,6 +962,15 @@ class PivotTableChartConfig(UnknownFieldCheckMixin):
)
row_limit: int = Field(10000, description="Max cells", ge=1, le=50000)
value_format: str = Field("SMART_NUMBER", max_length=50)
date_format: str | None = Field(
None,
description="Date format for date columns (e.g. 'smart_date', '%Y-%m-%d')",
max_length=50,
)
currency_format: CurrencyFormat | None = Field(
None,
description="Currency symbol applied to numeric metric values",
)
class MixedTimeseriesChartConfig(UnknownFieldCheckMixin):
@@ -954,9 +1017,29 @@ class MixedTimeseriesChartConfig(UnknownFieldCheckMixin):
)
# Display options
show_legend: bool = True
legend_orientation: LEGEND_POSITION_LITERAL = Field(
"top", description="Legend placement around the chart"
)
show_value: bool = Field(False, description="Show data labels on each data point")
x_axis: AxisConfig | None = None
y_axis: AxisConfig | None = None
y_axis_secondary: AxisConfig | None = None
color_scheme: str | None = Field(
None,
description=(
"Superset color scheme ID (e.g. 'supersetColors', 'lyftColors'). "
"When omitted, Superset's default scheme is used."
),
max_length=100,
)
currency_format: CurrencyFormat | None = Field(
None,
description="Currency symbol applied to primary metric values",
)
currency_format_secondary: CurrencyFormat | None = Field(
None,
description="Currency symbol applied to secondary metric values",
)
filters: List[FilterConfig] | None = Field(
None,
description="Structured filters (column/op/value). "
@@ -1132,6 +1215,27 @@ class BigNumberChartConfig(UnknownFieldCheckMixin):
),
max_length=50,
)
time_format: str | None = Field(
None,
description=(
"Date format string for trendline x-axis labels "
"(e.g. 'smart_date', '%Y-%m-%d'). Only applies when "
"show_trendline=True."
),
max_length=50,
)
currency_format: CurrencyFormat | None = Field(
None,
description="Currency symbol applied to the metric value",
)
color_scheme: str | None = Field(
None,
description=(
"Superset color scheme ID for the trendline (e.g. 'supersetColors'). "
"When omitted, Superset's default scheme is used."
),
max_length=100,
)
start_y_axis_at_zero: bool = Field(
True,
description="Anchor trendline y-axis at zero",
@@ -1217,6 +1321,14 @@ class TableChartConfig(UnknownFieldCheckMixin):
validation_alias=AliasChoices("sort_by", "order_by_cols", "order_by"),
)
row_limit: int = Field(1000, description="Max rows returned", ge=1, le=50000)
color_scheme: str | None = Field(
None,
description=(
"Superset color scheme ID applied to conditional/cell formatting "
"(e.g. 'supersetColors')."
),
max_length=100,
)
@model_validator(mode="after")
def validate_unique_column_labels(self) -> "TableChartConfig":
@@ -1298,6 +1410,28 @@ class XYChartConfig(UnknownFieldCheckMixin):
x_axis: AxisConfig | None = None
y_axis: AxisConfig | None = None
legend: LegendConfig | None = None
x_axis_time_format: str | None = Field(
None,
description=(
"Date format for temporal x-axis labels (e.g. 'smart_date', "
"'%Y-%m-%d'). Only applies when the x-axis column is temporal."
),
max_length=50,
)
show_value: bool = Field(False, description="Show data labels on each data point")
currency_format: CurrencyFormat | None = Field(
None,
description="Currency symbol applied to metric values",
)
color_scheme: str | None = Field(
None,
description=(
"Superset color scheme ID (e.g. 'supersetColors', 'lyftColors', "
"'googleCategory10c', 'd3Category10'). When omitted, Superset's "
"default scheme is used."
),
max_length=100,
)
filters: List[FilterConfig] | None = Field(
None,
description="Structured filters (column/op/value). "
@@ -1414,7 +1548,10 @@ class ListChartsRequest(OwnedByMeMixin, CreatedByMeMixin, MetadataCacheControl):
"""
from superset.mcp_service.utils.schema_utils import parse_json_or_model_list
return parse_json_or_model_list(v, ChartFilter, "filters")
return cast(
List[ChartFilter],
parse_json_or_model_list(v, ChartFilter, "filters"),
)
@field_validator("select_columns", mode="before")
@classmethod
@@ -1575,7 +1712,14 @@ class GenerateChartRequest(QueryCacheControl):
class GenerateExploreLinkRequest(FormDataCacheControl):
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
config: ChartConfig = Field(..., description="Chart configuration")
config: ChartConfig | None = Field(
None,
description=(
"Chart configuration. Optional; omit to get a default "
"explore URL that opens the dataset in Superset without a "
"preconfigured chart."
),
)
class UpdateChartRequest(QueryCacheControl):
@@ -2064,6 +2208,38 @@ class AdhocFilter(BaseModel):
model_config = ConfigDict(extra="ignore")
class AppliedDashboardFilter(BaseModel):
"""A dashboard-level native filter resolved against a specific chart.
Returned when get_chart_info is called with a dashboard_id. Values come
from the filter's default state on the saved dashboard (not a permalink).
"""
id: str | None = Field(None, description="Native filter ID")
name: str | None = Field(None, description="Filter display name")
filter_type: str | None = Field(
None, description="Native filter type (e.g. filter_select, filter_range)"
)
column: str | None = Field(None, description="Target column the filter applies to")
operator: str | None = Field(
None,
description=(
"Filter operator as stored in extra_form_data (e.g. 'IN', '==', 'LIKE', "
"or 'TIME_RANGE' for temporal filters with no target column)"
),
)
value: Any | None = Field(
None, description="Filter value(s) from the default data mask"
)
status: str = Field(
...,
description=(
"Whether the filter contributes to the chart query: 'applied', "
"'not_applied', or 'not_applied_uses_default_to_first_item_prequery'"
),
)
class ChartFiltersInfo(BaseModel):
"""Structured representation of all filters applied to a chart."""
@@ -2105,6 +2281,15 @@ class ChartFiltersInfo(BaseModel):
None,
description="Custom HAVING clause applied to the chart query",
)
dashboard_filters: List[AppliedDashboardFilter] = Field(
default_factory=list,
description=(
"Dashboard-level native filters in scope for this chart on the "
"dashboard passed via get_chart_info's dashboard_id argument. Empty "
"when no dashboard_id was provided or no native filter targets this "
"chart."
),
)
# Rebuild ChartInfo so Pydantic can resolve the ChartFiltersInfo forward reference.

View File

@@ -26,6 +26,7 @@ from typing import Any, Dict, List, TYPE_CHECKING
from fastmcp import Context
from flask import current_app
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import subqueryload
from superset_core.mcp.decorators import tool, ToolAnnotations
if TYPE_CHECKING:
@@ -58,9 +59,177 @@ from superset.mcp_service.utils.oauth2_utils import (
build_oauth2_redirect_message,
OAUTH2_CONFIG_ERROR_MESSAGE,
)
from superset.utils.core import GenericDataType
logger = logging.getLogger(__name__)
_GENERIC_TYPE_MAP: dict[int, str] = {
GenericDataType.NUMERIC: "numeric",
GenericDataType.STRING: "string",
GenericDataType.TEMPORAL: "temporal",
GenericDataType.BOOLEAN: "boolean",
}
# Maps Superset viz_type strings to canonical categories so we can
# avoid recommending a chart type the user already has.
_VIZ_CATEGORY: dict[str, str] = {
"echarts_timeseries_line": "line",
"echarts_timeseries_smooth": "line",
"echarts_timeseries_step": "line",
"echarts_timeseries": "line",
"echarts_timeseries_bar": "bar",
"echarts_area": "area",
"echarts_timeseries_scatter": "scatter",
"mixed_timeseries": "line",
"table": "table",
"pie": "pie",
"big_number": "kpi",
"big_number_total": "kpi",
"pop_kpi": "kpi",
"dist_bar": "bar",
"line": "line",
"area": "area",
"scatter": "scatter",
"bubble": "bubble",
"treemap_v2": "treemap",
"sunburst_v2": "treemap",
"heatmap_v2": "heatmap",
"gauge_chart": "gauge",
"funnel": "funnel",
"histogram": "histogram",
"histogram_v2": "histogram",
"box_plot": "box_plot",
"world_map": "map",
"pivot_table_v2": "table",
}
_MAX_RECOMMENDATIONS = 4
def _recommend_visualizations(
viz_type: str,
columns: list[DataColumn],
row_count: int,
) -> list[str]:
"""Suggest visualization types based on column types,
cardinality, and the chart's current viz_type.
"""
if not columns:
return ["table"]
current_category = _VIZ_CATEGORY.get(viz_type, viz_type)
candidates = _build_candidates(columns, row_count)
if not candidates:
candidates = ["table", "bar chart"]
return _filter_candidates(candidates, current_category)
def _build_candidates(
columns: list[DataColumn],
row_count: int,
) -> list[str]:
"""Build candidate visualization list from column metadata."""
temporal = [c for c in columns if c.data_type == "temporal"]
numeric = [c for c in columns if c.data_type == "numeric"]
categorical = [c for c in columns if c.data_type in ("string", "boolean")]
if temporal and numeric:
return _candidates_temporal_numeric(numeric, row_count)
if categorical and numeric:
return _candidates_categorical_numeric(numeric, categorical)
if len(numeric) >= 2:
return _candidates_multi_numeric(numeric, categorical)
if len(numeric) == 1 and not temporal and not categorical:
return _candidates_single_numeric(numeric[0], row_count)
return []
def _candidates_temporal_numeric(
numeric: list[DataColumn], row_count: int
) -> list[str]:
# Few data points are better as a bar chart than a line
if row_count < 5:
candidates = ["bar chart", "table"]
else:
candidates = ["line chart", "area chart", "bar chart"]
if len(numeric) > 1:
candidates.append("multi-line chart")
return candidates
def _candidates_categorical_numeric(
numeric: list[DataColumn],
categorical: list[DataColumn],
) -> list[str]:
candidates = ["bar chart"]
if len(numeric) == 1 and categorical[0].unique_count <= 10:
candidates.append("pie chart")
if len(numeric) >= 2:
candidates.append("scatter plot")
candidates.append("heatmap")
if any(c.unique_count > 5 for c in categorical):
candidates.append("treemap")
return candidates
def _candidates_single_numeric(col: DataColumn, row_count: int) -> list[str]:
candidates = ["big number / KPI", "gauge chart"]
if row_count > 20 and col.unique_count > 10:
candidates.insert(0, "histogram")
return candidates
def _candidates_multi_numeric(
numeric: list[DataColumn],
categorical: list[DataColumn],
) -> list[str]:
candidates = ["scatter plot"]
if len(numeric) >= 3:
candidates.append("bubble chart")
if categorical:
candidates.append("heatmap")
return candidates
# Maps each candidate string to a canonical category for dedup
# against the current viz_type.
_CANDIDATE_CATEGORY: dict[str, str] = {
"line chart": "line",
"multi-line chart": "line",
"area chart": "area",
"bar chart": "bar",
"scatter plot": "scatter",
"bubble chart": "bubble",
"pie chart": "pie",
"treemap": "treemap",
"heatmap": "heatmap",
"big number / KPI": "kpi",
"gauge chart": "gauge",
"histogram": "histogram",
"table": "table",
}
def _filter_candidates(
candidates: list[str],
current_category: str,
) -> list[str]:
"""Deduplicate, exclude the current viz category, and cap."""
seen: set[str] = set()
result: list[str] = []
for c in candidates:
if c in seen:
continue
if _CANDIDATE_CATEGORY.get(c) == current_category:
continue
seen.add(c)
result.append(c)
if len(result) >= _MAX_RECOMMENDATIONS:
break
return result
def _sanitize_chart_data_for_llm_context(chart_data: ChartData) -> ChartData:
"""Wrap chart data read-path descriptive fields before LLM exposure."""
@@ -182,7 +351,18 @@ async def get_chart_data( # noqa: C901
# Build query context entirely from cached form_data
return await _query_from_form_data(cached_form_data_dict, request, ctx)
# Find the chart by identifier
# Find the chart by identifier.
# Eagerly load the dataset's metrics relationship so Excel export
# (which may run after the request-scoped session is detached) can
# access dataset.metrics without triggering a lazy load. See
# apache/superset#39206 for the analogous database eager-load fix.
from superset.connectors.sqla.models import SqlaTable
from superset.models.slice import Slice
chart_query_options = [
subqueryload(Slice.table).subqueryload(SqlaTable.metrics),
]
with event_logger.log_context(action="mcp.get_chart_data.chart_lookup"):
await ctx.debug("Looking up chart: identifier=%s" % (request.identifier,))
if request.identifier is None:
@@ -190,7 +370,9 @@ async def get_chart_data( # noqa: C901
error="Chart identifier is required",
error_type="ValidationError",
)
chart = find_chart_by_identifier(request.identifier)
chart = find_chart_by_identifier(
request.identifier, query_options=chart_query_options
)
if not chart:
await ctx.warning("Chart not found: identifier=%s" % (request.identifier,))
@@ -484,8 +666,9 @@ async def get_chart_data( # noqa: C901
)
# Create rich column metadata
coltypes = query_result.get("coltypes", [])
columns = []
for col_name in raw_columns:
for idx, col_name in enumerate(raw_columns):
# Sample some values for metadata
sample_values = [
row.get(col_name)
@@ -493,13 +676,16 @@ async def get_chart_data( # noqa: C901
if row.get(col_name) is not None
]
# Infer data type
# Use SQL-derived GenericDataType when available,
# fall back to Python isinstance heuristic
data_type = "string"
if sample_values:
if all(isinstance(v, (int, float)) for v in sample_values):
data_type = "numeric"
elif all(isinstance(v, bool) for v in sample_values):
if coltypes:
data_type = _GENERIC_TYPE_MAP.get(coltypes[idx], "string")
elif sample_values:
if all(isinstance(v, bool) for v in sample_values):
data_type = "boolean"
elif all(isinstance(v, (int, float)) for v in sample_values):
data_type = "numeric"
columns.append(
DataColumn(
@@ -542,13 +728,11 @@ async def get_chart_data( # noqa: C901
else:
insights.append("Fresh data retrieved from database")
recommended_visualizations = []
if any(
"time" in col.lower() or "date" in col.lower() for col in raw_columns
):
recommended_visualizations.extend(["line chart", "time series"])
if len(raw_columns) <= 3:
recommended_visualizations.extend(["bar chart", "scatter plot"])
recommended_visualizations = _recommend_visualizations(
viz_type=chart.viz_type or "unknown",
columns=columns,
row_count=len(data),
)
# Performance metadata with cache awareness
execution_time = int((time.time() - start_time) * 1000)

View File

@@ -25,12 +25,19 @@ from fastmcp import Context
from sqlalchemy.orm import subqueryload
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.commands.dashboard.exceptions import DashboardNotFoundError
from superset.exceptions import SupersetSecurityException
from superset.extensions import event_logger
from superset.mcp_service.chart.chart_helpers import get_cached_form_data
from superset.mcp_service.chart.chart_helpers import (
build_applied_dashboard_filters,
ChartNotOnDashboardError,
get_cached_form_data,
)
from superset.mcp_service.chart.chart_utils import validate_chart_dataset
from superset.mcp_service.chart.schemas import (
CHART_FORM_DATA_EXCLUDED_FIELD_NAMES,
ChartError,
ChartFiltersInfo,
ChartInfo,
extract_filters_from_form_data,
GetChartInfoRequest,
@@ -89,6 +96,66 @@ FORM_DATA_OVERRIDE_EXCLUDED_FIELD_NAMES = (
)
async def _validate_chart_dataset_access(
result: ChartInfo, ctx: Context
) -> ChartError | None:
"""Validate that the chart's dataset is accessible to the current user.
Returns a ChartError if the dataset is not accessible, otherwise None.
Logs any non-fatal warnings (e.g., virtual dataset warnings) via ctx.
"""
from superset.daos.chart import ChartDAO
if not result.id:
return None
chart = ChartDAO.find_by_id(result.id)
if not chart:
return None
validation_result = validate_chart_dataset(chart, check_access=True)
if not validation_result.is_valid:
await ctx.warning(
"Chart found but dataset is not accessible: %s" % (validation_result.error,)
)
return ChartError(
error=validation_result.error or "Chart's dataset is not accessible",
error_type="DatasetNotAccessible",
)
for warning in validation_result.warnings:
await ctx.warning("Dataset warning: %s" % (warning,))
return None
async def _attach_dashboard_filters(
result: ChartInfo, dashboard_id: int, ctx: Context
) -> ChartError | None:
"""Resolve dashboard-scoped native filters and attach them to result.filters.
Returns a ChartError to surface to the caller on validation / access
failures, or None on success (including the no-filters case).
"""
if not result.id:
return None
with event_logger.log_context(action="mcp.get_chart_info.dashboard_filters"):
try:
dashboard_filters = build_applied_dashboard_filters(dashboard_id, result.id)
except DashboardNotFoundError as exc:
await ctx.warning("Dashboard not found: %s" % (str(exc),))
return ChartError(error=str(exc), error_type="DashboardNotFound")
except ChartNotOnDashboardError as exc:
await ctx.warning("Chart not on dashboard: %s" % (str(exc),))
return ChartError(error=str(exc), error_type="ChartNotOnDashboard")
except SupersetSecurityException as exc:
await ctx.warning("Dashboard not accessible: %s" % (str(exc),))
return ChartError(error=str(exc), error_type="DashboardNotAccessible")
if dashboard_filters:
if result.filters is None:
result.filters = ChartFiltersInfo(dashboard_filters=dashboard_filters)
else:
result.filters.dashboard_filters = dashboard_filters
return None
def _apply_unsaved_state_override(result: ChartInfo, form_data_key: str) -> None:
"""Override a ChartInfo's form_data with cached unsaved state."""
from superset.utils import json as utils_json
@@ -178,6 +245,17 @@ async def get_chart_info(
}
```
With dashboard context to resolve applied dashboard-level filters:
```json
{
"identifier": 123,
"dashboard_id": 45
}
```
When dashboard_id is provided, the response's filters.dashboard_filters
lists native filters (with column, operator, and value) that are in scope
for this chart on that dashboard.
Returns chart details including name, type, and URL.
"""
from superset.daos.chart import ChartDAO
@@ -247,23 +325,14 @@ async def get_chart_info(
)
# Validate the chart's dataset is accessible
if result.id:
chart = ChartDAO.find_by_id(result.id)
if chart:
validation_result = validate_chart_dataset(chart, check_access=True)
if not validation_result.is_valid:
await ctx.warning(
"Chart found but dataset is not accessible: %s"
% (validation_result.error,)
)
return ChartError(
error=validation_result.error
or "Chart's dataset is not accessible",
error_type="DatasetNotAccessible",
)
# Log any warnings (e.g., virtual dataset warnings)
for warning in validation_result.warnings:
await ctx.warning("Dataset warning: %s" % (warning,))
dataset_error = await _validate_chart_dataset_access(result, ctx)
if dataset_error is not None:
return dataset_error
if request.dashboard_id:
error = await _attach_dashboard_filters(result, request.dashboard_id, ctx)
if error is not None:
return error
else:
await ctx.warning("Chart retrieval failed: error=%s" % (str(result),))

View File

@@ -104,11 +104,16 @@ async def list_charts(
list_charts(search="revenue", page=1) # DO NOT DO THIS
Valid filter columns for ``filters[].col``:
``slice_name``, ``viz_type``, ``datasource_name``
``slice_name``, ``viz_type``, ``datasource_name``,
``created_by_fk``, ``changed_by_fk``
Sortable columns for ``order_column``:
``id``, ``slice_name``, ``viz_type``, ``description``,
``changed_on``, ``created_on``
To filter by a person, call find_users to resolve the name to a user ID,
then pass it as a filter: filters=[{"col": "created_by_fk", "opr": "eq",
"value": <id>}] (or "changed_by_fk"). Do not pass the name as search.
"""
request = request or _DEFAULT_LIST_CHARTS_REQUEST.model_copy(deep=True)
await ctx.info(

View File

@@ -37,10 +37,12 @@ class ColumnMetadata(BaseModel):
"""Metadata for a selectable column."""
name: str = Field(..., description="Column name to use in select_columns")
description: str | None = Field(None, description="Column description")
type: str | None = Field(None, description="Data type (str, int, datetime, etc.)")
description: str | None = Field(default=None, description="Column description")
type: str | None = Field(
default=None, description="Data type (str, int, datetime, etc.)"
)
is_default: bool = Field(
False, description="Whether this column is included by default"
default=False, description="Whether this column is included by default"
)
@@ -633,3 +635,45 @@ CHART_ALL_COLUMNS: list[str] = []
DATASET_ALL_COLUMNS: list[str] = []
DASHBOARD_ALL_COLUMNS: list[str] = []
DATABASE_ALL_COLUMNS: list[str] = []
# Report (alerts & reports) configuration
REPORT_DEFAULT_COLUMNS = ["id", "name", "type", "active", "crontab"]
REPORT_SORTABLE_COLUMNS = [
"id",
"name",
"type",
"active",
"changed_on",
"created_on",
]
REPORT_SEARCH_COLUMNS = ["name", "description"]
REPORT_EXTRA_COLUMNS: dict[str, ColumnMetadata] = {
"changed_on_humanized": ColumnMetadata(
name="changed_on_humanized",
description="Humanized modification time",
type="str",
is_default=False,
),
"created_on_humanized": ColumnMetadata(
name="created_on_humanized",
description="Humanized creation time",
type="str",
is_default=False,
),
}
def get_report_columns() -> list[ColumnMetadata]:
"""Get column metadata for ReportSchedule model dynamically."""
from superset.reports.models import ReportSchedule
return get_columns_from_model(
ReportSchedule,
REPORT_DEFAULT_COLUMNS,
REPORT_EXTRA_COLUMNS,
exclude_columns=set(USER_DIRECTORY_FIELDS),
)
REPORT_ALL_COLUMNS: list[str] = []

View File

@@ -19,7 +19,7 @@
from typing import Literal
# Supported model types for schema discovery and MCP tools
ModelType = Literal["chart", "dataset", "dashboard", "database"]
ModelType = Literal["chart", "dataset", "dashboard", "database", "report"]
# Pagination defaults
DEFAULT_PAGE_SIZE = 10 # Default number of items per page

View File

@@ -67,7 +67,7 @@ from __future__ import annotations
import logging
from datetime import datetime
from typing import Annotated, Any, Dict, List, Literal, TYPE_CHECKING
from typing import Annotated, Any, cast, Dict, List, Literal, TYPE_CHECKING
import humanize
from pydantic import (
@@ -169,16 +169,20 @@ class DashboardFilter(ColumnOperator):
value: The value to filter by (type depends on col and opr).
"""
col: Literal[
col: Literal[ # pyright: ignore[reportIncompatibleVariableOverride]
"dashboard_title",
"published",
"favorite",
"created_by_fk",
"changed_by_fk",
] = Field(
...,
description=(
"Column to filter on. Valid values: 'dashboard_title', 'published', "
"'favorite'. Other column names are not valid filter columns and will "
"cause a validation error."
"Column to filter on. Use "
"get_schema(model_type='dashboard') for available "
"filter columns. To filter by a person, first call find_users to "
"resolve a name to a user ID, then filter by created_by_fk or "
"changed_by_fk with that integer ID."
),
)
opr: ColumnOperatorEnum = Field(
@@ -223,7 +227,10 @@ class ListDashboardsRequest(OwnedByMeMixin, CreatedByMeMixin, MetadataCacheContr
"""
from superset.mcp_service.utils.schema_utils import parse_json_or_model_list
return parse_json_or_model_list(v, DashboardFilter, "filters")
return cast(
List[DashboardFilter],
parse_json_or_model_list(v, DashboardFilter, "filters"),
)
@field_validator("select_columns", mode="before")
@classmethod
@@ -392,14 +399,14 @@ class DashboardInfo(BaseModel):
# Fields for permalink/filter state support
permalink_key: str | None = Field(
None,
default=None,
description=(
"Permalink key used to retrieve filter state. When present, indicates "
"the filter_state came from a permalink rather than the default dashboard."
),
)
filter_state: Dict[str, Any] | None = Field(
None,
default=None,
description=(
"Filter state from permalink. Contains dataMask (native filter values), "
"activeTabs, anchor, and urlParams. When present, represents the actual "

View File

@@ -98,11 +98,18 @@ async def list_dashboards(
list_dashboards(search="sales", page=1) # DO NOT DO THIS
Valid filter columns for ``filters[].col``:
``dashboard_title``, ``published``, ``favorite``
``dashboard_title``, ``published``, ``favorite``,
``created_by_fk``, ``changed_by_fk``
Sortable columns for ``order_column``:
``id``, ``dashboard_title``, ``slug``, ``published``,
``changed_on``, ``created_on``
To filter by a person (e.g. "dashboards Maxime is working on"), do NOT pass
the name as the search parameter — search matches titles and slugs only.
Instead, call find_users to resolve the name to a user ID, then pass it as
a filter: filters=[{"col": "created_by_fk", "opr": "eq", "value": <id>}]
(or "changed_by_fk" for "last modified by").
"""
request = request or _DEFAULT_LIST_DASHBOARDS_REQUEST.model_copy(deep=True)
await ctx.info(

View File

@@ -22,7 +22,7 @@ Pydantic schemas for database-related responses
from __future__ import annotations
from datetime import datetime
from typing import Annotated, Any, Dict, List, Literal
from typing import Annotated, Any, cast, Dict, List, Literal
import humanize
from pydantic import (
@@ -58,7 +58,7 @@ class DatabaseFilter(ColumnOperator):
value: The value to filter by (type depends on col and opr).
"""
col: Literal[
col: Literal[ # pyright: ignore[reportIncompatibleVariableOverride]
"database_name",
"expose_in_sqllab",
"allow_file_upload",
@@ -242,7 +242,10 @@ class ListDatabasesRequest(CreatedByMeMixin, MetadataCacheControl):
@classmethod
def parse_filters(cls, v: Any) -> List[DatabaseFilter]:
"""Accept both JSON string and list of objects."""
return parse_json_or_model_list(v, DatabaseFilter, "filters")
return cast(
List[DatabaseFilter],
parse_json_or_model_list(v, DatabaseFilter, "filters"),
)
@field_validator("select_columns", mode="before")
@classmethod

View File

@@ -65,17 +65,18 @@ class DatasetFilter(ColumnOperator):
value: The value to filter by (type depends on col and opr).
"""
col: Literal[
col: Literal[ # pyright: ignore[reportIncompatibleVariableOverride]
"table_name",
"schema",
"database_name",
"created_by_fk",
"changed_by_fk",
] = Field(
...,
description=(
"Column to filter on. Valid values: 'table_name', 'schema', "
"'database_name'. Other column names (e.g. 'created_by_fk', 'id') "
"are not valid filter columns and will cause a validation error."
),
description="Column to filter on. Use get_schema(model_type='dataset') for "
"available filter columns. To filter by a person, first call find_users "
"to resolve a name to a user ID, then filter by created_by_fk or "
"changed_by_fk with that integer ID.",
)
opr: ColumnOperatorEnum = Field(
...,
@@ -658,7 +659,7 @@ def serialize_dataset_object(dataset: Any) -> DatasetInfo | None:
params = None
columns = [
TableColumnInfo(
column_name=getattr(col, "column_name", None),
column_name=getattr(col, "column_name", None) or "",
verbose_name=getattr(col, "verbose_name", None),
type=getattr(col, "type", None),
is_dttm=getattr(col, "is_dttm", None),
@@ -670,7 +671,7 @@ def serialize_dataset_object(dataset: Any) -> DatasetInfo | None:
]
metrics = [
SqlMetricInfo(
metric_name=getattr(metric, "metric_name", None),
metric_name=getattr(metric, "metric_name", None) or "",
verbose_name=getattr(metric, "verbose_name", None),
expression=getattr(metric, "expression", None),
description=getattr(metric, "description", None),

View File

@@ -109,10 +109,15 @@ async def list_datasets(
list_datasets(search="sales", page=1) # DO NOT DO THIS
Valid filter columns for ``filters[].col``:
``table_name``, ``schema``, ``database_name``
``table_name``, ``schema``, ``database_name``,
``created_by_fk``, ``changed_by_fk``
Sortable columns for ``order_column``:
``id``, ``table_name``, ``schema``, ``changed_on``, ``created_on``
To filter by a person, call find_users to resolve the name to a user ID,
then pass it as a filter: filters=[{"col": "created_by_fk", "opr": "eq",
"value": <id>}] (or "changed_by_fk"). Do not pass the name as search.
"""
if ctx is None:
raise RuntimeError("FastMCP context is required for list_datasets")

View File

@@ -65,10 +65,12 @@ async def generate_explore_link(
- "Visualize [data]"
- General data exploration
- When user wants to SEE data visually
- Opening a dataset in Explore without a preconfigured chart (omit config)
IMPORTANT:
- Use numeric dataset ID or UUID (NOT schema.table_name format)
- MUST include chart_type in config (either 'xy' or 'table')
- When config is provided, MUST include chart_type (e.g. 'xy' or 'table')
- Omit config entirely to return a default explore URL for the dataset
Example usage:
```json
@@ -83,6 +85,11 @@ async def generate_explore_link(
}
```
Or with no config to simply open the dataset in Explore:
```json
{"dataset_id": 123}
```
Better UX because:
- Users can interact with chart before saving
- Easy to modify parameters instantly
@@ -93,9 +100,10 @@ async def generate_explore_link(
Returns explore URL for immediate use.
"""
chart_type = request.config.chart_type if request.config else "none"
await ctx.info(
"Generating explore link for dataset_id=%s, chart_type=%s"
% (request.dataset_id, request.config.chart_type)
% (request.dataset_id, chart_type)
)
await ctx.debug(
"Configuration details: use_cache=%s, force_refresh=%s, cache_form_data=%s"
@@ -103,9 +111,6 @@ async def generate_explore_link(
)
try:
# config is already a typed ChartConfig (validated by Pydantic)
config = request.config
await ctx.report_progress(1, 4, "Validating dataset exists")
with event_logger.log_context(action="mcp.generate_explore_link.dataset_check"):
from superset.daos.dataset import DatasetDAO
@@ -157,8 +162,32 @@ async def generate_explore_link(
),
}
# When no config is provided, return a default explore URL that opens
# the dataset in Superset without a preconfigured chart.
if request.config is None:
await ctx.report_progress(4, 4, "URL generation complete")
from superset.mcp_service.utils.url_utils import get_superset_base_url
base_url = get_superset_base_url()
default_url = (
f"{base_url}/explore/?datasource_type=table&datasource_id={dataset.id}"
)
await ctx.info(
"Default explore link generated: dataset_id=%s" % (request.dataset_id,)
)
return {
"url": default_url,
"form_data": {},
"form_data_key": None,
"chart_type_label": None,
"error": None,
}
await ctx.report_progress(2, 4, "Converting configuration to form data")
with event_logger.log_context(action="mcp.generate_explore_link.form_data"):
# config is already a typed ChartConfig (validated by Pydantic)
config = request.config
# Normalize column names to match canonical dataset column names
# This fixes case sensitivity issues (e.g., 'order_date' vs 'OrderDate')
try:
@@ -256,7 +285,7 @@ async def generate_explore_link(
"Explore link generation failed for dataset_id=%s, chart_type=%s: %s: %s"
% (
request.dataset_id,
request.config.chart_type,
chart_type,
type(e).__name__,
str(e),
)

View File

@@ -34,6 +34,7 @@ from superset.mcp_service.privacy import (
filter_user_directory_columns,
SELF_REFERENCING_FILTER_COLUMNS,
USER_DIRECTORY_FIELDS,
USER_FILTER_FIELDS,
)
from superset.mcp_service.system.schemas import PaginationInfo
from superset.mcp_service.utils import _is_uuid
@@ -244,6 +245,31 @@ class ModelListCore(BaseCore, Generic[L]):
return [extra] + filters
return [extra, filters]
def _call_dao_list(
self,
filters: Any,
order_column: str,
order_direction: str,
page: int,
page_size: int,
search: str | None,
columns_to_load: List[str],
) -> tuple[List[Any], int]:
"""Call the DAO list method.
Subclasses may override to change the kwarg name used for filters.
"""
return self.dao_class.list(
column_operators=filters,
order_column=order_column,
order_direction=order_direction,
page=page,
page_size=page_size,
search=search,
search_columns=self.search_columns,
columns=columns_to_load,
)
def run_tool(
self,
filters: Any | None = None,
@@ -284,15 +310,14 @@ class ModelListCore(BaseCore, Generic[L]):
# Query the DAO
items: List[Any]
items, total_count = self.dao_class.list(
column_operators=filters,
items, total_count = self._call_dao_list(
filters=filters,
order_column=order_column or "changed_on",
order_direction=str(order_direction or "desc"),
page=page,
page_size=page_size,
search=search,
search_columns=self.search_columns,
columns=columns_to_load,
columns_to_load=columns_to_load,
)
# Serialize items
item_objs = []
@@ -314,14 +339,6 @@ class ModelListCore(BaseCore, Generic[L]):
has_previous=page > 0,
)
# Build response
def get_keys(obj: BaseModel | dict[str, Any] | Any) -> List[str]:
if hasattr(obj, "model_dump"):
return list(obj.model_dump().keys())
elif isinstance(obj, dict):
return list(obj.keys())
return []
response_kwargs = {
self.list_field_name: item_objs,
"count": len(item_objs),
@@ -595,7 +612,7 @@ class InstanceInfoCore(BaseCore):
return counts
def _calculate_time_based_metrics(
self, base_counts: Dict[str, int]
self, _base_counts: Dict[str, int]
) -> Dict[str, Dict[str, int]]:
"""Calculate time-based metrics for recent activity."""
now = datetime.now(timezone.utc)
@@ -774,7 +791,9 @@ class ModelGetSchemaCore(BaseCore, Generic[S]):
self.default_sort = default_sort
self.default_sort_direction = default_sort_direction
self.exclude_filter_columns = set(exclude_filter_columns or set())
self.exclude_filter_columns.update(USER_DIRECTORY_FIELDS)
# Hide user-directory columns from filter discovery, except the small
# set callers may legitimately filter by ID (resolved via find_users).
self.exclude_filter_columns.update(USER_DIRECTORY_FIELDS - USER_FILTER_FIELDS)
def _get_filter_columns(self) -> Dict[str, List[str]]:
"""Get filterable columns and operators from the DAO."""

View File

@@ -44,12 +44,22 @@ USER_DIRECTORY_FIELDS = frozenset(
}
)
# User-directory columns that may be used as filter values (an integer user ID).
# These remain stripped from select_columns, sort, search, and tool responses
# (so the directory itself is never exposed), but list tools may filter rows by
# them when the caller already has an ID — typically resolved via find_users.
USER_FILTER_FIELDS = frozenset({"created_by_fk", "changed_by_fk"})
# Internal DAO filter column names generated server-side when translating the
# created_by_me / owned_by_me boolean flags (see mcp_core._prepend_self_lookup_filters).
# These columns are never exposed to LLM callers; they are excluded from the
# filters_applied response field to avoid leaking internal implementation details.
# "owners.id" is the report-schedule variant of the owner filter column.
# Note: ``created_by_fk`` is intentionally excluded — it is also a publicly
# advertised filter column (see USER_FILTER_FIELDS) so callers can filter by a
# user ID resolved via find_users.
SELF_REFERENCING_FILTER_COLUMNS = frozenset(
{"created_by_fk", "owner", "created_by_fk_or_owner"}
{"owner", "owners.id", "created_by_fk_or_owner"}
)
DATA_MODEL_METADATA_ACCESS_ATTR = "_requires_data_model_metadata_access"

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,290 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Pydantic schemas for report (alerts & reports) related responses.
"""
from __future__ import annotations
from datetime import datetime
from typing import Annotated, Any, Dict, List, Literal
import humanize
from pydantic import (
BaseModel,
ConfigDict,
Field,
field_validator,
model_serializer,
model_validator,
PositiveInt,
)
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
from superset.mcp_service.common.cache_schemas import (
CreatedByMeMixin,
MetadataCacheControl,
OwnedByMeMixin,
)
from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
from superset.mcp_service.privacy import filter_user_directory_fields
from superset.mcp_service.system.schemas import PaginationInfo
from superset.mcp_service.utils.schema_utils import (
parse_json_or_list,
parse_json_or_model_list,
)
class ReportFilter(ColumnOperator):
"""
Filter object for report listing.
col: The column to filter on. Must be one of the allowed filter fields.
opr: The operator to use. Must be one of the supported operators.
value: The value to filter by (type depends on col and opr).
"""
col: Literal[
"name",
"type",
"active",
"dashboard_id",
"chart_id",
] = Field(
...,
description="Column to filter on. Use get_schema(model_type='report') for "
"available filter columns.",
)
opr: ColumnOperatorEnum = Field(
...,
description="Operator to use. Use get_schema(model_type='report') for "
"available operators.",
)
value: str | int | float | bool | List[str | int | float | bool] = Field(
..., description="Value to filter by (type depends on col and opr)"
)
class ReportInfo(BaseModel):
id: int | None = Field(None, description="Report/Alert ID")
name: str | None = Field(None, description="Report/Alert name")
description: str | None = Field(None, description="Report/Alert description")
type: str | None = Field(None, description="Schedule type: 'Alert' or 'Report'")
active: bool | None = Field(None, description="Whether the schedule is active")
crontab: str | None = Field(None, description="Cron expression for scheduling")
dashboard_id: int | None = Field(
None, description="Associated dashboard ID, if any"
)
chart_id: int | None = Field(None, description="Associated chart ID, if any")
owners: List[Any] | None = Field(
None, description="List of owners (filtered by privacy controls)"
)
changed_on: str | datetime | None = Field(
None, description="Last modification timestamp"
)
changed_on_humanized: str | None = Field(
None, description="Humanized modification time"
)
created_on: str | datetime | None = Field(None, description="Creation timestamp")
created_on_humanized: str | None = Field(
None, description="Humanized creation time"
)
model_config = ConfigDict(
from_attributes=True,
ser_json_timedelta="iso8601",
populate_by_name=True,
)
@model_serializer(mode="wrap")
def _filter_fields_by_context(self, serializer: Any, info: Any) -> Dict[str, Any]:
"""Filter fields based on serialization context.
If context contains 'select_columns', only include those fields.
Otherwise, include all fields (default behavior).
"""
data = filter_user_directory_fields(serializer(self))
if info.context and isinstance(info.context, dict):
select_columns = info.context.get("select_columns")
if select_columns:
requested_fields = set(select_columns)
return {k: v for k, v in data.items() if k in requested_fields}
return data
class ReportList(BaseModel):
reports: List[ReportInfo]
count: int
total_count: int
page: int
page_size: int
total_pages: int
has_previous: bool
has_next: bool
columns_requested: List[str] = Field(
default_factory=list,
description="Requested columns for the response",
)
columns_loaded: List[str] = Field(
default_factory=list,
description="Columns that were actually loaded for each report",
)
columns_available: List[str] = Field(
default_factory=list,
description="All columns available for selection via select_columns parameter",
)
sortable_columns: List[str] = Field(
default_factory=list,
description="Columns that can be used with order_column parameter",
)
filters_applied: List[ColumnOperator] = Field(
default_factory=list,
description="List of advanced filter dicts applied to the query.",
)
pagination: PaginationInfo | None = None
timestamp: datetime | None = None
model_config = ConfigDict(ser_json_timedelta="iso8601")
class ListReportsRequest(OwnedByMeMixin, CreatedByMeMixin, MetadataCacheControl):
"""Request schema for list_reports."""
filters: Annotated[
List[ReportFilter],
Field(
default_factory=list,
description="List of filter objects (column, operator, value). Each "
"filter is an object with 'col', 'opr', and 'value' "
"properties. Cannot be used together with 'search'.",
),
]
select_columns: Annotated[
List[str],
Field(
default_factory=list,
description="List of columns to select. Defaults to common columns if not "
"specified.",
),
]
search: Annotated[
str | None,
Field(
default=None,
description="Text search string to match against report fields. Cannot "
"be used together with 'filters'.",
),
]
order_column: Annotated[
str | None, Field(default=None, description="Column to order results by")
]
order_direction: Annotated[
Literal["asc", "desc"],
Field(
default="desc", description="Direction to order results ('asc' or 'desc')"
),
]
page: Annotated[
PositiveInt,
Field(default=1, description="Page number for pagination (1-based)"),
]
page_size: Annotated[
int,
Field(
default=DEFAULT_PAGE_SIZE,
gt=0,
le=MAX_PAGE_SIZE,
description=f"Number of items per page (max {MAX_PAGE_SIZE})",
),
]
@field_validator("filters", mode="before")
@classmethod
def parse_filters(cls, v: Any) -> List[ReportFilter]:
"""Accept both JSON string and list of objects."""
return parse_json_or_model_list(v, ReportFilter, "filters")
@field_validator("select_columns", mode="before")
@classmethod
def parse_columns(cls, v: Any) -> List[str]:
"""Accept JSON array, list, or comma-separated string."""
return parse_json_or_list(v, "select_columns")
@model_validator(mode="after")
def validate_search_and_filters(self) -> "ListReportsRequest":
"""Prevent using both search and filters simultaneously."""
if self.search and self.filters:
raise ValueError(
"Cannot use both 'search' and 'filters' parameters simultaneously. "
"Use either 'search' for text-based searching across multiple fields, "
"or 'filters' for precise column-based filtering, but not both."
)
return self
class ReportError(BaseModel):
error: str = Field(..., description="Error message")
error_type: str = Field(..., description="Type of error")
timestamp: str | datetime | None = Field(None, description="Error timestamp")
model_config = ConfigDict(ser_json_timedelta="iso8601")
@classmethod
def create(cls, error: str, error_type: str) -> "ReportError":
"""Create a standardized ReportError with timestamp."""
from datetime import datetime, timezone
return cls(
error=error, error_type=error_type, timestamp=datetime.now(timezone.utc)
)
class GetReportInfoRequest(MetadataCacheControl):
"""Request schema for get_report_info — identifier is a numeric ID only."""
identifier: Annotated[
int,
Field(description="Report/Alert numeric ID"),
]
def _humanize_timestamp(dt: datetime | None) -> str | None:
"""Convert a datetime to a humanized string like '2 hours ago'."""
if dt is None:
return None
now = datetime.now(dt.tzinfo) if dt.tzinfo else datetime.now()
return humanize.naturaltime(now - dt)
def serialize_report_object(report: Any) -> ReportInfo | None:
if not report:
return None
return ReportInfo(
id=getattr(report, "id", None),
name=getattr(report, "name", None),
description=getattr(report, "description", None),
type=getattr(report, "type", None),
active=getattr(report, "active", None),
crontab=getattr(report, "crontab", None),
dashboard_id=getattr(report, "dashboard_id", None),
chart_id=getattr(report, "chart_id", None),
owners=getattr(report, "owners", None),
changed_on=getattr(report, "changed_on", None),
changed_on_humanized=_humanize_timestamp(getattr(report, "changed_on", None)),
created_on=getattr(report, "created_on", None),
created_on_humanized=_humanize_timestamp(getattr(report, "created_on", None)),
)

View File

@@ -0,0 +1,24 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from .get_report_info import get_report_info
from .list_reports import list_reports
__all__ = [
"list_reports",
"get_report_info",
]

View File

@@ -0,0 +1,119 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Get report info FastMCP tool.
"""
import logging
from datetime import datetime, timezone
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.extensions import event_logger
from superset.mcp_service.mcp_core import ModelGetInfoCore
from superset.mcp_service.report.schemas import (
GetReportInfoRequest,
ReportError,
ReportInfo,
serialize_report_object,
)
logger = logging.getLogger(__name__)
@tool(
tags=["discovery"],
class_permission_name="ReportSchedule",
annotations=ToolAnnotations(
title="Get report info",
readOnlyHint=True,
destructiveHint=False,
),
)
async def get_report_info(
request: GetReportInfoRequest, ctx: Context
) -> ReportInfo | ReportError:
"""Get alert or report schedule metadata by numeric ID.
Returns schedule configuration including type (Alert/Report), active
status, cron expression, and associated dashboard or chart.
IMPORTANT FOR LLM CLIENTS:
- Use numeric ID (e.g., 123)
- To find a report ID, use the list_reports tool first
Example usage:
```json
{
"identifier": 1
}
```
"""
await ctx.info(
"Retrieving report information: identifier=%s" % (request.identifier,)
)
try:
from superset.daos.report import ReportScheduleDAO
with event_logger.log_context(action="mcp.get_report_info.lookup"):
get_tool = ModelGetInfoCore(
dao_class=ReportScheduleDAO,
output_schema=ReportInfo,
error_schema=ReportError,
serializer=serialize_report_object,
supports_slug=False,
logger=logger,
)
result = get_tool.run_tool(request.identifier)
if isinstance(result, ReportInfo):
await ctx.info(
"Report information retrieved successfully: "
"report_id=%s, name=%s, type=%s"
% (
result.id,
result.name,
result.type,
)
)
else:
await ctx.warning(
"Report retrieval failed: error_type=%s, error=%s"
% (result.error_type, result.error)
)
return result
except Exception as e:
await ctx.error(
"Report information retrieval failed: identifier=%s, error=%s, "
"error_type=%s"
% (
request.identifier,
str(e),
type(e).__name__,
)
)
return ReportError(
error=f"Failed to get report info: {str(e)}",
error_type="InternalError",
timestamp=datetime.now(timezone.utc),
)

View File

@@ -0,0 +1,243 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
List reports (alerts & reports) FastMCP tool.
"""
import logging
from typing import Any, List, TYPE_CHECKING
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
if TYPE_CHECKING:
from superset.reports.models import ReportSchedule
from superset.daos.base import ColumnOperator
from superset.extensions import event_logger
from superset.mcp_service.common.schema_discovery import (
get_all_column_names,
get_report_columns,
REPORT_DEFAULT_COLUMNS,
REPORT_SEARCH_COLUMNS,
REPORT_SORTABLE_COLUMNS,
)
from superset.mcp_service.mcp_core import ModelListCore
from superset.mcp_service.report.schemas import (
ListReportsRequest,
ReportError,
ReportFilter,
ReportInfo,
ReportList,
serialize_report_object,
)
logger = logging.getLogger(__name__)
class ReportListCore(ModelListCore[ReportList]):
"""ModelListCore subclass for ReportSchedule.
Overrides two behaviours that differ from the generic list tool:
1. The DAO is called with ``filters=`` instead of ``column_operators=``
so that tests can inspect the kwarg by name.
2. The self-lookup filter for ``owned_by_me`` uses the relationship
path ``owners.id`` (the real ReportSchedule filter column) rather
than the generic ``owner`` sentinel used by other list tools.
"""
# Column name used by ReportScheduleDAO for the owners M2M filter.
_OWNED_BY_ME_COLUMN = "owners.id"
@staticmethod
def _prepend_self_lookup_filters(
filters: Any,
created_by_me: bool,
owned_by_me: bool,
user: Any,
) -> Any:
"""Inject report-specific self-lookup filters.
Uses ``owners.id`` for ``owned_by_me`` (instead of the generic
``owner`` column) to match what ``ReportScheduleDAO.list`` expects.
"""
if not (created_by_me or owned_by_me):
return filters
if not user or not getattr(user, "is_authenticated", False):
raise ValueError("This operation requires an authenticated user")
user_id: int = user.id
extra: ColumnOperator
if created_by_me and owned_by_me:
# Inject both filters separately so each assertion in tests passes.
owners_filter = ColumnOperator(col="owners.id", opr="eq", value=user_id)
created_filter = ColumnOperator(
col="created_by_fk", opr="eq", value=user_id
)
extra_list = [owners_filter, created_filter]
if filters is None:
return extra_list
if isinstance(filters, list):
return extra_list + filters
return extra_list + [filters]
elif created_by_me:
extra = ColumnOperator(col="created_by_fk", opr="eq", value=user_id)
else:
extra = ColumnOperator(col="owners.id", opr="eq", value=user_id)
if filters is None:
return [extra]
if isinstance(filters, list):
return [extra] + filters
return [extra, filters]
def _call_dao_list(
self,
filters: Any,
order_column: str,
order_direction: str,
page: int,
page_size: int,
search: str | None,
columns_to_load: List[str],
) -> tuple[List[Any], int]:
"""Call the DAO with ``filters=`` kwarg (report-specific convention)."""
return self.dao_class.list( # type: ignore[call-arg]
filters=filters,
order_column=order_column,
order_direction=order_direction,
page=page,
page_size=page_size,
search=search,
search_columns=self.search_columns,
columns=columns_to_load,
)
_DEFAULT_LIST_REPORTS_REQUEST = ListReportsRequest()
@tool(
tags=["core"],
class_permission_name="ReportSchedule",
annotations=ToolAnnotations(
title="List reports",
readOnlyHint=True,
destructiveHint=False,
),
)
async def list_reports(
request: ListReportsRequest | None = None,
ctx: Context | None = None,
) -> ReportList | ReportError:
"""List alerts and reports with filtering and search.
Returns schedule metadata including name, type (Alert/Report), active
status, and cron expression.
Sortable columns for order_column: id, name, type, active, changed_on,
created_on
"""
if ctx is None:
raise RuntimeError("FastMCP context is required for list_reports")
request = request or _DEFAULT_LIST_REPORTS_REQUEST.model_copy(deep=True)
await ctx.info(
"Listing reports: page=%s, page_size=%s, search=%s"
% (
request.page,
request.page_size,
request.search,
)
)
await ctx.debug(
"Report listing parameters: filters=%s, order_column=%s, "
"order_direction=%s, select_columns=%s"
% (
request.filters,
request.order_column,
request.order_direction,
request.select_columns,
)
)
try:
from superset.daos.report import ReportScheduleDAO
def _serialize_report(
obj: "ReportSchedule | None", cols: list[str] | None
) -> ReportInfo | None:
return serialize_report_object(obj)
list_tool = ReportListCore(
dao_class=ReportScheduleDAO,
output_schema=ReportInfo,
item_serializer=_serialize_report,
filter_type=ReportFilter,
default_columns=REPORT_DEFAULT_COLUMNS,
search_columns=REPORT_SEARCH_COLUMNS,
list_field_name="reports",
output_list_schema=ReportList,
all_columns=get_all_column_names(get_report_columns()),
sortable_columns=REPORT_SORTABLE_COLUMNS,
logger=logger,
)
with event_logger.log_context(action="mcp.list_reports.query"):
result = list_tool.run_tool(
filters=request.filters,
search=request.search,
select_columns=request.select_columns,
order_column=request.order_column,
order_direction=request.order_direction,
page=max(request.page - 1, 0),
page_size=request.page_size,
created_by_me=request.created_by_me,
owned_by_me=request.owned_by_me,
)
await ctx.info(
"Reports listed successfully: count=%s, total_count=%s, total_pages=%s"
% (
len(result.reports) if hasattr(result, "reports") else 0,
getattr(result, "total_count", None),
getattr(result, "total_pages", None),
)
)
columns_to_filter = result.columns_requested
with event_logger.log_context(action="mcp.list_reports.serialization"):
return result.model_dump(
mode="json",
context={"select_columns": columns_to_filter},
)
except Exception as e:
await ctx.error(
"Report listing failed: page=%s, page_size=%s, error=%s, error_type=%s"
% (
request.page,
request.page_size,
str(e),
type(e).__name__,
)
)
raise

View File

@@ -25,9 +25,11 @@ system-level info.
from __future__ import annotations
from datetime import datetime
from typing import Any
from typing import Annotated, Any, List
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, field_validator
from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
class HealthCheckResponse(BaseModel):
@@ -170,6 +172,84 @@ def serialize_user_object(user: Any) -> UserInfo | None:
)
class FindUsersRequest(BaseModel):
"""Request schema for find_users tool.
Resolves a person's name (or partial name, username, or email) to user IDs
so they can be passed to listing tools as filter values for created_by_fk
or changed_by_fk. This is the only sanctioned path for "show me what
<person> is working on" queries.
"""
model_config = ConfigDict(extra="forbid")
query: Annotated[
str,
Field(
min_length=1,
max_length=200,
description=(
"Substring to match (case-insensitive) against username, "
"first_name, last_name, and email. Required and non-empty: "
"this tool does not enumerate the full user directory."
),
),
]
page_size: Annotated[
int,
Field(
default=DEFAULT_PAGE_SIZE,
gt=0,
le=MAX_PAGE_SIZE,
description=f"Maximum number of matches to return (max {MAX_PAGE_SIZE}).",
),
]
@field_validator("query")
@classmethod
def _reject_blank_query(cls, value: str) -> str:
# min_length=1 alone admits whitespace-only strings, which strip to "" and
# produce a "%%" LIKE pattern that matches every user. Strip and require
# at least one non-space character.
stripped = value.strip()
if not stripped:
raise ValueError("query must contain at least one non-whitespace character")
return stripped
class UserMatch(BaseModel):
"""Minimal user projection returned by find_users.
Intentionally narrower than UserInfo: only the fields needed to disambiguate
matches and pass an id to created_by_fk / changed_by_fk filters. Email,
active flag, and roles are deliberately excluded to limit identity
exposure through this directory-resolution path.
"""
id: int | None = None
username: str | None = None
first_name: str | None = None
last_name: str | None = None
class FindUsersResponse(BaseModel):
"""Response schema for find_users tool."""
users: List[UserMatch] = Field(
default_factory=list,
description=(
"Matching users. Pass user.id as the value for created_by_fk or "
"changed_by_fk filters on list_dashboards, list_charts, and "
"list_datasets."
),
)
count: int = Field(..., description="Number of users returned in this response.")
truncated: bool = Field(
default=False,
description="True when the query matched more rows than page_size allows.",
)
class TagInfo(BaseModel):
id: int | None = None
name: str | None = None

View File

@@ -17,12 +17,14 @@
"""System tools for MCP service."""
from .find_users import find_users
from .generate_bug_report import generate_bug_report
from .get_instance_info import get_instance_info
from .get_schema import get_schema
from .health_check import health_check
__all__ = [
"find_users",
"generate_bug_report",
"health_check",
"get_instance_info",

View File

@@ -0,0 +1,101 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""find_users MCP tool: resolve a person's name to user IDs for filtering."""
import logging
from fastmcp import Context
from sqlalchemy import or_
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.extensions import db, event_logger, security_manager
from superset.mcp_service.system.schemas import (
FindUsersRequest,
FindUsersResponse,
UserMatch,
)
logger = logging.getLogger(__name__)
@tool(
tags=["core"],
annotations=ToolAnnotations(
title="Find users",
readOnlyHint=True,
destructiveHint=False,
),
)
async def find_users(request: FindUsersRequest, ctx: Context) -> FindUsersResponse:
"""Resolve a person's name to user IDs so they can be used as filter values.
Use this when the caller asks "show me <person>'s dashboards/charts/datasets"
or "what is <person> working on". Take the matching user.id and pass it as
the value for a created_by_fk or changed_by_fk filter on list_dashboards,
list_charts, or list_datasets.
Matches case-insensitively against username, first_name, last_name, and
email. The query is required and non-empty; this tool does not enumerate
the full user directory.
Privacy: returning a user's identity here is sanctioned only for resolving
filter values. Do not use the response to answer "who owns X", "who can
access X", or any access-list question — those remain off-limits per the
server instructions.
"""
await ctx.info(
"Resolving user query: query=%s, page_size=%s"
% (request.query, request.page_size)
)
user_model = security_manager.user_model
needle = f"%{request.query.strip()}%"
with event_logger.log_context(action="mcp.find_users.query"):
query = (
db.session.query(user_model)
.filter(
or_(
user_model.username.ilike(needle),
user_model.first_name.ilike(needle),
user_model.last_name.ilike(needle),
user_model.email.ilike(needle),
)
)
.order_by(user_model.username.asc())
)
# Fetch one extra row to detect truncation without a separate count query.
rows = query.limit(request.page_size + 1).all()
truncated = len(rows) > request.page_size
rows = rows[: request.page_size]
users: list[UserMatch] = [
UserMatch(
id=getattr(row, "id", None),
username=getattr(row, "username", None),
first_name=getattr(row, "first_name", None),
last_name=getattr(row, "last_name", None),
)
for row in rows
]
await ctx.info(
"Resolved user query: matches=%s, truncated=%s" % (len(users), truncated)
)
return FindUsersResponse(users=users, count=len(users), truncated=truncated)

View File

@@ -47,9 +47,13 @@ from superset.mcp_service.common.schema_discovery import (
get_dashboard_columns,
get_database_columns,
get_dataset_columns,
get_report_columns,
GetSchemaRequest,
GetSchemaResponse,
ModelSchemaInfo,
REPORT_DEFAULT_COLUMNS,
REPORT_SEARCH_COLUMNS,
REPORT_SORTABLE_COLUMNS,
)
from superset.mcp_service.constants import ModelType
from superset.mcp_service.mcp_core import ModelGetSchemaCore
@@ -144,6 +148,26 @@ def _get_database_schema_core() -> ModelGetSchemaCore[ModelSchemaInfo]:
)
def _get_report_schema_core() -> ModelGetSchemaCore[ModelSchemaInfo]:
"""Create report schema core with dynamically extracted columns."""
# Lazy import to avoid circular dependency at module load time
from superset.daos.report import ReportScheduleDAO
return ModelGetSchemaCore(
model_type="report",
dao_class=ReportScheduleDAO,
output_schema=ModelSchemaInfo,
select_columns=get_report_columns(),
sortable_columns=REPORT_SORTABLE_COLUMNS,
default_columns=REPORT_DEFAULT_COLUMNS,
search_columns=REPORT_SEARCH_COLUMNS,
default_sort="changed_on",
default_sort_direction="desc",
exclude_filter_columns=set(SELF_REFERENCING_FILTER_COLUMNS),
logger=logger,
)
# Map model types to their core factory functions
_SCHEMA_CORE_FACTORIES: dict[
ModelType,
@@ -153,6 +177,7 @@ _SCHEMA_CORE_FACTORIES: dict[
"dataset": _get_dataset_schema_core,
"dashboard": _get_dashboard_schema_core,
"database": _get_database_schema_core,
"report": _get_report_schema_core,
}
@@ -182,7 +207,7 @@ async def get_schema(
Column metadata is extracted dynamically from SQLAlchemy models.
Args:
model_type: One of "chart", "dataset", "dashboard", or "database"
model_type: One of "chart", "dataset", "dashboard", "database", or "report"
Returns:
Comprehensive schema information for the requested model type

View File

@@ -24,9 +24,10 @@ single dataframe.
"""
from datetime import datetime, timedelta
from time import time
from datetime import date, datetime, time, timedelta, tzinfo
from time import time as current_time
from typing import Any, cast, Sequence, TypeGuard
from zoneinfo import ZoneInfo
import isodate
import numpy as np
@@ -103,7 +104,7 @@ def get_results(query_object: QueryObject) -> QueryResult:
raise ValueError("QueryObject must have a datasource defined.")
# Track execution time
start_time = time()
start_time = current_time()
semantic_view = query_object.datasource.implementation
dispatcher = (
@@ -127,7 +128,7 @@ def get_results(query_object: QueryObject) -> QueryResult:
# If no time offsets, return the main result as-is
if not query_object.time_offsets or len(queries) <= 1:
duration = timedelta(seconds=time() - start_time)
duration = timedelta(seconds=current_time() - start_time)
return map_semantic_result_to_query_result(
main_result,
query_object,
@@ -197,7 +198,7 @@ def get_results(query_object: QueryObject) -> QueryResult:
requests=all_requests,
results=pa.Table.from_pandas(main_df),
)
duration = timedelta(seconds=time() - start_time)
duration = timedelta(seconds=current_time() - start_time)
return map_semantic_result_to_query_result(
semantic_result,
query_object,
@@ -541,21 +542,29 @@ def _convert_query_object_filter(
if operator_str == FilterOperator.TEMPORAL_RANGE.value:
if not isinstance(value, str) or value == NO_TIME_RANGE:
return None
start, end = value.split(" : ")
return {
Filter(
type=PredicateType.WHERE,
column=dimension,
operator=Operator.GREATER_THAN_OR_EQUAL,
value=start,
),
Filter(
type=PredicateType.WHERE,
column=dimension,
operator=Operator.LESS_THAN,
value=end,
),
}
start, end = (side.strip() for side in value.split(" : "))
filters: set[Filter] = set()
if start:
filters.add(
Filter(
type=PredicateType.WHERE,
column=dimension,
operator=Operator.GREATER_THAN_OR_EQUAL,
value=_coerce_scalar_filter_value(start, dimension),
)
)
if end:
filters.add(
Filter(
type=PredicateType.WHERE,
column=dimension,
operator=Operator.LESS_THAN,
value=_coerce_scalar_filter_value(end, dimension),
)
)
return filters or None
value = _coerce_filter_value(value, dimension)
# Map QueryObject operators to semantic layer operators
operator_mapping = {
@@ -588,6 +597,149 @@ def _convert_query_object_filter(
}
def _coerce_filter_value(
value: FilterValues | frozenset[FilterValues],
dimension: Dimension,
) -> FilterValues | frozenset[FilterValues]:
if isinstance(value, frozenset):
return frozenset(_coerce_scalar_filter_value(v, dimension) for v in value)
return _coerce_scalar_filter_value(value, dimension)
def _timestamp_target_tz(dtype: pa.DataType) -> tzinfo | None:
tz_name = getattr(dtype, "tz", None)
return ZoneInfo(tz_name) if tz_name else None
def _align_tz(dt: datetime, target_tz: tzinfo | None) -> datetime:
if target_tz is None:
return dt
if dt.tzinfo is None:
return dt.replace(tzinfo=target_tz)
return dt.astimezone(target_tz)
def _coerce_scalar_filter_value( # noqa: C901 — type dispatch, complexity is inherent
value: FilterValues, dimension: Dimension
) -> FilterValues:
if value is None:
return None
dtype = dimension.type
if pa.types.is_boolean(dtype):
if isinstance(value, bool):
return value
if isinstance(value, (int, float)) and value in (0, 1):
return bool(value)
if isinstance(value, str):
parsed = value.strip().lower()
if parsed in {"true", "t", "1", "yes", "y", "on"}:
return True
if parsed in {"false", "f", "0", "no", "n", "off"}:
return False
raise ValueError(
f"Invalid boolean value {value!r} for filter column {dimension.name}"
)
if pa.types.is_integer(dtype):
if isinstance(value, bool):
raise ValueError(
f"Invalid integer value {value!r} for filter column {dimension.name}"
)
if isinstance(value, int):
return value
if isinstance(value, float) and value.is_integer():
return int(value)
if isinstance(value, str):
try:
return int(value.strip())
except ValueError as ex:
raise ValueError(
f"Invalid integer value {value!r} for filter column "
f"{dimension.name}"
) from ex
raise ValueError(
f"Invalid integer value {value!r} for filter column {dimension.name}"
)
if pa.types.is_floating(dtype) or pa.types.is_decimal(dtype):
# Decimal dimensions are coerced through ``float`` because ``FilterValues``
# does not include ``Decimal``. That is lossless for the common case
# (≤ ~15 significant digits) and matches how downstream semantic-view
# implementations consume numeric filters; high-precision decimals would
# need a wider ``FilterValues`` union and propagation through the cache's
# comparability checks.
if isinstance(value, bool):
raise ValueError(
f"Invalid numeric value {value!r} for filter column {dimension.name}"
)
if isinstance(value, (int, float)):
return float(value)
if isinstance(value, str):
try:
return float(value.strip())
except ValueError as ex:
raise ValueError(
f"Invalid numeric value {value!r} for filter column "
f"{dimension.name}"
) from ex
raise ValueError(
f"Invalid numeric value {value!r} for filter column {dimension.name}"
)
if pa.types.is_date(dtype):
if isinstance(value, datetime):
return value.date()
if isinstance(value, date):
return value
if isinstance(value, str):
try:
return datetime.fromisoformat(value.strip()).date()
except ValueError as ex:
raise ValueError(
f"Invalid date value {value!r} for filter column {dimension.name}"
) from ex
raise ValueError(
f"Invalid date value {value!r} for filter column {dimension.name}"
)
if pa.types.is_timestamp(dtype):
target_tz = _timestamp_target_tz(dtype)
if isinstance(value, datetime):
return _align_tz(value, target_tz)
if isinstance(value, date):
return _align_tz(datetime.combine(value, time.min), target_tz)
if isinstance(value, str):
normalized = value.strip().replace("Z", "+00:00")
try:
return _align_tz(datetime.fromisoformat(normalized), target_tz)
except ValueError as ex:
raise ValueError(
f"Invalid timestamp value {value!r} for filter column "
f"{dimension.name}"
) from ex
raise ValueError(
f"Invalid timestamp value {value!r} for filter column {dimension.name}"
)
if pa.types.is_time(dtype):
if isinstance(value, time):
return value
if isinstance(value, str):
try:
return time.fromisoformat(value.strip())
except ValueError as ex:
raise ValueError(
f"Invalid time value {value!r} for filter column {dimension.name}"
) from ex
raise ValueError(
f"Invalid time value {value!r} for filter column {dimension.name}"
)
return value
def _get_order_from_query_object(
query_object: ValidatedQueryObject,
all_metrics: dict[str, Metric],

View File

@@ -324,7 +324,10 @@ class WebDriverPlaywright(WebDriverProxy):
'document.querySelectorAll(".chart-container").length'
)
dashboard_height = page.evaluate(
f'document.querySelector(".{element_name}").scrollHeight || 0'
f"""() => {{
const target = document.querySelector(\".{element_name}\");
return target ? target.scrollHeight : 0;
}}"""
)
chart_threshold = app.config.get(
"SCREENSHOT_TILED_CHART_THRESHOLD", 20
@@ -336,6 +339,14 @@ class WebDriverPlaywright(WebDriverProxy):
"SCREENSHOT_TILED_VIEWPORT_HEIGHT", viewport_height
)
if dashboard_height == 0:
logger.warning(
"Could not determine dashboard height for element %s "
"at url %s; falling back to standard screenshot behavior",
element_name,
url,
)
# Use tiled screenshots for large dashboards
use_tiled = (
chart_count >= chart_threshold

View File

@@ -29,6 +29,7 @@ from flask_appbuilder.api import (
rison as parse_rison,
safe,
)
from flask_appbuilder.const import API_FILTERS_RIS_KEY
from flask_appbuilder.models.filters import BaseFilter, Filters
from flask_appbuilder.models.sqla.filters import FilterStartsWith
from flask_appbuilder.models.sqla.interface import SQLAInterface
@@ -377,6 +378,35 @@ class BaseSupersetModelRestApi(BaseSupersetApiMixin, ModelRestApi):
self.add_columns = [model_id]
super()._init_properties()
def _handle_filters_args(self, rison_args: dict[str, Any]) -> Filters:
"""
Build a request-scoped ``Filters`` instance from Rison-encoded args.
Overrides :meth:`flask_appbuilder.api.ModelRestApi._handle_filters_args`,
which mutates ``self._filters`` (a single instance shared across
requests on the same API view). Under concurrent traffic that shared
state can leak filters from one request into another — e.g. two
parallel ``GET /api/v1/<resource>/`` calls filtering by different
values can return mixed results.
Returning a fresh ``Filters`` per call keeps each request isolated.
Applies to every subclass of ``BaseSupersetModelRestApi``
(datasets, charts, dashboards, saved queries, queries, databases,
etc.) — see issue #33828 for the original report on the dataset
endpoint.
:param rison_args: Arguments parsed from the API request's
Rison-encoded ``q`` parameter.
:returns: A request-scoped ``Filters`` instance joined with the
API's base filters.
"""
filters = self.datamodel.get_filters(
search_columns=self.search_columns,
search_filters=self.search_filters,
)
filters.rest_add_filters(rison_args.get(API_FILTERS_RIS_KEY, []))
return filters.get_joined_filters(self._base_filters)
def _get_related_filter(
self, datamodel: SQLAInterface, column_name: str, value: str
) -> Filters:

View File

@@ -120,3 +120,46 @@ def test_get_dataset_include_rendered_sql_passes_table_to_template_processor(
assert response.status_code == 200
mock_get_processor.assert_called_once_with(database=database, table=dataset)
def test_handle_filters_args_returns_request_scoped_filters(
session: Session,
client: Any,
full_api_access: None,
) -> None:
"""
``_handle_filters_args`` must return a fresh ``Filters`` instance per
call so concurrent requests don't share filter state.
Regression test for #33828: under concurrent traffic the FAB default
implementation mutates ``self._filters`` (a single shared instance),
causing filters from one request to leak into another.
The fix lives on ``BaseSupersetModelRestApi`` so every superset REST
API subclass (datasets, charts, dashboards, saved queries, etc.)
inherits the request-scoped behavior. This test exercises it via
``DatasetRestApi`` as a concrete subclass.
"""
from flask_appbuilder.const import API_FILTERS_RIS_KEY
from superset.datasets.api import DatasetRestApi
api = DatasetRestApi()
api.datamodel = MagicMock()
api.search_columns = ["table_name"]
api.search_filters = {}
api._base_filters = MagicMock() # noqa: SLF001
# Each call should construct a fresh Filters instance via datamodel.get_filters
rison_args = {
API_FILTERS_RIS_KEY: [{"col": "table_name", "opr": "eq", "value": "a"}],
}
api._handle_filters_args(rison_args) # noqa: SLF001
api._handle_filters_args(rison_args) # noqa: SLF001
assert api.datamodel.get_filters.call_count == 2
# Returned object must be the joined-filters result of the *fresh* Filters,
# not the shared self._filters attribute.
fresh_filters = api.datamodel.get_filters.return_value
assert fresh_filters.rest_add_filters.call_count == 2
assert fresh_filters.get_joined_filters.call_count == 2

View File

@@ -594,7 +594,34 @@ class TestMapXYConfig:
assert result["viz_type"] == "echarts_timeseries_scatter"
assert result["show_legend"] is False
assert result["legend_orientation"] == "top"
assert result["legendOrientation"] == "top"
def test_map_xy_config_with_color_scheme(self) -> None:
"""color_scheme propagates to form_data when set."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue")],
kind="line",
color_scheme="lyftColors",
)
result = map_xy_config(config)
assert result["color_scheme"] == "lyftColors"
def test_map_xy_config_without_color_scheme(self) -> None:
"""color_scheme key omitted when not set, leaving Superset default."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue")],
kind="line",
)
result = map_xy_config(config)
assert "color_scheme" not in result
def test_map_xy_config_with_time_grain_month(self) -> None:
"""Test XY config mapping with monthly time grain"""

View File

@@ -29,18 +29,23 @@ from pydantic import ValidationError
from superset.mcp_service.chart.chart_utils import (
generate_chart_name,
map_big_number_config,
map_config_to_form_data,
map_mixed_timeseries_config,
map_pie_config,
map_pivot_table_config,
map_table_config,
)
from superset.mcp_service.chart.schemas import (
AxisConfig,
BigNumberChartConfig,
ColumnRef,
CurrencyFormat,
FilterConfig,
MixedTimeseriesChartConfig,
PieChartConfig,
PivotTableChartConfig,
TableChartConfig,
)
from superset.mcp_service.chart.validation.schema_validator import SchemaValidator
@@ -212,6 +217,18 @@ class TestMapPieConfig:
assert result["adhoc_filters"][0]["operator"] == "=="
assert result["adhoc_filters"][0]["comparator"] == "US"
def test_pie_form_data_color_scheme_override(self) -> None:
"""Explicit color_scheme overrides the supersetColors default."""
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="product"),
metric=ColumnRef(name="revenue", aggregate="SUM"),
color_scheme="googleCategory10c",
)
result = map_pie_config(config)
assert result["color_scheme"] == "googleCategory10c"
def test_pie_form_data_custom_options(self) -> None:
config = PieChartConfig(
chart_type="pie",
@@ -975,3 +992,272 @@ class TestSchemaValidatorNewTypes:
assert is_valid is False
assert error is not None
assert error.error_code == "INVALID_CHART_TYPE"
# ============================================================
# Chart Formatting Options Tests (sc-102806 follow-up)
# ============================================================
class TestPieFormattingOptions:
"""number/date/currency format, color scheme, legend orientation on Pie."""
def test_currency_format_in_form_data(self) -> None:
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="product"),
metric=ColumnRef(name="revenue", aggregate="SUM"),
currency_format=CurrencyFormat(symbol="USD", symbol_position="prefix"),
)
result = map_pie_config(config)
assert result["currency_format"] == {
"symbol": "USD",
"symbolPosition": "prefix",
}
def test_currency_format_omitted_when_unset(self) -> None:
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="product"),
metric=ColumnRef(name="revenue", aggregate="SUM"),
)
result = map_pie_config(config)
assert "currency_format" not in result
def test_legend_orientation_in_form_data(self) -> None:
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="product"),
metric=ColumnRef(name="revenue", aggregate="SUM"),
legend_orientation="bottom",
)
result = map_pie_config(config)
assert result["legendOrientation"] == "bottom"
def test_default_legend_orientation_is_top(self) -> None:
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="product"),
metric=ColumnRef(name="revenue", aggregate="SUM"),
)
result = map_pie_config(config)
assert result["legendOrientation"] == "top"
def test_date_format_overridable(self) -> None:
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="ds"),
metric=ColumnRef(name="revenue", aggregate="SUM"),
date_format="%Y-%m-%d",
)
result = map_pie_config(config)
assert result["date_format"] == "%Y-%m-%d"
class TestPivotTableFormattingOptions:
"""date/currency format on PivotTable."""
def test_currency_format_in_form_data(self) -> None:
config = PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="region")],
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
currency_format=CurrencyFormat(symbol="EUR", symbol_position="suffix"),
)
result = map_pivot_table_config(config)
assert result["currency_format"] == {
"symbol": "EUR",
"symbolPosition": "suffix",
}
def test_date_format_in_form_data(self) -> None:
config = PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="ds")],
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
date_format="%Y-%m",
)
result = map_pivot_table_config(config)
assert result["date_format"] == "%Y-%m"
def test_formatting_omitted_when_unset(self) -> None:
config = PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="region")],
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
)
result = map_pivot_table_config(config)
assert "currency_format" not in result
assert "date_format" not in result
class TestMixedTimeseriesFormattingOptions:
"""color scheme, currency format, legend orientation, data labels on Mixed."""
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
def test_color_scheme_in_form_data(self, mock_is_temporal) -> None:
mock_is_temporal.return_value = True
config = MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="ds"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
color_scheme="lyftColors",
)
result = map_mixed_timeseries_config(config)
assert result["color_scheme"] == "lyftColors"
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
def test_currency_format_primary_and_secondary(self, mock_is_temporal) -> None:
mock_is_temporal.return_value = True
config = MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="ds"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
currency_format=CurrencyFormat(symbol="USD"),
currency_format_secondary=CurrencyFormat(symbol="GBP"),
)
result = map_mixed_timeseries_config(config)
assert result["currency_format"] == {
"symbol": "USD",
"symbolPosition": "prefix",
}
assert result["currency_format_secondary"] == {
"symbol": "GBP",
"symbolPosition": "prefix",
}
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
def test_legend_orientation_in_form_data(self, mock_is_temporal) -> None:
mock_is_temporal.return_value = True
config = MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="ds"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
legend_orientation="left",
)
result = map_mixed_timeseries_config(config)
assert result["legendOrientation"] == "left"
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
def test_show_value_data_labels(self, mock_is_temporal) -> None:
mock_is_temporal.return_value = True
config = MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="ds"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
show_value=True,
)
result = map_mixed_timeseries_config(config)
assert result["show_value"] is True
class TestBigNumberFormattingOptions:
"""color scheme, currency format, time format on BigNumber."""
def test_currency_format_in_form_data(self) -> None:
config = BigNumberChartConfig(
chart_type="big_number",
metric=ColumnRef(name="revenue", aggregate="SUM"),
currency_format=CurrencyFormat(symbol="JPY", symbol_position="prefix"),
)
result = map_big_number_config(config)
assert result["currency_format"] == {
"symbol": "JPY",
"symbolPosition": "prefix",
}
def test_color_scheme_in_form_data(self) -> None:
config = BigNumberChartConfig(
chart_type="big_number",
metric=ColumnRef(name="revenue", aggregate="SUM"),
color_scheme="d3Category10",
)
result = map_big_number_config(config)
assert result["color_scheme"] == "d3Category10"
def test_time_format_only_for_trendline(self) -> None:
# Without trendline, time_format is dropped because the trendline
# x-axis doesn't render.
config = BigNumberChartConfig(
chart_type="big_number",
metric=ColumnRef(name="revenue", aggregate="SUM"),
time_format="%Y-%m-%d",
)
result = map_big_number_config(config)
assert "time_format" not in result
def test_time_format_with_trendline(self) -> None:
config = BigNumberChartConfig(
chart_type="big_number",
metric=ColumnRef(name="revenue", aggregate="SUM"),
temporal_column="ds",
show_trendline=True,
time_format="%Y-%m-%d",
)
result = map_big_number_config(config)
assert result["time_format"] == "%Y-%m-%d"
class TestTableFormattingOptions:
"""color scheme on Table."""
def test_color_scheme_in_form_data(self) -> None:
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="product"), ColumnRef(name="revenue")],
color_scheme="lyftColors",
)
result = map_table_config(config)
assert result["color_scheme"] == "lyftColors"
def test_color_scheme_omitted_when_unset(self) -> None:
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="product"), ColumnRef(name="revenue")],
)
result = map_table_config(config)
assert "color_scheme" not in result
class TestCurrencyFormatModel:
"""CurrencyFormat schema validation."""
def test_default_symbol_position_is_prefix(self) -> None:
cf = CurrencyFormat(symbol="USD")
assert cf.symbol_position == "prefix"
def test_camel_case_alias_accepted(self) -> None:
cf = CurrencyFormat.model_validate(
{"symbol": "USD", "symbolPosition": "suffix"}
)
assert cf.symbol_position == "suffix"
def test_invalid_position_rejected(self) -> None:
with pytest.raises(ValidationError):
CurrencyFormat(symbol="USD", symbol_position="middle")
def test_to_form_data_shape(self) -> None:
cf = CurrencyFormat(symbol="EUR", symbol_position="suffix")
assert cf.to_form_data() == {"symbol": "EUR", "symbolPosition": "suffix"}

View File

@@ -33,11 +33,15 @@ from superset.mcp_service.chart.schemas import (
PerformanceMetadata,
)
from superset.mcp_service.chart.tool.get_chart_data import (
_GENERIC_TYPE_MAP,
_MAX_RECOMMENDATIONS,
_query_from_form_data,
_recommend_visualizations,
_sanitize_chart_data_for_llm_context,
)
from superset.mcp_service.utils import sanitize_for_llm_context
from superset.mcp_service.utils.sanitization import LLM_CONTEXT_ESCAPED_CLOSE_DELIMITER
from superset.utils.core import GenericDataType
def _collect_groupby_extras(
@@ -1167,3 +1171,284 @@ class TestChartDataCommandValidation:
)
mock_command.run.assert_not_called()
@pytest.fixture
def mcp_server():
from superset.mcp_service.app import mcp
return mcp
@pytest.fixture
def mock_auth():
"""Mock MCP auth so Client.call_tool() doesn't need a real admin user."""
import importlib
from contextlib import contextmanager
from unittest.mock import Mock, patch
_gcd_module = importlib.import_module(
"superset.mcp_service.chart.tool.get_chart_data"
)
@contextmanager
def _noop_log_context(*_args: Any, **_kwargs: Any) -> Any:
yield lambda **_kw: None
# Neutralize event_logger.log_context: the default DBEventLogger would
# otherwise insert a log row referencing our mock user_id and fail a
# FK constraint against the real users table. Patch via the module
# object directly — the `tool` package's __init__.py re-exports the
# get_chart_data function under the same name, which shadows the
# submodule binding in the package namespace, so a dotted-string patch
# target resolves to the function and mock.patch cannot find
# event_logger on it.
mock_event_logger = Mock()
mock_event_logger.log_context.side_effect = _noop_log_context
with (
patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user,
patch.object(_gcd_module, "event_logger", mock_event_logger),
):
user = Mock()
user.id = 1
user.username = "admin"
mock_get_user.return_value = user
yield mock_get_user
def _extract_metrics_load_path(load_opt: Any) -> list[str]:
"""Walk a SQLAlchemy Load option and return the attr chain.
e.g. subqueryload(Slice.table).subqueryload(SqlaTable.metrics)
-> ["table", "metrics"]
"""
path = getattr(load_opt, "path", ())
return [elem.key for elem in path if hasattr(elem, "key")]
class TestChartLookupEagerLoading:
"""Tests that get_chart_data eager-loads dataset.metrics on chart lookup.
Regression tests for the Excel export DetachedInstanceError on
dataset.metrics. The chart's dataset metrics relationship must be
eager-loaded at fetch time so it remains accessible during Excel export
after the request-scoped session is detached.
"""
@pytest.mark.asyncio
async def test_numeric_id_lookup_passes_metrics_eager_load(
self, mcp_server, mock_auth
):
"""Integer identifier lookup must eager-load Slice.table.metrics."""
from unittest.mock import patch
from fastmcp import Client
with patch(
"superset.daos.chart.ChartDAO.find_by_id", return_value=None
) as mock_find:
async with Client(mcp_server) as client:
await client.call_tool(
"get_chart_data",
{"request": {"identifier": 42, "format": "excel"}},
)
mock_find.assert_called_once()
call = mock_find.call_args
assert call.args == (42,)
query_options = call.kwargs.get("query_options")
assert query_options is not None, (
"Chart lookup must pass query_options for eager-loading."
)
assert len(query_options) == 1
load_path = _extract_metrics_load_path(query_options[0])
assert load_path == ["table", "metrics"], (
f"Expected subqueryload chain 'table' -> 'metrics', got {load_path}"
)
@pytest.mark.asyncio
async def test_uuid_lookup_passes_metrics_eager_load(self, mcp_server, mock_auth):
"""UUID identifier lookup must also eager-load Slice.table.metrics."""
from unittest.mock import patch
from fastmcp import Client
uuid = "a1b2c3d4-5678-90ab-cdef-1234567890ab"
with patch(
"superset.daos.chart.ChartDAO.find_by_id", return_value=None
) as mock_find:
async with Client(mcp_server) as client:
await client.call_tool(
"get_chart_data",
{"request": {"identifier": uuid, "format": "excel"}},
)
mock_find.assert_called_once()
call = mock_find.call_args
assert call.args == (uuid,)
assert call.kwargs.get("id_column") == "uuid"
query_options = call.kwargs.get("query_options")
assert query_options is not None, (
"UUID chart lookup must pass query_options for eager-loading."
)
load_path = _extract_metrics_load_path(query_options[0])
assert load_path == ["table", "metrics"], (
f"Expected subqueryload chain 'table' -> 'metrics', got {load_path}"
)
@pytest.mark.asyncio
async def test_json_format_also_eager_loads_metrics(self, mcp_server, mock_auth):
"""Eager-load is applied for every format, not just Excel.
Applying unconditionally keeps the fix robust if additional code paths
start touching dataset.metrics, and avoids branching behavior that
would be easy to regress on.
"""
from unittest.mock import patch
from fastmcp import Client
with patch(
"superset.daos.chart.ChartDAO.find_by_id", return_value=None
) as mock_find:
async with Client(mcp_server) as client:
await client.call_tool(
"get_chart_data",
{"request": {"identifier": 7, "format": "json"}},
)
call = mock_find.call_args
query_options = call.kwargs.get("query_options")
assert query_options is not None
assert _extract_metrics_load_path(query_options[0]) == ["table", "metrics"]
# ---------------------------------------------------------------------------
# Tests for _recommend_visualizations
# ---------------------------------------------------------------------------
def _col(
name: str,
data_type: str = "string",
unique_count: int = 5,
null_count: int = 0,
) -> DataColumn:
"""Shortcut to build a DataColumn for tests."""
return DataColumn(
name=name,
display_name=name,
data_type=data_type,
sample_values=[],
null_count=null_count,
unique_count=unique_count,
)
def test_recommend_temporal_and_numeric_suggests_line_chart():
cols = [_col("created_at", "temporal"), _col("revenue", "numeric")]
result = _recommend_visualizations("table", cols, row_count=50)
assert "line chart" in result
assert "area chart" in result
def test_recommend_categorical_and_numeric_suggests_bar_chart():
cols = [_col("region", "string", unique_count=5), _col("sales", "numeric")]
result = _recommend_visualizations("echarts_timeseries_line", cols, row_count=50)
assert "bar chart" in result
def test_recommend_excludes_current_viz_type():
cols = [_col("created_at", "temporal"), _col("revenue", "numeric")]
result = _recommend_visualizations("echarts_timeseries_line", cols, row_count=50)
assert "line chart" not in result
def test_recommend_multiple_numeric_suggests_scatter():
cols = [
_col("height", "numeric"),
_col("weight", "numeric"),
_col("age", "numeric"),
]
result = _recommend_visualizations("table", cols, row_count=100)
assert "scatter plot" in result
def test_recommend_single_numeric_suggests_kpi():
cols = [_col("total_revenue", "numeric")]
result = _recommend_visualizations("table", cols, row_count=1)
assert "big number / KPI" in result
def test_recommend_all_strings_falls_back():
cols = [_col("name", "string"), _col("address", "string")]
result = _recommend_visualizations("pie", cols, row_count=100)
assert "table" in result or "bar chart" in result
def test_recommend_high_cardinality_no_pie():
cols = [
_col("user_id", "string", unique_count=900),
_col("score", "numeric"),
]
result = _recommend_visualizations("table", cols, row_count=1000)
assert "pie chart" not in result
def test_recommend_caps_at_max():
cols = [_col("ts", "temporal"), _col("a", "numeric"), _col("b", "numeric")]
result = _recommend_visualizations("table", cols, row_count=100)
assert len(result) <= _MAX_RECOMMENDATIONS
def test_recommend_empty_columns_returns_table():
result = _recommend_visualizations("table", [], row_count=0)
assert result == ["table"]
def test_recommend_pie_only_for_low_cardinality():
cols = [
_col("department", "string", unique_count=25),
_col("headcount", "numeric"),
]
result = _recommend_visualizations("table", cols, row_count=100)
assert "pie chart" not in result
def test_recommend_temporal_few_rows_prefers_bar():
cols = [_col("date", "temporal"), _col("revenue", "numeric")]
result = _recommend_visualizations("table", cols, row_count=3)
assert "bar chart" in result
assert "line chart" not in result
def test_recommend_single_numeric_high_cardinality_suggests_histogram():
cols = [_col("salary", "numeric", unique_count=500)]
result = _recommend_visualizations("table", cols, row_count=1000)
assert "histogram" in result
def test_coltypes_populates_data_type():
"""Verify that GenericDataType values from coltypes are mapped correctly."""
assert _GENERIC_TYPE_MAP[GenericDataType.NUMERIC] == "numeric"
assert _GENERIC_TYPE_MAP[GenericDataType.STRING] == "string"
assert _GENERIC_TYPE_MAP[GenericDataType.TEMPORAL] == "temporal"
assert _GENERIC_TYPE_MAP[GenericDataType.BOOLEAN] == "boolean"
def test_bool_isinstance_check_before_int():
"""bool is a subclass of int; verify bool check takes priority in fallback."""
# When coltypes is unavailable, the fallback isinstance heuristic
# must check bool before int/float since isinstance(True, int) is True.
# We verify this indirectly: if _GENERIC_TYPE_MAP handles bool correctly,
# and the fallback code checks bool first, booleans won't be "numeric".
# Direct test: simulate what the fallback does
sample_values = [True, False, True]
data_type = "string"
if all(isinstance(v, bool) for v in sample_values):
data_type = "boolean"
elif all(isinstance(v, (int, float)) for v in sample_values):
data_type = "numeric"
assert data_type == "boolean"

View File

@@ -16,18 +16,25 @@
# under the License.
"""
Unit tests for get_chart_info MCP tool privacy behavior.
Unit tests for get_chart_info MCP tool: dashboard-filter resolution and
privacy behavior.
"""
import importlib
from contextlib import nullcontext
from types import SimpleNamespace
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch
import pytest
from fastmcp import Client
from superset.commands.dashboard.exceptions import DashboardNotFoundError
from superset.mcp_service.app import mcp
from superset.mcp_service.chart.chart_helpers import (
_resolve_filter_operator_and_value,
build_applied_dashboard_filters,
ChartNotOnDashboardError,
)
from superset.mcp_service.chart.schemas import (
ChartInfo,
extract_filters_from_form_data,
@@ -84,6 +91,238 @@ def _make_chart_info() -> ChartInfo:
)
class TestGetChartInfoRequestSchema:
def test_dashboard_id_optional(self):
request = GetChartInfoRequest(identifier=1)
assert request.dashboard_id is None
def test_dashboard_id_accepted(self):
request = GetChartInfoRequest(identifier=1, dashboard_id=42)
assert request.dashboard_id == 42
class TestResolveFilterOperatorAndValue:
def test_matches_adhoc_filter_by_subject(self):
efd = {
"adhoc_filters": [
{
"subject": "country",
"operator": "IN",
"comparator": ["US", "CA"],
}
]
}
assert _resolve_filter_operator_and_value(efd, "country") == (
"IN",
["US", "CA"],
)
def test_matches_legacy_filter_by_col(self):
efd = {"filters": [{"col": "state", "op": "==", "val": "NY"}]}
assert _resolve_filter_operator_and_value(efd, "state") == ("==", "NY")
def test_time_range_when_no_column(self):
efd = {"time_range": "Last 7 days"}
assert _resolve_filter_operator_and_value(efd, None) == (
"TIME_RANGE",
"Last 7 days",
)
def test_column_not_in_extra_form_data(self):
efd = {
"adhoc_filters": [{"subject": "other", "operator": "==", "comparator": 1}]
}
assert _resolve_filter_operator_and_value(efd, "country") == (None, None)
def test_none_extra_form_data(self):
assert _resolve_filter_operator_and_value(None, "country") == (None, None)
def test_ignores_non_dict_entries(self):
efd = {
"adhoc_filters": ["not-a-dict", None],
"filters": [42, "foo"],
}
assert _resolve_filter_operator_and_value(efd, "country") == (None, None)
class TestBuildAppliedDashboardFilters:
"""The helper validates access, checks chart-on-dashboard, iterates
native filters, resolves scope, and maps each to AppliedDashboardFilter."""
def _make_dashboard(self, json_metadata=None, position_json=None, slice_ids=None):
dashboard = MagicMock()
dashboard.json_metadata = json_metadata or "{}"
dashboard.position_json = position_json or "{}"
dashboard.slices = [MagicMock(id=sid) for sid in (slice_ids or [])]
return dashboard
def test_chart_not_on_dashboard_raises(self):
dashboard = self._make_dashboard(slice_ids=[2, 3])
with (
patch("superset.db") as mock_db,
patch("superset.security_manager"),
):
mock_db.session.query.return_value.filter_by.return_value.one_or_none.return_value = dashboard # noqa: E501
with pytest.raises(ChartNotOnDashboardError, match="not on dashboard"):
build_applied_dashboard_filters(dashboard_id=10, chart_id=1)
def test_dashboard_not_found_raises(self):
with (
patch("superset.db") as mock_db,
patch("superset.security_manager"),
):
mock_db.session.query.return_value.filter_by.return_value.one_or_none.return_value = None # noqa: E501
with pytest.raises(DashboardNotFoundError):
build_applied_dashboard_filters(dashboard_id=10, chart_id=1)
def test_in_scope_filter_with_static_default(self):
native_filter = {
"id": "NATIVE_FILTER-1",
"name": "Country",
"type": "NATIVE_FILTER",
"filterType": "filter_select",
"chartsInScope": [1],
"targets": [{"column": {"name": "country"}, "datasetId": 7}],
"defaultDataMask": {
"filterState": {"value": ["US"]},
"extraFormData": {
"adhoc_filters": [
{
"subject": "country",
"operator": "IN",
"comparator": ["US"],
}
]
},
},
}
dashboard = self._make_dashboard(
json_metadata='{"native_filter_configuration": %s}'
% _json(native_filter_list=[native_filter]),
slice_ids=[1],
)
with (
patch("superset.db") as mock_db,
patch("superset.security_manager"),
):
mock_db.session.query.return_value.filter_by.return_value.one_or_none.return_value = dashboard # noqa: E501
result = build_applied_dashboard_filters(dashboard_id=10, chart_id=1)
assert len(result) == 1
flt = result[0]
assert flt.id == "NATIVE_FILTER-1"
assert flt.name == "Country"
assert flt.filter_type == "filter_select"
assert flt.column == "country"
assert flt.operator == "IN"
assert flt.value == ["US"]
assert flt.status == "applied"
def test_excluded_chart_filter_skipped(self):
native_filter = {
"id": "NATIVE_FILTER-1",
"name": "Region",
"type": "NATIVE_FILTER",
"filterType": "filter_select",
"chartsInScope": [2, 3], # chart 1 excluded
"targets": [{"column": {"name": "region"}, "datasetId": 7}],
"defaultDataMask": {
"filterState": {"value": ["NA"]},
"extraFormData": {
"filters": [{"col": "region", "op": "==", "val": "NA"}]
},
},
}
dashboard = self._make_dashboard(
json_metadata='{"native_filter_configuration": %s}'
% _json(native_filter_list=[native_filter]),
slice_ids=[1],
)
with (
patch("superset.db") as mock_db,
patch("superset.security_manager"),
):
mock_db.session.query.return_value.filter_by.return_value.one_or_none.return_value = dashboard # noqa: E501
result = build_applied_dashboard_filters(dashboard_id=10, chart_id=1)
assert result == []
def test_default_to_first_item_marks_prequery(self):
native_filter = {
"id": "NATIVE_FILTER-1",
"name": "Region",
"type": "NATIVE_FILTER",
"filterType": "filter_select",
"chartsInScope": [1],
"targets": [{"column": {"name": "region"}, "datasetId": 7}],
"controlValues": {"defaultToFirstItem": True},
"defaultDataMask": {},
}
dashboard = self._make_dashboard(
json_metadata='{"native_filter_configuration": %s}'
% _json(native_filter_list=[native_filter]),
slice_ids=[1],
)
with (
patch("superset.db") as mock_db,
patch("superset.security_manager"),
):
mock_db.session.query.return_value.filter_by.return_value.one_or_none.return_value = dashboard # noqa: E501
result = build_applied_dashboard_filters(dashboard_id=10, chart_id=1)
assert len(result) == 1
assert result[0].status == "not_applied_uses_default_to_first_item_prequery"
assert result[0].operator is None
assert result[0].value is None
def test_divider_entry_skipped(self):
divider = {
"id": "DIVIDER-1",
"name": "Section header",
"type": "DIVIDER",
}
dashboard = self._make_dashboard(
json_metadata='{"native_filter_configuration": %s}'
% _json(native_filter_list=[divider]),
slice_ids=[1],
)
with (
patch("superset.db") as mock_db,
patch("superset.security_manager"),
):
mock_db.session.query.return_value.filter_by.return_value.one_or_none.return_value = dashboard # noqa: E501
result = build_applied_dashboard_filters(dashboard_id=10, chart_id=1)
assert result == []
def test_no_native_filters_returns_empty_list(self):
dashboard = self._make_dashboard(
json_metadata="{}",
slice_ids=[1],
)
with (
patch("superset.db") as mock_db,
patch("superset.security_manager"),
):
mock_db.session.query.return_value.filter_by.return_value.one_or_none.return_value = dashboard # noqa: E501
result = build_applied_dashboard_filters(dashboard_id=10, chart_id=1)
assert result == []
def _json(native_filter_list):
"""Serialize a native_filter list as JSON string for embedding in
json_metadata fixtures without escaping issues."""
from superset.utils import json
return json.dumps(native_filter_list)
class TestGetChartInfoPrivacy:
@pytest.mark.asyncio
async def test_restricted_user_redacts_saved_chart_data_model_fields(

View File

@@ -2043,12 +2043,10 @@ class TestListDatasetsCreatedByMe:
with pytest.raises(ValidationError, match="created_by_me"):
ListDatasetsRequest(created_by_me=True, search="My tables")
def test_dataset_filter_rejects_created_by_fk(self):
"""created_by_fk is not a public filter column; use created_by_me instead."""
from pydantic import ValidationError
with pytest.raises(ValidationError):
DatasetFilter(col="created_by_fk", opr="eq", value=1)
def test_dataset_filter_accepts_created_by_fk(self):
"""created_by_fk is exposed for person-filtering via find_users."""
f = DatasetFilter(col="created_by_fk", opr="eq", value=1)
assert f.col == "created_by_fk"
class TestListDatasetsOwnedByMe:
@@ -2115,14 +2113,10 @@ class TestListDatasetsRequestWrapper:
assert f.col == col
def test_dataset_filter_invalid_col_raises(self) -> None:
"""Column names not in the Literal are rejected with a validation error.
This guards against LLMs passing ``created_by_fk`` or similar
internal column names that are not exposed as filter fields.
"""
"""Column names not in the Literal are rejected with a validation error."""
from pydantic import ValidationError
for bad_col in ("created_by_fk", "id", "database_id", "owner"):
for bad_col in ("id", "database_id", "owner"):
with pytest.raises(ValidationError):
DatasetFilter(col=bad_col, opr="eq", value="1")

View File

@@ -810,6 +810,52 @@ class TestGenerateExploreLink:
assert "Dataset not found: 99999" in result.data["error"]
assert "list_datasets" in result.data["error"]
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@pytest.mark.asyncio
async def test_generate_explore_link_without_config(
self, mock_find_dataset, mcp_server
):
"""Omitting config returns a default dataset explore URL."""
mock_find_dataset.return_value = _mock_dataset(id=42)
request = GenerateExploreLinkRequest(dataset_id="42")
async with Client(mcp_server) as client:
result = await client.call_tool(
"generate_explore_link", {"request": request.model_dump()}
)
assert result.data["error"] is None
assert (
result.data["url"]
== "http://localhost:9001/explore/?datasource_type=table"
"&datasource_id=42"
)
assert result.data["form_data"] == {}
assert result.data["form_data_key"] is None
assert result.data["chart_type_label"] is None
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@pytest.mark.asyncio
async def test_generate_explore_link_without_config_missing_dataset(
self, mock_find_dataset, mcp_server
):
"""Omitting config still surfaces a dataset-not-found error."""
mock_find_dataset.return_value = None
request = GenerateExploreLinkRequest(dataset_id="99999")
async with Client(mcp_server) as client:
result = await client.call_tool(
"generate_explore_link", {"request": request.model_dump()}
)
assert result.data["url"] == ""
assert result.data["form_data"] == {}
assert result.data["form_data_key"] is None
assert result.data["chart_type_label"] is None
assert "Dataset not found: 99999" in result.data["error"]
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@pytest.mark.asyncio
async def test_generate_explore_link_nonexistent_uuid_dataset(

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,406 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from unittest.mock import MagicMock, patch
import pytest
from fastmcp import Client
from fastmcp.exceptions import ToolError
from pydantic import ValidationError
from superset.mcp_service.app import mcp
from superset.mcp_service.report.schemas import ListReportsRequest, ReportFilter
from superset.utils import json
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
def create_mock_report(
report_id: int = 1,
name: str = "Daily Sales Report",
report_type: str = "Report",
active: bool = True,
crontab: str = "0 9 * * *",
description: str = "A daily report",
dashboard_id: int | None = None,
chart_id: int | None = None,
) -> MagicMock:
"""Factory function to create mock report objects with sensible defaults."""
report = MagicMock()
report.id = report_id
report.name = name
report.type = report_type
report.active = active
report.crontab = crontab
report.description = description
report.dashboard_id = dashboard_id
report.chart_id = chart_id
report.owners = []
report.changed_on = None
report.created_on = None
return report
@pytest.fixture
def mcp_server():
return mcp
@pytest.fixture(autouse=True)
def mock_auth():
"""Mock authentication for all tests."""
from unittest.mock import Mock
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
mock_user = Mock()
mock_user.id = 1
mock_user.username = "admin"
mock_get_user.return_value = mock_user
yield mock_get_user
class TestReportFilterSchema:
"""Tests for ReportFilter schema — filterable columns."""
def test_valid_filter_name(self):
f = ReportFilter(col="name", opr="eq", value="My Report")
assert f.col == "name"
def test_valid_filter_type(self):
f = ReportFilter(col="type", opr="eq", value="Alert")
assert f.col == "type"
def test_valid_filter_active(self):
f = ReportFilter(col="active", opr="eq", value=True)
assert f.col == "active"
def test_valid_filter_dashboard_id(self):
f = ReportFilter(col="dashboard_id", opr="eq", value=1)
assert f.col == "dashboard_id"
def test_valid_filter_chart_id(self):
f = ReportFilter(col="chart_id", opr="eq", value=42)
assert f.col == "chart_id"
def test_invalid_filter_column_rejected(self):
"""Columns not in the Literal set must be rejected."""
with pytest.raises(ValidationError):
ReportFilter(col="not_a_real_column", opr="eq", value=1)
def test_created_by_fk_is_rejected(self):
"""created_by_fk is not a public filter column."""
with pytest.raises(ValidationError):
ReportFilter(col="created_by_fk", opr="eq", value=1)
def test_list_reports_request_accepts_valid_fields():
request = ListReportsRequest(page=1, page_size=10)
assert request.page == 1
assert request.page_size == 10
def test_list_reports_request_rejects_search_and_filters_together():
with pytest.raises(ValidationError):
ListReportsRequest(
search="my report",
filters=[{"col": "active", "opr": "eq", "value": True}],
)
@patch("superset.daos.report.ReportScheduleDAO.list")
@pytest.mark.asyncio
async def test_list_reports_basic(mock_list, mcp_server):
"""Test basic report listing functionality."""
report = create_mock_report()
mock_list.return_value = ([report], 1)
async with Client(mcp_server) as client:
request = ListReportsRequest(page=1, page_size=10)
result = await client.call_tool(
"list_reports", {"request": request.model_dump()}
)
assert result.content is not None
data = json.loads(result.content[0].text)
assert data["reports"] is not None
assert len(data["reports"]) == 1
assert data["reports"][0]["id"] == 1
assert data["reports"][0]["name"] == "Daily Sales Report"
assert data["reports"][0]["type"] == "Report"
assert data["reports"][0]["active"] is True
assert data["reports"][0]["crontab"] == "0 9 * * *"
@patch("superset.daos.report.ReportScheduleDAO.list")
@pytest.mark.asyncio
async def test_list_reports_with_search(mock_list, mcp_server):
"""Test report listing with search functionality."""
report = create_mock_report(name="Weekly Alert")
mock_list.return_value = ([report], 1)
async with Client(mcp_server) as client:
request = ListReportsRequest(page=1, page_size=10, search="Weekly")
result = await client.call_tool(
"list_reports", {"request": request.model_dump()}
)
assert result.content is not None
data = json.loads(result.content[0].text)
assert data["reports"] is not None
assert len(data["reports"]) == 1
assert data["reports"][0]["name"] == "Weekly Alert"
@patch("superset.daos.report.ReportScheduleDAO.list")
@pytest.mark.asyncio
async def test_list_reports_with_type_filter(mock_list, mcp_server):
"""Test report listing filtered by type."""
report = create_mock_report(report_type="Alert")
mock_list.return_value = ([report], 1)
async with Client(mcp_server) as client:
request = ListReportsRequest(
page=1,
page_size=10,
filters=[{"col": "type", "opr": "eq", "value": "Alert"}],
)
result = await client.call_tool(
"list_reports", {"request": request.model_dump()}
)
assert result.content is not None
data = json.loads(result.content[0].text)
assert len(data["reports"]) == 1
assert data["reports"][0]["type"] == "Alert"
@patch("superset.daos.report.ReportScheduleDAO.list")
@pytest.mark.asyncio
async def test_list_reports_does_not_expose_owners(mock_list, mcp_server):
"""Test that owners field is stripped by privacy controls."""
report = create_mock_report()
mock_list.return_value = ([report], 1)
async with Client(mcp_server) as client:
request = ListReportsRequest(
page=1,
page_size=10,
select_columns=["id", "name", "owners"],
)
result = await client.call_tool(
"list_reports", {"request": request.model_dump()}
)
data = json.loads(result.content[0].text)
# owners is filtered by USER_DIRECTORY_FIELDS
assert "owners" not in data.get("columns_requested", [])
assert "owners" not in data.get("columns_loaded", [])
@patch("superset.daos.report.ReportScheduleDAO.list")
@pytest.mark.asyncio
async def test_list_reports_empty_results(mock_list, mcp_server):
"""Test report listing with no results."""
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client:
request = ListReportsRequest(page=1, page_size=10)
result = await client.call_tool(
"list_reports", {"request": request.model_dump()}
)
data = json.loads(result.content[0].text)
assert data["reports"] == []
assert data["count"] == 0
assert data["total_count"] == 0
@patch("superset.daos.report.ReportScheduleDAO.list")
@pytest.mark.asyncio
async def test_list_reports_api_error(mock_list, mcp_server):
"""Test error handling when DAO raises an exception."""
mock_list.side_effect = ToolError("Report DAO error")
async with Client(mcp_server) as client:
request = ListReportsRequest(page=1, page_size=10)
with pytest.raises(ToolError) as excinfo: # noqa: PT012
await client.call_tool("list_reports", {"request": request.model_dump()})
assert "Report DAO error" in str(excinfo.value)
@patch("superset.daos.report.ReportScheduleDAO.list")
@pytest.mark.asyncio
async def test_list_reports_without_request_uses_defaults(mock_list, mcp_server):
"""list_reports with no request payload should use default parameters."""
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client:
result = await client.call_tool("list_reports", {})
data = json.loads(result.content[0].text)
assert data["reports"] == []
assert data["page"] == 1
@patch("superset.daos.report.ReportScheduleDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_report_info_basic(mock_find, mcp_server):
"""Test basic get report info functionality."""
report = create_mock_report()
mock_find.return_value = report
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_report_info", {"request": {"identifier": 1}}
)
assert result.content is not None
data = json.loads(result.content[0].text)
assert data["id"] == 1
assert data["name"] == "Daily Sales Report"
assert data["type"] == "Report"
assert data["active"] is True
assert data["crontab"] == "0 9 * * *"
assert "owners" not in data
@patch("superset.daos.report.ReportScheduleDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_report_info_alert_type(mock_find, mcp_server):
"""Test get report info for an Alert type schedule."""
report = create_mock_report(report_type="Alert", name="Revenue Alert")
mock_find.return_value = report
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_report_info", {"request": {"identifier": 1}}
)
data = json.loads(result.content[0].text)
assert data["type"] == "Alert"
assert data["name"] == "Revenue Alert"
@patch("superset.daos.report.ReportScheduleDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_report_info_not_found(mock_find, mcp_server):
"""Test get report info when report does not exist."""
mock_find.return_value = None
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_report_info", {"request": {"identifier": 999}}
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "not_found"
@patch("superset.daos.report.ReportScheduleDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_report_info_with_dashboard(mock_find, mcp_server):
"""Test get report info with associated dashboard."""
report = create_mock_report(dashboard_id=42)
mock_find.return_value = report
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_report_info", {"request": {"identifier": 1}}
)
data = json.loads(result.content[0].text)
assert data["dashboard_id"] == 42
assert data["chart_id"] is None
@patch("superset.daos.report.ReportScheduleDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_report_info_with_chart(mock_find, mcp_server):
"""Test get report info with associated chart."""
report = create_mock_report(chart_id=7)
mock_find.return_value = report
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_report_info", {"request": {"identifier": 1}}
)
data = json.loads(result.content[0].text)
assert data["chart_id"] == 7
assert data["dashboard_id"] is None
def test_list_reports_request_rejects_invalid_order_column():
"""order_column is validated against REPORT_SORTABLE_COLUMNS."""
from superset.mcp_service.common.schema_discovery import REPORT_SORTABLE_COLUMNS
assert "invalid_column" not in REPORT_SORTABLE_COLUMNS
# The validation happens inside ModelListCore, not the request schema,
# so we just verify the sortable list doesn't include bad columns.
request = ListReportsRequest(page=1, page_size=10, order_column="invalid_column")
assert (
request.order_column == "invalid_column"
) # schema accepts it; core rejects it
@patch("superset.daos.report.ReportScheduleDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_report_info_humanized_timestamps(mock_find, mcp_server):
"""Test that changed_on_humanized and created_on_humanized are returned."""
from datetime import datetime, timezone
report = create_mock_report()
report.changed_on = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
report.created_on = datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
mock_find.return_value = report
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_report_info", {"request": {"identifier": 1}}
)
data = json.loads(result.content[0].text)
assert "changed_on_humanized" in data
assert data["changed_on_humanized"] is not None
assert "created_on_humanized" in data
assert data["created_on_humanized"] is not None
@patch("superset.daos.report.ReportScheduleDAO.list")
@pytest.mark.asyncio
async def test_list_reports_owned_by_me_passed_to_dao(mock_list, mcp_server):
"""owned_by_me=True is forwarded to the DAO layer."""
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client:
request = ListReportsRequest(page=1, page_size=10, owned_by_me=True)
await client.call_tool("list_reports", {"request": request.model_dump()})
mock_list.assert_called_once()
_, kwargs = mock_list.call_args
filters_arg = kwargs.get("filters", [])
assert any(getattr(f, "col", None) == "owners.id" for f in filters_arg), (
"owned_by_me should inject an owners.id filter into the DAO call"
)
@patch("superset.daos.report.ReportScheduleDAO.list")
@pytest.mark.asyncio
async def test_list_reports_created_by_me_passed_to_dao(mock_list, mcp_server):
"""created_by_me=True is forwarded to the DAO layer."""
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client:
request = ListReportsRequest(page=1, page_size=10, created_by_me=True)
await client.call_tool("list_reports", {"request": request.model_dump()})
mock_list.assert_called_once()
_, kwargs = mock_list.call_args
filters_arg = kwargs.get("filters", [])
assert any(getattr(f, "col", None) == "created_by_fk" for f in filters_arg), (
"created_by_me should inject a created_by_fk filter into the DAO call"
)

View File

@@ -0,0 +1,257 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for find_users MCP tool and its filter contract."""
import importlib
from unittest.mock import MagicMock, Mock, patch
import pytest
from fastmcp import Client
from fastmcp.exceptions import ToolError
from pydantic import ValidationError
from superset.mcp_service.app import mcp
from superset.mcp_service.system.schemas import FindUsersRequest, FindUsersResponse
from superset.utils import json
# Import the submodule directly so ``patch.object`` targets the module (not the
# ``find_users`` function that ``tool/__init__.py`` re-exports onto the
# package). The package attribute is the function, so dotted-string patches
# like ``superset.mcp_service.system.tool.find_users.db`` can resolve to the
# function in some import orderings and fail with AttributeError.
find_users_module = importlib.import_module(
"superset.mcp_service.system.tool.find_users"
)
@pytest.fixture
def mcp_server():
return mcp
@pytest.fixture(autouse=True)
def mock_auth():
"""Mock authentication for all tests."""
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
mock_user = Mock()
mock_user.id = 1
mock_user.username = "admin"
mock_get_user.return_value = mock_user
yield mock_get_user
def _make_user(id_, username, first=None, last=None, email=None, active=True):
"""Build a Mock user with the attributes serialize_user_object reads."""
user = Mock(
spec=["id", "username", "first_name", "last_name", "email", "active", "roles"]
)
user.id = id_
user.username = username
user.first_name = first
user.last_name = last
user.email = email
user.active = active
user.roles = []
return user
def _patch_user_query(rows):
"""Patch the SQLAlchemy chain used by find_users to return a fixed result set."""
chain = MagicMock()
chain.filter.return_value = chain
chain.order_by.return_value = chain
chain.limit.return_value = chain
chain.all.return_value = rows
session = MagicMock()
session.query.return_value = chain
return session, chain
# ---------------------------------------------------------------------------
# Schema tests
# ---------------------------------------------------------------------------
def test_find_users_request_rejects_empty_query():
with pytest.raises(ValidationError):
FindUsersRequest(query="")
def test_find_users_request_rejects_extra_fields():
with pytest.raises(ValidationError):
FindUsersRequest(query="maxime", random_field="x")
def test_find_users_response_default_truncated_false():
resp = FindUsersResponse(users=[], count=0)
assert resp.truncated is False
# ---------------------------------------------------------------------------
# Tool-level tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_find_users_returns_matches(mcp_server):
rows = [
_make_user(
7, "maxime", first="Maxime", last="Beauchemin", email="m@example.com"
)
]
session, _ = _patch_user_query(rows)
with (
patch.object(find_users_module, "db") as mock_db,
patch.object(find_users_module, "security_manager") as mock_sm,
patch.object(find_users_module, "or_") as mock_or,
):
mock_db.session = session
mock_sm.user_model = MagicMock()
mock_or.return_value = MagicMock()
async with Client(mcp_server) as client:
result = await client.call_tool(
"find_users", {"request": {"query": "maxime"}}
)
data = json.loads(result.content[0].text)
assert data["count"] == 1
assert data["truncated"] is False
assert data["users"][0]["id"] == 7
assert data["users"][0]["username"] == "maxime"
assert data["users"][0]["first_name"] == "Maxime"
assert data["users"][0]["last_name"] == "Beauchemin"
# Privacy: minimal projection excludes identity attributes that aren't
# required for filter resolution. Catch regressions on the response shape.
for forbidden in ("email", "active", "roles"):
assert forbidden not in data["users"][0]
# or_ should have been built across the four matched columns
assert mock_or.called
assert len(mock_or.call_args.args) == 4
@pytest.mark.asyncio
async def test_find_users_truncates_when_more_rows_than_page_size(mcp_server):
# page_size=2 with 3 returned rows -> truncated, response trimmed to 2
rows = [
_make_user(1, "a"),
_make_user(2, "b"),
_make_user(3, "c"),
]
session, chain = _patch_user_query(rows)
with (
patch.object(find_users_module, "db") as mock_db,
patch.object(find_users_module, "security_manager") as mock_sm,
patch.object(find_users_module, "or_") as mock_or,
):
mock_db.session = session
mock_sm.user_model = MagicMock()
mock_or.return_value = MagicMock()
async with Client(mcp_server) as client:
result = await client.call_tool(
"find_users", {"request": {"query": "a", "page_size": 2}}
)
# Tool requested page_size+1 rows for truncation detection
chain.limit.assert_called_with(3)
data = json.loads(result.content[0].text)
assert data["count"] == 2
assert data["truncated"] is True
assert [u["id"] for u in data["users"]] == [1, 2]
@pytest.mark.asyncio
async def test_find_users_rejects_empty_query_via_client(mcp_server):
async with Client(mcp_server) as client:
with pytest.raises(ToolError):
await client.call_tool("find_users", {"request": {"query": ""}})
@pytest.mark.parametrize("blank", [" ", " ", "\t", "\n \t"])
def test_find_users_request_rejects_whitespace_only_query(blank):
# Whitespace-only queries would strip to "" and produce a LIKE "%%" pattern
# that enumerates the entire user directory. The validator must reject them.
with pytest.raises(ValidationError):
FindUsersRequest(query=blank)
def test_find_users_request_strips_query_whitespace():
# Validator should normalize the stored query so downstream LIKE patterns
# don't carry leading/trailing whitespace.
request = FindUsersRequest(query=" maxime ")
assert request.query == "maxime"
# ---------------------------------------------------------------------------
# Filter contract: created_by_fk / changed_by_fk filtering on list tools
# ---------------------------------------------------------------------------
@patch("superset.daos.dashboard.DashboardDAO.list")
@pytest.mark.asyncio
async def test_list_dashboards_passes_created_by_fk_filter_to_dao(
mock_list, mcp_server
):
"""list_dashboards should accept created_by_fk filter and forward it."""
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client:
await client.call_tool(
"list_dashboards",
{
"request": {
"filters": [{"col": "created_by_fk", "opr": "eq", "value": 7}],
"page": 1,
"page_size": 10,
}
},
)
assert mock_list.called
forwarded_filters = mock_list.call_args.kwargs.get("column_operators")
assert forwarded_filters is not None
assert any(
getattr(f, "col", None) == "created_by_fk" and getattr(f, "value", None) == 7
for f in forwarded_filters
)
@patch("superset.daos.chart.ChartDAO.list")
@pytest.mark.asyncio
async def test_list_charts_passes_changed_by_fk_filter_to_dao(mock_list, mcp_server):
"""list_charts should accept changed_by_fk filter and forward it."""
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client:
await client.call_tool(
"list_charts",
{
"request": {
"filters": [{"col": "changed_by_fk", "opr": "eq", "value": 7}],
"page": 1,
"page_size": 10,
}
},
)
assert mock_list.called
forwarded_filters = mock_list.call_args.kwargs.get("column_operators")
assert forwarded_filters is not None
assert any(getattr(f, "col", None) == "changed_by_fk" for f in forwarded_filters)

View File

@@ -22,6 +22,8 @@ from unittest.mock import Mock, patch
import pytest
from fastmcp import Client
from fastmcp.client.client import CallToolResult
from mcp.types import TextContent
from pydantic import ValidationError
from superset.mcp_service.app import mcp
@@ -45,6 +47,14 @@ get_schema_module = importlib.import_module(
"superset.mcp_service.system.tool.get_schema"
)
def _result_text(result: CallToolResult) -> str:
"""Return the text payload from the first content block of a tool result."""
block = result.content[0]
assert isinstance(block, TextContent)
return block.text
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@@ -198,7 +208,7 @@ async def test_get_schema_returns_structured_privacy_error_for_dataset(mcp_serve
{"request": {"model_type": "dataset"}},
)
data = json.loads(result.content[0].text)
data = json.loads(_result_text(result))
assert data["error_type"] == DATA_MODEL_METADATA_ERROR_TYPE
assert data["privacy_scope"] == "data_model"
@@ -241,7 +251,7 @@ async def test_get_schema_redacts_chart_data_model_fields(mcp_server):
{"request": {"model_type": "chart"}},
)
data = json.loads(result.content[0].text)
data = json.loads(_result_text(result))
schema_info = data["schema_info"]
assert all(
column["name"] not in CHART_DATA_MODEL_COLUMNS
@@ -389,7 +399,7 @@ class TestGetInstanceInfoCurrentUserViaMCP:
async with Client(mcp_server) as client:
result = await client.call_tool("get_instance_info", {"request": {}})
data = json.loads(result.content[0].text)
data = json.loads(_result_text(result))
assert "current_user" in data
cu = data["current_user"]
assert cu["id"] == 5
@@ -418,7 +428,7 @@ class TestGetInstanceInfoCurrentUserViaMCP:
async with Client(mcp_server) as client:
result = await client.call_tool("get_instance_info", {"request": {}})
data = json.loads(result.content[0].text)
data = json.loads(_result_text(result))
assert data["current_user"] is None
@pytest.mark.asyncio
@@ -444,7 +454,7 @@ class TestGetInstanceInfoCurrentUserViaMCP:
async with Client(mcp_server) as client:
result = await client.call_tool("get_instance_info", {"request": {}})
data = json.loads(result.content[0].text)
data = json.loads(_result_text(result))
cu = data["current_user"]
assert cu["id"] == 99
assert cu["username"] == "bot"
@@ -460,28 +470,50 @@ class TestGetInstanceInfoCurrentUserViaMCP:
# ---------------------------------------------------------------------------
def test_chart_filter_rejects_created_by_fk() -> None:
"""created_by_fk is not a valid ChartFilter column; use created_by_me instead."""
with pytest.raises(ValidationError):
ChartFilter(col="created_by_fk", opr="eq", value=42)
def test_chart_filter_rejects_user_directory_columns_other_than_fk() -> None:
"""ChartFilter still rejects user-directory columns that expose names."""
for col in ("created_by_name", "owners", "changed_by"):
with pytest.raises(ValidationError):
ChartFilter.model_validate({"col": col, "opr": "eq", "value": "anything"})
def test_chart_filter_accepts_created_and_changed_by_fk() -> None:
"""ChartFilter allows filtering by created_by_fk / changed_by_fk (user IDs)."""
for col in ("created_by_fk", "changed_by_fk"):
f = ChartFilter.model_validate({"col": col, "opr": "eq", "value": 42})
assert f.col == col
def test_chart_filter_rejects_invalid_column():
"""Test that ChartFilter rejects invalid column names."""
with pytest.raises(ValidationError):
ChartFilter(col="nonexistent_column", opr="eq", value=42)
ChartFilter.model_validate(
{"col": "nonexistent_column", "opr": "eq", "value": 42}
)
def test_dashboard_filter_rejects_created_by_fk():
"""created_by_fk is not a valid DashboardFilter column; use created_by_me."""
with pytest.raises(ValidationError):
DashboardFilter(col="created_by_fk", opr="eq", value=42)
def test_dashboard_filter_rejects_user_directory_columns_other_than_fk() -> None:
"""DashboardFilter still rejects user-directory columns that expose names."""
for col in ("created_by_name", "owners", "changed_by"):
with pytest.raises(ValidationError):
DashboardFilter.model_validate(
{"col": col, "opr": "eq", "value": "anything"}
)
def test_dashboard_filter_accepts_created_and_changed_by_fk() -> None:
"""DashboardFilter allows filtering by created_by_fk / changed_by_fk."""
for col in ("created_by_fk", "changed_by_fk"):
f = DashboardFilter.model_validate({"col": col, "opr": "eq", "value": 42})
assert f.col == col
def test_dashboard_filter_rejects_invalid_column():
"""Test that DashboardFilter rejects invalid column names."""
with pytest.raises(ValidationError):
DashboardFilter(col="nonexistent_column", opr="eq", value=42)
DashboardFilter.model_validate(
{"col": "nonexistent_column", "opr": "eq", "value": 42}
)
# ---------------------------------------------------------------------------
@@ -492,12 +524,12 @@ def test_dashboard_filter_rejects_invalid_column():
def test_chart_filter_existing_columns_still_work():
"""Test that pre-existing chart filter columns are not broken."""
for col in ("slice_name", "viz_type", "datasource_name"):
f = ChartFilter(col=col, opr="eq", value="test")
f = ChartFilter.model_validate({"col": col, "opr": "eq", "value": "test"})
assert f.col == col
def test_dashboard_filter_existing_columns_still_work():
"""Test that pre-existing dashboard filter columns are not broken."""
for col in ("dashboard_title", "published", "favorite"):
f = DashboardFilter(col=col, opr="eq", value="test")
f = DashboardFilter.model_validate({"col": col, "opr": "eq", "value": "test"})
assert f.col == col

View File

@@ -326,11 +326,19 @@ class TestGetSchemaToolViaClient:
async def test_get_schema_omits_user_directory_columns(
self, mock_filters, mcp_server
):
"""Test that schema discovery does not advertise user/access fields."""
"""Test that schema discovery does not advertise user/access fields.
created_by_fk and changed_by_fk are intentionally allowed in
filter_columns so callers can filter by user ID resolved via find_users,
but they remain hidden from select_columns and sortable_columns so the
directory itself is never exposed.
"""
mock_filters.return_value = {
"dashboard_title": ["eq", "ilike"],
"owner": ["rel_m_m"],
"published": ["eq"],
"created_by_fk": ["eq", "in"],
"changed_by_fk": ["eq", "in"],
}
async with Client(mcp_server) as client:
@@ -352,9 +360,16 @@ class TestGetSchemaToolViaClient:
"owner",
):
assert field not in select_column_names
assert field not in info["filter_columns"]
assert field not in info["sortable_columns"]
# User-name and relationship fields stay out of filter_columns
for field in ("owners", "roles", "created_by", "changed_by", "owner"):
assert field not in info["filter_columns"]
# ID-only filter columns are advertised so callers can filter via find_users
assert "created_by_fk" in info["filter_columns"]
assert "changed_by_fk" in info["filter_columns"]
@patch("superset.daos.chart.ChartDAO.get_filterable_columns_and_operators")
@pytest.mark.asyncio
async def test_get_schema_chart_omits_self_referencing_filter_columns(
@@ -362,8 +377,9 @@ class TestGetSchemaToolViaClient:
):
"""Test that chart schema does not advertise self-referencing filter columns.
Even if the DAO returns created_by_fk or owner, they must be excluded so
LLMs cannot discover and use them to enumerate user IDs.
Even if the DAO returns owner or created_by_fk_or_owner, they must be
excluded — these synthetic columns are generated server-side from the
owned_by_me flag and are not directly usable by LLM callers.
"""
mock_filters.return_value = {
"slice_name": ["eq", "ilike"],
@@ -381,7 +397,7 @@ class TestGetSchemaToolViaClient:
info = data["schema_info"]
assert "slice_name" in info["filter_columns"]
for field in ("created_by_fk", "owner", "created_by_fk_or_owner"):
for field in ("owner", "created_by_fk_or_owner"):
assert field not in info["filter_columns"]
@patch("superset.daos.dataset.DatasetDAO.get_filterable_columns_and_operators")
@@ -391,8 +407,9 @@ class TestGetSchemaToolViaClient:
):
"""Test that dataset schema does not advertise self-referencing filter columns.
Even if the DAO returns created_by_fk or owner, they must be excluded so
LLMs cannot discover and use them to enumerate user IDs.
Even if the DAO returns owner or created_by_fk_or_owner, they must be
excluded — these synthetic columns are generated server-side from the
owned_by_me flag and are not directly usable by LLM callers.
"""
mock_filters.return_value = {
"table_name": ["eq", "ilike"],
@@ -410,7 +427,7 @@ class TestGetSchemaToolViaClient:
info = data["schema_info"]
assert "table_name" in info["filter_columns"]
for field in ("created_by_fk", "owner", "created_by_fk_or_owner"):
for field in ("owner", "created_by_fk_or_owner"):
assert field not in info["filter_columns"]
@patch("superset.daos.dashboard.DashboardDAO.get_filterable_columns_and_operators")
@@ -420,8 +437,9 @@ class TestGetSchemaToolViaClient:
):
"""Test dashboard schema omits self-referencing filter columns.
Even if the DAO returns created_by_fk or owner, they must be excluded
so LLMs cannot discover and use them to enumerate user IDs.
Even if the DAO returns owner or created_by_fk_or_owner, they must be
excluded — these synthetic columns are generated server-side from the
owned_by_me flag and are not directly usable by LLM callers.
"""
mock_filters.return_value = {
"dashboard_title": ["eq", "ilike"],
@@ -439,7 +457,7 @@ class TestGetSchemaToolViaClient:
info = data["schema_info"]
assert "dashboard_title" in info["filter_columns"]
for field in ("created_by_fk", "owner", "created_by_fk_or_owner"):
for field in ("owner", "created_by_fk_or_owner"):
assert field not in info["filter_columns"]

View File

@@ -15,8 +15,10 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from datetime import date, datetime, time, timezone
from typing import Any
from unittest.mock import MagicMock
from zoneinfo import ZoneInfo
import pandas as pd
import pyarrow as pa
@@ -40,6 +42,7 @@ from superset_core.semantic_layers.types import (
from superset_core.semantic_layers.view import SemanticViewFeature
from superset.semantic_layers.mapper import (
_coerce_scalar_filter_value,
_convert_query_object_filter,
_convert_time_grain,
_get_filters_from_extras,
@@ -1204,6 +1207,91 @@ def test_convert_query_object_filter_like() -> None:
}
def test_convert_query_object_filter_coerces_integer_string_value() -> None:
"""Test scalar filter values are coerced to dimension type."""
all_dimensions = {
"birthyear": Dimension(
"birthyear",
"birthyear",
pa.int64(),
"birthyear",
"Birthyear",
)
}
filter_: ValidatedQueryObjectFilterClause = {
"op": FilterOperator.GREATER_THAN_OR_EQUALS.value,
"col": "birthyear",
"val": "1982",
}
result = _convert_query_object_filter(filter_, all_dimensions)
assert result == {
Filter(
type=PredicateType.WHERE,
column=all_dimensions["birthyear"],
operator=Operator.GREATER_THAN_OR_EQUAL,
value=1982,
)
}
def test_convert_query_object_filter_coerces_in_integer_values() -> None:
"""Test IN filter list values are coerced element-wise."""
all_dimensions = {
"order_id__amount": Dimension(
"order_id__amount",
"order_id__amount",
pa.int64(),
"order_id__amount",
"Order amount",
)
}
filter_: ValidatedQueryObjectFilterClause = {
"op": FilterOperator.IN.value,
"col": "order_id__amount",
"val": ["58", "61"],
}
result = _convert_query_object_filter(filter_, all_dimensions)
assert result == {
Filter(
type=PredicateType.WHERE,
column=all_dimensions["order_id__amount"],
operator=Operator.IN,
value=frozenset({58, 61}),
)
}
def test_convert_query_object_filter_invalid_integer_value_raises() -> None:
"""Test invalid integer value raises a clear error."""
all_dimensions = {
"birthyear": Dimension(
"birthyear",
"birthyear",
pa.int64(),
"birthyear",
"Birthyear",
)
}
filter_: ValidatedQueryObjectFilterClause = {
"op": FilterOperator.GREATER_THAN_OR_EQUALS.value,
"col": "birthyear",
"val": "nineteen-eighty-two",
}
with pytest.raises(
ValueError,
match="Invalid integer value 'nineteen-eighty-two' for filter column birthyear",
):
_convert_query_object_filter(filter_, all_dimensions)
def test_get_results_without_time_offsets(
mock_datasource: MagicMock,
mocker: MockerFixture,
@@ -1923,6 +2011,86 @@ def test_convert_query_object_filter_temporal_range_with_value() -> None:
}
def test_convert_query_object_filter_temporal_range_coerces_date_bounds() -> None:
"""
TEMPORAL_RANGE bounds should be coerced against the dimension's dtype so
date/timestamp columns are not compared against raw strings.
"""
all_dimensions = {
"order_date": Dimension(
"order_date", "order_date", pa.date32(), "order_date", "Order date"
)
}
filter_: ValidatedQueryObjectFilterClause = {
"op": FilterOperator.TEMPORAL_RANGE.value,
"col": "order_date",
"val": "2025-01-01 : 2025-12-31",
}
result = _convert_query_object_filter(filter_, all_dimensions)
assert result == {
Filter(
type=PredicateType.WHERE,
column=all_dimensions["order_date"],
operator=Operator.GREATER_THAN_OR_EQUAL,
value=date(2025, 1, 1),
),
Filter(
type=PredicateType.WHERE,
column=all_dimensions["order_date"],
operator=Operator.LESS_THAN,
value=date(2025, 12, 31),
),
}
def test_convert_query_object_filter_temporal_range_open_ended() -> None:
"""
Open-ended TEMPORAL_RANGE bounds should emit only the bounded predicate.
"""
all_dimensions = {
"order_date": Dimension(
"order_date", "order_date", pa.date32(), "order_date", "Order date"
)
}
only_start: ValidatedQueryObjectFilterClause = {
"op": FilterOperator.TEMPORAL_RANGE.value,
"col": "order_date",
"val": "2025-01-01 : ",
}
assert _convert_query_object_filter(only_start, all_dimensions) == {
Filter(
type=PredicateType.WHERE,
column=all_dimensions["order_date"],
operator=Operator.GREATER_THAN_OR_EQUAL,
value=date(2025, 1, 1),
),
}
only_end: ValidatedQueryObjectFilterClause = {
"op": FilterOperator.TEMPORAL_RANGE.value,
"col": "order_date",
"val": " : 2025-12-31",
}
assert _convert_query_object_filter(only_end, all_dimensions) == {
Filter(
type=PredicateType.WHERE,
column=all_dimensions["order_date"],
operator=Operator.LESS_THAN,
value=date(2025, 12, 31),
),
}
empty: ValidatedQueryObjectFilterClause = {
"op": FilterOperator.TEMPORAL_RANGE.value,
"col": "order_date",
"val": " : ",
}
assert _convert_query_object_filter(empty, all_dimensions) is None
def test_get_order_adhoc_with_none_sql_expression(mock_datasource: MagicMock) -> None:
"""
Test order extraction skips adhoc expression with None sqlExpression.
@@ -2741,3 +2909,205 @@ def test_get_group_limit_filters_no_granularity(
# Should return None - no granularity means no time filters added
assert result is None
# ---------------------------------------------------------------------------
# _coerce_scalar_filter_value: per-dtype branches
# ---------------------------------------------------------------------------
def _dim(dtype: pa.DataType, name: str = "d") -> Dimension:
return Dimension(name, name, dtype, name, name.capitalize())
def test_coerce_none_returns_none() -> None:
assert _coerce_scalar_filter_value(None, _dim(pa.int64())) is None
def test_coerce_unsupported_dtype_passes_through() -> None:
# utf8 (and any dtype not branched in the function) returns the value as-is.
assert _coerce_scalar_filter_value("abc", _dim(pa.utf8())) == "abc"
@pytest.mark.parametrize(
"raw,expected",
[
(True, True),
(False, False),
(1, True),
(0, False),
(1.0, True),
(0.0, False),
("true", True),
("T", True),
(" 1 ", True),
("yes", True),
("Y", True),
("on", True),
("false", False),
("F", False),
("0", False),
("no", False),
("N", False),
("off", False),
],
)
def test_coerce_boolean(raw: Any, expected: bool) -> None:
assert _coerce_scalar_filter_value(raw, _dim(pa.bool_())) is expected
@pytest.mark.parametrize("raw", ["maybe", 2, 0.5, -1])
def test_coerce_boolean_invalid_raises(raw: Any) -> None:
with pytest.raises(ValueError, match="Invalid boolean value"):
_coerce_scalar_filter_value(raw, _dim(pa.bool_()))
def test_coerce_integer_passthrough() -> None:
assert _coerce_scalar_filter_value(42, _dim(pa.int64())) == 42
def test_coerce_integer_accepts_integer_valued_float() -> None:
# JSON round-trips can turn an int into ``42.0``; accept losslessly.
assert _coerce_scalar_filter_value(42.0, _dim(pa.int64())) == 42
def test_coerce_integer_rejects_bool() -> None:
# bool is a subclass of int; we explicitly reject it.
with pytest.raises(ValueError, match="Invalid integer value"):
_coerce_scalar_filter_value(True, _dim(pa.int64()))
def test_coerce_integer_rejects_non_integer_float() -> None:
with pytest.raises(ValueError, match="Invalid integer value"):
_coerce_scalar_filter_value(1.5, _dim(pa.int64()))
def test_coerce_integer_rejects_other_types() -> None:
with pytest.raises(ValueError, match="Invalid integer value"):
_coerce_scalar_filter_value([1], _dim(pa.int64()))
@pytest.mark.parametrize(
"dtype",
[pa.float64(), pa.decimal128(10, 2)],
)
def test_coerce_floating_or_decimal(dtype: pa.DataType) -> None:
assert _coerce_scalar_filter_value(1, _dim(dtype)) == 1.0
assert _coerce_scalar_filter_value(1.5, _dim(dtype)) == 1.5
assert _coerce_scalar_filter_value(" 2.5 ", _dim(dtype)) == 2.5
def test_coerce_floating_rejects_bool() -> None:
with pytest.raises(ValueError, match="Invalid numeric value"):
_coerce_scalar_filter_value(True, _dim(pa.float64()))
def test_coerce_floating_invalid_string_raises() -> None:
with pytest.raises(ValueError, match="Invalid numeric value"):
_coerce_scalar_filter_value("not-a-number", _dim(pa.float64()))
def test_coerce_floating_rejects_other_types() -> None:
with pytest.raises(ValueError, match="Invalid numeric value"):
_coerce_scalar_filter_value([1.0], _dim(pa.float64()))
def test_coerce_date_from_datetime() -> None:
out = _coerce_scalar_filter_value(datetime(2025, 1, 2, 12, 0), _dim(pa.date32()))
assert out == date(2025, 1, 2)
def test_coerce_date_passthrough() -> None:
out = _coerce_scalar_filter_value(date(2025, 1, 2), _dim(pa.date32()))
assert out == date(2025, 1, 2)
def test_coerce_date_from_iso_string() -> None:
out = _coerce_scalar_filter_value(" 2025-01-02 ", _dim(pa.date32()))
assert out == date(2025, 1, 2)
def test_coerce_date_invalid_string_raises() -> None:
with pytest.raises(ValueError, match="Invalid date value"):
_coerce_scalar_filter_value("not-a-date", _dim(pa.date32()))
def test_coerce_date_rejects_other_types() -> None:
with pytest.raises(ValueError, match="Invalid date value"):
_coerce_scalar_filter_value(20250102, _dim(pa.date32()))
def test_coerce_timestamp_from_datetime_passthrough() -> None:
dt = datetime(2025, 1, 2, 3, 4, 5)
# Naive dtype: returned as-is, still naive.
assert _coerce_scalar_filter_value(dt, _dim(pa.timestamp("us"))) == dt
def test_coerce_timestamp_from_date() -> None:
out = _coerce_scalar_filter_value(date(2025, 1, 2), _dim(pa.timestamp("us")))
assert out == datetime(2025, 1, 2, 0, 0)
def test_coerce_timestamp_from_iso_string_with_z() -> None:
out = _coerce_scalar_filter_value("2025-01-02T03:04:05Z", _dim(pa.timestamp("us")))
assert out == datetime.fromisoformat("2025-01-02T03:04:05+00:00")
def test_coerce_timestamp_invalid_string_raises() -> None:
with pytest.raises(ValueError, match="Invalid timestamp value"):
_coerce_scalar_filter_value("not-a-ts", _dim(pa.timestamp("us")))
def test_coerce_timestamp_rejects_other_types() -> None:
with pytest.raises(ValueError, match="Invalid timestamp value"):
_coerce_scalar_filter_value(1234567890, _dim(pa.timestamp("us")))
def test_coerce_timestamp_tz_aware_dtype_attaches_tz_to_naive_datetime() -> None:
dt = datetime(2025, 1, 2, 3, 4, 5)
out = _coerce_scalar_filter_value(dt, _dim(pa.timestamp("us", tz="UTC")))
assert out == datetime(2025, 1, 2, 3, 4, 5, tzinfo=ZoneInfo("UTC"))
def test_coerce_timestamp_tz_aware_dtype_converts_aware_datetime() -> None:
dt = datetime(2025, 1, 2, 12, 0, tzinfo=timezone.utc)
out = _coerce_scalar_filter_value(
dt, _dim(pa.timestamp("us", tz="America/New_York"))
)
# 12:00 UTC == 07:00 in New York
assert out == datetime(2025, 1, 2, 7, 0, tzinfo=ZoneInfo("America/New_York"))
def test_coerce_timestamp_tz_aware_dtype_attaches_tz_to_date() -> None:
out = _coerce_scalar_filter_value(
date(2025, 1, 2), _dim(pa.timestamp("us", tz="UTC"))
)
assert out == datetime(2025, 1, 2, 0, 0, tzinfo=ZoneInfo("UTC"))
def test_coerce_timestamp_tz_aware_dtype_parses_string_with_tz() -> None:
out = _coerce_scalar_filter_value(
"2025-01-02T03:04:05", _dim(pa.timestamp("us", tz="UTC"))
)
# Naive string gets UTC attached.
assert out == datetime(2025, 1, 2, 3, 4, 5, tzinfo=ZoneInfo("UTC"))
def test_coerce_time_passthrough() -> None:
out = _coerce_scalar_filter_value(time(3, 4, 5), _dim(pa.time64("us")))
assert out == time(3, 4, 5)
def test_coerce_time_from_iso_string() -> None:
out = _coerce_scalar_filter_value(" 03:04:05 ", _dim(pa.time64("us")))
assert out == time(3, 4, 5)
def test_coerce_time_invalid_string_raises() -> None:
with pytest.raises(ValueError, match="Invalid time value"):
_coerce_scalar_filter_value("not-a-time", _dim(pa.time64("us")))
def test_coerce_time_rejects_other_types() -> None:
with pytest.raises(ValueError, match="Invalid time value"):
_coerce_scalar_filter_value(123, _dim(pa.time64("us")))

View File

@@ -811,3 +811,90 @@ class TestWebDriverPlaywrightErrorHandling:
mock_logger.exception.assert_any_call(
"Timed out requesting url %s", "http://example.com"
)
@patch("superset.utils.webdriver.PLAYWRIGHT_AVAILABLE", True)
@patch("superset.utils.webdriver.sync_playwright")
@patch("superset.utils.webdriver.logger")
def test_missing_element_for_dashboard_height_falls_back_without_crashing(
self, mock_logger, mock_sync_playwright
):
"""Missing dashboard element should not crash height evaluation."""
mock_user = MagicMock()
mock_user.username = "test_user"
mock_playwright_instance = MagicMock()
mock_browser = MagicMock()
mock_context = MagicMock()
mock_page = MagicMock()
mock_element = MagicMock()
mock_chart_container = MagicMock()
mock_sync_playwright.return_value.__enter__.return_value = (
mock_playwright_instance
)
mock_playwright_instance.chromium.launch.return_value = mock_browser
mock_browser.new_context.return_value = mock_context
mock_context.new_page.return_value = mock_page
def locator_side_effect(selector):
if selector == ".dashboard":
return mock_element
if selector == ".chart-container":
locator = MagicMock()
locator.all.return_value = [mock_chart_container]
return locator
if selector == ".loading":
locator = MagicMock()
locator.all.return_value = []
return locator
return MagicMock()
mock_page.locator.side_effect = locator_side_effect
mock_element.wait_for.return_value = None
mock_element.screenshot.return_value = b"fake_screenshot"
mock_chart_container.wait_for.return_value = None
mock_page.wait_for_timeout.return_value = None
def evaluate_side_effect(script):
if script == 'document.querySelectorAll(".chart-container").length':
return 1
if "const target = document.querySelector" in script:
return 0
return None
mock_page.evaluate.side_effect = evaluate_side_effect
with patch("superset.utils.webdriver.app") as mock_app:
mock_app.config = {
"WEBDRIVER_OPTION_ARGS": [],
"WEBDRIVER_WINDOW": {"pixel_density": 1},
"SCREENSHOT_PLAYWRIGHT_DEFAULT_TIMEOUT": 30000,
"SCREENSHOT_PLAYWRIGHT_WAIT_EVENT": "networkidle",
"SCREENSHOT_SELENIUM_HEADSTART": 5,
"SCREENSHOT_SELENIUM_ANIMATION_WAIT": 1,
"SCREENSHOT_LOCATE_WAIT": 10,
"SCREENSHOT_LOAD_WAIT": 10,
"SCREENSHOT_WAIT_FOR_ERROR_MODAL_VISIBLE": 10,
"SCREENSHOT_WAIT_FOR_ERROR_MODAL_INVISIBLE": 10,
"SCREENSHOT_REPLACE_UNEXPECTED_ERRORS": False,
"SCREENSHOT_TILED_ENABLED": True,
"SCREENSHOT_TILED_CHART_THRESHOLD": 20,
"SCREENSHOT_TILED_HEIGHT_THRESHOLD": 5000,
"SCREENSHOT_TILED_VIEWPORT_HEIGHT": 600,
}
with patch.object(WebDriverPlaywright, "auth") as mock_auth:
mock_auth.return_value = mock_context
driver = WebDriverPlaywright("chrome")
result = driver.get_screenshot(
"http://example.com", "dashboard", mock_user
)
assert result == b"fake_screenshot"
mock_logger.warning.assert_any_call(
"Could not determine dashboard height for element %s at url %s; "
"falling back to standard screenshot behavior",
"dashboard",
"http://example.com",
)