diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 9f1f4bfdf06..c1162b9671a 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -113,11 +113,11 @@ class QueryContext: } @staticmethod - def left_join_on_dttm( - left_df: pd.DataFrame, right_df: pd.DataFrame + def left_join_df( + left_df: pd.DataFrame, right_df: pd.DataFrame, join_keys: List[str], ) -> pd.DataFrame: - df = left_df.set_index(DTTM_ALIAS).join(right_df.set_index(DTTM_ALIAS)) - df.reset_index(level=0, inplace=True) + df = left_df.set_index(join_keys).join(right_df.set_index(join_keys)) + df.reset_index(inplace=True) return df def processing_time_offsets( # pylint: disable=too-many-locals @@ -125,8 +125,9 @@ class QueryContext: ) -> CachedTimeOffset: # ensure query_object is immutable query_object_clone = copy.copy(query_object) - queries = [] - cache_keys = [] + queries: List[str] = [] + cache_keys: List[Optional[str]] = [] + rv_dfs: List[pd.DataFrame] = [df] time_offsets = query_object.time_offsets outer_from_dttm = query_object.from_dttm @@ -155,31 +156,34 @@ class QueryContext: # `offset` is added to the hash function cache_key = self.query_cache_key(query_object_clone, time_offset=offset) cache = QueryCacheManager.get(cache_key, CacheRegion.DATA, self.force) - # whether hit in the cache + # whether hit on the cache if cache.is_loaded: - df = self.left_join_on_dttm(df, cache.df) + rv_dfs.append(cache.df) queries.append(cache.query) cache_keys.append(cache_key) continue query_object_clone_dct = query_object_clone.to_dict() - result = self.datasource.query(query_object_clone_dct) - queries.append(result.query) - cache_keys.append(None) - # rename metrics: SUM(value) => SUM(value) 1 year ago - columns_name_mapping = { + metrics_mapping = { metric: TIME_COMPARISION.join([metric, offset]) for metric in get_metric_names( query_object_clone_dct.get("metrics", []) ) } - columns_name_mapping[DTTM_ALIAS] = DTTM_ALIAS + join_keys = [col for col in df.columns if col not in metrics_mapping.keys()] + + result = self.datasource.query(query_object_clone_dct) + queries.append(result.query) + cache_keys.append(None) offset_metrics_df = result.df if offset_metrics_df.empty: offset_metrics_df = pd.DataFrame( - {col: [np.NaN] for col in columns_name_mapping.values()} + { + col: [np.NaN] + for col in join_keys + list(metrics_mapping.values()) + } ) else: # 1. normalize df, set dttm column @@ -187,25 +191,23 @@ class QueryContext: offset_metrics_df, query_object_clone ) - # 2. extract `metrics` columns and `dttm` column from extra query - offset_metrics_df = offset_metrics_df[columns_name_mapping.keys()] + # 2. rename extra query columns + offset_metrics_df = offset_metrics_df.rename(columns=metrics_mapping) - # 3. rename extra query columns - offset_metrics_df = offset_metrics_df.rename( - columns=columns_name_mapping - ) - - # 4. set offset for dttm column + # 3. set time offset for dttm column offset_metrics_df[DTTM_ALIAS] = offset_metrics_df[ DTTM_ALIAS ] - DateOffset(**normalize_time_delta(offset)) - # df left join `offset_metrics_df` on `DTTM` - df = self.left_join_on_dttm(df, offset_metrics_df) + # df left join `offset_metrics_df` + offset_df = self.left_join_df( + left_df=df, right_df=offset_metrics_df, join_keys=join_keys, + ) + offset_slice = offset_df[metrics_mapping.values()] - # set offset df to cache. + # set offset_slice to cache and stack. value = { - "df": offset_metrics_df, + "df": offset_slice, "query": result.query, } cache.set( @@ -215,8 +217,10 @@ class QueryContext: datasource_uid=self.datasource.uid, region=CacheRegion.DATA, ) + rv_dfs.append(offset_slice) - return CachedTimeOffset(df=df, queries=queries, cache_keys=cache_keys) + rv_df = pd.concat(rv_dfs, axis=1, copy=False) if time_offsets else df + return CachedTimeOffset(df=rv_df, queries=queries, cache_keys=cache_keys) def normalize_df(self, df: pd.DataFrame, query_object: QueryObject) -> pd.DataFrame: timestamp_format = None diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index ecc69b7b8d3..cd7654032c7 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -20,6 +20,7 @@ import time from typing import Any, Dict import pytest +from pandas import DateOffset from superset import db from superset.charts.schemas import ChartDataQueryContextSchema @@ -546,11 +547,15 @@ class TestQueryContext(SupersetTestCase): self.login(username="admin") payload = get_query_context("birth_names") payload["queries"][0]["metrics"] = ["sum__num"] + # should process empty dateframe correctly + # due to "name" is random generated, each time_offset slice will be empty payload["queries"][0]["groupby"] = ["name"] payload["queries"][0]["is_timeseries"] = True payload["queries"][0]["timeseries_limit"] = 5 payload["queries"][0]["time_offsets"] = [] payload["queries"][0]["time_range"] = "1990 : 1991" + payload["queries"][0]["granularity"] = "ds" + payload["queries"][0]["extras"]["time_grain_sqla"] = "P1Y" query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] query_result = query_context.get_query_result(query_object) @@ -588,3 +593,97 @@ class TestQueryContext(SupersetTestCase): self.assertIs(rv["df"], df) self.assertEqual(rv["queries"], []) self.assertEqual(rv["cache_keys"], []) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_time_offsets_sql(self): + payload = get_query_context("birth_names") + payload["queries"][0]["metrics"] = ["sum__num"] + payload["queries"][0]["groupby"] = ["state"] + payload["queries"][0]["is_timeseries"] = True + payload["queries"][0]["timeseries_limit"] = 5 + payload["queries"][0]["time_offsets"] = [] + payload["queries"][0]["time_range"] = "1980 : 1991" + payload["queries"][0]["granularity"] = "ds" + payload["queries"][0]["extras"]["time_grain_sqla"] = "P1Y" + query_context = ChartDataQueryContextSchema().load(payload) + query_object = query_context.queries[0] + query_result = query_context.get_query_result(query_object) + # get main query dataframe + df = query_result.df + + # set time_offsets to query_object + payload["queries"][0]["time_offsets"] = ["3 years ago", "3 years later"] + query_context = ChartDataQueryContextSchema().load(payload) + query_object = query_context.queries[0] + time_offsets_obj = query_context.processing_time_offsets(df, query_object) + query_from_1977_to_1988 = time_offsets_obj["queries"][0] + query_from_1983_to_1994 = time_offsets_obj["queries"][1] + + # should generate expected date range in sql + assert "1977-01-01" in query_from_1977_to_1988 + assert "1988-01-01" in query_from_1977_to_1988 + assert "1983-01-01" in query_from_1983_to_1994 + assert "1994-01-01" in query_from_1983_to_1994 + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_time_offsets_accuracy(self): + payload = get_query_context("birth_names") + payload["queries"][0]["metrics"] = ["sum__num"] + payload["queries"][0]["groupby"] = ["state"] + payload["queries"][0]["is_timeseries"] = True + payload["queries"][0]["timeseries_limit"] = 5 + payload["queries"][0]["time_offsets"] = [] + payload["queries"][0]["time_range"] = "1980 : 1991" + payload["queries"][0]["granularity"] = "ds" + payload["queries"][0]["extras"]["time_grain_sqla"] = "P1Y" + query_context = ChartDataQueryContextSchema().load(payload) + query_object = query_context.queries[0] + query_result = query_context.get_query_result(query_object) + # get main query dataframe + df = query_result.df + + # set time_offsets to query_object + payload["queries"][0]["time_offsets"] = ["3 years ago", "3 years later"] + query_context = ChartDataQueryContextSchema().load(payload) + query_object = query_context.queries[0] + time_offsets_obj = query_context.processing_time_offsets(df, query_object) + df_with_offsets = time_offsets_obj["df"] + df_with_offsets = df_with_offsets.set_index(["__timestamp", "state"]) + + # should get correct data when apply "3 years ago" + payload["queries"][0]["time_offsets"] = [] + payload["queries"][0]["time_range"] = "1977 : 1988" + query_context = ChartDataQueryContextSchema().load(payload) + query_object = query_context.queries[0] + query_result = query_context.get_query_result(query_object) + # get df for "3 years ago" + df_3_years_ago = query_result.df + df_3_years_ago["__timestamp"] = df_3_years_ago["__timestamp"] + DateOffset( + years=3 + ) + df_3_years_ago = df_3_years_ago.set_index(["__timestamp", "state"]) + for index, row in df_with_offsets.iterrows(): + if index in df_3_years_ago.index: + assert ( + row["sum__num__3 years ago"] + == df_3_years_ago.loc[index]["sum__num"] + ) + + # should get correct data when apply "3 years later" + payload["queries"][0]["time_offsets"] = [] + payload["queries"][0]["time_range"] = "1983 : 1994" + query_context = ChartDataQueryContextSchema().load(payload) + query_object = query_context.queries[0] + query_result = query_context.get_query_result(query_object) + # get df for "3 years later" + df_3_years_later = query_result.df + df_3_years_later["__timestamp"] = df_3_years_later["__timestamp"] - DateOffset( + years=3 + ) + df_3_years_later = df_3_years_later.set_index(["__timestamp", "state"]) + for index, row in df_with_offsets.iterrows(): + if index in df_3_years_later.index: + assert ( + row["sum__num__3 years later"] + == df_3_years_later.loc[index]["sum__num"] + )