mirror of
https://github.com/apache/superset.git
synced 2026-04-07 10:31:50 +00:00
feat(chart): add toggle for percentage metric calculation mode in Table chart (#33656)
This commit is contained in:
@@ -97,6 +97,7 @@ interface _PostProcessingContribution {
|
||||
orientation?: 'row' | 'column';
|
||||
columns?: string[];
|
||||
rename_columns?: string[];
|
||||
contribution_totals?: Record<string, number>;
|
||||
};
|
||||
}
|
||||
export type PostProcessingContribution =
|
||||
|
||||
@@ -84,7 +84,7 @@ const buildQuery: BuildQuery<TableChartFormData> = (
|
||||
return buildQueryContext(formDataCopy, baseQueryObject => {
|
||||
let { metrics, orderby = [], columns = [] } = baseQueryObject;
|
||||
const { extras = {} } = baseQueryObject;
|
||||
let postProcessing: PostProcessingRule[] = [];
|
||||
const postProcessing: PostProcessingRule[] = [];
|
||||
const nonCustomNorInheritShifts = ensureIsArray(
|
||||
formData.time_compare,
|
||||
).filter((shift: string) => shift !== 'custom' && shift !== 'inherit');
|
||||
@@ -129,6 +129,12 @@ const buildQuery: BuildQuery<TableChartFormData> = (
|
||||
orderby = [[metrics[0], false]];
|
||||
}
|
||||
// add postprocessing for percent metrics only when in aggregation mode
|
||||
type PercentMetricCalculationMode = 'row_limit' | 'all_records';
|
||||
|
||||
const calculationMode: PercentMetricCalculationMode =
|
||||
(formData.percent_metric_calculation as PercentMetricCalculationMode) ||
|
||||
'row_limit';
|
||||
|
||||
if (percentMetrics && percentMetrics.length > 0) {
|
||||
const percentMetricsLabelsWithTimeComparison = isTimeComparison(
|
||||
formData,
|
||||
@@ -139,6 +145,7 @@ const buildQuery: BuildQuery<TableChartFormData> = (
|
||||
timeOffsets,
|
||||
)
|
||||
: percentMetrics.map(getMetricLabel);
|
||||
|
||||
const percentMetricLabels = removeDuplicates(
|
||||
percentMetricsLabelsWithTimeComparison,
|
||||
);
|
||||
@@ -146,16 +153,26 @@ const buildQuery: BuildQuery<TableChartFormData> = (
|
||||
metrics.concat(percentMetrics),
|
||||
getMetricLabel,
|
||||
);
|
||||
postProcessing = [
|
||||
{
|
||||
|
||||
if (calculationMode === 'all_records') {
|
||||
postProcessing.push({
|
||||
operation: 'contribution',
|
||||
options: {
|
||||
columns: percentMetricLabels,
|
||||
rename_columns: percentMetricLabels.map(x => `%${x}`),
|
||||
rename_columns: percentMetricLabels.map(m => `%${m}`),
|
||||
},
|
||||
},
|
||||
];
|
||||
});
|
||||
} else {
|
||||
postProcessing.push({
|
||||
operation: 'contribution',
|
||||
options: {
|
||||
columns: percentMetricLabels,
|
||||
rename_columns: percentMetricLabels.map(m => `%${m}`),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Add the operator for the time comparison if some is selected
|
||||
if (!isEmpty(timeOffsets)) {
|
||||
postProcessing.push(timeCompareOperator(formData, baseQueryObject));
|
||||
@@ -252,6 +269,26 @@ const buildQuery: BuildQuery<TableChartFormData> = (
|
||||
});
|
||||
|
||||
const extraQueries: QueryObject[] = [];
|
||||
|
||||
const calculationMode = formData.percent_metric_calculation || 'row_limit';
|
||||
|
||||
if (
|
||||
calculationMode === 'all_records' &&
|
||||
percentMetrics &&
|
||||
percentMetrics.length > 0
|
||||
) {
|
||||
extraQueries.push({
|
||||
...queryObject,
|
||||
columns: [],
|
||||
metrics: percentMetrics,
|
||||
post_processing: [],
|
||||
row_limit: 0,
|
||||
row_offset: 0,
|
||||
orderby: [],
|
||||
is_timeseries: false,
|
||||
});
|
||||
}
|
||||
|
||||
if (
|
||||
metrics?.length &&
|
||||
formData.show_totals &&
|
||||
@@ -263,8 +300,8 @@ const buildQuery: BuildQuery<TableChartFormData> = (
|
||||
row_limit: 0,
|
||||
row_offset: 0,
|
||||
post_processing: [],
|
||||
order_desc: undefined, // we don't need orderby stuff here,
|
||||
orderby: undefined, // because this query will be used for get total aggregation.
|
||||
order_desc: undefined,
|
||||
orderby: undefined,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -161,12 +161,30 @@ const generateComparisonColumns = (colname: string) => [
|
||||
`△ ${colname}`,
|
||||
`% ${colname}`,
|
||||
];
|
||||
|
||||
/**
|
||||
* Generate column types for the comparison columns.
|
||||
*/
|
||||
const generateComparisonColumnTypes = (count: number) =>
|
||||
Array(count).fill(GenericDataType.Numeric);
|
||||
|
||||
const percentMetricCalculationControl: ControlConfig<'SelectControl'> = {
|
||||
type: 'SelectControl',
|
||||
label: t('Percentage metric calculation'),
|
||||
description: t(
|
||||
'Row Limit: percentages are calculated based on the subset of data retrieved, respecting the row limit. ' +
|
||||
'All Records: Percentages are calculated based on the total dataset, ignoring the row limit.',
|
||||
),
|
||||
default: 'row_limit',
|
||||
clearable: false,
|
||||
choices: [
|
||||
['row_limit', t('Row limit')],
|
||||
['all_records', t('All records')],
|
||||
],
|
||||
visibility: isAggMode,
|
||||
renderTrigger: false,
|
||||
};
|
||||
|
||||
const processComparisonColumns = (columns: any[], suffix: string) =>
|
||||
columns
|
||||
.map(col => {
|
||||
@@ -433,6 +451,13 @@ const config: ControlPanelConfig = {
|
||||
},
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
name: 'percent_metric_calculation',
|
||||
config: percentMetricCalculationControl,
|
||||
},
|
||||
],
|
||||
|
||||
[
|
||||
{
|
||||
name: 'show_totals',
|
||||
|
||||
@@ -148,5 +148,92 @@ describe('plugin-chart-table', () => {
|
||||
expect(queries[1].extras?.time_grain_sqla).toEqual(TimeGranularity.MONTH);
|
||||
expect(queries[1].extras?.where).toEqual("(status IN ('In Process'))");
|
||||
});
|
||||
|
||||
describe('Percent Metric Calculation Modes', () => {
|
||||
const baseFormDataWithPercents: TableChartFormData = {
|
||||
...basicFormData,
|
||||
query_mode: QueryMode.Aggregate,
|
||||
metrics: ['count'],
|
||||
percent_metrics: ['sum_sales'],
|
||||
groupby: ['category'],
|
||||
};
|
||||
|
||||
it('should default to row_limit mode with single query', () => {
|
||||
const { queries } = buildQuery(baseFormDataWithPercents);
|
||||
|
||||
expect(queries).toHaveLength(1);
|
||||
expect(queries[0].metrics).toEqual(['count', 'sum_sales']);
|
||||
expect(queries[0].post_processing).toEqual([
|
||||
{
|
||||
operation: 'contribution',
|
||||
options: {
|
||||
columns: ['sum_sales'],
|
||||
rename_columns: ['%sum_sales'],
|
||||
},
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should create extra query in all_records mode', () => {
|
||||
const formData = {
|
||||
...baseFormDataWithPercents,
|
||||
percent_metric_calculation: 'all_records',
|
||||
};
|
||||
|
||||
const { queries } = buildQuery(formData);
|
||||
|
||||
expect(queries).toHaveLength(2);
|
||||
|
||||
expect(queries[0].post_processing).toEqual([
|
||||
{
|
||||
operation: 'contribution',
|
||||
options: {
|
||||
columns: ['sum_sales'],
|
||||
rename_columns: ['%sum_sales'],
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
expect(queries[1]).toMatchObject({
|
||||
columns: [],
|
||||
metrics: ['sum_sales'],
|
||||
post_processing: [],
|
||||
row_limit: 0,
|
||||
row_offset: 0,
|
||||
orderby: [],
|
||||
is_timeseries: false,
|
||||
});
|
||||
});
|
||||
|
||||
it('should work with show_totals in all_records mode', () => {
|
||||
const formData = {
|
||||
...baseFormDataWithPercents,
|
||||
percent_metric_calculation: 'all_records',
|
||||
show_totals: true,
|
||||
};
|
||||
|
||||
const { queries } = buildQuery(formData);
|
||||
|
||||
expect(queries).toHaveLength(3);
|
||||
expect(queries[1].metrics).toEqual(['sum_sales']);
|
||||
expect(queries[2].metrics).toEqual(['count', 'sum_sales']);
|
||||
});
|
||||
|
||||
it('should handle empty percent_metrics in all_records mode', () => {
|
||||
const formData = {
|
||||
...basicFormData,
|
||||
query_mode: QueryMode.Aggregate,
|
||||
metrics: ['count'],
|
||||
percent_metrics: [],
|
||||
percent_metric_calculation: 'all_records',
|
||||
groupby: ['category'],
|
||||
};
|
||||
|
||||
const { queries } = buildQuery(formData);
|
||||
|
||||
expect(queries).toHaveLength(1);
|
||||
expect(queries[0].post_processing).toEqual([]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -741,6 +741,47 @@ class QueryContextProcessor:
|
||||
|
||||
return df.to_dict(orient="records")
|
||||
|
||||
def ensure_totals_available(self) -> None:
|
||||
queries_needing_totals = []
|
||||
totals_queries = []
|
||||
|
||||
for i, query in enumerate(self._query_context.queries):
|
||||
needs_totals = any(
|
||||
pp.get("operation") == "contribution"
|
||||
for pp in getattr(query, "post_processing", []) or []
|
||||
)
|
||||
|
||||
if needs_totals:
|
||||
queries_needing_totals.append(i)
|
||||
|
||||
is_totals_query = (
|
||||
not query.columns and query.metrics and not query.post_processing
|
||||
)
|
||||
if is_totals_query:
|
||||
totals_queries.append(i)
|
||||
|
||||
if not queries_needing_totals or not totals_queries:
|
||||
return
|
||||
|
||||
totals_idx = totals_queries[0]
|
||||
totals_query = self._query_context.queries[totals_idx]
|
||||
|
||||
totals_query.row_limit = None
|
||||
|
||||
result = self._query_context.get_query_result(totals_query)
|
||||
df = result.df
|
||||
|
||||
totals = {
|
||||
col: df[col].sum() for col in df.columns if df[col].dtype.kind in "biufc"
|
||||
}
|
||||
|
||||
for idx in queries_needing_totals:
|
||||
query = self._query_context.queries[idx]
|
||||
if hasattr(query, "post_processing") and query.post_processing:
|
||||
for pp in query.post_processing:
|
||||
if pp.get("operation") == "contribution":
|
||||
pp["options"]["contribution_totals"] = totals
|
||||
|
||||
def get_payload(
|
||||
self,
|
||||
cache_query_context: bool | None = False,
|
||||
@@ -748,7 +789,8 @@ class QueryContextProcessor:
|
||||
) -> dict[str, Any]:
|
||||
"""Returns the query results with both metadata and data"""
|
||||
|
||||
# Get all the payloads from the QueryObjects
|
||||
self.ensure_totals_available()
|
||||
|
||||
query_results = [
|
||||
get_query_results(
|
||||
query_obj.result_type or self._query_context.result_type,
|
||||
@@ -758,6 +800,7 @@ class QueryContextProcessor:
|
||||
)
|
||||
for query_obj in self._query_context.queries
|
||||
]
|
||||
|
||||
return_value = {"queries": query_results}
|
||||
|
||||
if cache_query_context:
|
||||
|
||||
@@ -37,6 +37,7 @@ def contribution(
|
||||
columns: list[str] | None = None,
|
||||
time_shifts: list[str] | None = None,
|
||||
rename_columns: list[str] | None = None,
|
||||
contribution_totals: dict[str, float] | None = None,
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Calculate cell contribution to row/column total for numeric columns.
|
||||
@@ -82,10 +83,19 @@ def contribution(
|
||||
numeric_df_view = numeric_df[actual_columns]
|
||||
|
||||
if orientation == PostProcessingContributionOrientation.COLUMN:
|
||||
numeric_df_view = numeric_df_view / numeric_df_view.values.sum(
|
||||
axis=0, keepdims=True
|
||||
)
|
||||
contribution_df[rename_columns] = numeric_df_view
|
||||
if contribution_totals:
|
||||
for i, col in enumerate(numeric_df_view.columns):
|
||||
total = contribution_totals.get(col)
|
||||
rename_col = rename_columns[i]
|
||||
if total is None or total == 0:
|
||||
contribution_df[rename_col] = 0
|
||||
else:
|
||||
contribution_df[rename_col] = numeric_df_view[col] / total
|
||||
else:
|
||||
numeric_df_view = numeric_df_view / numeric_df_view.values.sum(
|
||||
axis=0, keepdims=True
|
||||
)
|
||||
contribution_df[rename_columns] = numeric_df_view
|
||||
return contribution_df
|
||||
|
||||
result = get_column_groups(numeric_df_view, time_shifts, rename_columns)
|
||||
|
||||
Reference in New Issue
Block a user