feat: generate label map on the backend (#21124)

This commit is contained in:
Yongjie Zhao
2022-08-22 21:00:02 +08:00
committed by GitHub
parent 756ed0e36a
commit 11bf7b9125
8 changed files with 154 additions and 2 deletions

View File

@@ -358,3 +358,30 @@ def physical_dataset():
for ds in dataset:
db.session.delete(ds)
db.session.commit()
@pytest.fixture
def virtual_dataset_comma_in_column_value():
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
dataset = SqlaTable(
table_name="virtual_dataset",
sql=(
"SELECT 'col1,row1' as col1, 'col2, row1' as col2 "
"UNION ALL "
"SELECT 'col1,row2' as col1, 'col2, row2' as col2 "
"UNION ALL "
"SELECT 'col1,row3' as col1, 'col2, row3' as col2 "
),
database=get_example_database(),
)
TableColumn(column_name="col1", type="VARCHAR(255)", table=dataset)
TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset)
SqlMetric(metric_name="count", expression="count(*)", table=dataset)
db.session.merge(dataset)
yield dataset
db.session.delete(dataset)
db.session.commit()

View File

@@ -25,6 +25,7 @@ from superset import db
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.query_context import QueryContext
from superset.common.query_context_factory import QueryContextFactory
from superset.common.query_object import QueryObject
from superset.connectors.sqla.models import SqlMetric
from superset.datasource.dao import DatasourceDAO
@@ -35,6 +36,7 @@ from superset.utils.core import (
DatasourceType,
QueryStatus,
)
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
@@ -683,3 +685,46 @@ class TestQueryContext(SupersetTestCase):
row["sum__num__3 years later"]
== df_3_years_later.loc[index]["sum__num"]
)
def test_get_label_map(app_context, virtual_dataset_comma_in_column_value):
qc = QueryContextFactory().create(
datasource={
"type": virtual_dataset_comma_in_column_value.type,
"id": virtual_dataset_comma_in_column_value.id,
},
queries=[
{
"columns": ["col1", "col2"],
"metrics": ["count"],
"post_processing": [
{
"operation": "pivot",
"options": {
"aggregates": {"count": {"operator": "mean"}},
"columns": ["col2"],
"index": ["col1"],
},
},
{"operation": "flatten"},
],
}
],
result_type=ChartDataResultType.FULL,
force=True,
)
query_object = qc.queries[0]
df = qc.get_df_payload(query_object)["df"]
label_map = qc.get_df_payload(query_object)["label_map"]
assert list(df.columns.values) == [
"col1",
"count" + FLAT_COLUMN_SEPARATOR + "col2, row1",
"count" + FLAT_COLUMN_SEPARATOR + "col2, row2",
"count" + FLAT_COLUMN_SEPARATOR + "col2, row3",
]
assert label_map == {
"col1": ["col1"],
"count, col2, row1": ["count", "col2, row1"],
"count, col2, row2": ["count", "col2, row2"],
"count, col2, row3": ["count", "col2, row3"],
}