feat(chart-data-api): make pivoted columns flattenable (#10255)

* feat(chart-data-api): make pivoted columns flattenable

* Linting + improve tests
This commit is contained in:
Ville Brofeldt
2020-07-08 13:35:53 +03:00
committed by GitHub
parent 4252770d50
commit baeacc3c56
3 changed files with 134 additions and 21 deletions

View File

@@ -72,13 +72,38 @@ WHITELIST_CUMULATIVE_FUNCTIONS = (
)
def _flatten_column_after_pivot(
column: Union[str, Tuple[str, ...]], aggregates: Dict[str, Dict[str, Any]]
) -> str:
"""
Function for flattening column names into a single string. This step is necessary
to be able to properly serialize a DataFrame. If the column is a string, return
element unchanged. For multi-element columns, join column elements with a comma,
with the exception of pivots made with a single aggregate, in which case the
aggregate column name is omitted.
:param column: single element from `DataFrame.columns`
:param aggregates: aggregates
:return:
"""
if isinstance(column, str):
return column
if len(column) == 1:
return column[0]
if len(aggregates) == 1 and len(column) > 1:
# drop aggregate for single aggregate pivots with multiple groupings
# from column name (aggregates always come first in column name)
column = column[1:]
return ", ".join(column)
def validate_column_args(*argnames: str) -> Callable[..., Any]:
def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapped(df: DataFrame, **options: Any) -> Any:
columns = df.columns.tolist()
for name in argnames:
if name in options and not all(
elem in columns for elem in options[name]
elem in columns for elem in options.get(name) or []
):
raise QueryObjectValidationError(
_("Referenced columns not available in DataFrame.")
@@ -154,14 +179,15 @@ def _append_columns(
def pivot( # pylint: disable=too-many-arguments
df: DataFrame,
index: List[str],
columns: List[str],
aggregates: Dict[str, Dict[str, Any]],
columns: Optional[List[str]] = None,
metric_fill_value: Optional[Any] = None,
column_fill_value: Optional[str] = None,
drop_missing_columns: Optional[bool] = True,
combine_value_with_metric: bool = False,
marginal_distributions: Optional[bool] = None,
marginal_distribution_name: Optional[str] = None,
flatten_columns: bool = True,
) -> DataFrame:
"""
Perform a pivot operation on a DataFrame.
@@ -179,6 +205,7 @@ def pivot( # pylint: disable=too-many-arguments
:param marginal_distributions: Add totals for row/column. Default to False
:param marginal_distribution_name: Name of row/column with marginal distribution.
Default to 'All'.
:param flatten_columns: Convert column names to strings
:return: A pivot table
:raises ChartDataValidationError: If the request in incorrect
"""
@@ -186,10 +213,6 @@ def pivot( # pylint: disable=too-many-arguments
raise QueryObjectValidationError(
_("Pivot operation requires at least one index")
)
if not columns:
raise QueryObjectValidationError(
_("Pivot operation requires at least one column")
)
if not aggregates:
raise QueryObjectValidationError(
_("Pivot operation must include at least one aggregate")
@@ -218,6 +241,13 @@ def pivot( # pylint: disable=too-many-arguments
if combine_value_with_metric:
df = df.stack(0).unstack()
# Make index regular column
if flatten_columns:
df.columns = [
_flatten_column_after_pivot(col, aggregates) for col in df.columns
]
# return index as regular column
df.reset_index(level=0, inplace=True)
return df