From e41653fafff122d482d78af0436d6920713ce131 Mon Sep 17 00:00:00 2001 From: Kamil Gabryjelski Date: Fri, 13 Mar 2026 11:58:12 +0100 Subject: [PATCH] fix(mcp): Support form_data_key without chart identifier for unsaved charts (#38628) (cherry picked from commit af5e05db2e4cdbc3864a04fe432d221898301c71) --- superset/mcp_service/chart/schemas.py | 86 +++++- .../mcp_service/chart/tool/get_chart_data.py | 247 +++++++++++++++--- .../mcp_service/chart/tool/get_chart_info.py | 111 ++++++-- .../chart/tool/get_chart_preview.py | 112 +++++++- 4 files changed, 481 insertions(+), 75 deletions(-) diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index 10b694d22c2..ac40d1f1662 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -111,6 +111,14 @@ class ChartInfo(BaseModel): owners: List[UserInfo] = Field(default_factory=list, description="Chart owners") # Fields for unsaved state support + form_data: Dict[str, Any] | None = Field( + None, + description=( + "The chart's form_data configuration. When form_data_key is provided, " + "this contains the unsaved (cached) configuration rather than the " + "saved version." + ), + ) form_data_key: str | None = Field( None, description=( @@ -218,27 +226,45 @@ class VersionedResponse(BaseModel): class GetChartInfoRequest(BaseModel): - """Request schema for get_chart_info with support for ID or UUID. + """Request schema for get_chart_info with support for ID, UUID, or form_data_key. When form_data_key is provided, the tool will retrieve the unsaved chart state from cache, allowing you to explain what the user actually sees (not the saved version). This is useful when a user edits a chart in Explore but hasn't saved yet. + + For unsaved charts (no chart ID), provide only form_data_key to retrieve the + current chart configuration from cache. """ identifier: Annotated[ - int | str, - Field(description="Chart identifier - can be numeric ID or UUID string"), + int | str | None, + Field( + default=None, + description=( + "Chart identifier - can be numeric ID or UUID string. " + "Optional when form_data_key is provided (for unsaved charts)." + ), + ), ] form_data_key: str | None = Field( default=None, description=( - "Optional cache key for retrieving unsaved chart state. When a user " + "Cache key for retrieving unsaved chart state. When a user " "edits a chart in Explore but hasn't saved, the current state is stored " "with this key. If provided, the tool returns the current unsaved " - "configuration instead of the saved version." + "configuration instead of the saved version. " + "Can be used alone (without identifier) for unsaved charts." ), ) + @model_validator(mode="after") + def validate_identifier_or_form_data_key(self) -> "GetChartInfoRequest": + if not self.identifier and not self.form_data_key: + raise ValueError( + "At least one of 'identifier' or 'form_data_key' must be provided." + ) + return self + def serialize_chart_object(chart: ChartLike | None) -> ChartInfo | None: if not chart: @@ -1010,18 +1036,37 @@ class GetChartDataRequest(QueryCacheControl): from cache to query data, allowing you to get data for what the user actually sees (not the saved version). This is useful when a user edits a chart in Explore but hasn't saved yet. + + For unsaved charts (no chart ID), provide only form_data_key to query data using + the current chart configuration from cache. """ - identifier: int | str = Field(description="Chart identifier (ID, UUID)") + identifier: int | str | None = Field( + default=None, + description=( + "Chart identifier (ID, UUID). " + "Optional when form_data_key is provided (for unsaved charts)." + ), + ) form_data_key: str | None = Field( default=None, description=( - "Optional cache key for retrieving unsaved chart state. When a user " + "Cache key for retrieving unsaved chart state. When a user " "edits a chart in Explore but hasn't saved, the current state is stored " "with this key. If provided, the tool uses this configuration to query " - "data instead of the saved chart configuration." + "data instead of the saved chart configuration. " + "Can be used alone (without identifier) for unsaved charts." ), ) + + @model_validator(mode="after") + def validate_identifier_or_form_data_key(self) -> "GetChartDataRequest": + if not self.identifier and not self.form_data_key: + raise ValueError( + "At least one of 'identifier' or 'form_data_key' must be provided." + ) + return self + limit: int | None = Field( default=None, description=( @@ -1103,18 +1148,37 @@ class GetChartPreviewRequest(QueryCacheControl): chart configuration from cache, allowing you to preview what the user actually sees (not the saved version). This is useful when a user edits a chart in Explore but hasn't saved yet. + + For unsaved charts (no chart ID), provide only form_data_key to render a preview + using the current chart configuration from cache. """ - identifier: int | str = Field(description="Chart identifier (ID, UUID)") + identifier: int | str | None = Field( + default=None, + description=( + "Chart identifier (ID, UUID). " + "Optional when form_data_key is provided (for unsaved charts)." + ), + ) form_data_key: str | None = Field( default=None, description=( - "Optional cache key for retrieving unsaved chart state. When a user " + "Cache key for retrieving unsaved chart state. When a user " "edits a chart in Explore but hasn't saved, the current state is stored " "with this key. If provided, the tool renders a preview using this " - "configuration instead of the saved chart configuration." + "configuration instead of the saved chart configuration. " + "Can be used alone (without identifier) for unsaved charts." ), ) + + @model_validator(mode="after") + def validate_identifier_or_form_data_key(self) -> "GetChartPreviewRequest": + if not self.identifier and not self.form_data_key: + raise ValueError( + "At least one of 'identifier' or 'form_data_key' must be provided." + ) + return self + format: Literal["url", "ascii", "table", "vega_lite"] = Field( default="url", description=( diff --git a/superset/mcp_service/chart/tool/get_chart_data.py b/superset/mcp_service/chart/tool/get_chart_data.py index 9239a278d9e..c2cf960dbeb 100644 --- a/superset/mcp_service/chart/tool/get_chart_data.py +++ b/superset/mcp_service/chart/tool/get_chart_data.py @@ -109,9 +109,42 @@ async def get_chart_data( # noqa: C901 from superset.daos.chart import ChartDAO from superset.utils import json as utils_json - # Find the chart + chart = None + + # Handle unsaved chart (form_data_key only, no identifier) + if not request.identifier and request.form_data_key: + with event_logger.log_context( + action="mcp.get_chart_data.unsaved_chart_from_cache" + ): + await ctx.info( + "No chart identifier - querying data from unsaved chart cache: " + "form_data_key=%s" % (request.form_data_key,) + ) + cached_form_data = _get_cached_form_data(request.form_data_key) + if not cached_form_data: + return ChartError( + error="No cached chart data found for form_data_key. " + "The cache may have expired.", + error_type="NotFound", + ) + try: + cached_form_data_dict = utils_json.loads(cached_form_data) + except (TypeError, ValueError) as e: + return ChartError( + error=f"Failed to parse cached form_data: {e}", + error_type="ParseError", + ) + if not isinstance(cached_form_data_dict, dict): + return ChartError( + error="Cached form_data is not a valid JSON object.", + error_type="ParseError", + ) + + # Build query context entirely from cached form_data + return await _query_from_form_data(cached_form_data_dict, request, ctx) + + # Find the chart by identifier with event_logger.log_context(action="mcp.get_chart_data.chart_lookup"): - chart = None if isinstance(request.identifier, int) or ( isinstance(request.identifier, str) and request.identifier.isdigit() ): @@ -124,7 +157,7 @@ async def get_chart_data( # noqa: C901 "Performing ID-based chart lookup: chart_id=%s" % (chart_id,) ) chart = ChartDAO.find_by_id(chart_id) - else: + elif isinstance(request.identifier, str): await ctx.debug( "Performing UUID-based chart lookup: uuid=%s" % (request.identifier,) @@ -178,36 +211,40 @@ async def get_chart_data( # noqa: C901 # Check if form_data_key is provided - use cached form_data instead if request.form_data_key: - await ctx.info( - "Retrieving unsaved chart state from cache: form_data_key=%s" - % (request.form_data_key,) - ) - if cached_form_data := _get_cached_form_data(request.form_data_key): - try: - parsed_form_data = utils_json.loads(cached_form_data) - # Only use if it's actually a dict (not null, list, etc.) - if isinstance(parsed_form_data, dict): - cached_form_data_dict = parsed_form_data - using_unsaved_state = True - await ctx.info( - "Using cached form_data from form_data_key " - "for data query" - ) - else: - await ctx.warning( - "Cached form_data is not a JSON object. " - "Falling back to saved chart configuration." - ) - except (TypeError, ValueError) as e: - await ctx.warning( - "Failed to parse cached form_data: %s. " - "Falling back to saved chart configuration." % str(e) - ) - else: - await ctx.warning( - "form_data_key provided but no cached data found. " - "The cache may have expired. Using saved chart configuration." + with event_logger.log_context( + action="mcp.get_chart_data.unsaved_state_override" + ): + await ctx.info( + "Retrieving unsaved chart state from cache: form_data_key=%s" + % (request.form_data_key,) ) + if cached_form_data := _get_cached_form_data(request.form_data_key): + try: + parsed_form_data = utils_json.loads(cached_form_data) + # Only use if it's actually a dict (not null, list, etc.) + if isinstance(parsed_form_data, dict): + cached_form_data_dict = parsed_form_data + using_unsaved_state = True + await ctx.info( + "Using cached form_data from form_data_key " + "for data query" + ) + else: + await ctx.warning( + "Cached form_data is not a JSON object. " + "Falling back to saved chart configuration." + ) + except (TypeError, ValueError) as e: + await ctx.warning( + "Failed to parse cached form_data: %s. " + "Falling back to saved chart configuration." % str(e) + ) + else: + await ctx.warning( + "form_data_key provided but no cached data found. " + "The cache may have expired. Using saved chart " + "configuration." + ) # Use the chart's saved query_context - this is the key! # The query_context contains all the information needed to reproduce @@ -707,6 +744,152 @@ async def get_chart_data( # noqa: C901 ) +async def _query_from_form_data( + form_data: Dict[str, Any], + request: GetChartDataRequest, + ctx: Context, +) -> ChartData | ChartError: + """Query chart data using only cached form_data (no saved chart). + + Used for unsaved charts where we only have form_data_key. + """ + from superset.commands.chart.data.get_data_command import ChartDataCommand + from superset.common.query_context_factory import QueryContextFactory + + datasource_id = form_data.get("datasource_id") + datasource_type: str = form_data.get("datasource_type") or "table" + + # Handle combined datasource field (e.g., "1__table") + if not datasource_id and form_data.get("datasource"): + parts = str(form_data["datasource"]).split("__") + if len(parts) == 2: + datasource_id, datasource_type = parts[0], parts[1] + + if not datasource_id: + return ChartError( + error="Cached form_data does not contain datasource information.", + error_type="InvalidFormData", + ) + + viz_type = form_data.get("viz_type", "unknown") + row_limit = ( + request.limit or form_data.get("row_limit") or current_app.config["ROW_LIMIT"] + ) + + # Extract metrics and groupby based on chart type + if viz_type in ("big_number", "big_number_total", "pop_kpi"): + metric = form_data.get("metric") + metrics = [metric] if metric else [] + groupby: list[str] = [] + else: + metrics = form_data.get("metrics", []) + groupby = list(form_data.get("groupby") or []) + + try: + factory = QueryContextFactory() + query_context = factory.create( + datasource={"id": datasource_id, "type": datasource_type}, + queries=[ + { + "filters": form_data.get("filters", []), + "columns": groupby, + "metrics": metrics, + "row_limit": row_limit, + "order_desc": form_data.get("order_desc", True), + } + ], + form_data=form_data, + force=request.force_refresh, + ) + + await ctx.report_progress(3, 4, "Executing data query") + with event_logger.log_context(action="mcp.get_chart_data.query_execution"): + command = ChartDataCommand(query_context) + command.validate() + result = command.run() + + if not result or "queries" not in result or len(result["queries"]) == 0: + return ChartError( + error="No query results returned for unsaved chart.", + error_type="EmptyQuery", + ) + + query_result = result["queries"][0] + data = query_result.get("data", []) + raw_columns = query_result.get("colnames", []) + + if not data: + return ChartError( + error="No data available for unsaved chart.", + error_type="NoData", + ) + + columns = [] + for col_name in raw_columns: + sample_values = [ + row.get(col_name) for row in data[:3] if row.get(col_name) is not None + ] + data_type = "string" + if sample_values and all( + isinstance(v, (int, float)) for v in sample_values + ): + data_type = "numeric" + columns.append( + DataColumn( + name=col_name, + display_name=col_name.replace("_", " ").title(), + data_type=data_type, + sample_values=sample_values[:3], + null_count=sum(1 for row in data if row.get(col_name) is None), + unique_count=len({str(row.get(col_name)) for row in data}), + ) + ) + + cache_status = get_cache_status_from_result( + query_result, force_refresh=request.force_refresh + ) + + chart_name = form_data.get("slice_name", "Unsaved chart") + summary = ( + f"Unsaved chart ({viz_type}). " + f"Contains {len(data)} rows across {len(raw_columns)} columns." + ) + + await ctx.report_progress(4, 4, "Building response") + return ChartData( + chart_id=0, + chart_name=chart_name, + chart_type=viz_type, + columns=columns, + data=data[: request.limit] if request.limit else data, + row_count=len(data), + total_rows=query_result.get("rowcount"), + summary=summary, + insights=["This is an unsaved chart queried from cached form_data."], + data_quality={ + "completeness": 1.0 + - ( + sum(col.null_count for col in columns) + / max(len(data) * len(columns), 1) + ) + }, + recommended_visualizations=[], + data_freshness=None, + performance=PerformanceMetadata( + query_duration_ms=0, + cache_status="fresh_query", + ), + cache_status=cache_status, + ) + + except (CommandException, SupersetException, ValueError) as e: + logger.error("Error querying unsaved chart data: %s", e) + return ChartError( + error=f"Error querying unsaved chart data: {e}", + error_type="DataError", + ) + + def _export_data_as_csv( chart: "Slice", data: List[Dict[str, Any]], diff --git a/superset/mcp_service/chart/tool/get_chart_info.py b/superset/mcp_service/chart/tool/get_chart_info.py index 95d0af23ff3..7029de2ffa6 100644 --- a/superset/mcp_service/chart/tool/get_chart_info.py +++ b/superset/mcp_service/chart/tool/get_chart_info.py @@ -56,6 +56,65 @@ def _get_cached_form_data(form_data_key: str) -> str | None: return None +def _build_unsaved_chart_info(form_data_key: str) -> ChartInfo | ChartError: + """Build a ChartInfo from cached form_data when no chart identifier exists.""" + from superset.utils import json as utils_json + + cached_form_data = _get_cached_form_data(form_data_key) + if not cached_form_data: + return ChartError( + error="No cached chart data found for form_data_key. " + "The cache may have expired.", + error_type="NotFound", + ) + try: + form_data = utils_json.loads(cached_form_data) + except (TypeError, ValueError) as e: + return ChartError( + error=f"Failed to parse cached form_data: {e}", + error_type="ParseError", + ) + if not isinstance(form_data, dict): + return ChartError( + error="Cached form_data is not a valid JSON object.", + error_type="ParseError", + ) + return ChartInfo( + viz_type=form_data.get("viz_type"), + datasource_name=form_data.get("datasource_name"), + datasource_type=form_data.get("datasource_type"), + form_data=form_data, + form_data_key=form_data_key, + is_unsaved_state=True, + ) + + +def _apply_unsaved_state_override(result: ChartInfo, form_data_key: str) -> None: + """Override a ChartInfo's form_data with cached unsaved state.""" + from superset.utils import json as utils_json + + if cached_form_data := _get_cached_form_data(form_data_key): + try: + result.form_data = utils_json.loads(cached_form_data) + result.form_data_key = form_data_key + result.is_unsaved_state = True + + # Update viz_type from cached form_data if present + if result.form_data and "viz_type" in result.form_data: + result.viz_type = result.form_data["viz_type"] + except (TypeError, ValueError) as e: + logger.warning( + "Failed to parse cached form_data: %s. " + "Using saved chart configuration.", + e, + ) + else: + logger.warning( + "form_data_key provided but no cached data found. " + "The cache may have expired. Using saved chart configuration." + ) + + @tool(tags=["discovery"]) @parse_request(GetChartInfoRequest) async def get_chart_info( @@ -96,13 +155,28 @@ async def get_chart_info( """ from superset.daos.chart import ChartDAO from superset.models.slice import Slice - from superset.utils import json as utils_json await ctx.info( "Retrieving chart information: identifier=%s, form_data_key=%s" % (request.identifier, request.form_data_key) ) + # Handle unsaved chart (form_data_key only, no identifier) + if not request.identifier and request.form_data_key: + with event_logger.log_context( + action="mcp.get_chart_info.unsaved_chart_from_cache" + ): + await ctx.info( + "No chart identifier provided - retrieving unsaved chart from cache: " + "form_data_key=%s" % (request.form_data_key,) + ) + return _build_unsaved_chart_info(request.form_data_key) + + # At this point identifier must be set (validator ensures at least one + # of identifier/form_data_key is provided, and the form_data_key-only + # branch returned above). + assert request.identifier is not None + # Eager load owners and tags to avoid N+1 queries during serialization eager_options = [ subqueryload(Slice.owners), @@ -125,35 +199,14 @@ async def get_chart_info( if isinstance(result, ChartInfo): # If form_data_key is provided, override form_data with cached version if request.form_data_key: - await ctx.info( - "Retrieving unsaved chart state from cache: form_data_key=%s" - % (request.form_data_key,) - ) - cached_form_data = _get_cached_form_data(request.form_data_key) - - if cached_form_data: - try: - result.form_data = utils_json.loads(cached_form_data) - result.form_data_key = request.form_data_key - result.is_unsaved_state = True - - # Update viz_type from cached form_data if present - if result.form_data and "viz_type" in result.form_data: - result.viz_type = result.form_data["viz_type"] - - await ctx.info( - "Chart form_data overridden with unsaved state from cache" - ) - except (TypeError, ValueError) as e: - await ctx.warning( - "Failed to parse cached form_data: %s. " - "Using saved chart configuration." % (str(e),) - ) - else: - await ctx.warning( - "form_data_key provided but no cached data found. " - "The cache may have expired. Using saved chart configuration." + with event_logger.log_context( + action="mcp.get_chart_info.unsaved_state_override" + ): + await ctx.info( + "Retrieving unsaved chart state from cache: form_data_key=%s" + % (request.form_data_key,) ) + _apply_unsaved_state_override(result, request.form_data_key) await ctx.info( "Chart information retrieved successfully: chart_name=%s, " diff --git a/superset/mcp_service/chart/tool/get_chart_preview.py b/superset/mcp_service/chart/tool/get_chart_preview.py index 6ee3ece4df4..adb7c1c15ac 100644 --- a/superset/mcp_service/chart/tool/get_chart_preview.py +++ b/superset/mcp_service/chart/tool/get_chart_preview.py @@ -1837,7 +1837,71 @@ async def _get_chart_preview_internal( # noqa: C901 # Find the chart with event_logger.log_context(action="mcp.get_chart_preview.chart_lookup"): chart: Any = None - if isinstance(request.identifier, int) or ( + + # Handle unsaved chart (form_data_key only, no identifier) + if not request.identifier and request.form_data_key: + with event_logger.log_context( + action="mcp.get_chart_preview.unsaved_chart_from_cache" + ): + await ctx.info( + "No chart identifier - creating transient chart from " + "form_data_key=%s" % (request.form_data_key,) + ) + from superset.commands.explore.form_data.get import ( + GetFormDataCommand, + ) + from superset.commands.explore.form_data.parameters import ( + CommandParameters, + ) + from superset.utils import json as utils_json + + try: + cmd_params = CommandParameters(key=request.form_data_key) + form_data_json = GetFormDataCommand(cmd_params).run() + if form_data_json: + form_data = utils_json.loads(form_data_json) + + class TransientChartFromKey: + def __init__(self, fd: Dict[str, Any]): + self.id = 0 + self.slice_name = "Unsaved Chart Preview" + self.viz_type = fd.get("viz_type", "table") + ds = fd.get("datasource", "") + parts = str(ds).split("__") if ds else [] + self.datasource_id = ( + int(parts[0]) + if len(parts) == 2 + else fd.get("datasource_id") + ) + self.datasource_type = ( + parts[1] + if len(parts) == 2 + else fd.get("datasource_type", "table") + ) + self.params = utils_json.dumps(fd) + self.form_data = fd + self.uuid = None + + chart = TransientChartFromKey(form_data) + except ( + CommandException, + ValueError, + KeyError, + AttributeError, + TypeError, + ) as e: + logger.warning( + "Failed to get form data for key %s: %s", + request.form_data_key, + e, + ) + return ChartError( + error="No cached chart data found for form_data_key. " + "The cache may have expired.", + error_type="NotFound", + ) + + elif isinstance(request.identifier, int) or ( isinstance(request.identifier, str) and request.identifier.isdigit() ): chart_id = ( @@ -1849,7 +1913,7 @@ async def _get_chart_preview_internal( # noqa: C901 "Performing ID-based chart lookup: chart_id=%s" % (chart_id,) ) chart = ChartDAO.find_by_id(chart_id) - else: + elif isinstance(request.identifier, str): await ctx.debug( "Performing UUID-based chart lookup: uuid=%s" % (request.identifier,) @@ -1883,7 +1947,7 @@ async def _get_chart_preview_internal( # noqa: C901 # Create a transient chart object from form data class TransientChart: def __init__(self, form_data: Dict[str, Any]): - self.id = None + self.id = 0 self.slice_name = "Unsaved Chart Preview" self.viz_type = form_data.get("viz_type", "table") self.datasource_id = None @@ -1951,6 +2015,48 @@ async def _get_chart_preview_internal( # noqa: C901 for warning in validation_result.warnings: await ctx.warning("Dataset warning: %s" % (warning,)) + # If form_data_key is provided, override chart.params with cached + # form_data so the preview reflects what the user actually sees + if request.form_data_key and getattr(chart, "id", None) is not None: + with event_logger.log_context( + action="mcp.get_chart_preview.unsaved_state_override" + ): + await ctx.info( + "Retrieving unsaved chart state from cache: form_data_key=%s" + % (request.form_data_key,) + ) + from superset.commands.explore.form_data.get import ( + GetFormDataCommand, + ) + from superset.commands.explore.form_data.parameters import ( + CommandParameters, + ) + + try: + cmd_params = CommandParameters(key=request.form_data_key) + cached_form_data = GetFormDataCommand(cmd_params).run() + if cached_form_data: + chart.params = cached_form_data + from superset.utils import json as utils_json + + parsed = utils_json.loads(cached_form_data) + if isinstance(parsed, dict) and "viz_type" in parsed: + chart.viz_type = parsed["viz_type"] + await ctx.info( + "Chart params overridden with unsaved state from cache" + ) + else: + await ctx.warning( + "form_data_key provided but no cached data found. " + "The cache may have expired. Using saved chart " + "configuration." + ) + except (CommandException, ValueError, KeyError) as e: + await ctx.warning( + "Failed to retrieve cached form_data: %s. " + "Using saved chart configuration." % (str(e),) + ) + import time start_time = time.time()