fix: Normalize prequery result type (#17312)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley
2021-11-03 13:58:40 -07:00
committed by GitHub
parent cb34a22684
commit 36f489eea0
3 changed files with 113 additions and 42 deletions

View File

@@ -16,11 +16,18 @@
# under the License.
# isort:skip_file
import re
from typing import Any, Dict, NamedTuple, List, Pattern, Tuple, Union
from datetime import datetime
from typing import Any, Dict, List, NamedTuple, Optional, Pattern, Tuple, Union
from unittest.mock import patch
import pytest
import numpy as np
import pandas as pd
import sqlalchemy as sa
from flask import Flask
from pytest_mock import MockFixture
from sqlalchemy.sql import text
from sqlalchemy.sql.elements import TextClause
from superset import db
from superset.connectors.sqla.models import SqlaTable, TableColumn
@@ -33,6 +40,7 @@ from superset.utils.core import (
FilterOperator,
GenericDataType,
get_example_database,
TemporalType,
)
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
@@ -484,3 +492,70 @@ class TestDatabaseModel(SupersetTestCase):
)
assert None not in without_null
assert len(without_null) == 2
@pytest.mark.parametrize(
"row,dimension,result",
[
(pd.Series({"foo": "abc"}), "foo", "abc"),
(pd.Series({"bar": True}), "bar", True),
(pd.Series({"baz": 123}), "baz", 123),
(pd.Series({"baz": np.int16(123)}), "baz", 123),
(pd.Series({"baz": np.uint32(123)}), "baz", 123),
(pd.Series({"baz": np.int64(123)}), "baz", 123),
(pd.Series({"qux": 123.456}), "qux", 123.456),
(pd.Series({"qux": np.float32(123.456)}), "qux", 123.45600128173828),
(pd.Series({"qux": np.float64(123.456)}), "qux", 123.456),
(pd.Series({"quux": "2021-01-01"}), "quux", "2021-01-01"),
(
pd.Series({"quuz": "2021-01-01T00:00:00"}),
"quuz",
text("TIME_PARSE('2021-01-01T00:00:00')"),
),
],
)
def test__normalize_prequery_result_type(
app_context: Flask,
mocker: MockFixture,
row: pd.Series,
dimension: str,
result: Any,
) -> None:
def _convert_dttm(target_type: str, dttm: datetime) -> Optional[str]:
if target_type.upper() == TemporalType.TIMESTAMP:
return f"""TIME_PARSE('{dttm.isoformat(timespec="seconds")}')"""
return None
table = SqlaTable(table_name="foobar", database=get_example_database())
mocker.patch.object(table.db_engine_spec, "convert_dttm", new=_convert_dttm)
columns_by_name = {
"foo": TableColumn(
column_name="foo", is_dttm=False, table=table, type="STRING",
),
"bar": TableColumn(
column_name="bar", is_dttm=False, table=table, type="BOOLEAN",
),
"baz": TableColumn(
column_name="baz", is_dttm=False, table=table, type="INTEGER",
),
"qux": TableColumn(
column_name="qux", is_dttm=False, table=table, type="FLOAT",
),
"quux": TableColumn(
column_name="quuz", is_dttm=True, table=table, type="STRING",
),
"quuz": TableColumn(
column_name="quux", is_dttm=True, table=table, type="TIMESTAMP",
),
}
normalized = table._normalize_prequery_result_type(row, dimension, columns_by_name,)
assert type(normalized) == type(result)
if isinstance(normalized, TextClause):
assert str(normalized) == str(result)
else:
assert normalized == result