diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 8fd281cf308..77f85b430ec 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -17,7 +17,7 @@ # pylint: disable=R import logging from datetime import datetime, timedelta -from typing import Any, Dict, List, NamedTuple, Optional, Union +from typing import Any, Dict, List, NamedTuple, Optional from flask_babel import gettext as _ from pandas import DataFrame @@ -103,7 +103,7 @@ class QueryObject: applied_time_extras: Optional[Dict[str, str]] = None, apply_fetch_values_predicate: bool = False, granularity: Optional[str] = None, - metrics: Optional[List[Union[Dict[str, Any], str]]] = None, + metrics: Optional[List[Metric]] = None, groupby: Optional[List[str]] = None, filters: Optional[List[Dict[str, Any]]] = None, time_range: Optional[str] = None, diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 1839f763e70..14eda046739 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -55,7 +55,13 @@ from superset.exceptions import SupersetException from superset.extensions import encrypted_field_factory from superset.models.core import Database from superset.models.helpers import AuditMixinNullable, ImportExportMixin, QueryResult -from superset.typing import FilterValues, Granularity, Metric, QueryObjectDict +from superset.typing import ( + AdhocMetric, + FilterValues, + Granularity, + Metric, + QueryObjectDict, +) from superset.utils import core as utils from superset.utils.date_parser import parse_human_datetime, parse_human_timedelta @@ -1010,7 +1016,7 @@ class DruidDatasource(Model, BaseDatasource): return ret @staticmethod - def druid_type_from_adhoc_metric(adhoc_metric: Dict[str, Any]) -> str: + def druid_type_from_adhoc_metric(adhoc_metric: AdhocMetric) -> str: column_type = adhoc_metric["column"]["type"].lower() aggregate = adhoc_metric["aggregate"].lower() @@ -1025,7 +1031,7 @@ class DruidDatasource(Model, BaseDatasource): def get_aggregations( metrics_dict: Dict[str, Any], saved_metrics: Set[str], - adhoc_metrics: Optional[List[Dict[str, Any]]] = None, + adhoc_metrics: Optional[List[AdhocMetric]] = None, ) -> "OrderedDict[str, Any]": """ Returns a dictionary of aggregation metric names to aggregation json objects diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 2f7b6d1498d..31f475f405d 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -22,7 +22,18 @@ from collections import defaultdict, OrderedDict from contextlib import closing from dataclasses import dataclass, field # pylint: disable=wrong-import-order from datetime import datetime, timedelta -from typing import Any, Dict, Hashable, List, NamedTuple, Optional, Tuple, Type, Union +from typing import ( + Any, + cast, + Dict, + Hashable, + List, + NamedTuple, + Optional, + Tuple, + Type, + Union, +) import pandas as pd import sqlalchemy as sa @@ -241,7 +252,9 @@ class TableColumn(Model, BaseColumn): column_spec = db_engine_spec.get_column_spec(self.type) type_ = column_spec.sqla_type if column_spec else None if self.expression: - col = literal_column(self.expression, type_=type_) + tp = self.table.get_template_processor() + expression = tp.process_template(self.expression) + col = literal_column(expression, type_=type_) else: col = column(self.column_name, type_=type_) col = self.table.make_sqla_column_compatible(col, label) @@ -879,7 +892,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at label = utils.get_metric_name(metric) if expression_type == utils.AdhocMetricExpressionType.SIMPLE: - column_name = metric["column"].get("column_name") + column_name = cast(str, metric["column"].get("column_name")) table_column: Optional[TableColumn] = columns_by_name.get(column_name) if table_column: sqla_column = table_column.get_sqla_col() @@ -887,7 +900,9 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at sqla_column = column(column_name) sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column) elif expression_type == utils.AdhocMetricExpressionType.SQL: - sqla_metric = literal_column(metric.get("sqlExpression")) + tp = self.get_template_processor() + expression = tp.process_template(cast(str, metric["sqlExpression"])) + sqla_metric = literal_column(expression) else: raise QueryObjectValidationError("Adhoc metric expressionType is invalid") @@ -1060,8 +1075,9 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at # Since orderby may use adhoc metrics, too; we need to process them first orderby_exprs: List[ColumnElement] = [] for orig_col, ascending in orderby: - col: Union[Metric, ColumnElement] = orig_col + col: Union[AdhocMetric, ColumnElement] = orig_col if isinstance(col, dict): + col = cast(AdhocMetric, col) if utils.is_adhoc_metric(col): # add adhoc sort by column to columns_by_name if not exists col = self.adhoc_metric_to_sqla(col, columns_by_name) diff --git a/superset/typing.py b/superset/typing.py index 0a7ef598ad3..f428831fd20 100644 --- a/superset/typing.py +++ b/superset/typing.py @@ -19,8 +19,23 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from flask import Flask from flask_caching import Cache +from typing_extensions import TypedDict from werkzeug.wrappers import Response + +class AdhocMetricColumn(TypedDict): + column_name: Optional[str] + type: str + + +class AdhocMetric(TypedDict): + aggregate: str + column: AdhocMetricColumn + expressionType: str + label: str + sqlExpression: Optional[str] + + CacheConfig = Union[Callable[[Flask], Cache], Dict[str, Any]] DbapiDescriptionRow = Tuple[ str, str, Optional[str], Optional[str], Optional[int], Optional[int], bool @@ -31,7 +46,6 @@ FilterValue = Union[datetime, float, int, str] FilterValues = Union[FilterValue, List[FilterValue], Tuple[FilterValue]] FormData = Dict[str, Any] Granularity = Union[str, Dict[str, Union[str, float]]] -AdhocMetric = Dict[str, Any] Metric = Union[AdhocMetric, str] OrderBy = Tuple[Metric, bool] QueryObjectDict = Dict[str, Any] diff --git a/superset/utils/core.py b/superset/utils/core.py index c8c352d3682..12c66e41a72 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -96,7 +96,7 @@ from superset.exceptions import ( SupersetException, SupersetTimeoutException, ) -from superset.typing import FlaskResponse, FormData, Metric +from superset.typing import AdhocMetric, FlaskResponse, FormData, Metric from superset.utils.dates import datetime_to_epoch, EPOCH from superset.utils.hashing import md5_sha_from_dict, md5_sha_from_str @@ -1494,7 +1494,7 @@ def get_column_name_from_metric(metric: Metric) -> Optional[str]: :return: column name if simple metric, otherwise None """ if is_adhoc_metric(metric): - metric = cast(Dict[str, Any], metric) + metric = cast(AdhocMetric, metric) if metric["expressionType"] == AdhocMetricExpressionType.SIMPLE: return cast(Dict[str, Any], metric["column"])["column_name"] return None diff --git a/superset/viz.py b/superset/viz.py index 53ea2ca4dee..04adab3d2a1 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -66,7 +66,7 @@ from superset.exceptions import ( from superset.extensions import cache_manager, security_manager from superset.models.cache import CacheKey from superset.models.helpers import QueryResult -from superset.typing import QueryObjectDict, VizData, VizPayload +from superset.typing import Metric, QueryObjectDict, VizData, VizPayload from superset.utils import core as utils, csv from superset.utils.cache import set_and_log_cache from superset.utils.core import ( @@ -526,10 +526,7 @@ class BaseViz: for col in (query_obj.get("columns") or []) + (query_obj.get("groupby") or []) + utils.get_column_names_from_metrics( - cast( - List[Union[str, Dict[str, Any]]], - query_obj.get("metrics") or [], - ) + cast(List[Metric], query_obj.get("metrics") or [],) ) if col not in self.datasource.column_names ] diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py index a759270809c..edd9dce5f05 100644 --- a/tests/sqla_models_tests.py +++ b/tests/sqla_models_tests.py @@ -26,7 +26,12 @@ from superset.db_engine_specs.bigquery import BigQueryEngineSpec from superset.db_engine_specs.druid import DruidEngineSpec from superset.exceptions import QueryObjectValidationError from superset.models.core import Database -from superset.utils.core import GenericDataType, get_example_database, FilterOperator +from superset.utils.core import ( + AdhocMetricExpressionType, + FilterOperator, + GenericDataType, + get_example_database, +) from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices from .base_tests import SupersetTestCase @@ -168,6 +173,50 @@ class TestDatabaseModel(SupersetTestCase): db.session.delete(table) db.session.commit() + @patch("superset.jinja_context.g") + def test_jinja_metrics_and_calc_columns(self, flask_g): + flask_g.user.username = "abc" + base_query_obj = { + "granularity": None, + "from_dttm": None, + "to_dttm": None, + "groupby": ["user", "expr"], + "metrics": [ + { + "expressionType": AdhocMetricExpressionType.SQL, + "sqlExpression": "SUM(case when user = '{{ current_username() }}' " + "then 1 else 0 end)", + "label": "SUM(userid)", + } + ], + "is_timeseries": False, + "filter": [], + } + + table = SqlaTable( + table_name="test_has_jinja_metric_and_expr", + sql="SELECT '{{ current_username() }}' as user", + database=get_example_database(), + ) + TableColumn( + column_name="expr", + expression="case when '{{ current_username() }}' = 'abc' " + "then 'yes' else 'no' end", + type="VARCHAR(100)", + table=table, + ) + db.session.commit() + + sqla_query = table.get_sqla_query(**base_query_obj) + query = table.database.compile_sqla_query(sqla_query.sqla_query) + # assert expression + assert "case when 'abc' = 'abc' then 'yes' else 'no' end" in query + # assert metric + assert "SUM(case when user = 'abc' then 1 else 0 end)" in query + # Cleanup + db.session.delete(table) + db.session.commit() + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_where_operators(self): class FilterTestCase(NamedTuple):