feat(advanced analysis): support MultiIndex column in post processing stage (#19116)

This commit is contained in:
Yongjie Zhao
2022-03-23 13:46:28 +08:00
committed by Ville Brofeldt
parent f8a92de75c
commit 9bc76337cf
55 changed files with 1272 additions and 772 deletions

View File

@@ -20,6 +20,7 @@ from superset.utils.pandas_postprocessing.compare import compare
from superset.utils.pandas_postprocessing.contribution import contribution
from superset.utils.pandas_postprocessing.cum import cum
from superset.utils.pandas_postprocessing.diff import diff
from superset.utils.pandas_postprocessing.flatten import flatten
from superset.utils.pandas_postprocessing.geography import (
geodetic_parse,
geohash_decode,
@@ -49,5 +50,6 @@ __all__ = [
"rolling",
"select",
"sort",
"flatten",
"_flatten_column_after_pivot",
]

View File

@@ -35,7 +35,7 @@ def aggregate(
:param groupby: columns to aggregate
:param aggregates: A mapping from metric column to the function used to
aggregate values.
:raises QueryObjectValidationError: If the request in incorrect
:raises InvalidPostProcessingError: If the request in incorrect
"""
aggregates = aggregates or {}
aggregate_funcs = _get_aggregate_funcs(df, aggregates)

View File

@@ -20,7 +20,7 @@ import numpy as np
from flask_babel import gettext as _
from pandas import DataFrame, Series
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.core import PostProcessingBoxplotWhiskerType
from superset.utils.pandas_postprocessing.aggregate import aggregate
@@ -84,7 +84,7 @@ def boxplot(
or not isinstance(percentiles[1], (int, float))
or percentiles[0] >= percentiles[1]
):
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_(
"percentiles must be a list or tuple with two numeric values, "
"of which the first is lower than the second value"

View File

@@ -21,7 +21,7 @@ from flask_babel import gettext as _
from pandas import DataFrame
from superset.constants import PandasPostprocessingCompare
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.core import TIME_COMPARISION
from superset.utils.pandas_postprocessing.utils import validate_column_args
@@ -31,7 +31,7 @@ def compare( # pylint: disable=too-many-arguments
df: DataFrame,
source_columns: List[str],
compare_columns: List[str],
compare_type: Optional[PandasPostprocessingCompare],
compare_type: PandasPostprocessingCompare,
drop_original_columns: Optional[bool] = False,
precision: Optional[int] = 4,
) -> DataFrame:
@@ -46,31 +46,38 @@ def compare( # pylint: disable=too-many-arguments
compare columns.
:param precision: Round a change rate to a variable number of decimal places.
:return: DataFrame with compared columns.
:raises QueryObjectValidationError: If the request in incorrect.
:raises InvalidPostProcessingError: If the request in incorrect.
"""
if len(source_columns) != len(compare_columns):
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("`compare_columns` must have the same length as `source_columns`.")
)
if compare_type not in tuple(PandasPostprocessingCompare):
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("`compare_type` must be `difference`, `percentage` or `ratio`")
)
if len(source_columns) == 0:
return df
for s_col, c_col in zip(source_columns, compare_columns):
s_df = df.loc[:, [s_col]]
s_df.rename(columns={s_col: "__intermediate"}, inplace=True)
c_df = df.loc[:, [c_col]]
c_df.rename(columns={c_col: "__intermediate"}, inplace=True)
if compare_type == PandasPostprocessingCompare.DIFF:
diff_series = df[s_col] - df[c_col]
diff_df = c_df - s_df
elif compare_type == PandasPostprocessingCompare.PCT:
diff_series = (
((df[s_col] - df[c_col]) / df[c_col]).astype(float).round(precision)
)
# https://en.wikipedia.org/wiki/Relative_change_and_difference#Percentage_change
diff_df = ((c_df - s_df) / s_df).astype(float).round(precision)
else:
# compare_type == "ratio"
diff_series = (df[s_col] / df[c_col]).astype(float).round(precision)
diff_df = diff_series.to_frame(
name=TIME_COMPARISION.join([compare_type, s_col, c_col])
diff_df = (c_df / s_df).astype(float).round(precision)
diff_df.rename(
columns={
"__intermediate": TIME_COMPARISION.join([compare_type, s_col, c_col])
},
inplace=True,
)
df = pd.concat([df, diff_df], axis=1)

View File

@@ -20,7 +20,7 @@ from typing import List, Optional
from flask_babel import gettext as _
from pandas import DataFrame
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.core import PostProcessingContributionOrientation
from superset.utils.pandas_postprocessing.utils import validate_column_args
@@ -55,7 +55,7 @@ def contribution(
numeric_columns = numeric_df.columns.tolist()
for col in columns:
if col not in numeric_columns:
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_(
'Column "%(column)s" is not numeric or does not '
"exists in the query results.",
@@ -65,7 +65,7 @@ def contribution(
columns = columns or numeric_df.columns
rename_columns = rename_columns or columns
if len(rename_columns) != len(columns):
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("`rename_columns` must have the same length as `columns`.")
)
# limit to selected columns

View File

@@ -14,27 +14,21 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, Optional
from typing import Dict
from flask_babel import gettext as _
from pandas import DataFrame
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing.utils import (
_append_columns,
_flatten_column_after_pivot,
ALLOWLIST_CUMULATIVE_FUNCTIONS,
validate_column_args,
)
@validate_column_args("columns")
def cum(
df: DataFrame,
operator: str,
columns: Optional[Dict[str, str]] = None,
is_pivot_df: bool = False,
) -> DataFrame:
def cum(df: DataFrame, operator: str, columns: Dict[str, str],) -> DataFrame:
"""
Calculate cumulative sum/product/min/max for select columns.
@@ -45,29 +39,16 @@ def cum(
`y2` based on cumulative values calculated from `y`, leaving the original
column `y` unchanged.
:param operator: cumulative operator, e.g. `sum`, `prod`, `min`, `max`
:param is_pivot_df: Dataframe is pivoted or not
:return: DataFrame with cumulated columns
"""
columns = columns or {}
if is_pivot_df:
df_cum = df
else:
df_cum = df[columns.keys()]
df_cum = df.loc[:, columns.keys()]
operation = "cum" + operator
if operation not in ALLOWLIST_CUMULATIVE_FUNCTIONS or not hasattr(
df_cum, operation
):
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("Invalid cumulative operator: %(operator)s", operator=operator)
)
if is_pivot_df:
df_cum = getattr(df_cum, operation)()
agg_in_pivot_df = df.columns.get_level_values(0).drop_duplicates().to_list()
agg: Dict[str, Dict[str, Any]] = {col: {} for col in agg_in_pivot_df}
df_cum.columns = [
_flatten_column_after_pivot(col, agg) for col in df_cum.columns
]
df_cum.reset_index(level=0, inplace=True)
else:
df_cum = _append_columns(df, getattr(df_cum, operation)(), columns)
df_cum = _append_columns(df, getattr(df_cum, operation)(), columns)
return df_cum

View File

@@ -44,7 +44,7 @@ def diff(
:param periods: periods to shift for calculating difference.
:param axis: 0 for row, 1 for column. default 0.
:return: DataFrame with diffed columns
:raises QueryObjectValidationError: If the request in incorrect
:raises InvalidPostProcessingError: If the request in incorrect
"""
df_diff = df[columns.keys()]
df_diff = df_diff.diff(periods=periods, axis=axis)

View File

@@ -0,0 +1,81 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pandas as pd
from superset.utils.pandas_postprocessing.utils import (
_is_multi_index_on_columns,
FLAT_COLUMN_SEPARATOR,
)
def flatten(df: pd.DataFrame, reset_index: bool = True,) -> pd.DataFrame:
"""
Convert N-dimensional DataFrame to a flat DataFrame
:param df: N-dimensional DataFrame.
:param reset_index: Convert index to column when df.index isn't RangeIndex
:return: a flat DataFrame
Examples
-----------
Convert DatetimeIndex into columns.
>>> index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03",])
>>> index.name = "__timestamp"
>>> df = pd.DataFrame(index=index, data={"metric": [1, 2, 3]})
>>> df
metric
__timestamp
2021-01-01 1
2021-01-02 2
2021-01-03 3
>>> df = flatten(df)
>>> df
__timestamp metric
0 2021-01-01 1
1 2021-01-02 2
2 2021-01-03 3
Convert DatetimeIndex and MultipleIndex into columns
>>> iterables = [["foo", "bar"], ["one", "two"]]
>>> columns = pd.MultiIndex.from_product(iterables, names=["level1", "level2"])
>>> df = pd.DataFrame(index=index, columns=columns, data=1)
>>> df
level1 foo bar
level2 one two one two
__timestamp
2021-01-01 1 1 1 1
2021-01-02 1 1 1 1
2021-01-03 1 1 1 1
>>> flatten(df)
__timestamp foo, one foo, two bar, one bar, two
0 2021-01-01 1 1 1 1
1 2021-01-02 1 1 1 1
2 2021-01-03 1 1 1 1
"""
if _is_multi_index_on_columns(df):
# every cell should be converted to string
df.columns = [
FLAT_COLUMN_SEPARATOR.join([str(cell) for cell in series])
for series in df.columns.to_flat_index()
]
if reset_index and not isinstance(df.index, pd.RangeIndex):
df = df.reset_index(level=0)
return df

View File

@@ -21,7 +21,7 @@ from flask_babel import gettext as _
from geopy.point import Point
from pandas import DataFrame
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing.utils import _append_columns
@@ -46,7 +46,7 @@ def geohash_decode(
df, lonlat_df, {"latitude": latitude, "longitude": longitude}
)
except ValueError as ex:
raise QueryObjectValidationError(_("Invalid geohash string")) from ex
raise InvalidPostProcessingError(_("Invalid geohash string")) from ex
def geohash_encode(
@@ -69,7 +69,7 @@ def geohash_encode(
)
return _append_columns(df, encode_df, {"geohash": geohash})
except ValueError as ex:
raise QueryObjectValidationError(_("Invalid longitude/latitude")) from ex
raise InvalidPostProcessingError(_("Invalid longitude/latitude")) from ex
def geodetic_parse(
@@ -111,4 +111,4 @@ def geodetic_parse(
columns["altitude"] = altitude
return _append_columns(df, geodetic_df, columns)
except ValueError as ex:
raise QueryObjectValidationError(_("Invalid geodetic string")) from ex
raise InvalidPostProcessingError(_("Invalid geodetic string")) from ex

View File

@@ -20,7 +20,7 @@ from flask_babel import gettext as _
from pandas import DataFrame
from superset.constants import NULL_STRING, PandasAxis
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing.utils import (
_flatten_column_after_pivot,
_get_aggregate_funcs,
@@ -64,14 +64,14 @@ def pivot( # pylint: disable=too-many-arguments,too-many-locals
:param flatten_columns: Convert column names to strings
:param reset_index: Convert index to column
:return: A pivot table
:raises QueryObjectValidationError: If the request in incorrect
:raises InvalidPostProcessingError: If the request in incorrect
"""
if not index:
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("Pivot operation requires at least one index")
)
if not aggregates:
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("Pivot operation must include at least one aggregate")
)

View File

@@ -20,7 +20,7 @@ from typing import Optional, Union
from flask_babel import gettext as _
from pandas import DataFrame
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.core import DTTM_ALIAS
from superset.utils.pandas_postprocessing.utils import PROPHET_TIME_GRAIN_MAP
@@ -58,7 +58,7 @@ def _prophet_fit_and_predict( # pylint: disable=too-many-arguments
prophet_logger.setLevel(logging.CRITICAL)
prophet_logger.setLevel(logging.NOTSET)
except ModuleNotFoundError as ex:
raise QueryObjectValidationError(_("`prophet` package not installed")) from ex
raise InvalidPostProcessingError(_("`prophet` package not installed")) from ex
model = Prophet(
interval_width=confidence_interval,
yearly_seasonality=yearly_seasonality,
@@ -111,24 +111,24 @@ def prophet( # pylint: disable=too-many-arguments
index = index or DTTM_ALIAS
# validate inputs
if not time_grain:
raise QueryObjectValidationError(_("Time grain missing"))
raise InvalidPostProcessingError(_("Time grain missing"))
if time_grain not in PROPHET_TIME_GRAIN_MAP:
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("Unsupported time grain: %(time_grain)s", time_grain=time_grain,)
)
freq = PROPHET_TIME_GRAIN_MAP[time_grain]
# check type at runtime due to marhsmallow schema not being able to handle
# union types
if not isinstance(periods, int) or periods < 0:
raise QueryObjectValidationError(_("Periods must be a whole number"))
raise InvalidPostProcessingError(_("Periods must be a whole number"))
if not confidence_interval or confidence_interval <= 0 or confidence_interval >= 1:
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("Confidence interval must be between 0 and 1 (exclusive)")
)
if index not in df.columns:
raise QueryObjectValidationError(_("DataFrame must include temporal column"))
raise InvalidPostProcessingError(_("DataFrame must include temporal column"))
if len(df.columns) < 2:
raise QueryObjectValidationError(_("DataFrame include at least one series"))
raise InvalidPostProcessingError(_("DataFrame include at least one series"))
target_df = DataFrame()
for column in [column for column in df.columns if column != index]:

View File

@@ -14,48 +14,35 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Optional, Tuple, Union
from typing import Optional, Union
from pandas import DataFrame
import pandas as pd
from flask_babel import gettext as _
from superset.utils.pandas_postprocessing.utils import validate_column_args
from superset.exceptions import InvalidPostProcessingError
@validate_column_args("groupby_columns")
def resample( # pylint: disable=too-many-arguments
df: DataFrame,
def resample(
df: pd.DataFrame,
rule: str,
method: str,
time_column: str,
groupby_columns: Optional[Tuple[Optional[str], ...]] = None,
fill_value: Optional[Union[float, int]] = None,
) -> DataFrame:
) -> pd.DataFrame:
"""
support upsampling in resample
:param df: DataFrame to resample.
:param rule: The offset string representing target conversion.
:param method: How to fill the NaN value after resample.
:param time_column: existing columns in DataFrame.
:param groupby_columns: columns except time_column in dataframe
:param fill_value: What values do fill missing.
:return: DataFrame after resample
:raises QueryObjectValidationError: If the request in incorrect
:raises InvalidPostProcessingError: If the request in incorrect
"""
if not isinstance(df.index, pd.DatetimeIndex):
raise InvalidPostProcessingError(_("Resample operation requires DatetimeIndex"))
def _upsampling(_df: DataFrame) -> DataFrame:
_df = _df.set_index(time_column)
if method == "asfreq" and fill_value is not None:
return _df.resample(rule).asfreq(fill_value=fill_value)
return getattr(_df.resample(rule), method)()
if groupby_columns:
df = (
df.set_index(keys=list(groupby_columns))
.groupby(by=list(groupby_columns))
.apply(_upsampling)
)
df = df.reset_index().set_index(time_column).sort_index()
if method == "asfreq" and fill_value is not None:
_df = df.resample(rule).asfreq(fill_value=fill_value)
else:
df = _upsampling(df)
return df.reset_index()
_df = getattr(df.resample(rule), method)()
return _df

View File

@@ -19,10 +19,9 @@ from typing import Any, Dict, Optional, Union
from flask_babel import gettext as _
from pandas import DataFrame
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing.utils import (
_append_columns,
_flatten_column_after_pivot,
DENYLIST_ROLLING_FUNCTIONS,
validate_column_args,
)
@@ -32,13 +31,12 @@ from superset.utils.pandas_postprocessing.utils import (
def rolling( # pylint: disable=too-many-arguments
df: DataFrame,
rolling_type: str,
columns: Optional[Dict[str, str]] = None,
columns: Dict[str, str],
window: Optional[int] = None,
rolling_type_options: Optional[Dict[str, Any]] = None,
center: bool = False,
win_type: Optional[str] = None,
min_periods: Optional[int] = None,
is_pivot_df: bool = False,
) -> DataFrame:
"""
Apply a rolling window on the dataset. See the Pandas docs for further details:
@@ -58,21 +56,17 @@ def rolling( # pylint: disable=too-many-arguments
:param win_type: Type of window function.
:param min_periods: The minimum amount of periods required for a row to be included
in the result set.
:param is_pivot_df: Dataframe is pivoted or not
:return: DataFrame with the rolling columns
:raises QueryObjectValidationError: If the request in incorrect
:raises InvalidPostProcessingError: If the request in incorrect
"""
rolling_type_options = rolling_type_options or {}
columns = columns or {}
if is_pivot_df:
df_rolling = df
else:
df_rolling = df[columns.keys()]
df_rolling = df.loc[:, columns.keys()]
kwargs: Dict[str, Union[str, int]] = {}
if window is None:
raise QueryObjectValidationError(_("Undefined window for rolling operation"))
raise InvalidPostProcessingError(_("Undefined window for rolling operation"))
if window == 0:
raise QueryObjectValidationError(_("Window must be > 0"))
raise InvalidPostProcessingError(_("Window must be > 0"))
kwargs["window"] = window
if min_periods is not None:
@@ -86,13 +80,13 @@ def rolling( # pylint: disable=too-many-arguments
if rolling_type not in DENYLIST_ROLLING_FUNCTIONS or not hasattr(
df_rolling, rolling_type
):
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("Invalid rolling_type: %(type)s", type=rolling_type)
)
try:
df_rolling = getattr(df_rolling, rolling_type)(**rolling_type_options)
except TypeError as ex:
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_(
"Invalid options for %(rolling_type)s: %(options)s",
rolling_type=rolling_type,
@@ -100,15 +94,7 @@ def rolling( # pylint: disable=too-many-arguments
)
) from ex
if is_pivot_df:
agg_in_pivot_df = df.columns.get_level_values(0).drop_duplicates().to_list()
agg: Dict[str, Dict[str, Any]] = {col: {} for col in agg_in_pivot_df}
df_rolling.columns = [
_flatten_column_after_pivot(col, agg) for col in df_rolling.columns
]
df_rolling.reset_index(level=0, inplace=True)
else:
df_rolling = _append_columns(df, df_rolling, columns)
df_rolling = _append_columns(df, df_rolling, columns)
if min_periods:
df_rolling = df_rolling[min_periods:]

View File

@@ -42,7 +42,7 @@ def select(
For instance, `{'y': 'y2'}` will rename the column `y` to
`y2`.
:return: Subset of columns in original DataFrame
:raises QueryObjectValidationError: If the request in incorrect
:raises InvalidPostProcessingError: If the request in incorrect
"""
df_select = df.copy(deep=False)
if columns:

View File

@@ -30,6 +30,6 @@ def sort(df: DataFrame, columns: Dict[str, bool]) -> DataFrame:
:param columns: columns by by which to sort. The key specifies the column name,
value specifies if sorting in ascending order.
:return: Sorted DataFrame
:raises QueryObjectValidationError: If the request in incorrect
:raises InvalidPostProcessingError: If the request in incorrect
"""
return df.sort_values(by=list(columns.keys()), ascending=list(columns.values()))

View File

@@ -18,10 +18,11 @@ from functools import partial
from typing import Any, Callable, Dict, Tuple, Union
import numpy as np
import pandas as pd
from flask_babel import gettext as _
from pandas import DataFrame, NamedAgg, Timestamp
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
NUMPY_FUNCTIONS = {
"average": np.average,
@@ -91,6 +92,8 @@ PROPHET_TIME_GRAIN_MAP = {
"P1W/1970-01-04T00:00:00Z": "W",
}
FLAT_COLUMN_SEPARATOR = ", "
def _flatten_column_after_pivot(
column: Union[float, Timestamp, str, Tuple[str, ...]],
@@ -113,21 +116,26 @@ def _flatten_column_after_pivot(
# drop aggregate for single aggregate pivots with multiple groupings
# from column name (aggregates always come first in column name)
column = column[1:]
return ", ".join([str(col) for col in column])
return FLAT_COLUMN_SEPARATOR.join([str(col) for col in column])
def _is_multi_index_on_columns(df: DataFrame) -> bool:
return isinstance(df.columns, pd.MultiIndex)
def validate_column_args(*argnames: str) -> Callable[..., Any]:
def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapped(df: DataFrame, **options: Any) -> Any:
if options.get("is_pivot_df"):
# skip validation when pivot Dataframe
return func(df, **options)
columns = df.columns.tolist()
if _is_multi_index_on_columns(df):
# MultiIndex column validate first level
columns = df.columns.get_level_values(0)
else:
columns = df.columns.tolist()
for name in argnames:
if name in options and not all(
elem in columns for elem in options.get(name) or []
):
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("Referenced columns not available in DataFrame.")
)
return func(df, **options)
@@ -152,14 +160,14 @@ def _get_aggregate_funcs(
for name, agg_obj in aggregates.items():
column = agg_obj.get("column", name)
if column not in df:
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_(
"Column referenced by aggregate is undefined: %(column)s",
column=column,
)
)
if "operator" not in agg_obj:
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("Operator undefined for aggregator: %(name)s", name=name,)
)
operator = agg_obj["operator"]
@@ -168,7 +176,7 @@ def _get_aggregate_funcs(
else:
func = NUMPY_FUNCTIONS.get(operator)
if not func:
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("Invalid numpy function: %(operator)s", operator=operator,)
)
options = agg_obj.get("options", {})
@@ -186,6 +194,8 @@ def _append_columns(
assign method, which overwrites the original column in `base_df` if the column
already exists, and appends the column if the name is not defined.
Note that! this is a memory-intensive operation.
:param base_df: DataFrame which to use as the base
:param append_df: DataFrame from which to select data.
:param columns: columns on which to append, mapping source column to
@@ -196,6 +206,10 @@ def _append_columns(
in `base_df` unchanged.
:return: new DataFrame with combined data from `base_df` and `append_df`
"""
return base_df.assign(
**{target: append_df[source] for source, target in columns.items()}
)
if all(key == value for key, value in columns.items()):
# make sure to return a new DataFrame instead of changing the `base_df`.
_base_df = base_df.copy()
_base_df.loc[:, columns.keys()] = append_df
return _base_df
append_df = append_df.rename(columns=columns)
return pd.concat([base_df, append_df], axis="columns")