From 75be3dd7b45ed98ade643d56b05a1ab10d8874b4 Mon Sep 17 00:00:00 2001 From: Rob Moore Date: Fri, 19 May 2023 21:29:42 +0100 Subject: [PATCH] fix: handle temporal columns in presto partitions (#24054) --- superset/db_engine_specs/base.py | 2 +- superset/db_engine_specs/hive.py | 2 +- superset/db_engine_specs/presto.py | 18 ++++---- .../unit_tests/db_engine_specs/test_presto.py | 43 ++++++++++++++++++- 4 files changed, 54 insertions(+), 11 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 27dd34a802d..b789bbe70ce 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1168,7 +1168,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods schema: Optional[str], database: Database, query: Select, - columns: Optional[List[Dict[str, str]]] = None, + columns: Optional[List[Dict[str, Any]]] = None, ) -> Optional[Select]: """ Add a where clause to a query to reference only the most recent partition diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index f07d53518c2..44dc435c2cb 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -404,7 +404,7 @@ class HiveEngineSpec(PrestoEngineSpec): schema: Optional[str], database: "Database", query: Select, - columns: Optional[List[Dict[str, str]]] = None, + columns: Optional[List[Dict[str, Any]]] = None, ) -> Optional[Select]: try: col_names, values = cls.latest_partition( diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 6bd556b79e3..87f362acc8f 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -462,7 +462,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): schema: Optional[str], database: Database, query: Select, - columns: Optional[List[Dict[str, str]]] = None, + columns: Optional[List[Dict[str, Any]]] = None, ) -> Optional[Select]: try: col_names, values = cls.latest_partition( @@ -480,13 +480,15 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): } for col_name, value in zip(col_names, values): - if col_name in column_type_by_name: - if column_type_by_name.get(col_name) == "TIMESTAMP": - query = query.where(Column(col_name, TimeStamp()) == value) - elif column_type_by_name.get(col_name) == "DATE": - query = query.where(Column(col_name, Date()) == value) - else: - query = query.where(Column(col_name) == value) + col_type = column_type_by_name.get(col_name) + + if isinstance(col_type, types.DATE): + col_type = Date() + elif isinstance(col_type, types.TIMESTAMP): + col_type = TimeStamp() + + query = query.where(Column(col_name, col_type) == value) + return query @classmethod diff --git a/tests/unit_tests/db_engine_specs/test_presto.py b/tests/unit_tests/db_engine_specs/test_presto.py index a30fab94c91..8f55b1c048d 100644 --- a/tests/unit_tests/db_engine_specs/test_presto.py +++ b/tests/unit_tests/db_engine_specs/test_presto.py @@ -16,10 +16,13 @@ # under the License. from datetime import datetime from typing import Any, Dict, Optional, Type +from unittest import mock import pytest import pytz -from sqlalchemy import types +from pyhive.sqlalchemy_presto import PrestoDialect +from sqlalchemy import sql, text, types +from sqlalchemy.engine.url import make_url from superset.utils.core import GenericDataType from tests.unit_tests.db_engine_specs.utils import ( @@ -82,3 +85,41 @@ def test_get_column_spec( from superset.db_engine_specs.presto import PrestoEngineSpec as spec assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm) + + +@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.latest_partition") +@pytest.mark.parametrize( + ["column_type", "column_value", "expected_value"], + [ + (types.DATE(), "2023-05-01", "DATE '2023-05-01'"), + (types.TIMESTAMP(), "2023-05-01", "TIMESTAMP '2023-05-01'"), + (types.VARCHAR(), "2023-05-01", "'2023-05-01'"), + (types.INT(), 1234, "1234"), + ], +) +def test_where_latest_partition( + mock_latest_partition: Any, + column_type: Any, + column_value: str, + expected_value: str, +) -> None: + """ + Test the ``where_latest_partition`` method + """ + from superset.db_engine_specs.presto import PrestoEngineSpec as spec + + mock_latest_partition.return_value = (["partition_key"], [column_value]) + + query = sql.select(text("* FROM table")) + columns = [{"name": "partition_key", "type": column_type}] + + expected = f"""SELECT * FROM table \nWHERE "partition_key" = {expected_value}""" + result = spec.where_latest_partition( + "table", mock.MagicMock(), mock.MagicMock(), query, columns + ) + assert result is not None + actual = result.compile( + dialect=PrestoDialect(), compile_kwargs={"literal_binds": True} + ) + + assert str(actual) == expected