feat(advanced analysis): support MultiIndex column in post processing stage (#19116)

This commit is contained in:
Yongjie Zhao
2022-03-23 13:46:28 +08:00
committed by GitHub
parent 6083545e86
commit 375c03e084
55 changed files with 1267 additions and 772 deletions

View File

@@ -20,7 +20,7 @@ from typing import Optional, Union
from flask_babel import gettext as _
from pandas import DataFrame
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.core import DTTM_ALIAS
from superset.utils.pandas_postprocessing.utils import PROPHET_TIME_GRAIN_MAP
@@ -58,7 +58,7 @@ def _prophet_fit_and_predict( # pylint: disable=too-many-arguments
prophet_logger.setLevel(logging.CRITICAL)
prophet_logger.setLevel(logging.NOTSET)
except ModuleNotFoundError as ex:
raise QueryObjectValidationError(_("`prophet` package not installed")) from ex
raise InvalidPostProcessingError(_("`prophet` package not installed")) from ex
model = Prophet(
interval_width=confidence_interval,
yearly_seasonality=yearly_seasonality,
@@ -111,24 +111,24 @@ def prophet( # pylint: disable=too-many-arguments
index = index or DTTM_ALIAS
# validate inputs
if not time_grain:
raise QueryObjectValidationError(_("Time grain missing"))
raise InvalidPostProcessingError(_("Time grain missing"))
if time_grain not in PROPHET_TIME_GRAIN_MAP:
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("Unsupported time grain: %(time_grain)s", time_grain=time_grain,)
)
freq = PROPHET_TIME_GRAIN_MAP[time_grain]
# check type at runtime due to marhsmallow schema not being able to handle
# union types
if not isinstance(periods, int) or periods < 0:
raise QueryObjectValidationError(_("Periods must be a whole number"))
raise InvalidPostProcessingError(_("Periods must be a whole number"))
if not confidence_interval or confidence_interval <= 0 or confidence_interval >= 1:
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("Confidence interval must be between 0 and 1 (exclusive)")
)
if index not in df.columns:
raise QueryObjectValidationError(_("DataFrame must include temporal column"))
raise InvalidPostProcessingError(_("DataFrame must include temporal column"))
if len(df.columns) < 2:
raise QueryObjectValidationError(_("DataFrame include at least one series"))
raise InvalidPostProcessingError(_("DataFrame include at least one series"))
target_df = DataFrame()
for column in [column for column in df.columns if column != index]: