refactor(tests): decouple unittests from integration tests (#15473)

* refactor move all tests to be under integration_tests package

* refactor decouple unittests from integration tests - commands

* add unit_tests package

* fix celery_tests.py

* fix wrong FIXTURES_DIR value
This commit is contained in:
ofekisr
2021-07-01 18:03:07 +03:00
committed by GitHub
parent 55d0371b92
commit b5119b8dff
172 changed files with 456 additions and 262 deletions

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,32 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.db_engine_specs.ascend import AscendEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestAscendDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
AscendEngineSpec.convert_dttm("DATE", dttm), "CAST('2019-01-02' AS DATE)"
)
self.assertEqual(
AscendEngineSpec.convert_dttm("TIMESTAMP", dttm),
"CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)",
)

View File

@@ -0,0 +1,57 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.db_engine_specs.athena import AthenaEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestAthenaDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
AthenaEngineSpec.convert_dttm("DATE", dttm),
"from_iso8601_date('2019-01-02')",
)
self.assertEqual(
AthenaEngineSpec.convert_dttm("TIMESTAMP", dttm),
"from_iso8601_timestamp('2019-01-02T03:04:05.678900')",
)
def test_extract_errors(self):
"""
Test that custom error messages are extracted correctly.
"""
msg = ": mismatched input 'fromm'. Expecting: "
result = AthenaEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
message='Please check your query for syntax errors at or near "fromm". Then, try running your query again.',
error_type=SupersetErrorType.SYNTAX_ERROR,
level=ErrorLevel.ERROR,
extra={
"engine_name": "Amazon Athena",
"issue_codes": [
{
"code": 1030,
"message": "Issue 1030 - The query has a syntax error.",
}
],
},
)
]

View File

@@ -0,0 +1,468 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
from unittest import mock
import pytest
from superset.db_engine_specs import get_engine_specs
from superset.db_engine_specs.base import (
BaseEngineSpec,
BasicParametersMixin,
builtin_time_grains,
LimitMethod,
)
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.db_engine_specs.sqlite import SqliteEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql_parse import ParsedQuery
from superset.utils.core import get_example_database
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.test_app import app
from ..fixtures.energy_dashboard import load_energy_table_with_slice
from ..fixtures.pyodbcRow import Row
class TestDbEngineSpecs(TestDbEngineSpec):
def test_extract_limit_from_query(self, engine_spec_class=BaseEngineSpec):
q0 = "select * from table"
q1 = "select * from mytable limit 10"
q2 = "select * from (select * from my_subquery limit 10) where col=1 limit 20"
q3 = "select * from (select * from my_subquery limit 10);"
q4 = "select * from (select * from my_subquery limit 10) where col=1 limit 20;"
q5 = "select * from mytable limit 20, 10"
q6 = "select * from mytable limit 10 offset 20"
q7 = "select * from mytable limit"
q8 = "select * from mytable limit 10.0"
q9 = "select * from mytable limit x"
q10 = "select * from mytable limit 20, x"
q11 = "select * from mytable limit x offset 20"
self.assertEqual(engine_spec_class.get_limit_from_sql(q0), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q1), 10)
self.assertEqual(engine_spec_class.get_limit_from_sql(q2), 20)
self.assertEqual(engine_spec_class.get_limit_from_sql(q3), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q4), 20)
self.assertEqual(engine_spec_class.get_limit_from_sql(q5), 10)
self.assertEqual(engine_spec_class.get_limit_from_sql(q6), 10)
self.assertEqual(engine_spec_class.get_limit_from_sql(q7), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q8), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q9), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q10), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q11), None)
def test_wrapped_semi_tabs(self):
self.sql_limit_regex(
"SELECT * FROM a \t \n ; \t \n ", "SELECT * FROM a\nLIMIT 1000"
)
def test_simple_limit_query(self):
self.sql_limit_regex("SELECT * FROM a", "SELECT * FROM a\nLIMIT 1000")
def test_modify_limit_query(self):
self.sql_limit_regex("SELECT * FROM a LIMIT 9999", "SELECT * FROM a LIMIT 1000")
def test_limit_query_with_limit_subquery(self): # pylint: disable=invalid-name
self.sql_limit_regex(
"SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999",
"SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000",
)
def test_limit_query_without_force(self):
self.sql_limit_regex(
"SELECT * FROM a LIMIT 10", "SELECT * FROM a LIMIT 10", limit=11,
)
def test_limit_query_with_force(self):
self.sql_limit_regex(
"SELECT * FROM a LIMIT 10",
"SELECT * FROM a LIMIT 11",
limit=11,
force=True,
)
def test_limit_with_expr(self):
self.sql_limit_regex(
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990""",
"""SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 1000""",
)
def test_limit_expr_and_semicolon(self):
self.sql_limit_regex(
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990 ;""",
"""SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 1000""",
)
def test_get_datatype(self):
self.assertEqual("VARCHAR", BaseEngineSpec.get_datatype("VARCHAR"))
def test_limit_with_implicit_offset(self):
self.sql_limit_regex(
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990, 999999""",
"""SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990, 1000""",
)
def test_limit_with_explicit_offset(self):
self.sql_limit_regex(
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990
OFFSET 999999""",
"""SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 1000
OFFSET 999999""",
)
def test_limit_with_non_token_limit(self):
self.sql_limit_regex(
"""SELECT 'LIMIT 777'""", """SELECT 'LIMIT 777'\nLIMIT 1000"""
)
def test_limit_with_fetch_many(self):
class DummyEngineSpec(BaseEngineSpec):
limit_method = LimitMethod.FETCH_MANY
self.sql_limit_regex(
"SELECT * FROM table", "SELECT * FROM table", DummyEngineSpec
)
def test_engine_time_grain_validity(self):
time_grains = set(builtin_time_grains.keys())
# loop over all subclasses of BaseEngineSpec
for engine in get_engine_specs().values():
if engine is not BaseEngineSpec:
# make sure time grain functions have been defined
self.assertGreater(len(engine.get_time_grain_expressions()), 0)
# make sure all defined time grains are supported
defined_grains = {grain.duration for grain in engine.get_time_grains()}
intersection = time_grains.intersection(defined_grains)
self.assertSetEqual(defined_grains, intersection, engine)
def test_get_time_grain_expressions(self):
time_grains = MySQLEngineSpec.get_time_grain_expressions()
self.assertEqual(
list(time_grains.keys()),
[
None,
"PT1S",
"PT1M",
"PT1H",
"P1D",
"P1W",
"P1M",
"P0.25Y",
"P1Y",
"1969-12-29T00:00:00Z/P1W",
],
)
def test_get_table_names(self):
inspector = mock.Mock()
inspector.get_table_names = mock.Mock(return_value=["schema.table", "table_2"])
inspector.get_foreign_table_names = mock.Mock(return_value=["table_3"])
""" Make sure base engine spec removes schema name from table name
ie. when try_remove_schema_from_table_name == True. """
base_result_expected = ["table", "table_2"]
base_result = BaseEngineSpec.get_table_names(
database=mock.ANY, schema="schema", inspector=inspector
)
self.assertListEqual(base_result_expected, base_result)
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_column_datatype_to_string(self):
example_db = get_example_database()
sqla_table = example_db.get_table("energy_usage")
dialect = example_db.get_dialect()
# TODO: fix column type conversion for presto.
if example_db.backend == "presto":
return
col_names = [
example_db.db_engine_spec.column_datatype_to_string(c.type, dialect)
for c in sqla_table.columns
]
if example_db.backend == "postgresql":
expected = ["VARCHAR(255)", "VARCHAR(255)", "DOUBLE PRECISION"]
elif example_db.backend == "hive":
expected = ["STRING", "STRING", "FLOAT"]
else:
expected = ["VARCHAR(255)", "VARCHAR(255)", "FLOAT"]
self.assertEqual(col_names, expected)
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertIsNone(BaseEngineSpec.convert_dttm("", dttm))
def test_pyodbc_rows_to_tuples(self):
# Test for case when pyodbc.Row is returned (odbc driver)
data = [
Row((1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000))),
Row((2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000))),
]
expected = [
(1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000)),
(2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)),
]
result = BaseEngineSpec.pyodbc_rows_to_tuples(data)
self.assertListEqual(result, expected)
def test_pyodbc_rows_to_tuples_passthrough(self):
# Test for case when tuples are returned
data = [
(1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000)),
(2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)),
]
result = BaseEngineSpec.pyodbc_rows_to_tuples(data)
self.assertListEqual(result, data)
def test_is_readonly():
def is_readonly(sql: str) -> bool:
return BaseEngineSpec.is_readonly_query(ParsedQuery(sql))
assert is_readonly("SHOW LOCKS test EXTENDED")
assert not is_readonly("SET hivevar:desc='Legislators'")
assert not is_readonly("UPDATE t1 SET col1 = NULL")
assert is_readonly("EXPLAIN SELECT 1")
assert is_readonly("SELECT 1")
assert is_readonly("WITH (SELECT 1) bla SELECT * from bla")
assert is_readonly("SHOW CATALOGS")
assert is_readonly("SHOW TABLES")
def test_time_grain_denylist():
config = app.config.copy()
app.config["TIME_GRAIN_DENYLIST"] = ["PT1M"]
with app.app_context():
time_grain_functions = SqliteEngineSpec.get_time_grain_expressions()
assert not "PT1M" in time_grain_functions
app.config = config
def test_time_grain_addons():
config = app.config.copy()
app.config["TIME_GRAIN_ADDONS"] = {"PTXM": "x seconds"}
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {"sqlite": {"PTXM": "ABC({col})"}}
with app.app_context():
time_grains = SqliteEngineSpec.get_time_grains()
time_grain_addon = time_grains[-1]
assert "PTXM" == time_grain_addon.duration
assert "x seconds" == time_grain_addon.label
app.config = config
def test_get_time_grain_with_config():
""" Should concatenate from configs and then sort in the proper order """
config = app.config.copy()
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
"mysql": {
"PT2H": "foo",
"PT4H": "foo",
"PT6H": "foo",
"PT8H": "foo",
"PT10H": "foo",
"PT12H": "foo",
"PT1S": "foo",
}
}
with app.app_context():
time_grains = MySQLEngineSpec.get_time_grain_expressions()
assert set(time_grains.keys()) == {
None,
"PT1S",
"PT1M",
"PT1H",
"PT2H",
"PT4H",
"PT6H",
"PT8H",
"PT10H",
"PT12H",
"P1D",
"P1W",
"P1M",
"P0.25Y",
"P1Y",
"1969-12-29T00:00:00Z/P1W",
}
app.config = config
def test_get_time_grain_with_unkown_values():
"""Should concatenate from configs and then sort in the proper order
putting unknown patterns at the end"""
config = app.config.copy()
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
"mysql": {"PT2H": "foo", "weird": "foo", "PT12H": "foo",}
}
with app.app_context():
time_grains = MySQLEngineSpec.get_time_grain_expressions()
assert list(time_grains)[-1] == "weird"
app.config = config
@mock.patch("superset.db_engine_specs.base.is_hostname_valid")
@mock.patch("superset.db_engine_specs.base.is_port_open")
def test_validate(is_port_open, is_hostname_valid):
is_hostname_valid.return_value = True
is_port_open.return_value = True
parameters = {
"host": "localhost",
"port": 5432,
"username": "username",
"password": "password",
"database": "dbname",
"query": {"sslmode": "verify-full"},
}
errors = BasicParametersMixin.validate_parameters(parameters)
assert errors == []
def test_validate_parameters_missing():
parameters = {
"host": "",
"port": None,
"username": "",
"password": "",
"database": "",
"query": {},
}
errors = BasicParametersMixin.validate_parameters(parameters)
assert errors == [
SupersetError(
message=(
"One or more parameters are missing: " "database, host, port, username"
),
error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR,
level=ErrorLevel.WARNING,
extra={"missing": ["database", "host", "port", "username"]},
),
]
@mock.patch("superset.db_engine_specs.base.is_hostname_valid")
def test_validate_parameters_invalid_host(is_hostname_valid):
is_hostname_valid.return_value = False
parameters = {
"host": "localhost",
"port": None,
"username": "username",
"password": "password",
"database": "dbname",
"query": {"sslmode": "verify-full"},
}
errors = BasicParametersMixin.validate_parameters(parameters)
assert errors == [
SupersetError(
message="One or more parameters are missing: port",
error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR,
level=ErrorLevel.WARNING,
extra={"missing": ["port"]},
),
SupersetError(
message="The hostname provided can't be resolved.",
error_type=SupersetErrorType.CONNECTION_INVALID_HOSTNAME_ERROR,
level=ErrorLevel.ERROR,
extra={"invalid": ["host"]},
),
]
@mock.patch("superset.db_engine_specs.base.is_hostname_valid")
@mock.patch("superset.db_engine_specs.base.is_port_open")
def test_validate_parameters_port_closed(is_port_open, is_hostname_valid):
is_hostname_valid.return_value = True
is_port_open.return_value = False
parameters = {
"host": "localhost",
"port": 5432,
"username": "username",
"password": "password",
"database": "dbname",
"query": {"sslmode": "verify-full"},
}
errors = BasicParametersMixin.validate_parameters(parameters)
assert errors == [
SupersetError(
message="The port is closed.",
error_type=SupersetErrorType.CONNECTION_PORT_CLOSED_ERROR,
level=ErrorLevel.ERROR,
extra={
"invalid": ["port"],
"issue_codes": [
{"code": 1008, "message": "Issue 1008 - The port is closed."}
],
},
)
]

View File

@@ -0,0 +1,52 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
from datetime import datetime
from typing import Tuple, Type
from tests.integration_tests.test_app import app
from tests.integration_tests.base_tests import SupersetTestCase
from superset.db_engine_specs.base import BaseEngineSpec
from superset.models.core import Database
from superset.utils.core import GenericDataType
class TestDbEngineSpec(SupersetTestCase):
def sql_limit_regex(
self,
sql,
expected_sql,
engine_spec_class=BaseEngineSpec,
limit=1000,
force=False,
):
main = Database(database_name="test_database", sqlalchemy_uri="sqlite://")
limited = engine_spec_class.apply_limit_to_sql(sql, limit, main, force)
self.assertEqual(expected_sql, limited)
def assert_generic_types(
spec: Type[BaseEngineSpec],
type_expectations: Tuple[Tuple[str, GenericDataType], ...],
) -> None:
for type_str, expected_type in type_expectations:
column_spec = spec.get_column_spec(type_str)
assert column_spec is not None
actual_type = column_spec.generic_type
assert (
actual_type == expected_type
), f"{type_str} should be {expected_type.name} but is {actual_type.name}"

View File

@@ -0,0 +1,337 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import sys
import unittest.mock as mock
from pandas import DataFrame
from sqlalchemy import column
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.bigquery import BigQueryEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql_parse import Table
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestBigQueryDbEngineSpec(TestDbEngineSpec):
def test_bigquery_sqla_column_label(self):
"""
DB Eng Specs (bigquery): Test column label
"""
test_cases = {
"Col": "Col",
"SUM(x)": "SUM_x__5f110",
"SUM[x]": "SUM_x__7ebe1",
"12345_col": "_12345_col_8d390",
}
for original, expected in test_cases.items():
actual = BigQueryEngineSpec.make_label_compatible(column(original).name)
self.assertEqual(actual, expected)
def test_convert_dttm(self):
"""
DB Eng Specs (bigquery): Test conversion to date time
"""
dttm = self.get_dttm()
test_cases = {
"DATE": "CAST('2019-01-02' AS DATE)",
"DATETIME": "CAST('2019-01-02T03:04:05.678900' AS DATETIME)",
"TIMESTAMP": "CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)",
"TIME": "CAST('03:04:05.678900' AS TIME)",
"UNKNOWNTYPE": None,
}
for target_type, expected in test_cases.items():
actual = BigQueryEngineSpec.convert_dttm(target_type, dttm)
self.assertEqual(actual, expected)
def test_timegrain_expressions(self):
"""
DB Eng Specs (bigquery): Test time grain expressions
"""
col = column("temporal")
test_cases = {
"DATE": "DATE_TRUNC(temporal, HOUR)",
"TIME": "TIME_TRUNC(temporal, HOUR)",
"DATETIME": "DATETIME_TRUNC(temporal, HOUR)",
"TIMESTAMP": "TIMESTAMP_TRUNC(temporal, HOUR)",
}
for type_, expected in test_cases.items():
actual = BigQueryEngineSpec.get_timestamp_expr(
col=col, pdf=None, time_grain="PT1H", type_=type_
)
self.assertEqual(str(actual), expected)
def test_custom_minute_timegrain_expressions(self):
"""
DB Eng Specs (bigquery): Test time grain expressions
"""
col = column("temporal")
test_cases = {
"DATE": "CAST(TIMESTAMP_SECONDS("
"5*60 * DIV(UNIX_SECONDS(CAST(temporal AS TIMESTAMP)), 5*60)"
") AS DATE)",
"DATETIME": "CAST(TIMESTAMP_SECONDS("
"5*60 * DIV(UNIX_SECONDS(CAST(temporal AS TIMESTAMP)), 5*60)"
") AS DATETIME)",
"TIMESTAMP": "CAST(TIMESTAMP_SECONDS("
"5*60 * DIV(UNIX_SECONDS(CAST(temporal AS TIMESTAMP)), 5*60)"
") AS TIMESTAMP)",
}
for type_, expected in test_cases.items():
actual = BigQueryEngineSpec.get_timestamp_expr(
col=col, pdf=None, time_grain="PT5M", type_=type_
)
assert str(actual) == expected
def test_fetch_data(self):
"""
DB Eng Specs (bigquery): Test fetch data
"""
# Mock a google.cloud.bigquery.table.Row
class Row(object):
def __init__(self, value):
self._value = value
def values(self):
return self._value
data1 = [(1, "foo")]
with mock.patch.object(BaseEngineSpec, "fetch_data", return_value=data1):
result = BigQueryEngineSpec.fetch_data(None, 0)
self.assertEqual(result, data1)
data2 = [Row(1), Row(2)]
with mock.patch.object(BaseEngineSpec, "fetch_data", return_value=data2):
result = BigQueryEngineSpec.fetch_data(None, 0)
self.assertEqual(result, [1, 2])
def test_extra_table_metadata(self):
"""
DB Eng Specs (bigquery): Test extra table metadata
"""
database = mock.Mock()
# Test no indexes
database.get_indexes = mock.MagicMock(return_value=None)
result = BigQueryEngineSpec.extra_table_metadata(
database, "some_table", "some_schema"
)
self.assertEqual(result, {})
index_metadata = [
{"name": "clustering", "column_names": ["c_col1", "c_col2", "c_col3"],},
{"name": "partition", "column_names": ["p_col1", "p_col2", "p_col3"],},
]
expected_result = {
"partitions": {"cols": [["p_col1", "p_col2", "p_col3"]]},
"clustering": {"cols": [["c_col1", "c_col2", "c_col3"]]},
}
database.get_indexes = mock.MagicMock(return_value=index_metadata)
result = BigQueryEngineSpec.extra_table_metadata(
database, "some_table", "some_schema"
)
self.assertEqual(result, expected_result)
def test_normalize_indexes(self):
"""
DB Eng Specs (bigquery): Test extra table metadata
"""
indexes = [{"name": "partition", "column_names": [None], "unique": False}]
normalized_idx = BigQueryEngineSpec.normalize_indexes(indexes)
self.assertEqual(normalized_idx, [])
indexes = [{"name": "partition", "column_names": ["dttm"], "unique": False}]
normalized_idx = BigQueryEngineSpec.normalize_indexes(indexes)
self.assertEqual(normalized_idx, indexes)
indexes = [
{"name": "partition", "column_names": ["dttm", None], "unique": False}
]
normalized_idx = BigQueryEngineSpec.normalize_indexes(indexes)
self.assertEqual(
normalized_idx,
[{"name": "partition", "column_names": ["dttm"], "unique": False}],
)
@mock.patch("superset.db_engine_specs.bigquery.BigQueryEngineSpec.get_engine")
def test_df_to_sql(self, mock_get_engine):
"""
DB Eng Specs (bigquery): Test DataFrame to SQL contract
"""
# test missing google.oauth2 dependency
sys.modules["pandas_gbq"] = mock.MagicMock()
df = DataFrame()
database = mock.MagicMock()
self.assertRaisesRegexp(
Exception,
"Could not import libraries",
BigQueryEngineSpec.df_to_sql,
database=database,
table=Table(table="name", schema="schema"),
df=df,
to_sql_kwargs={},
)
invalid_kwargs = [
{"name": "some_name"},
{"schema": "some_schema"},
{"con": "some_con"},
{"name": "some_name", "con": "some_con"},
{"name": "some_name", "schema": "some_schema"},
{"con": "some_con", "schema": "some_schema"},
]
# Test check for missing schema.
sys.modules["google.oauth2"] = mock.MagicMock()
for invalid_kwarg in invalid_kwargs:
self.assertRaisesRegexp(
Exception,
"The table schema must be defined",
BigQueryEngineSpec.df_to_sql,
database=database,
table=Table(table="name"),
df=df,
to_sql_kwargs=invalid_kwarg,
)
import pandas_gbq
from google.oauth2 import service_account
pandas_gbq.to_gbq = mock.Mock()
service_account.Credentials.from_service_account_info = mock.MagicMock(
return_value="account_info"
)
mock_get_engine.return_value.url.host = "google-host"
mock_get_engine.return_value.dialect.credentials_info = "secrets"
BigQueryEngineSpec.df_to_sql(
database=database,
table=Table(table="name", schema="schema"),
df=df,
to_sql_kwargs={"if_exists": "extra_key"},
)
pandas_gbq.to_gbq.assert_called_with(
df,
project_id="google-host",
destination_table="schema.name",
credentials="account_info",
if_exists="extra_key",
)
def test_extract_errors(self):
msg = "403 POST https://bigquery.googleapis.com/bigquery/v2/projects/test-keel-310804/jobs?prettyPrint=false: Access Denied: Project User does not have bigquery.jobs.create permission in project profound-keel-310804"
result = BigQueryEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
message="We were unable to connect to your database. Please confirm that your service account has the Viewer and Job User roles on the project.",
error_type=SupersetErrorType.CONNECTION_DATABASE_PERMISSIONS_ERROR,
level=ErrorLevel.ERROR,
extra={
"engine_name": "Google BigQuery",
"issue_codes": [{"code": 1017, "message": "",}],
},
)
]
msg = "bigquery error: 404 Not found: Dataset fakeDataset:bogusSchema was not found in location"
result = BigQueryEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
message='The schema "bogusSchema" does not exist. A valid schema must be used to run this query.',
error_type=SupersetErrorType.SCHEMA_DOES_NOT_EXIST_ERROR,
level=ErrorLevel.ERROR,
extra={
"engine_name": "Google BigQuery",
"issue_codes": [
{
"code": 1003,
"message": "Issue 1003 - There is a syntax error in the SQL query. Perhaps there was a misspelling or a typo.",
},
{
"code": 1004,
"message": "Issue 1004 - The column was deleted or renamed in the database.",
},
],
},
)
]
msg = 'Table name "badtable" missing dataset while no default dataset is set in the request'
result = BigQueryEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
message='The table "badtable" does not exist. A valid table must be used to run this query.',
error_type=SupersetErrorType.TABLE_DOES_NOT_EXIST_ERROR,
level=ErrorLevel.ERROR,
extra={
"engine_name": "Google BigQuery",
"issue_codes": [
{
"code": 1003,
"message": "Issue 1003 - There is a syntax error in the SQL query. Perhaps there was a misspelling or a typo.",
},
{
"code": 1005,
"message": "Issue 1005 - The table was deleted or renamed in the database.",
},
],
},
)
]
msg = "Unrecognized name: badColumn at [1:8]"
result = BigQueryEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
message='We can\'t seem to resolve column "badColumn" at line 1:8.',
error_type=SupersetErrorType.COLUMN_DOES_NOT_EXIST_ERROR,
level=ErrorLevel.ERROR,
extra={
"engine_name": "Google BigQuery",
"issue_codes": [
{
"code": 1003,
"message": "Issue 1003 - There is a syntax error in the SQL query. Perhaps there was a misspelling or a typo.",
},
{
"code": 1004,
"message": "Issue 1004 - The column was deleted or renamed in the database.",
},
],
},
)
]
msg = 'Syntax error: Expected end of input but got identifier "fromm"'
result = BigQueryEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
message='Please check your query for syntax errors at or near "fromm". Then, try running your query again.',
error_type=SupersetErrorType.SYNTAX_ERROR,
level=ErrorLevel.ERROR,
extra={
"engine_name": "Google BigQuery",
"issue_codes": [
{
"code": 1030,
"message": "Issue 1030 - The query has a syntax error.",
}
],
},
)
]

View File

@@ -0,0 +1,47 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest import mock
import pytest
from superset.db_engine_specs.clickhouse import ClickHouseEngineSpec
from superset.db_engine_specs.exceptions import SupersetDBAPIDatabaseError
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestClickHouseDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
ClickHouseEngineSpec.convert_dttm("DATE", dttm), "toDate('2019-01-02')"
)
self.assertEqual(
ClickHouseEngineSpec.convert_dttm("DATETIME", dttm),
"toDateTime('2019-01-02 03:04:05')",
)
def test_execute_connection_error(self):
from urllib3.exceptions import NewConnectionError
cursor = mock.Mock()
cursor.execute.side_effect = NewConnectionError(
"Dummypool", message="Exception with sensitive data"
)
with pytest.raises(SupersetDBAPIDatabaseError) as ex:
ClickHouseEngineSpec.execute(cursor, "SELECT col1 from table1")

View File

@@ -0,0 +1,53 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.db_engine_specs.crate import CrateEngineSpec
from superset.models.core import Database
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestCrateDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
"""
DB Eng Specs (crate): Test conversion to date time
"""
dttm = self.get_dttm()
assert CrateEngineSpec.convert_dttm("TIMESTAMP", dttm) == str(
dttm.timestamp() * 1000
)
def test_epoch_to_dttm(self):
"""
DB Eng Specs (crate): Test epoch to dttm
"""
assert CrateEngineSpec.epoch_to_dttm() == "{col} * 1000"
def test_epoch_ms_to_dttm(self):
"""
DB Eng Specs (crate): Test epoch ms to dttm
"""
assert CrateEngineSpec.epoch_ms_to_dttm() == "{col}"
def test_alter_new_orm_column(self):
"""
DB Eng Specs (crate): Test alter orm column
"""
database = Database(database_name="crate", sqlalchemy_uri="crate://db")
tbl = SqlaTable(table_name="druid_tbl", database=database)
col = TableColumn(column_name="ts", type="TIMESTAMP", table=tbl)
CrateEngineSpec.alter_new_orm_column(col)
assert col.python_date_format == "epoch_ms"

View File

@@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.db_engine_specs.dremio import DremioEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestDremioDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
DremioEngineSpec.convert_dttm("DATE", dttm),
"TO_DATE('2019-01-02', 'YYYY-MM-DD')",
)
self.assertEqual(
DremioEngineSpec.convert_dttm("TIMESTAMP", dttm),
"TO_TIMESTAMP('2019-01-02 03:04:05.678', 'YYYY-MM-DD HH24:MI:SS.FFF')",
)

View File

@@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.db_engine_specs.drill import DrillEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestDrillDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
DrillEngineSpec.convert_dttm("DATE", dttm),
"TO_DATE('2019-01-02', 'yyyy-MM-dd')",
)
self.assertEqual(
DrillEngineSpec.convert_dttm("TIMESTAMP", dttm),
"TO_TIMESTAMP('2019-01-02 03:04:05', 'yyyy-MM-dd HH:mm:ss')",
)

View File

@@ -0,0 +1,76 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest import mock
from sqlalchemy import column
from superset.db_engine_specs.druid import DruidEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.fixtures.certificates import ssl_certificate
from tests.integration_tests.fixtures.database import default_db_extra
class TestDruidDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
DruidEngineSpec.convert_dttm("DATETIME", dttm),
"TIME_PARSE('2019-01-02T03:04:05')",
)
self.assertEqual(
DruidEngineSpec.convert_dttm("TIMESTAMP", dttm),
"TIME_PARSE('2019-01-02T03:04:05')",
)
self.assertEqual(
DruidEngineSpec.convert_dttm("DATE", dttm),
"CAST(TIME_PARSE('2019-01-02') AS DATE)",
)
def test_timegrain_expressions(self):
"""
DB Eng Specs (druid): Test time grain expressions
"""
col = "__time"
sqla_col = column(col)
test_cases = {
"PT1S": f"FLOOR({col} TO SECOND)",
"PT5M": f"TIME_FLOOR({col}, 'PT5M')",
}
for grain, expected in test_cases.items():
actual = DruidEngineSpec.get_timestamp_expr(
col=sqla_col, pdf=None, time_grain=grain
)
self.assertEqual(str(actual), expected)
def test_extras_without_ssl(self):
db = mock.Mock()
db.extra = default_db_extra
db.server_cert = None
extras = DruidEngineSpec.get_extra_params(db)
assert "connect_args" not in extras["engine_params"]
def test_extras_with_ssl(self):
db = mock.Mock()
db.extra = default_db_extra
db.server_cert = ssl_certificate
extras = DruidEngineSpec.get_extra_params(db)
connect_args = extras["engine_params"]["connect_args"]
assert connect_args["scheme"] == "https"
assert "ssl_verify_cert" in connect_args

View File

@@ -0,0 +1,70 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest.mock import MagicMock
from sqlalchemy import column
from superset.db_engine_specs.elasticsearch import (
ElasticSearchEngineSpec,
OpenDistroEngineSpec,
)
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestElasticSearchDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
ElasticSearchEngineSpec.convert_dttm("DATETIME", dttm),
"CAST('2019-01-02T03:04:05' AS DATETIME)",
)
def test_opendistro_convert_dttm(self):
"""
DB Eng Specs (opendistro): Test convert_dttm
"""
dttm = self.get_dttm()
self.assertEqual(
OpenDistroEngineSpec.convert_dttm("DATETIME", dttm),
"'2019-01-02T03:04:05'",
)
def test_opendistro_sqla_column_label(self):
"""
DB Eng Specs (opendistro): Test column label
"""
test_cases = {
"Col": "Col",
"Col.keyword": "Col_keyword",
}
for original, expected in test_cases.items():
actual = OpenDistroEngineSpec.make_label_compatible(column(original).name)
self.assertEqual(actual, expected)
def test_opendistro_strip_comments(self):
"""
DB Eng Specs (opendistro): Test execute sql strip comments
"""
mock_cursor = MagicMock()
mock_cursor.execute.return_value = []
OpenDistroEngineSpec.execute(
mock_cursor, "-- some comment \nSELECT 1\n --other comment"
)
mock_cursor.execute.assert_called_once_with("SELECT 1\n")

View File

@@ -0,0 +1,81 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from unittest import mock
import pytest
from superset.db_engine_specs.firebird import FirebirdEngineSpec
grain_expressions = {
None: "timestamp_column",
"PT1S": (
"CAST(CAST(timestamp_column AS DATE) "
"|| ' ' "
"|| EXTRACT(HOUR FROM timestamp_column) "
"|| ':' "
"|| EXTRACT(MINUTE FROM timestamp_column) "
"|| ':' "
"|| FLOOR(EXTRACT(SECOND FROM timestamp_column)) AS TIMESTAMP)"
),
"PT1M": (
"CAST(CAST(timestamp_column AS DATE) "
"|| ' ' "
"|| EXTRACT(HOUR FROM timestamp_column) "
"|| ':' "
"|| EXTRACT(MINUTE FROM timestamp_column) "
"|| ':00' AS TIMESTAMP)"
),
"P1D": "CAST(timestamp_column AS DATE)",
"P1M": (
"CAST(EXTRACT(YEAR FROM timestamp_column) "
"|| '-' "
"|| EXTRACT(MONTH FROM timestamp_column) "
"|| '-01' AS DATE)"
),
"P1Y": "CAST(EXTRACT(YEAR FROM timestamp_column) || '-01-01' AS DATE)",
}
@pytest.mark.parametrize("grain,expected", grain_expressions.items())
def test_time_grain_expressions(grain, expected):
assert (
FirebirdEngineSpec._time_grain_expressions[grain].format(col="timestamp_column")
== expected
)
def test_epoch_to_dttm():
assert (
FirebirdEngineSpec.epoch_to_dttm().format(col="timestamp_column")
== "DATEADD(second, timestamp_column, CAST('00:00:00' AS TIMESTAMP))"
)
def test_convert_dttm():
dttm = datetime(2021, 1, 1)
assert (
FirebirdEngineSpec.convert_dttm("timestamp", dttm)
== "CAST('2021-01-01 00:00:00' AS TIMESTAMP)"
)
assert (
FirebirdEngineSpec.convert_dttm("TIMESTAMP", dttm)
== "CAST('2021-01-01 00:00:00' AS TIMESTAMP)"
)
assert FirebirdEngineSpec.convert_dttm("TIME", dttm) == "CAST('00:00:00' AS TIME)"
assert FirebirdEngineSpec.convert_dttm("DATE", dttm) == "CAST('2021-01-01' AS DATE)"
assert FirebirdEngineSpec.convert_dttm("STRING", dttm) is None

View File

@@ -0,0 +1,44 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestGsheetsDbEngineSpec(TestDbEngineSpec):
def test_extract_errors(self):
"""
Test that custom error messages are extracted correctly.
"""
msg = 'SQLError: near "fromm": syntax error'
result = GSheetsEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
message='Please check your query for syntax errors near "fromm". Then, try running your query again.',
error_type=SupersetErrorType.SYNTAX_ERROR,
level=ErrorLevel.ERROR,
extra={
"engine_name": "Google Sheets",
"issue_codes": [
{
"code": 1030,
"message": "Issue 1030 - The query has a syntax error.",
}
],
},
)
]

View File

@@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.db_engine_specs.hana import HanaEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestHanaDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
HanaEngineSpec.convert_dttm("DATE", dttm),
"TO_DATE('2019-01-02', 'YYYY-MM-DD')",
)
self.assertEqual(
HanaEngineSpec.convert_dttm("TIMESTAMP", dttm),
"TO_TIMESTAMP('2019-01-02T03:04:05.678900', 'YYYY-MM-DD\"T\"HH24:MI:SS.ff6')",
)

View File

@@ -0,0 +1,381 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
from datetime import datetime
from unittest import mock
import pytest
import pandas as pd
from sqlalchemy.sql import select
from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3
from superset.exceptions import SupersetException
from superset.sql_parse import Table, ParsedQuery
from tests.integration_tests.test_app import app
def test_0_progress():
log = """
17/02/07 18:26:27 INFO log.PerfLogger: <PERFLOG method=compile from=org.apache.hadoop.hive.ql.Driver>
17/02/07 18:26:27 INFO log.PerfLogger: <PERFLOG method=parse from=org.apache.hadoop.hive.ql.Driver>
""".split(
"\n"
)
assert HiveEngineSpec.progress(log) == 0
def test_number_of_jobs_progress():
log = """
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
""".split(
"\n"
)
assert HiveEngineSpec.progress(log) == 0
def test_job_1_launched_progress():
log = """
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
""".split(
"\n"
)
assert HiveEngineSpec.progress(log) == 0
def test_job_1_launched_stage_1():
log = """
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
""".split(
"\n"
)
assert HiveEngineSpec.progress(log) == 0
def test_job_1_launched_stage_1_map_40_progress(): # pylint: disable=invalid-name
log = """
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0%
""".split(
"\n"
)
assert HiveEngineSpec.progress(log) == 10
def test_job_1_launched_stage_1_map_80_reduce_40_progress(): # pylint: disable=invalid-name
log = """
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0%
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 80%, reduce = 40%
""".split(
"\n"
)
assert HiveEngineSpec.progress(log) == 30
def test_job_1_launched_stage_2_stages_progress(): # pylint: disable=invalid-name
log = """
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0%
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 80%, reduce = 40%
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-2 map = 0%, reduce = 0%
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 100%, reduce = 0%
""".split(
"\n"
)
assert HiveEngineSpec.progress(log) == 12
def test_job_2_launched_stage_2_stages_progress(): # pylint: disable=invalid-name
log = """
17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 100%, reduce = 0%
17/02/07 19:15:55 INFO ql.Driver: Launching Job 2 out of 2
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%, reduce = 0%
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0%
""".split(
"\n"
)
assert HiveEngineSpec.progress(log) == 60
def test_hive_error_msg():
msg = (
'{...} errorMessage="Error while compiling statement: FAILED: '
"SemanticException [Error 10001]: Line 4"
":5 Table not found 'fact_ridesfdslakj'\", statusCode=3, "
"sqlState='42S02', errorCode=10001)){...}"
)
assert HiveEngineSpec.extract_error_message(Exception(msg)) == (
"hive error: Error while compiling statement: FAILED: "
"SemanticException [Error 10001]: Line 4:5 "
"Table not found 'fact_ridesfdslakj'"
)
e = Exception("Some string that doesn't match the regex")
assert HiveEngineSpec.extract_error_message(e) == f"hive error: {e}"
msg = (
"errorCode=10001, "
'errorMessage="Error while compiling statement"), operationHandle'
'=None)"'
)
assert (
HiveEngineSpec.extract_error_message(Exception(msg))
== "hive error: Error while compiling statement"
)
def test_hive_get_view_names_return_empty_list(): # pylint: disable=invalid-name
assert HiveEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY) == []
def test_convert_dttm():
dttm = datetime.strptime("2019-01-02 03:04:05.678900", "%Y-%m-%d %H:%M:%S.%f")
assert HiveEngineSpec.convert_dttm("DATE", dttm) == "CAST('2019-01-02' AS DATE)"
assert (
HiveEngineSpec.convert_dttm("TIMESTAMP", dttm)
== "CAST('2019-01-02 03:04:05.678900' AS TIMESTAMP)"
)
def test_df_to_csv() -> None:
with pytest.raises(SupersetException):
HiveEngineSpec.df_to_sql(
mock.MagicMock(), Table("foobar"), pd.DataFrame(), {"if_exists": "append"},
)
@mock.patch("superset.db_engine_specs.hive.g", spec={})
def test_df_to_sql_if_exists_fail(mock_g):
mock_g.user = True
mock_database = mock.MagicMock()
mock_database.get_df.return_value.empty = False
with pytest.raises(SupersetException, match="Table already exists"):
HiveEngineSpec.df_to_sql(
mock_database, Table("foobar"), pd.DataFrame(), {"if_exists": "fail"}
)
@mock.patch("superset.db_engine_specs.hive.g", spec={})
def test_df_to_sql_if_exists_fail_with_schema(mock_g):
mock_g.user = True
mock_database = mock.MagicMock()
mock_database.get_df.return_value.empty = False
with pytest.raises(SupersetException, match="Table already exists"):
HiveEngineSpec.df_to_sql(
mock_database,
Table(table="foobar", schema="schema"),
pd.DataFrame(),
{"if_exists": "fail"},
)
@mock.patch("superset.db_engine_specs.hive.g", spec={})
@mock.patch("superset.db_engine_specs.hive.upload_to_s3")
def test_df_to_sql_if_exists_replace(mock_upload_to_s3, mock_g):
config = app.config.copy()
app.config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"]: lambda *args: ""
mock_upload_to_s3.return_value = "mock-location"
mock_g.user = True
mock_database = mock.MagicMock()
mock_database.get_df.return_value.empty = False
mock_execute = mock.MagicMock(return_value=True)
mock_database.get_sqla_engine.return_value.execute = mock_execute
table_name = "foobar"
with app.app_context():
HiveEngineSpec.df_to_sql(
mock_database,
Table(table=table_name),
pd.DataFrame(),
{"if_exists": "replace", "header": 1, "na_values": "mock", "sep": "mock"},
)
mock_execute.assert_any_call(f"DROP TABLE IF EXISTS {table_name}")
app.config = config
@mock.patch("superset.db_engine_specs.hive.g", spec={})
@mock.patch("superset.db_engine_specs.hive.upload_to_s3")
def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g):
config = app.config.copy()
app.config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"]: lambda *args: ""
mock_upload_to_s3.return_value = "mock-location"
mock_g.user = True
mock_database = mock.MagicMock()
mock_database.get_df.return_value.empty = False
mock_execute = mock.MagicMock(return_value=True)
mock_database.get_sqla_engine.return_value.execute = mock_execute
table_name = "foobar"
schema = "schema"
with app.app_context():
HiveEngineSpec.df_to_sql(
mock_database,
Table(table=table_name, schema=schema),
pd.DataFrame(),
{"if_exists": "replace", "header": 1, "na_values": "mock", "sep": "mock"},
)
mock_execute.assert_any_call(f"DROP TABLE IF EXISTS {schema}.{table_name}")
app.config = config
def test_is_readonly():
def is_readonly(sql: str) -> bool:
return HiveEngineSpec.is_readonly_query(ParsedQuery(sql))
assert not is_readonly("UPDATE t1 SET col1 = NULL")
assert not is_readonly("INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA")
assert is_readonly("SHOW LOCKS test EXTENDED")
assert is_readonly("SET hivevar:desc='Legislators'")
assert is_readonly("EXPLAIN SELECT 1")
assert is_readonly("SELECT 1")
assert is_readonly("WITH (SELECT 1) bla SELECT * from bla")
@pytest.mark.parametrize(
"schema,upload_prefix",
[("foo", "EXTERNAL_HIVE_TABLES/1/foo/"), (None, "EXTERNAL_HIVE_TABLES/1/")],
)
def test_s3_upload_prefix(schema: str, upload_prefix: str) -> None:
mock_database = mock.MagicMock()
mock_database.id = 1
assert (
app.config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"](
database=mock_database, user=mock.MagicMock(), schema=schema
)
== upload_prefix
)
def test_upload_to_s3_no_bucket_path():
with app.app_context():
with pytest.raises(
Exception,
match="No upload bucket specified. You can specify one in the config file.",
):
upload_to_s3("filename", "prefix", Table("table"))
@mock.patch("boto3.client")
def test_upload_to_s3_client_error(client):
config = app.config.copy()
app.config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"] = "bucket"
from botocore.exceptions import ClientError
client.return_value.upload_file.side_effect = ClientError(
{"Error": {}}, "operation_name"
)
with app.app_context():
with pytest.raises(ClientError):
upload_to_s3("filename", "prefix", Table("table"))
app.config = config
@mock.patch("boto3.client")
def test_upload_to_s3_success(client):
config = app.config.copy()
app.config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"] = "bucket"
client.return_value.upload_file.return_value = True
with app.app_context():
location = upload_to_s3("filename", "prefix", Table("table"))
assert f"s3a://bucket/prefix/table" == location
app.config = config
def test_fetch_data_query_error():
from TCLIService import ttypes
err_msg = "error message"
cursor = mock.Mock()
cursor.poll.return_value.operationState = ttypes.TOperationState.ERROR_STATE
cursor.poll.return_value.errorMessage = err_msg
with pytest.raises(Exception, match=f"('Query error', '{err_msg})'"):
HiveEngineSpec.fetch_data(cursor)
@mock.patch("superset.db_engine_specs.base.BaseEngineSpec.fetch_data")
def test_fetch_data_programming_error(fetch_data_mock):
from pyhive.exc import ProgrammingError
fetch_data_mock.side_effect = ProgrammingError
cursor = mock.Mock()
assert HiveEngineSpec.fetch_data(cursor) == []
@mock.patch("superset.db_engine_specs.base.BaseEngineSpec.fetch_data")
def test_fetch_data_success(fetch_data_mock):
return_value = ["a", "b"]
fetch_data_mock.return_value = return_value
cursor = mock.Mock()
assert HiveEngineSpec.fetch_data(cursor) == return_value
@mock.patch("superset.db_engine_specs.hive.HiveEngineSpec._latest_partition_from_df")
def test_where_latest_partition(mock_method):
mock_method.return_value = ("01-01-19", 1)
db = mock.Mock()
db.get_indexes = mock.Mock(return_value=[{"column_names": ["ds", "hour"]}])
db.get_extra = mock.Mock(return_value={})
db.get_df = mock.Mock()
columns = [{"name": "ds"}, {"name": "hour"}]
with app.app_context():
result = HiveEngineSpec.where_latest_partition(
"test_table", "test_schema", db, select(), columns
)
query_result = str(result.compile(compile_kwargs={"literal_binds": True}))
assert "SELECT \nWHERE ds = '01-01-19' AND hour = 1" == query_result
@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.latest_partition")
def test_where_latest_partition_super_method_exception(mock_method):
mock_method.side_effect = Exception()
db = mock.Mock()
columns = [{"name": "ds"}, {"name": "hour"}]
with app.app_context():
result = HiveEngineSpec.where_latest_partition(
"test_table", "test_schema", db, select(), columns
)
assert result is None
mock_method.assert_called()
@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.latest_partition")
def test_where_latest_partition_no_columns_no_values(mock_method):
mock_method.return_value = ("01-01-19", None)
db = mock.Mock()
with app.app_context():
result = HiveEngineSpec.where_latest_partition(
"test_table", "test_schema", db, select()
)
assert result is None

View File

@@ -0,0 +1,32 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.db_engine_specs.impala import ImpalaEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestImpalaDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
ImpalaEngineSpec.convert_dttm("DATE", dttm), "CAST('2019-01-02' AS DATE)"
)
self.assertEqual(
ImpalaEngineSpec.convert_dttm("TIMESTAMP", dttm),
"CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)",
)

View File

@@ -0,0 +1,32 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.db_engine_specs.kylin import KylinEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestKylinDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
KylinEngineSpec.convert_dttm("DATE", dttm), "CAST('2019-01-02' AS DATE)"
)
self.assertEqual(
KylinEngineSpec.convert_dttm("TIMESTAMP", dttm),
"CAST('2019-01-02 03:04:05' AS TIMESTAMP)",
)

View File

@@ -0,0 +1,310 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unittest.mock as mock
from textwrap import dedent
from sqlalchemy import column, table
from sqlalchemy.dialects import mssql
from sqlalchemy.dialects.mssql import DATE, NTEXT, NVARCHAR, TEXT, VARCHAR
from sqlalchemy.sql import select
from sqlalchemy.types import String, UnicodeText
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.mssql import MssqlEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.utils.core import GenericDataType
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestMssqlEngineSpec(TestDbEngineSpec):
def test_mssql_column_types(self):
def assert_type(type_string, type_expected, generic_type_expected):
if type_expected is None:
type_assigned = MssqlEngineSpec.get_sqla_column_type(type_string)
self.assertIsNone(type_assigned)
else:
column_spec = MssqlEngineSpec.get_column_spec(type_string)
if column_spec != None:
self.assertIsInstance(column_spec.sqla_type, type_expected)
self.assertEquals(column_spec.generic_type, generic_type_expected)
assert_type("STRING", String, GenericDataType.STRING)
assert_type("CHAR(10)", String, GenericDataType.STRING)
assert_type("VARCHAR(10)", String, GenericDataType.STRING)
assert_type("TEXT", String, GenericDataType.STRING)
assert_type("NCHAR(10)", UnicodeText, GenericDataType.STRING)
assert_type("NVARCHAR(10)", UnicodeText, GenericDataType.STRING)
assert_type("NTEXT", UnicodeText, GenericDataType.STRING)
def test_where_clause_n_prefix(self):
dialect = mssql.dialect()
spec = MssqlEngineSpec
type_, _ = spec.get_sqla_column_type("VARCHAR(10)")
str_col = column("col", type_=type_)
type_, _ = spec.get_sqla_column_type("NTEXT")
unicode_col = column("unicode_col", type_=type_)
tbl = table("tbl")
sel = (
select([str_col, unicode_col])
.select_from(tbl)
.where(str_col == "abc")
.where(unicode_col == "abc")
)
query = str(
sel.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
)
query_expected = (
"SELECT col, unicode_col \n"
"FROM tbl \n"
"WHERE col = 'abc' AND unicode_col = N'abc'"
)
self.assertEqual(query, query_expected)
def test_time_exp_mixd_case_col_1y(self):
col = column("MixedCase")
expr = MssqlEngineSpec.get_timestamp_expr(col, None, "P1Y")
result = str(expr.compile(None, dialect=mssql.dialect()))
self.assertEqual(result, "DATEADD(year, DATEDIFF(year, 0, [MixedCase]), 0)")
def test_convert_dttm(self):
dttm = self.get_dttm()
test_cases = (
(
MssqlEngineSpec.convert_dttm("DATE", dttm),
"CONVERT(DATE, '2019-01-02', 23)",
),
(
MssqlEngineSpec.convert_dttm("DATETIME", dttm),
"CONVERT(DATETIME, '2019-01-02T03:04:05.678', 126)",
),
(
MssqlEngineSpec.convert_dttm("SMALLDATETIME", dttm),
"CONVERT(SMALLDATETIME, '2019-01-02 03:04:05', 20)",
),
)
for actual, expected in test_cases:
self.assertEqual(actual, expected)
def test_extract_error_message(self):
test_mssql_exception = Exception(
"(8155, b\"No column name was specified for column 1 of 'inner_qry'."
"DB-Lib error message 20018, severity 16:\\nGeneral SQL Server error: "
'Check messages from the SQL Server\\n")'
)
error_message = MssqlEngineSpec.extract_error_message(test_mssql_exception)
expected_message = (
"mssql error: All your SQL functions need to "
"have an alias on MSSQL. For example: SELECT COUNT(*) AS C1 FROM TABLE1"
)
self.assertEqual(expected_message, error_message)
test_mssql_exception = Exception(
'(8200, b"A correlated expression is invalid because it is not in a '
"GROUP BY clause.\\n\")'"
)
error_message = MssqlEngineSpec.extract_error_message(test_mssql_exception)
expected_message = "mssql error: " + MssqlEngineSpec._extract_error_message(
test_mssql_exception
)
self.assertEqual(expected_message, error_message)
@mock.patch.object(
MssqlEngineSpec, "pyodbc_rows_to_tuples", return_value="converted"
)
def test_fetch_data(self, mock_pyodbc_rows_to_tuples):
data = [(1, "foo")]
with mock.patch.object(
BaseEngineSpec, "fetch_data", return_value=data
) as mock_fetch:
result = MssqlEngineSpec.fetch_data(None, 0)
mock_pyodbc_rows_to_tuples.assert_called_once_with(data)
self.assertEqual(result, "converted")
def test_column_datatype_to_string(self):
test_cases = (
(DATE(), "DATE"),
(VARCHAR(length=255), "VARCHAR(255)"),
(VARCHAR(length=255, collation="utf8_general_ci"), "VARCHAR(255)"),
(NVARCHAR(length=128), "NVARCHAR(128)"),
(TEXT(), "TEXT"),
(NTEXT(collation="utf8_general_ci"), "NTEXT"),
)
for original, expected in test_cases:
actual = MssqlEngineSpec.column_datatype_to_string(
original, mssql.dialect()
)
self.assertEqual(actual, expected)
def test_extract_errors(self):
"""
Test that custom error messages are extracted correctly.
"""
msg = dedent(
"""
DB-Lib error message 20009, severity 9:
Unable to connect: Adaptive Server is unavailable or does not exist (locahost)
"""
)
result = MssqlEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
error_type=SupersetErrorType.CONNECTION_INVALID_HOSTNAME_ERROR,
message='The hostname "locahost" cannot be resolved.',
level=ErrorLevel.ERROR,
extra={
"engine_name": "Microsoft SQL",
"issue_codes": [
{
"code": 1007,
"message": "Issue 1007 - The hostname provided can't be resolved.",
}
],
},
)
]
msg = dedent(
"""
DB-Lib error message 20009, severity 9:
Unable to connect: Adaptive Server is unavailable or does not exist (localhost)
Net-Lib error during Connection refused (61)
DB-Lib error message 20009, severity 9:
Unable to connect: Adaptive Server is unavailable or does not exist (localhost)
Net-Lib error during Connection refused (61)
"""
)
result = MssqlEngineSpec.extract_errors(
Exception(msg), context={"port": 12345, "hostname": "localhost"}
)
assert result == [
SupersetError(
error_type=SupersetErrorType.CONNECTION_PORT_CLOSED_ERROR,
message='Port 12345 on hostname "localhost" refused the connection.',
level=ErrorLevel.ERROR,
extra={
"engine_name": "Microsoft SQL",
"issue_codes": [
{"code": 1008, "message": "Issue 1008 - The port is closed."}
],
},
)
]
msg = dedent(
"""
DB-Lib error message 20009, severity 9:
Unable to connect: Adaptive Server is unavailable or does not exist (example.com)
Net-Lib error during Operation timed out (60)
DB-Lib error message 20009, severity 9:
Unable to connect: Adaptive Server is unavailable or does not exist (example.com)
Net-Lib error during Operation timed out (60)
"""
)
result = MssqlEngineSpec.extract_errors(
Exception(msg), context={"port": 12345, "hostname": "example.com"}
)
assert result == [
SupersetError(
error_type=SupersetErrorType.CONNECTION_HOST_DOWN_ERROR,
message=(
'The host "example.com" might be down, '
"and can't be reached on port 12345."
),
level=ErrorLevel.ERROR,
extra={
"engine_name": "Microsoft SQL",
"issue_codes": [
{
"code": 1009,
"message": "Issue 1009 - The host might be down, and can't be reached on the provided port.",
}
],
},
)
]
msg = dedent(
"""
DB-Lib error message 20009, severity 9:
Unable to connect: Adaptive Server is unavailable or does not exist (93.184.216.34)
Net-Lib error during Operation timed out (60)
DB-Lib error message 20009, severity 9:
Unable to connect: Adaptive Server is unavailable or does not exist (93.184.216.34)
Net-Lib error during Operation timed out (60)
"""
)
result = MssqlEngineSpec.extract_errors(
Exception(msg), context={"port": 12345, "hostname": "93.184.216.34"}
)
assert result == [
SupersetError(
error_type=SupersetErrorType.CONNECTION_HOST_DOWN_ERROR,
message=(
'The host "93.184.216.34" might be down, '
"and can't be reached on port 12345."
),
level=ErrorLevel.ERROR,
extra={
"engine_name": "Microsoft SQL",
"issue_codes": [
{
"code": 1009,
"message": "Issue 1009 - The host might be down, and can't be reached on the provided port.",
}
],
},
)
]
msg = dedent(
"""
DB-Lib error message 20018, severity 14:
General SQL Server error: Check messages from the SQL Server
DB-Lib error message 20002, severity 9:
Adaptive Server connection failed (mssqldb.cxiotftzsypc.us-west-2.rds.amazonaws.com)
DB-Lib error message 20002, severity 9:
Adaptive Server connection failed (mssqldb.cxiotftzsypc.us-west-2.rds.amazonaws.com)
"""
)
result = MssqlEngineSpec.extract_errors(
Exception(msg), context={"username": "testuser", "database": "testdb"}
)
assert result == [
SupersetError(
message='Either the username "testuser", password, or database name "testdb" is incorrect.',
error_type=SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR,
level=ErrorLevel.ERROR,
extra={
"engine_name": "Microsoft SQL",
"issue_codes": [
{
"code": 1014,
"message": "Issue 1014 - Either the username or "
"the password is wrong.",
},
{
"code": 1015,
"message": "Issue 1015 - Either the database is "
"spelled incorrectly or does not exist.",
},
],
},
)
]

View File

@@ -0,0 +1,240 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unittest
from sqlalchemy.dialects import mysql
from sqlalchemy.dialects.mysql import DATE, NVARCHAR, TEXT, VARCHAR
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.utils.core import GenericDataType
from tests.integration_tests.db_engine_specs.base_tests import (
assert_generic_types,
TestDbEngineSpec,
)
class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec):
@unittest.skipUnless(
TestDbEngineSpec.is_module_installed("MySQLdb"), "mysqlclient not installed"
)
def test_get_datatype_mysql(self):
"""Tests related to datatype mapping for MySQL"""
self.assertEqual("TINY", MySQLEngineSpec.get_datatype(1))
self.assertEqual("VARCHAR", MySQLEngineSpec.get_datatype(15))
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
MySQLEngineSpec.convert_dttm("DATE", dttm),
"STR_TO_DATE('2019-01-02', '%Y-%m-%d')",
)
self.assertEqual(
MySQLEngineSpec.convert_dttm("DATETIME", dttm),
"STR_TO_DATE('2019-01-02 03:04:05.678900', '%Y-%m-%d %H:%i:%s.%f')",
)
def test_column_datatype_to_string(self):
test_cases = (
(DATE(), "DATE"),
(VARCHAR(length=255), "VARCHAR(255)"),
(
VARCHAR(length=255, charset="latin1", collation="utf8mb4_general_ci"),
"VARCHAR(255)",
),
(NVARCHAR(length=128), "NATIONAL VARCHAR(128)"),
(TEXT(), "TEXT"),
)
for original, expected in test_cases:
actual = MySQLEngineSpec.column_datatype_to_string(
original, mysql.dialect()
)
self.assertEqual(actual, expected)
def test_generic_type(self):
type_expectations = (
# Numeric
("TINYINT", GenericDataType.NUMERIC),
("SMALLINT", GenericDataType.NUMERIC),
("MEDIUMINT", GenericDataType.NUMERIC),
("INT", GenericDataType.NUMERIC),
("BIGINT", GenericDataType.NUMERIC),
("DECIMAL", GenericDataType.NUMERIC),
("FLOAT", GenericDataType.NUMERIC),
("DOUBLE", GenericDataType.NUMERIC),
("BIT", GenericDataType.NUMERIC),
# String
("CHAR", GenericDataType.STRING),
("VARCHAR", GenericDataType.STRING),
("TINYTEXT", GenericDataType.STRING),
("MEDIUMTEXT", GenericDataType.STRING),
("LONGTEXT", GenericDataType.STRING),
# Temporal
("DATE", GenericDataType.TEMPORAL),
("DATETIME", GenericDataType.TEMPORAL),
("TIMESTAMP", GenericDataType.TEMPORAL),
("TIME", GenericDataType.TEMPORAL),
)
assert_generic_types(MySQLEngineSpec, type_expectations)
def test_extract_error_message(self):
from MySQLdb._exceptions import OperationalError
message = "Unknown table 'BIRTH_NAMES1' in information_schema"
exception = OperationalError(message)
extracted_message = MySQLEngineSpec._extract_error_message(exception)
assert extracted_message == message
exception = OperationalError(123, message)
extracted_message = MySQLEngineSpec._extract_error_message(exception)
assert extracted_message == message
def test_extract_errors(self):
"""
Test that custom error messages are extracted correctly.
"""
msg = "mysql: Access denied for user 'test'@'testuser.com'"
result = MySQLEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
error_type=SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR,
message='Either the username "test" or the password is incorrect.',
level=ErrorLevel.ERROR,
extra={
"invalid": ["username", "password"],
"engine_name": "MySQL",
"issue_codes": [
{
"code": 1014,
"message": "Issue 1014 - Either the"
" username or the password is wrong.",
},
{
"code": 1015,
"message": "Issue 1015 - Either the database is "
"spelled incorrectly or does not exist.",
},
],
},
)
]
msg = "mysql: Unknown MySQL server host 'badhostname.com'"
result = MySQLEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
error_type=SupersetErrorType.CONNECTION_INVALID_HOSTNAME_ERROR,
message='Unknown MySQL server host "badhostname.com".',
level=ErrorLevel.ERROR,
extra={
"invalid": ["host"],
"engine_name": "MySQL",
"issue_codes": [
{
"code": 1007,
"message": "Issue 1007 - The hostname"
" provided can't be resolved.",
}
],
},
)
]
msg = "mysql: Can't connect to MySQL server on 'badconnection.com'"
result = MySQLEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
error_type=SupersetErrorType.CONNECTION_HOST_DOWN_ERROR,
message='The host "badconnection.com" might be '
"down and can't be reached.",
level=ErrorLevel.ERROR,
extra={
"invalid": ["host", "port"],
"engine_name": "MySQL",
"issue_codes": [
{
"code": 1007,
"message": "Issue 1007 - The hostname provided"
" can't be resolved.",
}
],
},
)
]
msg = "mysql: Can't connect to MySQL server on '93.184.216.34'"
result = MySQLEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
error_type=SupersetErrorType.CONNECTION_HOST_DOWN_ERROR,
message='The host "93.184.216.34" might be down and can\'t be reached.',
level=ErrorLevel.ERROR,
extra={
"invalid": ["host", "port"],
"engine_name": "MySQL",
"issue_codes": [
{
"code": 10007,
"message": "Issue 1007 - The hostname provided "
"can't be resolved.",
}
],
},
)
]
msg = "mysql: Unknown database 'badDB'"
result = MySQLEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
message='Unable to connect to database "badDB".',
error_type=SupersetErrorType.CONNECTION_UNKNOWN_DATABASE_ERROR,
level=ErrorLevel.ERROR,
extra={
"invalid": ["database"],
"engine_name": "MySQL",
"issue_codes": [
{
"code": 1015,
"message": "Issue 1015 - Either the database is spelled incorrectly or does not exist.",
}
],
},
)
]
msg = "check the manual that corresponds to your MySQL server version for the right syntax to use near 'fromm"
result = MySQLEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
message='Please check your query for syntax errors near "fromm". Then, try running your query again.',
error_type=SupersetErrorType.SYNTAX_ERROR,
level=ErrorLevel.ERROR,
extra={
"engine_name": "MySQL",
"issue_codes": [
{
"code": 1030,
"message": "Issue 1030 - The query has a syntax error.",
}
],
},
)
]

View File

@@ -0,0 +1,87 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest import mock
import pytest
from sqlalchemy import column
from sqlalchemy.dialects import oracle
from sqlalchemy.dialects.oracle import DATE, NVARCHAR, VARCHAR
from superset.db_engine_specs.oracle import OracleEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestOracleDbEngineSpec(TestDbEngineSpec):
def test_oracle_sqla_column_name_length_exceeded(self):
col = column("This_Is_32_Character_Column_Name")
label = OracleEngineSpec.make_label_compatible(col.name)
self.assertEqual(label.quote, True)
label_expected = "3b26974078683be078219674eeb8f5"
self.assertEqual(label, label_expected)
def test_oracle_time_expression_reserved_keyword_1m_grain(self):
col = column("decimal")
expr = OracleEngineSpec.get_timestamp_expr(col, None, "P1M")
result = str(expr.compile(dialect=oracle.dialect()))
self.assertEqual(result, "TRUNC(CAST(\"decimal\" as DATE), 'MONTH')")
dttm = self.get_dttm()
def test_column_datatype_to_string(self):
test_cases = (
(DATE(), "DATE"),
(VARCHAR(length=255), "VARCHAR(255 CHAR)"),
(VARCHAR(length=255, collation="utf8"), "VARCHAR(255 CHAR)"),
(NVARCHAR(length=128), "NVARCHAR2(128)"),
)
for original, expected in test_cases:
actual = OracleEngineSpec.column_datatype_to_string(
original, oracle.dialect()
)
self.assertEqual(actual, expected)
def test_fetch_data_no_description(self):
cursor = mock.MagicMock()
cursor.description = []
assert OracleEngineSpec.fetch_data(cursor) == []
def test_fetch_data(self):
cursor = mock.MagicMock()
result = ["a", "b"]
cursor.fetchall.return_value = result
assert OracleEngineSpec.fetch_data(cursor) == result
@pytest.mark.parametrize(
"date_format,expected",
[
("DATE", "TO_DATE('2019-01-02', 'YYYY-MM-DD')"),
("DATETIME", """TO_DATE('2019-01-02T03:04:05', 'YYYY-MM-DD"T"HH24:MI:SS')"""),
(
"TIMESTAMP",
"""TO_TIMESTAMP('2019-01-02T03:04:05.678900', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')""",
),
(
"timestamp",
"""TO_TIMESTAMP('2019-01-02T03:04:05.678900', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')""",
),
("Other", None),
],
)
def test_convert_dttm(date_format, expected):
dttm = TestOracleDbEngineSpec.get_dttm()
assert OracleEngineSpec.convert_dttm(date_format, dttm) == expected

View File

@@ -0,0 +1,75 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from sqlalchemy import column
from superset.db_engine_specs.pinot import PinotEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestPinotDbEngineSpec(TestDbEngineSpec):
""" Tests pertaining to our Pinot database support """
def test_pinot_time_expression_sec_one_1d_grain(self):
col = column("tstamp")
expr = PinotEngineSpec.get_timestamp_expr(col, "epoch_s", "P1D")
result = str(expr.compile())
self.assertEqual(
result,
"DATETIMECONVERT(tstamp, '1:SECONDS:EPOCH', '1:SECONDS:EPOCH', '1:DAYS')",
)
def test_pinot_time_expression_simple_date_format_1d_grain(self):
col = column("tstamp")
expr = PinotEngineSpec.get_timestamp_expr(col, "%Y-%m-%d %H:%M:%S", "P1D")
result = str(expr.compile())
self.assertEqual(
result,
(
"DATETIMECONVERT(tstamp, "
+ "'1:SECONDS:SIMPLE_DATE_FORMAT:yyyy-MM-dd HH:mm:ss', "
+ "'1:SECONDS:SIMPLE_DATE_FORMAT:yyyy-MM-dd HH:mm:ss', '1:DAYS')"
),
)
def test_pinot_time_expression_simple_date_format_1w_grain(self):
col = column("tstamp")
expr = PinotEngineSpec.get_timestamp_expr(col, "%Y-%m-%d %H:%M:%S", "P1W")
result = str(expr.compile())
self.assertEqual(
result,
(
"ToDateTime(DATETRUNC('week', FromDateTime(tstamp, "
+ "'yyyy-MM-dd HH:mm:ss'), 'MILLISECONDS'), 'yyyy-MM-dd HH:mm:ss')"
),
)
def test_pinot_time_expression_sec_one_1m_grain(self):
col = column("tstamp")
expr = PinotEngineSpec.get_timestamp_expr(col, "epoch_s", "P1M")
result = str(expr.compile())
self.assertEqual(
result, "DATETRUNC('month', tstamp, 'SECONDS')",
)
def test_invalid_get_time_expression_arguments(self):
with self.assertRaises(NotImplementedError):
PinotEngineSpec.get_timestamp_expr(column("tstamp"), None, "P1M")
with self.assertRaises(NotImplementedError):
PinotEngineSpec.get_timestamp_expr(
column("tstamp"), "epoch_s", "invalid_grain"
)

View File

@@ -0,0 +1,518 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from textwrap import dedent
from unittest import mock
from sqlalchemy import column, literal_column
from sqlalchemy.dialects import postgresql
from superset.db_engine_specs import get_engine_specs
from superset.db_engine_specs.postgres import PostgresEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.utils.core import GenericDataType
from tests.integration_tests.db_engine_specs.base_tests import (
assert_generic_types,
TestDbEngineSpec,
)
from tests.integration_tests.fixtures.certificates import ssl_certificate
from tests.integration_tests.fixtures.database import default_db_extra
class TestPostgresDbEngineSpec(TestDbEngineSpec):
def test_get_table_names(self):
"""
DB Eng Specs (postgres): Test get table names
"""
""" Make sure postgres doesn't try to remove schema name from table name
ie. when try_remove_schema_from_table_name == False. """
inspector = mock.Mock()
inspector.get_table_names = mock.Mock(return_value=["schema.table", "table_2"])
inspector.get_foreign_table_names = mock.Mock(return_value=["table_3"])
pg_result_expected = ["schema.table", "table_2", "table_3"]
pg_result = PostgresEngineSpec.get_table_names(
database=mock.ANY, schema="schema", inspector=inspector
)
self.assertListEqual(pg_result_expected, pg_result)
def test_time_exp_literal_no_grain(self):
"""
DB Eng Specs (postgres): Test no grain literal column
"""
col = literal_column("COALESCE(a, b)")
expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
result = str(expr.compile(None, dialect=postgresql.dialect()))
self.assertEqual(result, "COALESCE(a, b)")
def test_time_exp_literal_1y_grain(self):
"""
DB Eng Specs (postgres): Test grain literal column 1 YEAR
"""
col = literal_column("COALESCE(a, b)")
expr = PostgresEngineSpec.get_timestamp_expr(col, None, "P1Y")
result = str(expr.compile(None, dialect=postgresql.dialect()))
self.assertEqual(result, "DATE_TRUNC('year', COALESCE(a, b))")
def test_time_ex_lowr_col_no_grain(self):
"""
DB Eng Specs (postgres): Test no grain expr lower case
"""
col = column("lower_case")
expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
result = str(expr.compile(None, dialect=postgresql.dialect()))
self.assertEqual(result, "lower_case")
def test_time_exp_lowr_col_sec_1y(self):
"""
DB Eng Specs (postgres): Test grain expr lower case 1 YEAR
"""
col = column("lower_case")
expr = PostgresEngineSpec.get_timestamp_expr(col, "epoch_s", "P1Y")
result = str(expr.compile(None, dialect=postgresql.dialect()))
self.assertEqual(
result,
"DATE_TRUNC('year', "
"(timestamp 'epoch' + lower_case * interval '1 second'))",
)
def test_time_exp_mixd_case_col_1y(self):
"""
DB Eng Specs (postgres): Test grain expr mixed case 1 YEAR
"""
col = column("MixedCase")
expr = PostgresEngineSpec.get_timestamp_expr(col, None, "P1Y")
result = str(expr.compile(None, dialect=postgresql.dialect()))
self.assertEqual(result, "DATE_TRUNC('year', \"MixedCase\")")
def test_convert_dttm(self):
"""
DB Eng Specs (postgres): Test conversion to date time
"""
dttm = self.get_dttm()
self.assertEqual(
PostgresEngineSpec.convert_dttm("DATE", dttm),
"TO_DATE('2019-01-02', 'YYYY-MM-DD')",
)
self.assertEqual(
PostgresEngineSpec.convert_dttm("TIMESTAMP", dttm),
"TO_TIMESTAMP('2019-01-02 03:04:05.678900', 'YYYY-MM-DD HH24:MI:SS.US')",
)
self.assertEqual(
PostgresEngineSpec.convert_dttm("DATETIME", dttm),
"TO_TIMESTAMP('2019-01-02 03:04:05.678900', 'YYYY-MM-DD HH24:MI:SS.US')",
)
self.assertEqual(PostgresEngineSpec.convert_dttm("TIME", dttm), None)
def test_empty_dbapi_cursor_description(self):
"""
DB Eng Specs (postgres): Test empty cursor description (no columns)
"""
cursor = mock.Mock()
# empty description mean no columns, this mocks the following SQL: "SELECT"
cursor.description = []
results = PostgresEngineSpec.fetch_data(cursor, 1000)
self.assertEqual(results, [])
def test_engine_alias_name(self):
"""
DB Eng Specs (postgres): Test "postgres" in engine spec
"""
self.assertIn("postgres", get_engine_specs())
def test_extras_without_ssl(self):
db = mock.Mock()
db.extra = default_db_extra
db.server_cert = None
extras = PostgresEngineSpec.get_extra_params(db)
assert "connect_args" not in extras["engine_params"]
def test_extras_with_ssl_default(self):
db = mock.Mock()
db.extra = default_db_extra
db.server_cert = ssl_certificate
extras = PostgresEngineSpec.get_extra_params(db)
connect_args = extras["engine_params"]["connect_args"]
assert connect_args["sslmode"] == "verify-full"
assert "sslrootcert" in connect_args
def test_extras_with_ssl_custom(self):
db = mock.Mock()
db.extra = default_db_extra.replace(
'"engine_params": {}',
'"engine_params": {"connect_args": {"sslmode": "verify-ca"}}',
)
db.server_cert = ssl_certificate
extras = PostgresEngineSpec.get_extra_params(db)
connect_args = extras["engine_params"]["connect_args"]
assert connect_args["sslmode"] == "verify-ca"
assert "sslrootcert" in connect_args
def test_estimate_statement_cost_select_star(self):
"""
DB Eng Specs (postgres): Test estimate_statement_cost select star
"""
cursor = mock.Mock()
cursor.fetchone.return_value = (
"Seq Scan on birth_names (cost=0.00..1537.91 rows=75691 width=46)",
)
sql = "SELECT * FROM birth_names"
results = PostgresEngineSpec.estimate_statement_cost(sql, cursor)
self.assertEqual(
results, {"Start-up cost": 0.00, "Total cost": 1537.91,},
)
def test_estimate_statement_invalid_syntax(self):
"""
DB Eng Specs (postgres): Test estimate_statement_cost invalid syntax
"""
from psycopg2 import errors
cursor = mock.Mock()
cursor.execute.side_effect = errors.SyntaxError(
"""
syntax error at or near "EXPLAIN"
LINE 1: EXPLAIN DROP TABLE birth_names
^
"""
)
sql = "DROP TABLE birth_names"
with self.assertRaises(errors.SyntaxError):
PostgresEngineSpec.estimate_statement_cost(sql, cursor)
def test_query_cost_formatter_example_costs(self):
"""
DB Eng Specs (postgres): Test test_query_cost_formatter example costs
"""
raw_cost = [
{"Start-up cost": 0.00, "Total cost": 1537.91,},
{"Start-up cost": 10.00, "Total cost": 1537.00,},
]
result = PostgresEngineSpec.query_cost_formatter(raw_cost)
self.assertEqual(
result,
[
{"Start-up cost": "0.0", "Total cost": "1537.91",},
{"Start-up cost": "10.0", "Total cost": "1537.0",},
],
)
def test_extract_errors(self):
"""
Test that custom error messages are extracted correctly.
"""
msg = 'psql: error: FATAL: role "testuser" does not exist'
result = PostgresEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
error_type=SupersetErrorType.CONNECTION_INVALID_USERNAME_ERROR,
message='The username "testuser" does not exist.',
level=ErrorLevel.ERROR,
extra={
"engine_name": "PostgreSQL",
"issue_codes": [
{
"code": 1012,
"message": (
"Issue 1012 - The username provided when "
"connecting to a database is not valid."
),
},
],
"invalid": ["username"],
},
)
]
msg = (
'psql: error: could not translate host name "locahost" to address: '
"nodename nor servname provided, or not known"
)
result = PostgresEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
error_type=SupersetErrorType.CONNECTION_INVALID_HOSTNAME_ERROR,
message='The hostname "locahost" cannot be resolved.',
level=ErrorLevel.ERROR,
extra={
"engine_name": "PostgreSQL",
"issue_codes": [
{
"code": 1007,
"message": "Issue 1007 - The hostname provided "
"can't be resolved.",
}
],
"invalid": ["host"],
},
)
]
msg = dedent(
"""
psql: error: could not connect to server: Connection refused
Is the server running on host "localhost" (::1) and accepting
TCP/IP connections on port 12345?
could not connect to server: Connection refused
Is the server running on host "localhost" (127.0.0.1) and accepting
TCP/IP connections on port 12345?
"""
)
result = PostgresEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
error_type=SupersetErrorType.CONNECTION_PORT_CLOSED_ERROR,
message='Port 12345 on hostname "localhost" refused the connection.',
level=ErrorLevel.ERROR,
extra={
"engine_name": "PostgreSQL",
"issue_codes": [
{"code": 1008, "message": "Issue 1008 - The port is closed."}
],
"invalid": ["host", "port"],
},
)
]
msg = dedent(
"""
psql: error: could not connect to server: Operation timed out
Is the server running on host "example.com" (93.184.216.34) and accepting
TCP/IP connections on port 12345?
"""
)
result = PostgresEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
error_type=SupersetErrorType.CONNECTION_HOST_DOWN_ERROR,
message=(
'The host "example.com" might be down, '
"and can't be reached on port 12345."
),
level=ErrorLevel.ERROR,
extra={
"engine_name": "PostgreSQL",
"issue_codes": [
{
"code": 1009,
"message": "Issue 1009 - The host might be down, "
"and can't be reached on the provided port.",
}
],
"invalid": ["host", "port"],
},
)
]
# response with IP only
msg = dedent(
"""
psql: error: could not connect to server: Operation timed out
Is the server running on host "93.184.216.34" and accepting
TCP/IP connections on port 12345?
"""
)
result = PostgresEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
error_type=SupersetErrorType.CONNECTION_HOST_DOWN_ERROR,
message=(
'The host "93.184.216.34" might be down, '
"and can't be reached on port 12345."
),
level=ErrorLevel.ERROR,
extra={
"engine_name": "PostgreSQL",
"issue_codes": [
{
"code": 1009,
"message": "Issue 1009 - The host might be down, "
"and can't be reached on the provided port.",
}
],
"invalid": ["host", "port"],
},
)
]
msg = 'FATAL: password authentication failed for user "postgres"'
result = PostgresEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
error_type=SupersetErrorType.CONNECTION_INVALID_PASSWORD_ERROR,
message=('The password provided for username "postgres" is incorrect.'),
level=ErrorLevel.ERROR,
extra={
"engine_name": "PostgreSQL",
"issue_codes": [
{
"code": 1013,
"message": (
"Issue 1013 - The password provided when "
"connecting to a database is not valid."
),
},
],
"invalid": ["username", "password"],
},
)
]
msg = 'database "badDB" does not exist'
result = PostgresEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
message='Unable to connect to database "badDB".',
error_type=SupersetErrorType.CONNECTION_UNKNOWN_DATABASE_ERROR,
level=ErrorLevel.ERROR,
extra={
"engine_name": "PostgreSQL",
"issue_codes": [
{
"code": 1015,
"message": (
"Issue 1015 - Either the database is spelled "
"incorrectly or does not exist.",
),
}
],
"invalid": ["database"],
},
)
]
msg = "no password supplied"
result = PostgresEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
message="Please re-enter the password.",
error_type=SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR,
level=ErrorLevel.ERROR,
extra={
"invalid": ["password"],
"engine_name": "PostgreSQL",
"issue_codes": [
{
"code": 1014,
"message": "Issue 1014 - Either the username or the password is wrong.",
},
{
"code": 1015,
"message": "Issue 1015 - Either the database is spelled incorrectly or does not exist.",
},
],
},
)
]
msg = 'syntax error at or near "fromm"'
result = PostgresEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
message='Please check your query for syntax errors at or near "fromm". Then, try running your query again.',
error_type=SupersetErrorType.SYNTAX_ERROR,
level=ErrorLevel.ERROR,
extra={
"engine_name": "PostgreSQL",
"issue_codes": [
{
"code": 1030,
"message": "Issue 1030 - The query has a syntax error.",
}
],
},
)
]
def test_base_parameters_mixin():
parameters = {
"username": "username",
"password": "password",
"host": "localhost",
"port": 5432,
"database": "dbname",
"query": {"foo": "bar"},
"encryption": True,
}
encrypted_extra = None
sqlalchemy_uri = PostgresEngineSpec.build_sqlalchemy_uri(
parameters, encrypted_extra
)
assert sqlalchemy_uri == (
"postgresql+psycopg2://username:password@localhost:5432/dbname?"
"foo=bar&sslmode=verify-ca"
)
parameters_from_uri = PostgresEngineSpec.get_parameters_from_uri(sqlalchemy_uri)
assert parameters_from_uri == parameters
json_schema = PostgresEngineSpec.parameters_json_schema()
assert json_schema == {
"type": "object",
"properties": {
"port": {
"type": "integer",
"format": "int32",
"description": "Database port",
},
"password": {"type": "string", "nullable": True, "description": "Password"},
"host": {"type": "string", "description": "Hostname or IP address"},
"username": {"type": "string", "nullable": True, "description": "Username"},
"query": {
"type": "object",
"description": "Additional parameters",
"additionalProperties": {},
},
"database": {"type": "string", "description": "Database name"},
"encryption": {
"type": "boolean",
"description": "Use an encrypted connection to the database",
},
},
"required": ["database", "host", "port", "username"],
}
def test_generic_type():
type_expectations = (
# Numeric
("SMALLINT", GenericDataType.NUMERIC),
("INTEGER", GenericDataType.NUMERIC),
("BIGINT", GenericDataType.NUMERIC),
("DECIMAL", GenericDataType.NUMERIC),
("NUMERIC", GenericDataType.NUMERIC),
("REAL", GenericDataType.NUMERIC),
("DOUBLE PRECISION", GenericDataType.NUMERIC),
("MONEY", GenericDataType.NUMERIC),
# String
("CHAR", GenericDataType.STRING),
("VARCHAR", GenericDataType.STRING),
("TEXT", GenericDataType.STRING),
# Temporal
("DATE", GenericDataType.TEMPORAL),
("TIMESTAMP", GenericDataType.TEMPORAL),
("TIME", GenericDataType.TEMPORAL),
# Boolean
("BOOLEAN", GenericDataType.BOOLEAN),
)
assert_generic_types(PostgresEngineSpec, type_expectations)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,185 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from textwrap import dedent
from superset.db_engine_specs.redshift import RedshiftEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestRedshiftDbEngineSpec(TestDbEngineSpec):
def test_extract_errors(self):
"""
Test that custom error messages are extracted correctly.
"""
msg = 'FATAL: password authentication failed for user "wronguser"'
result = RedshiftEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
error_type=SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR,
message='Either the username "wronguser" or the password is incorrect.',
level=ErrorLevel.ERROR,
extra={
"invalid": ["username", "password"],
"engine_name": "Amazon Redshift",
"issue_codes": [
{
"code": 1014,
"message": "Issue 1014 - Either the username "
"or the password is wrong.",
},
{
"code": 1015,
"message": "Issue 1015 - Either the database is "
"spelled incorrectly or does not exist.",
},
],
},
)
]
msg = (
'redshift: error: could not translate host name "badhost" '
"to address: nodename nor servname provided, or not known"
)
result = RedshiftEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
error_type=SupersetErrorType.CONNECTION_INVALID_HOSTNAME_ERROR,
message='The hostname "badhost" cannot be resolved.',
level=ErrorLevel.ERROR,
extra={
"invalid": ["host"],
"engine_name": "Amazon Redshift",
"issue_codes": [
{
"code": 1007,
"message": "Issue 1007 - The hostname provided "
"can't be resolved.",
}
],
},
)
]
msg = dedent(
"""
psql: error: could not connect to server: Connection refused
Is the server running on host "localhost" (::1) and accepting
TCP/IP connections on port 12345?
could not connect to server: Connection refused
Is the server running on host "localhost" (127.0.0.1) and accepting
TCP/IP connections on port 12345?
"""
)
result = RedshiftEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
error_type=SupersetErrorType.CONNECTION_PORT_CLOSED_ERROR,
message='Port 12345 on hostname "localhost" refused the connection.',
level=ErrorLevel.ERROR,
extra={
"invalid": ["host", "port"],
"engine_name": "Amazon Redshift",
"issue_codes": [
{"code": 1008, "message": "Issue 1008 - The port is closed."}
],
},
)
]
msg = dedent(
"""
psql: error: could not connect to server: Operation timed out
Is the server running on host "example.com" (93.184.216.34) and accepting
TCP/IP connections on port 12345?
"""
)
result = RedshiftEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
error_type=SupersetErrorType.CONNECTION_HOST_DOWN_ERROR,
message=(
'The host "example.com" might be down, '
"and can't be reached on port 12345."
),
level=ErrorLevel.ERROR,
extra={
"engine_name": "Amazon Redshift",
"issue_codes": [
{
"code": 1009,
"message": "Issue 1009 - The host might be down, "
"and can't be reached on the provided port.",
}
],
"invalid": ["host", "port"],
},
)
]
# response with IP only
msg = dedent(
"""
psql: error: could not connect to server: Operation timed out
Is the server running on host "93.184.216.34" and accepting
TCP/IP connections on port 12345?
"""
)
result = RedshiftEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
error_type=SupersetErrorType.CONNECTION_HOST_DOWN_ERROR,
message=(
'The host "93.184.216.34" might be down, '
"and can't be reached on port 12345."
),
level=ErrorLevel.ERROR,
extra={
"engine_name": "Amazon Redshift",
"issue_codes": [
{
"code": 1009,
"message": "Issue 1009 - The host might be down, "
"and can't be reached on the provided port.",
}
],
"invalid": ["host", "port"],
},
)
]
msg = 'database "badDB" does not exist'
result = RedshiftEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
error_type=SupersetErrorType.CONNECTION_UNKNOWN_DATABASE_ERROR,
message='We were unable to connect to your database named "badDB".'
" Please verify your database name and try again.",
level=ErrorLevel.ERROR,
extra={
"engine_name": "Amazon Redshift",
"issue_codes": [
{
"code": 10015,
"message": "Issue 1015 - Either the database is "
"spelled incorrectly or does not exist.",
}
],
"invalid": ["database"],
},
)
]

View File

@@ -0,0 +1,85 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.core import Database
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestSnowflakeDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
test_cases = {
"DATE": "TO_DATE('2019-01-02')",
"DATETIME": "CAST('2019-01-02T03:04:05.678900' AS DATETIME)",
"TIMESTAMP": "TO_TIMESTAMP('2019-01-02T03:04:05.678900')",
}
for type_, expected in test_cases.items():
self.assertEqual(SnowflakeEngineSpec.convert_dttm(type_, dttm), expected)
def test_database_connection_test_mutator(self):
database = Database(sqlalchemy_uri="snowflake://abc")
SnowflakeEngineSpec.mutate_db_for_connection_test(database)
engine_params = json.loads(database.extra or "{}")
self.assertDictEqual(
{"engine_params": {"connect_args": {"validate_default_parameters": True}}},
engine_params,
)
def test_extract_errors(self):
msg = "Object dumbBrick does not exist or not authorized."
result = SnowflakeEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
message="dumbBrick does not exist in this database.",
error_type=SupersetErrorType.OBJECT_DOES_NOT_EXIST_ERROR,
level=ErrorLevel.ERROR,
extra={
"engine_name": "Snowflake",
"issue_codes": [
{
"code": 1029,
"message": "Issue 1029 - The object does not exist in the given database.",
}
],
},
)
]
msg = "syntax error line 1 at position 10 unexpected 'limmmited'."
result = SnowflakeEngineSpec.extract_errors(Exception(msg))
assert result == [
SupersetError(
message='Please check your query for syntax errors at or near "limmmited". Then, try running your query again.',
error_type=SupersetErrorType.SYNTAX_ERROR,
level=ErrorLevel.ERROR,
extra={
"engine_name": "Snowflake",
"issue_codes": [
{
"code": 1030,
"message": "Issue 1030 - The query has a syntax error.",
}
],
},
)
]

View File

@@ -0,0 +1,78 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest import mock
from superset.db_engine_specs.sqlite import SqliteEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestSQliteDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
SqliteEngineSpec.convert_dttm("TEXT", dttm), "'2019-01-02 03:04:05.678900'"
)
def test_convert_dttm_lower(self):
dttm = self.get_dttm()
self.assertEqual(
SqliteEngineSpec.convert_dttm("text", dttm), "'2019-01-02 03:04:05.678900'"
)
def test_convert_dttm_invalid_type(self):
dttm = self.get_dttm()
self.assertEqual(SqliteEngineSpec.convert_dttm("other", dttm), None)
def test_get_all_datasource_names_table(self):
database = mock.MagicMock()
database.get_all_schema_names.return_value = ["schema1"]
table_names = ["table1", "table2"]
get_tables = mock.MagicMock(return_value=table_names)
database.get_all_table_names_in_schema = get_tables
result = SqliteEngineSpec.get_all_datasource_names(database, "table")
assert result == table_names
get_tables.assert_called_once_with(
schema="schema1",
force=True,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
)
def test_get_all_datasource_names_view(self):
database = mock.MagicMock()
database.get_all_schema_names.return_value = ["schema1"]
views_names = ["view1", "view2"]
get_views = mock.MagicMock(return_value=views_names)
database.get_all_view_names_in_schema = get_views
result = SqliteEngineSpec.get_all_datasource_names(database, "view")
assert result == views_names
get_views.assert_called_once_with(
schema="schema1",
force=True,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
)
def test_get_all_datasource_names_invalid_type(self):
database = mock.MagicMock()
database.get_all_schema_names.return_value = ["schema1"]
invalid_type = "asdf"
with self.assertRaises(Exception):
SqliteEngineSpec.get_all_datasource_names(database, invalid_type)

View File

@@ -0,0 +1,54 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from sqlalchemy.engine.url import URL
from superset.db_engine_specs.trino import TrinoEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestTrinoDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
TrinoEngineSpec.convert_dttm("DATE", dttm),
"from_iso8601_date('2019-01-02')",
)
self.assertEqual(
TrinoEngineSpec.convert_dttm("TIMESTAMP", dttm),
"from_iso8601_timestamp('2019-01-02T03:04:05.678900')",
)
def test_adjust_database_uri(self):
url = URL(drivername="trino", database="hive")
TrinoEngineSpec.adjust_database_uri(url, selected_schema="foobar")
self.assertEqual(url.database, "hive/foobar")
def test_adjust_database_uri_when_database_contain_schema(self):
url = URL(drivername="trino", database="hive/default")
TrinoEngineSpec.adjust_database_uri(url, selected_schema="foobar")
self.assertEqual(url.database, "hive/foobar")
def test_adjust_database_uri_when_selected_schema_is_none(self):
url = URL(drivername="trino", database="hive")
TrinoEngineSpec.adjust_database_uri(url, selected_schema=None)
self.assertEqual(url.database, "hive")
url.database = "hive/default"
TrinoEngineSpec.adjust_database_uri(url, selected_schema=None)
self.assertEqual(url.database, "hive/default")