chore(db_engine_specs): clean up column spec logic and add tests (#22871)

This commit is contained in:
Ville Brofeldt
2023-01-31 15:54:07 +02:00
committed by GitHub
parent 8466eec228
commit cd6fc35f60
73 changed files with 1953 additions and 1463 deletions

View File

@@ -17,8 +17,12 @@
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
import re
from datetime import datetime
from typing import Optional
import pytest
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
SYNTAX_ERROR_REGEX = re.compile(
@@ -26,19 +30,20 @@ SYNTAX_ERROR_REGEX = re.compile(
)
def test_convert_dttm(dttm: datetime) -> None:
"""
Test that date objects are converted correctly.
"""
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "DATE '2019-01-02'"),
("TimeStamp", "TIMESTAMP '2019-01-02 03:04:05.678'"),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.athena import AthenaEngineSpec as spec
from superset.db_engine_specs.athena import AthenaEngineSpec
assert AthenaEngineSpec.convert_dttm("DATE", dttm) == "DATE '2019-01-02'"
assert (
AthenaEngineSpec.convert_dttm("TIMESTAMP", dttm)
== "TIMESTAMP '2019-01-02 03:04:05.678'"
)
assert_convert_dttm(spec, target_type, expected_result, dttm)
def test_extract_errors() -> None:

View File

@@ -17,9 +17,13 @@
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
from textwrap import dedent
from typing import Any, Dict, Optional, Type
import pytest
from sqlalchemy.types import TypeEngine
from sqlalchemy import types
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import assert_column_spec
def test_get_text_clause_with_colon() -> None:
@@ -94,8 +98,43 @@ select 'USD' as cur
),
],
)
def test_cte_query_parsing(original: TypeEngine, expected: str) -> None:
def test_cte_query_parsing(original: types.TypeEngine, expected: str) -> None:
from superset.db_engine_specs.base import BaseEngineSpec
actual = BaseEngineSpec.get_cte_query(original)
assert actual == expected
@pytest.mark.parametrize(
"native_type,sqla_type,attrs,generic_type,is_dttm",
[
("SMALLINT", types.SmallInteger, None, GenericDataType.NUMERIC, False),
("INTEGER", types.Integer, None, GenericDataType.NUMERIC, False),
("BIGINT", types.BigInteger, None, GenericDataType.NUMERIC, False),
("DECIMAL", types.Numeric, None, GenericDataType.NUMERIC, False),
("NUMERIC", types.Numeric, None, GenericDataType.NUMERIC, False),
("REAL", types.REAL, None, GenericDataType.NUMERIC, False),
("DOUBLE PRECISION", types.Float, None, GenericDataType.NUMERIC, False),
("MONEY", types.Numeric, None, GenericDataType.NUMERIC, False),
# String
("CHAR", types.String, None, GenericDataType.STRING, False),
("VARCHAR", types.String, None, GenericDataType.STRING, False),
("TEXT", types.String, None, GenericDataType.STRING, False),
# Temporal
("DATE", types.Date, None, GenericDataType.TEMPORAL, True),
("TIMESTAMP", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True),
("TIME", types.Time, None, GenericDataType.TEMPORAL, True),
# Boolean
("BOOLEAN", types.Boolean, None, GenericDataType.BOOLEAN, False),
],
)
def test_get_column_spec(
native_type: str,
sqla_type: Type[types.TypeEngine],
attrs: Optional[Dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec as spec
assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm)

View File

@@ -18,12 +18,18 @@
# pylint: disable=line-too-long, import-outside-toplevel, protected-access, invalid-name
import json
from datetime import datetime
from typing import Optional
import pytest
from pytest_mock import MockFixture
from sqlalchemy import select
from sqlalchemy.sql import sqltypes
from sqlalchemy_bigquery import BigQueryDialect
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
def test_get_fields() -> None:
"""
@@ -285,3 +291,24 @@ def test_parse_error_raises_exception() -> None:
== expected_result
)
assert str(BigQueryEngineSpec.parse_error_exception(Exception(message_2))) == "6"
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "CAST('2019-01-02' AS DATE)"),
("DateTime", "CAST('2019-01-02T03:04:05.678900' AS DATETIME)"),
("TimeStamp", "CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)"),
("Time", "CAST('03:04:05.678900' AS TIME)"),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
"""
DB Eng Specs (bigquery): Test conversion to date time
"""
from superset.db_engine_specs.bigquery import BigQueryEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@@ -14,22 +14,31 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from unittest import mock
from typing import Optional
from unittest.mock import Mock
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
def test_convert_dttm(dttm: datetime) -> None:
from superset.db_engine_specs.clickhouse import ClickHouseEngineSpec
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "toDate('2019-01-02')"),
("DateTime", "toDateTime('2019-01-02 03:04:05')"),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.clickhouse import ClickHouseEngineSpec as spec
assert ClickHouseEngineSpec.convert_dttm("DATE", dttm) == "toDate('2019-01-02')"
assert (
ClickHouseEngineSpec.convert_dttm("DATETIME", dttm)
== "toDateTime('2019-01-02 03:04:05')"
)
assert_convert_dttm(spec, target_type, expected_result, dttm)
def test_execute_connection_error() -> None:
@@ -38,7 +47,7 @@ def test_execute_connection_error() -> None:
from superset.db_engine_specs.clickhouse import ClickHouseEngineSpec
from superset.db_engine_specs.exceptions import SupersetDBAPIDatabaseError
cursor = mock.Mock()
cursor = Mock()
cursor.execute.side_effect = NewConnectionError(
"Dummypool", "Exception with sensitive data"
)

View File

@@ -0,0 +1,71 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
def test_epoch_to_dttm() -> None:
"""
DB Eng Specs (crate): Test epoch to dttm
"""
from superset.db_engine_specs.crate import CrateEngineSpec
assert CrateEngineSpec.epoch_to_dttm() == "{col} * 1000"
def test_epoch_ms_to_dttm() -> None:
"""
DB Eng Specs (crate): Test epoch ms to dttm
"""
from superset.db_engine_specs.crate import CrateEngineSpec
assert CrateEngineSpec.epoch_ms_to_dttm() == "{col}"
def test_alter_new_orm_column() -> None:
"""
DB Eng Specs (crate): Test alter orm column
"""
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.db_engine_specs.crate import CrateEngineSpec
from superset.models.core import Database
database = Database(database_name="crate", sqlalchemy_uri="crate://db")
tbl = SqlaTable(table_name="tbl", database=database)
col = TableColumn(column_name="ts", type="TIMESTAMP", table=tbl)
CrateEngineSpec.alter_new_orm_column(col)
assert col.python_date_format == "epoch_ms"
@pytest.mark.parametrize(
"target_type,expected_result",
[
("TimeStamp", "1546398245678.9"),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.crate import CrateEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@@ -17,13 +17,16 @@
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
import json
from datetime import datetime
from typing import Optional
import pytest
from pytest_mock import MockerFixture
from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
def test_get_parameters_from_uri() -> None:
@@ -109,37 +112,6 @@ def test_parameters_json_schema() -> None:
}
def test_generic_type() -> None:
"""
assert that generic types match
"""
from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import assert_generic_types
type_expectations = (
# Numeric
("SMALLINT", GenericDataType.NUMERIC),
("INTEGER", GenericDataType.NUMERIC),
("BIGINT", GenericDataType.NUMERIC),
("DECIMAL", GenericDataType.NUMERIC),
("NUMERIC", GenericDataType.NUMERIC),
("REAL", GenericDataType.NUMERIC),
("DOUBLE PRECISION", GenericDataType.NUMERIC),
("MONEY", GenericDataType.NUMERIC),
# String
("CHAR", GenericDataType.STRING),
("VARCHAR", GenericDataType.STRING),
("TEXT", GenericDataType.STRING),
# Temporal
("DATE", GenericDataType.TEMPORAL),
("TIMESTAMP", GenericDataType.TEMPORAL),
("TIME", GenericDataType.TEMPORAL),
# Boolean
("BOOLEAN", GenericDataType.BOOLEAN),
)
assert_generic_types(DatabricksNativeEngineSpec, type_expectations)
def test_get_extra_params(mocker: MockerFixture) -> None:
"""
Test the ``get_extra_params`` method.
@@ -253,3 +225,22 @@ def test_extract_errors_with_context() -> None:
},
)
]
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "CAST('2019-01-02' AS DATE)"),
(
"TimeStamp",
"CAST('2019-01-02 03:04:05.678900' AS TIMESTAMP)",
),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@@ -0,0 +1,42 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "TO_DATE('2019-01-02', 'YYYY-MM-DD')"),
(
"TimeStamp",
"TO_TIMESTAMP('2019-01-02 03:04:05.678', 'YYYY-MM-DD HH24:MI:SS.FFF')",
),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.dremio import DremioEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@@ -16,7 +16,13 @@
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
from pytest import raises
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
def test_odbc_impersonation() -> None:
@@ -82,5 +88,21 @@ def test_invalid_impersonation() -> None:
url = URL("drill+foobar")
username = "DoAsUser"
with raises(SupersetDBAPIProgrammingError):
with pytest.raises(SupersetDBAPIProgrammingError):
DrillEngineSpec.get_url_for_impersonation(url, True, username)
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "TO_DATE('2019-01-02', 'yyyy-MM-dd')"),
("TimeStamp", "TO_TIMESTAMP('2019-01-02 03:04:05', 'yyyy-MM-dd HH:mm:ss')"),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.drill import DrillEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@@ -0,0 +1,95 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Optional
from unittest import mock
import pytest
from sqlalchemy import column
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "CAST(TIME_PARSE('2019-01-02') AS DATE)"),
("DateTime", "TIME_PARSE('2019-01-02T03:04:05')"),
("TimeStamp", "TIME_PARSE('2019-01-02T03:04:05')"),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.druid import DruidEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)
@pytest.mark.parametrize(
"time_grain,expected_result",
[
("PT1S", "TIME_FLOOR(CAST(col AS TIMESTAMP), 'PT1S')"),
("PT5M", "TIME_FLOOR(CAST({col} AS TIMESTAMP), 'PT5M')"),
(
"P1W/1970-01-03T00:00:00Z",
"TIME_SHIFT(TIME_FLOOR(TIME_SHIFT(CAST(col AS TIMESTAMP), 'P1D', 1), 'P1W'), 'P1D', 5)",
),
(
"1969-12-28T00:00:00Z/P1W",
"TIME_SHIFT(TIME_FLOOR(TIME_SHIFT(CAST(col AS TIMESTAMP), 'P1D', 1), 'P1W'), 'P1D', -1)",
),
],
)
def test_timegrain_expressions(time_grain: str, expected_result: str) -> None:
"""
DB Eng Specs (druid): Test time grain expressions
"""
from superset.db_engine_specs.druid import DruidEngineSpec
assert str(
DruidEngineSpec.get_timestamp_expr(
col=column("col"), pdf=None, time_grain=time_grain
)
)
def test_extras_without_ssl() -> None:
from superset.db_engine_specs.druid import DruidEngineSpec
from tests.integration_tests.fixtures.database import default_db_extra
db = mock.Mock()
db.extra = default_db_extra
db.server_cert = None
extras = DruidEngineSpec.get_extra_params(db)
assert "connect_args" not in extras["engine_params"]
def test_extras_with_ssl() -> None:
from superset.db_engine_specs.druid import DruidEngineSpec
from tests.integration_tests.fixtures.certificates import ssl_certificate
from tests.integration_tests.fixtures.database import default_db_extra
db = mock.Mock()
db.extra = default_db_extra
db.server_cert = ssl_certificate
extras = DruidEngineSpec.get_extra_params(db)
connect_args = extras["engine_params"]["connect_args"]
assert connect_args["scheme"] == "https"
assert "ssl_verify_cert" in connect_args

View File

@@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Text", "'2019-01-02 03:04:05.678900'"),
("DateTime", "'2019-01-02 03:04:05.678900'"),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.duckdb import DuckDBEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@@ -14,24 +14,27 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
def test_convert_dttm(dttm: datetime) -> None:
from superset.db_engine_specs.dynamodb import DynamoDBEngineSpec
@pytest.mark.parametrize(
"target_type,expected_result",
[
("text", "'2019-01-02 03:04:05'"),
("dateTime", "'2019-01-02 03:04:05'"),
("unknowntype", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.dynamodb import DynamoDBEngineSpec as spec
assert DynamoDBEngineSpec.convert_dttm("TEXT", dttm) == "'2019-01-02 03:04:05'"
def test_convert_dttm_lower(dttm: datetime) -> None:
from superset.db_engine_specs.dynamodb import DynamoDBEngineSpec
assert DynamoDBEngineSpec.convert_dttm("text", dttm) == "'2019-01-02 03:04:05'"
def test_convert_dttm_invalid_type(dttm: datetime) -> None:
from superset.db_engine_specs.dynamodb import DynamoDBEngineSpec
assert DynamoDBEngineSpec.convert_dttm("other", dttm) is None
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@@ -0,0 +1,106 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Any, Dict, Optional
from unittest.mock import MagicMock
import pytest
from sqlalchemy import column
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"target_type,db_extra,expected_result",
[
("DateTime", None, "CAST('2019-01-02T03:04:05' AS DATETIME)"),
(
"DateTime",
{"version": "7.7"},
"CAST('2019-01-02T03:04:05' AS DATETIME)",
),
(
"DateTime",
{"version": "7.8"},
"DATETIME_PARSE('2019-01-02 03:04:05', 'yyyy-MM-dd HH:mm:ss')",
),
(
"DateTime",
{"version": "unparseable semver version"},
"CAST('2019-01-02T03:04:05' AS DATETIME)",
),
("Unknown", None, None),
],
)
def test_elasticsearch_convert_dttm(
target_type: str,
db_extra: Optional[Dict[str, Any]],
expected_result: Optional[str],
dttm: datetime,
) -> None:
from superset.db_engine_specs.elasticsearch import ElasticSearchEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm, db_extra)
@pytest.mark.parametrize(
"target_type,expected_result",
[
("DateTime", "'2019-01-02T03:04:05'"),
("Unknown", None),
],
)
def test_opendistro_convert_dttm(
target_type: str,
expected_result: Optional[str],
dttm: datetime,
) -> None:
from superset.db_engine_specs.elasticsearch import OpenDistroEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)
@pytest.mark.parametrize(
"original,expected",
[
("Col", "Col"),
("Col.keyword", "Col_keyword"),
],
)
def test_opendistro_sqla_column_label(original: str, expected: str) -> None:
"""
DB Eng Specs (opendistro): Test column label
"""
from superset.db_engine_specs.elasticsearch import OpenDistroEngineSpec
assert OpenDistroEngineSpec.make_label_compatible(original) == expected
def test_opendistro_strip_comments() -> None:
"""
DB Eng Specs (opendistro): Test execute sql strip comments
"""
from superset.db_engine_specs.elasticsearch import OpenDistroEngineSpec
mock_cursor = MagicMock()
mock_cursor.execute.return_value = []
OpenDistroEngineSpec.execute(
mock_cursor, "-- some comment \nSELECT 1\n --other comment"
)
mock_cursor.execute.assert_called_once_with("SELECT 1\n")

View File

@@ -0,0 +1,102 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"time_grain,expected",
[
(None, "timestamp_column"),
(
"PT1S",
(
"CAST(CAST(timestamp_column AS DATE) "
"|| ' ' "
"|| EXTRACT(HOUR FROM timestamp_column) "
"|| ':' "
"|| EXTRACT(MINUTE FROM timestamp_column) "
"|| ':' "
"|| FLOOR(EXTRACT(SECOND FROM timestamp_column)) AS TIMESTAMP)"
),
),
(
"PT1M",
(
"CAST(CAST(timestamp_column AS DATE) "
"|| ' ' "
"|| EXTRACT(HOUR FROM timestamp_column) "
"|| ':' "
"|| EXTRACT(MINUTE FROM timestamp_column) "
"|| ':00' AS TIMESTAMP)"
),
),
("P1D", "CAST(timestamp_column AS DATE)"),
(
"P1M",
(
"CAST(EXTRACT(YEAR FROM timestamp_column) "
"|| '-' "
"|| EXTRACT(MONTH FROM timestamp_column) "
"|| '-01' AS DATE)"
),
),
("P1Y", "CAST(EXTRACT(YEAR FROM timestamp_column) || '-01-01' AS DATE)"),
],
)
def test_time_grain_expressions(time_grain: Optional[str], expected: str) -> None:
from superset.db_engine_specs.firebird import FirebirdEngineSpec
assert (
FirebirdEngineSpec._time_grain_expressions[time_grain].format(
col="timestamp_column",
)
== expected
)
def test_epoch_to_dttm() -> None:
from superset.db_engine_specs.firebird import FirebirdEngineSpec
assert (
FirebirdEngineSpec.epoch_to_dttm().format(col="timestamp_column")
== "DATEADD(second, timestamp_column, CAST('00:00:00' AS TIMESTAMP))"
)
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "CAST('2019-01-02' AS DATE)"),
("DateTime", "CAST('2019-01-02 03:04:05.6789' AS TIMESTAMP)"),
("TimeStamp", "CAST('2019-01-02 03:04:05.6789' AS TIMESTAMP)"),
("Time", "CAST('03:04:05.678900' AS TIME)"),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.firebird import FirebirdEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@@ -0,0 +1,57 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "CAST('2019-01-02' AS DATE)"),
(
"DateTime",
"CAST('2019-01-02T03:04:05' AS DATETIME)",
),
(
"TimeStamp",
"CAST('2019-01-02T03:04:05' AS TIMESTAMP)",
),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.firebolt import FireboltEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)
def test_epoch_to_dttm() -> None:
from superset.db_engine_specs.firebolt import FireboltEngineSpec
assert (
FireboltEngineSpec.epoch_to_dttm().format(col="timestamp_column")
== "from_unixtime(timestamp_column)"
)

View File

@@ -0,0 +1,43 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "TO_DATE('2019-01-02', 'YYYY-MM-DD')"),
(
"TimeStamp",
"TO_TIMESTAMP('2019-01-02T03:04:05.678900', 'YYYY-MM-DD\"T\"HH24:MI:SS.ff6')",
),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.hana import HanaEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@@ -0,0 +1,44 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "CAST('2019-01-02' AS DATE)"),
(
"TimeStamp",
"CAST('2019-01-02 03:04:05.678900' AS TIMESTAMP)",
),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.hive import HiveEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "CAST('2019-01-02' AS DATE)"),
("TimeStamp", "CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)"),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.impala import ImpalaEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@@ -16,9 +16,11 @@
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@@ -108,45 +110,35 @@ def test_kql_parse_sql() -> None:
@pytest.mark.parametrize(
"target_type,expected_dttm",
"target_type,expected_result",
[
("DATETIME", "datetime(2019-01-02T03:04:05.678900)"),
("TIMESTAMP", "datetime(2019-01-02T03:04:05.678900)"),
("DATE", "datetime(2019-01-02)"),
("DateTime", "datetime(2019-01-02T03:04:05.678900)"),
("TimeStamp", "datetime(2019-01-02T03:04:05.678900)"),
("Date", "datetime(2019-01-02)"),
("UnknownType", None),
],
)
def test_kql_convert_dttm(
target_type: str,
expected_dttm: str,
dttm: datetime,
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
"""
Test that date objects are converted correctly.
"""
from superset.db_engine_specs.kusto import KustoKqlEngineSpec as spec
from superset.db_engine_specs.kusto import KustoKqlEngineSpec
assert expected_dttm == KustoKqlEngineSpec.convert_dttm(target_type, dttm)
assert_convert_dttm(spec, target_type, expected_result, dttm)
@pytest.mark.parametrize(
"target_type,expected_dttm",
"target_type,expected_result",
[
("DATETIME", "CONVERT(DATETIME, '2019-01-02T03:04:05.678', 126)"),
("DATE", "CONVERT(DATE, '2019-01-02', 23)"),
("SMALLDATETIME", "CONVERT(SMALLDATETIME, '2019-01-02 03:04:05', 20)"),
("TIMESTAMP", "CONVERT(TIMESTAMP, '2019-01-02 03:04:05', 20)"),
("Date", "CONVERT(DATE, '2019-01-02', 23)"),
("DateTime", "CONVERT(DATETIME, '2019-01-02T03:04:05.678', 126)"),
("SmallDateTime", "CONVERT(SMALLDATETIME, '2019-01-02 03:04:05', 20)"),
("TimeStamp", "CONVERT(TIMESTAMP, '2019-01-02 03:04:05', 20)"),
("UnknownType", None),
],
)
def test_sql_convert_dttm(
target_type: str,
expected_dttm: str,
dttm: datetime,
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
"""
Test that date objects are converted correctly.
"""
from superset.db_engine_specs.kusto import KustoSqlEngineSpec as spec
from superset.db_engine_specs.kusto import KustoSqlEngineSpec
assert expected_dttm == KustoSqlEngineSpec.convert_dttm(target_type, dttm)
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "CAST('2019-01-02' AS DATE)"),
("TimeStamp", "CAST('2019-01-02 03:04:05' AS TIMESTAMP)"),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.kylin import KylinEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@@ -17,9 +17,10 @@
import unittest.mock as mock
from datetime import datetime
from textwrap import dedent
from typing import Any, Dict, Optional, Type
import pytest
from sqlalchemy import column, table
from sqlalchemy import column, table, types
from sqlalchemy.dialects import mssql
from sqlalchemy.dialects.mssql import DATE, NTEXT, NVARCHAR, TEXT, VARCHAR
from sqlalchemy.sql import select
@@ -27,36 +28,36 @@ from sqlalchemy.types import String, TypeEngine, UnicodeText
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
assert_column_spec,
assert_convert_dttm,
)
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"type_string,type_expected,generic_type_expected",
"native_type,sqla_type,attrs,generic_type,is_dttm",
[
("STRING", String, GenericDataType.STRING),
("CHAR(10)", String, GenericDataType.STRING),
("VARCHAR(10)", String, GenericDataType.STRING),
("TEXT", String, GenericDataType.STRING),
("NCHAR(10)", UnicodeText, GenericDataType.STRING),
("NVARCHAR(10)", UnicodeText, GenericDataType.STRING),
("NTEXT", UnicodeText, GenericDataType.STRING),
("CHAR", String, None, GenericDataType.STRING, False),
("CHAR(10)", String, None, GenericDataType.STRING, False),
("VARCHAR", String, None, GenericDataType.STRING, False),
("VARCHAR(10)", String, None, GenericDataType.STRING, False),
("TEXT", String, None, GenericDataType.STRING, False),
("NCHAR(10)", UnicodeText, None, GenericDataType.STRING, False),
("NVARCHAR(10)", UnicodeText, None, GenericDataType.STRING, False),
("NTEXT", UnicodeText, None, GenericDataType.STRING, False),
],
)
def test_mssql_column_types(
type_string: str,
type_expected: TypeEngine,
generic_type_expected: GenericDataType,
def test_get_column_spec(
native_type: str,
sqla_type: Type[types.TypeEngine],
attrs: Optional[Dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
from superset.db_engine_specs.mssql import MssqlEngineSpec
from superset.db_engine_specs.mssql import MssqlEngineSpec as spec
if type_expected is None:
type_assigned = MssqlEngineSpec.get_sqla_column_type(type_string)
assert type_assigned is None
else:
column_spec = MssqlEngineSpec.get_column_spec(type_string)
if column_spec is not None:
assert isinstance(column_spec.sqla_type, type_expected)
assert column_spec.generic_type == generic_type_expected
assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm)
def test_where_clause_n_prefix() -> None:
@@ -65,13 +66,13 @@ def test_where_clause_n_prefix() -> None:
dialect = mssql.dialect()
# non-unicode col
sqla_column_type = MssqlEngineSpec.get_sqla_column_type("VARCHAR(10)")
sqla_column_type = MssqlEngineSpec.get_column_types("VARCHAR(10)")
assert sqla_column_type is not None
type_, _ = sqla_column_type
str_col = column("col", type_=type_)
# unicode col
sqla_column_type = MssqlEngineSpec.get_sqla_column_type("NTEXT")
sqla_column_type = MssqlEngineSpec.get_column_types("NTEXT")
assert sqla_column_type is not None
type_, _ = sqla_column_type
unicode_col = column("unicode_col", type_=type_)
@@ -103,30 +104,31 @@ def test_time_exp_mixd_case_col_1y() -> None:
@pytest.mark.parametrize(
"actual,expected",
"target_type,expected_result",
[
(
"DATE",
"date",
"CONVERT(DATE, '2019-01-02', 23)",
),
(
"DATETIME",
"datetime",
"CONVERT(DATETIME, '2019-01-02T03:04:05.678', 126)",
),
(
"SMALLDATETIME",
"smalldatetime",
"CONVERT(SMALLDATETIME, '2019-01-02 03:04:05', 20)",
),
("Other", None),
],
)
def test_convert_dttm(
actual: str,
expected: str,
target_type: str,
expected_result: Optional[str],
dttm: datetime,
) -> None:
from superset.db_engine_specs.mssql import MssqlEngineSpec
from superset.db_engine_specs.mssql import MssqlEngineSpec as spec
assert MssqlEngineSpec.convert_dttm(actual, dttm) == expected
assert_convert_dttm(spec, target_type, expected_result, dttm)
def test_extract_error_message() -> None:

View File

@@ -0,0 +1,130 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Any, Dict, Optional, Type
from unittest.mock import Mock, patch
import pytest
from sqlalchemy import types
from sqlalchemy.dialects.mysql import (
BIT,
DECIMAL,
DOUBLE,
FLOAT,
INTEGER,
LONGTEXT,
MEDIUMINT,
MEDIUMTEXT,
TINYINT,
TINYTEXT,
)
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
assert_column_spec,
assert_convert_dttm,
)
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"native_type,sqla_type,attrs,generic_type,is_dttm",
[
# Numeric
("TINYINT", TINYINT, None, GenericDataType.NUMERIC, False),
("SMALLINT", types.SmallInteger, None, GenericDataType.NUMERIC, False),
("MEDIUMINT", MEDIUMINT, None, GenericDataType.NUMERIC, False),
("INT", INTEGER, None, GenericDataType.NUMERIC, False),
("BIGINT", types.BigInteger, None, GenericDataType.NUMERIC, False),
("DECIMAL", DECIMAL, None, GenericDataType.NUMERIC, False),
("FLOAT", FLOAT, None, GenericDataType.NUMERIC, False),
("DOUBLE", DOUBLE, None, GenericDataType.NUMERIC, False),
("BIT", BIT, None, GenericDataType.NUMERIC, False),
# String
("CHAR", types.String, None, GenericDataType.STRING, False),
("VARCHAR", types.String, None, GenericDataType.STRING, False),
("TINYTEXT", TINYTEXT, None, GenericDataType.STRING, False),
("MEDIUMTEXT", MEDIUMTEXT, None, GenericDataType.STRING, False),
("LONGTEXT", LONGTEXT, None, GenericDataType.STRING, False),
# Temporal
("DATE", types.Date, None, GenericDataType.TEMPORAL, True),
("DATETIME", types.DateTime, None, GenericDataType.TEMPORAL, True),
("TIMESTAMP", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True),
("TIME", types.Time, None, GenericDataType.TEMPORAL, True),
],
)
def test_get_column_spec(
native_type: str,
sqla_type: Type[types.TypeEngine],
attrs: Optional[Dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
from superset.db_engine_specs.mysql import MySQLEngineSpec as spec
assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm)
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "STR_TO_DATE('2019-01-02', '%Y-%m-%d')"),
(
"DateTime",
"STR_TO_DATE('2019-01-02 03:04:05.678900', '%Y-%m-%d %H:%i:%s.%f')",
),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.mysql import MySQLEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)
@patch("sqlalchemy.engine.Engine.connect")
def test_get_cancel_query_id(engine_mock: Mock) -> None:
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.models.sql_lab import Query
query = Query()
cursor_mock = engine_mock.return_value.__enter__.return_value
cursor_mock.fetchone.return_value = ["123"]
assert MySQLEngineSpec.get_cancel_query_id(cursor_mock, query) == "123"
@patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query(engine_mock: Mock) -> None:
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.models.sql_lab import Query
query = Query()
cursor_mock = engine_mock.return_value.__enter__.return_value
assert MySQLEngineSpec.cancel_query(cursor_mock, query, "123") is True
@patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query_failed(engine_mock: Mock) -> None:
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.models.sql_lab import Query
query = Query()
cursor_mock = engine_mock.raiseError.side_effect = Exception()
assert MySQLEngineSpec.cancel_query(cursor_mock, query, "123") is False

View File

@@ -0,0 +1,113 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Optional, Union
from unittest import mock
import pytest
from sqlalchemy import column, types
from sqlalchemy.dialects import oracle
from sqlalchemy.dialects.oracle import DATE, NVARCHAR, VARCHAR
from sqlalchemy.sql import quoted_name
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"column_name,expected_result",
[
("This_Is_32_Character_Column_Name", "3b26974078683be078219674eeb8f5"),
("snake_label", "snake_label"),
("camelLabel", "camelLabel"),
],
)
def test_oracle_sqla_column_name_length_exceeded(
column_name: str, expected_result: Union[str, quoted_name]
) -> None:
from superset.db_engine_specs.oracle import OracleEngineSpec
label = OracleEngineSpec.make_label_compatible(column_name)
assert isinstance(label, quoted_name)
assert label.quote is True
assert label == expected_result
def test_oracle_time_expression_reserved_keyword_1m_grain() -> None:
from superset.db_engine_specs.oracle import OracleEngineSpec
col = column("decimal")
expr = OracleEngineSpec.get_timestamp_expr(col, None, "P1M")
result = str(expr.compile(dialect=oracle.dialect()))
assert result == "TRUNC(CAST(\"decimal\" as DATE), 'MONTH')"
@pytest.mark.parametrize(
"sqla_type,expected_result",
[
(DATE(), "DATE"),
(VARCHAR(length=255), "VARCHAR(255 CHAR)"),
(VARCHAR(length=255, collation="utf8"), "VARCHAR(255 CHAR)"),
(NVARCHAR(length=128), "NVARCHAR2(128)"),
],
)
def test_column_datatype_to_string(
sqla_type: types.TypeEngine, expected_result: str
) -> None:
from superset.db_engine_specs.oracle import OracleEngineSpec
assert (
OracleEngineSpec.column_datatype_to_string(sqla_type, oracle.dialect())
== expected_result
)
def test_fetch_data_no_description() -> None:
from superset.db_engine_specs.oracle import OracleEngineSpec
cursor = mock.MagicMock()
cursor.description = []
assert OracleEngineSpec.fetch_data(cursor) == []
def test_fetch_data() -> None:
from superset.db_engine_specs.oracle import OracleEngineSpec
cursor = mock.MagicMock()
result = ["a", "b"]
cursor.fetchall.return_value = result
assert OracleEngineSpec.fetch_data(cursor) == result
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "TO_DATE('2019-01-02', 'YYYY-MM-DD')"),
("DateTime", """TO_DATE('2019-01-02T03:04:05', 'YYYY-MM-DD"T"HH24:MI:SS')"""),
(
"TimeStamp",
"""TO_TIMESTAMP('2019-01-02T03:04:05.678900', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')""",
),
("Other", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.oracle import OracleEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@@ -0,0 +1,91 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Any, Dict, Optional, Type
import pytest
from sqlalchemy import types
from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
assert_column_spec,
assert_convert_dttm,
)
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "TO_DATE('2019-01-02', 'YYYY-MM-DD')"),
(
"DateTime",
"TO_TIMESTAMP('2019-01-02 03:04:05.678900', 'YYYY-MM-DD HH24:MI:SS.US')",
),
(
"TimeStamp",
"TO_TIMESTAMP('2019-01-02 03:04:05.678900', 'YYYY-MM-DD HH24:MI:SS.US')",
),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.postgres import PostgresEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)
@pytest.mark.parametrize(
"native_type,sqla_type,attrs,generic_type,is_dttm",
[
("SMALLINT", types.SmallInteger, None, GenericDataType.NUMERIC, False),
("INTEGER", types.Integer, None, GenericDataType.NUMERIC, False),
("BIGINT", types.BigInteger, None, GenericDataType.NUMERIC, False),
("DECIMAL", types.Numeric, None, GenericDataType.NUMERIC, False),
("NUMERIC", types.Numeric, None, GenericDataType.NUMERIC, False),
("REAL", types.REAL, None, GenericDataType.NUMERIC, False),
("DOUBLE PRECISION", DOUBLE_PRECISION, None, GenericDataType.NUMERIC, False),
("MONEY", types.Numeric, None, GenericDataType.NUMERIC, False),
# String
("CHAR", types.String, None, GenericDataType.STRING, False),
("VARCHAR", types.String, None, GenericDataType.STRING, False),
("TEXT", types.String, None, GenericDataType.STRING, False),
("ARRAY", types.String, None, GenericDataType.STRING, False),
("ENUM", ENUM, None, GenericDataType.STRING, False),
("JSON", JSON, None, GenericDataType.STRING, False),
# Temporal
("DATE", types.Date, None, GenericDataType.TEMPORAL, True),
("TIMESTAMP", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True),
("TIME", types.Time, None, GenericDataType.TEMPORAL, True),
# Boolean
("BOOLEAN", types.Boolean, None, GenericDataType.BOOLEAN, False),
],
)
def test_get_column_spec(
native_type: str,
sqla_type: Type[types.TypeEngine],
attrs: Optional[Dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
from superset.db_engine_specs.postgres import PostgresEngineSpec as spec
assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm)

View File

@@ -15,14 +15,21 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Optional
from typing import Any, Dict, Optional, Type
import pytest
import pytz
from sqlalchemy import types
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
assert_column_spec,
assert_convert_dttm,
)
@pytest.mark.parametrize(
"target_type,dttm,result",
"target_type,dttm,expected_result",
[
("VARCHAR", datetime(2022, 1, 1), None),
("DATE", datetime(2022, 1, 1), "DATE '2022-01-01'"),
@@ -46,9 +53,32 @@ import pytz
def test_convert_dttm(
target_type: str,
dttm: datetime,
result: Optional[str],
expected_result: Optional[str],
) -> None:
from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.db_engine_specs.presto import PrestoEngineSpec as spec
for case in (str.lower, str.upper):
assert PrestoEngineSpec.convert_dttm(case(target_type), dttm) == result
assert_convert_dttm(spec, target_type, expected_result, dttm)
@pytest.mark.parametrize(
"native_type,sqla_type,attrs,generic_type,is_dttm",
[
("varchar(255)", types.VARCHAR, {"length": 255}, GenericDataType.STRING, False),
("varchar", types.String, None, GenericDataType.STRING, False),
("char(255)", types.CHAR, {"length": 255}, GenericDataType.STRING, False),
("char", types.String, None, GenericDataType.STRING, False),
("integer", types.Integer, None, GenericDataType.NUMERIC, False),
("time", types.Time, None, GenericDataType.TEMPORAL, True),
("timestamp", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True),
],
)
def test_get_column_spec(
native_type: str,
sqla_type: Type[types.TypeEngine],
attrs: Optional[Dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
from superset.db_engine_specs.presto import PrestoEngineSpec as spec
assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm)

View File

@@ -0,0 +1,41 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "DATE '2019-01-02'"),
("DateTime", "DATETIME '2019-01-02 03:04:05.678900'"),
("Timestamp", "TIMESTAMP '2019-01-02T03:04:05.678900'"),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.rockset import RocksetEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@@ -19,28 +19,43 @@
import json
from datetime import datetime
from typing import Optional
from unittest import mock
import pytest
from pytest_mock import MockerFixture
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"actual,expected",
"target_type,expected_result",
[
("DATE", "TO_DATE('2019-01-02')"),
("DATETIME", "CAST('2019-01-02T03:04:05.678900' AS DATETIME)"),
("TIMESTAMP", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"),
("Date", "TO_DATE('2019-01-02')"),
("DateTime", "CAST('2019-01-02T03:04:05.678900' AS DATETIME)"),
("TimeStamp", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"),
("TIMESTAMP_NTZ", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"),
("TIMESTAMP_LTZ", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"),
("TIMESTAMP_TZ", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"),
("TIMESTAMPLTZ", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"),
("TIMESTAMPNTZ", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"),
("TIMESTAMPTZ", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"),
(
"TIMESTAMP WITH LOCAL TIME ZONE",
"TO_TIMESTAMP('2019-01-02T03:04:05.678900')",
),
("TIMESTAMP WITHOUT TIME ZONE", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"),
("UnknownType", None),
],
)
def test_convert_dttm(actual: str, expected: str, dttm: datetime) -> None:
from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.snowflake import SnowflakeEngineSpec as spec
assert SnowflakeEngineSpec.convert_dttm(actual, dttm) == expected
assert_convert_dttm(spec, target_type, expected_result, dttm)
def test_database_connection_test_mutator() -> None:

View File

@@ -16,30 +16,32 @@
# under the License.
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel, redefined-outer-name
from datetime import datetime
from unittest import mock
from typing import Optional
import pytest
from sqlalchemy.engine import create_engine
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
def test_convert_dttm(dttm: datetime) -> None:
from superset.db_engine_specs.sqlite import SqliteEngineSpec
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Text", "'2019-01-02 03:04:05'"),
("DateTime", "'2019-01-02 03:04:05'"),
("TimeStamp", "'2019-01-02 03:04:05'"),
("Other", None),
],
)
def test_convert_dttm(
target_type: str,
expected_result: Optional[str],
dttm: datetime,
) -> None:
from superset.db_engine_specs.sqlite import SqliteEngineSpec as spec
assert SqliteEngineSpec.convert_dttm("TEXT", dttm) == "'2019-01-02 03:04:05'"
def test_convert_dttm_lower(dttm: datetime) -> None:
from superset.db_engine_specs.sqlite import SqliteEngineSpec
assert SqliteEngineSpec.convert_dttm("text", dttm) == "'2019-01-02 03:04:05'"
def test_convert_dttm_invalid_type(dttm: datetime) -> None:
from superset.db_engine_specs.sqlite import SqliteEngineSpec
assert SqliteEngineSpec.convert_dttm("other", dttm) is None
assert_convert_dttm(spec, target_type, expected_result, dttm)
@pytest.mark.parametrize(

View File

@@ -16,17 +16,288 @@
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
import json
from typing import Any, Dict
from unittest import mock
from datetime import datetime
from typing import Any, Dict, Optional, Type
from unittest.mock import Mock, patch
import pandas as pd
import pytest
from pytest_mock import MockerFixture
from sqlalchemy import types
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY
import superset.config
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
assert_column_spec,
assert_convert_dttm,
)
from tests.unit_tests.fixtures.common import dttm
@mock.patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query_success(engine_mock: mock.Mock) -> None:
@pytest.mark.parametrize(
"extra,expected",
[
({}, {"engine_params": {"connect_args": {"source": USER_AGENT}}}),
(
{
"first": 1,
"engine_params": {
"second": "two",
"connect_args": {"source": "foobar", "third": "three"},
},
},
{
"first": 1,
"engine_params": {
"second": "two",
"connect_args": {"source": "foobar", "third": "three"},
},
},
),
],
)
def test_get_extra_params(extra: Dict[str, Any], expected: Dict[str, Any]) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
database.extra = json.dumps(extra)
database.server_cert = None
assert TrinoEngineSpec.get_extra_params(database) == expected
@patch("superset.utils.core.create_ssl_cert_file")
def test_get_extra_params_with_server_cert(mock_create_ssl_cert_file: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
database.extra = json.dumps({})
database.server_cert = "TEST_CERT"
mock_create_ssl_cert_file.return_value = "/path/to/tls.crt"
extra = TrinoEngineSpec.get_extra_params(database)
connect_args = extra.get("engine_params", {}).get("connect_args", {})
assert connect_args.get("http_scheme") == "https"
assert connect_args.get("verify") == "/path/to/tls.crt"
mock_create_ssl_cert_file.assert_called_once_with(database.server_cert)
@patch("trino.auth.BasicAuthentication")
def test_auth_basic(mock_auth: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
auth_params = {"username": "username", "password": "password"}
database.encrypted_extra = json.dumps(
{"auth_method": "basic", "auth_params": auth_params}
)
params: Dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
assert connect_args.get("http_scheme") == "https"
mock_auth.assert_called_once_with(**auth_params)
@patch("trino.auth.KerberosAuthentication")
def test_auth_kerberos(mock_auth: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
auth_params = {
"service_name": "superset",
"mutual_authentication": False,
"delegate": True,
}
database.encrypted_extra = json.dumps(
{"auth_method": "kerberos", "auth_params": auth_params}
)
params: Dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
assert connect_args.get("http_scheme") == "https"
mock_auth.assert_called_once_with(**auth_params)
@patch("trino.auth.CertificateAuthentication")
def test_auth_certificate(mock_auth: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
auth_params = {"cert": "/path/to/cert.pem", "key": "/path/to/key.pem"}
database.encrypted_extra = json.dumps(
{"auth_method": "certificate", "auth_params": auth_params}
)
params: Dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
assert connect_args.get("http_scheme") == "https"
mock_auth.assert_called_once_with(**auth_params)
@patch("trino.auth.JWTAuthentication")
def test_auth_jwt(mock_auth: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
auth_params = {"token": "jwt-token-string"}
database.encrypted_extra = json.dumps(
{"auth_method": "jwt", "auth_params": auth_params}
)
params: Dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
assert connect_args.get("http_scheme") == "https"
mock_auth.assert_called_once_with(**auth_params)
def test_auth_custom_auth() -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
auth_class = Mock()
auth_method = "custom_auth"
auth_params = {"params1": "params1", "params2": "params2"}
database.encrypted_extra = json.dumps(
{"auth_method": auth_method, "auth_params": auth_params}
)
with patch.dict(
"superset.config.ALLOWED_EXTRA_AUTHENTICATIONS",
{"trino": {"custom_auth": auth_class}},
clear=True,
):
params: Dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
assert connect_args.get("http_scheme") == "https"
auth_class.assert_called_once_with(**auth_params)
def test_auth_custom_auth_denied() -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
auth_method = "my.module:TrinoAuthClass"
auth_params = {"params1": "params1", "params2": "params2"}
database.encrypted_extra = json.dumps(
{"auth_method": auth_method, "auth_params": auth_params}
)
superset.config.ALLOWED_EXTRA_AUTHENTICATIONS = {}
with pytest.raises(ValueError) as excinfo:
TrinoEngineSpec.update_params_from_encrypted_extra(database, {})
assert str(excinfo.value) == (
f"For security reason, custom authentication '{auth_method}' "
f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config"
)
@pytest.mark.parametrize(
"native_type,sqla_type,attrs,generic_type,is_dttm",
[
("BOOLEAN", types.Boolean, None, GenericDataType.BOOLEAN, False),
("TINYINT", types.Integer, None, GenericDataType.NUMERIC, False),
("SMALLINT", types.SmallInteger, None, GenericDataType.NUMERIC, False),
("INTEGER", types.Integer, None, GenericDataType.NUMERIC, False),
("BIGINT", types.BigInteger, None, GenericDataType.NUMERIC, False),
("REAL", types.FLOAT, None, GenericDataType.NUMERIC, False),
("DOUBLE", types.FLOAT, None, GenericDataType.NUMERIC, False),
("DECIMAL", types.DECIMAL, None, GenericDataType.NUMERIC, False),
("VARCHAR", types.String, None, GenericDataType.STRING, False),
("VARCHAR(20)", types.VARCHAR, {"length": 20}, GenericDataType.STRING, False),
("CHAR", types.String, None, GenericDataType.STRING, False),
("CHAR(2)", types.CHAR, {"length": 2}, GenericDataType.STRING, False),
("JSON", types.JSON, None, GenericDataType.STRING, False),
("TIMESTAMP", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True),
("TIMESTAMP(3)", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True),
(
"TIMESTAMP WITH TIME ZONE",
types.TIMESTAMP,
None,
GenericDataType.TEMPORAL,
True,
),
(
"TIMESTAMP(3) WITH TIME ZONE",
types.TIMESTAMP,
None,
GenericDataType.TEMPORAL,
True,
),
("DATE", types.Date, None, GenericDataType.TEMPORAL, True),
],
)
def test_get_column_spec(
native_type: str,
sqla_type: Type[types.TypeEngine],
attrs: Optional[Dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec as spec
assert_column_spec(
spec,
native_type,
sqla_type,
attrs,
generic_type,
is_dttm,
)
@pytest.mark.parametrize(
"target_type,expected_result",
[
("TimeStamp", "TIMESTAMP '2019-01-02 03:04:05.678900'"),
("TimeStamp(3)", "TIMESTAMP '2019-01-02 03:04:05.678900'"),
("TimeStamp With Time Zone", "TIMESTAMP '2019-01-02 03:04:05.678900'"),
("TimeStamp(3) With Time Zone", "TIMESTAMP '2019-01-02 03:04:05.678900'"),
("Date", "DATE '2019-01-02'"),
("Other", None),
],
)
def test_convert_dttm(
target_type: str,
expected_result: Optional[str],
dttm: datetime,
) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
assert_convert_dttm(TrinoEngineSpec, target_type, expected_result, dttm)
def test_extra_table_metadata() -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
db_mock = Mock()
db_mock.get_indexes = Mock(
return_value=[{"column_names": ["ds", "hour"], "name": "partition"}]
)
db_mock.get_extra = Mock(return_value={})
db_mock.has_view_by_name = Mock(return_value=None)
db_mock.get_df = Mock(return_value=pd.DataFrame({"ds": ["01-01-19"], "hour": [1]}))
result = TrinoEngineSpec.extra_table_metadata(db_mock, "test_table", "test_schema")
assert result["partitions"]["cols"] == ["ds", "hour"]
assert result["partitions"]["latest"] == {"ds": "01-01-19", "hour": 1}
@patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query_success(engine_mock: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
@@ -35,8 +306,8 @@ def test_cancel_query_success(engine_mock: mock.Mock) -> None:
assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is True
@mock.patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query_failed(engine_mock: mock.Mock) -> None:
@patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query_failed(engine_mock: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
@@ -67,11 +338,11 @@ def test_prepare_cancel_query(
@pytest.mark.parametrize("cancel_early", [True, False])
@mock.patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query")
@mock.patch("sqlalchemy.engine.Engine.connect")
@patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query")
@patch("sqlalchemy.engine.Engine.connect")
def test_handle_cursor_early_cancel(
engine_mock: mock.Mock,
cancel_query_mock: mock.Mock,
engine_mock: Mock,
cancel_query_mock: Mock,
cancel_early: bool,
mocker: MockerFixture,
) -> None:

View File

@@ -0,0 +1,67 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from datetime import datetime
from typing import Any, Dict, Optional, Type, TYPE_CHECKING
from sqlalchemy import types
from superset.utils.core import GenericDataType
if TYPE_CHECKING:
from superset.db_engine_specs.base import BaseEngineSpec
def assert_convert_dttm(
db_engine_spec: Type[BaseEngineSpec],
target_type: str,
expected_result: Optional[str],
dttm: datetime,
db_extra: Optional[Dict[str, Any]] = None,
) -> None:
for target in (
target_type,
target_type.upper(),
target_type.lower(),
target_type.capitalize(),
):
assert (
result := db_engine_spec.convert_dttm(
target_type=target,
dttm=dttm,
db_extra=db_extra,
)
) == expected_result, result
def assert_column_spec(
db_engine_spec: Type[BaseEngineSpec],
native_type: str,
sqla_type: Type[types.TypeEngine],
attrs: Optional[Dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
assert (column_spec := db_engine_spec.get_column_spec(native_type)) is not None
assert isinstance(column_spec.sqla_type, sqla_type)
for key, value in (attrs or {}).items():
assert getattr(column_spec.sqla_type, key) == value
assert column_spec.generic_type == generic_type
assert column_spec.is_dttm == is_dttm