diff --git a/superset-frontend/packages/superset-ui-core/src/query/types/PostProcessing.ts b/superset-frontend/packages/superset-ui-core/src/query/types/PostProcessing.ts index 79bcabdaff1..96dd51bf7ba 100644 --- a/superset-frontend/packages/superset-ui-core/src/query/types/PostProcessing.ts +++ b/superset-frontend/packages/superset-ui-core/src/query/types/PostProcessing.ts @@ -97,6 +97,7 @@ interface _PostProcessingContribution { orientation?: 'row' | 'column'; columns?: string[]; rename_columns?: string[]; + contribution_totals?: Record; }; } export type PostProcessingContribution = diff --git a/superset-frontend/plugins/plugin-chart-table/src/buildQuery.ts b/superset-frontend/plugins/plugin-chart-table/src/buildQuery.ts index 439df369f5a..e4b60e7893c 100644 --- a/superset-frontend/plugins/plugin-chart-table/src/buildQuery.ts +++ b/superset-frontend/plugins/plugin-chart-table/src/buildQuery.ts @@ -84,7 +84,7 @@ const buildQuery: BuildQuery = ( return buildQueryContext(formDataCopy, baseQueryObject => { let { metrics, orderby = [], columns = [] } = baseQueryObject; const { extras = {} } = baseQueryObject; - let postProcessing: PostProcessingRule[] = []; + const postProcessing: PostProcessingRule[] = []; const nonCustomNorInheritShifts = ensureIsArray( formData.time_compare, ).filter((shift: string) => shift !== 'custom' && shift !== 'inherit'); @@ -129,6 +129,12 @@ const buildQuery: BuildQuery = ( orderby = [[metrics[0], false]]; } // add postprocessing for percent metrics only when in aggregation mode + type PercentMetricCalculationMode = 'row_limit' | 'all_records'; + + const calculationMode: PercentMetricCalculationMode = + (formData.percent_metric_calculation as PercentMetricCalculationMode) || + 'row_limit'; + if (percentMetrics && percentMetrics.length > 0) { const percentMetricsLabelsWithTimeComparison = isTimeComparison( formData, @@ -139,6 +145,7 @@ const buildQuery: BuildQuery = ( timeOffsets, ) : percentMetrics.map(getMetricLabel); + const percentMetricLabels = removeDuplicates( percentMetricsLabelsWithTimeComparison, ); @@ -146,16 +153,26 @@ const buildQuery: BuildQuery = ( metrics.concat(percentMetrics), getMetricLabel, ); - postProcessing = [ - { + + if (calculationMode === 'all_records') { + postProcessing.push({ operation: 'contribution', options: { columns: percentMetricLabels, - rename_columns: percentMetricLabels.map(x => `%${x}`), + rename_columns: percentMetricLabels.map(m => `%${m}`), }, - }, - ]; + }); + } else { + postProcessing.push({ + operation: 'contribution', + options: { + columns: percentMetricLabels, + rename_columns: percentMetricLabels.map(m => `%${m}`), + }, + }); + } } + // Add the operator for the time comparison if some is selected if (!isEmpty(timeOffsets)) { postProcessing.push(timeCompareOperator(formData, baseQueryObject)); @@ -252,6 +269,26 @@ const buildQuery: BuildQuery = ( }); const extraQueries: QueryObject[] = []; + + const calculationMode = formData.percent_metric_calculation || 'row_limit'; + + if ( + calculationMode === 'all_records' && + percentMetrics && + percentMetrics.length > 0 + ) { + extraQueries.push({ + ...queryObject, + columns: [], + metrics: percentMetrics, + post_processing: [], + row_limit: 0, + row_offset: 0, + orderby: [], + is_timeseries: false, + }); + } + if ( metrics?.length && formData.show_totals && @@ -263,8 +300,8 @@ const buildQuery: BuildQuery = ( row_limit: 0, row_offset: 0, post_processing: [], - order_desc: undefined, // we don't need orderby stuff here, - orderby: undefined, // because this query will be used for get total aggregation. + order_desc: undefined, + orderby: undefined, }); } diff --git a/superset-frontend/plugins/plugin-chart-table/src/controlPanel.tsx b/superset-frontend/plugins/plugin-chart-table/src/controlPanel.tsx index 1f9425354e5..1c7cb036a23 100644 --- a/superset-frontend/plugins/plugin-chart-table/src/controlPanel.tsx +++ b/superset-frontend/plugins/plugin-chart-table/src/controlPanel.tsx @@ -161,12 +161,30 @@ const generateComparisonColumns = (colname: string) => [ `△ ${colname}`, `% ${colname}`, ]; + /** * Generate column types for the comparison columns. */ const generateComparisonColumnTypes = (count: number) => Array(count).fill(GenericDataType.Numeric); +const percentMetricCalculationControl: ControlConfig<'SelectControl'> = { + type: 'SelectControl', + label: t('Percentage metric calculation'), + description: t( + 'Row Limit: percentages are calculated based on the subset of data retrieved, respecting the row limit. ' + + 'All Records: Percentages are calculated based on the total dataset, ignoring the row limit.', + ), + default: 'row_limit', + clearable: false, + choices: [ + ['row_limit', t('Row limit')], + ['all_records', t('All records')], + ], + visibility: isAggMode, + renderTrigger: false, +}; + const processComparisonColumns = (columns: any[], suffix: string) => columns .map(col => { @@ -433,6 +451,13 @@ const config: ControlPanelConfig = { }, }, ], + [ + { + name: 'percent_metric_calculation', + config: percentMetricCalculationControl, + }, + ], + [ { name: 'show_totals', diff --git a/superset-frontend/plugins/plugin-chart-table/test/buildQuery.test.ts b/superset-frontend/plugins/plugin-chart-table/test/buildQuery.test.ts index 4badcc673ac..200bf7d9e8a 100644 --- a/superset-frontend/plugins/plugin-chart-table/test/buildQuery.test.ts +++ b/superset-frontend/plugins/plugin-chart-table/test/buildQuery.test.ts @@ -148,5 +148,92 @@ describe('plugin-chart-table', () => { expect(queries[1].extras?.time_grain_sqla).toEqual(TimeGranularity.MONTH); expect(queries[1].extras?.where).toEqual("(status IN ('In Process'))"); }); + + describe('Percent Metric Calculation Modes', () => { + const baseFormDataWithPercents: TableChartFormData = { + ...basicFormData, + query_mode: QueryMode.Aggregate, + metrics: ['count'], + percent_metrics: ['sum_sales'], + groupby: ['category'], + }; + + it('should default to row_limit mode with single query', () => { + const { queries } = buildQuery(baseFormDataWithPercents); + + expect(queries).toHaveLength(1); + expect(queries[0].metrics).toEqual(['count', 'sum_sales']); + expect(queries[0].post_processing).toEqual([ + { + operation: 'contribution', + options: { + columns: ['sum_sales'], + rename_columns: ['%sum_sales'], + }, + }, + ]); + }); + + it('should create extra query in all_records mode', () => { + const formData = { + ...baseFormDataWithPercents, + percent_metric_calculation: 'all_records', + }; + + const { queries } = buildQuery(formData); + + expect(queries).toHaveLength(2); + + expect(queries[0].post_processing).toEqual([ + { + operation: 'contribution', + options: { + columns: ['sum_sales'], + rename_columns: ['%sum_sales'], + }, + }, + ]); + + expect(queries[1]).toMatchObject({ + columns: [], + metrics: ['sum_sales'], + post_processing: [], + row_limit: 0, + row_offset: 0, + orderby: [], + is_timeseries: false, + }); + }); + + it('should work with show_totals in all_records mode', () => { + const formData = { + ...baseFormDataWithPercents, + percent_metric_calculation: 'all_records', + show_totals: true, + }; + + const { queries } = buildQuery(formData); + + expect(queries).toHaveLength(3); + expect(queries[1].metrics).toEqual(['sum_sales']); + expect(queries[2].metrics).toEqual(['count', 'sum_sales']); + }); + + it('should handle empty percent_metrics in all_records mode', () => { + const formData = { + ...basicFormData, + query_mode: QueryMode.Aggregate, + metrics: ['count'], + percent_metrics: [], + percent_metric_calculation: 'all_records', + groupby: ['category'], + }; + + const { queries } = buildQuery(formData); + + expect(queries).toHaveLength(1); + expect(queries[0].post_processing).toEqual([]); + }); + }); }); }); diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index 509dcba5a71..15625ad838a 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -741,6 +741,47 @@ class QueryContextProcessor: return df.to_dict(orient="records") + def ensure_totals_available(self) -> None: + queries_needing_totals = [] + totals_queries = [] + + for i, query in enumerate(self._query_context.queries): + needs_totals = any( + pp.get("operation") == "contribution" + for pp in getattr(query, "post_processing", []) or [] + ) + + if needs_totals: + queries_needing_totals.append(i) + + is_totals_query = ( + not query.columns and query.metrics and not query.post_processing + ) + if is_totals_query: + totals_queries.append(i) + + if not queries_needing_totals or not totals_queries: + return + + totals_idx = totals_queries[0] + totals_query = self._query_context.queries[totals_idx] + + totals_query.row_limit = None + + result = self._query_context.get_query_result(totals_query) + df = result.df + + totals = { + col: df[col].sum() for col in df.columns if df[col].dtype.kind in "biufc" + } + + for idx in queries_needing_totals: + query = self._query_context.queries[idx] + if hasattr(query, "post_processing") and query.post_processing: + for pp in query.post_processing: + if pp.get("operation") == "contribution": + pp["options"]["contribution_totals"] = totals + def get_payload( self, cache_query_context: bool | None = False, @@ -748,7 +789,8 @@ class QueryContextProcessor: ) -> dict[str, Any]: """Returns the query results with both metadata and data""" - # Get all the payloads from the QueryObjects + self.ensure_totals_available() + query_results = [ get_query_results( query_obj.result_type or self._query_context.result_type, @@ -758,6 +800,7 @@ class QueryContextProcessor: ) for query_obj in self._query_context.queries ] + return_value = {"queries": query_results} if cache_query_context: diff --git a/superset/utils/pandas_postprocessing/contribution.py b/superset/utils/pandas_postprocessing/contribution.py index 3c0ea04f102..ec6716fcba0 100644 --- a/superset/utils/pandas_postprocessing/contribution.py +++ b/superset/utils/pandas_postprocessing/contribution.py @@ -37,6 +37,7 @@ def contribution( columns: list[str] | None = None, time_shifts: list[str] | None = None, rename_columns: list[str] | None = None, + contribution_totals: dict[str, float] | None = None, ) -> DataFrame: """ Calculate cell contribution to row/column total for numeric columns. @@ -82,10 +83,19 @@ def contribution( numeric_df_view = numeric_df[actual_columns] if orientation == PostProcessingContributionOrientation.COLUMN: - numeric_df_view = numeric_df_view / numeric_df_view.values.sum( - axis=0, keepdims=True - ) - contribution_df[rename_columns] = numeric_df_view + if contribution_totals: + for i, col in enumerate(numeric_df_view.columns): + total = contribution_totals.get(col) + rename_col = rename_columns[i] + if total is None or total == 0: + contribution_df[rename_col] = 0 + else: + contribution_df[rename_col] = numeric_df_view[col] / total + else: + numeric_df_view = numeric_df_view / numeric_df_view.values.sum( + axis=0, keepdims=True + ) + contribution_df[rename_columns] = numeric_df_view return contribution_df result = get_column_groups(numeric_df_view, time_shifts, rename_columns)