diff --git a/superset-frontend/package-lock.json b/superset-frontend/package-lock.json index 6541092670e..3b6d100373d 100644 --- a/superset-frontend/package-lock.json +++ b/superset-frontend/package-lock.json @@ -40446,9 +40446,9 @@ "license": "ISC" }, "node_modules/protocol-buffers-schema": { - "version": "3.6.0", - "resolved": "https://registry.npmjs.org/protocol-buffers-schema/-/protocol-buffers-schema-3.6.0.tgz", - "integrity": "sha512-TdDRD+/QNdrCGCE7v8340QyuXd4kIWIgapsE2+n/SaGiSSbomYl4TjHlvIoCWRpE7wFt02EpB35VVA2ImcBVqw==", + "version": "3.6.1", + "resolved": "https://registry.npmjs.org/protocol-buffers-schema/-/protocol-buffers-schema-3.6.1.tgz", + "integrity": "sha512-VG2K63Igkiv9p76tk1lilczEK1cT+kCjKtkdhw1dQZV3k3IXJbd3o6Ho8b9zJZaHSnT2hKe4I+ObmX9w6m5SmQ==", "license": "MIT" }, "node_modules/protocols": { diff --git a/superset-frontend/plugins/legacy-plugin-chart-parallel-coordinates/src/ParallelCoordinates.ts b/superset-frontend/plugins/legacy-plugin-chart-parallel-coordinates/src/ParallelCoordinates.ts index 67ae1c2686c..844cd05ac06 100644 --- a/superset-frontend/plugins/legacy-plugin-chart-parallel-coordinates/src/ParallelCoordinates.ts +++ b/superset-frontend/plugins/legacy-plugin-chart-parallel-coordinates/src/ParallelCoordinates.ts @@ -47,7 +47,9 @@ interface ParallelCoordinatesProps { width: number; height: number; colorMetric: string; + defaultLineColor: string; includeSeries: boolean; + isDarkMode: boolean; linearColorScheme: string; metrics: string[]; series: string; @@ -63,7 +65,9 @@ function ParallelCoordinates( width, height, colorMetric, + defaultLineColor, includeSeries, + isDarkMode, linearColorScheme, metrics, series, @@ -87,9 +91,25 @@ function ParallelCoordinates( (d: Record) => d[colorMetric] as number, ), ) - : () => 'grey'; - const color = (d: Record) => - (colorScale as Function)(d[colorMetric]); + : null; + + const brightenForDarkMode = (colorStr: string): string => { + const hsl = d3.hsl(colorStr); + if (hsl.l < 0.5) { + hsl.l = Math.min(1, hsl.l + 0.4); + return hsl.toString(); + } + return colorStr; + }; + + const color = (d: Record): string => { + if (!colorScale) { + return defaultLineColor; + } + const baseColor = (colorScale as Function)(d[colorMetric]) as string; + return isDarkMode ? brightenForDarkMode(baseColor) : baseColor; + }; + const container = d3 .select(element) .classed('superset-legacy-chart-parallel-coordinates', true); @@ -105,7 +125,7 @@ function ParallelCoordinates( .width(width) .color(color) .alpha(0.5) - .composite('darken') + .composite(isDarkMode ? 'screen' : 'darken') .height(effHeight) .data(data) .dimensions(cols) diff --git a/superset-frontend/plugins/legacy-plugin-chart-parallel-coordinates/src/ReactParallelCoordinates.tsx b/superset-frontend/plugins/legacy-plugin-chart-parallel-coordinates/src/ReactParallelCoordinates.tsx index f5e8a5c8a9b..cf64d198650 100644 --- a/superset-frontend/plugins/legacy-plugin-chart-parallel-coordinates/src/ReactParallelCoordinates.tsx +++ b/superset-frontend/plugins/legacy-plugin-chart-parallel-coordinates/src/ReactParallelCoordinates.tsx @@ -64,6 +64,7 @@ export default styled(ParallelCoordinates)` .parcoords text.label { font: 100%; font-size: ${theme.fontSizeSM}px; + fill: ${theme.colorText}; cursor: drag; } .parcoords rect.background { @@ -85,6 +86,9 @@ export default styled(ParallelCoordinates)` stroke: ${theme.colorText}; shape-rendering: crispEdges; } + .parcoords .axis text { + fill: ${theme.colorText}; + } .parcoords canvas { opacity: 1; -moz-transition: opacity 0.3s; diff --git a/superset-frontend/plugins/legacy-plugin-chart-parallel-coordinates/src/transformProps.ts b/superset-frontend/plugins/legacy-plugin-chart-parallel-coordinates/src/transformProps.ts index afb4759bf88..94a065717ed 100644 --- a/superset-frontend/plugins/legacy-plugin-chart-parallel-coordinates/src/transformProps.ts +++ b/superset-frontend/plugins/legacy-plugin-chart-parallel-coordinates/src/transformProps.ts @@ -17,9 +17,10 @@ * under the License. */ import { ChartProps } from '@superset-ui/core'; +import { isThemeDark } from '@apache-superset/core/theme'; export default function transformProps(chartProps: ChartProps) { - const { width, height, formData, queriesData } = chartProps; + const { width, height, formData, queriesData, theme } = chartProps; const { includeSeries, linearColorScheme, @@ -33,15 +34,14 @@ export default function transformProps(chartProps: ChartProps) { width, height, data: queriesData[0].data, + defaultLineColor: theme.colorTextTertiary, includeSeries, + isDarkMode: isThemeDark(theme), linearColorScheme, metrics: metrics.map((m: { label?: string } | string) => typeof m === 'string' ? m : m.label || m, ), - colorMetric: - secondaryMetric && secondaryMetric.label - ? secondaryMetric.label - : secondaryMetric, + colorMetric: secondaryMetric?.label || secondaryMetric, series, showDatatable, }; diff --git a/superset-frontend/src/SqlLab/components/TableExploreTree/TreeNodeRenderer.tsx b/superset-frontend/src/SqlLab/components/TableExploreTree/TreeNodeRenderer.tsx index ba562348738..06605fdd601 100644 --- a/superset-frontend/src/SqlLab/components/TableExploreTree/TreeNodeRenderer.tsx +++ b/superset-frontend/src/SqlLab/components/TableExploreTree/TreeNodeRenderer.tsx @@ -79,6 +79,7 @@ export interface TreeNodeRendererProps extends NodeRendererProps { searchTerm: string; catalog: string | null | undefined; pinnedTableKeys: Set; + pinnedSchemas: Set; selectStarMap: Record; handleRefreshTables: (params: { dbId: number; @@ -91,6 +92,11 @@ export interface TreeNodeRendererProps extends NodeRendererProps { catalogName: string | null, ) => void; handleUnpinTable: (tableName: string, schemaName: string) => void; + handlePinSchema: (schemaName: string) => void; + handleUnpinSchema: (schemaName: string) => void; + refreshTableSchema: (id: string) => void; + sortedTables: Record; + toggleSortColumns: (tableId: string) => void; } const TreeNodeRenderer: React.FC = ({ @@ -101,19 +107,23 @@ const TreeNodeRenderer: React.FC = ({ searchTerm, catalog, pinnedTableKeys, + pinnedSchemas, selectStarMap, handleRefreshTables, handlePinTable, handleUnpinTable, + handlePinSchema, + handleUnpinSchema, + refreshTableSchema, + sortedTables, + toggleSortColumns, }) => { const theme = useTheme(); const { data } = node; const parts = data.id.split(':'); const [identifier, _dbId, schema, tableName] = parts; - // Use manually tracked open state for icon display - // This prevents search auto-expansion from affecting the icon - const isManuallyOpen = manuallyOpenedNodes[data.id] ?? false; + const isManuallyOpen = node.isOpen && !node.data.disableCheckbox; const isLoading = loadingNodes[data.id] ?? false; const renderIcon = () => { @@ -135,12 +145,7 @@ const TreeNodeRenderer: React.FC = ({ ? Icons.FunctionOutlined : Icons.TableOutlined; if (isLoading) { - return ( - <> - - - - ); + return ; } return ; } @@ -233,7 +238,27 @@ const TreeNodeRenderer: React.FC = ({ {highlightText(data.name, searchTerm)} {identifier === 'schema' && ( -
+
e.stopPropagation()} + > + {pinnedSchemas.has(schema) && ( +
+ + } + onClick={() => handleUnpinSchema(schema)} + /> +
+ )}
{ @@ -246,6 +271,30 @@ const TreeNodeRenderer: React.FC = ({ }} tooltipContent={t('Force refresh table list')} /> + + ) : ( + + ) + } + onClick={() => + pinnedSchemas.has(schema) + ? handleUnpinSchema(schema) + : handlePinSchema(schema) + } + />
)} @@ -288,6 +337,31 @@ const TreeNodeRenderer: React.FC = ({ } /> )} + + } + onClick={() => toggleSortColumns(data.id)} + /> + } + onClick={() => refreshTableSchema(data.id)} + /> `${dbId ?? ''}:${catalog ?? ''}`; + +const getPinnedSchemasFromStorage = ( + dbId: number | undefined, + catalog: string | null | undefined, +): Set => { + if (!dbId) return new Set(); + const stored = getItem(LocalStorageKeys.SqllabPinnedSchemas, {}); + const key = getPinnedSchemasStorageKey(dbId, catalog); + const schemas = stored[key]; + return Array.isArray(schemas) ? new Set(schemas) : new Set(); +}; + +const savePinnedSchemasToStorage = ( + dbId: number | undefined, + catalog: string | null | undefined, + schemas: Set, +) => { + if (!dbId) return; + const stored = getItem(LocalStorageKeys.SqllabPinnedSchemas, {}); + const key = getPinnedSchemasStorageKey(dbId, catalog); + setItem(LocalStorageKeys.SqllabPinnedSchemas, { + ...stored, + [key]: [...schemas], + }); +}; + const TableExploreTree: React.FC = ({ queryEditorId }) => { const dispatch = useDispatch(); const theme = useTheme(); @@ -161,6 +196,7 @@ const TableExploreTree: React.FC = ({ queryEditorId }) => { selectStarMap, handleToggle, handleRefreshTables, + refreshTableSchema, errorPayload, } = useTreeData({ dbId, @@ -199,6 +235,83 @@ const TableExploreTree: React.FC = ({ queryEditorId }) => { }, [dispatch, tables, editorId, dbId], ); + const [pinnedSchemas, setPinnedSchemas] = useState>(() => + getPinnedSchemasFromStorage(dbId, catalog), + ); + + const previousDbIdRef = useRef(dbId); + const previousCatalogRef = useRef(catalog); + + // Single effect handles both loading and persisting pinned schemas. + // Using refs to detect source changes avoids the race condition where the + // persist branch would run with stale pinnedSchemas right after a dbId/catalog + // change, corrupting the new source's stored pins. + useEffect(() => { + const dbChanged = previousDbIdRef.current !== dbId; + const catalogChanged = previousCatalogRef.current !== catalog; + + if (dbChanged || catalogChanged) { + previousDbIdRef.current = dbId; + previousCatalogRef.current = catalog; + setPinnedSchemas(getPinnedSchemasFromStorage(dbId, catalog)); + return; + } + + savePinnedSchemasToStorage(dbId, catalog, pinnedSchemas); + }, [dbId, catalog, pinnedSchemas]); + + const handlePinSchema = useCallback((schemaName: string) => { + setPinnedSchemas(prev => new Set([...prev, schemaName])); + }, []); + + const handleUnpinSchema = useCallback((schemaName: string) => { + setPinnedSchemas(prev => { + const next = new Set(prev); + next.delete(schemaName); + return next; + }); + }, []); + + const sortedTreeData = useMemo(() => { + if (pinnedSchemas.size === 0) return treeData; + const pinned = treeData.filter(node => pinnedSchemas.has(node.name)); + const rest = treeData.filter(node => !pinnedSchemas.has(node.name)); + return [...pinned, ...rest]; + }, [treeData, pinnedSchemas]); + + const [sortedTables, setSortedTables] = useState>({}); + + useEffect(() => { + setSortedTables({}); + }, [dbId, catalog]); + + const toggleSortColumns = useCallback((tableId: string) => { + setSortedTables(prev => ({ ...prev, [tableId]: !prev[tableId] })); + }, []); + + const displayTreeData = useMemo(() => { + const activeSorted = Object.keys(sortedTables).filter( + id => sortedTables[id], + ); + if (activeSorted.length === 0) return sortedTreeData; + + const sortedSet = new Set(activeSorted); + return sortedTreeData.map(schemaNode => ({ + ...schemaNode, + children: schemaNode.children?.map(tableNode => { + if (tableNode.type !== 'table' || !sortedSet.has(tableNode.id)) { + return tableNode; + } + const { children } = tableNode; + if (!children || children.length <= 1) return tableNode; + return { + ...tableNode, + children: [...children].sort((a, b) => a.name.localeCompare(b.name)), + }; + }), + })); + }, [sortedTreeData, sortedTables]); + const [searchTerm, setSearchTerm] = useState(''); const handleSearchChange = useCallback( ({ target }: ChangeEvent) => setSearchTerm(target.value), @@ -270,8 +383,8 @@ const TableExploreTree: React.FC = ({ queryEditorId }) => { return false; }; - return treeData.some(node => checkNode(node)); - }, [searchTerm, treeData]); + return displayTreeData.some(node => checkNode(node)); + }, [searchTerm, displayTreeData]); // Node renderer for react-arborist const renderNode = useCallback( @@ -283,19 +396,31 @@ const TableExploreTree: React.FC = ({ queryEditorId }) => { searchTerm={searchTerm} catalog={catalog} pinnedTableKeys={pinnedTableKeys} + pinnedSchemas={pinnedSchemas} selectStarMap={selectStarMap} handleRefreshTables={handleRefreshTables} handlePinTable={handlePinTable} handleUnpinTable={handleUnpinTable} + handlePinSchema={handlePinSchema} + handleUnpinSchema={handleUnpinSchema} + refreshTableSchema={refreshTableSchema} + sortedTables={sortedTables} + toggleSortColumns={toggleSortColumns} /> ), [ catalog, pinnedTableKeys, + pinnedSchemas, selectStarMap, handleRefreshTables, handlePinTable, handleUnpinTable, + handlePinSchema, + handleUnpinSchema, + refreshTableSchema, + sortedTables, + toggleSortColumns, loadingNodes, manuallyOpenedNodes, searchTerm, @@ -369,7 +494,7 @@ const TableExploreTree: React.FC = ({ queryEditorId }) => { return ( ref={treeRef} - data={treeData} + data={displayTreeData} width="100%" height={height || 500} rowHeight={ROW_HEIGHT} diff --git a/superset-frontend/src/SqlLab/components/TableExploreTree/useTreeData.ts b/superset-frontend/src/SqlLab/components/TableExploreTree/useTreeData.ts index 9cc54ce599c..a61ea25c6ca 100644 --- a/superset-frontend/src/SqlLab/components/TableExploreTree/useTreeData.ts +++ b/superset-frontend/src/SqlLab/components/TableExploreTree/useTreeData.ts @@ -17,6 +17,7 @@ * under the License. */ import { useMemo, useReducer, useCallback } from 'react'; +import { useDispatch } from 'react-redux'; import { t } from '@apache-superset/core/translation'; import { Table, @@ -26,6 +27,7 @@ import { useLazyTableMetadataQuery, useLazyTableExtendedMetadataQuery, } from 'src/hooks/apiResources'; +import { addDangerToast } from 'src/SqlLab/actions/sqlLab'; import type { TreeNodeData } from './types'; import { SupersetError } from '@superset-ui/core'; @@ -42,6 +44,7 @@ interface TreeDataState { type TreeDataAction = | { type: 'SET_TABLE_DATA'; key: string; data: { options: Table[] } } | { type: 'SET_TABLE_SCHEMA_DATA'; key: string; data: TableMetaData } + | { type: 'CLEAR_TABLE_SCHEMA_DATA'; key: string } | { type: 'SET_LOADING_NODE'; nodeId: string; loading: boolean } | { type: 'SET_ERROR'; errorPayload: SupersetError | null }; @@ -71,6 +74,10 @@ function treeDataReducer( [action.key]: action.data, }, }; + case 'CLEAR_TABLE_SCHEMA_DATA': { + const { [action.key]: _, ...rest } = state.tableSchemaData; + return { ...state, tableSchemaData: rest }; + } case 'SET_LOADING_NODE': return { ...state, @@ -108,6 +115,7 @@ interface UseTreeDataResult { catalog: string | null | undefined; schema: string; }) => void; + refreshTableSchema: (id: string) => void; errorPayload: SupersetError | null; } @@ -122,6 +130,7 @@ const useTreeData = ({ catalog, pinnedTables, }: UseTreeDataParams): UseTreeDataResult => { + const reduxDispatch = useDispatch(); // Schema data from API const { currentData: schemaData, @@ -137,6 +146,64 @@ const useTreeData = ({ const [state, dispatch] = useReducer(treeDataReducer, initialState); const { tableData, tableSchemaData, loadingNodes, errorPayload } = state; + // Shared helper: fetch table metadata + extended metadata and store in state. + // preferCacheValue=true on initial open (use cached data if available), + // preferCacheValue=false on explicit refresh (bypass cache). + const fetchAndStoreTableSchema = useCallback( + (id: string, preferCacheValue: boolean) => { + if (loadingNodes[id]) return; + + const parts = id.split(':'); + const [, databaseId, schema, table] = parts; + const parsedDbId = Number(databaseId); + const tableKey = `${parsedDbId}:${schema}:${table}`; + + dispatch({ type: 'SET_LOADING_NODE', nodeId: id, loading: true }); + + // .unwrap() causes RTK Query to reject on error so .catch() fires. + // Without it RTK Query resolves with { error } instead of rejecting. + Promise.all([ + fetchTableMetadata( + { dbId: parsedDbId, catalog, schema, table }, + preferCacheValue, + ).unwrap(), + fetchTableExtendedMetadata( + { dbId: parsedDbId, catalog, schema, table }, + preferCacheValue, + ).unwrap(), + ]) + .then(([tableMetadata, tableExtendedMetadata]) => { + if (tableMetadata) { + dispatch({ + type: 'SET_TABLE_SCHEMA_DATA', + key: tableKey, + data: { ...tableMetadata, ...tableExtendedMetadata }, + }); + } + }) + .catch(() => { + reduxDispatch( + addDangerToast( + t( + 'An error occurred while fetching table metadata for %s', + table, + ), + ), + ); + }) + .finally(() => { + dispatch({ type: 'SET_LOADING_NODE', nodeId: id, loading: false }); + }); + }, + [ + catalog, + fetchTableExtendedMetadata, + fetchTableMetadata, + loadingNodes, + reduxDispatch, + ], + ); + // Handle async loading when node is toggled open const handleToggle = useCallback( async (id: string, isOpen: boolean) => { @@ -150,20 +217,14 @@ const useTreeData = ({ if (identifier === 'schema') { const schemaKey = `${parsedDbId}:${schema}`; if (!tableData?.[schemaKey]) { - // Set loading state dispatch({ type: 'SET_LOADING_NODE', nodeId: id, loading: true }); - // Fetch tables asynchronously fetchLazyTables( - { - dbId: parsedDbId, - catalog, - schema, - forceRefresh: false, - }, + { dbId: parsedDbId, catalog, schema, forceRefresh: false }, true, ) - .then(({ data }) => { + .unwrap() + .then(data => { if (data) { dispatch({ type: 'SET_TABLE_DATA', key: schemaKey, data }); } @@ -191,59 +252,14 @@ const useTreeData = ({ if (pinnedTables[tableKey]) return; if (!tableSchemaData[tableKey]) { - // Set loading state - dispatch({ type: 'SET_LOADING_NODE', nodeId: id, loading: true }); - - // Fetch metadata asynchronously - Promise.all([ - fetchTableMetadata( - { - dbId: parsedDbId, - catalog, - schema, - table, - }, - true, - ), - fetchTableExtendedMetadata( - { - dbId: parsedDbId, - catalog, - schema, - table, - }, - true, - ), - ]) - .then( - ([{ data: tableMetadata }, { data: tableExtendedMetadata }]) => { - if (tableMetadata) { - dispatch({ - type: 'SET_TABLE_SCHEMA_DATA', - key: tableKey, - data: { - ...tableMetadata, - ...tableExtendedMetadata, - }, - }); - } - }, - ) - .finally(() => { - dispatch({ - type: 'SET_LOADING_NODE', - nodeId: id, - loading: false, - }); - }); + fetchAndStoreTableSchema(id, true); } } }, [ catalog, + fetchAndStoreTableSchema, fetchLazyTables, - fetchTableExtendedMetadata, - fetchTableMetadata, pinnedTables, tableData, tableSchemaData, @@ -289,6 +305,13 @@ const useTreeData = ({ [fetchLazyTables], ); + const refreshTableSchema = useCallback( + (id: string) => { + fetchAndStoreTableSchema(id, false); + }, + [fetchAndStoreTableSchema], + ); + // Build tree data const treeData = useMemo((): TreeNodeData[] => { const data = schemaData?.map(schema => { @@ -378,6 +401,7 @@ const useTreeData = ({ selectStarMap, handleToggle, handleRefreshTables, + refreshTableSchema, errorPayload, }; }; diff --git a/superset-frontend/src/SqlLab/components/TablePreview/index.tsx b/superset-frontend/src/SqlLab/components/TablePreview/index.tsx index 9e38b83ad7a..9fbcc3c1264 100644 --- a/superset-frontend/src/SqlLab/components/TablePreview/index.tsx +++ b/superset-frontend/src/SqlLab/components/TablePreview/index.tsx @@ -286,7 +286,6 @@ const TablePreview: FC = ({ dbId, catalog, schema, tableName }) => { {backend} {databaseName} {catalog && {catalog}} - {schema && {schema}} <Icons.InsertRowAboveOutlined iconSize="l" /> + {schema ? `${schema}.` : ''} {tableName} {titleActions()} diff --git a/superset-frontend/src/utils/localStorageHelpers.ts b/superset-frontend/src/utils/localStorageHelpers.ts index 61e1bbc3505..3c59b04b9c4 100644 --- a/superset-frontend/src/utils/localStorageHelpers.ts +++ b/superset-frontend/src/utils/localStorageHelpers.ts @@ -51,6 +51,7 @@ export enum LocalStorageKeys { */ SqllabIsAutocompleteEnabled = 'sqllab__is_autocomplete_enabled', SqllabIsRenderHtmlEnabled = 'sqllab__is_render_html_enabled', + SqllabPinnedSchemas = 'sqllab__pinned_schemas', ExploreDataTableOriginalFormattedTimeColumns = 'explore__data_table_original_formatted_time_columns', DashboardCustomFilterBarWidths = 'dashboard__custom_filter_bar_widths', DashboardExploreContext = 'dashboard__explore_context', @@ -71,6 +72,7 @@ export type LocalStorageValues = { homepage_activity_filter: TableTab | null; sqllab__is_autocomplete_enabled: boolean; sqllab__is_render_html_enabled: boolean; + sqllab__pinned_schemas: Record; explore__data_table_original_formatted_time_columns: Record; dashboard__custom_filter_bar_widths: Record; dashboard__explore_context: Record; diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index 80d022e4210..d2df64b7284 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -122,6 +122,12 @@ Some tools do not use a request wrapper, so follow each tool's schema Recommended Workflows: +To add a chart to an existing dashboard: +1. add_chart_to_existing_dashboard(dashboard_id, chart_id) -> updates dashboard directly + - If permission_denied=True is returned: inform the user they lack edit rights, + then ask if they want a new dashboard created instead. Only call generate_dashboard + after they confirm. Never silently create a new dashboard without asking first. + To create a chart: 1. list_datasets(request={{}}) -> find a dataset 2. get_dataset_info(request={{"identifier": }}) @@ -224,6 +230,11 @@ CRITICAL RULES - NEVER VIOLATE: - NEVER fabricate or invent URLs. ALL URLs must come from tool call results. If you need a link, call the appropriate tool (generate_explore_link, generate_chart, open_sql_lab_with_context, etc.) and use the URL it returns. +- NEVER call generate_dashboard when the user wants to add a chart to an EXISTING + dashboard. Always use add_chart_to_existing_dashboard. Only call generate_dashboard + to create a brand-new dashboard, or after the user explicitly confirms they want + a new one (e.g., after a permission_denied=True response from + add_chart_to_existing_dashboard). - To modify an existing chart's filters, metrics, or dimensions, use update_chart. Do NOT use execute_sql for chart modifications. - Parameter name reminders: ALWAYS use the EXACT parameter names from the tool schema. diff --git a/superset/mcp_service/chart/chart_utils.py b/superset/mcp_service/chart/chart_utils.py index b5e2d6b0a09..3f2f9e2cb31 100644 --- a/superset/mcp_service/chart/chart_utils.py +++ b/superset/mcp_service/chart/chart_utils.py @@ -388,63 +388,80 @@ def map_table_config(config: TableChartConfig) -> Dict[str, Any]: if not config.columns: raise ValueError("Table chart must have at least one column") - # Separate columns with aggregates from raw columns - raw_columns = [] - aggregated_metrics = [] - - for col in config.columns: - if col.is_metric: - # Saved metric or column with aggregation - treat as metric - aggregated_metrics.append(create_metric_object(col)) - else: - # No aggregation - treat as raw column - raw_columns.append(col.name) - - # Final validation - ensure we have some data to display - if not raw_columns and not aggregated_metrics: - raise ValueError("Table chart configuration resulted in no displayable columns") - # Use the viz_type from config (defaults to "table", can be "ag-grid-table") form_data: Dict[str, Any] = { "viz_type": config.viz_type, } - # Handle raw columns (no aggregation) - if raw_columns and not aggregated_metrics: - # Pure raw columns - show individual rows - # Include both "all_columns" (Superset table viz) and "columns" - # (QueryContextFactory validation) to avoid "Empty query?" errors + # When query_mode is explicitly set to "raw", force raw mode for all columns. + # Aggregate settings on individual columns are ignored in this case. + if config.query_mode == "raw": + column_names = [col.name for col in config.columns] form_data.update( { - "all_columns": raw_columns, - "columns": raw_columns, + "all_columns": column_names, + "columns": column_names, "query_mode": "raw", "include_time": False, "order_desc": True, } ) + else: + # Auto-detect or explicit "aggregate": separate columns with aggregates + # from raw columns and build the appropriate form_data. + raw_columns = [] + aggregated_metrics = [] - # Handle aggregated columns only - elif aggregated_metrics and not raw_columns: - # Pure aggregation - show totals - form_data.update( - { - "metrics": aggregated_metrics, - "query_mode": "aggregate", - } - ) + for col in config.columns: + if col.is_metric: + # Saved metric or column with aggregation - treat as metric + aggregated_metrics.append(create_metric_object(col)) + else: + # No aggregation - treat as raw column + raw_columns.append(col.name) - # Handle mixed columns (raw + aggregated) - elif raw_columns and aggregated_metrics: - # Mixed mode - group by raw columns, aggregate metrics - form_data.update( - { - "all_columns": raw_columns, - "metrics": aggregated_metrics, - "groupby": raw_columns, - "query_mode": "aggregate", - } - ) + # Final validation - ensure we have some data to display + if not raw_columns and not aggregated_metrics: + raise ValueError( + "Table chart configuration resulted in no displayable columns" + ) + + # Handle raw columns (no aggregation) + if raw_columns and not aggregated_metrics: + # Pure raw columns - show individual rows + # Include both "all_columns" (Superset table viz) and "columns" + # (QueryContextFactory validation) to avoid "Empty query?" errors + form_data.update( + { + "all_columns": raw_columns, + "columns": raw_columns, + "query_mode": "raw", + "include_time": False, + "order_desc": True, + } + ) + + # Handle aggregated columns only + elif aggregated_metrics and not raw_columns: + # Pure aggregation - show totals + form_data.update( + { + "metrics": aggregated_metrics, + "query_mode": "aggregate", + } + ) + + # Handle mixed columns (raw + aggregated) + else: + # Mixed mode - group by raw columns, aggregate metrics + form_data.update( + { + "all_columns": raw_columns, + "metrics": aggregated_metrics, + "groupby": raw_columns, + "query_mode": "aggregate", + } + ) _add_adhoc_filters(form_data, config.filters) diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index c03599bb141..8e5e13d0c44 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -996,6 +996,17 @@ class TableChartConfig(UnknownFieldCheckMixin): viz_type: Literal["table", "ag-grid-table"] = Field( "table", description="'ag-grid-table' for interactive features" ) + query_mode: Literal["aggregate", "raw"] | None = Field( + None, + description=( + "Query mode: 'raw' returns individual rows without aggregation, " + "'aggregate' groups data using metrics. " + "When set to 'raw', all columns are treated as plain columns regardless " + "of any aggregate settings. " + "Defaults to auto-detection: 'raw' if no column has an aggregate " + "function, 'aggregate' otherwise." + ), + ) columns: List[ColumnRef] = Field( ..., min_length=1, diff --git a/superset/mcp_service/dashboard/schemas.py b/superset/mcp_service/dashboard/schemas.py index e61573dd7ad..6973245ad64 100644 --- a/superset/mcp_service/dashboard/schemas.py +++ b/superset/mcp_service/dashboard/schemas.py @@ -493,6 +493,15 @@ class AddChartToDashboardResponse(BaseModel): None, description="Position information for the added chart" ) error: str | None = Field(None, description="Error message, if operation failed") + permission_denied: bool = Field( + default=False, + description=( + "True when the operation failed because the current user does not " + "have edit rights on the target dashboard. When True, inform the " + "user and ask if they would like a new dashboard created instead. " + "Do NOT silently create a new dashboard — always confirm first." + ), + ) class GenerateDashboardRequest(BaseModel): diff --git a/superset/mcp_service/dashboard/tool/add_chart_to_existing_dashboard.py b/superset/mcp_service/dashboard/tool/add_chart_to_existing_dashboard.py index d3d9293ce5b..da511c29eb3 100644 --- a/superset/mcp_service/dashboard/tool/add_chart_to_existing_dashboard.py +++ b/superset/mcp_service/dashboard/tool/add_chart_to_existing_dashboard.py @@ -315,6 +315,47 @@ def _ensure_layout_structure( layout["DASHBOARD_VERSION_KEY"] = "v2" +def _find_and_authorize_dashboard( + dashboard_id: int, +) -> tuple[Any, AddChartToDashboardResponse | None]: + """Return (dashboard, None) on success or (None, error_response) on failure. + + Handles both the not-found case and the ownership check so the main tool + function doesn't need two separate branches for these pre-conditions. + """ + from superset import security_manager + from superset.daos.dashboard import DashboardDAO + from superset.exceptions import SupersetSecurityException + + dashboard = DashboardDAO.find_by_id(dashboard_id) + if not dashboard: + return None, AddChartToDashboardResponse( + dashboard=None, + dashboard_url=None, + position=None, + error=f"Dashboard with ID {dashboard_id} not found", + ) + + try: + security_manager.raise_for_ownership(dashboard) + except SupersetSecurityException: + return None, AddChartToDashboardResponse( + dashboard=None, + dashboard_url=None, + position=None, + permission_denied=True, + error=( + f"You don't have permission to edit dashboard " + f"'{dashboard.dashboard_title}' (ID: {dashboard_id}). " + "Ask the user if they would like a new dashboard " + "created with this chart instead, and only proceed " + "if they confirm." + ), + ) + + return dashboard, None + + @tool( tags=["mutate"], class_permission_name="Dashboard", @@ -333,18 +374,12 @@ def add_chart_to_existing_dashboard( """ try: from superset.commands.dashboard.update import UpdateDashboardCommand - from superset.daos.dashboard import DashboardDAO - # Validate dashboard and chart exist + # Validate dashboard exists and user has edit permission with event_logger.log_context(action="mcp.add_chart_to_dashboard.validation"): - dashboard = DashboardDAO.find_by_id(request.dashboard_id) - if not dashboard: - return AddChartToDashboardResponse( - dashboard=None, - dashboard_url=None, - position=None, - error=(f"Dashboard with ID {request.dashboard_id} not found"), - ) + dashboard, auth_error = _find_and_authorize_dashboard(request.dashboard_id) + if auth_error is not None: + return auth_error # Get chart object for SQLAlchemy relationships and validation from superset import db @@ -442,6 +477,7 @@ def add_chart_to_existing_dashboard( # trigger lazy-loading on the same dead session. from sqlalchemy.orm import subqueryload + from superset.daos.dashboard import DashboardDAO from superset.models.dashboard import Dashboard from superset.models.slice import Slice diff --git a/superset/mcp_service/dashboard/tool/generate_dashboard.py b/superset/mcp_service/dashboard/tool/generate_dashboard.py index 1cec383f643..aa1bb900a31 100644 --- a/superset/mcp_service/dashboard/tool/generate_dashboard.py +++ b/superset/mcp_service/dashboard/tool/generate_dashboard.py @@ -189,9 +189,12 @@ def _generate_title_from_charts(chart_objects: List[Any]) -> str: def generate_dashboard( # noqa: C901 request: GenerateDashboardRequest, ctx: Context ) -> GenerateDashboardResponse: - """Create dashboard from chart IDs. + """Create a NEW dashboard from chart IDs. IMPORTANT: + - Use this tool ONLY when creating a brand-new dashboard. + - To add a chart to an EXISTING dashboard, use add_chart_to_existing_dashboard. + Never use this tool as a fallback when add_chart_to_existing_dashboard fails. - All charts must exist and be accessible to current user - Charts arranged automatically in 2-column grid layout diff --git a/superset/mcp_service/mcp_config.py b/superset/mcp_service/mcp_config.py index caa1d1fc2d0..e052e373e39 100644 --- a/superset/mcp_service/mcp_config.py +++ b/superset/mcp_service/mcp_config.py @@ -261,6 +261,18 @@ MCP_RESPONSE_SIZE_CONFIG: Dict[str, Any] = { # - Set compact_schemas=False to disable schema compaction only (full $defs # and descriptions in search results, tool search still active). # - Set max_description_length=0 to disable description truncation only. +# +# Summary Mode (include_schemas): +# -------------------------------- +# When include_schemas=False (default), search results omit inputSchema +# entirely and include a lightweight "parameters_hint" field listing +# top-level parameter names (e.g. "page, page_size, search, filters"). +# This reduces per-search token cost by ~80% vs compact mode while still +# conveying what parameters a tool accepts. Full schemas remain available +# when invoking the tool via call_tool. +# - Set include_schemas=True to restore full inputSchema in search results. +# - compact_schemas is ignored when include_schemas=False (no schema to +# compact); max_description_length still applies in summary mode. # ============================================================================= MCP_TOOL_SEARCH_CONFIG: Dict[str, Any] = { "enabled": True, # Enabled by default — reduces initial context by ~70% @@ -272,8 +284,9 @@ MCP_TOOL_SEARCH_CONFIG: Dict[str, Any] = { ], "search_tool_name": "search_tools", # Name of the search tool "call_tool_name": "call_tool", # Name of the call proxy tool - "compact_schemas": True, # Strip $defs and simplify $ref in search results + "compact_schemas": True, # Strip $defs/$ref (requires include_schemas=True) "max_description_length": 300, # Truncate tool descriptions (0 = no truncation) + "include_schemas": False, # False=summary mode (name+hint), True=full inputSchema } diff --git a/superset/mcp_service/server.py b/superset/mcp_service/server.py index 074914de4e7..735f75c5df8 100644 --- a/superset/mcp_service/server.py +++ b/superset/mcp_service/server.py @@ -254,6 +254,20 @@ def _truncate_description(text: str, max_length: int) -> str: return truncated.rstrip() + "..." +def _extract_parameter_names(input_schema: dict[str, Any]) -> str: + """Extract top-level parameter names from a JSON Schema as a hint string. + + Returns a comma-separated string of property names from the schema's + ``properties`` key, or an empty string if none are found. + + Example: ``"page, page_size, search, filters, select_columns"`` + """ + properties = input_schema.get("properties", {}) + if not properties: + return "" + return ", ".join(properties.keys()) + + def _serialize_tools_without_output_schema( tools: Sequence[Any], ) -> list[dict[str, Any]]: @@ -265,7 +279,9 @@ def _serialize_tools_without_output_schema( """ results = [] for tool in tools: - data = tool.to_mcp_tool().model_dump(mode="json", exclude_none=True) + data = tool.to_mcp_tool().model_dump( + mode="json", exclude_none=True, exclude={"outputSchema"} + ) data.pop("outputSchema", None) if input_schema := data.get("inputSchema"): data["inputSchema"] = _strip_titles(input_schema) @@ -273,22 +289,62 @@ def _serialize_tools_without_output_schema( return results +def _build_summary_serializer(max_desc: int) -> Any: + """Build a summary-mode serializer that omits ``inputSchema``. + + Returns a callable that serializes each tool to ``name``, + ``description`` (optionally truncated), and a ``parameters_hint`` + string listing top-level parameter names. ``inputSchema`` and + ``outputSchema`` are stripped entirely. + """ + + def _summary_serializer(tools: Sequence[Any]) -> list[dict[str, Any]]: + results = [] + for tool in tools: + data = tool.to_mcp_tool().model_dump( + mode="json", exclude_none=True, exclude={"outputSchema"} + ) + data.pop("outputSchema", None) + if input_schema := data.pop("inputSchema", None): + hint = _extract_parameter_names(input_schema) + if hint: + data["parameters_hint"] = hint + if max_desc and (desc := data.get("description")): + data["description"] = _truncate_description(desc, max_desc) + results.append(data) + return results + + return _summary_serializer + + def _create_search_result_serializer( config: dict[str, Any], ) -> Any: """Build a search-result serializer from the tool-search config. - When ``compact_schemas`` is enabled (default), the serializer applies - additional compaction on top of the base serialization: + When ``include_schemas`` is False (default), delegates to + :func:`_build_summary_serializer`, which strips ``inputSchema`` + entirely and adds a ``parameters_hint`` field with comma-separated + top-level parameter names. This reduces per-search token cost by + ~80% vs compact mode while still conveying what parameters a tool + accepts. - * ``$defs`` sections and ``$ref`` pointers are collapsed - (see :func:`_compact_schema`). + When ``include_schemas`` is True, the full ``compact_schemas``/ + ``max_description_length`` pipeline applies (existing behavior): + + * ``$defs`` sections and ``$ref`` pointers are collapsed when + ``compact_schemas`` is True (see :func:`_compact_schema`). * Tool descriptions are truncated to ``max_description_length`` chars. - This reduces per-search-call token cost by ~40-60 % while keeping - enough detail for the LLM to identify the right tool and construct - a basic invocation. + Full schemas remain available when the tool is invoked via ``call_tool``. """ + include_schemas = config.get("include_schemas", False) + + if not include_schemas: + max_desc = config.get("max_description_length", 300) + return _build_summary_serializer(max_desc) + + # include_schemas=True: apply full compact_schemas/max_description_length pipeline compact = config.get("compact_schemas", True) # Description truncation defaults to 300 when compact_schemas is on, # but is disabled when compact_schemas is off (unless explicitly set). @@ -304,10 +360,8 @@ def _create_search_result_serializer( if compact: if input_schema := data.get("inputSchema"): data["inputSchema"] = _compact_schema(input_schema) - if max_desc and "description" in data: - data["description"] = _truncate_description( - data["description"], max_desc - ) + if max_desc and (desc := data.get("description")): + data["description"] = _truncate_description(desc, max_desc) return results return _serializer diff --git a/tests/unit_tests/mcp_service/chart/test_chart_schemas.py b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py index 682471b4d4a..1d3b671fa7a 100644 --- a/tests/unit_tests/mcp_service/chart/test_chart_schemas.py +++ b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py @@ -109,6 +109,44 @@ class TestTableChartConfig: columns=[ColumnRef(name="product")], ) + def test_explicit_raw_query_mode_accepted(self) -> None: + """Test that TableChartConfig accepts explicit query_mode='raw'.""" + config = TableChartConfig( + chart_type="table", + query_mode="raw", + columns=[ColumnRef(name="product"), ColumnRef(name="category")], + ) + assert config.query_mode == "raw" + assert len(config.columns) == 2 + + def test_explicit_aggregate_query_mode_accepted(self) -> None: + """Test that TableChartConfig accepts explicit query_mode='aggregate'.""" + config = TableChartConfig( + chart_type="table", + query_mode="aggregate", + columns=[ColumnRef(name="sales", aggregate="SUM")], + ) + assert config.query_mode == "aggregate" + + def test_default_query_mode_is_none(self) -> None: + """Test that default query_mode is None (auto-detect).""" + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="product")], + ) + assert config.query_mode is None + + def test_invalid_query_mode_rejected(self) -> None: + """Test that invalid query_mode values are rejected.""" + from pydantic import ValidationError + + with pytest.raises(ValidationError): + TableChartConfig( + chart_type="table", + query_mode="invalid", + columns=[ColumnRef(name="product")], + ) + class TestXYChartConfig: """Test XYChartConfig validation.""" diff --git a/tests/unit_tests/mcp_service/chart/test_chart_utils.py b/tests/unit_tests/mcp_service/chart/test_chart_utils.py index 293adfb09c3..894440d4888 100644 --- a/tests/unit_tests/mcp_service/chart/test_chart_utils.py +++ b/tests/unit_tests/mcp_service/chart/test_chart_utils.py @@ -390,6 +390,62 @@ class TestMapTableConfig: assert result["metrics"] == ["total_revenue", "avg_order_value"] assert "all_columns" not in result + def test_map_table_config_explicit_raw_mode(self) -> None: + """Test that explicit query_mode='raw' forces raw mode.""" + config = TableChartConfig( + chart_type="table", + query_mode="raw", + columns=[ + ColumnRef(name="product"), + ColumnRef(name="category"), + ], + ) + + result = map_table_config(config) + + assert result["viz_type"] == "table" + assert result["query_mode"] == "raw" + assert result["all_columns"] == ["product", "category"] + assert result["columns"] == ["product", "category"] + assert "metrics" not in result + + def test_map_table_config_explicit_raw_mode_ignores_aggregates(self) -> None: + """Test that explicit query_mode='raw' ignores aggregate settings on columns.""" + config = TableChartConfig( + chart_type="table", + query_mode="raw", + columns=[ + ColumnRef(name="product"), + ColumnRef(name="sales", aggregate="SUM"), + ], + ) + + result = map_table_config(config) + + assert result["query_mode"] == "raw" + # Both columns treated as raw; aggregate setting on "sales" is ignored + assert result["all_columns"] == ["product", "sales"] + assert result["columns"] == ["product", "sales"] + assert "metrics" not in result + + def test_map_table_config_explicit_aggregate_mode(self) -> None: + """Test that explicit query_mode='aggregate' uses inference logic.""" + config = TableChartConfig( + chart_type="table", + query_mode="aggregate", + columns=[ + ColumnRef(name="product"), + ColumnRef(name="revenue", aggregate="SUM"), + ], + ) + + result = map_table_config(config) + + assert result["query_mode"] == "aggregate" + assert len(result["metrics"]) == 1 + assert result["metrics"][0]["aggregate"] == "SUM" + assert "product" in result["groupby"] + class TestAddAdhocFilters: """Test _add_adhoc_filters helper function""" diff --git a/tests/unit_tests/mcp_service/dashboard/tool/test_add_chart_to_existing_dashboard.py b/tests/unit_tests/mcp_service/dashboard/tool/test_add_chart_to_existing_dashboard.py new file mode 100644 index 00000000000..de8058c92de --- /dev/null +++ b/tests/unit_tests/mcp_service/dashboard/tool/test_add_chart_to_existing_dashboard.py @@ -0,0 +1,297 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Unit tests for add_chart_to_existing_dashboard MCP tool. + +Follows the same pattern used in test_dashboard_generation.py: +- Tests run through the async MCP Client (not direct function calls) +- Patches applied at source locations (superset.daos.dashboard.*, superset.db.*, etc.) +- auth is mocked via the autouse mock_auth fixture (same as other tool test files) + +Covers: +- Dashboard not found +- Permission denied (user does not own the dashboard) -> permission_denied=True +- Chart not found +- Chart already in dashboard +- Successful add (happy path) +""" + +import logging +from unittest.mock import Mock, patch + +import pytest +from fastmcp import Client + +from superset.mcp_service.app import mcp +from superset.mcp_service.chart.chart_utils import DatasetValidationResult + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mcp_server() -> object: + """Return the FastMCP app instance for use in MCP client tests.""" + return mcp + + +@pytest.fixture(autouse=True) +def mock_auth(): + """Mock authentication for all tests.""" + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + +@pytest.fixture(autouse=True) +def mock_chart_access(): + """Allow chart data access by default so tests focus on dashboard logic.""" + with patch( + "superset.mcp_service.auth.check_chart_data_access", + return_value=DatasetValidationResult( + is_valid=True, + dataset_id=1, + dataset_name="test_dataset", + warnings=[], + error=None, + ), + ): + yield + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _mock_chart(id: int = 10, slice_name: str = "Test Chart") -> Mock: + """Create a minimal mock Slice object with the given ID and name.""" + chart = Mock() + chart.id = id + chart.slice_name = slice_name + chart.uuid = f"chart-uuid-{id}" + chart.tags = [] + chart.owners = [] + chart.viz_type = "table" + chart.datasource_name = None + chart.description = None + return chart + + +def _mock_dashboard( + id: int = 1, + title: str = "Sales Dashboard", + slices: list[Mock] | None = None, +) -> Mock: + """Create a minimal mock Dashboard object with the given ID, title and slices.""" + dashboard = Mock() + dashboard.id = id + dashboard.dashboard_title = title + dashboard.slug = f"test-dashboard-{id}" + dashboard.description = None + dashboard.published = True + dashboard.created_on = None + dashboard.changed_on = None + dashboard.created_by_name = "test_user" + dashboard.changed_by_name = "test_user" + dashboard.uuid = f"dashboard-uuid-{id}" + dashboard.slices = slices or [] + dashboard.owners = [] + dashboard.tags = [] + dashboard.roles = [] + dashboard.position_json = "{}" + dashboard.json_metadata = None + dashboard.css = None + dashboard.certified_by = None + dashboard.certification_details = None + dashboard.is_managed_externally = False + dashboard.external_url = None + return dashboard + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@patch("superset.daos.dashboard.DashboardDAO.find_by_id") +@pytest.mark.asyncio +async def test_dashboard_not_found(mock_find_by_id: Mock, mcp_server: object) -> None: + """Returns a clear error when the target dashboard does not exist.""" + mock_find_by_id.return_value = None + + async with Client(mcp_server) as client: + result = await client.call_tool( + "add_chart_to_existing_dashboard", + {"request": {"dashboard_id": 999, "chart_id": 10}}, + ) + + assert result.structured_content["dashboard"] is None + assert result.structured_content["dashboard_url"] is None + assert result.structured_content["permission_denied"] is False + assert "not found" in (result.structured_content["error"] or "").lower() + + +@patch("superset.security_manager.raise_for_ownership") +@patch("superset.daos.dashboard.DashboardDAO.find_by_id") +@pytest.mark.asyncio +async def test_permission_denied( + mock_find_by_id: Mock, mock_raise_for_ownership: Mock, mcp_server: object +) -> None: + """Returns permission_denied=True and an actionable error when the user + cannot edit the dashboard. + + This is the core regression test for the bug fix: before the fix the tool + returned a generic error that caused the LLM to silently call + generate_dashboard instead. After the fix it returns permission_denied=True + with a message that explicitly tells the LLM to ask the user first. + """ + from superset.errors import ErrorLevel, SupersetError, SupersetErrorType + from superset.exceptions import SupersetSecurityException + + dashboard = _mock_dashboard(id=1, title="Sales Dashboard") + mock_find_by_id.return_value = dashboard + mock_raise_for_ownership.side_effect = SupersetSecurityException( + SupersetError( + message="Changing this Dashboard is forbidden", + error_type=SupersetErrorType.GENERIC_BACKEND_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "add_chart_to_existing_dashboard", + {"request": {"dashboard_id": 1, "chart_id": 10}}, + ) + + content = result.structured_content + assert content["dashboard"] is None + assert content["permission_denied"] is True, ( + "Expected permission_denied=True so the LLM knows to ask the user " + "before creating a new dashboard — this is the fix for the bug" + ) + assert content["error"] is not None + assert "Sales Dashboard" in content["error"] + assert "permission" in content["error"].lower() + assert "new dashboard" in content["error"].lower() + + +@patch("superset.db.session.get") +@patch("superset.security_manager.raise_for_ownership") +@patch("superset.daos.dashboard.DashboardDAO.find_by_id") +@pytest.mark.asyncio +async def test_chart_not_found( + mock_find_by_id: Mock, + mock_raise_for_ownership: Mock, + mock_session_get: Mock, + mcp_server: object, +) -> None: + """Returns an error when the requested chart does not exist.""" + dashboard = _mock_dashboard() + mock_find_by_id.return_value = dashboard + mock_raise_for_ownership.return_value = None + mock_session_get.return_value = None # chart not found + + async with Client(mcp_server) as client: + result = await client.call_tool( + "add_chart_to_existing_dashboard", + {"request": {"dashboard_id": 1, "chart_id": 99}}, + ) + + content = result.structured_content + assert content["dashboard"] is None + assert content["permission_denied"] is False + assert "99" in (content["error"] or "") + + +@patch("superset.db.session.get") +@patch("superset.security_manager.raise_for_ownership") +@patch("superset.daos.dashboard.DashboardDAO.find_by_id") +@pytest.mark.asyncio +async def test_chart_already_in_dashboard( + mock_find_by_id: Mock, + mock_raise_for_ownership: Mock, + mock_session_get: Mock, + mcp_server: object, +) -> None: + """Returns an error when the chart is already present on the dashboard.""" + chart = _mock_chart(id=10) + dashboard = _mock_dashboard(slices=[chart]) + mock_find_by_id.return_value = dashboard + mock_raise_for_ownership.return_value = None + mock_session_get.return_value = chart + + async with Client(mcp_server) as client: + result = await client.call_tool( + "add_chart_to_existing_dashboard", + {"request": {"dashboard_id": 1, "chart_id": 10}}, + ) + + content = result.structured_content + assert content["dashboard"] is None + assert content["permission_denied"] is False + assert "already" in (content["error"] or "").lower() + + +@patch("superset.commands.dashboard.update.UpdateDashboardCommand") +@patch("superset.db.session.get") +@patch("superset.security_manager.raise_for_ownership") +@patch("superset.daos.dashboard.DashboardDAO.find_by_id") +@pytest.mark.asyncio +async def test_successful_add( + mock_find_by_id: Mock, + mock_raise_for_ownership: Mock, + mock_session_get: Mock, + mock_update_cmd_cls: Mock, + mcp_server: object, +) -> None: + """Happy path: chart added, permission_denied=False, URL and position returned.""" + chart = _mock_chart(id=10) + dashboard = _mock_dashboard(id=1) + updated_dashboard = _mock_dashboard(id=1, slices=[chart]) + + mock_find_by_id.side_effect = [dashboard, updated_dashboard] + mock_raise_for_ownership.return_value = None + mock_session_get.return_value = chart + + mock_update_cmd = Mock() + mock_update_cmd.run.return_value = updated_dashboard + mock_update_cmd_cls.return_value = mock_update_cmd + + async with Client(mcp_server) as client: + result = await client.call_tool( + "add_chart_to_existing_dashboard", + {"request": {"dashboard_id": 1, "chart_id": 10}}, + ) + + content = result.structured_content + assert content["error"] is None + assert content["permission_denied"] is False + assert content["dashboard_url"] is not None + assert "/superset/dashboard/1/" in content["dashboard_url"] + assert content["position"] is not None + assert "chart_key" in content["position"] diff --git a/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_generation.py b/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_generation.py index 04a3b5a568a..b5df3e91abc 100644 --- a/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_generation.py +++ b/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_generation.py @@ -52,11 +52,12 @@ def mcp_server(): def mock_auth(): """Mock authentication for all tests.""" with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: - mock_user = Mock() - mock_user.id = 1 - mock_user.username = "admin" - mock_get_user.return_value = mock_user - yield mock_get_user + with patch("superset.security_manager.raise_for_ownership"): + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user @pytest.fixture(autouse=True) diff --git a/tests/unit_tests/mcp_service/test_tool_search_transform.py b/tests/unit_tests/mcp_service/test_tool_search_transform.py index dff91c76f81..74de17b88e5 100644 --- a/tests/unit_tests/mcp_service/test_tool_search_transform.py +++ b/tests/unit_tests/mcp_service/test_tool_search_transform.py @@ -27,6 +27,7 @@ from superset.mcp_service.server import ( _apply_tool_search_transform, _compact_schema, _create_search_result_serializer, + _extract_parameter_names, _fix_call_tool_arguments, _normalize_call_tool_arguments, _serialize_tools_without_output_schema, @@ -44,6 +45,7 @@ def test_tool_search_config_defaults(): assert "get_instance_info" in MCP_TOOL_SEARCH_CONFIG["always_visible"] assert MCP_TOOL_SEARCH_CONFIG["search_tool_name"] == "search_tools" assert MCP_TOOL_SEARCH_CONFIG["call_tool_name"] == "call_tool" + assert MCP_TOOL_SEARCH_CONFIG["include_schemas"] is False def test_apply_bm25_transform(): @@ -514,7 +516,9 @@ def test_create_serializer_compacts_schemas(): }, ) - serializer = _create_search_result_serializer({"compact_schemas": True}) + serializer = _create_search_result_serializer( + {"include_schemas": True, "compact_schemas": True} + ) result = serializer([tool]) assert len(result) == 1 @@ -550,7 +554,7 @@ def test_create_serializer_disabled(): ) serializer = _create_search_result_serializer( - {"compact_schemas": False, "max_description_length": 0} + {"include_schemas": True, "compact_schemas": False, "max_description_length": 0} ) result = serializer([tool]) @@ -569,7 +573,9 @@ def test_create_serializer_compact_false_disables_truncation(): {"type": "object", "$defs": {"Model": {"type": "object"}}}, ) - serializer = _create_search_result_serializer({"compact_schemas": False}) + serializer = _create_search_result_serializer( + {"include_schemas": True, "compact_schemas": False} + ) result = serializer([tool]) # $defs should still be present (compaction disabled) @@ -588,7 +594,11 @@ def test_create_serializer_compact_false_explicit_truncation(): ) serializer = _create_search_result_serializer( - {"compact_schemas": False, "max_description_length": 200} + { + "include_schemas": True, + "compact_schemas": False, + "max_description_length": 200, + } ) result = serializer([tool]) @@ -599,7 +609,11 @@ def test_create_serializer_compact_false_explicit_truncation(): def test_create_serializer_uses_config_defaults(): - """Empty config uses defaults (compact=True, max_desc=300).""" + """Empty config defaults to summary mode (include_schemas=False). + + The new default omits inputSchema and adds parameters_hint instead. + Descriptions are still truncated to 300 chars. + """ long_desc = "First sentence. " + "x" * 500 tool = _make_mock_tool( "test_tool", @@ -614,7 +628,10 @@ def test_create_serializer_uses_config_defaults(): serializer = _create_search_result_serializer({}) result = serializer([tool]) - assert "$defs" not in result[0]["inputSchema"] + # Summary mode: no inputSchema, parameters_hint present + assert "inputSchema" not in result[0] + assert result[0]["parameters_hint"] == "x" + # Description still truncated to default 300 assert len(result[0]["description"]) <= 303 @@ -640,3 +657,194 @@ def test_apply_transform_uses_compact_serializer(): transform._search_result_serializer is not _serialize_tools_without_output_schema ) + + +# -- _extract_parameter_names tests -- + + +def test_extract_parameter_names_basic(): + """Returns comma-separated top-level property names.""" + schema = { + "type": "object", + "properties": { + "page": {"type": "integer"}, + "page_size": {"type": "integer"}, + "search": {"type": "string"}, + }, + } + + result = _extract_parameter_names(schema) + + assert result == "page, page_size, search" + + +def test_extract_parameter_names_empty_properties(): + """Returns empty string when properties dict is empty.""" + schema = {"type": "object", "properties": {}} + + result = _extract_parameter_names(schema) + + assert result == "" + + +def test_extract_parameter_names_no_properties_key(): + """Returns empty string when properties key is absent.""" + schema = {"type": "object"} + + result = _extract_parameter_names(schema) + + assert result == "" + + +def test_extract_parameter_names_with_refs(): + """Extracts names regardless of the shape of property values.""" + schema = { + "type": "object", + "properties": { + "filters": {"type": "array", "items": {"$ref": "#/$defs/ChartFilter"}}, + "select_columns": {"type": "array"}, + }, + "$defs": {"ChartFilter": {"type": "object"}}, + } + + result = _extract_parameter_names(schema) + + assert result == "filters, select_columns" + + +# -- _create_search_result_serializer summary mode (include_schemas=False) -- + + +def test_create_serializer_summary_mode_strips_input_schema(): + """When include_schemas=False, inputSchema is absent from results.""" + tool = _make_mock_tool( + "list_charts", + "List charts.", + { + "type": "object", + "properties": { + "page": {"type": "integer"}, + "search": {"type": "string"}, + }, + }, + ) + + serializer = _create_search_result_serializer({"include_schemas": False}) + result = serializer([tool]) + + assert len(result) == 1 + assert "inputSchema" not in result[0] + assert result[0]["name"] == "list_charts" + + +def test_create_serializer_summary_mode_adds_parameters_hint(): + """When include_schemas=False, parameters_hint lists top-level param names.""" + tool = _make_mock_tool( + "list_charts", + "List charts.", + { + "type": "object", + "properties": { + "page": {"type": "integer"}, + "page_size": {"type": "integer"}, + "search": {"type": "string"}, + }, + }, + ) + + serializer = _create_search_result_serializer({"include_schemas": False}) + result = serializer([tool]) + + assert result[0]["parameters_hint"] == "page, page_size, search" + + +def test_create_serializer_summary_mode_no_hint_when_no_properties(): + """When inputSchema has no properties, parameters_hint is absent.""" + tool = _make_mock_tool( + "health_check", + "Health check.", + {"type": "object"}, + ) + + serializer = _create_search_result_serializer({"include_schemas": False}) + result = serializer([tool]) + + assert "inputSchema" not in result[0] + assert "parameters_hint" not in result[0] + + +def test_create_serializer_summary_mode_truncates_description(): + """Summary mode still truncates descriptions to max_description_length.""" + long_desc = "First sentence. " + "x" * 500 + tool = _make_mock_tool( + "list_charts", + long_desc, + {"type": "object", "properties": {"page": {"type": "integer"}}}, + ) + + serializer = _create_search_result_serializer( + {"include_schemas": False, "max_description_length": 50} + ) + result = serializer([tool]) + + assert len(result[0]["description"]) <= 53 + + +def test_create_serializer_summary_mode_is_default(): + """Empty config defaults to summary mode (include_schemas=False).""" + tool = _make_mock_tool( + "list_charts", + "List charts.", + { + "type": "object", + "properties": {"page": {"type": "integer"}}, + }, + ) + + serializer = _create_search_result_serializer({}) + result = serializer([tool]) + + assert "inputSchema" not in result[0] + assert "parameters_hint" in result[0] + + +def test_create_serializer_include_schemas_true_restores_full_schema(): + """include_schemas=True preserves inputSchema in results.""" + schema = { + "type": "object", + "properties": {"page": {"type": "integer"}}, + "$defs": {"Model": {"type": "object"}}, + } + tool = _make_mock_tool("list_charts", "List charts.", schema) + + serializer = _create_search_result_serializer( + {"include_schemas": True, "compact_schemas": False, "max_description_length": 0} + ) + result = serializer([tool]) + + assert "inputSchema" in result[0] + assert "parameters_hint" not in result[0] + assert "$defs" in result[0]["inputSchema"] + + +def test_create_serializer_include_schemas_true_with_compact(): + """include_schemas=True + compact_schemas=True still compacts the schema.""" + schema = { + "type": "object", + "properties": { + "filters": {"type": "array", "items": {"$ref": "#/$defs/ChartFilter"}} + }, + "$defs": {"ChartFilter": {"type": "object"}}, + } + tool = _make_mock_tool("list_charts", "List charts.", schema) + + serializer = _create_search_result_serializer( + {"include_schemas": True, "compact_schemas": True} + ) + result = serializer([tool]) + + assert "inputSchema" in result[0] + assert "$defs" not in result[0]["inputSchema"] + assert result[0]["inputSchema"]["properties"]["filters"]["items"] == { + "type": "object" + }