fix(presto/trino): Ensure get_table_names only returns real tables (#21794)

This commit is contained in:
John Bodley
2022-11-09 14:30:49 -08:00
committed by GitHub
parent 53ed8f2d5a
commit 9f7bd1e63f
10 changed files with 125 additions and 116 deletions

View File

@@ -150,10 +150,6 @@ def test_hive_error_msg():
)
def test_hive_get_view_names_return_empty_list(): # pylint: disable=invalid-name
assert HiveEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY) == []
def test_convert_dttm():
dttm = datetime.strptime("2019-01-02 03:04:05.678900", "%Y-%m-%d %H:%M:%S.%f")
assert HiveEngineSpec.convert_dttm("DATE", dttm) == "CAST('2019-01-02' AS DATE)"

View File

@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from collections import namedtuple
from textwrap import dedent
from unittest import mock, skipUnless
import pandas as pd
@@ -33,52 +34,50 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
def test_get_datatype_presto(self):
self.assertEqual("STRING", PrestoEngineSpec.get_datatype("string"))
def test_presto_get_view_names_return_empty_list(
self,
): # pylint: disable=invalid-name
self.assertEqual(
[], PrestoEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY)
)
@mock.patch("superset.db_engine_specs.presto.is_feature_enabled")
def test_get_view_names(self, mock_is_feature_enabled):
mock_is_feature_enabled.return_value = True
mock_execute = mock.MagicMock()
mock_fetchall = mock.MagicMock(return_value=[["a", "b,", "c"], ["d", "e"]])
def test_get_view_names_with_schema(self):
database = mock.MagicMock()
mock_execute = mock.MagicMock()
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
mock_execute
)
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = (
mock_fetchall
)
result = PrestoEngineSpec.get_view_names(database, mock.Mock(), None)
mock_execute.assert_called_once_with(
"SELECT table_name FROM information_schema.views", {}
)
assert result == ["a", "d"]
@mock.patch("superset.db_engine_specs.presto.is_feature_enabled")
def test_get_view_names_with_schema(self, mock_is_feature_enabled):
mock_is_feature_enabled.return_value = True
mock_execute = mock.MagicMock()
mock_fetchall = mock.MagicMock(return_value=[["a", "b,", "c"], ["d", "e"]])
database = mock.MagicMock()
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
mock_execute
)
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = (
mock_fetchall
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
return_value=[["a", "b,", "c"], ["d", "e"]]
)
schema = "schema"
result = PrestoEngineSpec.get_view_names(database, mock.Mock(), schema)
mock_execute.assert_called_once_with(
"SELECT table_name FROM information_schema.views "
"WHERE table_schema=%(schema)s",
dedent(
"""
SELECT table_name FROM information_schema.tables
WHERE table_schema = %(schema)s
AND table_type = 'VIEW'
"""
).strip(),
{"schema": schema},
)
assert result == ["a", "d"]
def test_get_view_names_without_schema(self):
database = mock.MagicMock()
mock_execute = mock.MagicMock()
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
mock_execute
)
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
return_value=[["a", "b,", "c"], ["d", "e"]]
)
result = PrestoEngineSpec.get_view_names(database, mock.Mock(), None)
mock_execute.assert_called_once_with(
dedent(
"""
SELECT table_name FROM information_schema.tables
WHERE table_type = 'VIEW'
"""
).strip(),
{},
)
assert result == ["a", "d"]
def verify_presto_column(self, column, expected_results):
inspector = mock.Mock()
inspector.engine.dialect.identifier_preparer.quote_identifier = mock.Mock()
@@ -663,50 +662,17 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
sqla_type = PrestoEngineSpec.get_sqla_column_type(None)
assert sqla_type is None
@mock.patch(
"superset.utils.feature_flag_manager.FeatureFlagManager.is_feature_enabled"
)
@mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_table_names")
@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.get_view_names")
def test_get_table_names_no_split_views_from_tables(
self, mock_get_view_names, mock_get_table_names, mock_is_feature_enabled
def test_get_table_names(
self,
mock_get_view_names,
mock_get_table_names,
):
mock_get_view_names.return_value = ["view1", "view2"]
table_names = ["table1", "table2", "view1", "view2"]
mock_get_table_names.return_value = table_names
mock_is_feature_enabled.return_value = False
mock_get_table_names.return_value = ["table1", "table2", "view1", "view2"]
tables = PrestoEngineSpec.get_table_names(mock.Mock(), mock.Mock(), None)
assert tables == table_names
@mock.patch(
"superset.utils.feature_flag_manager.FeatureFlagManager.is_feature_enabled"
)
@mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_table_names")
@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.get_view_names")
def test_get_table_names_split_views_from_tables(
self, mock_get_view_names, mock_get_table_names, mock_is_feature_enabled
):
mock_get_view_names.return_value = ["view1", "view2"]
table_names = ["table1", "table2", "view1", "view2"]
mock_get_table_names.return_value = table_names
mock_is_feature_enabled.return_value = True
tables = PrestoEngineSpec.get_table_names(mock.Mock(), mock.Mock(), None)
assert sorted(tables) == sorted(table_names)
@mock.patch(
"superset.utils.feature_flag_manager.FeatureFlagManager.is_feature_enabled"
)
@mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_table_names")
@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.get_view_names")
def test_get_table_names_split_views_from_tables_no_tables(
self, mock_get_view_names, mock_get_table_names, mock_is_feature_enabled
):
mock_get_view_names.return_value = []
table_names = []
mock_get_table_names.return_value = table_names
mock_is_feature_enabled.return_value = True
tables = PrestoEngineSpec.get_table_names(mock.Mock(), mock.Mock(), None)
assert tables == []
assert tables == ["table1", "table2"]
def test_get_full_name(self):
names = [