fix: Handling of column types for Presto, Trino, et al. (#28653)

This commit is contained in:
John Bodley
2024-05-28 08:56:38 -07:00
committed by GitHub
parent a59bad83d4
commit 4ff17409ab
4 changed files with 85 additions and 39 deletions

View File

@@ -19,16 +19,17 @@ import copy
import json
from datetime import datetime
from typing import Any, Optional
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch
import pandas as pd
import pytest
from pytest_mock import MockerFixture
from requests.exceptions import ConnectionError as RequestsConnectionError
from sqlalchemy import types
from sqlalchemy import sql, text, types
from sqlalchemy.engine.url import make_url
from trino.exceptions import TrinoExternalError, TrinoInternalError, TrinoUserError
from trino.sqlalchemy import datatype
from trino.sqlalchemy.dialect import TrinoDialect
import superset.config
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT
@@ -39,7 +40,7 @@ from superset.db_engine_specs.exceptions import (
SupersetDBAPIProgrammingError,
)
from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType, SQLAColumnType
from superset.superset_typing import ResultSetColumnType, SQLAColumnType, SQLType
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
assert_column_spec,
@@ -645,3 +646,46 @@ def test_get_default_catalog() -> None:
sqlalchemy_uri="trino://user:pass@localhost:8080/system/default",
)
assert TrinoEngineSpec.get_default_catalog(database) == "system"
@patch("superset.db_engine_specs.trino.TrinoEngineSpec.latest_partition")
@pytest.mark.parametrize(
["column_type", "column_value", "expected_value"],
[
(types.DATE(), "2023-05-01", "DATE '2023-05-01'"),
(types.TIMESTAMP(), "2023-05-01", "TIMESTAMP '2023-05-01'"),
(types.VARCHAR(), "2023-05-01", "'2023-05-01'"),
(types.INT(), 1234, "1234"),
],
)
def test_where_latest_partition(
mock_latest_partition,
column_type: SQLType,
column_value: Any,
expected_value: str,
) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
mock_latest_partition.return_value = (["partition_key"], [column_value])
assert (
str(
TrinoEngineSpec.where_latest_partition( # type: ignore
database=MagicMock(),
table=Table("table"),
query=sql.select(text("* FROM table")),
columns=[
{
"column_name": "partition_key",
"name": "partition_key",
"type": column_type,
"is_dttm": False,
}
],
).compile(
dialect=TrinoDialect(),
compile_kwargs={"literal_binds": True},
)
)
== f"""SELECT * FROM table \nWHERE partition_key = {expected_value}"""
)