diff --git a/superset/charts/client_processing.py b/superset/charts/client_processing.py index 5a873ab8986..a66531f7b5a 100644 --- a/superset/charts/client_processing.py +++ b/superset/charts/client_processing.py @@ -24,6 +24,7 @@ data they would see on Explore. In order to do that, we reproduce the post-processing in Python for these chart types. """ +import logging from io import StringIO from typing import Any, Optional, TYPE_CHECKING, Union @@ -44,6 +45,9 @@ if TYPE_CHECKING: from superset.models.sql_lab import Query +logger = logging.getLogger(__name__) + + def get_column_key(label: tuple[str, ...], metrics: list[str]) -> tuple[Any, ...]: """ Sort columns when combining metrics. @@ -182,11 +186,20 @@ def pivot_df( # pylint: disable=too-many-locals, too-many-arguments, too-many-s for level in range(df.index.nlevels): subgroups = {group[:level] for group in groups} for subgroup in subgroups: - slice_ = df.index.get_loc(subgroup) + try: + slice_ = df.index.get_loc(subgroup) + except Exception: # pylint: disable=broad-except + logger.exception( + "Error getting location for subgroup %s from %s", + subgroup, + groups, + ) + raise + subtotal = pivot_v2_aggfunc_map[aggfunc]( df.iloc[slice_, :].apply(pd.to_numeric, errors="coerce"), axis=0 ) - depth = df.index.nlevels - len(subgroup) - 1 + depth = groups.nlevels - len(subgroup) - 1 total = metric_name if level == 0 else __("Subtotal") subtotal.name = tuple([*subgroup, total, *([""] * depth)]) # noqa: C409 # insert row after subgroup diff --git a/tests/unit_tests/charts/test_client_processing.py b/tests/unit_tests/charts/test_client_processing.py index e35229554b5..b38513a2e3c 100644 --- a/tests/unit_tests/charts/test_client_processing.py +++ b/tests/unit_tests/charts/test_client_processing.py @@ -987,8 +987,12 @@ def test_pivot_df_complex(): show_columns_total=False, apply_metrics_on_rows=False, ) + + # Sort the pivoted DataFrame to ensure deterministic output + pivoted_sorted = pivoted.sort_index() + assert ( - pivoted.to_markdown() + pivoted_sorted.to_markdown() == """ | | ('SUM(num)', 'CA') | ('SUM(num)', 'FL') | ('MAX(num)', 'CA') | ('MAX(num)', 'FL') | |:-------------------|---------------------:|---------------------:|---------------------:|---------------------:| @@ -2499,3 +2503,62 @@ def test_apply_client_processing_verbose_map(session: Session): } ] } + + +def test_pivot_multi_level_index(): + """ + Pivot table with multi-level indexing. + """ + arrays = [ + ["Region1", "Region1", "Region1", "Region2", "Region2", "Region2"], + ["State1", "State1", "State2", "State3", "State3", "State4"], + ["City1", "City2", "City3", "City4", "City5", "City6"], + ] + index = pd.MultiIndex.from_tuples( + list(zip(*arrays, strict=False)), + names=["Region", "State", "City"], + ) + + data = { + "Metric1": [10, 20, 30, 40, 50, 60], + "Metric2": [5, 10, 15, 20, 25, 30], + "Metric3": [None, None, None, None, None, None], + } + df = pd.DataFrame(data, index=index) + + pivoted = pivot_df( + df, + rows=["Region", "State", "City"], + columns=[], + metrics=["Metric1", "Metric2", "Metric3"], + aggfunc="Sum", + transpose_pivot=False, + combine_metrics=False, + show_rows_total=False, + show_columns_total=True, + apply_metrics_on_rows=False, + ) + + # Sort the pivoted DataFrame to ensure deterministic output + pivoted_sorted = pivoted.sort_index() + + assert ( + pivoted_sorted.to_markdown() + == """ +| | ('Metric1',) | ('Metric2',) | ('Metric3',) | +|:----------------------------------|---------------:|---------------:|---------------:| +| ('Region1', 'State1', 'City1') | 10 | 5 | nan | +| ('Region1', 'State1', 'City2') | 20 | 10 | nan | +| ('Region1', 'State1', 'Subtotal') | 30 | 15 | 0 | +| ('Region1', 'State2', 'City3') | 30 | 15 | nan | +| ('Region1', 'State2', 'Subtotal') | 30 | 15 | 0 | +| ('Region1', 'Subtotal', '') | 60 | 30 | 0 | +| ('Region2', 'State3', 'City4') | 40 | 20 | nan | +| ('Region2', 'State3', 'City5') | 50 | 25 | nan | +| ('Region2', 'State3', 'Subtotal') | 90 | 45 | 0 | +| ('Region2', 'State4', 'City6') | 60 | 30 | nan | +| ('Region2', 'State4', 'Subtotal') | 60 | 30 | 0 | +| ('Region2', 'Subtotal', '') | 150 | 75 | 0 | +| ('Total (Sum)', '', '') | 210 | 105 | 0 | + """.strip() + )