mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
feat(advanced analysis): support MultiIndex column in post processing stage (#19116)
This commit is contained in:
@@ -20,7 +20,7 @@ from typing import Optional, Union
|
||||
from flask_babel import gettext as _
|
||||
from pandas import DataFrame
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.core import DTTM_ALIAS
|
||||
from superset.utils.pandas_postprocessing.utils import PROPHET_TIME_GRAIN_MAP
|
||||
|
||||
@@ -58,7 +58,7 @@ def _prophet_fit_and_predict( # pylint: disable=too-many-arguments
|
||||
prophet_logger.setLevel(logging.CRITICAL)
|
||||
prophet_logger.setLevel(logging.NOTSET)
|
||||
except ModuleNotFoundError as ex:
|
||||
raise QueryObjectValidationError(_("`prophet` package not installed")) from ex
|
||||
raise InvalidPostProcessingError(_("`prophet` package not installed")) from ex
|
||||
model = Prophet(
|
||||
interval_width=confidence_interval,
|
||||
yearly_seasonality=yearly_seasonality,
|
||||
@@ -111,24 +111,24 @@ def prophet( # pylint: disable=too-many-arguments
|
||||
index = index or DTTM_ALIAS
|
||||
# validate inputs
|
||||
if not time_grain:
|
||||
raise QueryObjectValidationError(_("Time grain missing"))
|
||||
raise InvalidPostProcessingError(_("Time grain missing"))
|
||||
if time_grain not in PROPHET_TIME_GRAIN_MAP:
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("Unsupported time grain: %(time_grain)s", time_grain=time_grain,)
|
||||
)
|
||||
freq = PROPHET_TIME_GRAIN_MAP[time_grain]
|
||||
# check type at runtime due to marhsmallow schema not being able to handle
|
||||
# union types
|
||||
if not isinstance(periods, int) or periods < 0:
|
||||
raise QueryObjectValidationError(_("Periods must be a whole number"))
|
||||
raise InvalidPostProcessingError(_("Periods must be a whole number"))
|
||||
if not confidence_interval or confidence_interval <= 0 or confidence_interval >= 1:
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("Confidence interval must be between 0 and 1 (exclusive)")
|
||||
)
|
||||
if index not in df.columns:
|
||||
raise QueryObjectValidationError(_("DataFrame must include temporal column"))
|
||||
raise InvalidPostProcessingError(_("DataFrame must include temporal column"))
|
||||
if len(df.columns) < 2:
|
||||
raise QueryObjectValidationError(_("DataFrame include at least one series"))
|
||||
raise InvalidPostProcessingError(_("DataFrame include at least one series"))
|
||||
|
||||
target_df = DataFrame()
|
||||
for column in [column for column in df.columns if column != index]:
|
||||
|
||||
Reference in New Issue
Block a user