feat: refactor all get_sqla_engine to use contextmanager in codebase (#21943)

This commit is contained in:
Hugh A. Miles II
2022-11-15 13:45:14 -05:00
committed by GitHub
parent 06f87e1467
commit e23efefc46
41 changed files with 635 additions and 595 deletions

View File

@@ -171,7 +171,7 @@ def example_db_provider() -> Callable[[], Database]: # type: ignore
return self._db
def _load_lazy_data_to_decouple_from_session(self) -> None:
self._db.get_sqla_engine() # type: ignore
self._db._get_sqla_engine() # type: ignore
self._db.backend # type: ignore
def remove(self) -> None:
@@ -336,37 +336,38 @@ def physical_dataset():
from superset.connectors.sqla.utils import get_identifier_quoter
example_database = get_example_database()
engine = example_database.get_sqla_engine()
quoter = get_identifier_quoter(engine.name)
# sqlite can only execute one statement at a time
engine.execute(
f"""
CREATE TABLE IF NOT EXISTS physical_dataset(
col1 INTEGER,
col2 VARCHAR(255),
col3 DECIMAL(4,2),
col4 VARCHAR(255),
col5 TIMESTAMP DEFAULT '1970-01-01 00:00:01',
col6 TIMESTAMP DEFAULT '1970-01-01 00:00:01',
{quoter('time column with spaces')} TIMESTAMP DEFAULT '1970-01-01 00:00:01'
);
with example_database.get_sqla_engine_with_context() as engine:
quoter = get_identifier_quoter(engine.name)
# sqlite can only execute one statement at a time
engine.execute(
f"""
CREATE TABLE IF NOT EXISTS physical_dataset(
col1 INTEGER,
col2 VARCHAR(255),
col3 DECIMAL(4,2),
col4 VARCHAR(255),
col5 TIMESTAMP DEFAULT '1970-01-01 00:00:01',
col6 TIMESTAMP DEFAULT '1970-01-01 00:00:01',
{quoter('time column with spaces')} TIMESTAMP DEFAULT '1970-01-01 00:00:01'
);
"""
)
engine.execute(
"""
INSERT INTO physical_dataset values
(0, 'a', 1.0, NULL, '2000-01-01 00:00:00', '2002-01-03 00:00:00', '2002-01-03 00:00:00'),
(1, 'b', 1.1, NULL, '2000-01-02 00:00:00', '2002-02-04 00:00:00', '2002-02-04 00:00:00'),
(2, 'c', 1.2, NULL, '2000-01-03 00:00:00', '2002-03-07 00:00:00', '2002-03-07 00:00:00'),
(3, 'd', 1.3, NULL, '2000-01-04 00:00:00', '2002-04-12 00:00:00', '2002-04-12 00:00:00'),
(4, 'e', 1.4, NULL, '2000-01-05 00:00:00', '2002-05-11 00:00:00', '2002-05-11 00:00:00'),
(5, 'f', 1.5, NULL, '2000-01-06 00:00:00', '2002-06-13 00:00:00', '2002-06-13 00:00:00'),
(6, 'g', 1.6, NULL, '2000-01-07 00:00:00', '2002-07-15 00:00:00', '2002-07-15 00:00:00'),
(7, 'h', 1.7, NULL, '2000-01-08 00:00:00', '2002-08-18 00:00:00', '2002-08-18 00:00:00'),
(8, 'i', 1.8, NULL, '2000-01-09 00:00:00', '2002-09-20 00:00:00', '2002-09-20 00:00:00'),
(9, 'j', 1.9, NULL, '2000-01-10 00:00:00', '2002-10-22 00:00:00', '2002-10-22 00:00:00');
"""
)
engine.execute(
"""
INSERT INTO physical_dataset values
(0, 'a', 1.0, NULL, '2000-01-01 00:00:00', '2002-01-03 00:00:00', '2002-01-03 00:00:00'),
(1, 'b', 1.1, NULL, '2000-01-02 00:00:00', '2002-02-04 00:00:00', '2002-02-04 00:00:00'),
(2, 'c', 1.2, NULL, '2000-01-03 00:00:00', '2002-03-07 00:00:00', '2002-03-07 00:00:00'),
(3, 'd', 1.3, NULL, '2000-01-04 00:00:00', '2002-04-12 00:00:00', '2002-04-12 00:00:00'),
(4, 'e', 1.4, NULL, '2000-01-05 00:00:00', '2002-05-11 00:00:00', '2002-05-11 00:00:00'),
(5, 'f', 1.5, NULL, '2000-01-06 00:00:00', '2002-06-13 00:00:00', '2002-06-13 00:00:00'),
(6, 'g', 1.6, NULL, '2000-01-07 00:00:00', '2002-07-15 00:00:00', '2002-07-15 00:00:00'),
(7, 'h', 1.7, NULL, '2000-01-08 00:00:00', '2002-08-18 00:00:00', '2002-08-18 00:00:00'),
(8, 'i', 1.8, NULL, '2000-01-09 00:00:00', '2002-09-20 00:00:00', '2002-09-20 00:00:00'),
(9, 'j', 1.9, NULL, '2000-01-10 00:00:00', '2002-10-22 00:00:00', '2002-10-22 00:00:00');
"""
)
)
dataset = SqlaTable(
table_name="physical_dataset",

View File

@@ -641,7 +641,7 @@ class TestImportDatabasesCommand(SupersetTestCase):
class TestTestConnectionDatabaseCommand(SupersetTestCase):
@mock.patch("superset.databases.dao.Database.get_sqla_engine")
@mock.patch("superset.databases.dao.Database._get_sqla_engine")
@mock.patch(
"superset.databases.commands.test_connection.event_logger.log_with_context"
)
@@ -664,7 +664,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
)
mock_event_logger.assert_called()
@mock.patch("superset.databases.dao.Database.get_sqla_engine")
@mock.patch("superset.databases.dao.Database._get_sqla_engine")
@mock.patch(
"superset.databases.commands.test_connection.event_logger.log_with_context"
)
@@ -713,7 +713,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
== SupersetErrorType.CONNECTION_DATABASE_TIMEOUT
)
@mock.patch("superset.databases.dao.Database.get_sqla_engine")
@mock.patch("superset.databases.dao.Database._get_sqla_engine")
@mock.patch(
"superset.databases.commands.test_connection.event_logger.log_with_context"
)
@@ -738,7 +738,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
mock_event_logger.assert_called()
@mock.patch("superset.databases.dao.Database.get_sqla_engine")
@mock.patch("superset.databases.dao.Database._get_sqla_engine")
@mock.patch(
"superset.databases.commands.test_connection.event_logger.log_with_context"
)

View File

@@ -227,8 +227,10 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
return_value="account_info"
)
mock_get_engine.return_value.url.host = "google-host"
mock_get_engine.return_value.dialect.credentials_info = "secrets"
mock_get_engine.return_value.__enter__.return_value.url.host = "google-host"
mock_get_engine.return_value.__enter__.return_value.dialect.credentials_info = (
"secrets"
)
BigQueryEngineSpec.df_to_sql(
database=database,

View File

@@ -204,7 +204,9 @@ def test_df_to_sql_if_exists_replace(mock_upload_to_s3, mock_g):
mock_database = mock.MagicMock()
mock_database.get_df.return_value.empty = False
mock_execute = mock.MagicMock(return_value=True)
mock_database.get_sqla_engine.return_value.execute = mock_execute
mock_database.get_sqla_engine_with_context.return_value.__enter__.return_value.execute = (
mock_execute
)
table_name = "foobar"
with app.app_context():
@@ -229,7 +231,9 @@ def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g):
mock_database = mock.MagicMock()
mock_database.get_df.return_value.empty = False
mock_execute = mock.MagicMock(return_value=True)
mock_database.get_sqla_engine.return_value.execute = mock_execute
mock_database.get_sqla_engine_with_context.return_value.__enter__.return_value.execute = (
mock_execute
)
table_name = "foobar"
schema = "schema"

View File

@@ -37,12 +37,13 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
def test_get_view_names_with_schema(self):
database = mock.MagicMock()
mock_execute = mock.MagicMock()
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
mock_execute
)
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
return_value=[["a", "b,", "c"], ["d", "e"]]
)
schema = "schema"
result = PrestoEngineSpec.get_view_names(database, mock.Mock(), schema)
mock_execute.assert_called_once_with(
@@ -60,10 +61,10 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
def test_get_view_names_without_schema(self):
database = mock.MagicMock()
mock_execute = mock.MagicMock()
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
mock_execute
)
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
return_value=[["a", "b,", "c"], ["d", "e"]]
)
result = PrestoEngineSpec.get_view_names(database, mock.Mock(), None)
@@ -821,13 +822,13 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
mock_execute = mock.MagicMock()
mock_fetchall = mock.MagicMock(return_value=[["a", "b,", "c"], ["d", "e"]])
database = mock.MagicMock()
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
mock_execute
)
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = (
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = (
mock_fetchall
)
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.poll.return_value = (
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.poll.return_value = (
False
)
schema = "schema"
@@ -839,7 +840,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
def test_get_create_view_exception(self):
mock_execute = mock.MagicMock(side_effect=Exception())
database = mock.MagicMock()
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
mock_execute
)
schema = "schema"
@@ -852,7 +853,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
mock_execute = mock.MagicMock(side_effect=DatabaseError())
database = mock.MagicMock()
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
mock_execute
)
schema = "schema"

View File

@@ -51,8 +51,8 @@ def load_unicode_data():
yield
with app.app_context():
engine = get_example_database().get_sqla_engine()
engine.execute("DROP TABLE IF EXISTS unicode_test")
with get_example_database().get_sqla_engine_with_context() as engine:
engine.execute("DROP TABLE IF EXISTS unicode_test")
@pytest.fixture()

View File

@@ -64,8 +64,8 @@ def load_world_bank_data():
yield
with app.app_context():
engine = get_example_database().get_sqla_engine()
engine.execute("DROP TABLE IF EXISTS wb_health_population")
with get_example_database().get_sqla_engine_with_context() as engine:
engine.execute("DROP TABLE IF EXISTS wb_health_population")
@pytest.fixture()

View File

@@ -164,7 +164,7 @@ class TestDatabaseModel(SupersetTestCase):
database_name="test_database", sqlalchemy_uri=uri, extra=extra
)
model.impersonate_user = True
model.get_sqla_engine()
model._get_sqla_engine()
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "presto://gamma@localhost"
@@ -177,7 +177,7 @@ class TestDatabaseModel(SupersetTestCase):
}
model.impersonate_user = False
model.get_sqla_engine()
model._get_sqla_engine()
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "presto://localhost"
@@ -197,7 +197,7 @@ class TestDatabaseModel(SupersetTestCase):
database_name="test_database", sqlalchemy_uri="trino://localhost"
)
model.impersonate_user = True
model.get_sqla_engine()
model._get_sqla_engine()
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "trino://localhost"
@@ -209,7 +209,7 @@ class TestDatabaseModel(SupersetTestCase):
)
model.impersonate_user = True
model.get_sqla_engine()
model._get_sqla_engine()
call_args = mocked_create_engine.call_args
assert (
@@ -242,7 +242,7 @@ class TestDatabaseModel(SupersetTestCase):
database_name="test_database", sqlalchemy_uri=uri, extra=extra
)
model.impersonate_user = True
model.get_sqla_engine()
model._get_sqla_engine()
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "hive://localhost"
@@ -255,7 +255,7 @@ class TestDatabaseModel(SupersetTestCase):
}
model.impersonate_user = False
model.get_sqla_engine()
model._get_sqla_engine()
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "hive://localhost"
@@ -380,21 +380,7 @@ class TestDatabaseModel(SupersetTestCase):
)
mocked_create_engine.side_effect = Exception()
with self.assertRaises(SupersetException):
model.get_sqla_engine()
# todo(hughhh): update this test
# @mock.patch("superset.models.core.create_engine")
# def test_get_sqla_engine_with_context(self, mocked_create_engine):
# model = Database(
# database_name="test_database",
# sqlalchemy_uri="mysql://root@localhost",
# )
# model.db_engine_spec.get_dbapi_exception_mapping = mock.Mock(
# return_value={Exception: SupersetException}
# )
# mocked_create_engine.side_effect = Exception()
# with self.assertRaises(SupersetException):
# model.get_sqla_engine()
model._get_sqla_engine()
class TestSqlaTableModel(SupersetTestCase):

View File

@@ -174,7 +174,9 @@ class TestPrestoValidator(SupersetTestCase):
def setUp(self):
self.validator = PrestoDBSQLValidator
self.database = MagicMock()
self.database_engine = self.database.get_sqla_engine.return_value
self.database_engine = (
self.database.get_sqla_engine_with_context.return_value.__enter__.return_value
)
self.database_conn = self.database_engine.raw_connection.return_value
self.database_cursor = self.database_conn.cursor.return_value
self.database_cursor.poll.return_value = None

View File

@@ -733,7 +733,7 @@ class TestSqlLab(SupersetTestCase):
mock_query = mock.MagicMock()
mock_query.database.allow_run_async = False
mock_cursor = mock.MagicMock()
mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = (
mock_query.database.get_sqla_engine_with_context().__enter__().raw_connection().cursor.return_value = (
mock_cursor
)
mock_query.database.db_engine_spec.run_multiple_statements_as_one = False
@@ -786,7 +786,7 @@ class TestSqlLab(SupersetTestCase):
mock_query = mock.MagicMock()
mock_query.database.allow_run_async = True
mock_cursor = mock.MagicMock()
mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = (
mock_query.database.get_sqla_engine_with_context().__enter__().raw_connection().cursor.return_value = (
mock_cursor
)
mock_query.database.db_engine_spec.run_multiple_statements_as_one = False
@@ -836,7 +836,7 @@ class TestSqlLab(SupersetTestCase):
mock_query = mock.MagicMock()
mock_query.database.allow_run_async = False
mock_cursor = mock.MagicMock()
mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = (
mock_query.database.get_sqla_engine_with_context().__enter__().raw_connection().cursor.return_value = (
mock_cursor
)
mock_query.database.db_engine_spec.run_multiple_statements_as_one = False