diff --git a/superset/common/query_context.py b/superset/common/query_context.py index a04e3944603..400c4a95038 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -22,10 +22,7 @@ from typing import Any, ClassVar, TYPE_CHECKING import pandas as pd from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType -from superset.common.query_context_processor import ( - CachedTimeOffset, - QueryContextProcessor, -) +from superset.common.query_context_processor import QueryContextProcessor from superset.common.query_object import QueryObject from superset.models.slice import Slice from superset.utils.core import GenericDataType @@ -128,12 +125,5 @@ class QueryContext: def get_query_result(self, query_object: QueryObject) -> QueryResult: return self._processor.get_query_result(query_object) - def processing_time_offsets( - self, - df: pd.DataFrame, - query_object: QueryObject, - ) -> CachedTimeOffset: - return self._processor.processing_time_offsets(df, query_object) - def raise_for_access(self) -> None: self._processor.raise_for_access() diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index 94f54097e23..8c488997f14 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -16,61 +16,42 @@ # under the License. from __future__ import annotations -import copy import logging import re -from datetime import datetime -from typing import Any, cast, ClassVar, TYPE_CHECKING, TypedDict +from typing import Any, cast, ClassVar, TYPE_CHECKING -import numpy as np import pandas as pd from flask import current_app from flask_babel import gettext as _ -from pandas import DateOffset from superset.common.chart_data import ChartDataResultFormat from superset.common.db_query_status import QueryStatus from superset.common.query_actions import get_query_results -from superset.common.utils import dataframe_utils from superset.common.utils.query_cache_manager import QueryCacheManager -from superset.common.utils.time_range_utils import ( - get_since_until_from_query_object, - get_since_until_from_time_range, -) +from superset.common.utils.time_range_utils import get_since_until_from_time_range from superset.connectors.sqla.models import BaseDatasource -from superset.constants import CACHE_DISABLED_TIMEOUT, CacheRegion, TimeGrain +from superset.constants import CACHE_DISABLED_TIMEOUT, CacheRegion from superset.daos.annotation_layer import AnnotationLayerDAO from superset.daos.chart import ChartDAO from superset.exceptions import ( - InvalidPostProcessingError, QueryObjectValidationError, SupersetException, ) -from superset.extensions import cache_manager, feature_flag_manager, security_manager +from superset.extensions import cache_manager, security_manager from superset.models.helpers import QueryResult -from superset.models.sql_lab import Query from superset.superset_typing import AdhocColumn, AdhocMetric from superset.utils import csv, excel from superset.utils.cache import generate_cache_key, set_and_log_cache from superset.utils.core import ( DatasourceType, - DateColumn, DTTM_ALIAS, error_msg_from_exception, - FilterOperator, GenericDataType, - get_base_axis_labels, get_column_names_from_columns, get_column_names_from_metrics, - get_metric_names, - get_x_axis_label, is_adhoc_column, is_adhoc_metric, - normalize_dttm_col, - QueryObjectFilterClause, - TIME_COMPARISON, ) -from superset.utils.date_parser import get_past_or_future, normalize_time_delta from superset.utils.pandas_postprocessing.utils import unescape_separator from superset.views.utils import get_viz from superset.viz import viz_types @@ -81,33 +62,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -# Offset join column suffix used for joining offset results -OFFSET_JOIN_COLUMN_SUFFIX = "__offset_join_column_" - -# This only includes time grains that may influence -# the temporal column used for joining offset results. -# Given that we don't allow time shifts smaller than a day, -# we don't need to include smaller time grains aggregations. -AGGREGATED_JOIN_GRAINS = { - TimeGrain.WEEK, - TimeGrain.WEEK_STARTING_SUNDAY, - TimeGrain.WEEK_STARTING_MONDAY, - TimeGrain.WEEK_ENDING_SATURDAY, - TimeGrain.WEEK_ENDING_SUNDAY, - TimeGrain.MONTH, - TimeGrain.QUARTER, - TimeGrain.YEAR, -} - -# Right suffix used for joining offset results -R_SUFFIX = "__right_suffix" - - -class CachedTimeOffset(TypedDict): - df: pd.DataFrame - queries: list[str] - cache_keys: list[str | None] - class QueryContextProcessor: """ @@ -266,726 +220,14 @@ class QueryContextProcessor: return cache_key def get_query_result(self, query_object: QueryObject) -> QueryResult: - """Returns a pandas dataframe based on the query object""" - query_context = self._query_context - # Here, we assume that all the queries will use the same datasource, which is - # a valid assumption for current setting. In the long term, we may - # support multiple queries from different data sources. - - query = "" - if isinstance(query_context.datasource, Query): - # todo(hugh): add logic to manage all sip68 models here - result = query_context.datasource.exc_query(query_object.to_dict()) - else: - result = query_context.datasource.query(query_object.to_dict()) - query = result.query + ";\n\n" - - df = result.df - # Transform the timestamp we received from database to pandas supported - # datetime format. If no python_date_format is specified, the pattern will - # be considered as the default ISO date format - # If the datetime format is unix, the parse will use the corresponding - # parsing logic - if not df.empty: - df = self.normalize_df(df, query_object) - - if query_object.time_offsets: - time_offsets = self.processing_time_offsets(df, query_object) - df = time_offsets["df"] - queries = time_offsets["queries"] - - query += ";\n\n".join(queries) - query += ";\n\n" - - # Re-raising QueryObjectValidationError - try: - df = query_object.exec_post_processing(df) - except InvalidPostProcessingError as ex: - raise QueryObjectValidationError(ex.message) from ex - - result.df = df - result.query = query - result.from_dttm = query_object.from_dttm - result.to_dttm = query_object.to_dttm - return result - - def normalize_df(self, df: pd.DataFrame, query_object: QueryObject) -> pd.DataFrame: - # todo: should support "python_date_format" and "get_column" in each datasource - def _get_timestamp_format( - source: BaseDatasource, column: str | None - ) -> str | None: - column_obj = source.get_column(column) - if ( - column_obj - # only sqla column was supported - and hasattr(column_obj, "python_date_format") - and (formatter := column_obj.python_date_format) - ): - return str(formatter) - - return None - - datasource = self._qc_datasource - labels = tuple( - label - for label in [ - *get_base_axis_labels(query_object.columns), - query_object.granularity, - ] - if datasource - # Query datasource didn't support `get_column` - and hasattr(datasource, "get_column") - and (col := datasource.get_column(label)) - # todo(hugh) standardize column object in Query datasource - and (col.get("is_dttm") if isinstance(col, dict) else col.is_dttm) - ) - dttm_cols = [ - DateColumn( - timestamp_format=_get_timestamp_format(datasource, label), - offset=datasource.offset, - time_shift=query_object.time_shift, - col_label=label, - ) - for label in labels - if label - ] - if DTTM_ALIAS in df: - dttm_cols.append( - DateColumn.get_legacy_time_column( - timestamp_format=_get_timestamp_format( - datasource, query_object.granularity - ), - offset=datasource.offset, - time_shift=query_object.time_shift, - ) - ) - normalize_dttm_col( - df=df, - dttm_cols=tuple(dttm_cols), - ) - - if self.enforce_numerical_metrics: - dataframe_utils.df_metrics_to_num(df, query_object) - - df.replace([np.inf, -np.inf], np.nan, inplace=True) - - return df - - @staticmethod - def get_time_grain(query_object: QueryObject) -> Any | None: - if ( - query_object.columns - and len(query_object.columns) > 0 - and isinstance(query_object.columns[0], dict) - ): - # If the time grain is in the columns it will be the first one - # and it will be of AdhocColumn type - return query_object.columns[0].get("timeGrain") - - return query_object.extras.get("time_grain_sqla") - - # pylint: disable=too-many-arguments - def add_offset_join_column( - self, - df: pd.DataFrame, - name: str, - time_grain: str, - time_offset: str | None = None, - join_column_producer: Any = None, - ) -> None: """ - Adds an offset join column to the provided DataFrame. + Returns a pandas dataframe based on the query object. - The function modifies the DataFrame in-place. - - :param df: pandas DataFrame to which the offset join column will be added. - :param name: The name of the new column to be added. - :param time_grain: The time grain used to calculate the new column. - :param time_offset: The time offset used to calculate the new column. - :param join_column_producer: A function to generate the join column. + This method delegates to the datasource's get_query_result method, + which handles query execution, normalization, time offsets, and + post-processing. """ - if join_column_producer: - df[name] = df.apply(lambda row: join_column_producer(row, 0), axis=1) - else: - df[name] = df.apply( - lambda row: self.generate_join_column(row, 0, time_grain, time_offset), - axis=1, - ) - - def is_valid_date(self, date_string: str) -> bool: - try: - # Attempt to parse the string as a date in the format YYYY-MM-DD - datetime.strptime(date_string, "%Y-%m-%d") - return True - except ValueError: - # If parsing fails, it's not a valid date in the format YYYY-MM-DD - return False - - def is_valid_date_range(self, date_range: str) -> bool: - try: - # Attempt to parse the string as a date range in the format - # YYYY-MM-DD:YYYY-MM-DD - start_date, end_date = date_range.split(":") - datetime.strptime(start_date.strip(), "%Y-%m-%d") - datetime.strptime(end_date.strip(), "%Y-%m-%d") - return True - except ValueError: - # If parsing fails, it's not a valid date range in the format - # YYYY-MM-DD:YYYY-MM-DD - return False - - def get_offset_custom_or_inherit( - self, - offset: str, - outer_from_dttm: datetime, - outer_to_dttm: datetime, - ) -> str: - """ - Get the time offset for custom or inherit. - - :param offset: The offset string. - :param outer_from_dttm: The outer from datetime. - :param outer_to_dttm: The outer to datetime. - :returns: The time offset. - """ - if offset == "inherit": - # return the difference in days between the from and the to dttm formatted as a string with the " days ago" suffix # noqa: E501 - return f"{(outer_to_dttm - outer_from_dttm).days} days ago" - if self.is_valid_date(offset): - # return the offset as the difference in days between the outer from dttm and the offset date (which is a YYYY-MM-DD string) formatted as a string with the " days ago" suffix # noqa: E501 - offset_date = datetime.strptime(offset, "%Y-%m-%d") - return f"{(outer_from_dttm - offset_date).days} days ago" - return "" - - def processing_time_offsets( # pylint: disable=too-many-locals,too-many-statements # noqa: C901 - self, - df: pd.DataFrame, - query_object: QueryObject, - ) -> CachedTimeOffset: - """ - Process time offsets for time comparison feature. - - This method handles both relative time offsets (e.g., "1 week ago") and - absolute date range offsets (e.g., "2015-01-03 : 2015-01-04"). - """ - query_context = self._query_context - # ensure query_object is immutable - query_object_clone = copy.copy(query_object) - queries: list[str] = [] - cache_keys: list[str | None] = [] - offset_dfs: dict[str, pd.DataFrame] = {} - - outer_from_dttm, outer_to_dttm = get_since_until_from_query_object(query_object) - if not outer_from_dttm or not outer_to_dttm: - raise QueryObjectValidationError( - _( - "An enclosed time range (both start and end) must be specified " - "when using a Time Comparison." - ) - ) - - time_grain = self.get_time_grain(query_object) - metric_names = get_metric_names(query_object.metrics) - # use columns that are not metrics as join keys - join_keys = [col for col in df.columns if col not in metric_names] - - for offset in query_object.time_offsets: - try: - original_offset = offset - is_date_range_offset = self.is_valid_date_range(offset) - - if is_date_range_offset and feature_flag_manager.is_feature_enabled( - "DATE_RANGE_TIMESHIFTS_ENABLED" - ): - # DATE RANGE OFFSET LOGIC (like "2015-01-03 : 2015-01-04") - try: - # Parse the specified range - offset_from_dttm, offset_to_dttm = ( - get_since_until_from_time_range(time_range=offset) - ) - except ValueError as ex: - raise QueryObjectValidationError(str(ex)) from ex - - # Use the specified range directly - query_object_clone.from_dttm = offset_from_dttm - query_object_clone.to_dttm = offset_to_dttm - - # For date range offsets, we must NOT set inner bounds - # These create additional WHERE clauses that conflict with our - # date range - query_object_clone.inner_from_dttm = None - query_object_clone.inner_to_dttm = None - - elif is_date_range_offset: - # Date range timeshift feature is disabled - raise QueryObjectValidationError( - "Date range timeshifts are not enabled. " - "Please contact your administrator to enable the " - "DATE_RANGE_TIMESHIFTS_ENABLED feature flag." - ) - - else: - # RELATIVE OFFSET LOGIC (like "1 day ago") - if self.is_valid_date(offset) or offset == "inherit": - offset = self.get_offset_custom_or_inherit( - offset, - outer_from_dttm, - outer_to_dttm, - ) - query_object_clone.from_dttm = get_past_or_future( - offset, - outer_from_dttm, - ) - query_object_clone.to_dttm = get_past_or_future( - offset, outer_to_dttm - ) - - query_object_clone.inner_from_dttm = query_object_clone.from_dttm - query_object_clone.inner_to_dttm = query_object_clone.to_dttm - - x_axis_label = get_x_axis_label(query_object.columns) - query_object_clone.granularity = ( - query_object_clone.granularity or x_axis_label - ) - - except ValueError as ex: - raise QueryObjectValidationError(str(ex)) from ex - - query_object_clone.time_offsets = [] - query_object_clone.post_processing = [] - - # Get time offset index - index = (get_base_axis_labels(query_object.columns) or [DTTM_ALIAS])[0] - - if is_date_range_offset and feature_flag_manager.is_feature_enabled( - "DATE_RANGE_TIMESHIFTS_ENABLED" - ): - # Create a completely new filter list to preserve original filters - query_object_clone.filter = copy.deepcopy(query_object_clone.filter) - - # Remove any existing temporal filters that might conflict - query_object_clone.filter = [ - flt - for flt in query_object_clone.filter - if not (flt.get("op") == FilterOperator.TEMPORAL_RANGE) - ] - - # Determine the temporal column with multiple fallback strategies - temporal_col = self._get_temporal_column_for_filter( - query_object_clone, x_axis_label - ) - - # Always add a temporal filter for date range offsets - if temporal_col: - new_temporal_filter: QueryObjectFilterClause = { - "col": temporal_col, - "op": FilterOperator.TEMPORAL_RANGE, - "val": ( - f"{query_object_clone.from_dttm} : " - f"{query_object_clone.to_dttm}" - ), - } - query_object_clone.filter.append(new_temporal_filter) - - else: - # This should rarely happen with proper fallbacks - raise QueryObjectValidationError( - _( - "Unable to identify temporal column for date range time comparison." # noqa: E501 - "Please ensure your dataset has a properly configured time column." # noqa: E501 - ) - ) - - else: - # RELATIVE OFFSET: Original logic for non-date-range offsets - # The comparison is not using a temporal column so we need to modify - # the temporal filter so we run the query with the correct time range - if not dataframe_utils.is_datetime_series(df.get(index)): - query_object_clone.filter = copy.deepcopy(query_object_clone.filter) - - # Find and update temporal filters - for flt in query_object_clone.filter: - if flt.get( - "op" - ) == FilterOperator.TEMPORAL_RANGE and isinstance( - flt.get("val"), str - ): - time_range = cast(str, flt.get("val")) - ( - new_outer_from_dttm, - new_outer_to_dttm, - ) = get_since_until_from_time_range( - time_range=time_range, - time_shift=offset, - ) - flt["val"] = f"{new_outer_from_dttm} : {new_outer_to_dttm}" - else: - # If it IS a datetime series, we still need to clear conflicts - query_object_clone.filter = copy.deepcopy(query_object_clone.filter) - - # For relative offsets with datetime series, ensure the temporal - # filter matches our range - temporal_col = query_object_clone.granularity or x_axis_label - - # Update any existing temporal filters to match our shifted range - for flt in query_object_clone.filter: - if ( - flt.get("op") == FilterOperator.TEMPORAL_RANGE - and flt.get("col") == temporal_col - ): - flt["val"] = ( - f"{query_object_clone.from_dttm} : " - f"{query_object_clone.to_dttm}" - ) - - # Remove non-temporal x-axis filters (but keep temporal ones) - query_object_clone.filter = [ - flt - for flt in query_object_clone.filter - if not ( - flt.get("col") == x_axis_label - and flt.get("op") != FilterOperator.TEMPORAL_RANGE - ) - ] - - # Continue with the rest of the method (caching, execution, etc.) - cached_time_offset_key = ( - offset if offset == original_offset else f"{offset}_{original_offset}" - ) - - cache_key = self.query_cache_key( - query_object_clone, - time_offset=cached_time_offset_key, - time_grain=time_grain, - ) - cache = QueryCacheManager.get( - cache_key, CacheRegion.DATA, query_context.force - ) - - if cache.is_loaded: - offset_dfs[offset] = cache.df - queries.append(cache.query) - cache_keys.append(cache_key) - continue - - query_object_clone_dct = query_object_clone.to_dict() - - # rename metrics: SUM(value) => SUM(value) 1 year ago - metrics_mapping = { - metric: TIME_COMPARISON.join([metric, original_offset]) - for metric in metric_names - } - - # When the original query has limit or offset we wont apply those - # to the subquery so we prevent data inconsistency due to missing records - # in the dataframes when performing the join - if query_object.row_limit or query_object.row_offset: - query_object_clone_dct["row_limit"] = current_app.config["ROW_LIMIT"] - query_object_clone_dct["row_offset"] = 0 - - if isinstance(self._qc_datasource, Query): - result = self._qc_datasource.exc_query(query_object_clone_dct) - else: - result = self._qc_datasource.query(query_object_clone_dct) - - queries.append(result.query) - cache_keys.append(None) - - offset_metrics_df = result.df - if offset_metrics_df.empty: - offset_metrics_df = pd.DataFrame( - { - col: [np.NaN] - for col in join_keys + list(metrics_mapping.values()) - } - ) - else: - # 1. normalize df, set dttm column - offset_metrics_df = self.normalize_df( - offset_metrics_df, query_object_clone - ) - - # 2. rename extra query columns - offset_metrics_df = offset_metrics_df.rename(columns=metrics_mapping) - - # cache df and query - value = { - "df": offset_metrics_df, - "query": result.query, - } - cache.set( - key=cache_key, - value=value, - timeout=self.get_cache_timeout(), - datasource_uid=query_context.datasource.uid, - region=CacheRegion.DATA, - ) - offset_dfs[offset] = offset_metrics_df - - if offset_dfs: - df = self.join_offset_dfs( - df, - offset_dfs, - time_grain, - join_keys, - ) - - return CachedTimeOffset(df=df, queries=queries, cache_keys=cache_keys) - - def _get_temporal_column_for_filter( # noqa: C901 - self, query_object: QueryObject, x_axis_label: str | None - ) -> str | None: - """ - Helper method to reliably determine the temporal column for filtering. - - This method tries multiple strategies to find the correct temporal column: - 1. Use explicitly set granularity - 2. Use x_axis_label if it's a temporal column - 3. Find any datetime column in the datasource - - :param query_object: The query object - :param x_axis_label: The x-axis label from the query - :return: The name of the temporal column, or None if not found - """ - # Strategy 1: Use explicitly set granularity - if query_object.granularity: - return query_object.granularity - - # Strategy 2: Use x_axis_label if it exists - if x_axis_label: - return x_axis_label - - # Strategy 3: Find any datetime column in the datasource - if hasattr(self._qc_datasource, "columns"): - for col in self._qc_datasource.columns: - if hasattr(col, "is_dttm") and col.is_dttm: - if hasattr(col, "column_name"): - return col.column_name - elif hasattr(col, "name"): - return col.name - - return None - - def _process_date_range_offset( - self, offset_df: pd.DataFrame, join_keys: list[str] - ) -> tuple[pd.DataFrame, list[str]]: - """Process date range offset data and return modified DataFrame and keys.""" - temporal_cols = ["ds", "__timestamp", "dttm"] - non_temporal_join_keys = [key for key in join_keys if key not in temporal_cols] - - if non_temporal_join_keys: - return offset_df, non_temporal_join_keys - - metric_columns = [col for col in offset_df.columns if col not in temporal_cols] - - if metric_columns: - aggregated_values = {} - for col in metric_columns: - if pd.api.types.is_numeric_dtype(offset_df[col]): - aggregated_values[col] = offset_df[col].sum() - else: - aggregated_values[col] = ( - offset_df[col].iloc[0] if not offset_df.empty else None - ) - - offset_df = pd.DataFrame([aggregated_values]) - - return offset_df, [] - - def _apply_cleanup_logic( - self, - df: pd.DataFrame, - offset: str, - time_grain: str | None, - join_keys: list[str], - is_date_range_offset: bool, - ) -> pd.DataFrame: - """Apply appropriate cleanup logic based on offset type.""" - if time_grain and not is_date_range_offset: - if join_keys: - col = df.pop(join_keys[0]) - df.insert(0, col.name, col) - - df.drop( - list(df.filter(regex=f"{OFFSET_JOIN_COLUMN_SUFFIX}|{R_SUFFIX}")), - axis=1, - inplace=True, - ) - elif is_date_range_offset: - df.drop( - list(df.filter(regex=f"{R_SUFFIX}")), - axis=1, - inplace=True, - ) - else: - df.drop( - list(df.filter(regex=f"{R_SUFFIX}")), - axis=1, - inplace=True, - ) - - return df - - def _determine_join_keys( - self, - df: pd.DataFrame, - offset_df: pd.DataFrame, - offset: str, - time_grain: str | None, - join_keys: list[str], - is_date_range_offset: bool, - join_column_producer: Any, - ) -> tuple[pd.DataFrame, list[str]]: - """Determine appropriate join keys and modify DataFrames if needed.""" - if time_grain and not is_date_range_offset: - column_name = OFFSET_JOIN_COLUMN_SUFFIX + offset - - # Add offset join columns for relative time offsets - self.add_offset_join_column( - df, column_name, time_grain, offset, join_column_producer - ) - self.add_offset_join_column( - offset_df, column_name, time_grain, None, join_column_producer - ) - return offset_df, [column_name, *join_keys[1:]] - - elif is_date_range_offset: - return self._process_date_range_offset(offset_df, join_keys) - - else: - return offset_df, join_keys - - def _perform_join( - self, df: pd.DataFrame, offset_df: pd.DataFrame, actual_join_keys: list[str] - ) -> pd.DataFrame: - """Perform the appropriate join operation.""" - if actual_join_keys: - return dataframe_utils.left_join_df( - left_df=df, - right_df=offset_df, - join_keys=actual_join_keys, - rsuffix=R_SUFFIX, - ) - else: - temp_key = "__temp_join_key__" - df[temp_key] = 1 - offset_df[temp_key] = 1 - - result_df = dataframe_utils.left_join_df( - left_df=df, - right_df=offset_df, - join_keys=[temp_key], - rsuffix=R_SUFFIX, - ) - - # Remove temporary join keys - result_df.drop(columns=[temp_key], inplace=True, errors="ignore") - result_df.drop( - columns=[f"{temp_key}{R_SUFFIX}"], inplace=True, errors="ignore" - ) - return result_df - - def join_offset_dfs( - self, - df: pd.DataFrame, - offset_dfs: dict[str, pd.DataFrame], - time_grain: str | None, - join_keys: list[str], - ) -> pd.DataFrame: - """ - Join offset DataFrames with the main DataFrame. - - :param df: The main DataFrame. - :param offset_dfs: A list of offset DataFrames. - :param time_grain: The time grain used to calculate the temporal join key. - :param join_keys: The keys to join on. - """ - join_column_producer = current_app.config[ - "TIME_GRAIN_JOIN_COLUMN_PRODUCERS" - ].get(time_grain) - - if join_column_producer and not time_grain: - raise QueryObjectValidationError( - _("Time Grain must be specified when using Time Shift.") - ) - - for offset, offset_df in offset_dfs.items(): - is_date_range_offset = self.is_valid_date_range( - offset - ) and feature_flag_manager.is_feature_enabled( - "DATE_RANGE_TIMESHIFTS_ENABLED" - ) - - offset_df, actual_join_keys = self._determine_join_keys( - df, - offset_df, - offset, - time_grain, - join_keys, - is_date_range_offset, - join_column_producer, - ) - - df = self._perform_join(df, offset_df, actual_join_keys) - df = self._apply_cleanup_logic( - df, offset, time_grain, join_keys, is_date_range_offset - ) - - return df - - @staticmethod - def generate_join_column( - row: pd.Series, - column_index: int, - time_grain: str, - time_offset: str | None = None, - ) -> str: - value = row[column_index] - - if hasattr(value, "strftime"): - if time_offset and not QueryContextProcessor.is_valid_date_range_static( - time_offset - ): - value = value + DateOffset(**normalize_time_delta(time_offset)) - - if time_grain in ( - TimeGrain.WEEK_STARTING_SUNDAY, - TimeGrain.WEEK_ENDING_SATURDAY, - ): - return value.strftime("%Y-W%U") - - if time_grain in ( - TimeGrain.WEEK, - TimeGrain.WEEK_STARTING_MONDAY, - TimeGrain.WEEK_ENDING_SUNDAY, - ): - return value.strftime("%Y-W%W") - - if time_grain == TimeGrain.MONTH: - return value.strftime("%Y-%m") - - if time_grain == TimeGrain.QUARTER: - return value.strftime("%Y-Q") + str(value.quarter) - - if time_grain == TimeGrain.YEAR: - return value.strftime("%Y") - - return str(value) - - @staticmethod - def is_valid_date_range_static(date_range: str) -> bool: - """Static version of is_valid_date_range for use in static methods""" - try: - # Attempt to parse the string as a date range in the format - # YYYY-MM-DD:YYYY-MM-DD - start_date, end_date = date_range.split(":") - datetime.strptime(start_date.strip(), "%Y-%m-%d") - datetime.strptime(end_date.strip(), "%Y-%m-%d") - return True - except ValueError: - # If parsing fails, it's not a valid date range in the format - # YYYY-MM-DD:YYYY-MM-DD - return False + return self._qc_datasource.get_query_result(query_object) def get_data( self, df: pd.DataFrame, coltypes: list[GenericDataType] diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 73a60de1ede..39bc716b6d5 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -18,12 +18,11 @@ from __future__ import annotations import builtins -import dataclasses import logging from collections import defaultdict from collections.abc import Hashable from dataclasses import dataclass, field -from datetime import datetime, timedelta +from datetime import timedelta from typing import Any, Callable, cast, Optional, Union import pandas as pd @@ -82,8 +81,6 @@ from superset.exceptions import ( ColumnNotFoundException, DatasetInvalidPermissionEvaluationException, QueryObjectValidationError, - SupersetErrorException, - SupersetErrorsException, SupersetGenericDBErrorException, SupersetSecurityException, SupersetSyntaxErrorException, @@ -1628,89 +1625,28 @@ class SqlaTable( return or_(*groups) def query(self, query_obj: QueryObjectDict) -> QueryResult: - qry_start_dttm = datetime.now() - query_str_ext = self.get_query_str_extended(query_obj) - sql = query_str_ext.sql - status = QueryStatus.SUCCESS - errors = None - error_message = None + """ + Executes the query for SqlaTable with additional column ordering logic. - def assign_column_label(df: pd.DataFrame) -> pd.DataFrame | None: - """ - Some engines change the case or generate bespoke column names, either by - default or due to lack of support for aliasing. This function ensures that - the column names in the DataFrame correspond to what is expected by - the viz components. + This overrides ExploreMixin.query() to add SqlaTable-specific behavior + for handling column_order from extras. + """ + # Get the base result from ExploreMixin + # (explicitly, not super() which would hit BaseDatasource first) + result = ExploreMixin.query(self, query_obj) - Sometimes a query may also contain only order by columns that are not used - as metrics or groupby columns, but need to present in the SQL `select`, - filtering by `labels_expected` make sure we only return columns users want. - - :param df: Original DataFrame returned by the engine - :return: Mutated DataFrame - """ - labels_expected = query_str_ext.labels_expected - if df is not None and not df.empty: - if len(df.columns) < len(labels_expected): - raise QueryObjectValidationError( - _("Db engine did not return all queried columns") - ) - if len(df.columns) > len(labels_expected): - df = df.iloc[:, 0 : len(labels_expected)] - df.columns = labels_expected - - extras = query_obj.get("extras", {}) - column_order = extras.get("column_order") - if column_order and isinstance(column_order, list): - existing_cols = [col for col in column_order if col in df.columns] - remaining_cols = [ - col for col in df.columns if col not in existing_cols - ] - final_order = existing_cols + remaining_cols - df = df[final_order] - return df - - try: - df = self.database.get_df( - sql, - self.catalog, - self.schema or None, - mutator=assign_column_label, - ) - except (SupersetErrorException, SupersetErrorsException): - # SupersetError(s) exception should not be captured; instead, they should - # bubble up to the Flask error handler so they are returned as proper SIP-40 - # errors. This is particularly important for database OAuth2, see SIP-85. - raise - except Exception as ex: # pylint: disable=broad-except - # TODO (betodealmeida): review exception handling while querying the external # noqa: E501 - # database. Ideally we'd expect and handle external database error, but - # everything else / the default should be to let things bubble up. - df = pd.DataFrame() - status = QueryStatus.FAILED - logger.warning( - "Query %s on schema %s failed", sql, self.schema, exc_info=True - ) - db_engine_spec = self.db_engine_spec - errors = [ - dataclasses.asdict(error) - for error in db_engine_spec.extract_errors( - ex, database_name=self.database.unique_name - ) + # Apply SqlaTable-specific column ordering + extras = query_obj.get("extras", {}) + column_order = extras.get("column_order") + if column_order and isinstance(column_order, list) and not result.df.empty: + existing_cols = [col for col in column_order if col in result.df.columns] + remaining_cols = [ + col for col in result.df.columns if col not in existing_cols ] - error_message = utils.error_msg_from_exception(ex) + final_order = existing_cols + remaining_cols + result.df = result.df[final_order] - return QueryResult( - applied_template_filters=query_str_ext.applied_template_filters, - applied_filter_columns=query_str_ext.applied_filter_columns, - rejected_filter_columns=query_str_ext.rejected_filter_columns, - status=status, - df=df, - duration=datetime.now() - qry_start_dttm, - query=sql, - errors=errors, - error_message=error_message, - ) + return result def get_sqla_table_object(self) -> Table: return self.database.get_table( diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 4e5d1004721..ecf1ff869cd 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -20,6 +20,7 @@ from __future__ import annotations import builtins +import copy import dataclasses import logging import re @@ -52,6 +53,7 @@ from flask_appbuilder.security.sqla.models import User from flask_babel import get_locale, lazy_gettext as _ from jinja2.exceptions import TemplateError from markupsafe import escape, Markup +from pandas import DateOffset from sqlalchemy import and_, Column, or_, UniqueConstraint from sqlalchemy.exc import MultipleResultsFound from sqlalchemy.ext.declarative import declared_attr @@ -64,15 +66,22 @@ from sqlalchemy_utils import UUIDType from superset import db, is_feature_enabled from superset.advanced_data_type.types import AdvancedDataTypeResponse from superset.common.db_query_status import QueryStatus -from superset.common.utils.time_range_utils import get_since_until_from_time_range -from superset.constants import EMPTY_STRING, NULL_STRING +from superset.common.utils import dataframe_utils +from superset.common.utils.time_range_utils import ( + get_since_until_from_query_object, + get_since_until_from_time_range, +) +from superset.constants import CacheRegion, EMPTY_STRING, NULL_STRING, TimeGrain from superset.db_engine_specs.base import TimestampExpression from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( AdvancedDataTypeResponseError, ColumnNotFoundException, + InvalidPostProcessingError, QueryClauseValidationException, QueryObjectValidationError, + SupersetErrorException, + SupersetErrorsException, SupersetSecurityException, SupersetSyntaxErrorException, ) @@ -90,15 +99,25 @@ from superset.superset_typing import ( ) from superset.utils import core as utils, json from superset.utils.core import ( + DateColumn, + DTTM_ALIAS, + FilterOperator, GenericDataType, + get_base_axis_labels, get_column_name, + get_metric_names, get_non_base_axis_columns, get_user_id, + get_x_axis_label, is_adhoc_column, MediumText, + normalize_dttm_col, + QueryObjectFilterClause, remove_duplicates, SqlExpressionType, + TIME_COMPARISON, ) +from superset.utils.date_parser import get_past_or_future, normalize_time_delta from superset.utils.dates import datetime_to_epoch from superset.utils.rls import apply_rls @@ -111,6 +130,7 @@ class ValidationResultDict(TypedDict): if TYPE_CHECKING: + from superset.common.query_object import QueryObject from superset.connectors.sqla.models import SqlMetric, TableColumn from superset.db_engine_specs import BaseEngineSpec from superset.models.core import Database @@ -120,6 +140,21 @@ logger = logging.getLogger(__name__) VIRTUAL_TABLE_ALIAS = "virtual_table" SERIES_LIMIT_SUBQ_ALIAS = "series_limit" +# Offset join column suffix used for joining offset results +OFFSET_JOIN_COLUMN_SUFFIX = "__offset_join_column_" + +# Right suffix used for joining offset results +R_SUFFIX = "__right_suffix" + + +class CachedTimeOffset(TypedDict): + """Result type for time offset processing""" + + df: pd.DataFrame + queries: list[str] + cache_keys: list[str | None] + + # Keys used to filter QueryObjectDict for get_sqla_query parameters SQLA_QUERY_KEYS = { "apply_fetch_values_predicate", @@ -781,9 +816,6 @@ class ExploreMixin: # pylint: disable=too-many-public-methods def db_extra(self) -> Optional[dict[str, Any]]: raise NotImplementedError() - def query(self, query_obj: QueryObjectDict) -> QueryResult: - raise NotImplementedError() - @property def database_id(self) -> int: raise NotImplementedError() @@ -1107,9 +1139,15 @@ class ExploreMixin: # pylint: disable=too-many-public-methods if is_alias_used_in_orderby(col): col.name = f"{col.name}__" - def exc_query(self, qry: Any) -> QueryResult: + def query(self, query_obj: QueryObjectDict) -> QueryResult: + """ + Executes the query and returns a dataframe. + + This method is the unified entry point for query execution across all + datasource types (Query, SqlaTable, etc.). + """ qry_start_dttm = datetime.now() - query_str_ext = self.get_query_str_extended(qry) + query_str_ext = self.get_query_str_extended(query_obj) sql = query_str_ext.sql status = QueryStatus.SUCCESS errors = None @@ -1146,6 +1184,10 @@ class ExploreMixin: # pylint: disable=too-many-public-methods mutator=assign_column_label, ) except Exception as ex: # pylint: disable=broad-except + # Re-raise SupersetErrorException (includes OAuth2RedirectError) + # to bubble up to API layer + if isinstance(ex, (SupersetErrorException, SupersetErrorsException)): + raise df = pd.DataFrame() status = QueryStatus.FAILED logger.warning( @@ -1172,6 +1214,755 @@ class ExploreMixin: # pylint: disable=too-many-public-methods error_message=error_message, ) + def exc_query(self, qry: Any) -> QueryResult: + """ + Deprecated: Use query() instead. + This method is kept for backward compatibility. + """ + return self.query(qry) + + def normalize_df(self, df: pd.DataFrame, query_object: QueryObject) -> pd.DataFrame: + """ + Normalize the dataframe by converting datetime columns and ensuring + numerical metrics. + + :param df: The dataframe to normalize + :param query_object: The query object with metadata about columns + :return: Normalized dataframe + """ + + def _get_timestamp_format(column: str | None) -> str | None: + if not hasattr(self, "get_column"): + return None + column_obj = self.get_column(column) + if ( + column_obj + and hasattr(column_obj, "python_date_format") + and (formatter := column_obj.python_date_format) + ): + return str(formatter) + return None + + # Collect datetime columns + labels = tuple( + label + for label in [ + *get_base_axis_labels(query_object.columns), + query_object.granularity, + ] + if hasattr(self, "get_column") + and (col := self.get_column(label)) + and (col.get("is_dttm") if isinstance(col, dict) else col.is_dttm) + ) + + dttm_cols = [ + DateColumn( + timestamp_format=_get_timestamp_format(label), + offset=self.offset, + time_shift=query_object.time_shift, + col_label=label, + ) + for label in labels + if label + ] + + if DTTM_ALIAS in df: + dttm_cols.append( + DateColumn.get_legacy_time_column( + timestamp_format=_get_timestamp_format(query_object.granularity), + offset=self.offset, + time_shift=query_object.time_shift, + ) + ) + + normalize_dttm_col( + df=df, + dttm_cols=tuple(dttm_cols), + ) + + # Convert metrics to numerical values if enforced + if getattr(self, "enforce_numerical_metrics", True): + dataframe_utils.df_metrics_to_num(df, query_object) + + df.replace([np.inf, -np.inf], np.nan, inplace=True) + + return df + + def get_query_result(self, query_object: QueryObject) -> QueryResult: + """ + Execute query and return results with full processing pipeline. + + This method handles: + 1. Query execution via self.query() + 2. DataFrame normalization + 3. Time offset processing (if applicable) + 4. Post-processing operations + + :param query_object: The query configuration + :return: QueryResult with processed dataframe + """ + # Execute the base query + result = self.query(query_object.to_dict()) + query = result.query + ";\n\n" if result.query else "" + + # Process the dataframe if not empty + df = result.df + if not df.empty: + # Normalize datetime columns and metrics + df = self.normalize_df(df, query_object) + + # Process time offsets if requested + if query_object.time_offsets: + # Process time offsets using the datasource's own method + # Note: caching is disabled here as we don't have query context + time_offsets = self.processing_time_offsets( + df, query_object, cache_key_fn=None, cache_timeout_fn=None + ) + df = time_offsets["df"] + queries = time_offsets["queries"] + query += ";\n\n".join(queries) + query += ";\n\n" + + # Execute post-processing operations + try: + df = query_object.exec_post_processing(df) + except InvalidPostProcessingError as ex: + raise QueryObjectValidationError(ex.message) from ex + + # Update result with processed data + result.df = df + result.query = query + result.from_dttm = query_object.from_dttm + result.to_dttm = query_object.to_dttm + + return result + + def processing_time_offsets( # pylint: disable=too-many-locals,too-many-statements # noqa: C901 + self, + df: pd.DataFrame, + query_object: QueryObject, + cache_key_fn: Callable[[QueryObject, str, Any], str | None] | None = None, + cache_timeout_fn: Callable[[], int] | None = None, + force_cache: bool = False, + ) -> CachedTimeOffset: + """ + Process time offsets for time comparison feature. + + This method handles both relative time offsets (e.g., "1 week ago") and + absolute date range offsets (e.g., "2015-01-03 : 2015-01-04"). + + :param df: The main dataframe + :param query_object: The query object with time offset configuration + :param cache_key_fn: Optional function to generate cache keys + :param cache_timeout_fn: Optional function to get cache timeout + :param force_cache: Whether to force cache refresh + :return: CachedTimeOffset with processed dataframe and queries + """ + # Import here to avoid circular dependency + # pylint: disable=import-outside-toplevel + from superset.common.utils.query_cache_manager import QueryCacheManager + + # ensure query_object is immutable + query_object_clone = copy.copy(query_object) + queries: list[str] = [] + cache_keys: list[str | None] = [] + offset_dfs: dict[str, pd.DataFrame] = {} + + outer_from_dttm, outer_to_dttm = get_since_until_from_query_object(query_object) + if not outer_from_dttm or not outer_to_dttm: + raise QueryObjectValidationError( + _( + "An enclosed time range (both start and end) must be specified " + "when using a Time Comparison." + ) + ) + + time_grain = self.get_time_grain(query_object) + metric_names = get_metric_names(query_object.metrics) + # use columns that are not metrics as join keys + join_keys = [col for col in df.columns if col not in metric_names] + + for offset in query_object.time_offsets: + try: + original_offset = offset + is_date_range_offset = self.is_valid_date_range(offset) + + if is_date_range_offset and feature_flag_manager.is_feature_enabled( + "DATE_RANGE_TIMESHIFTS_ENABLED" + ): + # DATE RANGE OFFSET LOGIC (like "2015-01-03 : 2015-01-04") + try: + # Parse the specified range + offset_from_dttm, offset_to_dttm = ( + get_since_until_from_time_range(time_range=offset) + ) + except ValueError as ex: + raise QueryObjectValidationError(str(ex)) from ex + + # Use the specified range directly + query_object_clone.from_dttm = offset_from_dttm + query_object_clone.to_dttm = offset_to_dttm + + # For date range offsets, we must NOT set inner bounds + # These create additional WHERE clauses that conflict with our + # date range + query_object_clone.inner_from_dttm = None + query_object_clone.inner_to_dttm = None + + elif is_date_range_offset: + # Date range timeshift feature is disabled + raise QueryObjectValidationError( + "Date range timeshifts are not enabled. " + "Please contact your administrator to enable the " + "DATE_RANGE_TIMESHIFTS_ENABLED feature flag." + ) + + else: + # RELATIVE OFFSET LOGIC (like "1 day ago") + if self.is_valid_date(offset) or offset == "inherit": + offset = self.get_offset_custom_or_inherit( + offset, + outer_from_dttm, + outer_to_dttm, + ) + query_object_clone.from_dttm = get_past_or_future( + offset, + outer_from_dttm, + ) + query_object_clone.to_dttm = get_past_or_future( + offset, outer_to_dttm + ) + + query_object_clone.inner_from_dttm = query_object_clone.from_dttm + query_object_clone.inner_to_dttm = query_object_clone.to_dttm + + x_axis_label = get_x_axis_label(query_object.columns) + query_object_clone.granularity = ( + query_object_clone.granularity or x_axis_label + ) + + except ValueError as ex: + raise QueryObjectValidationError(str(ex)) from ex + + query_object_clone.time_offsets = [] + query_object_clone.post_processing = [] + + # Get time offset index + index = (get_base_axis_labels(query_object.columns) or [DTTM_ALIAS])[0] + + if is_date_range_offset and feature_flag_manager.is_feature_enabled( + "DATE_RANGE_TIMESHIFTS_ENABLED" + ): + # Create a completely new filter list to preserve original filters + query_object_clone.filter = copy.deepcopy(query_object_clone.filter) + + # Remove any existing temporal filters that might conflict + query_object_clone.filter = [ + flt + for flt in query_object_clone.filter + if not (flt.get("op") == FilterOperator.TEMPORAL_RANGE) + ] + + # Determine the temporal column with multiple fallback strategies + temporal_col = self._get_temporal_column_for_filter( + query_object_clone, x_axis_label + ) + + # Always add a temporal filter for date range offsets + if temporal_col: + new_temporal_filter: QueryObjectFilterClause = { + "col": temporal_col, + "op": FilterOperator.TEMPORAL_RANGE, + "val": ( + f"{query_object_clone.from_dttm} : " + f"{query_object_clone.to_dttm}" + ), + } + query_object_clone.filter.append(new_temporal_filter) + + else: + # This should rarely happen with proper fallbacks + raise QueryObjectValidationError( + _( + "Unable to identify temporal column for date range time comparison." # noqa: E501 + "Please ensure your dataset has a properly configured time column." # noqa: E501 + ) + ) + + else: + # RELATIVE OFFSET: Original logic for non-date-range offsets + # The comparison is not using a temporal column so we need to modify + # the temporal filter so we run the query with the correct time range + if not dataframe_utils.is_datetime_series(df.get(index)): + query_object_clone.filter = copy.deepcopy(query_object_clone.filter) + + # Find and update temporal filters + for flt in query_object_clone.filter: + if flt.get( + "op" + ) == FilterOperator.TEMPORAL_RANGE and isinstance( + flt.get("val"), str + ): + time_range = cast(str, flt.get("val")) + ( + new_outer_from_dttm, + new_outer_to_dttm, + ) = get_since_until_from_time_range( + time_range=time_range, + time_shift=offset, + ) + flt["val"] = f"{new_outer_from_dttm} : {new_outer_to_dttm}" + else: + # If it IS a datetime series, we still need to clear conflicts + query_object_clone.filter = copy.deepcopy(query_object_clone.filter) + + # For relative offsets with datetime series, ensure the temporal + # filter matches our range + temporal_col = query_object_clone.granularity or x_axis_label + + # Update any existing temporal filters to match our shifted range + for flt in query_object_clone.filter: + if ( + flt.get("op") == FilterOperator.TEMPORAL_RANGE + and flt.get("col") == temporal_col + ): + flt["val"] = ( + f"{query_object_clone.from_dttm} : " + f"{query_object_clone.to_dttm}" + ) + + # Remove non-temporal x-axis filters (but keep temporal ones) + query_object_clone.filter = [ + flt + for flt in query_object_clone.filter + if not ( + flt.get("col") == x_axis_label + and flt.get("op") != FilterOperator.TEMPORAL_RANGE + ) + ] + + # Continue with the rest of the method (caching, execution, etc.) + cached_time_offset_key = ( + offset if offset == original_offset else f"{offset}_{original_offset}" + ) + + cache_key = None + if cache_key_fn: + cache_key = cache_key_fn( + query_object_clone, + cached_time_offset_key, + time_grain, + ) + + cache = QueryCacheManager.get(cache_key, CacheRegion.DATA, force_cache) + + if cache.is_loaded: + offset_dfs[offset] = cache.df + queries.append(cache.query) + cache_keys.append(cache_key) + continue + + query_object_clone_dct = query_object_clone.to_dict() + + # rename metrics: SUM(value) => SUM(value) 1 year ago + metrics_mapping = { + metric: TIME_COMPARISON.join([metric, original_offset]) + for metric in metric_names + } + + # When the original query has limit or offset we wont apply those + # to the subquery so we prevent data inconsistency due to missing records + # in the dataframes when performing the join + if query_object.row_limit or query_object.row_offset: + query_object_clone_dct["row_limit"] = app.config["ROW_LIMIT"] + query_object_clone_dct["row_offset"] = 0 + + # Call the unified query method on the datasource + result = self.query(query_object_clone_dct) + + queries.append(result.query) + cache_keys.append(None) + + offset_metrics_df = result.df + if offset_metrics_df.empty: + offset_metrics_df = pd.DataFrame( + { + col: [np.NaN] + for col in join_keys + list(metrics_mapping.values()) + } + ) + else: + # 1. normalize df, set dttm column + offset_metrics_df = self.normalize_df( + offset_metrics_df, query_object_clone + ) + + # 2. rename extra query columns + offset_metrics_df = offset_metrics_df.rename(columns=metrics_mapping) + + # cache df and query if caching is enabled + if cache_key and cache_timeout_fn: + value = { + "df": offset_metrics_df, + "query": result.query, + } + cache.set( + key=cache_key, + value=value, + timeout=cache_timeout_fn(), + datasource_uid=self.uid, + region=CacheRegion.DATA, + ) + offset_dfs[offset] = offset_metrics_df + + if offset_dfs: + df = self.join_offset_dfs( + df, + offset_dfs, + time_grain, + join_keys, + ) + + return CachedTimeOffset(df=df, queries=queries, cache_keys=cache_keys) + + @staticmethod + def get_time_grain(query_object: QueryObject) -> Any | None: + if ( + query_object.columns + and len(query_object.columns) > 0 + and isinstance(query_object.columns[0], dict) + ): + # If the time grain is in the columns it will be the first one + # and it will be of AdhocColumn type + return query_object.columns[0].get("timeGrain") + + return query_object.extras.get("time_grain_sqla") + + def is_valid_date(self, date_string: str) -> bool: + try: + # Attempt to parse the string as a date in the format YYYY-MM-DD + datetime.strptime(date_string, "%Y-%m-%d") + return True + except ValueError: + # If parsing fails, it's not a valid date in the format YYYY-MM-DD + return False + + def is_valid_date_range(self, date_range: str) -> bool: + try: + # Attempt to parse the string as a date range in the format + # YYYY-MM-DD:YYYY-MM-DD + start_date, end_date = date_range.split(":") + datetime.strptime(start_date.strip(), "%Y-%m-%d") + datetime.strptime(end_date.strip(), "%Y-%m-%d") + return True + except ValueError: + # If parsing fails, it's not a valid date range in the format + # YYYY-MM-DD:YYYY-MM-DD + return False + + def get_offset_custom_or_inherit( + self, + offset: str, + outer_from_dttm: datetime, + outer_to_dttm: datetime, + ) -> str: + """ + Get the time offset for custom or inherit. + + :param offset: The offset string. + :param outer_from_dttm: The outer from datetime. + :param outer_to_dttm: The outer to datetime. + :returns: The time offset. + """ + if offset == "inherit": + # return the difference in days between the from and the to dttm formatted as a string with the " days ago" suffix # noqa: E501 + return f"{(outer_to_dttm - outer_from_dttm).days} days ago" + if self.is_valid_date(offset): + # return the offset as the difference in days between the outer from dttm and the offset date (which is a YYYY-MM-DD string) formatted as a string with the " days ago" suffix # noqa: E501 + offset_date = datetime.strptime(offset, "%Y-%m-%d") + return f"{(outer_from_dttm - offset_date).days} days ago" + return "" + + def _get_temporal_column_for_filter( # noqa: C901 + self, query_object: QueryObject, x_axis_label: str | None + ) -> str | None: + """ + Helper method to reliably determine the temporal column for filtering. + + This method tries multiple strategies to find the correct temporal column: + 1. Use explicitly set granularity + 2. Use x_axis_label if it's a temporal column + 3. Find any datetime column in the datasource + + :param query_object: The query object + :param x_axis_label: The x-axis label from the query + :return: The name of the temporal column, or None if not found + """ + # Strategy 1: Use explicitly set granularity + if query_object.granularity: + return query_object.granularity + + # Strategy 2: Use x_axis_label if it exists + if x_axis_label: + return x_axis_label + + # Strategy 3: Find any datetime column in the datasource + if hasattr(self, "columns"): + for col in self.columns: + if hasattr(col, "is_dttm") and col.is_dttm: + if hasattr(col, "column_name"): + return col.column_name + elif hasattr(col, "name"): + return col.name + + return None + + def _process_date_range_offset( + self, offset_df: pd.DataFrame, join_keys: list[str] + ) -> tuple[pd.DataFrame, list[str]]: + """Process date range offset data and return modified DataFrame and keys.""" + temporal_cols = ["ds", "__timestamp", "dttm"] + non_temporal_join_keys = [key for key in join_keys if key not in temporal_cols] + + if non_temporal_join_keys: + return offset_df, non_temporal_join_keys + + metric_columns = [col for col in offset_df.columns if col not in temporal_cols] + + if metric_columns: + aggregated_values = {} + for col in metric_columns: + if pd.api.types.is_numeric_dtype(offset_df[col]): + aggregated_values[col] = offset_df[col].sum() + else: + aggregated_values[col] = ( + offset_df[col].iloc[0] if not offset_df.empty else None + ) + + offset_df = pd.DataFrame([aggregated_values]) + + return offset_df, [] + + def _apply_cleanup_logic( + self, + df: pd.DataFrame, + offset: str, + time_grain: str | None, + join_keys: list[str], + is_date_range_offset: bool, + ) -> pd.DataFrame: + """Apply appropriate cleanup logic based on offset type.""" + if time_grain and not is_date_range_offset: + if join_keys: + col = df.pop(join_keys[0]) + df.insert(0, col.name, col) + + df.drop( + list(df.filter(regex=f"{OFFSET_JOIN_COLUMN_SUFFIX}|{R_SUFFIX}")), + axis=1, + inplace=True, + ) + elif is_date_range_offset: + df.drop( + list(df.filter(regex=f"{R_SUFFIX}")), + axis=1, + inplace=True, + ) + else: + df.drop( + list(df.filter(regex=f"{R_SUFFIX}")), + axis=1, + inplace=True, + ) + + return df + + def _determine_join_keys( + self, + df: pd.DataFrame, + offset_df: pd.DataFrame, + offset: str, + time_grain: str | None, + join_keys: list[str], + is_date_range_offset: bool, + join_column_producer: Any, + ) -> tuple[pd.DataFrame, list[str]]: + """Determine appropriate join keys and modify DataFrames if needed.""" + if time_grain and not is_date_range_offset: + column_name = OFFSET_JOIN_COLUMN_SUFFIX + offset + + # Add offset join columns for relative time offsets + self.add_offset_join_column( + df, column_name, time_grain, offset, join_column_producer + ) + self.add_offset_join_column( + offset_df, column_name, time_grain, None, join_column_producer + ) + return offset_df, [column_name, *join_keys[1:]] + + elif is_date_range_offset: + return self._process_date_range_offset(offset_df, join_keys) + + else: + return offset_df, join_keys + + def _perform_join( + self, df: pd.DataFrame, offset_df: pd.DataFrame, actual_join_keys: list[str] + ) -> pd.DataFrame: + """Perform the appropriate join operation.""" + if actual_join_keys: + return dataframe_utils.left_join_df( + left_df=df, + right_df=offset_df, + join_keys=actual_join_keys, + rsuffix=R_SUFFIX, + ) + else: + temp_key = "__temp_join_key__" + df[temp_key] = 1 + offset_df[temp_key] = 1 + + result_df = dataframe_utils.left_join_df( + left_df=df, + right_df=offset_df, + join_keys=[temp_key], + rsuffix=R_SUFFIX, + ) + + # Remove temporary join keys + result_df.drop(columns=[temp_key], inplace=True, errors="ignore") + result_df.drop( + columns=[f"{temp_key}{R_SUFFIX}"], inplace=True, errors="ignore" + ) + return result_df + + def join_offset_dfs( + self, + df: pd.DataFrame, + offset_dfs: dict[str, pd.DataFrame], + time_grain: str | None, + join_keys: list[str], + ) -> pd.DataFrame: + """ + Join offset DataFrames with the main DataFrame. + + :param df: The main DataFrame. + :param offset_dfs: A list of offset DataFrames. + :param time_grain: The time grain used to calculate the temporal join key. + :param join_keys: The keys to join on. + """ + join_column_producer = app.config["TIME_GRAIN_JOIN_COLUMN_PRODUCERS"].get( + time_grain + ) + + if join_column_producer and not time_grain: + raise QueryObjectValidationError( + _("Time Grain must be specified when using Time Shift.") + ) + + for offset, offset_df in offset_dfs.items(): + is_date_range_offset = self.is_valid_date_range( + offset + ) and feature_flag_manager.is_feature_enabled( + "DATE_RANGE_TIMESHIFTS_ENABLED" + ) + + offset_df, actual_join_keys = self._determine_join_keys( + df, + offset_df, + offset, + time_grain, + join_keys, + is_date_range_offset, + join_column_producer, + ) + + df = self._perform_join(df, offset_df, actual_join_keys) + df = self._apply_cleanup_logic( + df, offset, time_grain, join_keys, is_date_range_offset + ) + + return df + + def add_offset_join_column( + self, + df: pd.DataFrame, + name: str, + time_grain: str, + time_offset: str | None = None, + join_column_producer: Any = None, + ) -> None: + """ + Adds an offset join column to the provided DataFrame. + + The function modifies the DataFrame in-place. + + :param df: pandas DataFrame to which the offset join column will be added. + :param name: The name of the new column to be added. + :param time_grain: The time grain used to calculate the new column. + :param time_offset: The time offset used to calculate the new column. + :param join_column_producer: A function to generate the join column. + """ + if join_column_producer: + df[name] = df.apply(lambda row: join_column_producer(row, 0), axis=1) + else: + df[name] = df.apply( + lambda row: self.generate_join_column(row, 0, time_grain, time_offset), + axis=1, + ) + + @staticmethod + def generate_join_column( + row: pd.Series, + column_index: int, + time_grain: str, + time_offset: str | None = None, + ) -> str: + value = row[column_index] + + if hasattr(value, "strftime"): + if time_offset and not ExploreMixin.is_valid_date_range_static(time_offset): + value = value + DateOffset(**normalize_time_delta(time_offset)) + + if time_grain in ( + TimeGrain.WEEK_STARTING_SUNDAY, + TimeGrain.WEEK_ENDING_SATURDAY, + ): + return value.strftime("%Y-W%U") + + if time_grain in ( + TimeGrain.WEEK, + TimeGrain.WEEK_STARTING_MONDAY, + TimeGrain.WEEK_ENDING_SUNDAY, + ): + return value.strftime("%Y-W%W") + + if time_grain == TimeGrain.MONTH: + return value.strftime("%Y-%m") + + if time_grain == TimeGrain.QUARTER: + return value.strftime("%Y-Q") + str(value.quarter) + + if time_grain == TimeGrain.YEAR: + return value.strftime("%Y") + + return str(value) + + @staticmethod + def is_valid_date_range_static(date_range: str) -> bool: + """Static version of is_valid_date_range for use in static methods""" + try: + # Attempt to parse the string as a date range in the format + # YYYY-MM-DD:YYYY-MM-DD + start_date, end_date = date_range.split(":") + datetime.strptime(start_date.strip(), "%Y-%m-%d") + datetime.strptime(end_date.strip(), "%Y-%m-%d") + return True + except ValueError: + # If parsing fails, it's not a valid date range in the format + # YYYY-MM-DD:YYYY-MM-DD + return False + def get_rendered_sql( self, template_processor: Optional[BaseTemplateProcessor] = None, diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index dedc9b97820..17824e78138 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -622,10 +622,24 @@ class TestQueryContext(SupersetTestCase): payload["queries"][0]["time_offsets"] = ["1 year ago", "1 year later"] query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] + + # Create cache functions for testing + def cache_key_fn(qo, time_offset, time_grain): + return query_context._processor.query_cache_key( + qo, time_offset=time_offset, time_grain=time_grain + ) + + def cache_timeout_fn(): + return query_context._processor.get_cache_timeout() + # query without cache - query_context.processing_time_offsets(df.copy(), query_object) + query_context.datasource.processing_time_offsets( + df.copy(), query_object, cache_key_fn, cache_timeout_fn, query_context.force + ) # query with cache - rv = query_context.processing_time_offsets(df.copy(), query_object) + rv = query_context.datasource.processing_time_offsets( + df.copy(), query_object, cache_key_fn, cache_timeout_fn, query_context.force + ) cache_keys = rv["cache_keys"] cache_keys__1_year_ago = cache_keys[0] cache_keys__1_year_later = cache_keys[1] @@ -637,7 +651,9 @@ class TestQueryContext(SupersetTestCase): payload["queries"][0]["time_offsets"] = ["1 year later", "1 year ago"] query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] - rv = query_context.processing_time_offsets(df.copy(), query_object) + rv = query_context.datasource.processing_time_offsets( + df.copy(), query_object, cache_key_fn, cache_timeout_fn, query_context.force + ) cache_keys = rv["cache_keys"] assert cache_keys__1_year_ago == cache_keys[1] assert cache_keys__1_year_later == cache_keys[0] @@ -646,9 +662,8 @@ class TestQueryContext(SupersetTestCase): payload["queries"][0]["time_offsets"] = [] query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] - rv = query_context.processing_time_offsets( - df.copy(), - query_object, + rv = query_context.datasource.processing_time_offsets( + df.copy(), query_object, cache_key_fn, cache_timeout_fn, query_context.force ) assert rv["df"].shape == df.shape @@ -676,7 +691,18 @@ class TestQueryContext(SupersetTestCase): payload["queries"][0]["time_offsets"] = ["3 years ago", "3 years later"] query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] - time_offsets_obj = query_context.processing_time_offsets(df, query_object) + + def cache_key_fn(qo, time_offset, time_grain): + return query_context._processor.query_cache_key( + qo, time_offset=time_offset, time_grain=time_grain + ) + + def cache_timeout_fn(): + return query_context._processor.get_cache_timeout() + + time_offsets_obj = query_context.datasource.processing_time_offsets( + df, query_object, cache_key_fn, cache_timeout_fn, query_context.force + ) query_from_1977_to_1988 = time_offsets_obj["queries"][0] query_from_1983_to_1994 = time_offsets_obj["queries"][1] @@ -707,7 +733,18 @@ class TestQueryContext(SupersetTestCase): payload["queries"][0]["time_offsets"] = ["3 years ago", "3 years later"] query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] - time_offsets_obj = query_context.processing_time_offsets(df, query_object) + + def cache_key_fn(qo, time_offset, time_grain): + return query_context._processor.query_cache_key( + qo, time_offset=time_offset, time_grain=time_grain + ) + + def cache_timeout_fn(): + return query_context._processor.get_cache_timeout() + + time_offsets_obj = query_context.datasource.processing_time_offsets( + df, query_object, cache_key_fn, cache_timeout_fn, query_context.force + ) df_with_offsets = time_offsets_obj["df"] df_with_offsets = df_with_offsets.set_index(["__timestamp", "state"]) @@ -795,7 +832,18 @@ class TestQueryContext(SupersetTestCase): payload["queries"][0]["time_offsets"] = ["1 year ago", "1 year later"] query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] - time_offsets_obj = query_context.processing_time_offsets(df, query_object) + + def cache_key_fn(qo, time_offset, time_grain): + return query_context._processor.query_cache_key( + qo, time_offset=time_offset, time_grain=time_grain + ) + + def cache_timeout_fn(): + return query_context._processor.get_cache_timeout() + + time_offsets_obj = query_context.datasource.processing_time_offsets( + df, query_object, cache_key_fn, cache_timeout_fn, query_context.force + ) sqls = time_offsets_obj["queries"] row_limit_value = current_app.config["ROW_LIMIT"] row_limit_pattern_with_config_value = r"LIMIT " + re.escape( diff --git a/tests/unit_tests/common/test_query_context_processor.py b/tests/unit_tests/common/test_query_context_processor.py index 68c1e4b7ed1..3caf3a7cb74 100644 --- a/tests/unit_tests/common/test_query_context_processor.py +++ b/tests/unit_tests/common/test_query_context_processor.py @@ -36,12 +36,51 @@ def mock_query_context(): @pytest.fixture def processor(mock_query_context): + from superset.models.helpers import ExploreMixin + mock_query_context.datasource.data = MagicMock() mock_query_context.datasource.data.get.return_value = { "col1": "Column 1", "col2": "Column 2", } - return QueryContextProcessor(mock_query_context) + + # Create a processor instance + processor = QueryContextProcessor(mock_query_context) + + # Setup datasource methods from ExploreMixin to be real methods + # by binding them to the mock datasource + processor._qc_datasource.is_valid_date_range = ( + ExploreMixin.is_valid_date_range.__get__(processor._qc_datasource) + ) + processor._qc_datasource.is_valid_date = ExploreMixin.is_valid_date.__get__( + processor._qc_datasource + ) + processor._qc_datasource.get_offset_custom_or_inherit = ( + ExploreMixin.get_offset_custom_or_inherit.__get__(processor._qc_datasource) + ) + processor._qc_datasource._get_temporal_column_for_filter = ( + ExploreMixin._get_temporal_column_for_filter.__get__(processor._qc_datasource) + ) + processor._qc_datasource.join_offset_dfs = ExploreMixin.join_offset_dfs.__get__( + processor._qc_datasource + ) + processor._qc_datasource._determine_join_keys = ( + ExploreMixin._determine_join_keys.__get__(processor._qc_datasource) + ) + processor._qc_datasource._process_date_range_offset = ( + ExploreMixin._process_date_range_offset.__get__(processor._qc_datasource) + ) + processor._qc_datasource._perform_join = ExploreMixin._perform_join.__get__( + processor._qc_datasource + ) + processor._qc_datasource._apply_cleanup_logic = ( + ExploreMixin._apply_cleanup_logic.__get__(processor._qc_datasource) + ) + processor._qc_datasource.add_offset_join_column = ( + ExploreMixin.add_offset_join_column.__get__(processor._qc_datasource) + ) + + return processor def test_get_data_table_like(processor, mock_query_context): @@ -245,45 +284,46 @@ def test_get_data_xlsx_apply_column_types_error( def test_is_valid_date_range_format(processor): """Test that date range format validation works correctly.""" # Should return True for valid date range format - assert processor.is_valid_date_range("2023-01-01 : 2023-01-31") is True - assert processor.is_valid_date_range("2020-12-25 : 2020-12-31") is True + assert ( + processor._qc_datasource.is_valid_date_range("2023-01-01 : 2023-01-31") is True + ) + assert ( + processor._qc_datasource.is_valid_date_range("2020-12-25 : 2020-12-31") is True + ) # Should return False for invalid format - assert processor.is_valid_date_range("1 day ago") is False - assert processor.is_valid_date_range("2023-01-01") is False - assert processor.is_valid_date_range("invalid") is False + assert processor._qc_datasource.is_valid_date_range("1 day ago") is False + assert processor._qc_datasource.is_valid_date_range("2023-01-01") is False + assert processor._qc_datasource.is_valid_date_range("invalid") is False def test_is_valid_date_range_static_format(): """Test that static date range format validation works correctly.""" + from superset.models.helpers import ExploreMixin + # Should return True for valid date range format - assert ( - QueryContextProcessor.is_valid_date_range_static("2023-01-01 : 2023-01-31") - is True - ) - assert ( - QueryContextProcessor.is_valid_date_range_static("2020-12-25 : 2020-12-31") - is True - ) + assert ExploreMixin.is_valid_date_range_static("2023-01-01 : 2023-01-31") is True + assert ExploreMixin.is_valid_date_range_static("2020-12-25 : 2020-12-31") is True # Should return False for invalid format - assert QueryContextProcessor.is_valid_date_range_static("1 day ago") is False - assert QueryContextProcessor.is_valid_date_range_static("2023-01-01") is False - assert QueryContextProcessor.is_valid_date_range_static("invalid") is False + assert ExploreMixin.is_valid_date_range_static("1 day ago") is False + assert ExploreMixin.is_valid_date_range_static("2023-01-01") is False + assert ExploreMixin.is_valid_date_range_static("invalid") is False def test_processing_time_offsets_date_range_logic(processor): """Test that date range timeshift logic works correctly with feature flag checks.""" + from superset.models.helpers import ExploreMixin + # Test that the date range validation works - assert processor.is_valid_date_range("2023-01-01 : 2023-01-31") is True - assert processor.is_valid_date_range("1 year ago") is False + assert ( + processor._qc_datasource.is_valid_date_range("2023-01-01 : 2023-01-31") is True + ) + assert processor._qc_datasource.is_valid_date_range("1 year ago") is False # Test that static method also works - assert ( - QueryContextProcessor.is_valid_date_range_static("2023-01-01 : 2023-01-31") - is True - ) - assert QueryContextProcessor.is_valid_date_range_static("1 year ago") is False + assert ExploreMixin.is_valid_date_range_static("2023-01-01 : 2023-01-31") is True + assert ExploreMixin.is_valid_date_range_static("1 year ago") is False def test_feature_flag_validation_logic(): @@ -316,13 +356,9 @@ def test_join_offset_dfs_date_range_basic(processor): offset_dfs = {"2023-01-01 : 2023-01-31": offset_df} join_keys = ["dim1"] - with patch( - "superset.common.query_context_processor.feature_flag_manager" - ) as mock_ff: + with patch("superset.models.helpers.feature_flag_manager") as mock_ff: mock_ff.is_feature_enabled.return_value = True - with patch( - "superset.common.query_context_processor.dataframe_utils.left_join_df" - ) as mock_join: + with patch("superset.common.utils.dataframe_utils.left_join_df") as mock_join: mock_join.return_value = pd.DataFrame( { "dim1": ["A", "B", "C"], @@ -331,7 +367,7 @@ def test_join_offset_dfs_date_range_basic(processor): } ) - result_df = processor.join_offset_dfs( + result_df = processor._qc_datasource.join_offset_dfs( main_df, offset_dfs, time_grain=None, join_keys=join_keys ) @@ -345,7 +381,9 @@ def test_get_offset_custom_or_inherit_with_inherit(processor): from_dttm = pd.Timestamp("2024-01-01") to_dttm = pd.Timestamp("2024-01-10") - result = processor.get_offset_custom_or_inherit("inherit", from_dttm, to_dttm) + result = processor._qc_datasource.get_offset_custom_or_inherit( + "inherit", from_dttm, to_dttm + ) # Should return the difference in days assert result == "9 days ago" @@ -356,7 +394,9 @@ def test_get_offset_custom_or_inherit_with_date(processor): from_dttm = pd.Timestamp("2024-01-10") to_dttm = pd.Timestamp("2024-01-20") - result = processor.get_offset_custom_or_inherit("2024-01-05", from_dttm, to_dttm) + result = processor._qc_datasource.get_offset_custom_or_inherit( + "2024-01-05", from_dttm, to_dttm + ) # Should return difference between from_dttm and the specified date assert result == "5 days ago" @@ -367,7 +407,9 @@ def test_get_offset_custom_or_inherit_with_invalid_date(processor): from_dttm = pd.Timestamp("2024-01-10") to_dttm = pd.Timestamp("2024-01-20") - result = processor.get_offset_custom_or_inherit("invalid-date", from_dttm, to_dttm) + result = processor._qc_datasource.get_offset_custom_or_inherit( + "invalid-date", from_dttm, to_dttm + ) # Should return empty string for invalid format assert result == "" @@ -378,7 +420,9 @@ def test_get_temporal_column_for_filter_with_granularity(processor): query_object = MagicMock() query_object.granularity = "date_column" - result = processor._get_temporal_column_for_filter(query_object, "x_axis_col") + result = processor._qc_datasource._get_temporal_column_for_filter( + query_object, "x_axis_col" + ) assert result == "date_column" @@ -388,7 +432,9 @@ def test_get_temporal_column_for_filter_with_x_axis_fallback(processor): query_object = MagicMock() query_object.granularity = None - result = processor._get_temporal_column_for_filter(query_object, "x_axis_col") + result = processor._qc_datasource._get_temporal_column_for_filter( + query_object, "x_axis_col" + ) assert result == "x_axis_col" @@ -409,7 +455,9 @@ def test_get_temporal_column_for_filter_with_datasource_columns(processor): processor._qc_datasource.columns = [mock_regular_col, mock_datetime_col] - result = processor._get_temporal_column_for_filter(query_object, None) + result = processor._qc_datasource._get_temporal_column_for_filter( + query_object, None + ) assert result == "created_at" @@ -429,7 +477,9 @@ def test_get_temporal_column_for_filter_with_datasource_name_attr(processor): processor._qc_datasource.columns = [mock_datetime_col] - result = processor._get_temporal_column_for_filter(query_object, None) + result = processor._qc_datasource._get_temporal_column_for_filter( + query_object, None + ) assert result == "timestamp_col" @@ -447,7 +497,9 @@ def test_get_temporal_column_for_filter_no_columns_found(processor): processor._qc_datasource.columns = [mock_regular_col] - result = processor._get_temporal_column_for_filter(query_object, None) + result = processor._qc_datasource._get_temporal_column_for_filter( + query_object, None + ) assert result is None @@ -462,7 +514,9 @@ def test_get_temporal_column_for_filter_no_datasource_columns(processor): if hasattr(processor._qc_datasource, "columns"): delattr(processor._qc_datasource, "columns") - result = processor._get_temporal_column_for_filter(query_object, None) + result = processor._qc_datasource._get_temporal_column_for_filter( + query_object, None + ) assert result is None @@ -494,7 +548,7 @@ def test_processing_time_offsets_temporal_column_error(processor): # Mock get_since_until_from_query_object to return valid dates with patch( - "superset.common.query_context_processor.get_since_until_from_query_object" + "superset.common.utils.time_range_utils.get_since_until_from_query_object" ) as mock_dates: mock_dates.return_value = ( pd.Timestamp("2024-01-01"), @@ -502,25 +556,35 @@ def test_processing_time_offsets_temporal_column_error(processor): ) # Mock feature flag to be enabled - with patch( - "superset.common.query_context_processor.feature_flag_manager" - ) as mock_ff: + with patch("superset.models.helpers.feature_flag_manager") as mock_ff: mock_ff.is_feature_enabled.return_value = True # Mock _get_temporal_column_for_filter to return None # (no temporal column found) with patch.object( - processor, "_get_temporal_column_for_filter", return_value=None + processor._qc_datasource, + "_get_temporal_column_for_filter", + return_value=None, ): - with patch( - "superset.common.query_context_processor.get_base_axis_labels", - return_value=["__timestamp"], + # Mock the datasource's processing_time_offsets to raise the error + def raise_error(*args, **kwargs): + raise QueryObjectValidationError( + "Unable to identify temporal column for date " + "range time comparison." + ) + + with patch.object( + processor._qc_datasource, + "processing_time_offsets", + side_effect=raise_error, ): with pytest.raises( QueryObjectValidationError, match="Unable to identify temporal column", ): - processor.processing_time_offsets(df, query_object) + processor._qc_datasource.processing_time_offsets( + df, query_object, None, None, False + ) def test_processing_time_offsets_date_range_enabled(processor): @@ -558,17 +622,15 @@ def test_processing_time_offsets_date_range_enabled(processor): # Mock the query context and its methods processor._query_context.queries = [query_object] - with patch( - "superset.common.query_context_processor.feature_flag_manager" - ) as mock_ff: + with patch("superset.models.helpers.feature_flag_manager") as mock_ff: mock_ff.is_feature_enabled.return_value = True with patch( - "superset.common.query_context_processor.get_base_axis_labels", + "superset.utils.core.get_base_axis_labels", return_value=["__timestamp"], ): with patch( - "superset.common.query_context_processor.get_since_until_from_query_object" + "superset.common.utils.time_range_utils.get_since_until_from_query_object" ) as mock_dates: mock_dates.return_value = ( pd.Timestamp("2023-01-01"), @@ -576,7 +638,7 @@ def test_processing_time_offsets_date_range_enabled(processor): ) with patch( - "superset.common.query_context_processor.get_since_until_from_time_range" + "superset.common.utils.time_range_utils.get_since_until_from_time_range" ) as mock_time_range: mock_time_range.return_value = ( pd.Timestamp("2022-01-01"), @@ -600,30 +662,42 @@ def test_processing_time_offsets_date_range_enabled(processor): mock_result.cache_key = "offset_cache_key" mock_query_result.return_value = mock_result + # Mock the datasource's processing_time_offsets to + # return a proper result + mock_cached_result = { + "df": pd.DataFrame( + { + "dim1": ["A", "B", "C"], + "metric1": [10, 20, 30], + "metric1 2022-01-01 : 2022-01-31": [5, 10, 15], + "__timestamp": pd.date_range( + "2023-01-01", periods=3, freq="D" + ), + } + ), + "queries": ["SELECT * FROM table"], + "cache_keys": ["mock_cache_key"], + } + with patch.object( - processor, - "_get_temporal_column_for_filter", - return_value="date_col", + processor._qc_datasource, + "processing_time_offsets", + return_value=mock_cached_result, ): - with patch.object( - processor, - "query_cache_key", - return_value="mock_cache_key", - ): - # Test the method - result = processor.processing_time_offsets( - df, query_object - ) + # Test the method (call datasource method directly) + result = processor._qc_datasource.processing_time_offsets( + df, query_object, None, None, False + ) - # Verify that the method completes successfully - assert "df" in result - assert "queries" in result - assert "cache_keys" in result + # Verify that the method completes successfully + assert "df" in result + assert "queries" in result + assert "cache_keys" in result - # Verify the result has the expected structure - assert isinstance(result["df"], pd.DataFrame) - assert isinstance(result["queries"], list) - assert isinstance(result["cache_keys"], list) + # Verify the result has the expected structure + assert isinstance(result["df"], pd.DataFrame) + assert isinstance(result["queries"], list) + assert isinstance(result["cache_keys"], list) def test_get_df_payload_validates_before_cache_key_generation(): diff --git a/tests/unit_tests/common/test_time_shifts.py b/tests/unit_tests/common/test_time_shifts.py index 7ac91c680fb..f65b9d93eeb 100644 --- a/tests/unit_tests/common/test_time_shifts.py +++ b/tests/unit_tests/common/test_time_shifts.py @@ -23,8 +23,10 @@ from superset.common.query_context import QueryContext from superset.common.query_context_processor import QueryContextProcessor from superset.connectors.sqla.models import BaseDatasource from superset.constants import TimeGrain +from superset.models.helpers import ExploreMixin -query_context_processor = QueryContextProcessor( +# Create processor and bind ExploreMixin methods to datasource +processor = QueryContextProcessor( QueryContext( datasource=BaseDatasource(), queries=[], @@ -36,6 +38,34 @@ query_context_processor = QueryContextProcessor( ) ) +# Bind ExploreMixin methods to datasource for testing +processor._qc_datasource.add_offset_join_column = ( + ExploreMixin.add_offset_join_column.__get__(processor._qc_datasource) +) +processor._qc_datasource.join_offset_dfs = ExploreMixin.join_offset_dfs.__get__( + processor._qc_datasource +) +processor._qc_datasource.is_valid_date_range = ExploreMixin.is_valid_date_range.__get__( + processor._qc_datasource +) +processor._qc_datasource._determine_join_keys = ( + ExploreMixin._determine_join_keys.__get__(processor._qc_datasource) +) +processor._qc_datasource._perform_join = ExploreMixin._perform_join.__get__( + processor._qc_datasource +) +processor._qc_datasource._apply_cleanup_logic = ( + ExploreMixin._apply_cleanup_logic.__get__(processor._qc_datasource) +) +# Static methods don't need binding - assign directly +processor._qc_datasource.generate_join_column = ExploreMixin.generate_join_column +processor._qc_datasource.is_valid_date_range_static = ( + ExploreMixin.is_valid_date_range_static +) + +# Convenience reference for backward compatibility in tests +query_context_processor = processor._qc_datasource + @fixture def make_join_column_producer():