fix(mcp): Support form_data_key without chart identifier for unsaved charts (#38628)

(cherry picked from commit af5e05db2e)
This commit is contained in:
Kamil Gabryjelski
2026-03-13 11:58:12 +01:00
committed by Michael S. Molina
parent 39bdccd7ee
commit e41653faff
4 changed files with 481 additions and 75 deletions

View File

@@ -109,9 +109,42 @@ async def get_chart_data( # noqa: C901
from superset.daos.chart import ChartDAO
from superset.utils import json as utils_json
# Find the chart
chart = None
# Handle unsaved chart (form_data_key only, no identifier)
if not request.identifier and request.form_data_key:
with event_logger.log_context(
action="mcp.get_chart_data.unsaved_chart_from_cache"
):
await ctx.info(
"No chart identifier - querying data from unsaved chart cache: "
"form_data_key=%s" % (request.form_data_key,)
)
cached_form_data = _get_cached_form_data(request.form_data_key)
if not cached_form_data:
return ChartError(
error="No cached chart data found for form_data_key. "
"The cache may have expired.",
error_type="NotFound",
)
try:
cached_form_data_dict = utils_json.loads(cached_form_data)
except (TypeError, ValueError) as e:
return ChartError(
error=f"Failed to parse cached form_data: {e}",
error_type="ParseError",
)
if not isinstance(cached_form_data_dict, dict):
return ChartError(
error="Cached form_data is not a valid JSON object.",
error_type="ParseError",
)
# Build query context entirely from cached form_data
return await _query_from_form_data(cached_form_data_dict, request, ctx)
# Find the chart by identifier
with event_logger.log_context(action="mcp.get_chart_data.chart_lookup"):
chart = None
if isinstance(request.identifier, int) or (
isinstance(request.identifier, str) and request.identifier.isdigit()
):
@@ -124,7 +157,7 @@ async def get_chart_data( # noqa: C901
"Performing ID-based chart lookup: chart_id=%s" % (chart_id,)
)
chart = ChartDAO.find_by_id(chart_id)
else:
elif isinstance(request.identifier, str):
await ctx.debug(
"Performing UUID-based chart lookup: uuid=%s"
% (request.identifier,)
@@ -178,36 +211,40 @@ async def get_chart_data( # noqa: C901
# Check if form_data_key is provided - use cached form_data instead
if request.form_data_key:
await ctx.info(
"Retrieving unsaved chart state from cache: form_data_key=%s"
% (request.form_data_key,)
)
if cached_form_data := _get_cached_form_data(request.form_data_key):
try:
parsed_form_data = utils_json.loads(cached_form_data)
# Only use if it's actually a dict (not null, list, etc.)
if isinstance(parsed_form_data, dict):
cached_form_data_dict = parsed_form_data
using_unsaved_state = True
await ctx.info(
"Using cached form_data from form_data_key "
"for data query"
)
else:
await ctx.warning(
"Cached form_data is not a JSON object. "
"Falling back to saved chart configuration."
)
except (TypeError, ValueError) as e:
await ctx.warning(
"Failed to parse cached form_data: %s. "
"Falling back to saved chart configuration." % str(e)
)
else:
await ctx.warning(
"form_data_key provided but no cached data found. "
"The cache may have expired. Using saved chart configuration."
with event_logger.log_context(
action="mcp.get_chart_data.unsaved_state_override"
):
await ctx.info(
"Retrieving unsaved chart state from cache: form_data_key=%s"
% (request.form_data_key,)
)
if cached_form_data := _get_cached_form_data(request.form_data_key):
try:
parsed_form_data = utils_json.loads(cached_form_data)
# Only use if it's actually a dict (not null, list, etc.)
if isinstance(parsed_form_data, dict):
cached_form_data_dict = parsed_form_data
using_unsaved_state = True
await ctx.info(
"Using cached form_data from form_data_key "
"for data query"
)
else:
await ctx.warning(
"Cached form_data is not a JSON object. "
"Falling back to saved chart configuration."
)
except (TypeError, ValueError) as e:
await ctx.warning(
"Failed to parse cached form_data: %s. "
"Falling back to saved chart configuration." % str(e)
)
else:
await ctx.warning(
"form_data_key provided but no cached data found. "
"The cache may have expired. Using saved chart "
"configuration."
)
# Use the chart's saved query_context - this is the key!
# The query_context contains all the information needed to reproduce
@@ -707,6 +744,152 @@ async def get_chart_data( # noqa: C901
)
async def _query_from_form_data(
form_data: Dict[str, Any],
request: GetChartDataRequest,
ctx: Context,
) -> ChartData | ChartError:
"""Query chart data using only cached form_data (no saved chart).
Used for unsaved charts where we only have form_data_key.
"""
from superset.commands.chart.data.get_data_command import ChartDataCommand
from superset.common.query_context_factory import QueryContextFactory
datasource_id = form_data.get("datasource_id")
datasource_type: str = form_data.get("datasource_type") or "table"
# Handle combined datasource field (e.g., "1__table")
if not datasource_id and form_data.get("datasource"):
parts = str(form_data["datasource"]).split("__")
if len(parts) == 2:
datasource_id, datasource_type = parts[0], parts[1]
if not datasource_id:
return ChartError(
error="Cached form_data does not contain datasource information.",
error_type="InvalidFormData",
)
viz_type = form_data.get("viz_type", "unknown")
row_limit = (
request.limit or form_data.get("row_limit") or current_app.config["ROW_LIMIT"]
)
# Extract metrics and groupby based on chart type
if viz_type in ("big_number", "big_number_total", "pop_kpi"):
metric = form_data.get("metric")
metrics = [metric] if metric else []
groupby: list[str] = []
else:
metrics = form_data.get("metrics", [])
groupby = list(form_data.get("groupby") or [])
try:
factory = QueryContextFactory()
query_context = factory.create(
datasource={"id": datasource_id, "type": datasource_type},
queries=[
{
"filters": form_data.get("filters", []),
"columns": groupby,
"metrics": metrics,
"row_limit": row_limit,
"order_desc": form_data.get("order_desc", True),
}
],
form_data=form_data,
force=request.force_refresh,
)
await ctx.report_progress(3, 4, "Executing data query")
with event_logger.log_context(action="mcp.get_chart_data.query_execution"):
command = ChartDataCommand(query_context)
command.validate()
result = command.run()
if not result or "queries" not in result or len(result["queries"]) == 0:
return ChartError(
error="No query results returned for unsaved chart.",
error_type="EmptyQuery",
)
query_result = result["queries"][0]
data = query_result.get("data", [])
raw_columns = query_result.get("colnames", [])
if not data:
return ChartError(
error="No data available for unsaved chart.",
error_type="NoData",
)
columns = []
for col_name in raw_columns:
sample_values = [
row.get(col_name) for row in data[:3] if row.get(col_name) is not None
]
data_type = "string"
if sample_values and all(
isinstance(v, (int, float)) for v in sample_values
):
data_type = "numeric"
columns.append(
DataColumn(
name=col_name,
display_name=col_name.replace("_", " ").title(),
data_type=data_type,
sample_values=sample_values[:3],
null_count=sum(1 for row in data if row.get(col_name) is None),
unique_count=len({str(row.get(col_name)) for row in data}),
)
)
cache_status = get_cache_status_from_result(
query_result, force_refresh=request.force_refresh
)
chart_name = form_data.get("slice_name", "Unsaved chart")
summary = (
f"Unsaved chart ({viz_type}). "
f"Contains {len(data)} rows across {len(raw_columns)} columns."
)
await ctx.report_progress(4, 4, "Building response")
return ChartData(
chart_id=0,
chart_name=chart_name,
chart_type=viz_type,
columns=columns,
data=data[: request.limit] if request.limit else data,
row_count=len(data),
total_rows=query_result.get("rowcount"),
summary=summary,
insights=["This is an unsaved chart queried from cached form_data."],
data_quality={
"completeness": 1.0
- (
sum(col.null_count for col in columns)
/ max(len(data) * len(columns), 1)
)
},
recommended_visualizations=[],
data_freshness=None,
performance=PerformanceMetadata(
query_duration_ms=0,
cache_status="fresh_query",
),
cache_status=cache_status,
)
except (CommandException, SupersetException, ValueError) as e:
logger.error("Error querying unsaved chart data: %s", e)
return ChartError(
error=f"Error querying unsaved chart data: {e}",
error_type="DataError",
)
def _export_data_as_csv(
chart: "Slice",
data: List[Dict[str, Any]],