diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index ea5459cfbc4..155d455f372 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -1084,6 +1084,14 @@ class GetChartDataRequest(QueryCacheControl): "chart's configured row limit." ), ) + extra_form_data: dict[str, Any] | None = Field( + default=None, + description=( + "Extra form data to merge into the chart query, typically from " + "dashboard native filters. Format: " + '{"filters": [{"col": "country", "op": "IN", "val": ["US"]}]}' + ), + ) format: Literal["json", "csv", "excel"] = Field( default="json", description="Data export format" ) diff --git a/superset/mcp_service/chart/tool/get_chart_data.py b/superset/mcp_service/chart/tool/get_chart_data.py index ed7a4e20c33..b6edb257f08 100644 --- a/superset/mcp_service/chart/tool/get_chart_data.py +++ b/superset/mcp_service/chart/tool/get_chart_data.py @@ -44,10 +44,21 @@ from superset.mcp_service.chart.schemas import ( ) from superset.mcp_service.utils.cache_utils import get_cache_status_from_result from superset.mcp_service.utils.schema_utils import parse_request +from superset.utils.core import merge_extra_filters logger = logging.getLogger(__name__) +def _apply_extra_form_data( + form_data: dict[str, Any], extra_form_data: dict[str, Any] | None +) -> None: + """Merge dashboard native filters into chart form_data in-place.""" + if not extra_form_data: + return + form_data["extra_form_data"] = extra_form_data + merge_extra_filters(form_data) + + def _get_cached_form_data(form_data_key: str) -> str | None: """Retrieve form_data from cache using form_data_key. @@ -286,20 +297,26 @@ async def get_chart_data( # noqa: C901 cached_metrics = cached_form_data_dict.get("metrics", []) cached_groupby = cached_form_data_dict.get("groupby", []) + _apply_extra_form_data(cached_form_data_dict, request.extra_form_data) + + cached_query: dict[str, Any] = { + "filters": cached_form_data_dict.get("filters", []), + "columns": cached_groupby, + "metrics": cached_metrics, + "row_limit": row_limit, + "order_desc": cached_form_data_dict.get("order_desc", True), + } + # Include adhoc_filters so dashboard native filters are applied + cached_adhoc = cached_form_data_dict.get("adhoc_filters") + if cached_adhoc: + cached_query["adhoc_filters"] = cached_adhoc + query_context = factory.create( datasource={ "id": datasource_id, "type": datasource_type, }, - queries=[ - { - "filters": cached_form_data_dict.get("filters", []), - "columns": cached_groupby, - "metrics": cached_metrics, - "row_limit": row_limit, - "order_desc": cached_form_data_dict.get("order_desc", True), - } - ], + queries=[cached_query], form_data=cached_form_data_dict, force=request.force_refresh, ) @@ -454,20 +471,26 @@ async def get_chart_data( # noqa: C901 error_type="MissingQueryContext", ) + _apply_extra_form_data(form_data, request.extra_form_data) + + fallback_query: dict[str, Any] = { + "filters": form_data.get("filters", []), + "columns": query_columns, + "metrics": metrics, + "row_limit": row_limit, + "order_desc": True, + } + # Include adhoc_filters so dashboard native filters are applied + fallback_adhoc = form_data.get("adhoc_filters") + if fallback_adhoc: + fallback_query["adhoc_filters"] = fallback_adhoc + query_context = factory.create( datasource={ "id": chart.datasource_id, "type": chart.datasource_type, }, - queries=[ - { - "filters": form_data.get("filters", []), - "columns": query_columns, - "metrics": metrics, - "row_limit": row_limit, - "order_desc": True, - } - ], + queries=[fallback_query], form_data=form_data, force=request.force_refresh, ) @@ -480,6 +503,10 @@ async def get_chart_data( # noqa: C901 for query in query_context_json.get("queries", []): query["row_limit"] = request.limit + # Merge dashboard native filters into query_context's form_data + qc_form_data = query_context_json.setdefault("form_data", {}) + _apply_extra_form_data(qc_form_data, request.extra_form_data) + # Create QueryContext from the saved context using the schema # This is exactly how the API does it query_context = ChartDataQueryContextSchema().load(query_context_json)