Make timestamp expression native SQLAlchemy element (#7131)

* Add native sqla component for time expressions

* Add unit tests and remove old tests

* Remove redundant _grains_dict method

* Clarify time_grain logic

* Add docstrings and typing

* Fix flake8 errors

* Add missing typings

* Rename to TimestampExpression

* Remove redundant tests

* Fix broken reference to db.database_name due to refactor
This commit is contained in:
Ville Brofeldt
2019-05-30 08:28:37 +03:00
committed by GitHub
parent fc3b043462
commit 34407e8962
5 changed files with 128 additions and 121 deletions

View File

@@ -17,15 +17,16 @@
import inspect
from unittest import mock
from sqlalchemy import column, select, table
from sqlalchemy.dialects.mssql import pymssql
from sqlalchemy import column, literal_column, select, table
from sqlalchemy.dialects import mssql, oracle, postgresql
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.types import String, UnicodeText
from superset import db_engine_specs
from superset.db_engine_specs import (
BaseEngineSpec, BQEngineSpec, HiveEngineSpec, MssqlEngineSpec,
MySQLEngineSpec, OracleEngineSpec, PrestoEngineSpec,
MySQLEngineSpec, OracleEngineSpec, PinotEngineSpec, PostgresEngineSpec,
PrestoEngineSpec,
)
from superset.models.core import Database
from .base_tests import SupersetTestCase
@@ -451,7 +452,7 @@ class DbEngineSpecsTestCase(SupersetTestCase):
assert_type('NTEXT', UnicodeText)
def test_mssql_where_clause_n_prefix(self):
dialect = pymssql.dialect()
dialect = mssql.dialect()
spec = MssqlEngineSpec
str_col = column('col', type_=spec.get_sqla_column_type('VARCHAR(10)'))
unicode_col = column('unicode_col', type_=spec.get_sqla_column_type('NTEXT'))
@@ -462,7 +463,9 @@ class DbEngineSpecsTestCase(SupersetTestCase):
where(unicode_col == 'abc')
query = str(sel.compile(dialect=dialect, compile_kwargs={'literal_binds': True}))
query_expected = "SELECT col, unicode_col \nFROM tbl \nWHERE col = 'abc' AND unicode_col = N'abc'" # noqa
query_expected = 'SELECT col, unicode_col \n' \
'FROM tbl \n' \
"WHERE col = 'abc' AND unicode_col = N'abc'"
self.assertEqual(query, query_expected)
def test_get_table_names(self):
@@ -483,3 +486,51 @@ class DbEngineSpecsTestCase(SupersetTestCase):
pg_result = db_engine_specs.PostgresEngineSpec.get_table_names(
schema='schema', inspector=inspector)
self.assertListEqual(pg_result_expected, pg_result)
def test_pg_time_expression_literal_no_grain(self):
col = literal_column('COALESCE(a, b)')
expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
result = str(expr.compile(dialect=postgresql.dialect()))
self.assertEqual(result, 'COALESCE(a, b)')
def test_pg_time_expression_literal_1y_grain(self):
col = literal_column('COALESCE(a, b)')
expr = PostgresEngineSpec.get_timestamp_expr(col, None, 'P1Y')
result = str(expr.compile(dialect=postgresql.dialect()))
self.assertEqual(result, "DATE_TRUNC('year', COALESCE(a, b))")
def test_pg_time_expression_lower_column_no_grain(self):
col = column('lower_case')
expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
result = str(expr.compile(dialect=postgresql.dialect()))
self.assertEqual(result, 'lower_case')
def test_pg_time_expression_lower_case_column_sec_1y_grain(self):
col = column('lower_case')
expr = PostgresEngineSpec.get_timestamp_expr(col, 'epoch_s', 'P1Y')
result = str(expr.compile(dialect=postgresql.dialect()))
self.assertEqual(result, "DATE_TRUNC('year', (timestamp 'epoch' + lower_case * interval '1 second'))") # noqa
def test_pg_time_expression_mixed_case_column_1y_grain(self):
col = column('MixedCase')
expr = PostgresEngineSpec.get_timestamp_expr(col, None, 'P1Y')
result = str(expr.compile(dialect=postgresql.dialect()))
self.assertEqual(result, "DATE_TRUNC('year', \"MixedCase\")")
def test_mssql_time_expression_mixed_case_column_1y_grain(self):
col = column('MixedCase')
expr = MssqlEngineSpec.get_timestamp_expr(col, None, 'P1Y')
result = str(expr.compile(dialect=mssql.dialect()))
self.assertEqual(result, 'DATEADD(year, DATEDIFF(year, 0, [MixedCase]), 0)')
def test_oracle_time_expression_reserved_keyword_1m_grain(self):
col = column('decimal')
expr = OracleEngineSpec.get_timestamp_expr(col, None, 'P1M')
result = str(expr.compile(dialect=oracle.dialect()))
self.assertEqual(result, "TRUNC(CAST(\"decimal\" as DATE), 'MONTH')")
def test_pinot_time_expression_sec_1m_grain(self):
col = column('tstamp')
expr = PinotEngineSpec.get_timestamp_expr(col, 'epoch_s', 'P1M')
result = str(expr.compile())
self.assertEqual(result, 'DATETIMECONVERT(tstamp, "1:SECONDS:EPOCH", "1:SECONDS:EPOCH", "1:MONTHS")') # noqa