mirror of
https://github.com/apache/superset.git
synced 2026-06-11 18:49:15 +00:00
Compare commits
17 Commits
enxdev/cha
...
mcp-rls-pl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
95b9efe1a3 | ||
|
|
336d30fbdc | ||
|
|
ecd073b3d2 | ||
|
|
768a3df21b | ||
|
|
e33d0d4f44 | ||
|
|
776fcf2f8b | ||
|
|
2ebb3bcd9f | ||
|
|
ff2c37e52b | ||
|
|
e14a92f881 | ||
|
|
3d528921ba | ||
|
|
81aad1c540 | ||
|
|
c8c805c7e8 | ||
|
|
ab69e3dbcc | ||
|
|
c470b7a8d4 | ||
|
|
7da1bca909 | ||
|
|
06d72bcc72 | ||
|
|
473456b6ea |
@@ -123,6 +123,14 @@ Database Connections:
|
||||
- list_databases: List database connections with advanced filters (1-based pagination)
|
||||
- get_database_info: Get detailed database connection info by ID (backend, capabilities)
|
||||
|
||||
Row Level Security (Admin only):
|
||||
- list_rls_filters: List RLS filters with filtering and search (1-based pagination)
|
||||
- get_rls_filter_info: Get detailed RLS filter info by ID (tables, roles, clause)
|
||||
|
||||
Plugins (Admin only):
|
||||
- list_plugins: List dynamic plugins with filtering and search (1-based pagination)
|
||||
- get_plugin_info: Get detailed plugin info by ID (name, key, bundle URL)
|
||||
|
||||
Dataset Management:
|
||||
- list_datasets: List datasets with advanced filters (1-based pagination)
|
||||
- get_dataset_info: Get detailed dataset information by ID (includes columns/metrics)
|
||||
@@ -150,6 +158,7 @@ Schema Discovery:
|
||||
|
||||
System Information:
|
||||
- get_instance_info: Get instance-wide statistics, metadata, and current user identity
|
||||
- find_users: Resolve a person's name to user IDs for use as a filter value
|
||||
- health_check: Simple health check tool (takes NO parameters, call without arguments)
|
||||
- generate_bug_report: Build a PII-sanitized bug report to send to Preset support
|
||||
(use when the user says the MCP is broken or asks how to report an issue)
|
||||
@@ -191,6 +200,16 @@ Some tools do not use a request wrapper, so follow each tool's schema
|
||||
|
||||
Recommended Workflows:
|
||||
|
||||
To filter dashboards/charts/datasets by a person ("show me what <name> is working on"):
|
||||
1. find_users(request={{"query": "<name>"}}) -> resolve to user IDs
|
||||
2. Pick the matching user.id from the response
|
||||
3. list_dashboards(request={{"filters": [
|
||||
{{"col": "created_by_fk", "opr": "eq", "value": <id>}}
|
||||
]}}) — same shape for list_charts / list_datasets.
|
||||
(use changed_by_fk for "last modified by", or "in" with a list of IDs for
|
||||
multiple matches). Do NOT pass the person's name as the search parameter —
|
||||
search matches titles, not people.
|
||||
|
||||
To add a chart to an existing dashboard:
|
||||
1. add_chart_to_existing_dashboard(dashboard_id, chart_id) -> updates dashboard directly
|
||||
- If permission_denied=True is returned: inform the user they lack edit rights,
|
||||
@@ -342,7 +361,10 @@ IMPORTANT - Tool-Only Interaction:
|
||||
|
||||
General usage tips:
|
||||
- All listing tools use 1-based pagination (first page is 1)
|
||||
- Use get_schema to discover filterable columns, sortable columns, and default columns
|
||||
- Use get_schema (chart/dataset/dashboard/database) to discover filterable columns,
|
||||
sortable columns, and default columns for those resource types
|
||||
- For list_rls_filters and list_plugins, filterable/sortable columns are listed
|
||||
inline in each tool's docstring — get_schema does not cover these tools
|
||||
- Use 'filters' parameter for advanced queries with filter columns from get_schema
|
||||
- IDs can be integer or UUID format where supported
|
||||
- All tools return structured, Pydantic-typed responses
|
||||
@@ -360,6 +382,11 @@ Input format:
|
||||
contact details, roles, admin status, ownership, or access-list information.
|
||||
- Do NOT infer access-list answers from dashboard metadata such as published status,
|
||||
role restrictions, empty owner lists, or schema fields.
|
||||
- find_users is sanctioned ONLY for resolving a name the user supplied into a
|
||||
user ID for filtering (e.g., "what is <name> working on" -> filter
|
||||
list_dashboards by created_by_fk). Do NOT use find_users to answer "who owns
|
||||
X", "who can access X", "is <name> an admin", or to enumerate the directory.
|
||||
Never return find_users output to the user verbatim.
|
||||
- Do NOT use execute_sql to query user, role, owner, or access-list tables for this
|
||||
information.
|
||||
- You may reference the current user's own identity details when appropriate, such
|
||||
@@ -620,6 +647,14 @@ from superset.mcp_service.dataset.tool import ( # noqa: F401, E402
|
||||
from superset.mcp_service.explore.tool import ( # noqa: F401, E402
|
||||
generate_explore_link,
|
||||
)
|
||||
from superset.mcp_service.plugin.tool import ( # noqa: F401, E402
|
||||
get_plugin_info,
|
||||
list_plugins,
|
||||
)
|
||||
from superset.mcp_service.rls.tool import ( # noqa: F401, E402
|
||||
get_rls_filter_info,
|
||||
list_rls_filters,
|
||||
)
|
||||
from superset.mcp_service.sql_lab.tool import ( # noqa: F401, E402
|
||||
execute_sql,
|
||||
open_sql_lab_with_context,
|
||||
@@ -630,6 +665,7 @@ from superset.mcp_service.system import ( # noqa: F401, E402
|
||||
resources as system_resources,
|
||||
)
|
||||
from superset.mcp_service.system.tool import ( # noqa: F401, E402
|
||||
find_users,
|
||||
generate_bug_report,
|
||||
get_instance_info,
|
||||
get_schema,
|
||||
|
||||
@@ -32,6 +32,7 @@ from urllib.parse import parse_qs, urlparse
|
||||
from superset.constants import EXTRA_FORM_DATA_OVERRIDE_REGULAR_MAPPINGS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.mcp_service.chart.schemas import AppliedDashboardFilter
|
||||
from superset.models.slice import Slice
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -44,20 +45,33 @@ QUERY_CONTEXT_EXTRA_FORM_DATA_OVERRIDE_KEYS = {
|
||||
}
|
||||
|
||||
|
||||
def find_chart_by_identifier(identifier: int | str) -> Slice | None:
|
||||
class ChartNotOnDashboardError(ValueError):
|
||||
"""Raised when a chart is not part of the given dashboard's slices."""
|
||||
|
||||
|
||||
def find_chart_by_identifier(
|
||||
identifier: int | str,
|
||||
query_options: list[Any] | None = None,
|
||||
) -> Slice | None:
|
||||
"""Find a chart by numeric ID or UUID string.
|
||||
|
||||
Accepts an integer ID, a string that looks like a digit (e.g. "123"),
|
||||
or a UUID string. Returns the Slice model instance or None.
|
||||
|
||||
``query_options`` is forwarded to the DAO so callers can eager-load
|
||||
relationships needed after the request-scoped session is detached.
|
||||
"""
|
||||
from superset.daos.chart import ChartDAO # avoid circular import
|
||||
|
||||
extra: dict[str, Any] = (
|
||||
{"query_options": query_options} if query_options is not None else {}
|
||||
)
|
||||
if isinstance(identifier, int) or (
|
||||
isinstance(identifier, str) and identifier.isdigit()
|
||||
):
|
||||
chart_id = int(identifier) if isinstance(identifier, str) else identifier
|
||||
return ChartDAO.find_by_id(chart_id)
|
||||
return ChartDAO.find_by_id(identifier, id_column="uuid")
|
||||
return ChartDAO.find_by_id(chart_id, **extra)
|
||||
return ChartDAO.find_by_id(identifier, id_column="uuid", **extra)
|
||||
|
||||
|
||||
def get_cached_form_data(form_data_key: str) -> str | None:
|
||||
@@ -528,3 +542,118 @@ def extract_form_data_key_from_url(url: str | None) -> str | None:
|
||||
parsed = urlparse(url)
|
||||
values = parse_qs(parsed.query).get("form_data_key", [])
|
||||
return values[0] if values else None
|
||||
|
||||
|
||||
def _match_adhoc_by_subject(
|
||||
adhoc_filters: Any, column: str | None
|
||||
) -> tuple[str | None, Any] | None:
|
||||
if not column or not isinstance(adhoc_filters, list):
|
||||
return None
|
||||
for af in adhoc_filters:
|
||||
if isinstance(af, dict) and af.get("subject") == column:
|
||||
return af.get("operator"), af.get("comparator")
|
||||
return None
|
||||
|
||||
|
||||
def _match_legacy_by_col(
|
||||
legacy_filters: Any, column: str | None
|
||||
) -> tuple[str | None, Any] | None:
|
||||
if not column or not isinstance(legacy_filters, list):
|
||||
return None
|
||||
for f in legacy_filters:
|
||||
if isinstance(f, dict) and f.get("col") == column:
|
||||
return f.get("op"), f.get("val")
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_filter_operator_and_value(
|
||||
extra_form_data: dict[str, Any] | None,
|
||||
column: str | None,
|
||||
) -> tuple[str | None, Any]:
|
||||
"""Pull operator and value for a dashboard filter from its
|
||||
default extra_form_data, matching on target column where applicable."""
|
||||
if not extra_form_data:
|
||||
return None, None
|
||||
|
||||
if match := _match_adhoc_by_subject(extra_form_data.get("adhoc_filters"), column):
|
||||
return match
|
||||
if match := _match_legacy_by_col(extra_form_data.get("filters"), column):
|
||||
return match
|
||||
# Temporal filters contribute time_range with no target column
|
||||
if time_range := extra_form_data.get("time_range"):
|
||||
return "TIME_RANGE", time_range
|
||||
return None, None
|
||||
|
||||
|
||||
def build_applied_dashboard_filters(
|
||||
dashboard_id: int, chart_id: int
|
||||
) -> list[AppliedDashboardFilter]:
|
||||
"""Resolve dashboard-level native filters in scope for a chart.
|
||||
|
||||
Validates that the dashboard exists, the caller has access, and the chart
|
||||
is on the dashboard. Returns one AppliedDashboardFilter per non-DIVIDER
|
||||
native filter whose scope includes the chart, populated with the filter's
|
||||
default operator and value.
|
||||
|
||||
Raises DashboardNotFoundError if the dashboard is missing,
|
||||
ChartNotOnDashboardError if the chart is not on it, and
|
||||
SupersetSecurityException if the caller cannot access the dashboard.
|
||||
"""
|
||||
# Local imports avoid circular deps at module load
|
||||
from superset import db, security_manager
|
||||
from superset.charts.data.dashboard_filter_context import (
|
||||
_extract_filter_extra_form_data,
|
||||
_get_filter_target_column,
|
||||
_is_filter_in_scope_for_chart,
|
||||
)
|
||||
from superset.commands.dashboard.exceptions import DashboardNotFoundError
|
||||
from superset.mcp_service.chart.schemas import AppliedDashboardFilter
|
||||
from superset.models.dashboard import Dashboard
|
||||
from superset.utils import json
|
||||
|
||||
dashboard = db.session.query(Dashboard).filter_by(id=dashboard_id).one_or_none()
|
||||
if not dashboard:
|
||||
raise DashboardNotFoundError(dashboard_id=str(dashboard_id))
|
||||
|
||||
security_manager.raise_for_access(dashboard=dashboard)
|
||||
|
||||
slice_ids = {slc.id for slc in dashboard.slices}
|
||||
if chart_id not in slice_ids:
|
||||
raise ChartNotOnDashboardError(
|
||||
f"Chart {chart_id} is not on dashboard {dashboard_id}"
|
||||
)
|
||||
|
||||
metadata = json.loads(dashboard.json_metadata or "{}")
|
||||
native_filter_config = metadata.get("native_filter_configuration", [])
|
||||
if not isinstance(native_filter_config, list):
|
||||
return []
|
||||
position_json = json.loads(dashboard.position_json or "{}")
|
||||
if not isinstance(position_json, dict):
|
||||
position_json = {}
|
||||
|
||||
applied: list[AppliedDashboardFilter] = []
|
||||
for flt in native_filter_config:
|
||||
if not isinstance(flt, dict):
|
||||
continue
|
||||
if flt.get("type", "") == "DIVIDER":
|
||||
continue
|
||||
if not _is_filter_in_scope_for_chart(flt, chart_id, position_json):
|
||||
continue
|
||||
|
||||
extra_form_data, status = _extract_filter_extra_form_data(flt)
|
||||
column = _get_filter_target_column(flt)
|
||||
operator, value = _resolve_filter_operator_and_value(extra_form_data, column)
|
||||
|
||||
applied.append(
|
||||
AppliedDashboardFilter(
|
||||
id=flt.get("id"),
|
||||
name=flt.get("name"),
|
||||
filter_type=flt.get("filterType"),
|
||||
column=column,
|
||||
operator=operator,
|
||||
value=value,
|
||||
status=status.value,
|
||||
)
|
||||
)
|
||||
|
||||
return applied
|
||||
|
||||
@@ -32,6 +32,7 @@ from superset.mcp_service.chart.schemas import (
|
||||
ChartCapabilities,
|
||||
ChartSemantics,
|
||||
ColumnRef,
|
||||
CurrencyFormat,
|
||||
FilterConfig,
|
||||
HandlebarsChartConfig,
|
||||
MixedTimeseriesChartConfig,
|
||||
@@ -477,6 +478,7 @@ def map_table_config(config: TableChartConfig) -> Dict[str, Any]:
|
||||
]
|
||||
|
||||
form_data["row_limit"] = config.row_limit
|
||||
add_color_scheme(form_data, config.color_scheme)
|
||||
|
||||
return form_data
|
||||
|
||||
@@ -547,7 +549,35 @@ def add_legend_config(form_data: Dict[str, Any], config: XYChartConfig) -> None:
|
||||
if not config.legend.show:
|
||||
form_data["show_legend"] = False
|
||||
if config.legend.position:
|
||||
form_data["legend_orientation"] = config.legend.position
|
||||
# Canonical form_data key is camelCase; the echarts plugins read
|
||||
# `legendOrientation` directly off form_data.
|
||||
form_data["legendOrientation"] = config.legend.position
|
||||
|
||||
|
||||
def add_color_scheme(form_data: Dict[str, Any], color_scheme: str | None) -> None:
|
||||
"""Add color scheme to form_data when set."""
|
||||
if color_scheme:
|
||||
form_data["color_scheme"] = color_scheme
|
||||
|
||||
|
||||
def add_currency_format(
|
||||
form_data: Dict[str, Any],
|
||||
currency_format: CurrencyFormat | None,
|
||||
key: str = "currency_format",
|
||||
) -> None:
|
||||
"""Add currency format to form_data under the given key when set."""
|
||||
if currency_format:
|
||||
form_data[key] = currency_format.to_form_data()
|
||||
|
||||
|
||||
def add_xy_data_label_options(
|
||||
form_data: Dict[str, Any], config: XYChartConfig, x_is_temporal: bool
|
||||
) -> None:
|
||||
"""Apply XY-specific data-label and time-format options when set."""
|
||||
if config.x_axis_time_format and x_is_temporal:
|
||||
form_data["x_axis_time_format"] = config.x_axis_time_format
|
||||
if config.show_value:
|
||||
form_data["show_value"] = True
|
||||
|
||||
|
||||
def add_orientation_config(form_data: Dict[str, Any], config: XYChartConfig) -> None:
|
||||
@@ -722,6 +752,9 @@ def map_xy_config(
|
||||
add_axis_config(form_data, config)
|
||||
add_legend_config(form_data, config)
|
||||
add_orientation_config(form_data, config)
|
||||
add_color_scheme(form_data, config.color_scheme)
|
||||
add_currency_format(form_data, config.currency_format)
|
||||
add_xy_data_label_options(form_data, config, x_is_temporal)
|
||||
|
||||
return form_data
|
||||
|
||||
@@ -734,11 +767,13 @@ def map_pie_config(config: PieChartConfig) -> Dict[str, Any]:
|
||||
"viz_type": "pie",
|
||||
"groupby": [config.dimension.name],
|
||||
"metric": metric,
|
||||
"color_scheme": "supersetColors",
|
||||
"color_scheme": config.color_scheme or "supersetColors",
|
||||
"show_labels": config.show_labels,
|
||||
"show_legend": config.show_legend,
|
||||
"legendOrientation": config.legend_orientation,
|
||||
"label_type": config.label_type,
|
||||
"number_format": config.number_format,
|
||||
"date_format": config.date_format,
|
||||
"sort_by_metric": config.sort_by_metric,
|
||||
"row_limit": config.row_limit,
|
||||
"donut": config.donut,
|
||||
@@ -746,9 +781,9 @@ def map_pie_config(config: PieChartConfig) -> Dict[str, Any]:
|
||||
"labels_outside": config.labels_outside,
|
||||
"outerRadius": config.outer_radius,
|
||||
"innerRadius": config.inner_radius,
|
||||
"date_format": "smart_date",
|
||||
}
|
||||
|
||||
add_currency_format(form_data, config.currency_format)
|
||||
_add_adhoc_filters(form_data, config.filters)
|
||||
|
||||
return form_data
|
||||
@@ -774,6 +809,9 @@ def map_big_number_config(config: BigNumberChartConfig) -> Dict[str, Any]:
|
||||
if config.y_axis_format:
|
||||
form_data["y_axis_format"] = config.y_axis_format
|
||||
|
||||
add_color_scheme(form_data, config.color_scheme)
|
||||
add_currency_format(form_data, config.currency_format)
|
||||
|
||||
# Trendline-specific fields
|
||||
if viz_type == "big_number":
|
||||
# Big Number with trendline uses granularity_sqla for the temporal column
|
||||
@@ -789,6 +827,9 @@ def map_big_number_config(config: BigNumberChartConfig) -> Dict[str, Any]:
|
||||
if config.compare_lag is not None:
|
||||
form_data["compare_lag"] = config.compare_lag
|
||||
|
||||
if config.time_format:
|
||||
form_data["time_format"] = config.time_format
|
||||
|
||||
_add_adhoc_filters(form_data, config.filters)
|
||||
|
||||
return form_data
|
||||
@@ -860,6 +901,10 @@ def map_pivot_table_config(config: PivotTableChartConfig) -> Dict[str, Any]:
|
||||
"row_limit": config.row_limit,
|
||||
}
|
||||
|
||||
if config.date_format:
|
||||
form_data["date_format"] = config.date_format
|
||||
|
||||
add_currency_format(form_data, config.currency_format)
|
||||
_add_adhoc_filters(form_data, config.filters)
|
||||
|
||||
return form_data
|
||||
@@ -939,10 +984,20 @@ def map_mixed_timeseries_config(
|
||||
"yAxisIndexB": 1,
|
||||
# Display
|
||||
"show_legend": config.show_legend,
|
||||
"legendOrientation": config.legend_orientation,
|
||||
"zoomable": True,
|
||||
"rich_tooltip": True,
|
||||
}
|
||||
|
||||
if config.show_value:
|
||||
form_data["show_value"] = True
|
||||
|
||||
add_color_scheme(form_data, config.color_scheme)
|
||||
add_currency_format(form_data, config.currency_format)
|
||||
add_currency_format(
|
||||
form_data, config.currency_format_secondary, key="currency_format_secondary"
|
||||
)
|
||||
|
||||
# Configure temporal handling
|
||||
configure_temporal_handling(form_data, x_is_temporal, config.time_grain)
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from __future__ import annotations
|
||||
|
||||
import difflib
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Dict, List, Literal, Protocol
|
||||
from typing import Annotated, Any, cast, Dict, List, Literal, Protocol
|
||||
|
||||
import humanize
|
||||
from pydantic import (
|
||||
@@ -146,7 +146,7 @@ class ChartInfo(BaseModel):
|
||||
),
|
||||
)
|
||||
form_data_key: str | None = Field(
|
||||
None,
|
||||
default=None,
|
||||
description=(
|
||||
"Cache key used to retrieve unsaved form_data. When present, indicates "
|
||||
"the form_data came from cache (unsaved edits) rather than the saved chart."
|
||||
@@ -279,6 +279,15 @@ class GetChartInfoRequest(BaseModel):
|
||||
"Can be used alone (without identifier) for unsaved charts."
|
||||
),
|
||||
)
|
||||
dashboard_id: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"When provided, resolves dashboard-level native filters that are in "
|
||||
"scope for this chart on the given dashboard and returns them under "
|
||||
"filters.dashboard_filters. Requires the chart to be on the dashboard "
|
||||
"and the caller to have dashboard access."
|
||||
),
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_identifier_or_form_data_key(self) -> "GetChartInfoRequest":
|
||||
@@ -523,17 +532,18 @@ class ChartFilter(ColumnOperator):
|
||||
value: The value to filter by (type depends on col and opr).
|
||||
"""
|
||||
|
||||
col: Literal[
|
||||
col: Literal[ # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
"slice_name",
|
||||
"viz_type",
|
||||
"datasource_name",
|
||||
"created_by_fk",
|
||||
"changed_by_fk",
|
||||
] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Column to filter on. Valid values: 'slice_name', 'viz_type', "
|
||||
"'datasource_name'. Other column names are not valid filter columns "
|
||||
"and will cause a validation error."
|
||||
),
|
||||
description="Column to filter on. Use get_schema(model_type='chart') for "
|
||||
"available filter columns. To filter by a person, first call find_users "
|
||||
"to resolve a name to a user ID, then filter by created_by_fk or "
|
||||
"changed_by_fk with that integer ID.",
|
||||
)
|
||||
opr: ColumnOperatorEnum = Field(
|
||||
...,
|
||||
@@ -731,6 +741,29 @@ class LegendConfig(BaseModel):
|
||||
position: Literal["top", "bottom", "left", "right"] | None = "right"
|
||||
|
||||
|
||||
class CurrencyFormat(BaseModel):
|
||||
"""Currency symbol and placement applied to numeric values."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
symbol: str = Field(
|
||||
...,
|
||||
description="Currency code or symbol (e.g. 'USD', 'EUR', '$', '€')",
|
||||
max_length=20,
|
||||
)
|
||||
symbol_position: Literal["prefix", "suffix"] = Field(
|
||||
"prefix",
|
||||
description="Whether to render the symbol before or after the value",
|
||||
validation_alias=AliasChoices("symbol_position", "symbolPosition"),
|
||||
)
|
||||
|
||||
def to_form_data(self) -> Dict[str, str]:
|
||||
return {"symbol": self.symbol, "symbolPosition": self.symbol_position}
|
||||
|
||||
|
||||
LEGEND_POSITION_LITERAL = Literal["top", "bottom", "left", "right"]
|
||||
|
||||
|
||||
class FilterConfig(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
@@ -857,6 +890,27 @@ class PieChartConfig(UnknownFieldCheckMixin):
|
||||
)
|
||||
row_limit: int = Field(100, description="Max slices", ge=1, le=10000)
|
||||
number_format: str = Field("SMART_NUMBER", max_length=50)
|
||||
date_format: str = Field(
|
||||
"smart_date",
|
||||
description="Date format for date dimension labels (e.g. 'smart_date', "
|
||||
"'%Y-%m-%d')",
|
||||
max_length=50,
|
||||
)
|
||||
currency_format: CurrencyFormat | None = Field(
|
||||
None,
|
||||
description="Currency symbol applied to the metric value",
|
||||
)
|
||||
color_scheme: str | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Superset color scheme ID (e.g. 'supersetColors', 'lyftColors', "
|
||||
"'googleCategory10c', 'd3Category10'). Defaults to 'supersetColors'."
|
||||
),
|
||||
max_length=100,
|
||||
)
|
||||
legend_orientation: LEGEND_POSITION_LITERAL = Field(
|
||||
"top", description="Legend placement around the chart"
|
||||
)
|
||||
show_total: bool = Field(False, description="Show total in center")
|
||||
labels_outside: bool = True
|
||||
outer_radius: int = Field(70, description="Outer radius % (1-100)", ge=1, le=100)
|
||||
@@ -908,6 +962,15 @@ class PivotTableChartConfig(UnknownFieldCheckMixin):
|
||||
)
|
||||
row_limit: int = Field(10000, description="Max cells", ge=1, le=50000)
|
||||
value_format: str = Field("SMART_NUMBER", max_length=50)
|
||||
date_format: str | None = Field(
|
||||
None,
|
||||
description="Date format for date columns (e.g. 'smart_date', '%Y-%m-%d')",
|
||||
max_length=50,
|
||||
)
|
||||
currency_format: CurrencyFormat | None = Field(
|
||||
None,
|
||||
description="Currency symbol applied to numeric metric values",
|
||||
)
|
||||
|
||||
|
||||
class MixedTimeseriesChartConfig(UnknownFieldCheckMixin):
|
||||
@@ -954,9 +1017,29 @@ class MixedTimeseriesChartConfig(UnknownFieldCheckMixin):
|
||||
)
|
||||
# Display options
|
||||
show_legend: bool = True
|
||||
legend_orientation: LEGEND_POSITION_LITERAL = Field(
|
||||
"top", description="Legend placement around the chart"
|
||||
)
|
||||
show_value: bool = Field(False, description="Show data labels on each data point")
|
||||
x_axis: AxisConfig | None = None
|
||||
y_axis: AxisConfig | None = None
|
||||
y_axis_secondary: AxisConfig | None = None
|
||||
color_scheme: str | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Superset color scheme ID (e.g. 'supersetColors', 'lyftColors'). "
|
||||
"When omitted, Superset's default scheme is used."
|
||||
),
|
||||
max_length=100,
|
||||
)
|
||||
currency_format: CurrencyFormat | None = Field(
|
||||
None,
|
||||
description="Currency symbol applied to primary metric values",
|
||||
)
|
||||
currency_format_secondary: CurrencyFormat | None = Field(
|
||||
None,
|
||||
description="Currency symbol applied to secondary metric values",
|
||||
)
|
||||
filters: List[FilterConfig] | None = Field(
|
||||
None,
|
||||
description="Structured filters (column/op/value). "
|
||||
@@ -1132,6 +1215,27 @@ class BigNumberChartConfig(UnknownFieldCheckMixin):
|
||||
),
|
||||
max_length=50,
|
||||
)
|
||||
time_format: str | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Date format string for trendline x-axis labels "
|
||||
"(e.g. 'smart_date', '%Y-%m-%d'). Only applies when "
|
||||
"show_trendline=True."
|
||||
),
|
||||
max_length=50,
|
||||
)
|
||||
currency_format: CurrencyFormat | None = Field(
|
||||
None,
|
||||
description="Currency symbol applied to the metric value",
|
||||
)
|
||||
color_scheme: str | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Superset color scheme ID for the trendline (e.g. 'supersetColors'). "
|
||||
"When omitted, Superset's default scheme is used."
|
||||
),
|
||||
max_length=100,
|
||||
)
|
||||
start_y_axis_at_zero: bool = Field(
|
||||
True,
|
||||
description="Anchor trendline y-axis at zero",
|
||||
@@ -1217,6 +1321,14 @@ class TableChartConfig(UnknownFieldCheckMixin):
|
||||
validation_alias=AliasChoices("sort_by", "order_by_cols", "order_by"),
|
||||
)
|
||||
row_limit: int = Field(1000, description="Max rows returned", ge=1, le=50000)
|
||||
color_scheme: str | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Superset color scheme ID applied to conditional/cell formatting "
|
||||
"(e.g. 'supersetColors')."
|
||||
),
|
||||
max_length=100,
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_unique_column_labels(self) -> "TableChartConfig":
|
||||
@@ -1298,6 +1410,28 @@ class XYChartConfig(UnknownFieldCheckMixin):
|
||||
x_axis: AxisConfig | None = None
|
||||
y_axis: AxisConfig | None = None
|
||||
legend: LegendConfig | None = None
|
||||
x_axis_time_format: str | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Date format for temporal x-axis labels (e.g. 'smart_date', "
|
||||
"'%Y-%m-%d'). Only applies when the x-axis column is temporal."
|
||||
),
|
||||
max_length=50,
|
||||
)
|
||||
show_value: bool = Field(False, description="Show data labels on each data point")
|
||||
currency_format: CurrencyFormat | None = Field(
|
||||
None,
|
||||
description="Currency symbol applied to metric values",
|
||||
)
|
||||
color_scheme: str | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Superset color scheme ID (e.g. 'supersetColors', 'lyftColors', "
|
||||
"'googleCategory10c', 'd3Category10'). When omitted, Superset's "
|
||||
"default scheme is used."
|
||||
),
|
||||
max_length=100,
|
||||
)
|
||||
filters: List[FilterConfig] | None = Field(
|
||||
None,
|
||||
description="Structured filters (column/op/value). "
|
||||
@@ -1414,7 +1548,10 @@ class ListChartsRequest(OwnedByMeMixin, CreatedByMeMixin, MetadataCacheControl):
|
||||
"""
|
||||
from superset.mcp_service.utils.schema_utils import parse_json_or_model_list
|
||||
|
||||
return parse_json_or_model_list(v, ChartFilter, "filters")
|
||||
return cast(
|
||||
List[ChartFilter],
|
||||
parse_json_or_model_list(v, ChartFilter, "filters"),
|
||||
)
|
||||
|
||||
@field_validator("select_columns", mode="before")
|
||||
@classmethod
|
||||
@@ -1575,7 +1712,14 @@ class GenerateChartRequest(QueryCacheControl):
|
||||
|
||||
class GenerateExploreLinkRequest(FormDataCacheControl):
|
||||
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
|
||||
config: ChartConfig = Field(..., description="Chart configuration")
|
||||
config: ChartConfig | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Chart configuration. Optional; omit to get a default "
|
||||
"explore URL that opens the dataset in Superset without a "
|
||||
"preconfigured chart."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class UpdateChartRequest(QueryCacheControl):
|
||||
@@ -2064,6 +2208,38 @@ class AdhocFilter(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
|
||||
class AppliedDashboardFilter(BaseModel):
|
||||
"""A dashboard-level native filter resolved against a specific chart.
|
||||
|
||||
Returned when get_chart_info is called with a dashboard_id. Values come
|
||||
from the filter's default state on the saved dashboard (not a permalink).
|
||||
"""
|
||||
|
||||
id: str | None = Field(None, description="Native filter ID")
|
||||
name: str | None = Field(None, description="Filter display name")
|
||||
filter_type: str | None = Field(
|
||||
None, description="Native filter type (e.g. filter_select, filter_range)"
|
||||
)
|
||||
column: str | None = Field(None, description="Target column the filter applies to")
|
||||
operator: str | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Filter operator as stored in extra_form_data (e.g. 'IN', '==', 'LIKE', "
|
||||
"or 'TIME_RANGE' for temporal filters with no target column)"
|
||||
),
|
||||
)
|
||||
value: Any | None = Field(
|
||||
None, description="Filter value(s) from the default data mask"
|
||||
)
|
||||
status: str = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether the filter contributes to the chart query: 'applied', "
|
||||
"'not_applied', or 'not_applied_uses_default_to_first_item_prequery'"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ChartFiltersInfo(BaseModel):
|
||||
"""Structured representation of all filters applied to a chart."""
|
||||
|
||||
@@ -2105,6 +2281,15 @@ class ChartFiltersInfo(BaseModel):
|
||||
None,
|
||||
description="Custom HAVING clause applied to the chart query",
|
||||
)
|
||||
dashboard_filters: List[AppliedDashboardFilter] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"Dashboard-level native filters in scope for this chart on the "
|
||||
"dashboard passed via get_chart_info's dashboard_id argument. Empty "
|
||||
"when no dashboard_id was provided or no native filter targets this "
|
||||
"chart."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Rebuild ChartInfo so Pydantic can resolve the ChartFiltersInfo forward reference.
|
||||
|
||||
@@ -26,6 +26,7 @@ from typing import Any, Dict, List, TYPE_CHECKING
|
||||
from fastmcp import Context
|
||||
from flask import current_app
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import subqueryload
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -58,9 +59,177 @@ from superset.mcp_service.utils.oauth2_utils import (
|
||||
build_oauth2_redirect_message,
|
||||
OAUTH2_CONFIG_ERROR_MESSAGE,
|
||||
)
|
||||
from superset.utils.core import GenericDataType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_GENERIC_TYPE_MAP: dict[int, str] = {
|
||||
GenericDataType.NUMERIC: "numeric",
|
||||
GenericDataType.STRING: "string",
|
||||
GenericDataType.TEMPORAL: "temporal",
|
||||
GenericDataType.BOOLEAN: "boolean",
|
||||
}
|
||||
|
||||
# Maps Superset viz_type strings to canonical categories so we can
|
||||
# avoid recommending a chart type the user already has.
|
||||
_VIZ_CATEGORY: dict[str, str] = {
|
||||
"echarts_timeseries_line": "line",
|
||||
"echarts_timeseries_smooth": "line",
|
||||
"echarts_timeseries_step": "line",
|
||||
"echarts_timeseries": "line",
|
||||
"echarts_timeseries_bar": "bar",
|
||||
"echarts_area": "area",
|
||||
"echarts_timeseries_scatter": "scatter",
|
||||
"mixed_timeseries": "line",
|
||||
"table": "table",
|
||||
"pie": "pie",
|
||||
"big_number": "kpi",
|
||||
"big_number_total": "kpi",
|
||||
"pop_kpi": "kpi",
|
||||
"dist_bar": "bar",
|
||||
"line": "line",
|
||||
"area": "area",
|
||||
"scatter": "scatter",
|
||||
"bubble": "bubble",
|
||||
"treemap_v2": "treemap",
|
||||
"sunburst_v2": "treemap",
|
||||
"heatmap_v2": "heatmap",
|
||||
"gauge_chart": "gauge",
|
||||
"funnel": "funnel",
|
||||
"histogram": "histogram",
|
||||
"histogram_v2": "histogram",
|
||||
"box_plot": "box_plot",
|
||||
"world_map": "map",
|
||||
"pivot_table_v2": "table",
|
||||
}
|
||||
|
||||
_MAX_RECOMMENDATIONS = 4
|
||||
|
||||
|
||||
def _recommend_visualizations(
|
||||
viz_type: str,
|
||||
columns: list[DataColumn],
|
||||
row_count: int,
|
||||
) -> list[str]:
|
||||
"""Suggest visualization types based on column types,
|
||||
cardinality, and the chart's current viz_type.
|
||||
"""
|
||||
if not columns:
|
||||
return ["table"]
|
||||
|
||||
current_category = _VIZ_CATEGORY.get(viz_type, viz_type)
|
||||
candidates = _build_candidates(columns, row_count)
|
||||
|
||||
if not candidates:
|
||||
candidates = ["table", "bar chart"]
|
||||
|
||||
return _filter_candidates(candidates, current_category)
|
||||
|
||||
|
||||
def _build_candidates(
|
||||
columns: list[DataColumn],
|
||||
row_count: int,
|
||||
) -> list[str]:
|
||||
"""Build candidate visualization list from column metadata."""
|
||||
temporal = [c for c in columns if c.data_type == "temporal"]
|
||||
numeric = [c for c in columns if c.data_type == "numeric"]
|
||||
categorical = [c for c in columns if c.data_type in ("string", "boolean")]
|
||||
|
||||
if temporal and numeric:
|
||||
return _candidates_temporal_numeric(numeric, row_count)
|
||||
if categorical and numeric:
|
||||
return _candidates_categorical_numeric(numeric, categorical)
|
||||
if len(numeric) >= 2:
|
||||
return _candidates_multi_numeric(numeric, categorical)
|
||||
if len(numeric) == 1 and not temporal and not categorical:
|
||||
return _candidates_single_numeric(numeric[0], row_count)
|
||||
return []
|
||||
|
||||
|
||||
def _candidates_temporal_numeric(
|
||||
numeric: list[DataColumn], row_count: int
|
||||
) -> list[str]:
|
||||
# Few data points are better as a bar chart than a line
|
||||
if row_count < 5:
|
||||
candidates = ["bar chart", "table"]
|
||||
else:
|
||||
candidates = ["line chart", "area chart", "bar chart"]
|
||||
if len(numeric) > 1:
|
||||
candidates.append("multi-line chart")
|
||||
return candidates
|
||||
|
||||
|
||||
def _candidates_categorical_numeric(
|
||||
numeric: list[DataColumn],
|
||||
categorical: list[DataColumn],
|
||||
) -> list[str]:
|
||||
candidates = ["bar chart"]
|
||||
if len(numeric) == 1 and categorical[0].unique_count <= 10:
|
||||
candidates.append("pie chart")
|
||||
if len(numeric) >= 2:
|
||||
candidates.append("scatter plot")
|
||||
candidates.append("heatmap")
|
||||
if any(c.unique_count > 5 for c in categorical):
|
||||
candidates.append("treemap")
|
||||
return candidates
|
||||
|
||||
|
||||
def _candidates_single_numeric(col: DataColumn, row_count: int) -> list[str]:
|
||||
candidates = ["big number / KPI", "gauge chart"]
|
||||
if row_count > 20 and col.unique_count > 10:
|
||||
candidates.insert(0, "histogram")
|
||||
return candidates
|
||||
|
||||
|
||||
def _candidates_multi_numeric(
|
||||
numeric: list[DataColumn],
|
||||
categorical: list[DataColumn],
|
||||
) -> list[str]:
|
||||
candidates = ["scatter plot"]
|
||||
if len(numeric) >= 3:
|
||||
candidates.append("bubble chart")
|
||||
if categorical:
|
||||
candidates.append("heatmap")
|
||||
return candidates
|
||||
|
||||
|
||||
# Maps each candidate string to a canonical category for dedup
|
||||
# against the current viz_type.
|
||||
_CANDIDATE_CATEGORY: dict[str, str] = {
|
||||
"line chart": "line",
|
||||
"multi-line chart": "line",
|
||||
"area chart": "area",
|
||||
"bar chart": "bar",
|
||||
"scatter plot": "scatter",
|
||||
"bubble chart": "bubble",
|
||||
"pie chart": "pie",
|
||||
"treemap": "treemap",
|
||||
"heatmap": "heatmap",
|
||||
"big number / KPI": "kpi",
|
||||
"gauge chart": "gauge",
|
||||
"histogram": "histogram",
|
||||
"table": "table",
|
||||
}
|
||||
|
||||
|
||||
def _filter_candidates(
|
||||
candidates: list[str],
|
||||
current_category: str,
|
||||
) -> list[str]:
|
||||
"""Deduplicate, exclude the current viz category, and cap."""
|
||||
seen: set[str] = set()
|
||||
result: list[str] = []
|
||||
for c in candidates:
|
||||
if c in seen:
|
||||
continue
|
||||
if _CANDIDATE_CATEGORY.get(c) == current_category:
|
||||
continue
|
||||
seen.add(c)
|
||||
result.append(c)
|
||||
if len(result) >= _MAX_RECOMMENDATIONS:
|
||||
break
|
||||
return result
|
||||
|
||||
|
||||
def _sanitize_chart_data_for_llm_context(chart_data: ChartData) -> ChartData:
|
||||
"""Wrap chart data read-path descriptive fields before LLM exposure."""
|
||||
@@ -182,7 +351,18 @@ async def get_chart_data( # noqa: C901
|
||||
# Build query context entirely from cached form_data
|
||||
return await _query_from_form_data(cached_form_data_dict, request, ctx)
|
||||
|
||||
# Find the chart by identifier
|
||||
# Find the chart by identifier.
|
||||
# Eagerly load the dataset's metrics relationship so Excel export
|
||||
# (which may run after the request-scoped session is detached) can
|
||||
# access dataset.metrics without triggering a lazy load. See
|
||||
# apache/superset#39206 for the analogous database eager-load fix.
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.models.slice import Slice
|
||||
|
||||
chart_query_options = [
|
||||
subqueryload(Slice.table).subqueryload(SqlaTable.metrics),
|
||||
]
|
||||
|
||||
with event_logger.log_context(action="mcp.get_chart_data.chart_lookup"):
|
||||
await ctx.debug("Looking up chart: identifier=%s" % (request.identifier,))
|
||||
if request.identifier is None:
|
||||
@@ -190,7 +370,9 @@ async def get_chart_data( # noqa: C901
|
||||
error="Chart identifier is required",
|
||||
error_type="ValidationError",
|
||||
)
|
||||
chart = find_chart_by_identifier(request.identifier)
|
||||
chart = find_chart_by_identifier(
|
||||
request.identifier, query_options=chart_query_options
|
||||
)
|
||||
|
||||
if not chart:
|
||||
await ctx.warning("Chart not found: identifier=%s" % (request.identifier,))
|
||||
@@ -484,8 +666,9 @@ async def get_chart_data( # noqa: C901
|
||||
)
|
||||
|
||||
# Create rich column metadata
|
||||
coltypes = query_result.get("coltypes", [])
|
||||
columns = []
|
||||
for col_name in raw_columns:
|
||||
for idx, col_name in enumerate(raw_columns):
|
||||
# Sample some values for metadata
|
||||
sample_values = [
|
||||
row.get(col_name)
|
||||
@@ -493,13 +676,16 @@ async def get_chart_data( # noqa: C901
|
||||
if row.get(col_name) is not None
|
||||
]
|
||||
|
||||
# Infer data type
|
||||
# Use SQL-derived GenericDataType when available,
|
||||
# fall back to Python isinstance heuristic
|
||||
data_type = "string"
|
||||
if sample_values:
|
||||
if all(isinstance(v, (int, float)) for v in sample_values):
|
||||
data_type = "numeric"
|
||||
elif all(isinstance(v, bool) for v in sample_values):
|
||||
if coltypes:
|
||||
data_type = _GENERIC_TYPE_MAP.get(coltypes[idx], "string")
|
||||
elif sample_values:
|
||||
if all(isinstance(v, bool) for v in sample_values):
|
||||
data_type = "boolean"
|
||||
elif all(isinstance(v, (int, float)) for v in sample_values):
|
||||
data_type = "numeric"
|
||||
|
||||
columns.append(
|
||||
DataColumn(
|
||||
@@ -542,13 +728,11 @@ async def get_chart_data( # noqa: C901
|
||||
else:
|
||||
insights.append("Fresh data retrieved from database")
|
||||
|
||||
recommended_visualizations = []
|
||||
if any(
|
||||
"time" in col.lower() or "date" in col.lower() for col in raw_columns
|
||||
):
|
||||
recommended_visualizations.extend(["line chart", "time series"])
|
||||
if len(raw_columns) <= 3:
|
||||
recommended_visualizations.extend(["bar chart", "scatter plot"])
|
||||
recommended_visualizations = _recommend_visualizations(
|
||||
viz_type=chart.viz_type or "unknown",
|
||||
columns=columns,
|
||||
row_count=len(data),
|
||||
)
|
||||
|
||||
# Performance metadata with cache awareness
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
@@ -25,12 +25,19 @@ from fastmcp import Context
|
||||
from sqlalchemy.orm import subqueryload
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.commands.dashboard.exceptions import DashboardNotFoundError
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.chart.chart_helpers import get_cached_form_data
|
||||
from superset.mcp_service.chart.chart_helpers import (
|
||||
build_applied_dashboard_filters,
|
||||
ChartNotOnDashboardError,
|
||||
get_cached_form_data,
|
||||
)
|
||||
from superset.mcp_service.chart.chart_utils import validate_chart_dataset
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
CHART_FORM_DATA_EXCLUDED_FIELD_NAMES,
|
||||
ChartError,
|
||||
ChartFiltersInfo,
|
||||
ChartInfo,
|
||||
extract_filters_from_form_data,
|
||||
GetChartInfoRequest,
|
||||
@@ -89,6 +96,66 @@ FORM_DATA_OVERRIDE_EXCLUDED_FIELD_NAMES = (
|
||||
)
|
||||
|
||||
|
||||
async def _validate_chart_dataset_access(
|
||||
result: ChartInfo, ctx: Context
|
||||
) -> ChartError | None:
|
||||
"""Validate that the chart's dataset is accessible to the current user.
|
||||
|
||||
Returns a ChartError if the dataset is not accessible, otherwise None.
|
||||
Logs any non-fatal warnings (e.g., virtual dataset warnings) via ctx.
|
||||
"""
|
||||
from superset.daos.chart import ChartDAO
|
||||
|
||||
if not result.id:
|
||||
return None
|
||||
chart = ChartDAO.find_by_id(result.id)
|
||||
if not chart:
|
||||
return None
|
||||
validation_result = validate_chart_dataset(chart, check_access=True)
|
||||
if not validation_result.is_valid:
|
||||
await ctx.warning(
|
||||
"Chart found but dataset is not accessible: %s" % (validation_result.error,)
|
||||
)
|
||||
return ChartError(
|
||||
error=validation_result.error or "Chart's dataset is not accessible",
|
||||
error_type="DatasetNotAccessible",
|
||||
)
|
||||
for warning in validation_result.warnings:
|
||||
await ctx.warning("Dataset warning: %s" % (warning,))
|
||||
return None
|
||||
|
||||
|
||||
async def _attach_dashboard_filters(
|
||||
result: ChartInfo, dashboard_id: int, ctx: Context
|
||||
) -> ChartError | None:
|
||||
"""Resolve dashboard-scoped native filters and attach them to result.filters.
|
||||
|
||||
Returns a ChartError to surface to the caller on validation / access
|
||||
failures, or None on success (including the no-filters case).
|
||||
"""
|
||||
if not result.id:
|
||||
return None
|
||||
with event_logger.log_context(action="mcp.get_chart_info.dashboard_filters"):
|
||||
try:
|
||||
dashboard_filters = build_applied_dashboard_filters(dashboard_id, result.id)
|
||||
except DashboardNotFoundError as exc:
|
||||
await ctx.warning("Dashboard not found: %s" % (str(exc),))
|
||||
return ChartError(error=str(exc), error_type="DashboardNotFound")
|
||||
except ChartNotOnDashboardError as exc:
|
||||
await ctx.warning("Chart not on dashboard: %s" % (str(exc),))
|
||||
return ChartError(error=str(exc), error_type="ChartNotOnDashboard")
|
||||
except SupersetSecurityException as exc:
|
||||
await ctx.warning("Dashboard not accessible: %s" % (str(exc),))
|
||||
return ChartError(error=str(exc), error_type="DashboardNotAccessible")
|
||||
|
||||
if dashboard_filters:
|
||||
if result.filters is None:
|
||||
result.filters = ChartFiltersInfo(dashboard_filters=dashboard_filters)
|
||||
else:
|
||||
result.filters.dashboard_filters = dashboard_filters
|
||||
return None
|
||||
|
||||
|
||||
def _apply_unsaved_state_override(result: ChartInfo, form_data_key: str) -> None:
|
||||
"""Override a ChartInfo's form_data with cached unsaved state."""
|
||||
from superset.utils import json as utils_json
|
||||
@@ -178,6 +245,17 @@ async def get_chart_info(
|
||||
}
|
||||
```
|
||||
|
||||
With dashboard context to resolve applied dashboard-level filters:
|
||||
```json
|
||||
{
|
||||
"identifier": 123,
|
||||
"dashboard_id": 45
|
||||
}
|
||||
```
|
||||
When dashboard_id is provided, the response's filters.dashboard_filters
|
||||
lists native filters (with column, operator, and value) that are in scope
|
||||
for this chart on that dashboard.
|
||||
|
||||
Returns chart details including name, type, and URL.
|
||||
"""
|
||||
from superset.daos.chart import ChartDAO
|
||||
@@ -247,23 +325,14 @@ async def get_chart_info(
|
||||
)
|
||||
|
||||
# Validate the chart's dataset is accessible
|
||||
if result.id:
|
||||
chart = ChartDAO.find_by_id(result.id)
|
||||
if chart:
|
||||
validation_result = validate_chart_dataset(chart, check_access=True)
|
||||
if not validation_result.is_valid:
|
||||
await ctx.warning(
|
||||
"Chart found but dataset is not accessible: %s"
|
||||
% (validation_result.error,)
|
||||
)
|
||||
return ChartError(
|
||||
error=validation_result.error
|
||||
or "Chart's dataset is not accessible",
|
||||
error_type="DatasetNotAccessible",
|
||||
)
|
||||
# Log any warnings (e.g., virtual dataset warnings)
|
||||
for warning in validation_result.warnings:
|
||||
await ctx.warning("Dataset warning: %s" % (warning,))
|
||||
dataset_error = await _validate_chart_dataset_access(result, ctx)
|
||||
if dataset_error is not None:
|
||||
return dataset_error
|
||||
|
||||
if request.dashboard_id:
|
||||
error = await _attach_dashboard_filters(result, request.dashboard_id, ctx)
|
||||
if error is not None:
|
||||
return error
|
||||
else:
|
||||
await ctx.warning("Chart retrieval failed: error=%s" % (str(result),))
|
||||
|
||||
|
||||
@@ -104,11 +104,16 @@ async def list_charts(
|
||||
list_charts(search="revenue", page=1) # DO NOT DO THIS
|
||||
|
||||
Valid filter columns for ``filters[].col``:
|
||||
``slice_name``, ``viz_type``, ``datasource_name``
|
||||
``slice_name``, ``viz_type``, ``datasource_name``,
|
||||
``created_by_fk``, ``changed_by_fk``
|
||||
|
||||
Sortable columns for ``order_column``:
|
||||
``id``, ``slice_name``, ``viz_type``, ``description``,
|
||||
``changed_on``, ``created_on``
|
||||
|
||||
To filter by a person, call find_users to resolve the name to a user ID,
|
||||
then pass it as a filter: filters=[{"col": "created_by_fk", "opr": "eq",
|
||||
"value": <id>}] (or "changed_by_fk"). Do not pass the name as search.
|
||||
"""
|
||||
request = request or _DEFAULT_LIST_CHARTS_REQUEST.model_copy(deep=True)
|
||||
await ctx.info(
|
||||
|
||||
@@ -37,10 +37,12 @@ class ColumnMetadata(BaseModel):
|
||||
"""Metadata for a selectable column."""
|
||||
|
||||
name: str = Field(..., description="Column name to use in select_columns")
|
||||
description: str | None = Field(None, description="Column description")
|
||||
type: str | None = Field(None, description="Data type (str, int, datetime, etc.)")
|
||||
description: str | None = Field(default=None, description="Column description")
|
||||
type: str | None = Field(
|
||||
default=None, description="Data type (str, int, datetime, etc.)"
|
||||
)
|
||||
is_default: bool = Field(
|
||||
False, description="Whether this column is included by default"
|
||||
default=False, description="Whether this column is included by default"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Dict, List, Literal, TYPE_CHECKING
|
||||
from typing import Annotated, Any, cast, Dict, List, Literal, TYPE_CHECKING
|
||||
|
||||
import humanize
|
||||
from pydantic import (
|
||||
@@ -169,16 +169,20 @@ class DashboardFilter(ColumnOperator):
|
||||
value: The value to filter by (type depends on col and opr).
|
||||
"""
|
||||
|
||||
col: Literal[
|
||||
col: Literal[ # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
"dashboard_title",
|
||||
"published",
|
||||
"favorite",
|
||||
"created_by_fk",
|
||||
"changed_by_fk",
|
||||
] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Column to filter on. Valid values: 'dashboard_title', 'published', "
|
||||
"'favorite'. Other column names are not valid filter columns and will "
|
||||
"cause a validation error."
|
||||
"Column to filter on. Use "
|
||||
"get_schema(model_type='dashboard') for available "
|
||||
"filter columns. To filter by a person, first call find_users to "
|
||||
"resolve a name to a user ID, then filter by created_by_fk or "
|
||||
"changed_by_fk with that integer ID."
|
||||
),
|
||||
)
|
||||
opr: ColumnOperatorEnum = Field(
|
||||
@@ -223,7 +227,10 @@ class ListDashboardsRequest(OwnedByMeMixin, CreatedByMeMixin, MetadataCacheContr
|
||||
"""
|
||||
from superset.mcp_service.utils.schema_utils import parse_json_or_model_list
|
||||
|
||||
return parse_json_or_model_list(v, DashboardFilter, "filters")
|
||||
return cast(
|
||||
List[DashboardFilter],
|
||||
parse_json_or_model_list(v, DashboardFilter, "filters"),
|
||||
)
|
||||
|
||||
@field_validator("select_columns", mode="before")
|
||||
@classmethod
|
||||
@@ -392,14 +399,14 @@ class DashboardInfo(BaseModel):
|
||||
|
||||
# Fields for permalink/filter state support
|
||||
permalink_key: str | None = Field(
|
||||
None,
|
||||
default=None,
|
||||
description=(
|
||||
"Permalink key used to retrieve filter state. When present, indicates "
|
||||
"the filter_state came from a permalink rather than the default dashboard."
|
||||
),
|
||||
)
|
||||
filter_state: Dict[str, Any] | None = Field(
|
||||
None,
|
||||
default=None,
|
||||
description=(
|
||||
"Filter state from permalink. Contains dataMask (native filter values), "
|
||||
"activeTabs, anchor, and urlParams. When present, represents the actual "
|
||||
|
||||
@@ -98,11 +98,18 @@ async def list_dashboards(
|
||||
list_dashboards(search="sales", page=1) # DO NOT DO THIS
|
||||
|
||||
Valid filter columns for ``filters[].col``:
|
||||
``dashboard_title``, ``published``, ``favorite``
|
||||
``dashboard_title``, ``published``, ``favorite``,
|
||||
``created_by_fk``, ``changed_by_fk``
|
||||
|
||||
Sortable columns for ``order_column``:
|
||||
``id``, ``dashboard_title``, ``slug``, ``published``,
|
||||
``changed_on``, ``created_on``
|
||||
|
||||
To filter by a person (e.g. "dashboards Maxime is working on"), do NOT pass
|
||||
the name as the search parameter — search matches titles and slugs only.
|
||||
Instead, call find_users to resolve the name to a user ID, then pass it as
|
||||
a filter: filters=[{"col": "created_by_fk", "opr": "eq", "value": <id>}]
|
||||
(or "changed_by_fk" for "last modified by").
|
||||
"""
|
||||
request = request or _DEFAULT_LIST_DASHBOARDS_REQUEST.model_copy(deep=True)
|
||||
await ctx.info(
|
||||
|
||||
@@ -22,7 +22,7 @@ Pydantic schemas for database-related responses
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Dict, List, Literal
|
||||
from typing import Annotated, Any, cast, Dict, List, Literal
|
||||
|
||||
import humanize
|
||||
from pydantic import (
|
||||
@@ -58,7 +58,7 @@ class DatabaseFilter(ColumnOperator):
|
||||
value: The value to filter by (type depends on col and opr).
|
||||
"""
|
||||
|
||||
col: Literal[
|
||||
col: Literal[ # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
"database_name",
|
||||
"expose_in_sqllab",
|
||||
"allow_file_upload",
|
||||
@@ -242,7 +242,10 @@ class ListDatabasesRequest(CreatedByMeMixin, MetadataCacheControl):
|
||||
@classmethod
|
||||
def parse_filters(cls, v: Any) -> List[DatabaseFilter]:
|
||||
"""Accept both JSON string and list of objects."""
|
||||
return parse_json_or_model_list(v, DatabaseFilter, "filters")
|
||||
return cast(
|
||||
List[DatabaseFilter],
|
||||
parse_json_or_model_list(v, DatabaseFilter, "filters"),
|
||||
)
|
||||
|
||||
@field_validator("select_columns", mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -65,17 +65,18 @@ class DatasetFilter(ColumnOperator):
|
||||
value: The value to filter by (type depends on col and opr).
|
||||
"""
|
||||
|
||||
col: Literal[
|
||||
col: Literal[ # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
"table_name",
|
||||
"schema",
|
||||
"database_name",
|
||||
"created_by_fk",
|
||||
"changed_by_fk",
|
||||
] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Column to filter on. Valid values: 'table_name', 'schema', "
|
||||
"'database_name'. Other column names (e.g. 'created_by_fk', 'id') "
|
||||
"are not valid filter columns and will cause a validation error."
|
||||
),
|
||||
description="Column to filter on. Use get_schema(model_type='dataset') for "
|
||||
"available filter columns. To filter by a person, first call find_users "
|
||||
"to resolve a name to a user ID, then filter by created_by_fk or "
|
||||
"changed_by_fk with that integer ID.",
|
||||
)
|
||||
opr: ColumnOperatorEnum = Field(
|
||||
...,
|
||||
@@ -658,7 +659,7 @@ def serialize_dataset_object(dataset: Any) -> DatasetInfo | None:
|
||||
params = None
|
||||
columns = [
|
||||
TableColumnInfo(
|
||||
column_name=getattr(col, "column_name", None),
|
||||
column_name=getattr(col, "column_name", None) or "",
|
||||
verbose_name=getattr(col, "verbose_name", None),
|
||||
type=getattr(col, "type", None),
|
||||
is_dttm=getattr(col, "is_dttm", None),
|
||||
@@ -670,7 +671,7 @@ def serialize_dataset_object(dataset: Any) -> DatasetInfo | None:
|
||||
]
|
||||
metrics = [
|
||||
SqlMetricInfo(
|
||||
metric_name=getattr(metric, "metric_name", None),
|
||||
metric_name=getattr(metric, "metric_name", None) or "",
|
||||
verbose_name=getattr(metric, "verbose_name", None),
|
||||
expression=getattr(metric, "expression", None),
|
||||
description=getattr(metric, "description", None),
|
||||
|
||||
@@ -109,10 +109,15 @@ async def list_datasets(
|
||||
list_datasets(search="sales", page=1) # DO NOT DO THIS
|
||||
|
||||
Valid filter columns for ``filters[].col``:
|
||||
``table_name``, ``schema``, ``database_name``
|
||||
``table_name``, ``schema``, ``database_name``,
|
||||
``created_by_fk``, ``changed_by_fk``
|
||||
|
||||
Sortable columns for ``order_column``:
|
||||
``id``, ``table_name``, ``schema``, ``changed_on``, ``created_on``
|
||||
|
||||
To filter by a person, call find_users to resolve the name to a user ID,
|
||||
then pass it as a filter: filters=[{"col": "created_by_fk", "opr": "eq",
|
||||
"value": <id>}] (or "changed_by_fk"). Do not pass the name as search.
|
||||
"""
|
||||
if ctx is None:
|
||||
raise RuntimeError("FastMCP context is required for list_datasets")
|
||||
|
||||
@@ -65,10 +65,12 @@ async def generate_explore_link(
|
||||
- "Visualize [data]"
|
||||
- General data exploration
|
||||
- When user wants to SEE data visually
|
||||
- Opening a dataset in Explore without a preconfigured chart (omit config)
|
||||
|
||||
IMPORTANT:
|
||||
- Use numeric dataset ID or UUID (NOT schema.table_name format)
|
||||
- MUST include chart_type in config (either 'xy' or 'table')
|
||||
- When config is provided, MUST include chart_type (e.g. 'xy' or 'table')
|
||||
- Omit config entirely to return a default explore URL for the dataset
|
||||
|
||||
Example usage:
|
||||
```json
|
||||
@@ -83,6 +85,11 @@ async def generate_explore_link(
|
||||
}
|
||||
```
|
||||
|
||||
Or with no config to simply open the dataset in Explore:
|
||||
```json
|
||||
{"dataset_id": 123}
|
||||
```
|
||||
|
||||
Better UX because:
|
||||
- Users can interact with chart before saving
|
||||
- Easy to modify parameters instantly
|
||||
@@ -93,9 +100,10 @@ async def generate_explore_link(
|
||||
|
||||
Returns explore URL for immediate use.
|
||||
"""
|
||||
chart_type = request.config.chart_type if request.config else "none"
|
||||
await ctx.info(
|
||||
"Generating explore link for dataset_id=%s, chart_type=%s"
|
||||
% (request.dataset_id, request.config.chart_type)
|
||||
% (request.dataset_id, chart_type)
|
||||
)
|
||||
await ctx.debug(
|
||||
"Configuration details: use_cache=%s, force_refresh=%s, cache_form_data=%s"
|
||||
@@ -103,9 +111,6 @@ async def generate_explore_link(
|
||||
)
|
||||
|
||||
try:
|
||||
# config is already a typed ChartConfig (validated by Pydantic)
|
||||
config = request.config
|
||||
|
||||
await ctx.report_progress(1, 4, "Validating dataset exists")
|
||||
with event_logger.log_context(action="mcp.generate_explore_link.dataset_check"):
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
@@ -157,8 +162,32 @@ async def generate_explore_link(
|
||||
),
|
||||
}
|
||||
|
||||
# When no config is provided, return a default explore URL that opens
|
||||
# the dataset in Superset without a preconfigured chart.
|
||||
if request.config is None:
|
||||
await ctx.report_progress(4, 4, "URL generation complete")
|
||||
from superset.mcp_service.utils.url_utils import get_superset_base_url
|
||||
|
||||
base_url = get_superset_base_url()
|
||||
default_url = (
|
||||
f"{base_url}/explore/?datasource_type=table&datasource_id={dataset.id}"
|
||||
)
|
||||
await ctx.info(
|
||||
"Default explore link generated: dataset_id=%s" % (request.dataset_id,)
|
||||
)
|
||||
return {
|
||||
"url": default_url,
|
||||
"form_data": {},
|
||||
"form_data_key": None,
|
||||
"chart_type_label": None,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
await ctx.report_progress(2, 4, "Converting configuration to form data")
|
||||
with event_logger.log_context(action="mcp.generate_explore_link.form_data"):
|
||||
# config is already a typed ChartConfig (validated by Pydantic)
|
||||
config = request.config
|
||||
|
||||
# Normalize column names to match canonical dataset column names
|
||||
# This fixes case sensitivity issues (e.g., 'order_date' vs 'OrderDate')
|
||||
try:
|
||||
@@ -256,7 +285,7 @@ async def generate_explore_link(
|
||||
"Explore link generation failed for dataset_id=%s, chart_type=%s: %s: %s"
|
||||
% (
|
||||
request.dataset_id,
|
||||
request.config.chart_type,
|
||||
chart_type,
|
||||
type(e).__name__,
|
||||
str(e),
|
||||
)
|
||||
|
||||
@@ -34,6 +34,7 @@ from superset.mcp_service.privacy import (
|
||||
filter_user_directory_columns,
|
||||
SELF_REFERENCING_FILTER_COLUMNS,
|
||||
USER_DIRECTORY_FIELDS,
|
||||
USER_FILTER_FIELDS,
|
||||
)
|
||||
from superset.mcp_service.system.schemas import PaginationInfo
|
||||
from superset.mcp_service.utils import _is_uuid
|
||||
@@ -314,14 +315,6 @@ class ModelListCore(BaseCore, Generic[L]):
|
||||
has_previous=page > 0,
|
||||
)
|
||||
|
||||
# Build response
|
||||
def get_keys(obj: BaseModel | dict[str, Any] | Any) -> List[str]:
|
||||
if hasattr(obj, "model_dump"):
|
||||
return list(obj.model_dump().keys())
|
||||
elif isinstance(obj, dict):
|
||||
return list(obj.keys())
|
||||
return []
|
||||
|
||||
response_kwargs = {
|
||||
self.list_field_name: item_objs,
|
||||
"count": len(item_objs),
|
||||
@@ -595,7 +588,7 @@ class InstanceInfoCore(BaseCore):
|
||||
return counts
|
||||
|
||||
def _calculate_time_based_metrics(
|
||||
self, base_counts: Dict[str, int]
|
||||
self, _base_counts: Dict[str, int]
|
||||
) -> Dict[str, Dict[str, int]]:
|
||||
"""Calculate time-based metrics for recent activity."""
|
||||
now = datetime.now(timezone.utc)
|
||||
@@ -774,7 +767,9 @@ class ModelGetSchemaCore(BaseCore, Generic[S]):
|
||||
self.default_sort = default_sort
|
||||
self.default_sort_direction = default_sort_direction
|
||||
self.exclude_filter_columns = set(exclude_filter_columns or set())
|
||||
self.exclude_filter_columns.update(USER_DIRECTORY_FIELDS)
|
||||
# Hide user-directory columns from filter discovery, except the small
|
||||
# set callers may legitimately filter by ID (resolved via find_users).
|
||||
self.exclude_filter_columns.update(USER_DIRECTORY_FIELDS - USER_FILTER_FIELDS)
|
||||
|
||||
def _get_filter_columns(self) -> Dict[str, List[str]]:
|
||||
"""Get filterable columns and operators from the DAO."""
|
||||
|
||||
16
superset/mcp_service/plugin/__init__.py
Normal file
16
superset/mcp_service/plugin/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
23
superset/mcp_service/plugin/dao.py
Normal file
23
superset/mcp_service/plugin/dao.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from superset.daos.base import BaseDAO
|
||||
from superset.models.dynamic_plugins import DynamicPlugin
|
||||
|
||||
|
||||
class DynamicPluginDAO(BaseDAO[DynamicPlugin]):
|
||||
pass
|
||||
213
superset/mcp_service/plugin/schemas.py
Normal file
213
superset/mcp_service/plugin/schemas.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Pydantic schemas for dynamic plugin responses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Dict, List, Literal
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
field_validator,
|
||||
model_serializer,
|
||||
model_validator,
|
||||
PositiveInt,
|
||||
)
|
||||
|
||||
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
|
||||
from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
|
||||
from superset.mcp_service.system.schemas import PaginationInfo
|
||||
from superset.mcp_service.utils.schema_utils import (
|
||||
parse_json_or_list,
|
||||
parse_json_or_model_list,
|
||||
)
|
||||
|
||||
DEFAULT_PLUGIN_COLUMNS = ["id", "name", "key", "bundle_url"]
|
||||
|
||||
ALL_PLUGIN_COLUMNS = [
|
||||
"id",
|
||||
"name",
|
||||
"key",
|
||||
"bundle_url",
|
||||
"changed_on",
|
||||
"created_on",
|
||||
]
|
||||
|
||||
SORTABLE_PLUGIN_COLUMNS = ["id", "name", "key", "changed_on", "created_on"]
|
||||
|
||||
|
||||
class PluginColumnFilter(ColumnOperator):
|
||||
"""Filter object for plugin listing."""
|
||||
|
||||
col: Literal["name", "key"] = Field(..., description="Column to filter on.")
|
||||
opr: ColumnOperatorEnum = Field(..., description="Operator to use.")
|
||||
value: str | int | float | bool | List[str | int | float | bool] = Field(
|
||||
..., description="Value to filter by"
|
||||
)
|
||||
|
||||
|
||||
class PluginInfo(BaseModel):
|
||||
id: int | None = Field(None, description="Plugin ID")
|
||||
name: str | None = Field(None, description="Plugin display name")
|
||||
key: str | None = Field(None, description="Plugin key (corresponds to viz_type)")
|
||||
bundle_url: str | None = Field(None, description="URL to the plugin bundle")
|
||||
changed_on: str | datetime | None = Field(
|
||||
None, description="Last modification timestamp"
|
||||
)
|
||||
created_on: str | datetime | None = Field(None, description="Creation timestamp")
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
ser_json_timedelta="iso8601",
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def _filter_fields_by_context(self, serializer: Any, info: Any) -> Dict[str, Any]:
|
||||
data = serializer(self)
|
||||
if info.context and isinstance(info.context, dict):
|
||||
select_columns = info.context.get("select_columns")
|
||||
if select_columns:
|
||||
requested_fields = set(select_columns)
|
||||
return {k: v for k, v in data.items() if k in requested_fields}
|
||||
return data
|
||||
|
||||
|
||||
class PluginList(BaseModel):
|
||||
plugins: List[PluginInfo]
|
||||
count: int
|
||||
total_count: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
has_previous: bool
|
||||
has_next: bool
|
||||
columns_requested: List[str] = Field(default_factory=list)
|
||||
columns_loaded: List[str] = Field(default_factory=list)
|
||||
columns_available: List[str] = Field(default_factory=list)
|
||||
sortable_columns: List[str] = Field(default_factory=list)
|
||||
filters_applied: List[PluginColumnFilter] = Field(default_factory=list)
|
||||
pagination: PaginationInfo | None = None
|
||||
timestamp: datetime | None = None
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
|
||||
class ListPluginsRequest(BaseModel):
|
||||
"""Request schema for list_plugins."""
|
||||
|
||||
filters: Annotated[
|
||||
List[PluginColumnFilter],
|
||||
Field(
|
||||
default_factory=list,
|
||||
description="List of filter objects (col, opr, value). "
|
||||
"Cannot be used with search.",
|
||||
),
|
||||
]
|
||||
select_columns: Annotated[
|
||||
List[str],
|
||||
Field(
|
||||
default_factory=list,
|
||||
description="Columns to include in response. Defaults to common columns.",
|
||||
),
|
||||
]
|
||||
search: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
default=None,
|
||||
description="Text search on plugin name or key. "
|
||||
"Cannot be used with filters.",
|
||||
),
|
||||
]
|
||||
order_column: Annotated[
|
||||
str | None, Field(default=None, description="Column to order results by")
|
||||
]
|
||||
order_direction: Annotated[
|
||||
Literal["asc", "desc"],
|
||||
Field(default="desc", description="Sort direction"),
|
||||
]
|
||||
page: Annotated[
|
||||
PositiveInt,
|
||||
Field(default=1, description="Page number (1-based)"),
|
||||
]
|
||||
page_size: Annotated[
|
||||
int,
|
||||
Field(
|
||||
default=DEFAULT_PAGE_SIZE,
|
||||
gt=0,
|
||||
le=MAX_PAGE_SIZE,
|
||||
description=f"Items per page (max {MAX_PAGE_SIZE})",
|
||||
),
|
||||
]
|
||||
|
||||
@field_validator("filters", mode="before")
|
||||
@classmethod
|
||||
def parse_filters(cls, v: Any) -> List[PluginColumnFilter]:
|
||||
return parse_json_or_model_list(v, PluginColumnFilter, "filters")
|
||||
|
||||
@field_validator("select_columns", mode="before")
|
||||
@classmethod
|
||||
def parse_columns(cls, v: Any) -> List[str]:
|
||||
return parse_json_or_list(v, "select_columns")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_search_and_filters(self) -> "ListPluginsRequest":
|
||||
if self.search and self.filters:
|
||||
raise ValueError("Cannot use both 'search' and 'filters' simultaneously.")
|
||||
return self
|
||||
|
||||
|
||||
class PluginError(BaseModel):
|
||||
error: str = Field(..., description="Error message")
|
||||
error_type: str = Field(..., description="Type of error")
|
||||
timestamp: str | datetime | None = Field(None, description="Error timestamp")
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
@classmethod
|
||||
def create(cls, error: str, error_type: str) -> "PluginError":
|
||||
from datetime import timezone
|
||||
|
||||
return cls(
|
||||
error=error, error_type=error_type, timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
|
||||
class GetPluginInfoRequest(BaseModel):
|
||||
"""Request schema for get_plugin_info."""
|
||||
|
||||
identifier: Annotated[
|
||||
int,
|
||||
Field(description="Plugin ID"),
|
||||
]
|
||||
|
||||
|
||||
def serialize_plugin_object(plugin: Any) -> PluginInfo | None:
|
||||
if not plugin:
|
||||
return None
|
||||
|
||||
return PluginInfo(
|
||||
id=getattr(plugin, "id", None),
|
||||
name=getattr(plugin, "name", None),
|
||||
key=getattr(plugin, "key", None),
|
||||
bundle_url=getattr(plugin, "bundle_url", None),
|
||||
changed_on=getattr(plugin, "changed_on", None),
|
||||
created_on=getattr(plugin, "created_on", None),
|
||||
)
|
||||
24
superset/mcp_service/plugin/tool/__init__.py
Normal file
24
superset/mcp_service/plugin/tool/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from .get_plugin_info import get_plugin_info
|
||||
from .list_plugins import list_plugins
|
||||
|
||||
__all__ = [
|
||||
"list_plugins",
|
||||
"get_plugin_info",
|
||||
]
|
||||
101
superset/mcp_service/plugin/tool/get_plugin_info.py
Normal file
101
superset/mcp_service/plugin/tool/get_plugin_info.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Get plugin info FastMCP tool.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.mcp_core import ModelGetInfoCore
|
||||
from superset.mcp_service.plugin.schemas import (
|
||||
GetPluginInfoRequest,
|
||||
PluginError,
|
||||
PluginInfo,
|
||||
serialize_plugin_object,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["discovery"],
|
||||
class_permission_name="DynamicPlugin",
|
||||
annotations=ToolAnnotations(
|
||||
title="Get plugin info",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def get_plugin_info(
|
||||
request: GetPluginInfoRequest, ctx: Context
|
||||
) -> PluginInfo | PluginError:
|
||||
"""Get dynamic plugin details by ID. Requires admin access.
|
||||
|
||||
Returns full plugin configuration including name, key, and bundle URL.
|
||||
|
||||
Example usage:
|
||||
```json
|
||||
{"identifier": 1}
|
||||
```
|
||||
"""
|
||||
await ctx.info(
|
||||
"Retrieving plugin information: identifier=%s" % (request.identifier,)
|
||||
)
|
||||
|
||||
try:
|
||||
from superset.mcp_service.plugin.dao import DynamicPluginDAO
|
||||
|
||||
with event_logger.log_context(action="mcp.get_plugin_info.lookup"):
|
||||
get_tool = ModelGetInfoCore(
|
||||
dao_class=DynamicPluginDAO,
|
||||
output_schema=PluginInfo,
|
||||
error_schema=PluginError,
|
||||
serializer=serialize_plugin_object,
|
||||
supports_slug=False,
|
||||
logger=logger,
|
||||
)
|
||||
result = get_tool.run_tool(request.identifier)
|
||||
|
||||
if isinstance(result, PluginInfo):
|
||||
await ctx.info(
|
||||
"Plugin retrieved: id=%s, name=%s, key=%s"
|
||||
% (result.id, result.name, result.key)
|
||||
)
|
||||
else:
|
||||
await ctx.warning(
|
||||
"Plugin retrieval failed: error_type=%s, error=%s"
|
||||
% (result.error_type, result.error)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
await ctx.error(
|
||||
"Plugin info retrieval failed: identifier=%s, error=%s"
|
||||
% (request.identifier, str(e))
|
||||
)
|
||||
return PluginError(
|
||||
error=f"Failed to get plugin info: {str(e)}",
|
||||
error_type="InternalError",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
123
superset/mcp_service/plugin/tool/list_plugins.py
Normal file
123
superset/mcp_service/plugin/tool/list_plugins.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
List plugins FastMCP tool.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.mcp_core import ModelListCore
|
||||
from superset.mcp_service.plugin.schemas import (
|
||||
ALL_PLUGIN_COLUMNS,
|
||||
DEFAULT_PLUGIN_COLUMNS,
|
||||
ListPluginsRequest,
|
||||
PluginColumnFilter,
|
||||
PluginError,
|
||||
PluginInfo,
|
||||
PluginList,
|
||||
serialize_plugin_object,
|
||||
SORTABLE_PLUGIN_COLUMNS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_LIST_PLUGINS_REQUEST = ListPluginsRequest()
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["core"],
|
||||
class_permission_name="DynamicPlugin",
|
||||
annotations=ToolAnnotations(
|
||||
title="List plugins",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def list_plugins(
|
||||
request: ListPluginsRequest | None = None,
|
||||
ctx: Context | None = None,
|
||||
) -> PluginList | PluginError:
|
||||
"""List dynamic plugins registered in this Superset instance. Requires admin access.
|
||||
|
||||
Returns plugin metadata including name, key, and bundle URL.
|
||||
|
||||
Sortable columns for order_column: id, name, key, changed_on, created_on
|
||||
"""
|
||||
if ctx is None:
|
||||
raise RuntimeError("FastMCP context is required for list_plugins")
|
||||
|
||||
request = request or _DEFAULT_LIST_PLUGINS_REQUEST.model_copy(deep=True)
|
||||
|
||||
await ctx.info(
|
||||
"Listing plugins: page=%s, page_size=%s, search=%s"
|
||||
% (request.page, request.page_size, request.search)
|
||||
)
|
||||
|
||||
try:
|
||||
from superset.mcp_service.plugin.dao import DynamicPluginDAO
|
||||
|
||||
def _serialize(obj: object, cols: list[str]) -> PluginInfo | None:
|
||||
return serialize_plugin_object(obj)
|
||||
|
||||
list_tool = ModelListCore(
|
||||
dao_class=DynamicPluginDAO,
|
||||
output_schema=PluginInfo,
|
||||
item_serializer=_serialize,
|
||||
filter_type=PluginColumnFilter,
|
||||
default_columns=DEFAULT_PLUGIN_COLUMNS,
|
||||
search_columns=["name", "key"],
|
||||
list_field_name="plugins",
|
||||
output_list_schema=PluginList,
|
||||
all_columns=ALL_PLUGIN_COLUMNS,
|
||||
sortable_columns=SORTABLE_PLUGIN_COLUMNS,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
with event_logger.log_context(action="mcp.list_plugins.query"):
|
||||
result = list_tool.run_tool(
|
||||
filters=request.filters,
|
||||
search=request.search,
|
||||
select_columns=request.select_columns,
|
||||
order_column=request.order_column,
|
||||
order_direction=request.order_direction,
|
||||
page=max(request.page - 1, 0),
|
||||
page_size=request.page_size,
|
||||
)
|
||||
|
||||
await ctx.info(
|
||||
"Plugins listed: count=%s, total_count=%s"
|
||||
% (len(result.plugins), result.total_count)
|
||||
)
|
||||
|
||||
columns_to_filter = result.columns_requested
|
||||
with event_logger.log_context(action="mcp.list_plugins.serialization"):
|
||||
return result.model_dump(
|
||||
mode="json",
|
||||
context={"select_columns": columns_to_filter},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
await ctx.error(
|
||||
"Plugin listing failed: error=%s, error_type=%s"
|
||||
% (str(e), type(e).__name__)
|
||||
)
|
||||
raise
|
||||
@@ -44,13 +44,20 @@ USER_DIRECTORY_FIELDS = frozenset(
|
||||
}
|
||||
)
|
||||
|
||||
# User-directory columns that may be used as filter values (an integer user ID).
|
||||
# These remain stripped from select_columns, sort, search, and tool responses
|
||||
# (so the directory itself is never exposed), but list tools may filter rows by
|
||||
# them when the caller already has an ID — typically resolved via find_users.
|
||||
USER_FILTER_FIELDS = frozenset({"created_by_fk", "changed_by_fk"})
|
||||
|
||||
# Internal DAO filter column names generated server-side when translating the
|
||||
# created_by_me / owned_by_me boolean flags (see mcp_core._prepend_self_lookup_filters).
|
||||
# These columns are never exposed to LLM callers; they are excluded from the
|
||||
# filters_applied response field to avoid leaking internal implementation details.
|
||||
SELF_REFERENCING_FILTER_COLUMNS = frozenset(
|
||||
{"created_by_fk", "owner", "created_by_fk_or_owner"}
|
||||
)
|
||||
# Note: ``created_by_fk`` is intentionally excluded — it is also a publicly
|
||||
# advertised filter column (see USER_FILTER_FIELDS) so callers can filter by a
|
||||
# user ID resolved via find_users.
|
||||
SELF_REFERENCING_FILTER_COLUMNS = frozenset({"owner", "created_by_fk_or_owner"})
|
||||
|
||||
DATA_MODEL_METADATA_ACCESS_ATTR = "_requires_data_model_metadata_access"
|
||||
DATA_MODEL_METADATA_ERROR_TYPE = "DataModelMetadataRestricted"
|
||||
@@ -133,7 +140,7 @@ def user_can_view_data_model_metadata() -> bool:
|
||||
|
||||
|
||||
def filter_user_directory_fields(data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Remove fields that expose users, roles, owners, or access metadata."""
|
||||
"""Remove fields that expose users, owners, or access metadata."""
|
||||
return {
|
||||
key: value for key, value in data.items() if key not in USER_DIRECTORY_FIELDS
|
||||
}
|
||||
|
||||
16
superset/mcp_service/rls/__init__.py
Normal file
16
superset/mcp_service/rls/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
255
superset/mcp_service/rls/schemas.py
Normal file
255
superset/mcp_service/rls/schemas.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Pydantic schemas for row level security filter responses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Dict, List, Literal
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
field_validator,
|
||||
model_serializer,
|
||||
model_validator,
|
||||
PositiveInt,
|
||||
)
|
||||
|
||||
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
|
||||
from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
|
||||
from superset.mcp_service.system.schemas import PaginationInfo
|
||||
from superset.mcp_service.utils.schema_utils import (
|
||||
parse_json_or_list,
|
||||
parse_json_or_model_list,
|
||||
)
|
||||
|
||||
DEFAULT_RLS_COLUMNS = ["id", "name", "filter_type", "clause"]
|
||||
|
||||
ALL_RLS_COLUMNS = [
|
||||
"id",
|
||||
"name",
|
||||
"filter_type",
|
||||
"tables",
|
||||
"roles",
|
||||
"clause",
|
||||
"group_key",
|
||||
"changed_on",
|
||||
]
|
||||
|
||||
SORTABLE_RLS_COLUMNS = ["id", "name", "filter_type", "changed_on"]
|
||||
|
||||
|
||||
class RlsColumnFilter(ColumnOperator):
|
||||
"""Filter object for RLS filter listing."""
|
||||
|
||||
col: Literal["name", "filter_type"] = Field(
|
||||
...,
|
||||
description="Column to filter on.",
|
||||
)
|
||||
opr: ColumnOperatorEnum = Field(..., description="Operator to use.")
|
||||
value: str | int | float | bool | List[str | int | float | bool] = Field(
|
||||
..., description="Value to filter by"
|
||||
)
|
||||
|
||||
|
||||
class RlsTableRef(BaseModel):
|
||||
id: int | None = Field(None, description="Table ID")
|
||||
table_name: str | None = Field(None, description="Table name")
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class RlsRoleRef(BaseModel):
|
||||
id: int | None = Field(None, description="Role ID")
|
||||
name: str | None = Field(None, description="Role name")
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class RlsFilterInfo(BaseModel):
|
||||
id: int | None = Field(None, description="RLS filter ID")
|
||||
name: str | None = Field(None, description="RLS filter name")
|
||||
filter_type: str | None = Field(None, description="Filter type: Regular or Base")
|
||||
tables: List[RlsTableRef] | None = Field(
|
||||
None, description="Tables this filter applies to"
|
||||
)
|
||||
roles: List[RlsRoleRef] | None = Field(
|
||||
None, description="Roles this filter applies to"
|
||||
)
|
||||
clause: str | None = Field(None, description="SQL WHERE clause")
|
||||
group_key: str | None = Field(
|
||||
None, description="Group key for Base filter grouping"
|
||||
)
|
||||
changed_on: str | datetime | None = Field(
|
||||
None, description="Last modification timestamp"
|
||||
)
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
ser_json_timedelta="iso8601",
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def _filter_fields_by_context(self, serializer: Any, info: Any) -> Dict[str, Any]:
|
||||
data = serializer(self)
|
||||
if info.context and isinstance(info.context, dict):
|
||||
select_columns = info.context.get("select_columns")
|
||||
if select_columns:
|
||||
requested_fields = set(select_columns)
|
||||
return {k: v for k, v in data.items() if k in requested_fields}
|
||||
return data
|
||||
|
||||
|
||||
class RlsFilterList(BaseModel):
|
||||
rls_filters: List[RlsFilterInfo]
|
||||
count: int
|
||||
total_count: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
has_previous: bool
|
||||
has_next: bool
|
||||
columns_requested: List[str] = Field(default_factory=list)
|
||||
columns_loaded: List[str] = Field(default_factory=list)
|
||||
columns_available: List[str] = Field(default_factory=list)
|
||||
sortable_columns: List[str] = Field(default_factory=list)
|
||||
filters_applied: List[RlsColumnFilter] = Field(default_factory=list)
|
||||
pagination: PaginationInfo | None = None
|
||||
timestamp: datetime | None = None
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
|
||||
class ListRlsFiltersRequest(BaseModel):
|
||||
"""Request schema for list_rls_filters."""
|
||||
|
||||
filters: Annotated[
|
||||
List[RlsColumnFilter],
|
||||
Field(
|
||||
default_factory=list,
|
||||
description="List of filter objects (col, opr, value). "
|
||||
"Cannot be used with search.",
|
||||
),
|
||||
]
|
||||
select_columns: Annotated[
|
||||
List[str],
|
||||
Field(
|
||||
default_factory=list,
|
||||
description="Columns to include in response. Defaults to common columns.",
|
||||
),
|
||||
]
|
||||
search: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
default=None,
|
||||
description="Text search on filter name. Cannot be used with filters.",
|
||||
),
|
||||
]
|
||||
order_column: Annotated[
|
||||
str | None, Field(default=None, description="Column to order results by")
|
||||
]
|
||||
order_direction: Annotated[
|
||||
Literal["asc", "desc"],
|
||||
Field(default="desc", description="Sort direction"),
|
||||
]
|
||||
page: Annotated[
|
||||
PositiveInt,
|
||||
Field(default=1, description="Page number (1-based)"),
|
||||
]
|
||||
page_size: Annotated[
|
||||
int,
|
||||
Field(
|
||||
default=DEFAULT_PAGE_SIZE,
|
||||
gt=0,
|
||||
le=MAX_PAGE_SIZE,
|
||||
description=f"Items per page (max {MAX_PAGE_SIZE})",
|
||||
),
|
||||
]
|
||||
|
||||
@field_validator("filters", mode="before")
|
||||
@classmethod
|
||||
def parse_filters(cls, v: Any) -> List[RlsColumnFilter]:
|
||||
return parse_json_or_model_list(v, RlsColumnFilter, "filters")
|
||||
|
||||
@field_validator("select_columns", mode="before")
|
||||
@classmethod
|
||||
def parse_columns(cls, v: Any) -> List[str]:
|
||||
return parse_json_or_list(v, "select_columns")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_search_and_filters(self) -> "ListRlsFiltersRequest":
|
||||
if self.search and self.filters:
|
||||
raise ValueError("Cannot use both 'search' and 'filters' simultaneously.")
|
||||
return self
|
||||
|
||||
|
||||
class RlsFilterError(BaseModel):
|
||||
error: str = Field(..., description="Error message")
|
||||
error_type: str = Field(..., description="Type of error")
|
||||
timestamp: str | datetime | None = Field(None, description="Error timestamp")
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
@classmethod
|
||||
def create(cls, error: str, error_type: str) -> "RlsFilterError":
|
||||
from datetime import timezone
|
||||
|
||||
return cls(
|
||||
error=error, error_type=error_type, timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
|
||||
class GetRlsFilterInfoRequest(BaseModel):
|
||||
"""Request schema for get_rls_filter_info."""
|
||||
|
||||
identifier: Annotated[
|
||||
int,
|
||||
Field(description="RLS filter ID"),
|
||||
]
|
||||
|
||||
|
||||
def serialize_rls_filter_object(rls_filter: Any) -> RlsFilterInfo | None:
|
||||
if not rls_filter:
|
||||
return None
|
||||
|
||||
tables = [
|
||||
RlsTableRef(
|
||||
id=getattr(t, "id", None),
|
||||
table_name=getattr(t, "table_name", None),
|
||||
)
|
||||
for t in (getattr(rls_filter, "tables", None) or [])
|
||||
]
|
||||
|
||||
roles = [
|
||||
RlsRoleRef(
|
||||
id=getattr(r, "id", None),
|
||||
name=getattr(r, "name", None),
|
||||
)
|
||||
for r in (getattr(rls_filter, "roles", None) or [])
|
||||
]
|
||||
|
||||
return RlsFilterInfo(
|
||||
id=getattr(rls_filter, "id", None),
|
||||
name=getattr(rls_filter, "name", None),
|
||||
filter_type=getattr(rls_filter, "filter_type", None),
|
||||
tables=tables,
|
||||
roles=roles,
|
||||
clause=getattr(rls_filter, "clause", None),
|
||||
group_key=getattr(rls_filter, "group_key", None),
|
||||
changed_on=getattr(rls_filter, "changed_on", None),
|
||||
)
|
||||
24
superset/mcp_service/rls/tool/__init__.py
Normal file
24
superset/mcp_service/rls/tool/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from .get_rls_filter_info import get_rls_filter_info
|
||||
from .list_rls_filters import list_rls_filters
|
||||
|
||||
__all__ = [
|
||||
"list_rls_filters",
|
||||
"get_rls_filter_info",
|
||||
]
|
||||
101
superset/mcp_service/rls/tool/get_rls_filter_info.py
Normal file
101
superset/mcp_service/rls/tool/get_rls_filter_info.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Get RLS filter info FastMCP tool.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.mcp_core import ModelGetInfoCore
|
||||
from superset.mcp_service.rls.schemas import (
|
||||
GetRlsFilterInfoRequest,
|
||||
RlsFilterError,
|
||||
RlsFilterInfo,
|
||||
serialize_rls_filter_object,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["discovery"],
|
||||
class_permission_name="Row Level Security",
|
||||
annotations=ToolAnnotations(
|
||||
title="Get RLS filter info",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def get_rls_filter_info(
|
||||
request: GetRlsFilterInfoRequest, ctx: Context
|
||||
) -> RlsFilterInfo | RlsFilterError:
|
||||
"""Get row level security filter details by ID. Requires admin access.
|
||||
|
||||
Returns full RLS filter configuration including name, type, tables, roles,
|
||||
and clause.
|
||||
|
||||
Example usage:
|
||||
```json
|
||||
{"identifier": 1}
|
||||
```
|
||||
"""
|
||||
await ctx.info(
|
||||
"Retrieving RLS filter information: identifier=%s" % (request.identifier,)
|
||||
)
|
||||
|
||||
try:
|
||||
from superset.daos.security import RLSDAO
|
||||
|
||||
with event_logger.log_context(action="mcp.get_rls_filter_info.lookup"):
|
||||
get_tool = ModelGetInfoCore(
|
||||
dao_class=RLSDAO,
|
||||
output_schema=RlsFilterInfo,
|
||||
error_schema=RlsFilterError,
|
||||
serializer=serialize_rls_filter_object,
|
||||
supports_slug=False,
|
||||
logger=logger,
|
||||
)
|
||||
result = get_tool.run_tool(request.identifier)
|
||||
|
||||
if isinstance(result, RlsFilterInfo):
|
||||
await ctx.info(
|
||||
"RLS filter retrieved: id=%s, name=%s" % (result.id, result.name)
|
||||
)
|
||||
else:
|
||||
await ctx.warning(
|
||||
"RLS filter retrieval failed: error_type=%s, error=%s"
|
||||
% (result.error_type, result.error)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
await ctx.error(
|
||||
"RLS filter info retrieval failed: identifier=%s, error=%s"
|
||||
% (request.identifier, str(e))
|
||||
)
|
||||
return RlsFilterError(
|
||||
error=f"Failed to get RLS filter info: {str(e)}",
|
||||
error_type="InternalError",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
149
superset/mcp_service/rls/tool/list_rls_filters.py
Normal file
149
superset/mcp_service/rls/tool/list_rls_filters.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
List RLS filters FastMCP tool.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.mcp_core import ModelListCore
|
||||
from superset.mcp_service.privacy import USER_DIRECTORY_FIELDS
|
||||
from superset.mcp_service.rls.schemas import (
|
||||
ALL_RLS_COLUMNS,
|
||||
DEFAULT_RLS_COLUMNS,
|
||||
ListRlsFiltersRequest,
|
||||
RlsColumnFilter,
|
||||
RlsFilterError,
|
||||
RlsFilterInfo,
|
||||
RlsFilterList,
|
||||
serialize_rls_filter_object,
|
||||
SORTABLE_RLS_COLUMNS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_LIST_RLS_FILTERS_REQUEST = ListRlsFiltersRequest()
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["core"],
|
||||
class_permission_name="Row Level Security",
|
||||
annotations=ToolAnnotations(
|
||||
title="List RLS filters",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def list_rls_filters(
|
||||
request: ListRlsFiltersRequest | None = None,
|
||||
ctx: Context | None = None,
|
||||
) -> RlsFilterList | RlsFilterError:
|
||||
"""List row level security filters. Requires admin access.
|
||||
|
||||
Returns RLS filter metadata including name, filter type, tables, roles, and clause.
|
||||
|
||||
Sortable columns for order_column: id, name, filter_type, changed_on
|
||||
"""
|
||||
if ctx is None:
|
||||
raise RuntimeError("FastMCP context is required for list_rls_filters")
|
||||
|
||||
request = request or _DEFAULT_LIST_RLS_FILTERS_REQUEST.model_copy(deep=True)
|
||||
|
||||
await ctx.info(
|
||||
"Listing RLS filters: page=%s, page_size=%s, search=%s"
|
||||
% (request.page, request.page_size, request.search)
|
||||
)
|
||||
|
||||
try:
|
||||
from superset.daos.security import RLSDAO
|
||||
|
||||
def _serialize_rls_filter(obj: object, cols: list[str]) -> RlsFilterInfo | None:
|
||||
return serialize_rls_filter_object(obj)
|
||||
|
||||
list_tool = ModelListCore(
|
||||
dao_class=RLSDAO,
|
||||
output_schema=RlsFilterInfo,
|
||||
item_serializer=_serialize_rls_filter,
|
||||
filter_type=RlsColumnFilter,
|
||||
default_columns=DEFAULT_RLS_COLUMNS,
|
||||
search_columns=["name"],
|
||||
list_field_name="rls_filters",
|
||||
output_list_schema=RlsFilterList,
|
||||
all_columns=ALL_RLS_COLUMNS,
|
||||
sortable_columns=SORTABLE_RLS_COLUMNS,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
# RLS 'roles' is valid filter data but lives in USER_DIRECTORY_FIELDS,
|
||||
# so ModelListCore would raise ValueError for a column list that reduces
|
||||
# to empty after privacy filtering (e.g. select_columns=["roles"]).
|
||||
# Strip directory-field columns here; the bypass below restores them in
|
||||
# the final serialized output from ALL_RLS_COLUMNS.
|
||||
run_tool_columns = None
|
||||
if request.select_columns:
|
||||
non_directory = [
|
||||
c for c in request.select_columns if c not in USER_DIRECTORY_FIELDS
|
||||
]
|
||||
run_tool_columns = non_directory if non_directory else None
|
||||
|
||||
with event_logger.log_context(action="mcp.list_rls_filters.query"):
|
||||
result = list_tool.run_tool(
|
||||
filters=request.filters,
|
||||
search=request.search,
|
||||
select_columns=run_tool_columns,
|
||||
order_column=request.order_column,
|
||||
order_direction=request.order_direction,
|
||||
page=max(request.page - 1, 0),
|
||||
page_size=request.page_size,
|
||||
)
|
||||
|
||||
await ctx.info(
|
||||
"RLS filters listed: count=%s, total_count=%s"
|
||||
% (len(result.rls_filters), result.total_count)
|
||||
)
|
||||
|
||||
# Build column selection using ALL_RLS_COLUMNS as the source of truth,
|
||||
# bypassing the USER_DIRECTORY_FIELDS privacy filter applied by
|
||||
# ModelListCore. 'roles' in an RLS filter is which roles the filter
|
||||
# applies to — core filter data — not user-directory metadata (like
|
||||
# dashboard.roles, which exposes who has access to the resource).
|
||||
if request.select_columns:
|
||||
columns_to_filter = [
|
||||
c for c in request.select_columns if c in ALL_RLS_COLUMNS
|
||||
]
|
||||
if not columns_to_filter:
|
||||
columns_to_filter = list(DEFAULT_RLS_COLUMNS)
|
||||
else:
|
||||
columns_to_filter = list(DEFAULT_RLS_COLUMNS)
|
||||
|
||||
with event_logger.log_context(action="mcp.list_rls_filters.serialization"):
|
||||
return result.model_dump(
|
||||
mode="json",
|
||||
context={"select_columns": columns_to_filter},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
await ctx.error(
|
||||
"RLS filter listing failed: error=%s, error_type=%s"
|
||||
% (str(e), type(e).__name__)
|
||||
)
|
||||
raise
|
||||
@@ -25,9 +25,11 @@ system-level info.
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Annotated, Any, List
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
|
||||
|
||||
|
||||
class HealthCheckResponse(BaseModel):
|
||||
@@ -170,6 +172,84 @@ def serialize_user_object(user: Any) -> UserInfo | None:
|
||||
)
|
||||
|
||||
|
||||
class FindUsersRequest(BaseModel):
|
||||
"""Request schema for find_users tool.
|
||||
|
||||
Resolves a person's name (or partial name, username, or email) to user IDs
|
||||
so they can be passed to listing tools as filter values for created_by_fk
|
||||
or changed_by_fk. This is the only sanctioned path for "show me what
|
||||
<person> is working on" queries.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
query: Annotated[
|
||||
str,
|
||||
Field(
|
||||
min_length=1,
|
||||
max_length=200,
|
||||
description=(
|
||||
"Substring to match (case-insensitive) against username, "
|
||||
"first_name, last_name, and email. Required and non-empty: "
|
||||
"this tool does not enumerate the full user directory."
|
||||
),
|
||||
),
|
||||
]
|
||||
page_size: Annotated[
|
||||
int,
|
||||
Field(
|
||||
default=DEFAULT_PAGE_SIZE,
|
||||
gt=0,
|
||||
le=MAX_PAGE_SIZE,
|
||||
description=f"Maximum number of matches to return (max {MAX_PAGE_SIZE}).",
|
||||
),
|
||||
]
|
||||
|
||||
@field_validator("query")
|
||||
@classmethod
|
||||
def _reject_blank_query(cls, value: str) -> str:
|
||||
# min_length=1 alone admits whitespace-only strings, which strip to "" and
|
||||
# produce a "%%" LIKE pattern that matches every user. Strip and require
|
||||
# at least one non-space character.
|
||||
stripped = value.strip()
|
||||
if not stripped:
|
||||
raise ValueError("query must contain at least one non-whitespace character")
|
||||
return stripped
|
||||
|
||||
|
||||
class UserMatch(BaseModel):
|
||||
"""Minimal user projection returned by find_users.
|
||||
|
||||
Intentionally narrower than UserInfo: only the fields needed to disambiguate
|
||||
matches and pass an id to created_by_fk / changed_by_fk filters. Email,
|
||||
active flag, and roles are deliberately excluded to limit identity
|
||||
exposure through this directory-resolution path.
|
||||
"""
|
||||
|
||||
id: int | None = None
|
||||
username: str | None = None
|
||||
first_name: str | None = None
|
||||
last_name: str | None = None
|
||||
|
||||
|
||||
class FindUsersResponse(BaseModel):
|
||||
"""Response schema for find_users tool."""
|
||||
|
||||
users: List[UserMatch] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"Matching users. Pass user.id as the value for created_by_fk or "
|
||||
"changed_by_fk filters on list_dashboards, list_charts, and "
|
||||
"list_datasets."
|
||||
),
|
||||
)
|
||||
count: int = Field(..., description="Number of users returned in this response.")
|
||||
truncated: bool = Field(
|
||||
default=False,
|
||||
description="True when the query matched more rows than page_size allows.",
|
||||
)
|
||||
|
||||
|
||||
class TagInfo(BaseModel):
|
||||
id: int | None = None
|
||||
name: str | None = None
|
||||
|
||||
@@ -17,12 +17,14 @@
|
||||
|
||||
"""System tools for MCP service."""
|
||||
|
||||
from .find_users import find_users
|
||||
from .generate_bug_report import generate_bug_report
|
||||
from .get_instance_info import get_instance_info
|
||||
from .get_schema import get_schema
|
||||
from .health_check import health_check
|
||||
|
||||
__all__ = [
|
||||
"find_users",
|
||||
"generate_bug_report",
|
||||
"health_check",
|
||||
"get_instance_info",
|
||||
|
||||
101
superset/mcp_service/system/tool/find_users.py
Normal file
101
superset/mcp_service/system/tool/find_users.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""find_users MCP tool: resolve a person's name to user IDs for filtering."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastmcp import Context
|
||||
from sqlalchemy import or_
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import db, event_logger, security_manager
|
||||
from superset.mcp_service.system.schemas import (
|
||||
FindUsersRequest,
|
||||
FindUsersResponse,
|
||||
UserMatch,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["core"],
|
||||
annotations=ToolAnnotations(
|
||||
title="Find users",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def find_users(request: FindUsersRequest, ctx: Context) -> FindUsersResponse:
|
||||
"""Resolve a person's name to user IDs so they can be used as filter values.
|
||||
|
||||
Use this when the caller asks "show me <person>'s dashboards/charts/datasets"
|
||||
or "what is <person> working on". Take the matching user.id and pass it as
|
||||
the value for a created_by_fk or changed_by_fk filter on list_dashboards,
|
||||
list_charts, or list_datasets.
|
||||
|
||||
Matches case-insensitively against username, first_name, last_name, and
|
||||
email. The query is required and non-empty; this tool does not enumerate
|
||||
the full user directory.
|
||||
|
||||
Privacy: returning a user's identity here is sanctioned only for resolving
|
||||
filter values. Do not use the response to answer "who owns X", "who can
|
||||
access X", or any access-list question — those remain off-limits per the
|
||||
server instructions.
|
||||
"""
|
||||
await ctx.info(
|
||||
"Resolving user query: query=%s, page_size=%s"
|
||||
% (request.query, request.page_size)
|
||||
)
|
||||
|
||||
user_model = security_manager.user_model
|
||||
needle = f"%{request.query.strip()}%"
|
||||
|
||||
with event_logger.log_context(action="mcp.find_users.query"):
|
||||
query = (
|
||||
db.session.query(user_model)
|
||||
.filter(
|
||||
or_(
|
||||
user_model.username.ilike(needle),
|
||||
user_model.first_name.ilike(needle),
|
||||
user_model.last_name.ilike(needle),
|
||||
user_model.email.ilike(needle),
|
||||
)
|
||||
)
|
||||
.order_by(user_model.username.asc())
|
||||
)
|
||||
# Fetch one extra row to detect truncation without a separate count query.
|
||||
rows = query.limit(request.page_size + 1).all()
|
||||
|
||||
truncated = len(rows) > request.page_size
|
||||
rows = rows[: request.page_size]
|
||||
|
||||
users: list[UserMatch] = [
|
||||
UserMatch(
|
||||
id=getattr(row, "id", None),
|
||||
username=getattr(row, "username", None),
|
||||
first_name=getattr(row, "first_name", None),
|
||||
last_name=getattr(row, "last_name", None),
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
await ctx.info(
|
||||
"Resolved user query: matches=%s, truncated=%s" % (len(users), truncated)
|
||||
)
|
||||
return FindUsersResponse(users=users, count=len(users), truncated=truncated)
|
||||
@@ -24,9 +24,10 @@ single dataframe.
|
||||
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from time import time
|
||||
from datetime import date, datetime, time, timedelta, tzinfo
|
||||
from time import time as current_time
|
||||
from typing import Any, cast, Sequence, TypeGuard
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import isodate
|
||||
import numpy as np
|
||||
@@ -103,7 +104,7 @@ def get_results(query_object: QueryObject) -> QueryResult:
|
||||
raise ValueError("QueryObject must have a datasource defined.")
|
||||
|
||||
# Track execution time
|
||||
start_time = time()
|
||||
start_time = current_time()
|
||||
|
||||
semantic_view = query_object.datasource.implementation
|
||||
dispatcher = (
|
||||
@@ -127,7 +128,7 @@ def get_results(query_object: QueryObject) -> QueryResult:
|
||||
|
||||
# If no time offsets, return the main result as-is
|
||||
if not query_object.time_offsets or len(queries) <= 1:
|
||||
duration = timedelta(seconds=time() - start_time)
|
||||
duration = timedelta(seconds=current_time() - start_time)
|
||||
return map_semantic_result_to_query_result(
|
||||
main_result,
|
||||
query_object,
|
||||
@@ -197,7 +198,7 @@ def get_results(query_object: QueryObject) -> QueryResult:
|
||||
requests=all_requests,
|
||||
results=pa.Table.from_pandas(main_df),
|
||||
)
|
||||
duration = timedelta(seconds=time() - start_time)
|
||||
duration = timedelta(seconds=current_time() - start_time)
|
||||
return map_semantic_result_to_query_result(
|
||||
semantic_result,
|
||||
query_object,
|
||||
@@ -541,21 +542,29 @@ def _convert_query_object_filter(
|
||||
if operator_str == FilterOperator.TEMPORAL_RANGE.value:
|
||||
if not isinstance(value, str) or value == NO_TIME_RANGE:
|
||||
return None
|
||||
start, end = value.split(" : ")
|
||||
return {
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=dimension,
|
||||
operator=Operator.GREATER_THAN_OR_EQUAL,
|
||||
value=start,
|
||||
),
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=dimension,
|
||||
operator=Operator.LESS_THAN,
|
||||
value=end,
|
||||
),
|
||||
}
|
||||
start, end = (side.strip() for side in value.split(" : "))
|
||||
filters: set[Filter] = set()
|
||||
if start:
|
||||
filters.add(
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=dimension,
|
||||
operator=Operator.GREATER_THAN_OR_EQUAL,
|
||||
value=_coerce_scalar_filter_value(start, dimension),
|
||||
)
|
||||
)
|
||||
if end:
|
||||
filters.add(
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=dimension,
|
||||
operator=Operator.LESS_THAN,
|
||||
value=_coerce_scalar_filter_value(end, dimension),
|
||||
)
|
||||
)
|
||||
return filters or None
|
||||
|
||||
value = _coerce_filter_value(value, dimension)
|
||||
|
||||
# Map QueryObject operators to semantic layer operators
|
||||
operator_mapping = {
|
||||
@@ -588,6 +597,149 @@ def _convert_query_object_filter(
|
||||
}
|
||||
|
||||
|
||||
def _coerce_filter_value(
|
||||
value: FilterValues | frozenset[FilterValues],
|
||||
dimension: Dimension,
|
||||
) -> FilterValues | frozenset[FilterValues]:
|
||||
if isinstance(value, frozenset):
|
||||
return frozenset(_coerce_scalar_filter_value(v, dimension) for v in value)
|
||||
return _coerce_scalar_filter_value(value, dimension)
|
||||
|
||||
|
||||
def _timestamp_target_tz(dtype: pa.DataType) -> tzinfo | None:
|
||||
tz_name = getattr(dtype, "tz", None)
|
||||
return ZoneInfo(tz_name) if tz_name else None
|
||||
|
||||
|
||||
def _align_tz(dt: datetime, target_tz: tzinfo | None) -> datetime:
|
||||
if target_tz is None:
|
||||
return dt
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=target_tz)
|
||||
return dt.astimezone(target_tz)
|
||||
|
||||
|
||||
def _coerce_scalar_filter_value( # noqa: C901 — type dispatch, complexity is inherent
|
||||
value: FilterValues, dimension: Dimension
|
||||
) -> FilterValues:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
dtype = dimension.type
|
||||
|
||||
if pa.types.is_boolean(dtype):
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, (int, float)) and value in (0, 1):
|
||||
return bool(value)
|
||||
if isinstance(value, str):
|
||||
parsed = value.strip().lower()
|
||||
if parsed in {"true", "t", "1", "yes", "y", "on"}:
|
||||
return True
|
||||
if parsed in {"false", "f", "0", "no", "n", "off"}:
|
||||
return False
|
||||
raise ValueError(
|
||||
f"Invalid boolean value {value!r} for filter column {dimension.name}"
|
||||
)
|
||||
|
||||
if pa.types.is_integer(dtype):
|
||||
if isinstance(value, bool):
|
||||
raise ValueError(
|
||||
f"Invalid integer value {value!r} for filter column {dimension.name}"
|
||||
)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
if isinstance(value, float) and value.is_integer():
|
||||
return int(value)
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return int(value.strip())
|
||||
except ValueError as ex:
|
||||
raise ValueError(
|
||||
f"Invalid integer value {value!r} for filter column "
|
||||
f"{dimension.name}"
|
||||
) from ex
|
||||
raise ValueError(
|
||||
f"Invalid integer value {value!r} for filter column {dimension.name}"
|
||||
)
|
||||
|
||||
if pa.types.is_floating(dtype) or pa.types.is_decimal(dtype):
|
||||
# Decimal dimensions are coerced through ``float`` because ``FilterValues``
|
||||
# does not include ``Decimal``. That is lossless for the common case
|
||||
# (≤ ~15 significant digits) and matches how downstream semantic-view
|
||||
# implementations consume numeric filters; high-precision decimals would
|
||||
# need a wider ``FilterValues`` union and propagation through the cache's
|
||||
# comparability checks.
|
||||
if isinstance(value, bool):
|
||||
raise ValueError(
|
||||
f"Invalid numeric value {value!r} for filter column {dimension.name}"
|
||||
)
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return float(value.strip())
|
||||
except ValueError as ex:
|
||||
raise ValueError(
|
||||
f"Invalid numeric value {value!r} for filter column "
|
||||
f"{dimension.name}"
|
||||
) from ex
|
||||
raise ValueError(
|
||||
f"Invalid numeric value {value!r} for filter column {dimension.name}"
|
||||
)
|
||||
|
||||
if pa.types.is_date(dtype):
|
||||
if isinstance(value, datetime):
|
||||
return value.date()
|
||||
if isinstance(value, date):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return datetime.fromisoformat(value.strip()).date()
|
||||
except ValueError as ex:
|
||||
raise ValueError(
|
||||
f"Invalid date value {value!r} for filter column {dimension.name}"
|
||||
) from ex
|
||||
raise ValueError(
|
||||
f"Invalid date value {value!r} for filter column {dimension.name}"
|
||||
)
|
||||
|
||||
if pa.types.is_timestamp(dtype):
|
||||
target_tz = _timestamp_target_tz(dtype)
|
||||
if isinstance(value, datetime):
|
||||
return _align_tz(value, target_tz)
|
||||
if isinstance(value, date):
|
||||
return _align_tz(datetime.combine(value, time.min), target_tz)
|
||||
if isinstance(value, str):
|
||||
normalized = value.strip().replace("Z", "+00:00")
|
||||
try:
|
||||
return _align_tz(datetime.fromisoformat(normalized), target_tz)
|
||||
except ValueError as ex:
|
||||
raise ValueError(
|
||||
f"Invalid timestamp value {value!r} for filter column "
|
||||
f"{dimension.name}"
|
||||
) from ex
|
||||
raise ValueError(
|
||||
f"Invalid timestamp value {value!r} for filter column {dimension.name}"
|
||||
)
|
||||
|
||||
if pa.types.is_time(dtype):
|
||||
if isinstance(value, time):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return time.fromisoformat(value.strip())
|
||||
except ValueError as ex:
|
||||
raise ValueError(
|
||||
f"Invalid time value {value!r} for filter column {dimension.name}"
|
||||
) from ex
|
||||
raise ValueError(
|
||||
f"Invalid time value {value!r} for filter column {dimension.name}"
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _get_order_from_query_object(
|
||||
query_object: ValidatedQueryObject,
|
||||
all_metrics: dict[str, Metric],
|
||||
|
||||
@@ -324,7 +324,10 @@ class WebDriverPlaywright(WebDriverProxy):
|
||||
'document.querySelectorAll(".chart-container").length'
|
||||
)
|
||||
dashboard_height = page.evaluate(
|
||||
f'document.querySelector(".{element_name}").scrollHeight || 0'
|
||||
f"""() => {{
|
||||
const target = document.querySelector(\".{element_name}\");
|
||||
return target ? target.scrollHeight : 0;
|
||||
}}"""
|
||||
)
|
||||
chart_threshold = app.config.get(
|
||||
"SCREENSHOT_TILED_CHART_THRESHOLD", 20
|
||||
@@ -336,6 +339,14 @@ class WebDriverPlaywright(WebDriverProxy):
|
||||
"SCREENSHOT_TILED_VIEWPORT_HEIGHT", viewport_height
|
||||
)
|
||||
|
||||
if dashboard_height == 0:
|
||||
logger.warning(
|
||||
"Could not determine dashboard height for element %s "
|
||||
"at url %s; falling back to standard screenshot behavior",
|
||||
element_name,
|
||||
url,
|
||||
)
|
||||
|
||||
# Use tiled screenshots for large dashboards
|
||||
use_tiled = (
|
||||
chart_count >= chart_threshold
|
||||
|
||||
@@ -29,6 +29,7 @@ from flask_appbuilder.api import (
|
||||
rison as parse_rison,
|
||||
safe,
|
||||
)
|
||||
from flask_appbuilder.const import API_FILTERS_RIS_KEY
|
||||
from flask_appbuilder.models.filters import BaseFilter, Filters
|
||||
from flask_appbuilder.models.sqla.filters import FilterStartsWith
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
@@ -377,6 +378,35 @@ class BaseSupersetModelRestApi(BaseSupersetApiMixin, ModelRestApi):
|
||||
self.add_columns = [model_id]
|
||||
super()._init_properties()
|
||||
|
||||
def _handle_filters_args(self, rison_args: dict[str, Any]) -> Filters:
|
||||
"""
|
||||
Build a request-scoped ``Filters`` instance from Rison-encoded args.
|
||||
|
||||
Overrides :meth:`flask_appbuilder.api.ModelRestApi._handle_filters_args`,
|
||||
which mutates ``self._filters`` (a single instance shared across
|
||||
requests on the same API view). Under concurrent traffic that shared
|
||||
state can leak filters from one request into another — e.g. two
|
||||
parallel ``GET /api/v1/<resource>/`` calls filtering by different
|
||||
values can return mixed results.
|
||||
|
||||
Returning a fresh ``Filters`` per call keeps each request isolated.
|
||||
Applies to every subclass of ``BaseSupersetModelRestApi``
|
||||
(datasets, charts, dashboards, saved queries, queries, databases,
|
||||
etc.) — see issue #33828 for the original report on the dataset
|
||||
endpoint.
|
||||
|
||||
:param rison_args: Arguments parsed from the API request's
|
||||
Rison-encoded ``q`` parameter.
|
||||
:returns: A request-scoped ``Filters`` instance joined with the
|
||||
API's base filters.
|
||||
"""
|
||||
filters = self.datamodel.get_filters(
|
||||
search_columns=self.search_columns,
|
||||
search_filters=self.search_filters,
|
||||
)
|
||||
filters.rest_add_filters(rison_args.get(API_FILTERS_RIS_KEY, []))
|
||||
return filters.get_joined_filters(self._base_filters)
|
||||
|
||||
def _get_related_filter(
|
||||
self, datamodel: SQLAInterface, column_name: str, value: str
|
||||
) -> Filters:
|
||||
|
||||
@@ -120,3 +120,46 @@ def test_get_dataset_include_rendered_sql_passes_table_to_template_processor(
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_get_processor.assert_called_once_with(database=database, table=dataset)
|
||||
|
||||
|
||||
def test_handle_filters_args_returns_request_scoped_filters(
|
||||
session: Session,
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
) -> None:
|
||||
"""
|
||||
``_handle_filters_args`` must return a fresh ``Filters`` instance per
|
||||
call so concurrent requests don't share filter state.
|
||||
|
||||
Regression test for #33828: under concurrent traffic the FAB default
|
||||
implementation mutates ``self._filters`` (a single shared instance),
|
||||
causing filters from one request to leak into another.
|
||||
|
||||
The fix lives on ``BaseSupersetModelRestApi`` so every superset REST
|
||||
API subclass (datasets, charts, dashboards, saved queries, etc.)
|
||||
inherits the request-scoped behavior. This test exercises it via
|
||||
``DatasetRestApi`` as a concrete subclass.
|
||||
"""
|
||||
from flask_appbuilder.const import API_FILTERS_RIS_KEY
|
||||
|
||||
from superset.datasets.api import DatasetRestApi
|
||||
|
||||
api = DatasetRestApi()
|
||||
api.datamodel = MagicMock()
|
||||
api.search_columns = ["table_name"]
|
||||
api.search_filters = {}
|
||||
api._base_filters = MagicMock() # noqa: SLF001
|
||||
|
||||
# Each call should construct a fresh Filters instance via datamodel.get_filters
|
||||
rison_args = {
|
||||
API_FILTERS_RIS_KEY: [{"col": "table_name", "opr": "eq", "value": "a"}],
|
||||
}
|
||||
api._handle_filters_args(rison_args) # noqa: SLF001
|
||||
api._handle_filters_args(rison_args) # noqa: SLF001
|
||||
|
||||
assert api.datamodel.get_filters.call_count == 2
|
||||
# Returned object must be the joined-filters result of the *fresh* Filters,
|
||||
# not the shared self._filters attribute.
|
||||
fresh_filters = api.datamodel.get_filters.return_value
|
||||
assert fresh_filters.rest_add_filters.call_count == 2
|
||||
assert fresh_filters.get_joined_filters.call_count == 2
|
||||
|
||||
@@ -594,7 +594,34 @@ class TestMapXYConfig:
|
||||
|
||||
assert result["viz_type"] == "echarts_timeseries_scatter"
|
||||
assert result["show_legend"] is False
|
||||
assert result["legend_orientation"] == "top"
|
||||
assert result["legendOrientation"] == "top"
|
||||
|
||||
def test_map_xy_config_with_color_scheme(self) -> None:
|
||||
"""color_scheme propagates to form_data when set."""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="revenue")],
|
||||
kind="line",
|
||||
color_scheme="lyftColors",
|
||||
)
|
||||
|
||||
result = map_xy_config(config)
|
||||
|
||||
assert result["color_scheme"] == "lyftColors"
|
||||
|
||||
def test_map_xy_config_without_color_scheme(self) -> None:
|
||||
"""color_scheme key omitted when not set, leaving Superset default."""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="revenue")],
|
||||
kind="line",
|
||||
)
|
||||
|
||||
result = map_xy_config(config)
|
||||
|
||||
assert "color_scheme" not in result
|
||||
|
||||
def test_map_xy_config_with_time_grain_month(self) -> None:
|
||||
"""Test XY config mapping with monthly time grain"""
|
||||
|
||||
@@ -29,18 +29,23 @@ from pydantic import ValidationError
|
||||
|
||||
from superset.mcp_service.chart.chart_utils import (
|
||||
generate_chart_name,
|
||||
map_big_number_config,
|
||||
map_config_to_form_data,
|
||||
map_mixed_timeseries_config,
|
||||
map_pie_config,
|
||||
map_pivot_table_config,
|
||||
map_table_config,
|
||||
)
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
AxisConfig,
|
||||
BigNumberChartConfig,
|
||||
ColumnRef,
|
||||
CurrencyFormat,
|
||||
FilterConfig,
|
||||
MixedTimeseriesChartConfig,
|
||||
PieChartConfig,
|
||||
PivotTableChartConfig,
|
||||
TableChartConfig,
|
||||
)
|
||||
from superset.mcp_service.chart.validation.schema_validator import SchemaValidator
|
||||
|
||||
@@ -212,6 +217,18 @@ class TestMapPieConfig:
|
||||
assert result["adhoc_filters"][0]["operator"] == "=="
|
||||
assert result["adhoc_filters"][0]["comparator"] == "US"
|
||||
|
||||
def test_pie_form_data_color_scheme_override(self) -> None:
|
||||
"""Explicit color_scheme overrides the supersetColors default."""
|
||||
config = PieChartConfig(
|
||||
chart_type="pie",
|
||||
dimension=ColumnRef(name="product"),
|
||||
metric=ColumnRef(name="revenue", aggregate="SUM"),
|
||||
color_scheme="googleCategory10c",
|
||||
)
|
||||
result = map_pie_config(config)
|
||||
|
||||
assert result["color_scheme"] == "googleCategory10c"
|
||||
|
||||
def test_pie_form_data_custom_options(self) -> None:
|
||||
config = PieChartConfig(
|
||||
chart_type="pie",
|
||||
@@ -975,3 +992,272 @@ class TestSchemaValidatorNewTypes:
|
||||
assert is_valid is False
|
||||
assert error is not None
|
||||
assert error.error_code == "INVALID_CHART_TYPE"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Chart Formatting Options Tests (sc-102806 follow-up)
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestPieFormattingOptions:
|
||||
"""number/date/currency format, color scheme, legend orientation on Pie."""
|
||||
|
||||
def test_currency_format_in_form_data(self) -> None:
|
||||
config = PieChartConfig(
|
||||
chart_type="pie",
|
||||
dimension=ColumnRef(name="product"),
|
||||
metric=ColumnRef(name="revenue", aggregate="SUM"),
|
||||
currency_format=CurrencyFormat(symbol="USD", symbol_position="prefix"),
|
||||
)
|
||||
result = map_pie_config(config)
|
||||
|
||||
assert result["currency_format"] == {
|
||||
"symbol": "USD",
|
||||
"symbolPosition": "prefix",
|
||||
}
|
||||
|
||||
def test_currency_format_omitted_when_unset(self) -> None:
|
||||
config = PieChartConfig(
|
||||
chart_type="pie",
|
||||
dimension=ColumnRef(name="product"),
|
||||
metric=ColumnRef(name="revenue", aggregate="SUM"),
|
||||
)
|
||||
result = map_pie_config(config)
|
||||
|
||||
assert "currency_format" not in result
|
||||
|
||||
def test_legend_orientation_in_form_data(self) -> None:
|
||||
config = PieChartConfig(
|
||||
chart_type="pie",
|
||||
dimension=ColumnRef(name="product"),
|
||||
metric=ColumnRef(name="revenue", aggregate="SUM"),
|
||||
legend_orientation="bottom",
|
||||
)
|
||||
result = map_pie_config(config)
|
||||
|
||||
assert result["legendOrientation"] == "bottom"
|
||||
|
||||
def test_default_legend_orientation_is_top(self) -> None:
|
||||
config = PieChartConfig(
|
||||
chart_type="pie",
|
||||
dimension=ColumnRef(name="product"),
|
||||
metric=ColumnRef(name="revenue", aggregate="SUM"),
|
||||
)
|
||||
result = map_pie_config(config)
|
||||
|
||||
assert result["legendOrientation"] == "top"
|
||||
|
||||
def test_date_format_overridable(self) -> None:
|
||||
config = PieChartConfig(
|
||||
chart_type="pie",
|
||||
dimension=ColumnRef(name="ds"),
|
||||
metric=ColumnRef(name="revenue", aggregate="SUM"),
|
||||
date_format="%Y-%m-%d",
|
||||
)
|
||||
result = map_pie_config(config)
|
||||
|
||||
assert result["date_format"] == "%Y-%m-%d"
|
||||
|
||||
|
||||
class TestPivotTableFormattingOptions:
|
||||
"""date/currency format on PivotTable."""
|
||||
|
||||
def test_currency_format_in_form_data(self) -> None:
|
||||
config = PivotTableChartConfig(
|
||||
chart_type="pivot_table",
|
||||
rows=[ColumnRef(name="region")],
|
||||
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
currency_format=CurrencyFormat(symbol="EUR", symbol_position="suffix"),
|
||||
)
|
||||
result = map_pivot_table_config(config)
|
||||
|
||||
assert result["currency_format"] == {
|
||||
"symbol": "EUR",
|
||||
"symbolPosition": "suffix",
|
||||
}
|
||||
|
||||
def test_date_format_in_form_data(self) -> None:
|
||||
config = PivotTableChartConfig(
|
||||
chart_type="pivot_table",
|
||||
rows=[ColumnRef(name="ds")],
|
||||
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
date_format="%Y-%m",
|
||||
)
|
||||
result = map_pivot_table_config(config)
|
||||
|
||||
assert result["date_format"] == "%Y-%m"
|
||||
|
||||
def test_formatting_omitted_when_unset(self) -> None:
|
||||
config = PivotTableChartConfig(
|
||||
chart_type="pivot_table",
|
||||
rows=[ColumnRef(name="region")],
|
||||
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
)
|
||||
result = map_pivot_table_config(config)
|
||||
|
||||
assert "currency_format" not in result
|
||||
assert "date_format" not in result
|
||||
|
||||
|
||||
class TestMixedTimeseriesFormattingOptions:
|
||||
"""color scheme, currency format, legend orientation, data labels on Mixed."""
|
||||
|
||||
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
|
||||
def test_color_scheme_in_form_data(self, mock_is_temporal) -> None:
|
||||
mock_is_temporal.return_value = True
|
||||
config = MixedTimeseriesChartConfig(
|
||||
chart_type="mixed_timeseries",
|
||||
x=ColumnRef(name="ds"),
|
||||
y=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
|
||||
color_scheme="lyftColors",
|
||||
)
|
||||
result = map_mixed_timeseries_config(config)
|
||||
|
||||
assert result["color_scheme"] == "lyftColors"
|
||||
|
||||
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
|
||||
def test_currency_format_primary_and_secondary(self, mock_is_temporal) -> None:
|
||||
mock_is_temporal.return_value = True
|
||||
config = MixedTimeseriesChartConfig(
|
||||
chart_type="mixed_timeseries",
|
||||
x=ColumnRef(name="ds"),
|
||||
y=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
|
||||
currency_format=CurrencyFormat(symbol="USD"),
|
||||
currency_format_secondary=CurrencyFormat(symbol="GBP"),
|
||||
)
|
||||
result = map_mixed_timeseries_config(config)
|
||||
|
||||
assert result["currency_format"] == {
|
||||
"symbol": "USD",
|
||||
"symbolPosition": "prefix",
|
||||
}
|
||||
assert result["currency_format_secondary"] == {
|
||||
"symbol": "GBP",
|
||||
"symbolPosition": "prefix",
|
||||
}
|
||||
|
||||
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
|
||||
def test_legend_orientation_in_form_data(self, mock_is_temporal) -> None:
|
||||
mock_is_temporal.return_value = True
|
||||
config = MixedTimeseriesChartConfig(
|
||||
chart_type="mixed_timeseries",
|
||||
x=ColumnRef(name="ds"),
|
||||
y=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
|
||||
legend_orientation="left",
|
||||
)
|
||||
result = map_mixed_timeseries_config(config)
|
||||
|
||||
assert result["legendOrientation"] == "left"
|
||||
|
||||
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
|
||||
def test_show_value_data_labels(self, mock_is_temporal) -> None:
|
||||
mock_is_temporal.return_value = True
|
||||
config = MixedTimeseriesChartConfig(
|
||||
chart_type="mixed_timeseries",
|
||||
x=ColumnRef(name="ds"),
|
||||
y=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
|
||||
show_value=True,
|
||||
)
|
||||
result = map_mixed_timeseries_config(config)
|
||||
|
||||
assert result["show_value"] is True
|
||||
|
||||
|
||||
class TestBigNumberFormattingOptions:
|
||||
"""color scheme, currency format, time format on BigNumber."""
|
||||
|
||||
def test_currency_format_in_form_data(self) -> None:
|
||||
config = BigNumberChartConfig(
|
||||
chart_type="big_number",
|
||||
metric=ColumnRef(name="revenue", aggregate="SUM"),
|
||||
currency_format=CurrencyFormat(symbol="JPY", symbol_position="prefix"),
|
||||
)
|
||||
result = map_big_number_config(config)
|
||||
|
||||
assert result["currency_format"] == {
|
||||
"symbol": "JPY",
|
||||
"symbolPosition": "prefix",
|
||||
}
|
||||
|
||||
def test_color_scheme_in_form_data(self) -> None:
|
||||
config = BigNumberChartConfig(
|
||||
chart_type="big_number",
|
||||
metric=ColumnRef(name="revenue", aggregate="SUM"),
|
||||
color_scheme="d3Category10",
|
||||
)
|
||||
result = map_big_number_config(config)
|
||||
|
||||
assert result["color_scheme"] == "d3Category10"
|
||||
|
||||
def test_time_format_only_for_trendline(self) -> None:
|
||||
# Without trendline, time_format is dropped because the trendline
|
||||
# x-axis doesn't render.
|
||||
config = BigNumberChartConfig(
|
||||
chart_type="big_number",
|
||||
metric=ColumnRef(name="revenue", aggregate="SUM"),
|
||||
time_format="%Y-%m-%d",
|
||||
)
|
||||
result = map_big_number_config(config)
|
||||
|
||||
assert "time_format" not in result
|
||||
|
||||
def test_time_format_with_trendline(self) -> None:
|
||||
config = BigNumberChartConfig(
|
||||
chart_type="big_number",
|
||||
metric=ColumnRef(name="revenue", aggregate="SUM"),
|
||||
temporal_column="ds",
|
||||
show_trendline=True,
|
||||
time_format="%Y-%m-%d",
|
||||
)
|
||||
result = map_big_number_config(config)
|
||||
|
||||
assert result["time_format"] == "%Y-%m-%d"
|
||||
|
||||
|
||||
class TestTableFormattingOptions:
|
||||
"""color scheme on Table."""
|
||||
|
||||
def test_color_scheme_in_form_data(self) -> None:
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="product"), ColumnRef(name="revenue")],
|
||||
color_scheme="lyftColors",
|
||||
)
|
||||
result = map_table_config(config)
|
||||
|
||||
assert result["color_scheme"] == "lyftColors"
|
||||
|
||||
def test_color_scheme_omitted_when_unset(self) -> None:
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="product"), ColumnRef(name="revenue")],
|
||||
)
|
||||
result = map_table_config(config)
|
||||
|
||||
assert "color_scheme" not in result
|
||||
|
||||
|
||||
class TestCurrencyFormatModel:
|
||||
"""CurrencyFormat schema validation."""
|
||||
|
||||
def test_default_symbol_position_is_prefix(self) -> None:
|
||||
cf = CurrencyFormat(symbol="USD")
|
||||
assert cf.symbol_position == "prefix"
|
||||
|
||||
def test_camel_case_alias_accepted(self) -> None:
|
||||
cf = CurrencyFormat.model_validate(
|
||||
{"symbol": "USD", "symbolPosition": "suffix"}
|
||||
)
|
||||
assert cf.symbol_position == "suffix"
|
||||
|
||||
def test_invalid_position_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
CurrencyFormat(symbol="USD", symbol_position="middle")
|
||||
|
||||
def test_to_form_data_shape(self) -> None:
|
||||
cf = CurrencyFormat(symbol="EUR", symbol_position="suffix")
|
||||
assert cf.to_form_data() == {"symbol": "EUR", "symbolPosition": "suffix"}
|
||||
|
||||
@@ -33,11 +33,15 @@ from superset.mcp_service.chart.schemas import (
|
||||
PerformanceMetadata,
|
||||
)
|
||||
from superset.mcp_service.chart.tool.get_chart_data import (
|
||||
_GENERIC_TYPE_MAP,
|
||||
_MAX_RECOMMENDATIONS,
|
||||
_query_from_form_data,
|
||||
_recommend_visualizations,
|
||||
_sanitize_chart_data_for_llm_context,
|
||||
)
|
||||
from superset.mcp_service.utils import sanitize_for_llm_context
|
||||
from superset.mcp_service.utils.sanitization import LLM_CONTEXT_ESCAPED_CLOSE_DELIMITER
|
||||
from superset.utils.core import GenericDataType
|
||||
|
||||
|
||||
def _collect_groupby_extras(
|
||||
@@ -1167,3 +1171,284 @@ class TestChartDataCommandValidation:
|
||||
)
|
||||
|
||||
mock_command.run.assert_not_called()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
from superset.mcp_service.app import mcp
|
||||
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth():
|
||||
"""Mock MCP auth so Client.call_tool() doesn't need a real admin user."""
|
||||
import importlib
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
_gcd_module = importlib.import_module(
|
||||
"superset.mcp_service.chart.tool.get_chart_data"
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _noop_log_context(*_args: Any, **_kwargs: Any) -> Any:
|
||||
yield lambda **_kw: None
|
||||
|
||||
# Neutralize event_logger.log_context: the default DBEventLogger would
|
||||
# otherwise insert a log row referencing our mock user_id and fail a
|
||||
# FK constraint against the real users table. Patch via the module
|
||||
# object directly — the `tool` package's __init__.py re-exports the
|
||||
# get_chart_data function under the same name, which shadows the
|
||||
# submodule binding in the package namespace, so a dotted-string patch
|
||||
# target resolves to the function and mock.patch cannot find
|
||||
# event_logger on it.
|
||||
mock_event_logger = Mock()
|
||||
mock_event_logger.log_context.side_effect = _noop_log_context
|
||||
with (
|
||||
patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user,
|
||||
patch.object(_gcd_module, "event_logger", mock_event_logger),
|
||||
):
|
||||
user = Mock()
|
||||
user.id = 1
|
||||
user.username = "admin"
|
||||
mock_get_user.return_value = user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
def _extract_metrics_load_path(load_opt: Any) -> list[str]:
|
||||
"""Walk a SQLAlchemy Load option and return the attr chain.
|
||||
|
||||
e.g. subqueryload(Slice.table).subqueryload(SqlaTable.metrics)
|
||||
-> ["table", "metrics"]
|
||||
"""
|
||||
path = getattr(load_opt, "path", ())
|
||||
return [elem.key for elem in path if hasattr(elem, "key")]
|
||||
|
||||
|
||||
class TestChartLookupEagerLoading:
|
||||
"""Tests that get_chart_data eager-loads dataset.metrics on chart lookup.
|
||||
|
||||
Regression tests for the Excel export DetachedInstanceError on
|
||||
dataset.metrics. The chart's dataset metrics relationship must be
|
||||
eager-loaded at fetch time so it remains accessible during Excel export
|
||||
after the request-scoped session is detached.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_numeric_id_lookup_passes_metrics_eager_load(
|
||||
self, mcp_server, mock_auth
|
||||
):
|
||||
"""Integer identifier lookup must eager-load Slice.table.metrics."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastmcp import Client
|
||||
|
||||
with patch(
|
||||
"superset.daos.chart.ChartDAO.find_by_id", return_value=None
|
||||
) as mock_find:
|
||||
async with Client(mcp_server) as client:
|
||||
await client.call_tool(
|
||||
"get_chart_data",
|
||||
{"request": {"identifier": 42, "format": "excel"}},
|
||||
)
|
||||
|
||||
mock_find.assert_called_once()
|
||||
call = mock_find.call_args
|
||||
assert call.args == (42,)
|
||||
query_options = call.kwargs.get("query_options")
|
||||
assert query_options is not None, (
|
||||
"Chart lookup must pass query_options for eager-loading."
|
||||
)
|
||||
assert len(query_options) == 1
|
||||
load_path = _extract_metrics_load_path(query_options[0])
|
||||
assert load_path == ["table", "metrics"], (
|
||||
f"Expected subqueryload chain 'table' -> 'metrics', got {load_path}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uuid_lookup_passes_metrics_eager_load(self, mcp_server, mock_auth):
|
||||
"""UUID identifier lookup must also eager-load Slice.table.metrics."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastmcp import Client
|
||||
|
||||
uuid = "a1b2c3d4-5678-90ab-cdef-1234567890ab"
|
||||
|
||||
with patch(
|
||||
"superset.daos.chart.ChartDAO.find_by_id", return_value=None
|
||||
) as mock_find:
|
||||
async with Client(mcp_server) as client:
|
||||
await client.call_tool(
|
||||
"get_chart_data",
|
||||
{"request": {"identifier": uuid, "format": "excel"}},
|
||||
)
|
||||
|
||||
mock_find.assert_called_once()
|
||||
call = mock_find.call_args
|
||||
assert call.args == (uuid,)
|
||||
assert call.kwargs.get("id_column") == "uuid"
|
||||
query_options = call.kwargs.get("query_options")
|
||||
assert query_options is not None, (
|
||||
"UUID chart lookup must pass query_options for eager-loading."
|
||||
)
|
||||
load_path = _extract_metrics_load_path(query_options[0])
|
||||
assert load_path == ["table", "metrics"], (
|
||||
f"Expected subqueryload chain 'table' -> 'metrics', got {load_path}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_format_also_eager_loads_metrics(self, mcp_server, mock_auth):
|
||||
"""Eager-load is applied for every format, not just Excel.
|
||||
|
||||
Applying unconditionally keeps the fix robust if additional code paths
|
||||
start touching dataset.metrics, and avoids branching behavior that
|
||||
would be easy to regress on.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastmcp import Client
|
||||
|
||||
with patch(
|
||||
"superset.daos.chart.ChartDAO.find_by_id", return_value=None
|
||||
) as mock_find:
|
||||
async with Client(mcp_server) as client:
|
||||
await client.call_tool(
|
||||
"get_chart_data",
|
||||
{"request": {"identifier": 7, "format": "json"}},
|
||||
)
|
||||
|
||||
call = mock_find.call_args
|
||||
query_options = call.kwargs.get("query_options")
|
||||
assert query_options is not None
|
||||
assert _extract_metrics_load_path(query_options[0]) == ["table", "metrics"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for _recommend_visualizations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _col(
|
||||
name: str,
|
||||
data_type: str = "string",
|
||||
unique_count: int = 5,
|
||||
null_count: int = 0,
|
||||
) -> DataColumn:
|
||||
"""Shortcut to build a DataColumn for tests."""
|
||||
return DataColumn(
|
||||
name=name,
|
||||
display_name=name,
|
||||
data_type=data_type,
|
||||
sample_values=[],
|
||||
null_count=null_count,
|
||||
unique_count=unique_count,
|
||||
)
|
||||
|
||||
|
||||
def test_recommend_temporal_and_numeric_suggests_line_chart():
|
||||
cols = [_col("created_at", "temporal"), _col("revenue", "numeric")]
|
||||
result = _recommend_visualizations("table", cols, row_count=50)
|
||||
assert "line chart" in result
|
||||
assert "area chart" in result
|
||||
|
||||
|
||||
def test_recommend_categorical_and_numeric_suggests_bar_chart():
|
||||
cols = [_col("region", "string", unique_count=5), _col("sales", "numeric")]
|
||||
result = _recommend_visualizations("echarts_timeseries_line", cols, row_count=50)
|
||||
assert "bar chart" in result
|
||||
|
||||
|
||||
def test_recommend_excludes_current_viz_type():
|
||||
cols = [_col("created_at", "temporal"), _col("revenue", "numeric")]
|
||||
result = _recommend_visualizations("echarts_timeseries_line", cols, row_count=50)
|
||||
assert "line chart" not in result
|
||||
|
||||
|
||||
def test_recommend_multiple_numeric_suggests_scatter():
|
||||
cols = [
|
||||
_col("height", "numeric"),
|
||||
_col("weight", "numeric"),
|
||||
_col("age", "numeric"),
|
||||
]
|
||||
result = _recommend_visualizations("table", cols, row_count=100)
|
||||
assert "scatter plot" in result
|
||||
|
||||
|
||||
def test_recommend_single_numeric_suggests_kpi():
|
||||
cols = [_col("total_revenue", "numeric")]
|
||||
result = _recommend_visualizations("table", cols, row_count=1)
|
||||
assert "big number / KPI" in result
|
||||
|
||||
|
||||
def test_recommend_all_strings_falls_back():
|
||||
cols = [_col("name", "string"), _col("address", "string")]
|
||||
result = _recommend_visualizations("pie", cols, row_count=100)
|
||||
assert "table" in result or "bar chart" in result
|
||||
|
||||
|
||||
def test_recommend_high_cardinality_no_pie():
|
||||
cols = [
|
||||
_col("user_id", "string", unique_count=900),
|
||||
_col("score", "numeric"),
|
||||
]
|
||||
result = _recommend_visualizations("table", cols, row_count=1000)
|
||||
assert "pie chart" not in result
|
||||
|
||||
|
||||
def test_recommend_caps_at_max():
|
||||
cols = [_col("ts", "temporal"), _col("a", "numeric"), _col("b", "numeric")]
|
||||
result = _recommend_visualizations("table", cols, row_count=100)
|
||||
assert len(result) <= _MAX_RECOMMENDATIONS
|
||||
|
||||
|
||||
def test_recommend_empty_columns_returns_table():
|
||||
result = _recommend_visualizations("table", [], row_count=0)
|
||||
assert result == ["table"]
|
||||
|
||||
|
||||
def test_recommend_pie_only_for_low_cardinality():
|
||||
cols = [
|
||||
_col("department", "string", unique_count=25),
|
||||
_col("headcount", "numeric"),
|
||||
]
|
||||
result = _recommend_visualizations("table", cols, row_count=100)
|
||||
assert "pie chart" not in result
|
||||
|
||||
|
||||
def test_recommend_temporal_few_rows_prefers_bar():
|
||||
cols = [_col("date", "temporal"), _col("revenue", "numeric")]
|
||||
result = _recommend_visualizations("table", cols, row_count=3)
|
||||
assert "bar chart" in result
|
||||
assert "line chart" not in result
|
||||
|
||||
|
||||
def test_recommend_single_numeric_high_cardinality_suggests_histogram():
|
||||
cols = [_col("salary", "numeric", unique_count=500)]
|
||||
result = _recommend_visualizations("table", cols, row_count=1000)
|
||||
assert "histogram" in result
|
||||
|
||||
|
||||
def test_coltypes_populates_data_type():
|
||||
"""Verify that GenericDataType values from coltypes are mapped correctly."""
|
||||
assert _GENERIC_TYPE_MAP[GenericDataType.NUMERIC] == "numeric"
|
||||
assert _GENERIC_TYPE_MAP[GenericDataType.STRING] == "string"
|
||||
assert _GENERIC_TYPE_MAP[GenericDataType.TEMPORAL] == "temporal"
|
||||
assert _GENERIC_TYPE_MAP[GenericDataType.BOOLEAN] == "boolean"
|
||||
|
||||
|
||||
def test_bool_isinstance_check_before_int():
|
||||
"""bool is a subclass of int; verify bool check takes priority in fallback."""
|
||||
|
||||
# When coltypes is unavailable, the fallback isinstance heuristic
|
||||
# must check bool before int/float since isinstance(True, int) is True.
|
||||
# We verify this indirectly: if _GENERIC_TYPE_MAP handles bool correctly,
|
||||
# and the fallback code checks bool first, booleans won't be "numeric".
|
||||
# Direct test: simulate what the fallback does
|
||||
sample_values = [True, False, True]
|
||||
data_type = "string"
|
||||
if all(isinstance(v, bool) for v in sample_values):
|
||||
data_type = "boolean"
|
||||
elif all(isinstance(v, (int, float)) for v in sample_values):
|
||||
data_type = "numeric"
|
||||
assert data_type == "boolean"
|
||||
|
||||
@@ -16,18 +16,25 @@
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Unit tests for get_chart_info MCP tool privacy behavior.
|
||||
Unit tests for get_chart_info MCP tool: dashboard-filter resolution and
|
||||
privacy behavior.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from contextlib import nullcontext
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
|
||||
from superset.commands.dashboard.exceptions import DashboardNotFoundError
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.mcp_service.chart.chart_helpers import (
|
||||
_resolve_filter_operator_and_value,
|
||||
build_applied_dashboard_filters,
|
||||
ChartNotOnDashboardError,
|
||||
)
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
ChartInfo,
|
||||
extract_filters_from_form_data,
|
||||
@@ -84,6 +91,238 @@ def _make_chart_info() -> ChartInfo:
|
||||
)
|
||||
|
||||
|
||||
class TestGetChartInfoRequestSchema:
|
||||
def test_dashboard_id_optional(self):
|
||||
request = GetChartInfoRequest(identifier=1)
|
||||
assert request.dashboard_id is None
|
||||
|
||||
def test_dashboard_id_accepted(self):
|
||||
request = GetChartInfoRequest(identifier=1, dashboard_id=42)
|
||||
assert request.dashboard_id == 42
|
||||
|
||||
|
||||
class TestResolveFilterOperatorAndValue:
|
||||
def test_matches_adhoc_filter_by_subject(self):
|
||||
efd = {
|
||||
"adhoc_filters": [
|
||||
{
|
||||
"subject": "country",
|
||||
"operator": "IN",
|
||||
"comparator": ["US", "CA"],
|
||||
}
|
||||
]
|
||||
}
|
||||
assert _resolve_filter_operator_and_value(efd, "country") == (
|
||||
"IN",
|
||||
["US", "CA"],
|
||||
)
|
||||
|
||||
def test_matches_legacy_filter_by_col(self):
|
||||
efd = {"filters": [{"col": "state", "op": "==", "val": "NY"}]}
|
||||
assert _resolve_filter_operator_and_value(efd, "state") == ("==", "NY")
|
||||
|
||||
def test_time_range_when_no_column(self):
|
||||
efd = {"time_range": "Last 7 days"}
|
||||
assert _resolve_filter_operator_and_value(efd, None) == (
|
||||
"TIME_RANGE",
|
||||
"Last 7 days",
|
||||
)
|
||||
|
||||
def test_column_not_in_extra_form_data(self):
|
||||
efd = {
|
||||
"adhoc_filters": [{"subject": "other", "operator": "==", "comparator": 1}]
|
||||
}
|
||||
assert _resolve_filter_operator_and_value(efd, "country") == (None, None)
|
||||
|
||||
def test_none_extra_form_data(self):
|
||||
assert _resolve_filter_operator_and_value(None, "country") == (None, None)
|
||||
|
||||
def test_ignores_non_dict_entries(self):
|
||||
efd = {
|
||||
"adhoc_filters": ["not-a-dict", None],
|
||||
"filters": [42, "foo"],
|
||||
}
|
||||
assert _resolve_filter_operator_and_value(efd, "country") == (None, None)
|
||||
|
||||
|
||||
class TestBuildAppliedDashboardFilters:
|
||||
"""The helper validates access, checks chart-on-dashboard, iterates
|
||||
native filters, resolves scope, and maps each to AppliedDashboardFilter."""
|
||||
|
||||
def _make_dashboard(self, json_metadata=None, position_json=None, slice_ids=None):
|
||||
dashboard = MagicMock()
|
||||
dashboard.json_metadata = json_metadata or "{}"
|
||||
dashboard.position_json = position_json or "{}"
|
||||
dashboard.slices = [MagicMock(id=sid) for sid in (slice_ids or [])]
|
||||
return dashboard
|
||||
|
||||
def test_chart_not_on_dashboard_raises(self):
|
||||
dashboard = self._make_dashboard(slice_ids=[2, 3])
|
||||
with (
|
||||
patch("superset.db") as mock_db,
|
||||
patch("superset.security_manager"),
|
||||
):
|
||||
mock_db.session.query.return_value.filter_by.return_value.one_or_none.return_value = dashboard # noqa: E501
|
||||
with pytest.raises(ChartNotOnDashboardError, match="not on dashboard"):
|
||||
build_applied_dashboard_filters(dashboard_id=10, chart_id=1)
|
||||
|
||||
def test_dashboard_not_found_raises(self):
|
||||
with (
|
||||
patch("superset.db") as mock_db,
|
||||
patch("superset.security_manager"),
|
||||
):
|
||||
mock_db.session.query.return_value.filter_by.return_value.one_or_none.return_value = None # noqa: E501
|
||||
with pytest.raises(DashboardNotFoundError):
|
||||
build_applied_dashboard_filters(dashboard_id=10, chart_id=1)
|
||||
|
||||
def test_in_scope_filter_with_static_default(self):
|
||||
native_filter = {
|
||||
"id": "NATIVE_FILTER-1",
|
||||
"name": "Country",
|
||||
"type": "NATIVE_FILTER",
|
||||
"filterType": "filter_select",
|
||||
"chartsInScope": [1],
|
||||
"targets": [{"column": {"name": "country"}, "datasetId": 7}],
|
||||
"defaultDataMask": {
|
||||
"filterState": {"value": ["US"]},
|
||||
"extraFormData": {
|
||||
"adhoc_filters": [
|
||||
{
|
||||
"subject": "country",
|
||||
"operator": "IN",
|
||||
"comparator": ["US"],
|
||||
}
|
||||
]
|
||||
},
|
||||
},
|
||||
}
|
||||
dashboard = self._make_dashboard(
|
||||
json_metadata='{"native_filter_configuration": %s}'
|
||||
% _json(native_filter_list=[native_filter]),
|
||||
slice_ids=[1],
|
||||
)
|
||||
|
||||
with (
|
||||
patch("superset.db") as mock_db,
|
||||
patch("superset.security_manager"),
|
||||
):
|
||||
mock_db.session.query.return_value.filter_by.return_value.one_or_none.return_value = dashboard # noqa: E501
|
||||
result = build_applied_dashboard_filters(dashboard_id=10, chart_id=1)
|
||||
|
||||
assert len(result) == 1
|
||||
flt = result[0]
|
||||
assert flt.id == "NATIVE_FILTER-1"
|
||||
assert flt.name == "Country"
|
||||
assert flt.filter_type == "filter_select"
|
||||
assert flt.column == "country"
|
||||
assert flt.operator == "IN"
|
||||
assert flt.value == ["US"]
|
||||
assert flt.status == "applied"
|
||||
|
||||
def test_excluded_chart_filter_skipped(self):
|
||||
native_filter = {
|
||||
"id": "NATIVE_FILTER-1",
|
||||
"name": "Region",
|
||||
"type": "NATIVE_FILTER",
|
||||
"filterType": "filter_select",
|
||||
"chartsInScope": [2, 3], # chart 1 excluded
|
||||
"targets": [{"column": {"name": "region"}, "datasetId": 7}],
|
||||
"defaultDataMask": {
|
||||
"filterState": {"value": ["NA"]},
|
||||
"extraFormData": {
|
||||
"filters": [{"col": "region", "op": "==", "val": "NA"}]
|
||||
},
|
||||
},
|
||||
}
|
||||
dashboard = self._make_dashboard(
|
||||
json_metadata='{"native_filter_configuration": %s}'
|
||||
% _json(native_filter_list=[native_filter]),
|
||||
slice_ids=[1],
|
||||
)
|
||||
|
||||
with (
|
||||
patch("superset.db") as mock_db,
|
||||
patch("superset.security_manager"),
|
||||
):
|
||||
mock_db.session.query.return_value.filter_by.return_value.one_or_none.return_value = dashboard # noqa: E501
|
||||
result = build_applied_dashboard_filters(dashboard_id=10, chart_id=1)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_default_to_first_item_marks_prequery(self):
|
||||
native_filter = {
|
||||
"id": "NATIVE_FILTER-1",
|
||||
"name": "Region",
|
||||
"type": "NATIVE_FILTER",
|
||||
"filterType": "filter_select",
|
||||
"chartsInScope": [1],
|
||||
"targets": [{"column": {"name": "region"}, "datasetId": 7}],
|
||||
"controlValues": {"defaultToFirstItem": True},
|
||||
"defaultDataMask": {},
|
||||
}
|
||||
dashboard = self._make_dashboard(
|
||||
json_metadata='{"native_filter_configuration": %s}'
|
||||
% _json(native_filter_list=[native_filter]),
|
||||
slice_ids=[1],
|
||||
)
|
||||
|
||||
with (
|
||||
patch("superset.db") as mock_db,
|
||||
patch("superset.security_manager"),
|
||||
):
|
||||
mock_db.session.query.return_value.filter_by.return_value.one_or_none.return_value = dashboard # noqa: E501
|
||||
result = build_applied_dashboard_filters(dashboard_id=10, chart_id=1)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].status == "not_applied_uses_default_to_first_item_prequery"
|
||||
assert result[0].operator is None
|
||||
assert result[0].value is None
|
||||
|
||||
def test_divider_entry_skipped(self):
|
||||
divider = {
|
||||
"id": "DIVIDER-1",
|
||||
"name": "Section header",
|
||||
"type": "DIVIDER",
|
||||
}
|
||||
dashboard = self._make_dashboard(
|
||||
json_metadata='{"native_filter_configuration": %s}'
|
||||
% _json(native_filter_list=[divider]),
|
||||
slice_ids=[1],
|
||||
)
|
||||
|
||||
with (
|
||||
patch("superset.db") as mock_db,
|
||||
patch("superset.security_manager"),
|
||||
):
|
||||
mock_db.session.query.return_value.filter_by.return_value.one_or_none.return_value = dashboard # noqa: E501
|
||||
result = build_applied_dashboard_filters(dashboard_id=10, chart_id=1)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_no_native_filters_returns_empty_list(self):
|
||||
dashboard = self._make_dashboard(
|
||||
json_metadata="{}",
|
||||
slice_ids=[1],
|
||||
)
|
||||
|
||||
with (
|
||||
patch("superset.db") as mock_db,
|
||||
patch("superset.security_manager"),
|
||||
):
|
||||
mock_db.session.query.return_value.filter_by.return_value.one_or_none.return_value = dashboard # noqa: E501
|
||||
result = build_applied_dashboard_filters(dashboard_id=10, chart_id=1)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
def _json(native_filter_list):
|
||||
"""Serialize a native_filter list as JSON string for embedding in
|
||||
json_metadata fixtures without escaping issues."""
|
||||
from superset.utils import json
|
||||
|
||||
return json.dumps(native_filter_list)
|
||||
|
||||
|
||||
class TestGetChartInfoPrivacy:
|
||||
@pytest.mark.asyncio
|
||||
async def test_restricted_user_redacts_saved_chart_data_model_fields(
|
||||
|
||||
@@ -2043,12 +2043,10 @@ class TestListDatasetsCreatedByMe:
|
||||
with pytest.raises(ValidationError, match="created_by_me"):
|
||||
ListDatasetsRequest(created_by_me=True, search="My tables")
|
||||
|
||||
def test_dataset_filter_rejects_created_by_fk(self):
|
||||
"""created_by_fk is not a public filter column; use created_by_me instead."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
DatasetFilter(col="created_by_fk", opr="eq", value=1)
|
||||
def test_dataset_filter_accepts_created_by_fk(self):
|
||||
"""created_by_fk is exposed for person-filtering via find_users."""
|
||||
f = DatasetFilter(col="created_by_fk", opr="eq", value=1)
|
||||
assert f.col == "created_by_fk"
|
||||
|
||||
|
||||
class TestListDatasetsOwnedByMe:
|
||||
@@ -2115,14 +2113,10 @@ class TestListDatasetsRequestWrapper:
|
||||
assert f.col == col
|
||||
|
||||
def test_dataset_filter_invalid_col_raises(self) -> None:
|
||||
"""Column names not in the Literal are rejected with a validation error.
|
||||
|
||||
This guards against LLMs passing ``created_by_fk`` or similar
|
||||
internal column names that are not exposed as filter fields.
|
||||
"""
|
||||
"""Column names not in the Literal are rejected with a validation error."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
for bad_col in ("created_by_fk", "id", "database_id", "owner"):
|
||||
for bad_col in ("id", "database_id", "owner"):
|
||||
with pytest.raises(ValidationError):
|
||||
DatasetFilter(col=bad_col, opr="eq", value="1")
|
||||
|
||||
|
||||
@@ -810,6 +810,52 @@ class TestGenerateExploreLink:
|
||||
assert "Dataset not found: 99999" in result.data["error"]
|
||||
assert "list_datasets" in result.data["error"]
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_explore_link_without_config(
|
||||
self, mock_find_dataset, mcp_server
|
||||
):
|
||||
"""Omitting config returns a default dataset explore URL."""
|
||||
mock_find_dataset.return_value = _mock_dataset(id=42)
|
||||
|
||||
request = GenerateExploreLinkRequest(dataset_id="42")
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"generate_explore_link", {"request": request.model_dump()}
|
||||
)
|
||||
|
||||
assert result.data["error"] is None
|
||||
assert (
|
||||
result.data["url"]
|
||||
== "http://localhost:9001/explore/?datasource_type=table"
|
||||
"&datasource_id=42"
|
||||
)
|
||||
assert result.data["form_data"] == {}
|
||||
assert result.data["form_data_key"] is None
|
||||
assert result.data["chart_type_label"] is None
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_explore_link_without_config_missing_dataset(
|
||||
self, mock_find_dataset, mcp_server
|
||||
):
|
||||
"""Omitting config still surfaces a dataset-not-found error."""
|
||||
mock_find_dataset.return_value = None
|
||||
|
||||
request = GenerateExploreLinkRequest(dataset_id="99999")
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"generate_explore_link", {"request": request.model_dump()}
|
||||
)
|
||||
|
||||
assert result.data["url"] == ""
|
||||
assert result.data["form_data"] == {}
|
||||
assert result.data["form_data_key"] is None
|
||||
assert result.data["chart_type_label"] is None
|
||||
assert "Dataset not found: 99999" in result.data["error"]
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_explore_link_nonexistent_uuid_dataset(
|
||||
|
||||
16
tests/unit_tests/mcp_service/plugin/__init__.py
Normal file
16
tests/unit_tests/mcp_service/plugin/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
16
tests/unit_tests/mcp_service/plugin/tool/__init__.py
Normal file
16
tests/unit_tests/mcp_service/plugin/tool/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
172
tests/unit_tests/mcp_service/plugin/tool/test_plugin_tools.py
Normal file
172
tests/unit_tests/mcp_service/plugin/tool/test_plugin_tools.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
from pydantic import ValidationError
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.mcp_service.plugin.schemas import ListPluginsRequest, PluginColumnFilter
|
||||
from superset.utils import json
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_mock_plugin(
|
||||
plugin_id: int = 1,
|
||||
name: str = "My Plugin",
|
||||
key: str = "my_plugin",
|
||||
bundle_url: str = "https://example.com/plugin.js",
|
||||
) -> MagicMock:
|
||||
plugin = MagicMock()
|
||||
plugin.id = plugin_id
|
||||
plugin.name = name
|
||||
plugin.key = key
|
||||
plugin.bundle_url = bundle_url
|
||||
plugin.changed_on = None
|
||||
plugin.created_on = None
|
||||
return plugin
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth():
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
class TestPluginColumnFilterSchema:
|
||||
def test_invalid_filter_column_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
PluginColumnFilter(col="bundle_url", opr="eq", value="test")
|
||||
|
||||
def test_valid_name_filter(self):
|
||||
f = PluginColumnFilter(col="name", opr="eq", value="test")
|
||||
assert f.col == "name"
|
||||
|
||||
def test_valid_key_filter(self):
|
||||
f = PluginColumnFilter(col="key", opr="eq", value="my_plugin")
|
||||
assert f.col == "key"
|
||||
|
||||
|
||||
@patch("superset.mcp_service.plugin.dao.DynamicPluginDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_plugins_basic(mock_list, mcp_server):
|
||||
plugin = create_mock_plugin()
|
||||
mock_list.return_value = ([plugin], 1)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("list_plugins", {})
|
||||
data = json.loads(result.content[0].text)
|
||||
assert "plugins" in data
|
||||
assert len(data["plugins"]) == 1
|
||||
assert data["plugins"][0]["id"] == 1
|
||||
assert data["plugins"][0]["name"] == "My Plugin"
|
||||
assert data["plugins"][0]["key"] == "my_plugin"
|
||||
|
||||
|
||||
@patch("superset.mcp_service.plugin.dao.DynamicPluginDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_plugins_with_request(mock_list, mcp_server):
|
||||
plugin = create_mock_plugin()
|
||||
mock_list.return_value = ([plugin], 1)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListPluginsRequest(page=1, page_size=10)
|
||||
result = await client.call_tool(
|
||||
"list_plugins", {"request": request.model_dump()}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["count"] == 1
|
||||
assert data["total_count"] == 1
|
||||
|
||||
|
||||
@patch("superset.mcp_service.plugin.dao.DynamicPluginDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_plugins_with_search(mock_list, mcp_server):
|
||||
plugin = create_mock_plugin(name="Custom Chart")
|
||||
mock_list.return_value = ([plugin], 1)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListPluginsRequest(page=1, page_size=10, search="custom")
|
||||
result = await client.call_tool(
|
||||
"list_plugins", {"request": request.model_dump()}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["plugins"][0]["name"] == "Custom Chart"
|
||||
|
||||
|
||||
@patch("superset.mcp_service.plugin.dao.DynamicPluginDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_plugins_empty(mock_list, mcp_server):
|
||||
mock_list.return_value = ([], 0)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("list_plugins", {})
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["count"] == 0
|
||||
assert data["plugins"] == []
|
||||
|
||||
|
||||
@patch("superset.mcp_service.plugin.dao.DynamicPluginDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_plugin_info_basic(mock_find, mcp_server):
|
||||
plugin = create_mock_plugin()
|
||||
mock_find.return_value = plugin
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_plugin_info", {"request": {"identifier": 1}}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["id"] == 1
|
||||
assert data["name"] == "My Plugin"
|
||||
assert data["key"] == "my_plugin"
|
||||
assert data["bundle_url"] == "https://example.com/plugin.js"
|
||||
|
||||
|
||||
@patch("superset.mcp_service.plugin.dao.DynamicPluginDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_plugin_info_not_found(mock_find, mcp_server):
|
||||
mock_find.return_value = None
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_plugin_info", {"request": {"identifier": 999}}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "not_found"
|
||||
|
||||
|
||||
def test_list_plugins_request_rejects_search_and_filters():
|
||||
with pytest.raises(ValidationError):
|
||||
ListPluginsRequest(
|
||||
search="test",
|
||||
filters=[{"col": "name", "opr": "eq", "value": "x"}],
|
||||
)
|
||||
16
tests/unit_tests/mcp_service/rls/__init__.py
Normal file
16
tests/unit_tests/mcp_service/rls/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
16
tests/unit_tests/mcp_service/rls/tool/__init__.py
Normal file
16
tests/unit_tests/mcp_service/rls/tool/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
244
tests/unit_tests/mcp_service/rls/tool/test_rls_tools.py
Normal file
244
tests/unit_tests/mcp_service/rls/tool/test_rls_tools.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
from pydantic import ValidationError
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.mcp_service.rls.schemas import ListRlsFiltersRequest, RlsColumnFilter
|
||||
from superset.utils import json
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_mock_rls_filter(
|
||||
filter_id: int = 1,
|
||||
name: str = "test_filter",
|
||||
filter_type: str = "Regular",
|
||||
clause: str = "user_id = {{current_user_id()}}",
|
||||
group_key: str | None = None,
|
||||
) -> MagicMock:
|
||||
rls_filter = MagicMock()
|
||||
rls_filter.id = filter_id
|
||||
rls_filter.name = name
|
||||
rls_filter.filter_type = filter_type
|
||||
rls_filter.clause = clause
|
||||
rls_filter.group_key = group_key
|
||||
rls_filter.changed_on = None
|
||||
|
||||
table = MagicMock()
|
||||
table.id = 1
|
||||
table.table_name = "sales"
|
||||
rls_filter.tables = [table]
|
||||
|
||||
role = MagicMock()
|
||||
role.id = 1
|
||||
role.name = "Alpha"
|
||||
rls_filter.roles = [role]
|
||||
|
||||
return rls_filter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth():
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
class TestRlsColumnFilterSchema:
|
||||
def test_invalid_filter_column_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
RlsColumnFilter(col="clause", opr="eq", value="test")
|
||||
|
||||
def test_valid_name_filter(self):
|
||||
f = RlsColumnFilter(col="name", opr="eq", value="test")
|
||||
assert f.col == "name"
|
||||
|
||||
def test_valid_filter_type_filter(self):
|
||||
f = RlsColumnFilter(col="filter_type", opr="eq", value="Regular")
|
||||
assert f.col == "filter_type"
|
||||
|
||||
|
||||
@patch("superset.daos.security.RLSDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_rls_filters_basic(mock_list, mcp_server):
|
||||
rls_filter = create_mock_rls_filter()
|
||||
mock_list.return_value = ([rls_filter], 1)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("list_rls_filters", {})
|
||||
assert result.content is not None
|
||||
data = json.loads(result.content[0].text)
|
||||
assert "rls_filters" in data
|
||||
assert len(data["rls_filters"]) == 1
|
||||
assert data["rls_filters"][0]["id"] == 1
|
||||
assert data["rls_filters"][0]["name"] == "test_filter"
|
||||
|
||||
|
||||
@patch("superset.daos.security.RLSDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_rls_filters_with_request(mock_list, mcp_server):
|
||||
rls_filter = create_mock_rls_filter()
|
||||
mock_list.return_value = ([rls_filter], 1)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListRlsFiltersRequest(page=1, page_size=10)
|
||||
result = await client.call_tool(
|
||||
"list_rls_filters", {"request": request.model_dump()}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["count"] == 1
|
||||
assert data["total_count"] == 1
|
||||
|
||||
|
||||
@patch("superset.daos.security.RLSDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_rls_filters_with_search(mock_list, mcp_server):
|
||||
rls_filter = create_mock_rls_filter(name="user_filter")
|
||||
mock_list.return_value = ([rls_filter], 1)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListRlsFiltersRequest(page=1, page_size=10, search="user")
|
||||
result = await client.call_tool(
|
||||
"list_rls_filters", {"request": request.model_dump()}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["rls_filters"][0]["name"] == "user_filter"
|
||||
|
||||
|
||||
@patch("superset.daos.security.RLSDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_rls_filters_returns_tables_and_roles(mock_list, mcp_server):
|
||||
rls_filter = create_mock_rls_filter()
|
||||
mock_list.return_value = ([rls_filter], 1)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListRlsFiltersRequest(
|
||||
page=1,
|
||||
page_size=10,
|
||||
select_columns=["id", "name", "tables", "roles"],
|
||||
)
|
||||
result = await client.call_tool(
|
||||
"list_rls_filters", {"request": request.model_dump()}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
item = data["rls_filters"][0]
|
||||
assert "tables" in item
|
||||
assert item["tables"][0]["table_name"] == "sales"
|
||||
assert "roles" in item
|
||||
assert item["roles"][0]["name"] == "Alpha"
|
||||
|
||||
|
||||
@patch("superset.daos.security.RLSDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_rls_filters_roles_only_select_columns(mock_list, mcp_server):
|
||||
"""Requesting only 'roles' must not raise ValueError from the privacy filter."""
|
||||
rls_filter = create_mock_rls_filter()
|
||||
mock_list.return_value = ([rls_filter], 1)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListRlsFiltersRequest(
|
||||
page=1,
|
||||
page_size=10,
|
||||
select_columns=["roles"],
|
||||
)
|
||||
result = await client.call_tool(
|
||||
"list_rls_filters", {"request": request.model_dump()}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
item = data["rls_filters"][0]
|
||||
assert "roles" in item
|
||||
assert item["roles"][0]["name"] == "Alpha"
|
||||
|
||||
|
||||
@patch("superset.daos.security.RLSDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_rls_filters_empty(mock_list, mcp_server):
|
||||
mock_list.return_value = ([], 0)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("list_rls_filters", {})
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["count"] == 0
|
||||
assert data["rls_filters"] == []
|
||||
|
||||
|
||||
@patch("superset.daos.security.RLSDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_rls_filter_info_basic(mock_find, mcp_server):
|
||||
rls_filter = create_mock_rls_filter()
|
||||
mock_find.return_value = rls_filter
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_rls_filter_info", {"request": {"identifier": 1}}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["id"] == 1
|
||||
assert data["name"] == "test_filter"
|
||||
assert data["filter_type"] == "Regular"
|
||||
assert data["clause"] == "user_id = {{current_user_id()}}"
|
||||
|
||||
|
||||
@patch("superset.daos.security.RLSDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_rls_filter_info_not_found(mock_find, mcp_server):
|
||||
mock_find.return_value = None
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_rls_filter_info", {"request": {"identifier": 999}}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "not_found"
|
||||
|
||||
|
||||
@patch("superset.daos.security.RLSDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_rls_filter_info_includes_tables_and_roles(mock_find, mcp_server):
|
||||
rls_filter = create_mock_rls_filter()
|
||||
mock_find.return_value = rls_filter
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_rls_filter_info", {"request": {"identifier": 1}}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["tables"][0]["table_name"] == "sales"
|
||||
assert data["roles"][0]["name"] == "Alpha"
|
||||
|
||||
|
||||
def test_list_rls_filters_request_rejects_search_and_filters():
|
||||
with pytest.raises(ValidationError):
|
||||
ListRlsFiltersRequest(
|
||||
search="test",
|
||||
filters=[{"col": "name", "opr": "eq", "value": "x"}],
|
||||
)
|
||||
257
tests/unit_tests/mcp_service/system/tool/test_find_users.py
Normal file
257
tests/unit_tests/mcp_service/system/tool/test_find_users.py
Normal file
@@ -0,0 +1,257 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Tests for find_users MCP tool and its filter contract."""
|
||||
|
||||
import importlib
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
from fastmcp.exceptions import ToolError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.mcp_service.system.schemas import FindUsersRequest, FindUsersResponse
|
||||
from superset.utils import json
|
||||
|
||||
# Import the submodule directly so ``patch.object`` targets the module (not the
|
||||
# ``find_users`` function that ``tool/__init__.py`` re-exports onto the
|
||||
# package). The package attribute is the function, so dotted-string patches
|
||||
# like ``superset.mcp_service.system.tool.find_users.db`` can resolve to the
|
||||
# function in some import orderings and fail with AttributeError.
|
||||
find_users_module = importlib.import_module(
|
||||
"superset.mcp_service.system.tool.find_users"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth():
|
||||
"""Mock authentication for all tests."""
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
def _make_user(id_, username, first=None, last=None, email=None, active=True):
|
||||
"""Build a Mock user with the attributes serialize_user_object reads."""
|
||||
user = Mock(
|
||||
spec=["id", "username", "first_name", "last_name", "email", "active", "roles"]
|
||||
)
|
||||
user.id = id_
|
||||
user.username = username
|
||||
user.first_name = first
|
||||
user.last_name = last
|
||||
user.email = email
|
||||
user.active = active
|
||||
user.roles = []
|
||||
return user
|
||||
|
||||
|
||||
def _patch_user_query(rows):
|
||||
"""Patch the SQLAlchemy chain used by find_users to return a fixed result set."""
|
||||
chain = MagicMock()
|
||||
chain.filter.return_value = chain
|
||||
chain.order_by.return_value = chain
|
||||
chain.limit.return_value = chain
|
||||
chain.all.return_value = rows
|
||||
session = MagicMock()
|
||||
session.query.return_value = chain
|
||||
return session, chain
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_find_users_request_rejects_empty_query():
|
||||
with pytest.raises(ValidationError):
|
||||
FindUsersRequest(query="")
|
||||
|
||||
|
||||
def test_find_users_request_rejects_extra_fields():
|
||||
with pytest.raises(ValidationError):
|
||||
FindUsersRequest(query="maxime", random_field="x")
|
||||
|
||||
|
||||
def test_find_users_response_default_truncated_false():
|
||||
resp = FindUsersResponse(users=[], count=0)
|
||||
assert resp.truncated is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool-level tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_users_returns_matches(mcp_server):
|
||||
rows = [
|
||||
_make_user(
|
||||
7, "maxime", first="Maxime", last="Beauchemin", email="m@example.com"
|
||||
)
|
||||
]
|
||||
session, _ = _patch_user_query(rows)
|
||||
|
||||
with (
|
||||
patch.object(find_users_module, "db") as mock_db,
|
||||
patch.object(find_users_module, "security_manager") as mock_sm,
|
||||
patch.object(find_users_module, "or_") as mock_or,
|
||||
):
|
||||
mock_db.session = session
|
||||
mock_sm.user_model = MagicMock()
|
||||
mock_or.return_value = MagicMock()
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"find_users", {"request": {"query": "maxime"}}
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["count"] == 1
|
||||
assert data["truncated"] is False
|
||||
assert data["users"][0]["id"] == 7
|
||||
assert data["users"][0]["username"] == "maxime"
|
||||
assert data["users"][0]["first_name"] == "Maxime"
|
||||
assert data["users"][0]["last_name"] == "Beauchemin"
|
||||
# Privacy: minimal projection excludes identity attributes that aren't
|
||||
# required for filter resolution. Catch regressions on the response shape.
|
||||
for forbidden in ("email", "active", "roles"):
|
||||
assert forbidden not in data["users"][0]
|
||||
# or_ should have been built across the four matched columns
|
||||
assert mock_or.called
|
||||
assert len(mock_or.call_args.args) == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_users_truncates_when_more_rows_than_page_size(mcp_server):
|
||||
# page_size=2 with 3 returned rows -> truncated, response trimmed to 2
|
||||
rows = [
|
||||
_make_user(1, "a"),
|
||||
_make_user(2, "b"),
|
||||
_make_user(3, "c"),
|
||||
]
|
||||
session, chain = _patch_user_query(rows)
|
||||
|
||||
with (
|
||||
patch.object(find_users_module, "db") as mock_db,
|
||||
patch.object(find_users_module, "security_manager") as mock_sm,
|
||||
patch.object(find_users_module, "or_") as mock_or,
|
||||
):
|
||||
mock_db.session = session
|
||||
mock_sm.user_model = MagicMock()
|
||||
mock_or.return_value = MagicMock()
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"find_users", {"request": {"query": "a", "page_size": 2}}
|
||||
)
|
||||
|
||||
# Tool requested page_size+1 rows for truncation detection
|
||||
chain.limit.assert_called_with(3)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["count"] == 2
|
||||
assert data["truncated"] is True
|
||||
assert [u["id"] for u in data["users"]] == [1, 2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_users_rejects_empty_query_via_client(mcp_server):
|
||||
async with Client(mcp_server) as client:
|
||||
with pytest.raises(ToolError):
|
||||
await client.call_tool("find_users", {"request": {"query": ""}})
|
||||
|
||||
|
||||
@pytest.mark.parametrize("blank", [" ", " ", "\t", "\n \t"])
|
||||
def test_find_users_request_rejects_whitespace_only_query(blank):
|
||||
# Whitespace-only queries would strip to "" and produce a LIKE "%%" pattern
|
||||
# that enumerates the entire user directory. The validator must reject them.
|
||||
with pytest.raises(ValidationError):
|
||||
FindUsersRequest(query=blank)
|
||||
|
||||
|
||||
def test_find_users_request_strips_query_whitespace():
|
||||
# Validator should normalize the stored query so downstream LIKE patterns
|
||||
# don't carry leading/trailing whitespace.
|
||||
request = FindUsersRequest(query=" maxime ")
|
||||
assert request.query == "maxime"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Filter contract: created_by_fk / changed_by_fk filtering on list tools
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_dashboards_passes_created_by_fk_filter_to_dao(
|
||||
mock_list, mcp_server
|
||||
):
|
||||
"""list_dashboards should accept created_by_fk filter and forward it."""
|
||||
mock_list.return_value = ([], 0)
|
||||
async with Client(mcp_server) as client:
|
||||
await client.call_tool(
|
||||
"list_dashboards",
|
||||
{
|
||||
"request": {
|
||||
"filters": [{"col": "created_by_fk", "opr": "eq", "value": 7}],
|
||||
"page": 1,
|
||||
"page_size": 10,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert mock_list.called
|
||||
forwarded_filters = mock_list.call_args.kwargs.get("column_operators")
|
||||
assert forwarded_filters is not None
|
||||
assert any(
|
||||
getattr(f, "col", None) == "created_by_fk" and getattr(f, "value", None) == 7
|
||||
for f in forwarded_filters
|
||||
)
|
||||
|
||||
|
||||
@patch("superset.daos.chart.ChartDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_charts_passes_changed_by_fk_filter_to_dao(mock_list, mcp_server):
|
||||
"""list_charts should accept changed_by_fk filter and forward it."""
|
||||
mock_list.return_value = ([], 0)
|
||||
async with Client(mcp_server) as client:
|
||||
await client.call_tool(
|
||||
"list_charts",
|
||||
{
|
||||
"request": {
|
||||
"filters": [{"col": "changed_by_fk", "opr": "eq", "value": 7}],
|
||||
"page": 1,
|
||||
"page_size": 10,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert mock_list.called
|
||||
forwarded_filters = mock_list.call_args.kwargs.get("column_operators")
|
||||
assert forwarded_filters is not None
|
||||
assert any(getattr(f, "col", None) == "changed_by_fk" for f in forwarded_filters)
|
||||
@@ -22,6 +22,8 @@ from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
from fastmcp.client.client import CallToolResult
|
||||
from mcp.types import TextContent
|
||||
from pydantic import ValidationError
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
@@ -45,6 +47,14 @@ get_schema_module = importlib.import_module(
|
||||
"superset.mcp_service.system.tool.get_schema"
|
||||
)
|
||||
|
||||
|
||||
def _result_text(result: CallToolResult) -> str:
|
||||
"""Return the text payload from the first content block of a tool result."""
|
||||
block = result.content[0]
|
||||
assert isinstance(block, TextContent)
|
||||
return block.text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -198,7 +208,7 @@ async def test_get_schema_returns_structured_privacy_error_for_dataset(mcp_serve
|
||||
{"request": {"model_type": "dataset"}},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
data = json.loads(_result_text(result))
|
||||
assert data["error_type"] == DATA_MODEL_METADATA_ERROR_TYPE
|
||||
assert data["privacy_scope"] == "data_model"
|
||||
|
||||
@@ -241,7 +251,7 @@ async def test_get_schema_redacts_chart_data_model_fields(mcp_server):
|
||||
{"request": {"model_type": "chart"}},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
data = json.loads(_result_text(result))
|
||||
schema_info = data["schema_info"]
|
||||
assert all(
|
||||
column["name"] not in CHART_DATA_MODEL_COLUMNS
|
||||
@@ -389,7 +399,7 @@ class TestGetInstanceInfoCurrentUserViaMCP:
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("get_instance_info", {"request": {}})
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
data = json.loads(_result_text(result))
|
||||
assert "current_user" in data
|
||||
cu = data["current_user"]
|
||||
assert cu["id"] == 5
|
||||
@@ -418,7 +428,7 @@ class TestGetInstanceInfoCurrentUserViaMCP:
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("get_instance_info", {"request": {}})
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
data = json.loads(_result_text(result))
|
||||
assert data["current_user"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -444,7 +454,7 @@ class TestGetInstanceInfoCurrentUserViaMCP:
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("get_instance_info", {"request": {}})
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
data = json.loads(_result_text(result))
|
||||
cu = data["current_user"]
|
||||
assert cu["id"] == 99
|
||||
assert cu["username"] == "bot"
|
||||
@@ -460,28 +470,50 @@ class TestGetInstanceInfoCurrentUserViaMCP:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_chart_filter_rejects_created_by_fk() -> None:
|
||||
"""created_by_fk is not a valid ChartFilter column; use created_by_me instead."""
|
||||
with pytest.raises(ValidationError):
|
||||
ChartFilter(col="created_by_fk", opr="eq", value=42)
|
||||
def test_chart_filter_rejects_user_directory_columns_other_than_fk() -> None:
|
||||
"""ChartFilter still rejects user-directory columns that expose names."""
|
||||
for col in ("created_by_name", "owners", "changed_by"):
|
||||
with pytest.raises(ValidationError):
|
||||
ChartFilter.model_validate({"col": col, "opr": "eq", "value": "anything"})
|
||||
|
||||
|
||||
def test_chart_filter_accepts_created_and_changed_by_fk() -> None:
|
||||
"""ChartFilter allows filtering by created_by_fk / changed_by_fk (user IDs)."""
|
||||
for col in ("created_by_fk", "changed_by_fk"):
|
||||
f = ChartFilter.model_validate({"col": col, "opr": "eq", "value": 42})
|
||||
assert f.col == col
|
||||
|
||||
|
||||
def test_chart_filter_rejects_invalid_column():
|
||||
"""Test that ChartFilter rejects invalid column names."""
|
||||
with pytest.raises(ValidationError):
|
||||
ChartFilter(col="nonexistent_column", opr="eq", value=42)
|
||||
ChartFilter.model_validate(
|
||||
{"col": "nonexistent_column", "opr": "eq", "value": 42}
|
||||
)
|
||||
|
||||
|
||||
def test_dashboard_filter_rejects_created_by_fk():
|
||||
"""created_by_fk is not a valid DashboardFilter column; use created_by_me."""
|
||||
with pytest.raises(ValidationError):
|
||||
DashboardFilter(col="created_by_fk", opr="eq", value=42)
|
||||
def test_dashboard_filter_rejects_user_directory_columns_other_than_fk() -> None:
|
||||
"""DashboardFilter still rejects user-directory columns that expose names."""
|
||||
for col in ("created_by_name", "owners", "changed_by"):
|
||||
with pytest.raises(ValidationError):
|
||||
DashboardFilter.model_validate(
|
||||
{"col": col, "opr": "eq", "value": "anything"}
|
||||
)
|
||||
|
||||
|
||||
def test_dashboard_filter_accepts_created_and_changed_by_fk() -> None:
|
||||
"""DashboardFilter allows filtering by created_by_fk / changed_by_fk."""
|
||||
for col in ("created_by_fk", "changed_by_fk"):
|
||||
f = DashboardFilter.model_validate({"col": col, "opr": "eq", "value": 42})
|
||||
assert f.col == col
|
||||
|
||||
|
||||
def test_dashboard_filter_rejects_invalid_column():
|
||||
"""Test that DashboardFilter rejects invalid column names."""
|
||||
with pytest.raises(ValidationError):
|
||||
DashboardFilter(col="nonexistent_column", opr="eq", value=42)
|
||||
DashboardFilter.model_validate(
|
||||
{"col": "nonexistent_column", "opr": "eq", "value": 42}
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -492,12 +524,12 @@ def test_dashboard_filter_rejects_invalid_column():
|
||||
def test_chart_filter_existing_columns_still_work():
|
||||
"""Test that pre-existing chart filter columns are not broken."""
|
||||
for col in ("slice_name", "viz_type", "datasource_name"):
|
||||
f = ChartFilter(col=col, opr="eq", value="test")
|
||||
f = ChartFilter.model_validate({"col": col, "opr": "eq", "value": "test"})
|
||||
assert f.col == col
|
||||
|
||||
|
||||
def test_dashboard_filter_existing_columns_still_work():
|
||||
"""Test that pre-existing dashboard filter columns are not broken."""
|
||||
for col in ("dashboard_title", "published", "favorite"):
|
||||
f = DashboardFilter(col=col, opr="eq", value="test")
|
||||
f = DashboardFilter.model_validate({"col": col, "opr": "eq", "value": "test"})
|
||||
assert f.col == col
|
||||
|
||||
@@ -326,11 +326,19 @@ class TestGetSchemaToolViaClient:
|
||||
async def test_get_schema_omits_user_directory_columns(
|
||||
self, mock_filters, mcp_server
|
||||
):
|
||||
"""Test that schema discovery does not advertise user/access fields."""
|
||||
"""Test that schema discovery does not advertise user/access fields.
|
||||
|
||||
created_by_fk and changed_by_fk are intentionally allowed in
|
||||
filter_columns so callers can filter by user ID resolved via find_users,
|
||||
but they remain hidden from select_columns and sortable_columns so the
|
||||
directory itself is never exposed.
|
||||
"""
|
||||
mock_filters.return_value = {
|
||||
"dashboard_title": ["eq", "ilike"],
|
||||
"owner": ["rel_m_m"],
|
||||
"published": ["eq"],
|
||||
"created_by_fk": ["eq", "in"],
|
||||
"changed_by_fk": ["eq", "in"],
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
@@ -352,9 +360,16 @@ class TestGetSchemaToolViaClient:
|
||||
"owner",
|
||||
):
|
||||
assert field not in select_column_names
|
||||
assert field not in info["filter_columns"]
|
||||
assert field not in info["sortable_columns"]
|
||||
|
||||
# User-name and relationship fields stay out of filter_columns
|
||||
for field in ("owners", "roles", "created_by", "changed_by", "owner"):
|
||||
assert field not in info["filter_columns"]
|
||||
|
||||
# ID-only filter columns are advertised so callers can filter via find_users
|
||||
assert "created_by_fk" in info["filter_columns"]
|
||||
assert "changed_by_fk" in info["filter_columns"]
|
||||
|
||||
@patch("superset.daos.chart.ChartDAO.get_filterable_columns_and_operators")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_schema_chart_omits_self_referencing_filter_columns(
|
||||
@@ -362,8 +377,9 @@ class TestGetSchemaToolViaClient:
|
||||
):
|
||||
"""Test that chart schema does not advertise self-referencing filter columns.
|
||||
|
||||
Even if the DAO returns created_by_fk or owner, they must be excluded so
|
||||
LLMs cannot discover and use them to enumerate user IDs.
|
||||
Even if the DAO returns owner or created_by_fk_or_owner, they must be
|
||||
excluded — these synthetic columns are generated server-side from the
|
||||
owned_by_me flag and are not directly usable by LLM callers.
|
||||
"""
|
||||
mock_filters.return_value = {
|
||||
"slice_name": ["eq", "ilike"],
|
||||
@@ -381,7 +397,7 @@ class TestGetSchemaToolViaClient:
|
||||
info = data["schema_info"]
|
||||
|
||||
assert "slice_name" in info["filter_columns"]
|
||||
for field in ("created_by_fk", "owner", "created_by_fk_or_owner"):
|
||||
for field in ("owner", "created_by_fk_or_owner"):
|
||||
assert field not in info["filter_columns"]
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.get_filterable_columns_and_operators")
|
||||
@@ -391,8 +407,9 @@ class TestGetSchemaToolViaClient:
|
||||
):
|
||||
"""Test that dataset schema does not advertise self-referencing filter columns.
|
||||
|
||||
Even if the DAO returns created_by_fk or owner, they must be excluded so
|
||||
LLMs cannot discover and use them to enumerate user IDs.
|
||||
Even if the DAO returns owner or created_by_fk_or_owner, they must be
|
||||
excluded — these synthetic columns are generated server-side from the
|
||||
owned_by_me flag and are not directly usable by LLM callers.
|
||||
"""
|
||||
mock_filters.return_value = {
|
||||
"table_name": ["eq", "ilike"],
|
||||
@@ -410,7 +427,7 @@ class TestGetSchemaToolViaClient:
|
||||
info = data["schema_info"]
|
||||
|
||||
assert "table_name" in info["filter_columns"]
|
||||
for field in ("created_by_fk", "owner", "created_by_fk_or_owner"):
|
||||
for field in ("owner", "created_by_fk_or_owner"):
|
||||
assert field not in info["filter_columns"]
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.get_filterable_columns_and_operators")
|
||||
@@ -420,8 +437,9 @@ class TestGetSchemaToolViaClient:
|
||||
):
|
||||
"""Test dashboard schema omits self-referencing filter columns.
|
||||
|
||||
Even if the DAO returns created_by_fk or owner, they must be excluded
|
||||
so LLMs cannot discover and use them to enumerate user IDs.
|
||||
Even if the DAO returns owner or created_by_fk_or_owner, they must be
|
||||
excluded — these synthetic columns are generated server-side from the
|
||||
owned_by_me flag and are not directly usable by LLM callers.
|
||||
"""
|
||||
mock_filters.return_value = {
|
||||
"dashboard_title": ["eq", "ilike"],
|
||||
@@ -439,7 +457,7 @@ class TestGetSchemaToolViaClient:
|
||||
info = data["schema_info"]
|
||||
|
||||
assert "dashboard_title" in info["filter_columns"]
|
||||
for field in ("created_by_fk", "owner", "created_by_fk_or_owner"):
|
||||
for field in ("owner", "created_by_fk_or_owner"):
|
||||
assert field not in info["filter_columns"]
|
||||
|
||||
|
||||
|
||||
@@ -15,8 +15,10 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import date, datetime, time, timezone
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
@@ -40,6 +42,7 @@ from superset_core.semantic_layers.types import (
|
||||
from superset_core.semantic_layers.view import SemanticViewFeature
|
||||
|
||||
from superset.semantic_layers.mapper import (
|
||||
_coerce_scalar_filter_value,
|
||||
_convert_query_object_filter,
|
||||
_convert_time_grain,
|
||||
_get_filters_from_extras,
|
||||
@@ -1204,6 +1207,91 @@ def test_convert_query_object_filter_like() -> None:
|
||||
}
|
||||
|
||||
|
||||
def test_convert_query_object_filter_coerces_integer_string_value() -> None:
|
||||
"""Test scalar filter values are coerced to dimension type."""
|
||||
all_dimensions = {
|
||||
"birthyear": Dimension(
|
||||
"birthyear",
|
||||
"birthyear",
|
||||
pa.int64(),
|
||||
"birthyear",
|
||||
"Birthyear",
|
||||
)
|
||||
}
|
||||
|
||||
filter_: ValidatedQueryObjectFilterClause = {
|
||||
"op": FilterOperator.GREATER_THAN_OR_EQUALS.value,
|
||||
"col": "birthyear",
|
||||
"val": "1982",
|
||||
}
|
||||
|
||||
result = _convert_query_object_filter(filter_, all_dimensions)
|
||||
|
||||
assert result == {
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=all_dimensions["birthyear"],
|
||||
operator=Operator.GREATER_THAN_OR_EQUAL,
|
||||
value=1982,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def test_convert_query_object_filter_coerces_in_integer_values() -> None:
|
||||
"""Test IN filter list values are coerced element-wise."""
|
||||
all_dimensions = {
|
||||
"order_id__amount": Dimension(
|
||||
"order_id__amount",
|
||||
"order_id__amount",
|
||||
pa.int64(),
|
||||
"order_id__amount",
|
||||
"Order amount",
|
||||
)
|
||||
}
|
||||
|
||||
filter_: ValidatedQueryObjectFilterClause = {
|
||||
"op": FilterOperator.IN.value,
|
||||
"col": "order_id__amount",
|
||||
"val": ["58", "61"],
|
||||
}
|
||||
|
||||
result = _convert_query_object_filter(filter_, all_dimensions)
|
||||
|
||||
assert result == {
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=all_dimensions["order_id__amount"],
|
||||
operator=Operator.IN,
|
||||
value=frozenset({58, 61}),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def test_convert_query_object_filter_invalid_integer_value_raises() -> None:
|
||||
"""Test invalid integer value raises a clear error."""
|
||||
all_dimensions = {
|
||||
"birthyear": Dimension(
|
||||
"birthyear",
|
||||
"birthyear",
|
||||
pa.int64(),
|
||||
"birthyear",
|
||||
"Birthyear",
|
||||
)
|
||||
}
|
||||
|
||||
filter_: ValidatedQueryObjectFilterClause = {
|
||||
"op": FilterOperator.GREATER_THAN_OR_EQUALS.value,
|
||||
"col": "birthyear",
|
||||
"val": "nineteen-eighty-two",
|
||||
}
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Invalid integer value 'nineteen-eighty-two' for filter column birthyear",
|
||||
):
|
||||
_convert_query_object_filter(filter_, all_dimensions)
|
||||
|
||||
|
||||
def test_get_results_without_time_offsets(
|
||||
mock_datasource: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
@@ -1923,6 +2011,86 @@ def test_convert_query_object_filter_temporal_range_with_value() -> None:
|
||||
}
|
||||
|
||||
|
||||
def test_convert_query_object_filter_temporal_range_coerces_date_bounds() -> None:
|
||||
"""
|
||||
TEMPORAL_RANGE bounds should be coerced against the dimension's dtype so
|
||||
date/timestamp columns are not compared against raw strings.
|
||||
"""
|
||||
all_dimensions = {
|
||||
"order_date": Dimension(
|
||||
"order_date", "order_date", pa.date32(), "order_date", "Order date"
|
||||
)
|
||||
}
|
||||
filter_: ValidatedQueryObjectFilterClause = {
|
||||
"op": FilterOperator.TEMPORAL_RANGE.value,
|
||||
"col": "order_date",
|
||||
"val": "2025-01-01 : 2025-12-31",
|
||||
}
|
||||
|
||||
result = _convert_query_object_filter(filter_, all_dimensions)
|
||||
|
||||
assert result == {
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=all_dimensions["order_date"],
|
||||
operator=Operator.GREATER_THAN_OR_EQUAL,
|
||||
value=date(2025, 1, 1),
|
||||
),
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=all_dimensions["order_date"],
|
||||
operator=Operator.LESS_THAN,
|
||||
value=date(2025, 12, 31),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def test_convert_query_object_filter_temporal_range_open_ended() -> None:
|
||||
"""
|
||||
Open-ended TEMPORAL_RANGE bounds should emit only the bounded predicate.
|
||||
"""
|
||||
all_dimensions = {
|
||||
"order_date": Dimension(
|
||||
"order_date", "order_date", pa.date32(), "order_date", "Order date"
|
||||
)
|
||||
}
|
||||
|
||||
only_start: ValidatedQueryObjectFilterClause = {
|
||||
"op": FilterOperator.TEMPORAL_RANGE.value,
|
||||
"col": "order_date",
|
||||
"val": "2025-01-01 : ",
|
||||
}
|
||||
assert _convert_query_object_filter(only_start, all_dimensions) == {
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=all_dimensions["order_date"],
|
||||
operator=Operator.GREATER_THAN_OR_EQUAL,
|
||||
value=date(2025, 1, 1),
|
||||
),
|
||||
}
|
||||
|
||||
only_end: ValidatedQueryObjectFilterClause = {
|
||||
"op": FilterOperator.TEMPORAL_RANGE.value,
|
||||
"col": "order_date",
|
||||
"val": " : 2025-12-31",
|
||||
}
|
||||
assert _convert_query_object_filter(only_end, all_dimensions) == {
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=all_dimensions["order_date"],
|
||||
operator=Operator.LESS_THAN,
|
||||
value=date(2025, 12, 31),
|
||||
),
|
||||
}
|
||||
|
||||
empty: ValidatedQueryObjectFilterClause = {
|
||||
"op": FilterOperator.TEMPORAL_RANGE.value,
|
||||
"col": "order_date",
|
||||
"val": " : ",
|
||||
}
|
||||
assert _convert_query_object_filter(empty, all_dimensions) is None
|
||||
|
||||
|
||||
def test_get_order_adhoc_with_none_sql_expression(mock_datasource: MagicMock) -> None:
|
||||
"""
|
||||
Test order extraction skips adhoc expression with None sqlExpression.
|
||||
@@ -2741,3 +2909,205 @@ def test_get_group_limit_filters_no_granularity(
|
||||
|
||||
# Should return None - no granularity means no time filters added
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _coerce_scalar_filter_value: per-dtype branches
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _dim(dtype: pa.DataType, name: str = "d") -> Dimension:
|
||||
return Dimension(name, name, dtype, name, name.capitalize())
|
||||
|
||||
|
||||
def test_coerce_none_returns_none() -> None:
|
||||
assert _coerce_scalar_filter_value(None, _dim(pa.int64())) is None
|
||||
|
||||
|
||||
def test_coerce_unsupported_dtype_passes_through() -> None:
|
||||
# utf8 (and any dtype not branched in the function) returns the value as-is.
|
||||
assert _coerce_scalar_filter_value("abc", _dim(pa.utf8())) == "abc"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raw,expected",
|
||||
[
|
||||
(True, True),
|
||||
(False, False),
|
||||
(1, True),
|
||||
(0, False),
|
||||
(1.0, True),
|
||||
(0.0, False),
|
||||
("true", True),
|
||||
("T", True),
|
||||
(" 1 ", True),
|
||||
("yes", True),
|
||||
("Y", True),
|
||||
("on", True),
|
||||
("false", False),
|
||||
("F", False),
|
||||
("0", False),
|
||||
("no", False),
|
||||
("N", False),
|
||||
("off", False),
|
||||
],
|
||||
)
|
||||
def test_coerce_boolean(raw: Any, expected: bool) -> None:
|
||||
assert _coerce_scalar_filter_value(raw, _dim(pa.bool_())) is expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("raw", ["maybe", 2, 0.5, -1])
|
||||
def test_coerce_boolean_invalid_raises(raw: Any) -> None:
|
||||
with pytest.raises(ValueError, match="Invalid boolean value"):
|
||||
_coerce_scalar_filter_value(raw, _dim(pa.bool_()))
|
||||
|
||||
|
||||
def test_coerce_integer_passthrough() -> None:
|
||||
assert _coerce_scalar_filter_value(42, _dim(pa.int64())) == 42
|
||||
|
||||
|
||||
def test_coerce_integer_accepts_integer_valued_float() -> None:
|
||||
# JSON round-trips can turn an int into ``42.0``; accept losslessly.
|
||||
assert _coerce_scalar_filter_value(42.0, _dim(pa.int64())) == 42
|
||||
|
||||
|
||||
def test_coerce_integer_rejects_bool() -> None:
|
||||
# bool is a subclass of int; we explicitly reject it.
|
||||
with pytest.raises(ValueError, match="Invalid integer value"):
|
||||
_coerce_scalar_filter_value(True, _dim(pa.int64()))
|
||||
|
||||
|
||||
def test_coerce_integer_rejects_non_integer_float() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid integer value"):
|
||||
_coerce_scalar_filter_value(1.5, _dim(pa.int64()))
|
||||
|
||||
|
||||
def test_coerce_integer_rejects_other_types() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid integer value"):
|
||||
_coerce_scalar_filter_value([1], _dim(pa.int64()))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dtype",
|
||||
[pa.float64(), pa.decimal128(10, 2)],
|
||||
)
|
||||
def test_coerce_floating_or_decimal(dtype: pa.DataType) -> None:
|
||||
assert _coerce_scalar_filter_value(1, _dim(dtype)) == 1.0
|
||||
assert _coerce_scalar_filter_value(1.5, _dim(dtype)) == 1.5
|
||||
assert _coerce_scalar_filter_value(" 2.5 ", _dim(dtype)) == 2.5
|
||||
|
||||
|
||||
def test_coerce_floating_rejects_bool() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid numeric value"):
|
||||
_coerce_scalar_filter_value(True, _dim(pa.float64()))
|
||||
|
||||
|
||||
def test_coerce_floating_invalid_string_raises() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid numeric value"):
|
||||
_coerce_scalar_filter_value("not-a-number", _dim(pa.float64()))
|
||||
|
||||
|
||||
def test_coerce_floating_rejects_other_types() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid numeric value"):
|
||||
_coerce_scalar_filter_value([1.0], _dim(pa.float64()))
|
||||
|
||||
|
||||
def test_coerce_date_from_datetime() -> None:
|
||||
out = _coerce_scalar_filter_value(datetime(2025, 1, 2, 12, 0), _dim(pa.date32()))
|
||||
assert out == date(2025, 1, 2)
|
||||
|
||||
|
||||
def test_coerce_date_passthrough() -> None:
|
||||
out = _coerce_scalar_filter_value(date(2025, 1, 2), _dim(pa.date32()))
|
||||
assert out == date(2025, 1, 2)
|
||||
|
||||
|
||||
def test_coerce_date_from_iso_string() -> None:
|
||||
out = _coerce_scalar_filter_value(" 2025-01-02 ", _dim(pa.date32()))
|
||||
assert out == date(2025, 1, 2)
|
||||
|
||||
|
||||
def test_coerce_date_invalid_string_raises() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid date value"):
|
||||
_coerce_scalar_filter_value("not-a-date", _dim(pa.date32()))
|
||||
|
||||
|
||||
def test_coerce_date_rejects_other_types() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid date value"):
|
||||
_coerce_scalar_filter_value(20250102, _dim(pa.date32()))
|
||||
|
||||
|
||||
def test_coerce_timestamp_from_datetime_passthrough() -> None:
|
||||
dt = datetime(2025, 1, 2, 3, 4, 5)
|
||||
# Naive dtype: returned as-is, still naive.
|
||||
assert _coerce_scalar_filter_value(dt, _dim(pa.timestamp("us"))) == dt
|
||||
|
||||
|
||||
def test_coerce_timestamp_from_date() -> None:
|
||||
out = _coerce_scalar_filter_value(date(2025, 1, 2), _dim(pa.timestamp("us")))
|
||||
assert out == datetime(2025, 1, 2, 0, 0)
|
||||
|
||||
|
||||
def test_coerce_timestamp_from_iso_string_with_z() -> None:
|
||||
out = _coerce_scalar_filter_value("2025-01-02T03:04:05Z", _dim(pa.timestamp("us")))
|
||||
assert out == datetime.fromisoformat("2025-01-02T03:04:05+00:00")
|
||||
|
||||
|
||||
def test_coerce_timestamp_invalid_string_raises() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid timestamp value"):
|
||||
_coerce_scalar_filter_value("not-a-ts", _dim(pa.timestamp("us")))
|
||||
|
||||
|
||||
def test_coerce_timestamp_rejects_other_types() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid timestamp value"):
|
||||
_coerce_scalar_filter_value(1234567890, _dim(pa.timestamp("us")))
|
||||
|
||||
|
||||
def test_coerce_timestamp_tz_aware_dtype_attaches_tz_to_naive_datetime() -> None:
|
||||
dt = datetime(2025, 1, 2, 3, 4, 5)
|
||||
out = _coerce_scalar_filter_value(dt, _dim(pa.timestamp("us", tz="UTC")))
|
||||
assert out == datetime(2025, 1, 2, 3, 4, 5, tzinfo=ZoneInfo("UTC"))
|
||||
|
||||
|
||||
def test_coerce_timestamp_tz_aware_dtype_converts_aware_datetime() -> None:
|
||||
dt = datetime(2025, 1, 2, 12, 0, tzinfo=timezone.utc)
|
||||
out = _coerce_scalar_filter_value(
|
||||
dt, _dim(pa.timestamp("us", tz="America/New_York"))
|
||||
)
|
||||
# 12:00 UTC == 07:00 in New York
|
||||
assert out == datetime(2025, 1, 2, 7, 0, tzinfo=ZoneInfo("America/New_York"))
|
||||
|
||||
|
||||
def test_coerce_timestamp_tz_aware_dtype_attaches_tz_to_date() -> None:
|
||||
out = _coerce_scalar_filter_value(
|
||||
date(2025, 1, 2), _dim(pa.timestamp("us", tz="UTC"))
|
||||
)
|
||||
assert out == datetime(2025, 1, 2, 0, 0, tzinfo=ZoneInfo("UTC"))
|
||||
|
||||
|
||||
def test_coerce_timestamp_tz_aware_dtype_parses_string_with_tz() -> None:
|
||||
out = _coerce_scalar_filter_value(
|
||||
"2025-01-02T03:04:05", _dim(pa.timestamp("us", tz="UTC"))
|
||||
)
|
||||
# Naive string gets UTC attached.
|
||||
assert out == datetime(2025, 1, 2, 3, 4, 5, tzinfo=ZoneInfo("UTC"))
|
||||
|
||||
|
||||
def test_coerce_time_passthrough() -> None:
|
||||
out = _coerce_scalar_filter_value(time(3, 4, 5), _dim(pa.time64("us")))
|
||||
assert out == time(3, 4, 5)
|
||||
|
||||
|
||||
def test_coerce_time_from_iso_string() -> None:
|
||||
out = _coerce_scalar_filter_value(" 03:04:05 ", _dim(pa.time64("us")))
|
||||
assert out == time(3, 4, 5)
|
||||
|
||||
|
||||
def test_coerce_time_invalid_string_raises() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid time value"):
|
||||
_coerce_scalar_filter_value("not-a-time", _dim(pa.time64("us")))
|
||||
|
||||
|
||||
def test_coerce_time_rejects_other_types() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid time value"):
|
||||
_coerce_scalar_filter_value(123, _dim(pa.time64("us")))
|
||||
|
||||
@@ -811,3 +811,90 @@ class TestWebDriverPlaywrightErrorHandling:
|
||||
mock_logger.exception.assert_any_call(
|
||||
"Timed out requesting url %s", "http://example.com"
|
||||
)
|
||||
|
||||
@patch("superset.utils.webdriver.PLAYWRIGHT_AVAILABLE", True)
|
||||
@patch("superset.utils.webdriver.sync_playwright")
|
||||
@patch("superset.utils.webdriver.logger")
|
||||
def test_missing_element_for_dashboard_height_falls_back_without_crashing(
|
||||
self, mock_logger, mock_sync_playwright
|
||||
):
|
||||
"""Missing dashboard element should not crash height evaluation."""
|
||||
mock_user = MagicMock()
|
||||
mock_user.username = "test_user"
|
||||
|
||||
mock_playwright_instance = MagicMock()
|
||||
mock_browser = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_page = MagicMock()
|
||||
mock_element = MagicMock()
|
||||
mock_chart_container = MagicMock()
|
||||
|
||||
mock_sync_playwright.return_value.__enter__.return_value = (
|
||||
mock_playwright_instance
|
||||
)
|
||||
mock_playwright_instance.chromium.launch.return_value = mock_browser
|
||||
mock_browser.new_context.return_value = mock_context
|
||||
mock_context.new_page.return_value = mock_page
|
||||
|
||||
def locator_side_effect(selector):
|
||||
if selector == ".dashboard":
|
||||
return mock_element
|
||||
if selector == ".chart-container":
|
||||
locator = MagicMock()
|
||||
locator.all.return_value = [mock_chart_container]
|
||||
return locator
|
||||
if selector == ".loading":
|
||||
locator = MagicMock()
|
||||
locator.all.return_value = []
|
||||
return locator
|
||||
return MagicMock()
|
||||
|
||||
mock_page.locator.side_effect = locator_side_effect
|
||||
mock_element.wait_for.return_value = None
|
||||
mock_element.screenshot.return_value = b"fake_screenshot"
|
||||
mock_chart_container.wait_for.return_value = None
|
||||
mock_page.wait_for_timeout.return_value = None
|
||||
|
||||
def evaluate_side_effect(script):
|
||||
if script == 'document.querySelectorAll(".chart-container").length':
|
||||
return 1
|
||||
if "const target = document.querySelector" in script:
|
||||
return 0
|
||||
return None
|
||||
|
||||
mock_page.evaluate.side_effect = evaluate_side_effect
|
||||
|
||||
with patch("superset.utils.webdriver.app") as mock_app:
|
||||
mock_app.config = {
|
||||
"WEBDRIVER_OPTION_ARGS": [],
|
||||
"WEBDRIVER_WINDOW": {"pixel_density": 1},
|
||||
"SCREENSHOT_PLAYWRIGHT_DEFAULT_TIMEOUT": 30000,
|
||||
"SCREENSHOT_PLAYWRIGHT_WAIT_EVENT": "networkidle",
|
||||
"SCREENSHOT_SELENIUM_HEADSTART": 5,
|
||||
"SCREENSHOT_SELENIUM_ANIMATION_WAIT": 1,
|
||||
"SCREENSHOT_LOCATE_WAIT": 10,
|
||||
"SCREENSHOT_LOAD_WAIT": 10,
|
||||
"SCREENSHOT_WAIT_FOR_ERROR_MODAL_VISIBLE": 10,
|
||||
"SCREENSHOT_WAIT_FOR_ERROR_MODAL_INVISIBLE": 10,
|
||||
"SCREENSHOT_REPLACE_UNEXPECTED_ERRORS": False,
|
||||
"SCREENSHOT_TILED_ENABLED": True,
|
||||
"SCREENSHOT_TILED_CHART_THRESHOLD": 20,
|
||||
"SCREENSHOT_TILED_HEIGHT_THRESHOLD": 5000,
|
||||
"SCREENSHOT_TILED_VIEWPORT_HEIGHT": 600,
|
||||
}
|
||||
|
||||
with patch.object(WebDriverPlaywright, "auth") as mock_auth:
|
||||
mock_auth.return_value = mock_context
|
||||
|
||||
driver = WebDriverPlaywright("chrome")
|
||||
result = driver.get_screenshot(
|
||||
"http://example.com", "dashboard", mock_user
|
||||
)
|
||||
|
||||
assert result == b"fake_screenshot"
|
||||
mock_logger.warning.assert_any_call(
|
||||
"Could not determine dashboard height for element %s at url %s; "
|
||||
"falling back to standard screenshot behavior",
|
||||
"dashboard",
|
||||
"http://example.com",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user