mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
chore(db_engine_specs): clean up column spec logic and add tests (#22871)
This commit is contained in:
@@ -22,7 +22,6 @@ from tests.integration_tests.test_app import app
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.models.core import Database
|
||||
from superset.utils.core import GenericDataType
|
||||
|
||||
|
||||
class TestDbEngineSpec(SupersetTestCase):
|
||||
@@ -37,16 +36,3 @@ class TestDbEngineSpec(SupersetTestCase):
|
||||
main = Database(database_name="test_database", sqlalchemy_uri="sqlite://")
|
||||
limited = engine_spec_class.apply_limit_to_sql(sql, limit, main, force)
|
||||
self.assertEqual(expected_sql, limited)
|
||||
|
||||
|
||||
def assert_generic_types(
|
||||
spec: Type[BaseEngineSpec],
|
||||
type_expectations: Tuple[Tuple[str, GenericDataType], ...],
|
||||
) -> None:
|
||||
for type_str, expected_type in type_expectations:
|
||||
column_spec = spec.get_column_spec(type_str)
|
||||
assert column_spec is not None
|
||||
actual_type = column_spec.generic_type
|
||||
assert (
|
||||
actual_type == expected_type
|
||||
), f"{type_str} should be {expected_type.name} but is {actual_type.name}"
|
||||
|
||||
@@ -48,23 +48,6 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
|
||||
actual = BigQueryEngineSpec.make_label_compatible(column(original).name)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_convert_dttm(self):
|
||||
"""
|
||||
DB Eng Specs (bigquery): Test conversion to date time
|
||||
"""
|
||||
dttm = self.get_dttm()
|
||||
test_cases = {
|
||||
"DATE": "CAST('2019-01-02' AS DATE)",
|
||||
"DATETIME": "CAST('2019-01-02T03:04:05.678900' AS DATETIME)",
|
||||
"TIMESTAMP": "CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)",
|
||||
"TIME": "CAST('03:04:05.678900' AS TIME)",
|
||||
"UNKNOWNTYPE": None,
|
||||
}
|
||||
|
||||
for target_type, expected in test_cases.items():
|
||||
actual = BigQueryEngineSpec.convert_dttm(target_type, dttm)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_timegrain_expressions(self):
|
||||
"""
|
||||
DB Eng Specs (bigquery): Test time grain expressions
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||
from superset.db_engine_specs.crate import CrateEngineSpec
|
||||
from superset.models.core import Database
|
||||
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||
|
||||
|
||||
class TestCrateDbEngineSpec(TestDbEngineSpec):
|
||||
def test_convert_dttm(self):
|
||||
"""
|
||||
DB Eng Specs (crate): Test conversion to date time
|
||||
"""
|
||||
dttm = self.get_dttm()
|
||||
assert CrateEngineSpec.convert_dttm("TIMESTAMP", dttm) == str(
|
||||
dttm.timestamp() * 1000
|
||||
)
|
||||
|
||||
def test_epoch_to_dttm(self):
|
||||
"""
|
||||
DB Eng Specs (crate): Test epoch to dttm
|
||||
"""
|
||||
assert CrateEngineSpec.epoch_to_dttm() == "{col} * 1000"
|
||||
|
||||
def test_epoch_ms_to_dttm(self):
|
||||
"""
|
||||
DB Eng Specs (crate): Test epoch ms to dttm
|
||||
"""
|
||||
assert CrateEngineSpec.epoch_ms_to_dttm() == "{col}"
|
||||
|
||||
def test_alter_new_orm_column(self):
|
||||
"""
|
||||
DB Eng Specs (crate): Test alter orm column
|
||||
"""
|
||||
database = Database(database_name="crate", sqlalchemy_uri="crate://db")
|
||||
tbl = SqlaTable(table_name="druid_tbl", database=database)
|
||||
col = TableColumn(column_name="ts", type="TIMESTAMP", table=tbl)
|
||||
CrateEngineSpec.alter_new_orm_column(col)
|
||||
assert col.python_date_format == "epoch_ms"
|
||||
@@ -14,18 +14,11 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from textwrap import dedent
|
||||
from unittest import mock
|
||||
|
||||
from sqlalchemy import column, literal_column
|
||||
|
||||
from superset.constants import USER_AGENT
|
||||
from superset.db_engine_specs import get_engine_spec
|
||||
from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
|
||||
from tests.integration_tests.db_engine_specs.base_tests import (
|
||||
assert_generic_types,
|
||||
TestDbEngineSpec,
|
||||
)
|
||||
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||
from tests.integration_tests.fixtures.certificates import ssl_certificate
|
||||
from tests.integration_tests.fixtures.database import default_db_extra
|
||||
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from superset.db_engine_specs.dremio import DremioEngineSpec
|
||||
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||
|
||||
|
||||
class TestDremioDbEngineSpec(TestDbEngineSpec):
|
||||
def test_convert_dttm(self):
|
||||
dttm = self.get_dttm()
|
||||
|
||||
self.assertEqual(
|
||||
DremioEngineSpec.convert_dttm("DATE", dttm),
|
||||
"TO_DATE('2019-01-02', 'YYYY-MM-DD')",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
DremioEngineSpec.convert_dttm("TIMESTAMP", dttm),
|
||||
"TO_TIMESTAMP('2019-01-02 03:04:05.678', 'YYYY-MM-DD HH24:MI:SS.FFF')",
|
||||
)
|
||||
@@ -1,33 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from superset.db_engine_specs.drill import DrillEngineSpec
|
||||
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||
|
||||
|
||||
class TestDrillDbEngineSpec(TestDbEngineSpec):
|
||||
def test_convert_dttm(self):
|
||||
dttm = self.get_dttm()
|
||||
|
||||
self.assertEqual(
|
||||
DrillEngineSpec.convert_dttm("DATE", dttm),
|
||||
"TO_DATE('2019-01-02', 'yyyy-MM-dd')",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
DrillEngineSpec.convert_dttm("TIMESTAMP", dttm),
|
||||
"TO_TIMESTAMP('2019-01-02 03:04:05', 'yyyy-MM-dd HH:mm:ss')",
|
||||
)
|
||||
@@ -1,78 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from unittest import mock
|
||||
|
||||
from sqlalchemy import column
|
||||
|
||||
from superset.db_engine_specs.druid import DruidEngineSpec
|
||||
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||
from tests.integration_tests.fixtures.certificates import ssl_certificate
|
||||
from tests.integration_tests.fixtures.database import default_db_extra
|
||||
|
||||
|
||||
class TestDruidDbEngineSpec(TestDbEngineSpec):
|
||||
def test_convert_dttm(self):
|
||||
dttm = self.get_dttm()
|
||||
|
||||
self.assertEqual(
|
||||
DruidEngineSpec.convert_dttm("DATETIME", dttm),
|
||||
"TIME_PARSE('2019-01-02T03:04:05')",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
DruidEngineSpec.convert_dttm("TIMESTAMP", dttm),
|
||||
"TIME_PARSE('2019-01-02T03:04:05')",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
DruidEngineSpec.convert_dttm("DATE", dttm),
|
||||
"CAST(TIME_PARSE('2019-01-02') AS DATE)",
|
||||
)
|
||||
|
||||
def test_timegrain_expressions(self):
|
||||
"""
|
||||
DB Eng Specs (druid): Test time grain expressions
|
||||
"""
|
||||
col = "__time"
|
||||
sqla_col = column(col)
|
||||
test_cases = {
|
||||
"PT1S": f"TIME_FLOOR(CAST({col} AS TIMESTAMP), 'PT1S')",
|
||||
"PT5M": f"TIME_FLOOR(CAST({col} AS TIMESTAMP), 'PT5M')",
|
||||
"P1W/1970-01-03T00:00:00Z": f"TIME_SHIFT(TIME_FLOOR(TIME_SHIFT(CAST({col} AS TIMESTAMP), 'P1D', 1), 'P1W'), 'P1D', 5)",
|
||||
"1969-12-28T00:00:00Z/P1W": f"TIME_SHIFT(TIME_FLOOR(TIME_SHIFT(CAST({col} AS TIMESTAMP), 'P1D', 1), 'P1W'), 'P1D', -1)",
|
||||
}
|
||||
for grain, expected in test_cases.items():
|
||||
actual = DruidEngineSpec.get_timestamp_expr(
|
||||
col=sqla_col, pdf=None, time_grain=grain
|
||||
)
|
||||
self.assertEqual(str(actual), expected)
|
||||
|
||||
def test_extras_without_ssl(self):
|
||||
db = mock.Mock()
|
||||
db.extra = default_db_extra
|
||||
db.server_cert = None
|
||||
extras = DruidEngineSpec.get_extra_params(db)
|
||||
assert "connect_args" not in extras["engine_params"]
|
||||
|
||||
def test_extras_with_ssl(self):
|
||||
db = mock.Mock()
|
||||
db.extra = default_db_extra
|
||||
db.server_cert = ssl_certificate
|
||||
extras = DruidEngineSpec.get_extra_params(db)
|
||||
connect_args = extras["engine_params"]["connect_args"]
|
||||
assert connect_args["scheme"] == "https"
|
||||
assert "ssl_verify_cert" in connect_args
|
||||
@@ -1,104 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import column
|
||||
|
||||
from superset.db_engine_specs.elasticsearch import (
|
||||
ElasticSearchEngineSpec,
|
||||
OpenDistroEngineSpec,
|
||||
)
|
||||
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||
|
||||
|
||||
class TestElasticSearchDbEngineSpec(TestDbEngineSpec):
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, caplog):
|
||||
self._caplog = caplog
|
||||
|
||||
def test_convert_dttm(self):
|
||||
dttm = self.get_dttm()
|
||||
|
||||
self.assertEqual(
|
||||
ElasticSearchEngineSpec.convert_dttm("DATETIME", dttm, db_extra=None),
|
||||
"CAST('2019-01-02T03:04:05' AS DATETIME)",
|
||||
)
|
||||
|
||||
def test_convert_dttm2(self):
|
||||
"""
|
||||
ES 7.8 and above versions need to use the DATETIME_PARSE function to
|
||||
solve the time zone problem
|
||||
"""
|
||||
dttm = self.get_dttm()
|
||||
db_extra = {"version": "7.8"}
|
||||
|
||||
self.assertEqual(
|
||||
ElasticSearchEngineSpec.convert_dttm("DATETIME", dttm, db_extra=db_extra),
|
||||
"DATETIME_PARSE('2019-01-02 03:04:05', 'yyyy-MM-dd HH:mm:ss')",
|
||||
)
|
||||
|
||||
def test_convert_dttm3(self):
|
||||
dttm = self.get_dttm()
|
||||
db_extra = {"version": 7.8}
|
||||
|
||||
self.assertEqual(
|
||||
ElasticSearchEngineSpec.convert_dttm("DATETIME", dttm, db_extra=db_extra),
|
||||
"CAST('2019-01-02T03:04:05' AS DATETIME)",
|
||||
)
|
||||
|
||||
self.assertNotEqual(
|
||||
ElasticSearchEngineSpec.convert_dttm("DATETIME", dttm, db_extra=db_extra),
|
||||
"DATETIME_PARSE('2019-01-02 03:04:05', 'yyyy-MM-dd HH:mm:ss')",
|
||||
)
|
||||
|
||||
self.assertIn("Unexpected error while convert es_version", self._caplog.text)
|
||||
|
||||
def test_opendistro_convert_dttm(self):
|
||||
"""
|
||||
DB Eng Specs (opendistro): Test convert_dttm
|
||||
"""
|
||||
dttm = self.get_dttm()
|
||||
|
||||
self.assertEqual(
|
||||
OpenDistroEngineSpec.convert_dttm("DATETIME", dttm, db_extra=None),
|
||||
"'2019-01-02T03:04:05'",
|
||||
)
|
||||
|
||||
def test_opendistro_sqla_column_label(self):
|
||||
"""
|
||||
DB Eng Specs (opendistro): Test column label
|
||||
"""
|
||||
test_cases = {
|
||||
"Col": "Col",
|
||||
"Col.keyword": "Col_keyword",
|
||||
}
|
||||
for original, expected in test_cases.items():
|
||||
actual = OpenDistroEngineSpec.make_label_compatible(column(original).name)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_opendistro_strip_comments(self):
|
||||
"""
|
||||
DB Eng Specs (opendistro): Test execute sql strip comments
|
||||
"""
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.execute.return_value = []
|
||||
|
||||
OpenDistroEngineSpec.execute(
|
||||
mock_cursor, "-- some comment \nSELECT 1\n --other comment"
|
||||
)
|
||||
mock_cursor.execute.assert_called_once_with("SELECT 1\n")
|
||||
@@ -1,81 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from datetime import datetime
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from superset.db_engine_specs.firebird import FirebirdEngineSpec
|
||||
|
||||
grain_expressions = {
|
||||
None: "timestamp_column",
|
||||
"PT1S": (
|
||||
"CAST(CAST(timestamp_column AS DATE) "
|
||||
"|| ' ' "
|
||||
"|| EXTRACT(HOUR FROM timestamp_column) "
|
||||
"|| ':' "
|
||||
"|| EXTRACT(MINUTE FROM timestamp_column) "
|
||||
"|| ':' "
|
||||
"|| FLOOR(EXTRACT(SECOND FROM timestamp_column)) AS TIMESTAMP)"
|
||||
),
|
||||
"PT1M": (
|
||||
"CAST(CAST(timestamp_column AS DATE) "
|
||||
"|| ' ' "
|
||||
"|| EXTRACT(HOUR FROM timestamp_column) "
|
||||
"|| ':' "
|
||||
"|| EXTRACT(MINUTE FROM timestamp_column) "
|
||||
"|| ':00' AS TIMESTAMP)"
|
||||
),
|
||||
"P1D": "CAST(timestamp_column AS DATE)",
|
||||
"P1M": (
|
||||
"CAST(EXTRACT(YEAR FROM timestamp_column) "
|
||||
"|| '-' "
|
||||
"|| EXTRACT(MONTH FROM timestamp_column) "
|
||||
"|| '-01' AS DATE)"
|
||||
),
|
||||
"P1Y": "CAST(EXTRACT(YEAR FROM timestamp_column) || '-01-01' AS DATE)",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("grain,expected", grain_expressions.items())
|
||||
def test_time_grain_expressions(grain, expected):
|
||||
assert (
|
||||
FirebirdEngineSpec._time_grain_expressions[grain].format(col="timestamp_column")
|
||||
== expected
|
||||
)
|
||||
|
||||
|
||||
def test_epoch_to_dttm():
|
||||
assert (
|
||||
FirebirdEngineSpec.epoch_to_dttm().format(col="timestamp_column")
|
||||
== "DATEADD(second, timestamp_column, CAST('00:00:00' AS TIMESTAMP))"
|
||||
)
|
||||
|
||||
|
||||
def test_convert_dttm():
|
||||
dttm = datetime(2021, 1, 1)
|
||||
assert (
|
||||
FirebirdEngineSpec.convert_dttm("timestamp", dttm)
|
||||
== "CAST('2021-01-01 00:00:00' AS TIMESTAMP)"
|
||||
)
|
||||
assert (
|
||||
FirebirdEngineSpec.convert_dttm("TIMESTAMP", dttm)
|
||||
== "CAST('2021-01-01 00:00:00' AS TIMESTAMP)"
|
||||
)
|
||||
assert FirebirdEngineSpec.convert_dttm("TIME", dttm) == "CAST('00:00:00' AS TIME)"
|
||||
assert FirebirdEngineSpec.convert_dttm("DATE", dttm) == "CAST('2021-01-01' AS DATE)"
|
||||
assert FirebirdEngineSpec.convert_dttm("STRING", dttm) is None
|
||||
@@ -1,39 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from superset.db_engine_specs.firebolt import FireboltEngineSpec
|
||||
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||
|
||||
|
||||
class TestFireboltDbEngineSpec(TestDbEngineSpec):
|
||||
def test_convert_dttm(self):
|
||||
dttm = self.get_dttm()
|
||||
test_cases = {
|
||||
"DATE": "CAST('2019-01-02' AS DATE)",
|
||||
"DATETIME": "CAST('2019-01-02T03:04:05' AS DATETIME)",
|
||||
"TIMESTAMP": "CAST('2019-01-02T03:04:05' AS TIMESTAMP)",
|
||||
"UNKNOWNTYPE": None,
|
||||
}
|
||||
|
||||
for target_type, expected in test_cases.items():
|
||||
actual = FireboltEngineSpec.convert_dttm(target_type, dttm)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_epoch_to_dttm(self):
|
||||
assert (
|
||||
FireboltEngineSpec.epoch_to_dttm().format(col="timestamp_column")
|
||||
== "from_unixtime(timestamp_column)"
|
||||
)
|
||||
@@ -1,33 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from superset.db_engine_specs.hana import HanaEngineSpec
|
||||
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||
|
||||
|
||||
class TestHanaDbEngineSpec(TestDbEngineSpec):
|
||||
def test_convert_dttm(self):
|
||||
dttm = self.get_dttm()
|
||||
|
||||
self.assertEqual(
|
||||
HanaEngineSpec.convert_dttm("DATE", dttm),
|
||||
"TO_DATE('2019-01-02', 'YYYY-MM-DD')",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
HanaEngineSpec.convert_dttm("TIMESTAMP", dttm),
|
||||
"TO_TIMESTAMP('2019-01-02T03:04:05.678900', 'YYYY-MM-DD\"T\"HH24:MI:SS.ff6')",
|
||||
)
|
||||
@@ -150,15 +150,6 @@ def test_hive_error_msg():
|
||||
)
|
||||
|
||||
|
||||
def test_convert_dttm():
|
||||
dttm = datetime.strptime("2019-01-02 03:04:05.678900", "%Y-%m-%d %H:%M:%S.%f")
|
||||
assert HiveEngineSpec.convert_dttm("DATE", dttm) == "CAST('2019-01-02' AS DATE)"
|
||||
assert (
|
||||
HiveEngineSpec.convert_dttm("TIMESTAMP", dttm)
|
||||
== "CAST('2019-01-02 03:04:05.678900' AS TIMESTAMP)"
|
||||
)
|
||||
|
||||
|
||||
def test_df_to_csv() -> None:
|
||||
with pytest.raises(SupersetException):
|
||||
HiveEngineSpec.df_to_sql(
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from superset.db_engine_specs.impala import ImpalaEngineSpec
|
||||
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||
|
||||
|
||||
class TestImpalaDbEngineSpec(TestDbEngineSpec):
|
||||
def test_convert_dttm(self):
|
||||
dttm = self.get_dttm()
|
||||
|
||||
self.assertEqual(
|
||||
ImpalaEngineSpec.convert_dttm("DATE", dttm), "CAST('2019-01-02' AS DATE)"
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
ImpalaEngineSpec.convert_dttm("TIMESTAMP", dttm),
|
||||
"CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)",
|
||||
)
|
||||
@@ -1,32 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from superset.db_engine_specs.kylin import KylinEngineSpec
|
||||
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||
|
||||
|
||||
class TestKylinDbEngineSpec(TestDbEngineSpec):
|
||||
def test_convert_dttm(self):
|
||||
dttm = self.get_dttm()
|
||||
|
||||
self.assertEqual(
|
||||
KylinEngineSpec.convert_dttm("DATE", dttm), "CAST('2019-01-02' AS DATE)"
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
KylinEngineSpec.convert_dttm("TIMESTAMP", dttm),
|
||||
"CAST('2019-01-02 03:04:05' AS TIMESTAMP)",
|
||||
)
|
||||
@@ -21,12 +21,7 @@ from sqlalchemy.dialects.mysql import DATE, NVARCHAR, TEXT, VARCHAR
|
||||
|
||||
from superset.db_engine_specs.mysql import MySQLEngineSpec
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.utils.core import GenericDataType
|
||||
from tests.integration_tests.db_engine_specs.base_tests import (
|
||||
assert_generic_types,
|
||||
TestDbEngineSpec,
|
||||
)
|
||||
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||
|
||||
|
||||
class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec):
|
||||
@@ -38,19 +33,6 @@ class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec):
|
||||
self.assertEqual("TINY", MySQLEngineSpec.get_datatype(1))
|
||||
self.assertEqual("VARCHAR", MySQLEngineSpec.get_datatype(15))
|
||||
|
||||
def test_convert_dttm(self):
|
||||
dttm = self.get_dttm()
|
||||
|
||||
self.assertEqual(
|
||||
MySQLEngineSpec.convert_dttm("DATE", dttm),
|
||||
"STR_TO_DATE('2019-01-02', '%Y-%m-%d')",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
MySQLEngineSpec.convert_dttm("DATETIME", dttm),
|
||||
"STR_TO_DATE('2019-01-02 03:04:05.678900', '%Y-%m-%d %H:%i:%s.%f')",
|
||||
)
|
||||
|
||||
def test_column_datatype_to_string(self):
|
||||
test_cases = (
|
||||
(DATE(), "DATE"),
|
||||
@@ -69,32 +51,6 @@ class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec):
|
||||
)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_generic_type(self):
|
||||
type_expectations = (
|
||||
# Numeric
|
||||
("TINYINT", GenericDataType.NUMERIC),
|
||||
("SMALLINT", GenericDataType.NUMERIC),
|
||||
("MEDIUMINT", GenericDataType.NUMERIC),
|
||||
("INT", GenericDataType.NUMERIC),
|
||||
("BIGINT", GenericDataType.NUMERIC),
|
||||
("DECIMAL", GenericDataType.NUMERIC),
|
||||
("FLOAT", GenericDataType.NUMERIC),
|
||||
("DOUBLE", GenericDataType.NUMERIC),
|
||||
("BIT", GenericDataType.NUMERIC),
|
||||
# String
|
||||
("CHAR", GenericDataType.STRING),
|
||||
("VARCHAR", GenericDataType.STRING),
|
||||
("TINYTEXT", GenericDataType.STRING),
|
||||
("MEDIUMTEXT", GenericDataType.STRING),
|
||||
("LONGTEXT", GenericDataType.STRING),
|
||||
# Temporal
|
||||
("DATE", GenericDataType.TEMPORAL),
|
||||
("DATETIME", GenericDataType.TEMPORAL),
|
||||
("TIMESTAMP", GenericDataType.TEMPORAL),
|
||||
("TIME", GenericDataType.TEMPORAL),
|
||||
)
|
||||
assert_generic_types(MySQLEngineSpec, type_expectations)
|
||||
|
||||
def test_extract_error_message(self):
|
||||
from MySQLdb._exceptions import OperationalError
|
||||
|
||||
@@ -239,22 +195,3 @@ class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec):
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
@unittest.mock.patch("sqlalchemy.engine.Engine.connect")
|
||||
def test_get_cancel_query_id(self, engine_mock):
|
||||
query = Query()
|
||||
cursor_mock = engine_mock.return_value.__enter__.return_value
|
||||
cursor_mock.fetchone.return_value = [123]
|
||||
assert MySQLEngineSpec.get_cancel_query_id(cursor_mock, query) == 123
|
||||
|
||||
@unittest.mock.patch("sqlalchemy.engine.Engine.connect")
|
||||
def test_cancel_query(self, engine_mock):
|
||||
query = Query()
|
||||
cursor_mock = engine_mock.return_value.__enter__.return_value
|
||||
assert MySQLEngineSpec.cancel_query(cursor_mock, query, 123) is True
|
||||
|
||||
@unittest.mock.patch("sqlalchemy.engine.Engine.connect")
|
||||
def test_cancel_query_failed(self, engine_mock):
|
||||
query = Query()
|
||||
cursor_mock = engine_mock.raiseError.side_effect = Exception()
|
||||
assert MySQLEngineSpec.cancel_query(cursor_mock, query, 123) is False
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import column
|
||||
from sqlalchemy.dialects import oracle
|
||||
from sqlalchemy.dialects.oracle import DATE, NVARCHAR, VARCHAR
|
||||
|
||||
from superset.db_engine_specs.oracle import OracleEngineSpec
|
||||
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||
|
||||
|
||||
class TestOracleDbEngineSpec(TestDbEngineSpec):
|
||||
def test_oracle_sqla_column_name_length_exceeded(self):
|
||||
col = column("This_Is_32_Character_Column_Name")
|
||||
label = OracleEngineSpec.make_label_compatible(col.name)
|
||||
self.assertEqual(label.quote, True)
|
||||
label_expected = "3b26974078683be078219674eeb8f5"
|
||||
self.assertEqual(label, label_expected)
|
||||
|
||||
def test_oracle_time_expression_reserved_keyword_1m_grain(self):
|
||||
col = column("decimal")
|
||||
expr = OracleEngineSpec.get_timestamp_expr(col, None, "P1M")
|
||||
result = str(expr.compile(dialect=oracle.dialect()))
|
||||
self.assertEqual(result, "TRUNC(CAST(\"decimal\" as DATE), 'MONTH')")
|
||||
dttm = self.get_dttm()
|
||||
|
||||
def test_column_datatype_to_string(self):
|
||||
test_cases = (
|
||||
(DATE(), "DATE"),
|
||||
(VARCHAR(length=255), "VARCHAR(255 CHAR)"),
|
||||
(VARCHAR(length=255, collation="utf8"), "VARCHAR(255 CHAR)"),
|
||||
(NVARCHAR(length=128), "NVARCHAR2(128)"),
|
||||
)
|
||||
|
||||
for original, expected in test_cases:
|
||||
actual = OracleEngineSpec.column_datatype_to_string(
|
||||
original, oracle.dialect()
|
||||
)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_fetch_data_no_description(self):
|
||||
cursor = mock.MagicMock()
|
||||
cursor.description = []
|
||||
assert OracleEngineSpec.fetch_data(cursor) == []
|
||||
|
||||
def test_fetch_data(self):
|
||||
cursor = mock.MagicMock()
|
||||
result = ["a", "b"]
|
||||
cursor.fetchall.return_value = result
|
||||
assert OracleEngineSpec.fetch_data(cursor) == result
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"date_format,expected",
|
||||
[
|
||||
("DATE", "TO_DATE('2019-01-02', 'YYYY-MM-DD')"),
|
||||
("DATETIME", """TO_DATE('2019-01-02T03:04:05', 'YYYY-MM-DD"T"HH24:MI:SS')"""),
|
||||
(
|
||||
"TIMESTAMP",
|
||||
"""TO_TIMESTAMP('2019-01-02T03:04:05.678900', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')""",
|
||||
),
|
||||
(
|
||||
"timestamp",
|
||||
"""TO_TIMESTAMP('2019-01-02T03:04:05.678900', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')""",
|
||||
),
|
||||
("Other", None),
|
||||
],
|
||||
)
|
||||
def test_convert_dttm(date_format, expected):
|
||||
dttm = TestOracleDbEngineSpec.get_dttm()
|
||||
assert OracleEngineSpec.convert_dttm(date_format, dttm) == expected
|
||||
@@ -24,11 +24,7 @@ from superset.db_engine_specs import load_engine_specs
|
||||
from superset.db_engine_specs.postgres import PostgresEngineSpec
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.utils.core import GenericDataType
|
||||
from tests.integration_tests.db_engine_specs.base_tests import (
|
||||
assert_generic_types,
|
||||
TestDbEngineSpec,
|
||||
)
|
||||
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||
from tests.integration_tests.fixtures.certificates import ssl_certificate
|
||||
from tests.integration_tests.fixtures.database import default_db_extra
|
||||
|
||||
@@ -100,29 +96,6 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
|
||||
result = str(expr.compile(None, dialect=postgresql.dialect()))
|
||||
self.assertEqual(result, "DATE_TRUNC('year', \"MixedCase\")")
|
||||
|
||||
def test_convert_dttm(self):
|
||||
"""
|
||||
DB Eng Specs (postgres): Test conversion to date time
|
||||
"""
|
||||
dttm = self.get_dttm()
|
||||
|
||||
self.assertEqual(
|
||||
PostgresEngineSpec.convert_dttm("DATE", dttm),
|
||||
"TO_DATE('2019-01-02', 'YYYY-MM-DD')",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
PostgresEngineSpec.convert_dttm("TIMESTAMP", dttm),
|
||||
"TO_TIMESTAMP('2019-01-02 03:04:05.678900', 'YYYY-MM-DD HH24:MI:SS.US')",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
PostgresEngineSpec.convert_dttm("DATETIME", dttm),
|
||||
"TO_TIMESTAMP('2019-01-02 03:04:05.678900', 'YYYY-MM-DD HH24:MI:SS.US')",
|
||||
)
|
||||
|
||||
self.assertEqual(PostgresEngineSpec.convert_dttm("TIME", dttm), None)
|
||||
|
||||
def test_empty_dbapi_cursor_description(self):
|
||||
"""
|
||||
DB Eng Specs (postgres): Test empty cursor description (no columns)
|
||||
@@ -541,28 +514,3 @@ def test_base_parameters_mixin():
|
||||
},
|
||||
"required": ["database", "host", "port", "username"],
|
||||
}
|
||||
|
||||
|
||||
def test_generic_type():
|
||||
type_expectations = (
|
||||
# Numeric
|
||||
("SMALLINT", GenericDataType.NUMERIC),
|
||||
("INTEGER", GenericDataType.NUMERIC),
|
||||
("BIGINT", GenericDataType.NUMERIC),
|
||||
("DECIMAL", GenericDataType.NUMERIC),
|
||||
("NUMERIC", GenericDataType.NUMERIC),
|
||||
("REAL", GenericDataType.NUMERIC),
|
||||
("DOUBLE PRECISION", GenericDataType.NUMERIC),
|
||||
("MONEY", GenericDataType.NUMERIC),
|
||||
# String
|
||||
("CHAR", GenericDataType.STRING),
|
||||
("VARCHAR", GenericDataType.STRING),
|
||||
("TEXT", GenericDataType.STRING),
|
||||
# Temporal
|
||||
("DATE", GenericDataType.TEMPORAL),
|
||||
("TIMESTAMP", GenericDataType.TEMPORAL),
|
||||
("TIME", GenericDataType.TEMPORAL),
|
||||
# Boolean
|
||||
("BOOLEAN", GenericDataType.BOOLEAN),
|
||||
)
|
||||
assert_generic_types(PostgresEngineSpec, type_expectations)
|
||||
|
||||
@@ -25,7 +25,6 @@ from sqlalchemy.sql import select
|
||||
from superset.db_engine_specs.presto import PrestoEngineSpec
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.sql_parse import ParsedQuery
|
||||
from superset.utils.core import DatasourceName, GenericDataType
|
||||
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||
|
||||
|
||||
@@ -624,42 +623,6 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
||||
self.assertEqual(actual_data, expected_data)
|
||||
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
|
||||
|
||||
def test_get_sqla_column_type(self):
|
||||
column_spec = PrestoEngineSpec.get_column_spec("varchar(255)")
|
||||
assert isinstance(column_spec.sqla_type, types.VARCHAR)
|
||||
assert column_spec.sqla_type.length == 255
|
||||
self.assertEqual(column_spec.generic_type, GenericDataType.STRING)
|
||||
|
||||
column_spec = PrestoEngineSpec.get_column_spec("varchar")
|
||||
assert isinstance(column_spec.sqla_type, types.String)
|
||||
assert column_spec.sqla_type.length is None
|
||||
self.assertEqual(column_spec.generic_type, GenericDataType.STRING)
|
||||
|
||||
column_spec = PrestoEngineSpec.get_column_spec("char(10)")
|
||||
assert isinstance(column_spec.sqla_type, types.CHAR)
|
||||
assert column_spec.sqla_type.length == 10
|
||||
self.assertEqual(column_spec.generic_type, GenericDataType.STRING)
|
||||
|
||||
column_spec = PrestoEngineSpec.get_column_spec("char")
|
||||
assert isinstance(column_spec.sqla_type, types.CHAR)
|
||||
assert column_spec.sqla_type.length is None
|
||||
self.assertEqual(column_spec.generic_type, GenericDataType.STRING)
|
||||
|
||||
column_spec = PrestoEngineSpec.get_column_spec("integer")
|
||||
assert isinstance(column_spec.sqla_type, types.Integer)
|
||||
self.assertEqual(column_spec.generic_type, GenericDataType.NUMERIC)
|
||||
|
||||
column_spec = PrestoEngineSpec.get_column_spec("time")
|
||||
assert isinstance(column_spec.sqla_type, types.Time)
|
||||
self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL)
|
||||
|
||||
column_spec = PrestoEngineSpec.get_column_spec("timestamp")
|
||||
assert isinstance(column_spec.sqla_type, types.TIMESTAMP)
|
||||
self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL)
|
||||
|
||||
sqla_type = PrestoEngineSpec.get_sqla_column_type(None)
|
||||
assert sqla_type is None
|
||||
|
||||
@mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_table_names")
|
||||
@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.get_view_names")
|
||||
def test_get_table_names(
|
||||
|
||||
@@ -1,214 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import json
|
||||
from typing import Any, Dict
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from sqlalchemy import types
|
||||
|
||||
import superset.config
|
||||
from superset.constants import USER_AGENT
|
||||
from superset.db_engine_specs.trino import TrinoEngineSpec
|
||||
from superset.utils.core import GenericDataType
|
||||
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||
|
||||
|
||||
class TestTrinoDbEngineSpec(TestDbEngineSpec):
|
||||
def test_get_extra_params(self):
|
||||
database = Mock()
|
||||
|
||||
database.extra = json.dumps({})
|
||||
database.server_cert = None
|
||||
extra = TrinoEngineSpec.get_extra_params(database)
|
||||
expected = {"engine_params": {"connect_args": {"source": USER_AGENT}}}
|
||||
self.assertEqual(extra, expected)
|
||||
|
||||
expected = {
|
||||
"first": 1,
|
||||
"engine_params": {
|
||||
"second": "two",
|
||||
"connect_args": {"source": "foobar", "third": "three"},
|
||||
},
|
||||
}
|
||||
database.extra = json.dumps(expected)
|
||||
database.server_cert = None
|
||||
extra = TrinoEngineSpec.get_extra_params(database)
|
||||
self.assertEqual(extra, expected)
|
||||
|
||||
@patch("superset.utils.core.create_ssl_cert_file")
|
||||
def test_get_extra_params_with_server_cert(self, create_ssl_cert_file_func: Mock):
|
||||
database = Mock()
|
||||
|
||||
database.extra = json.dumps({})
|
||||
database.server_cert = "TEST_CERT"
|
||||
create_ssl_cert_file_func.return_value = "/path/to/tls.crt"
|
||||
extra = TrinoEngineSpec.get_extra_params(database)
|
||||
|
||||
connect_args = extra.get("engine_params", {}).get("connect_args", {})
|
||||
self.assertEqual(connect_args.get("http_scheme"), "https")
|
||||
self.assertEqual(connect_args.get("verify"), "/path/to/tls.crt")
|
||||
create_ssl_cert_file_func.assert_called_once_with(database.server_cert)
|
||||
|
||||
@patch("trino.auth.BasicAuthentication")
|
||||
def test_auth_basic(self, auth: Mock):
|
||||
database = Mock()
|
||||
|
||||
auth_params = {"username": "username", "password": "password"}
|
||||
database.encrypted_extra = json.dumps(
|
||||
{"auth_method": "basic", "auth_params": auth_params}
|
||||
)
|
||||
|
||||
params: Dict[str, Any] = {}
|
||||
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
|
||||
connect_args = params.setdefault("connect_args", {})
|
||||
self.assertEqual(connect_args.get("http_scheme"), "https")
|
||||
auth.assert_called_once_with(**auth_params)
|
||||
|
||||
@patch("trino.auth.KerberosAuthentication")
|
||||
def test_auth_kerberos(self, auth: Mock):
|
||||
database = Mock()
|
||||
|
||||
auth_params = {
|
||||
"service_name": "superset",
|
||||
"mutual_authentication": False,
|
||||
"delegate": True,
|
||||
}
|
||||
database.encrypted_extra = json.dumps(
|
||||
{"auth_method": "kerberos", "auth_params": auth_params}
|
||||
)
|
||||
|
||||
params: Dict[str, Any] = {}
|
||||
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
|
||||
connect_args = params.setdefault("connect_args", {})
|
||||
self.assertEqual(connect_args.get("http_scheme"), "https")
|
||||
auth.assert_called_once_with(**auth_params)
|
||||
|
||||
@patch("trino.auth.CertificateAuthentication")
|
||||
def test_auth_certificate(self, auth: Mock):
|
||||
database = Mock()
|
||||
|
||||
auth_params = {"cert": "/path/to/cert.pem", "key": "/path/to/key.pem"}
|
||||
database.encrypted_extra = json.dumps(
|
||||
{"auth_method": "certificate", "auth_params": auth_params}
|
||||
)
|
||||
|
||||
params: Dict[str, Any] = {}
|
||||
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
|
||||
connect_args = params.setdefault("connect_args", {})
|
||||
self.assertEqual(connect_args.get("http_scheme"), "https")
|
||||
auth.assert_called_once_with(**auth_params)
|
||||
|
||||
@patch("trino.auth.JWTAuthentication")
|
||||
def test_auth_jwt(self, auth: Mock):
|
||||
database = Mock()
|
||||
|
||||
auth_params = {"token": "jwt-token-string"}
|
||||
database.encrypted_extra = json.dumps(
|
||||
{"auth_method": "jwt", "auth_params": auth_params}
|
||||
)
|
||||
|
||||
params: Dict[str, Any] = {}
|
||||
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
|
||||
connect_args = params.setdefault("connect_args", {})
|
||||
self.assertEqual(connect_args.get("http_scheme"), "https")
|
||||
auth.assert_called_once_with(**auth_params)
|
||||
|
||||
def test_auth_custom_auth(self):
|
||||
database = Mock()
|
||||
auth_class = Mock()
|
||||
|
||||
auth_method = "custom_auth"
|
||||
auth_params = {"params1": "params1", "params2": "params2"}
|
||||
database.encrypted_extra = json.dumps(
|
||||
{"auth_method": auth_method, "auth_params": auth_params}
|
||||
)
|
||||
|
||||
with patch.dict(
|
||||
"superset.config.ALLOWED_EXTRA_AUTHENTICATIONS",
|
||||
{"trino": {"custom_auth": auth_class}},
|
||||
clear=True,
|
||||
):
|
||||
params: Dict[str, Any] = {}
|
||||
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
|
||||
|
||||
connect_args = params.setdefault("connect_args", {})
|
||||
self.assertEqual(connect_args.get("http_scheme"), "https")
|
||||
|
||||
auth_class.assert_called_once_with(**auth_params)
|
||||
|
||||
def test_auth_custom_auth_denied(self):
|
||||
database = Mock()
|
||||
auth_method = "my.module:TrinoAuthClass"
|
||||
auth_params = {"params1": "params1", "params2": "params2"}
|
||||
database.encrypted_extra = json.dumps(
|
||||
{"auth_method": auth_method, "auth_params": auth_params}
|
||||
)
|
||||
|
||||
superset.config.ALLOWED_EXTRA_AUTHENTICATIONS = {}
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
TrinoEngineSpec.update_params_from_encrypted_extra(database, {})
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
f"For security reason, custom authentication '{auth_method}' "
|
||||
f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config"
|
||||
)
|
||||
|
||||
def test_convert_dttm(self):
|
||||
dttm = self.get_dttm()
|
||||
|
||||
self.assertEqual(
|
||||
TrinoEngineSpec.convert_dttm("TIMESTAMP", dttm),
|
||||
"TIMESTAMP '2019-01-02 03:04:05.678900'",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
TrinoEngineSpec.convert_dttm("TIMESTAMP(3)", dttm),
|
||||
"TIMESTAMP '2019-01-02 03:04:05.678900'",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
TrinoEngineSpec.convert_dttm("TIMESTAMP WITH TIME ZONE", dttm),
|
||||
"TIMESTAMP '2019-01-02 03:04:05.678900'",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
TrinoEngineSpec.convert_dttm("TIMESTAMP(3) WITH TIME ZONE", dttm),
|
||||
"TIMESTAMP '2019-01-02 03:04:05.678900'",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
TrinoEngineSpec.convert_dttm("DATE", dttm),
|
||||
"DATE '2019-01-02'",
|
||||
)
|
||||
|
||||
def test_extra_table_metadata(self):
|
||||
db = mock.Mock()
|
||||
db.get_indexes = mock.Mock(
|
||||
return_value=[{"column_names": ["ds", "hour"], "name": "partition"}]
|
||||
)
|
||||
db.get_extra = mock.Mock(return_value={})
|
||||
db.has_view_by_name = mock.Mock(return_value=None)
|
||||
db.get_df = mock.Mock(
|
||||
return_value=pd.DataFrame({"ds": ["01-01-19"], "hour": [1]})
|
||||
)
|
||||
result = TrinoEngineSpec.extra_table_metadata(db, "test_table", "test_schema")
|
||||
assert result["partitions"]["cols"] == ["ds", "hour"]
|
||||
assert result["partitions"]["latest"] == {"ds": "01-01-19", "hour": 1}
|
||||
@@ -39,7 +39,6 @@ from superset.utils.core import (
|
||||
AdhocMetricExpressionType,
|
||||
FilterOperator,
|
||||
GenericDataType,
|
||||
TemporalType,
|
||||
)
|
||||
from superset.utils.database import get_example_database
|
||||
from tests.integration_tests.fixtures.birth_names_dashboard import (
|
||||
@@ -805,7 +804,7 @@ def test__normalize_prequery_result_type(
|
||||
def _convert_dttm(
|
||||
target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[str]:
|
||||
if target_type.upper() == TemporalType.TIMESTAMP:
|
||||
if target_type.upper() == "TIMESTAMP":
|
||||
return f"""TIME_PARSE('{dttm.isoformat(timespec="seconds")}')"""
|
||||
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user