diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index edf06e9d88c..509dcba5a71 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -49,6 +49,7 @@ from superset.exceptions import ( from superset.extensions import cache_manager, security_manager from superset.models.helpers import QueryResult from superset.models.sql_lab import Query +from superset.superset_typing import AdhocColumn, AdhocMetric from superset.utils import csv, excel from superset.utils.cache import generate_cache_key, set_and_log_cache from superset.utils.core import ( @@ -63,6 +64,8 @@ from superset.utils.core import ( get_column_names_from_metrics, get_metric_names, get_x_axis_label, + is_adhoc_column, + is_adhoc_metric, normalize_dttm_col, TIME_COMPARISON, ) @@ -180,6 +183,30 @@ class QueryContextProcessor: ] for col in cache.df.columns.values } + label_map.update( + { + column_name: [ + str(query_obj.columns[idx]) + if not is_adhoc_column(query_obj.columns[idx]) + else cast(AdhocColumn, query_obj.columns[idx])["sqlExpression"], + ] + for idx, column_name in enumerate(query_obj.column_names) + } + ) + label_map.update( + { + metric_name: [ + str(query_obj.metrics[idx]) + if not is_adhoc_metric(query_obj.metrics[idx]) + else str(cast(AdhocMetric, query_obj.metrics[idx])["sqlExpression"]) + if cast(AdhocMetric, query_obj.metrics[idx])["expressionType"] + == "SQL" + else metric_name, + ] + for idx, metric_name in enumerate(query_obj.metric_names) + if query_obj and query_obj.metrics + } + ) cache.df.columns = [unescape_separator(col) for col in cache.df.columns.values] return { diff --git a/superset/common/query_object.py b/superset/common/query_object.py index a9956d6f73f..40729109466 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -258,7 +258,12 @@ class QueryObject: # pylint: disable=too-many-instance-attributes @property def metric_names(self) -> list[str]: """Return metrics names (labels), coerce adhoc metrics to strings.""" - return get_metric_names(self.metrics or []) + return get_metric_names( + self.metrics or [], + self.datasource.verbose_map + if self.datasource and hasattr(self.datasource, "verbose_map") + else None, + ) @property def column_names(self) -> list[str]: diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 6478fdf075d..4a006cd9f3a 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -424,7 +424,7 @@ class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable= # pull out all required metrics from the form_data for metric_param in METRIC_FORM_DATA_PARAMS: for metric in utils.as_list(form_data.get(metric_param) or []): - metric_names.add(utils.get_metric_name(metric)) + metric_names.add(utils.get_metric_name(metric, self.verbose_map)) if utils.is_adhoc_metric(metric): column_ = metric.get("column") or {} if column_name := column_.get("column_name"): @@ -476,6 +476,7 @@ class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable= metric for metric in data["metrics"] if metric["metric_name"] in metric_names + or metric["verbose_name"] in metric_names ] filtered_columns: list[Column] = [] @@ -1496,7 +1497,7 @@ class SqlaTable( :rtype: sqlalchemy.sql.column """ expression_type = metric.get("expressionType") - label = utils.get_metric_name(metric) + label = utils.get_metric_name(metric, self.verbose_map) if expression_type == utils.AdhocMetricExpressionType.SIMPLE: metric_column = metric.get("column") or {} diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index c744bb75a54..a56e8338352 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -786,6 +786,8 @@ def test_get_label_map(app_context, virtual_dataset_comma_in_column_value): "count, col2, row1": ["count", "col2, row1"], "count, col2, row2": ["count", "col2, row2"], "count, col2, row3": ["count", "col2, row3"], + "col2": ["col2"], + "count": ["count"], } diff --git a/tests/unit_tests/common/test_query_object_factory.py b/tests/unit_tests/common/test_query_object_factory.py index 4d54f77de88..8ff362dec36 100644 --- a/tests/unit_tests/common/test_query_object_factory.py +++ b/tests/unit_tests/common/test_query_object_factory.py @@ -40,7 +40,9 @@ def app_config() -> dict[str, Any]: @fixture def connector_registry() -> Mock: - return Mock(spec=["get_datasource"]) + mock = Mock(spec=["get_datasource"]) + mock.get_datasource().verbose_map = {"sum__num": "SUM", "unused": "UNUSED"} + return mock def apply_max_row_limit(limit: int, max_limit: Optional[int] = None) -> int: @@ -66,6 +68,11 @@ def raw_query_context() -> dict[str, Any]: return QueryContextGenerator().generate("birth_names") +@fixture +def metric_label_raw_query_context() -> dict[str, Any]: + return QueryContextGenerator().generate("birth_names:metric_labels") + + class TestQueryObjectFactory: def test_query_context_limit_and_offset_defaults( self, @@ -107,3 +114,21 @@ class TestQueryObjectFactory: raw_query_context["result_type"], **raw_query_object ) assert query_object.post_processing == [] + + def test_query_context_metric_names( + self, + query_object_factory: QueryObjectFactory, + raw_query_context: dict[str, Any], + ): + raw_query_context["queries"][0]["metrics"] = [ + {"label": "sum__num"}, + {"label": "num_girls"}, + {"label": "num_boys"}, + ] + raw_query_object = raw_query_context["queries"][0] + query_object = query_object_factory.create( + raw_query_context["result_type"], + datasource=raw_query_context["datasource"], + **raw_query_object, + ) + assert query_object.metric_names == ["SUM", "num_girls", "num_boys"]