chore(db_engine_specs): clean up column spec logic and add tests (#22871)

This commit is contained in:
Ville Brofeldt
2023-01-31 15:54:07 +02:00
committed by GitHub
parent 8466eec228
commit cd6fc35f60
73 changed files with 1953 additions and 1463 deletions

View File

@@ -22,7 +22,6 @@ from tests.integration_tests.test_app import app
from tests.integration_tests.base_tests import SupersetTestCase
from superset.db_engine_specs.base import BaseEngineSpec
from superset.models.core import Database
from superset.utils.core import GenericDataType
class TestDbEngineSpec(SupersetTestCase):
@@ -37,16 +36,3 @@ class TestDbEngineSpec(SupersetTestCase):
main = Database(database_name="test_database", sqlalchemy_uri="sqlite://")
limited = engine_spec_class.apply_limit_to_sql(sql, limit, main, force)
self.assertEqual(expected_sql, limited)
def assert_generic_types(
spec: Type[BaseEngineSpec],
type_expectations: Tuple[Tuple[str, GenericDataType], ...],
) -> None:
for type_str, expected_type in type_expectations:
column_spec = spec.get_column_spec(type_str)
assert column_spec is not None
actual_type = column_spec.generic_type
assert (
actual_type == expected_type
), f"{type_str} should be {expected_type.name} but is {actual_type.name}"

View File

@@ -48,23 +48,6 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
actual = BigQueryEngineSpec.make_label_compatible(column(original).name)
self.assertEqual(actual, expected)
def test_convert_dttm(self):
"""
DB Eng Specs (bigquery): Test conversion to date time
"""
dttm = self.get_dttm()
test_cases = {
"DATE": "CAST('2019-01-02' AS DATE)",
"DATETIME": "CAST('2019-01-02T03:04:05.678900' AS DATETIME)",
"TIMESTAMP": "CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)",
"TIME": "CAST('03:04:05.678900' AS TIME)",
"UNKNOWNTYPE": None,
}
for target_type, expected in test_cases.items():
actual = BigQueryEngineSpec.convert_dttm(target_type, dttm)
self.assertEqual(actual, expected)
def test_timegrain_expressions(self):
"""
DB Eng Specs (bigquery): Test time grain expressions

View File

@@ -1,53 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.db_engine_specs.crate import CrateEngineSpec
from superset.models.core import Database
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestCrateDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
"""
DB Eng Specs (crate): Test conversion to date time
"""
dttm = self.get_dttm()
assert CrateEngineSpec.convert_dttm("TIMESTAMP", dttm) == str(
dttm.timestamp() * 1000
)
def test_epoch_to_dttm(self):
"""
DB Eng Specs (crate): Test epoch to dttm
"""
assert CrateEngineSpec.epoch_to_dttm() == "{col} * 1000"
def test_epoch_ms_to_dttm(self):
"""
DB Eng Specs (crate): Test epoch ms to dttm
"""
assert CrateEngineSpec.epoch_ms_to_dttm() == "{col}"
def test_alter_new_orm_column(self):
"""
DB Eng Specs (crate): Test alter orm column
"""
database = Database(database_name="crate", sqlalchemy_uri="crate://db")
tbl = SqlaTable(table_name="druid_tbl", database=database)
col = TableColumn(column_name="ts", type="TIMESTAMP", table=tbl)
CrateEngineSpec.alter_new_orm_column(col)
assert col.python_date_format == "epoch_ms"

View File

@@ -14,18 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from textwrap import dedent
from unittest import mock
from sqlalchemy import column, literal_column
from superset.constants import USER_AGENT
from superset.db_engine_specs import get_engine_spec
from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import (
assert_generic_types,
TestDbEngineSpec,
)
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.fixtures.certificates import ssl_certificate
from tests.integration_tests.fixtures.database import default_db_extra

View File

@@ -1,33 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.db_engine_specs.dremio import DremioEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestDremioDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
DremioEngineSpec.convert_dttm("DATE", dttm),
"TO_DATE('2019-01-02', 'YYYY-MM-DD')",
)
self.assertEqual(
DremioEngineSpec.convert_dttm("TIMESTAMP", dttm),
"TO_TIMESTAMP('2019-01-02 03:04:05.678', 'YYYY-MM-DD HH24:MI:SS.FFF')",
)

View File

@@ -1,33 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.db_engine_specs.drill import DrillEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestDrillDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
DrillEngineSpec.convert_dttm("DATE", dttm),
"TO_DATE('2019-01-02', 'yyyy-MM-dd')",
)
self.assertEqual(
DrillEngineSpec.convert_dttm("TIMESTAMP", dttm),
"TO_TIMESTAMP('2019-01-02 03:04:05', 'yyyy-MM-dd HH:mm:ss')",
)

View File

@@ -1,78 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest import mock
from sqlalchemy import column
from superset.db_engine_specs.druid import DruidEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.fixtures.certificates import ssl_certificate
from tests.integration_tests.fixtures.database import default_db_extra
class TestDruidDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
DruidEngineSpec.convert_dttm("DATETIME", dttm),
"TIME_PARSE('2019-01-02T03:04:05')",
)
self.assertEqual(
DruidEngineSpec.convert_dttm("TIMESTAMP", dttm),
"TIME_PARSE('2019-01-02T03:04:05')",
)
self.assertEqual(
DruidEngineSpec.convert_dttm("DATE", dttm),
"CAST(TIME_PARSE('2019-01-02') AS DATE)",
)
def test_timegrain_expressions(self):
"""
DB Eng Specs (druid): Test time grain expressions
"""
col = "__time"
sqla_col = column(col)
test_cases = {
"PT1S": f"TIME_FLOOR(CAST({col} AS TIMESTAMP), 'PT1S')",
"PT5M": f"TIME_FLOOR(CAST({col} AS TIMESTAMP), 'PT5M')",
"P1W/1970-01-03T00:00:00Z": f"TIME_SHIFT(TIME_FLOOR(TIME_SHIFT(CAST({col} AS TIMESTAMP), 'P1D', 1), 'P1W'), 'P1D', 5)",
"1969-12-28T00:00:00Z/P1W": f"TIME_SHIFT(TIME_FLOOR(TIME_SHIFT(CAST({col} AS TIMESTAMP), 'P1D', 1), 'P1W'), 'P1D', -1)",
}
for grain, expected in test_cases.items():
actual = DruidEngineSpec.get_timestamp_expr(
col=sqla_col, pdf=None, time_grain=grain
)
self.assertEqual(str(actual), expected)
def test_extras_without_ssl(self):
db = mock.Mock()
db.extra = default_db_extra
db.server_cert = None
extras = DruidEngineSpec.get_extra_params(db)
assert "connect_args" not in extras["engine_params"]
def test_extras_with_ssl(self):
db = mock.Mock()
db.extra = default_db_extra
db.server_cert = ssl_certificate
extras = DruidEngineSpec.get_extra_params(db)
connect_args = extras["engine_params"]["connect_args"]
assert connect_args["scheme"] == "https"
assert "ssl_verify_cert" in connect_args

View File

@@ -1,104 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest.mock import MagicMock
import pytest
from sqlalchemy import column
from superset.db_engine_specs.elasticsearch import (
ElasticSearchEngineSpec,
OpenDistroEngineSpec,
)
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestElasticSearchDbEngineSpec(TestDbEngineSpec):
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
ElasticSearchEngineSpec.convert_dttm("DATETIME", dttm, db_extra=None),
"CAST('2019-01-02T03:04:05' AS DATETIME)",
)
def test_convert_dttm2(self):
"""
ES 7.8 and above versions need to use the DATETIME_PARSE function to
solve the time zone problem
"""
dttm = self.get_dttm()
db_extra = {"version": "7.8"}
self.assertEqual(
ElasticSearchEngineSpec.convert_dttm("DATETIME", dttm, db_extra=db_extra),
"DATETIME_PARSE('2019-01-02 03:04:05', 'yyyy-MM-dd HH:mm:ss')",
)
def test_convert_dttm3(self):
dttm = self.get_dttm()
db_extra = {"version": 7.8}
self.assertEqual(
ElasticSearchEngineSpec.convert_dttm("DATETIME", dttm, db_extra=db_extra),
"CAST('2019-01-02T03:04:05' AS DATETIME)",
)
self.assertNotEqual(
ElasticSearchEngineSpec.convert_dttm("DATETIME", dttm, db_extra=db_extra),
"DATETIME_PARSE('2019-01-02 03:04:05', 'yyyy-MM-dd HH:mm:ss')",
)
self.assertIn("Unexpected error while convert es_version", self._caplog.text)
def test_opendistro_convert_dttm(self):
"""
DB Eng Specs (opendistro): Test convert_dttm
"""
dttm = self.get_dttm()
self.assertEqual(
OpenDistroEngineSpec.convert_dttm("DATETIME", dttm, db_extra=None),
"'2019-01-02T03:04:05'",
)
def test_opendistro_sqla_column_label(self):
"""
DB Eng Specs (opendistro): Test column label
"""
test_cases = {
"Col": "Col",
"Col.keyword": "Col_keyword",
}
for original, expected in test_cases.items():
actual = OpenDistroEngineSpec.make_label_compatible(column(original).name)
self.assertEqual(actual, expected)
def test_opendistro_strip_comments(self):
"""
DB Eng Specs (opendistro): Test execute sql strip comments
"""
mock_cursor = MagicMock()
mock_cursor.execute.return_value = []
OpenDistroEngineSpec.execute(
mock_cursor, "-- some comment \nSELECT 1\n --other comment"
)
mock_cursor.execute.assert_called_once_with("SELECT 1\n")

View File

@@ -1,81 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from unittest import mock
import pytest
from superset.db_engine_specs.firebird import FirebirdEngineSpec
grain_expressions = {
None: "timestamp_column",
"PT1S": (
"CAST(CAST(timestamp_column AS DATE) "
"|| ' ' "
"|| EXTRACT(HOUR FROM timestamp_column) "
"|| ':' "
"|| EXTRACT(MINUTE FROM timestamp_column) "
"|| ':' "
"|| FLOOR(EXTRACT(SECOND FROM timestamp_column)) AS TIMESTAMP)"
),
"PT1M": (
"CAST(CAST(timestamp_column AS DATE) "
"|| ' ' "
"|| EXTRACT(HOUR FROM timestamp_column) "
"|| ':' "
"|| EXTRACT(MINUTE FROM timestamp_column) "
"|| ':00' AS TIMESTAMP)"
),
"P1D": "CAST(timestamp_column AS DATE)",
"P1M": (
"CAST(EXTRACT(YEAR FROM timestamp_column) "
"|| '-' "
"|| EXTRACT(MONTH FROM timestamp_column) "
"|| '-01' AS DATE)"
),
"P1Y": "CAST(EXTRACT(YEAR FROM timestamp_column) || '-01-01' AS DATE)",
}
@pytest.mark.parametrize("grain,expected", grain_expressions.items())
def test_time_grain_expressions(grain, expected):
assert (
FirebirdEngineSpec._time_grain_expressions[grain].format(col="timestamp_column")
== expected
)
def test_epoch_to_dttm():
assert (
FirebirdEngineSpec.epoch_to_dttm().format(col="timestamp_column")
== "DATEADD(second, timestamp_column, CAST('00:00:00' AS TIMESTAMP))"
)
def test_convert_dttm():
dttm = datetime(2021, 1, 1)
assert (
FirebirdEngineSpec.convert_dttm("timestamp", dttm)
== "CAST('2021-01-01 00:00:00' AS TIMESTAMP)"
)
assert (
FirebirdEngineSpec.convert_dttm("TIMESTAMP", dttm)
== "CAST('2021-01-01 00:00:00' AS TIMESTAMP)"
)
assert FirebirdEngineSpec.convert_dttm("TIME", dttm) == "CAST('00:00:00' AS TIME)"
assert FirebirdEngineSpec.convert_dttm("DATE", dttm) == "CAST('2021-01-01' AS DATE)"
assert FirebirdEngineSpec.convert_dttm("STRING", dttm) is None

View File

@@ -1,39 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.db_engine_specs.firebolt import FireboltEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestFireboltDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
test_cases = {
"DATE": "CAST('2019-01-02' AS DATE)",
"DATETIME": "CAST('2019-01-02T03:04:05' AS DATETIME)",
"TIMESTAMP": "CAST('2019-01-02T03:04:05' AS TIMESTAMP)",
"UNKNOWNTYPE": None,
}
for target_type, expected in test_cases.items():
actual = FireboltEngineSpec.convert_dttm(target_type, dttm)
self.assertEqual(actual, expected)
def test_epoch_to_dttm(self):
assert (
FireboltEngineSpec.epoch_to_dttm().format(col="timestamp_column")
== "from_unixtime(timestamp_column)"
)

View File

@@ -1,33 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.db_engine_specs.hana import HanaEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestHanaDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
HanaEngineSpec.convert_dttm("DATE", dttm),
"TO_DATE('2019-01-02', 'YYYY-MM-DD')",
)
self.assertEqual(
HanaEngineSpec.convert_dttm("TIMESTAMP", dttm),
"TO_TIMESTAMP('2019-01-02T03:04:05.678900', 'YYYY-MM-DD\"T\"HH24:MI:SS.ff6')",
)

View File

@@ -150,15 +150,6 @@ def test_hive_error_msg():
)
def test_convert_dttm():
dttm = datetime.strptime("2019-01-02 03:04:05.678900", "%Y-%m-%d %H:%M:%S.%f")
assert HiveEngineSpec.convert_dttm("DATE", dttm) == "CAST('2019-01-02' AS DATE)"
assert (
HiveEngineSpec.convert_dttm("TIMESTAMP", dttm)
== "CAST('2019-01-02 03:04:05.678900' AS TIMESTAMP)"
)
def test_df_to_csv() -> None:
with pytest.raises(SupersetException):
HiveEngineSpec.df_to_sql(

View File

@@ -1,32 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.db_engine_specs.impala import ImpalaEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestImpalaDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
ImpalaEngineSpec.convert_dttm("DATE", dttm), "CAST('2019-01-02' AS DATE)"
)
self.assertEqual(
ImpalaEngineSpec.convert_dttm("TIMESTAMP", dttm),
"CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)",
)

View File

@@ -1,32 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.db_engine_specs.kylin import KylinEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestKylinDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
KylinEngineSpec.convert_dttm("DATE", dttm), "CAST('2019-01-02' AS DATE)"
)
self.assertEqual(
KylinEngineSpec.convert_dttm("TIMESTAMP", dttm),
"CAST('2019-01-02 03:04:05' AS TIMESTAMP)",
)

View File

@@ -21,12 +21,7 @@ from sqlalchemy.dialects.mysql import DATE, NVARCHAR, TEXT, VARCHAR
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query
from superset.utils.core import GenericDataType
from tests.integration_tests.db_engine_specs.base_tests import (
assert_generic_types,
TestDbEngineSpec,
)
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec):
@@ -38,19 +33,6 @@ class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec):
self.assertEqual("TINY", MySQLEngineSpec.get_datatype(1))
self.assertEqual("VARCHAR", MySQLEngineSpec.get_datatype(15))
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
MySQLEngineSpec.convert_dttm("DATE", dttm),
"STR_TO_DATE('2019-01-02', '%Y-%m-%d')",
)
self.assertEqual(
MySQLEngineSpec.convert_dttm("DATETIME", dttm),
"STR_TO_DATE('2019-01-02 03:04:05.678900', '%Y-%m-%d %H:%i:%s.%f')",
)
def test_column_datatype_to_string(self):
test_cases = (
(DATE(), "DATE"),
@@ -69,32 +51,6 @@ class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec):
)
self.assertEqual(actual, expected)
def test_generic_type(self):
type_expectations = (
# Numeric
("TINYINT", GenericDataType.NUMERIC),
("SMALLINT", GenericDataType.NUMERIC),
("MEDIUMINT", GenericDataType.NUMERIC),
("INT", GenericDataType.NUMERIC),
("BIGINT", GenericDataType.NUMERIC),
("DECIMAL", GenericDataType.NUMERIC),
("FLOAT", GenericDataType.NUMERIC),
("DOUBLE", GenericDataType.NUMERIC),
("BIT", GenericDataType.NUMERIC),
# String
("CHAR", GenericDataType.STRING),
("VARCHAR", GenericDataType.STRING),
("TINYTEXT", GenericDataType.STRING),
("MEDIUMTEXT", GenericDataType.STRING),
("LONGTEXT", GenericDataType.STRING),
# Temporal
("DATE", GenericDataType.TEMPORAL),
("DATETIME", GenericDataType.TEMPORAL),
("TIMESTAMP", GenericDataType.TEMPORAL),
("TIME", GenericDataType.TEMPORAL),
)
assert_generic_types(MySQLEngineSpec, type_expectations)
def test_extract_error_message(self):
from MySQLdb._exceptions import OperationalError
@@ -239,22 +195,3 @@ class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec):
},
)
]
@unittest.mock.patch("sqlalchemy.engine.Engine.connect")
def test_get_cancel_query_id(self, engine_mock):
query = Query()
cursor_mock = engine_mock.return_value.__enter__.return_value
cursor_mock.fetchone.return_value = [123]
assert MySQLEngineSpec.get_cancel_query_id(cursor_mock, query) == 123
@unittest.mock.patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query(self, engine_mock):
query = Query()
cursor_mock = engine_mock.return_value.__enter__.return_value
assert MySQLEngineSpec.cancel_query(cursor_mock, query, 123) is True
@unittest.mock.patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query_failed(self, engine_mock):
query = Query()
cursor_mock = engine_mock.raiseError.side_effect = Exception()
assert MySQLEngineSpec.cancel_query(cursor_mock, query, 123) is False

View File

@@ -1,87 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest import mock
import pytest
from sqlalchemy import column
from sqlalchemy.dialects import oracle
from sqlalchemy.dialects.oracle import DATE, NVARCHAR, VARCHAR
from superset.db_engine_specs.oracle import OracleEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestOracleDbEngineSpec(TestDbEngineSpec):
def test_oracle_sqla_column_name_length_exceeded(self):
col = column("This_Is_32_Character_Column_Name")
label = OracleEngineSpec.make_label_compatible(col.name)
self.assertEqual(label.quote, True)
label_expected = "3b26974078683be078219674eeb8f5"
self.assertEqual(label, label_expected)
def test_oracle_time_expression_reserved_keyword_1m_grain(self):
col = column("decimal")
expr = OracleEngineSpec.get_timestamp_expr(col, None, "P1M")
result = str(expr.compile(dialect=oracle.dialect()))
self.assertEqual(result, "TRUNC(CAST(\"decimal\" as DATE), 'MONTH')")
dttm = self.get_dttm()
def test_column_datatype_to_string(self):
test_cases = (
(DATE(), "DATE"),
(VARCHAR(length=255), "VARCHAR(255 CHAR)"),
(VARCHAR(length=255, collation="utf8"), "VARCHAR(255 CHAR)"),
(NVARCHAR(length=128), "NVARCHAR2(128)"),
)
for original, expected in test_cases:
actual = OracleEngineSpec.column_datatype_to_string(
original, oracle.dialect()
)
self.assertEqual(actual, expected)
def test_fetch_data_no_description(self):
cursor = mock.MagicMock()
cursor.description = []
assert OracleEngineSpec.fetch_data(cursor) == []
def test_fetch_data(self):
cursor = mock.MagicMock()
result = ["a", "b"]
cursor.fetchall.return_value = result
assert OracleEngineSpec.fetch_data(cursor) == result
@pytest.mark.parametrize(
"date_format,expected",
[
("DATE", "TO_DATE('2019-01-02', 'YYYY-MM-DD')"),
("DATETIME", """TO_DATE('2019-01-02T03:04:05', 'YYYY-MM-DD"T"HH24:MI:SS')"""),
(
"TIMESTAMP",
"""TO_TIMESTAMP('2019-01-02T03:04:05.678900', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')""",
),
(
"timestamp",
"""TO_TIMESTAMP('2019-01-02T03:04:05.678900', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')""",
),
("Other", None),
],
)
def test_convert_dttm(date_format, expected):
dttm = TestOracleDbEngineSpec.get_dttm()
assert OracleEngineSpec.convert_dttm(date_format, dttm) == expected

View File

@@ -24,11 +24,7 @@ from superset.db_engine_specs import load_engine_specs
from superset.db_engine_specs.postgres import PostgresEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query
from superset.utils.core import GenericDataType
from tests.integration_tests.db_engine_specs.base_tests import (
assert_generic_types,
TestDbEngineSpec,
)
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.fixtures.certificates import ssl_certificate
from tests.integration_tests.fixtures.database import default_db_extra
@@ -100,29 +96,6 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
result = str(expr.compile(None, dialect=postgresql.dialect()))
self.assertEqual(result, "DATE_TRUNC('year', \"MixedCase\")")
def test_convert_dttm(self):
"""
DB Eng Specs (postgres): Test conversion to date time
"""
dttm = self.get_dttm()
self.assertEqual(
PostgresEngineSpec.convert_dttm("DATE", dttm),
"TO_DATE('2019-01-02', 'YYYY-MM-DD')",
)
self.assertEqual(
PostgresEngineSpec.convert_dttm("TIMESTAMP", dttm),
"TO_TIMESTAMP('2019-01-02 03:04:05.678900', 'YYYY-MM-DD HH24:MI:SS.US')",
)
self.assertEqual(
PostgresEngineSpec.convert_dttm("DATETIME", dttm),
"TO_TIMESTAMP('2019-01-02 03:04:05.678900', 'YYYY-MM-DD HH24:MI:SS.US')",
)
self.assertEqual(PostgresEngineSpec.convert_dttm("TIME", dttm), None)
def test_empty_dbapi_cursor_description(self):
"""
DB Eng Specs (postgres): Test empty cursor description (no columns)
@@ -541,28 +514,3 @@ def test_base_parameters_mixin():
},
"required": ["database", "host", "port", "username"],
}
def test_generic_type():
type_expectations = (
# Numeric
("SMALLINT", GenericDataType.NUMERIC),
("INTEGER", GenericDataType.NUMERIC),
("BIGINT", GenericDataType.NUMERIC),
("DECIMAL", GenericDataType.NUMERIC),
("NUMERIC", GenericDataType.NUMERIC),
("REAL", GenericDataType.NUMERIC),
("DOUBLE PRECISION", GenericDataType.NUMERIC),
("MONEY", GenericDataType.NUMERIC),
# String
("CHAR", GenericDataType.STRING),
("VARCHAR", GenericDataType.STRING),
("TEXT", GenericDataType.STRING),
# Temporal
("DATE", GenericDataType.TEMPORAL),
("TIMESTAMP", GenericDataType.TEMPORAL),
("TIME", GenericDataType.TEMPORAL),
# Boolean
("BOOLEAN", GenericDataType.BOOLEAN),
)
assert_generic_types(PostgresEngineSpec, type_expectations)

View File

@@ -25,7 +25,6 @@ from sqlalchemy.sql import select
from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql_parse import ParsedQuery
from superset.utils.core import DatasourceName, GenericDataType
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
@@ -624,42 +623,6 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
self.assertEqual(actual_data, expected_data)
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
def test_get_sqla_column_type(self):
column_spec = PrestoEngineSpec.get_column_spec("varchar(255)")
assert isinstance(column_spec.sqla_type, types.VARCHAR)
assert column_spec.sqla_type.length == 255
self.assertEqual(column_spec.generic_type, GenericDataType.STRING)
column_spec = PrestoEngineSpec.get_column_spec("varchar")
assert isinstance(column_spec.sqla_type, types.String)
assert column_spec.sqla_type.length is None
self.assertEqual(column_spec.generic_type, GenericDataType.STRING)
column_spec = PrestoEngineSpec.get_column_spec("char(10)")
assert isinstance(column_spec.sqla_type, types.CHAR)
assert column_spec.sqla_type.length == 10
self.assertEqual(column_spec.generic_type, GenericDataType.STRING)
column_spec = PrestoEngineSpec.get_column_spec("char")
assert isinstance(column_spec.sqla_type, types.CHAR)
assert column_spec.sqla_type.length is None
self.assertEqual(column_spec.generic_type, GenericDataType.STRING)
column_spec = PrestoEngineSpec.get_column_spec("integer")
assert isinstance(column_spec.sqla_type, types.Integer)
self.assertEqual(column_spec.generic_type, GenericDataType.NUMERIC)
column_spec = PrestoEngineSpec.get_column_spec("time")
assert isinstance(column_spec.sqla_type, types.Time)
self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL)
column_spec = PrestoEngineSpec.get_column_spec("timestamp")
assert isinstance(column_spec.sqla_type, types.TIMESTAMP)
self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL)
sqla_type = PrestoEngineSpec.get_sqla_column_type(None)
assert sqla_type is None
@mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_table_names")
@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.get_view_names")
def test_get_table_names(

View File

@@ -1,214 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
from typing import Any, Dict
from unittest import mock
from unittest.mock import Mock, patch
import pandas as pd
import pytest
from sqlalchemy import types
import superset.config
from superset.constants import USER_AGENT
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.utils.core import GenericDataType
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestTrinoDbEngineSpec(TestDbEngineSpec):
def test_get_extra_params(self):
database = Mock()
database.extra = json.dumps({})
database.server_cert = None
extra = TrinoEngineSpec.get_extra_params(database)
expected = {"engine_params": {"connect_args": {"source": USER_AGENT}}}
self.assertEqual(extra, expected)
expected = {
"first": 1,
"engine_params": {
"second": "two",
"connect_args": {"source": "foobar", "third": "three"},
},
}
database.extra = json.dumps(expected)
database.server_cert = None
extra = TrinoEngineSpec.get_extra_params(database)
self.assertEqual(extra, expected)
@patch("superset.utils.core.create_ssl_cert_file")
def test_get_extra_params_with_server_cert(self, create_ssl_cert_file_func: Mock):
database = Mock()
database.extra = json.dumps({})
database.server_cert = "TEST_CERT"
create_ssl_cert_file_func.return_value = "/path/to/tls.crt"
extra = TrinoEngineSpec.get_extra_params(database)
connect_args = extra.get("engine_params", {}).get("connect_args", {})
self.assertEqual(connect_args.get("http_scheme"), "https")
self.assertEqual(connect_args.get("verify"), "/path/to/tls.crt")
create_ssl_cert_file_func.assert_called_once_with(database.server_cert)
@patch("trino.auth.BasicAuthentication")
def test_auth_basic(self, auth: Mock):
database = Mock()
auth_params = {"username": "username", "password": "password"}
database.encrypted_extra = json.dumps(
{"auth_method": "basic", "auth_params": auth_params}
)
params: Dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
self.assertEqual(connect_args.get("http_scheme"), "https")
auth.assert_called_once_with(**auth_params)
@patch("trino.auth.KerberosAuthentication")
def test_auth_kerberos(self, auth: Mock):
database = Mock()
auth_params = {
"service_name": "superset",
"mutual_authentication": False,
"delegate": True,
}
database.encrypted_extra = json.dumps(
{"auth_method": "kerberos", "auth_params": auth_params}
)
params: Dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
self.assertEqual(connect_args.get("http_scheme"), "https")
auth.assert_called_once_with(**auth_params)
@patch("trino.auth.CertificateAuthentication")
def test_auth_certificate(self, auth: Mock):
database = Mock()
auth_params = {"cert": "/path/to/cert.pem", "key": "/path/to/key.pem"}
database.encrypted_extra = json.dumps(
{"auth_method": "certificate", "auth_params": auth_params}
)
params: Dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
self.assertEqual(connect_args.get("http_scheme"), "https")
auth.assert_called_once_with(**auth_params)
@patch("trino.auth.JWTAuthentication")
def test_auth_jwt(self, auth: Mock):
database = Mock()
auth_params = {"token": "jwt-token-string"}
database.encrypted_extra = json.dumps(
{"auth_method": "jwt", "auth_params": auth_params}
)
params: Dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
self.assertEqual(connect_args.get("http_scheme"), "https")
auth.assert_called_once_with(**auth_params)
def test_auth_custom_auth(self):
database = Mock()
auth_class = Mock()
auth_method = "custom_auth"
auth_params = {"params1": "params1", "params2": "params2"}
database.encrypted_extra = json.dumps(
{"auth_method": auth_method, "auth_params": auth_params}
)
with patch.dict(
"superset.config.ALLOWED_EXTRA_AUTHENTICATIONS",
{"trino": {"custom_auth": auth_class}},
clear=True,
):
params: Dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
self.assertEqual(connect_args.get("http_scheme"), "https")
auth_class.assert_called_once_with(**auth_params)
def test_auth_custom_auth_denied(self):
database = Mock()
auth_method = "my.module:TrinoAuthClass"
auth_params = {"params1": "params1", "params2": "params2"}
database.encrypted_extra = json.dumps(
{"auth_method": auth_method, "auth_params": auth_params}
)
superset.config.ALLOWED_EXTRA_AUTHENTICATIONS = {}
with pytest.raises(ValueError) as excinfo:
TrinoEngineSpec.update_params_from_encrypted_extra(database, {})
assert str(excinfo.value) == (
f"For security reason, custom authentication '{auth_method}' "
f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config"
)
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
TrinoEngineSpec.convert_dttm("TIMESTAMP", dttm),
"TIMESTAMP '2019-01-02 03:04:05.678900'",
)
self.assertEqual(
TrinoEngineSpec.convert_dttm("TIMESTAMP(3)", dttm),
"TIMESTAMP '2019-01-02 03:04:05.678900'",
)
self.assertEqual(
TrinoEngineSpec.convert_dttm("TIMESTAMP WITH TIME ZONE", dttm),
"TIMESTAMP '2019-01-02 03:04:05.678900'",
)
self.assertEqual(
TrinoEngineSpec.convert_dttm("TIMESTAMP(3) WITH TIME ZONE", dttm),
"TIMESTAMP '2019-01-02 03:04:05.678900'",
)
self.assertEqual(
TrinoEngineSpec.convert_dttm("DATE", dttm),
"DATE '2019-01-02'",
)
def test_extra_table_metadata(self):
db = mock.Mock()
db.get_indexes = mock.Mock(
return_value=[{"column_names": ["ds", "hour"], "name": "partition"}]
)
db.get_extra = mock.Mock(return_value={})
db.has_view_by_name = mock.Mock(return_value=None)
db.get_df = mock.Mock(
return_value=pd.DataFrame({"ds": ["01-01-19"], "hour": [1]})
)
result = TrinoEngineSpec.extra_table_metadata(db, "test_table", "test_schema")
assert result["partitions"]["cols"] == ["ds", "hour"]
assert result["partitions"]["latest"] == {"ds": "01-01-19", "hour": 1}

View File

@@ -39,7 +39,6 @@ from superset.utils.core import (
AdhocMetricExpressionType,
FilterOperator,
GenericDataType,
TemporalType,
)
from superset.utils.database import get_example_database
from tests.integration_tests.fixtures.birth_names_dashboard import (
@@ -805,7 +804,7 @@ def test__normalize_prequery_result_type(
def _convert_dttm(
target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
if target_type.upper() == TemporalType.TIMESTAMP:
if target_type.upper() == "TIMESTAMP":
return f"""TIME_PARSE('{dttm.isoformat(timespec="seconds")}')"""
return None