mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
feat: refactor all get_sqla_engine to use contextmanager in codebase (#21943)
This commit is contained in:
@@ -171,7 +171,7 @@ def example_db_provider() -> Callable[[], Database]: # type: ignore
|
||||
return self._db
|
||||
|
||||
def _load_lazy_data_to_decouple_from_session(self) -> None:
|
||||
self._db.get_sqla_engine() # type: ignore
|
||||
self._db._get_sqla_engine() # type: ignore
|
||||
self._db.backend # type: ignore
|
||||
|
||||
def remove(self) -> None:
|
||||
@@ -336,37 +336,38 @@ def physical_dataset():
|
||||
from superset.connectors.sqla.utils import get_identifier_quoter
|
||||
|
||||
example_database = get_example_database()
|
||||
engine = example_database.get_sqla_engine()
|
||||
quoter = get_identifier_quoter(engine.name)
|
||||
# sqlite can only execute one statement at a time
|
||||
engine.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS physical_dataset(
|
||||
col1 INTEGER,
|
||||
col2 VARCHAR(255),
|
||||
col3 DECIMAL(4,2),
|
||||
col4 VARCHAR(255),
|
||||
col5 TIMESTAMP DEFAULT '1970-01-01 00:00:01',
|
||||
col6 TIMESTAMP DEFAULT '1970-01-01 00:00:01',
|
||||
{quoter('time column with spaces')} TIMESTAMP DEFAULT '1970-01-01 00:00:01'
|
||||
);
|
||||
|
||||
with example_database.get_sqla_engine_with_context() as engine:
|
||||
quoter = get_identifier_quoter(engine.name)
|
||||
# sqlite can only execute one statement at a time
|
||||
engine.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS physical_dataset(
|
||||
col1 INTEGER,
|
||||
col2 VARCHAR(255),
|
||||
col3 DECIMAL(4,2),
|
||||
col4 VARCHAR(255),
|
||||
col5 TIMESTAMP DEFAULT '1970-01-01 00:00:01',
|
||||
col6 TIMESTAMP DEFAULT '1970-01-01 00:00:01',
|
||||
{quoter('time column with spaces')} TIMESTAMP DEFAULT '1970-01-01 00:00:01'
|
||||
);
|
||||
"""
|
||||
)
|
||||
engine.execute(
|
||||
"""
|
||||
INSERT INTO physical_dataset values
|
||||
(0, 'a', 1.0, NULL, '2000-01-01 00:00:00', '2002-01-03 00:00:00', '2002-01-03 00:00:00'),
|
||||
(1, 'b', 1.1, NULL, '2000-01-02 00:00:00', '2002-02-04 00:00:00', '2002-02-04 00:00:00'),
|
||||
(2, 'c', 1.2, NULL, '2000-01-03 00:00:00', '2002-03-07 00:00:00', '2002-03-07 00:00:00'),
|
||||
(3, 'd', 1.3, NULL, '2000-01-04 00:00:00', '2002-04-12 00:00:00', '2002-04-12 00:00:00'),
|
||||
(4, 'e', 1.4, NULL, '2000-01-05 00:00:00', '2002-05-11 00:00:00', '2002-05-11 00:00:00'),
|
||||
(5, 'f', 1.5, NULL, '2000-01-06 00:00:00', '2002-06-13 00:00:00', '2002-06-13 00:00:00'),
|
||||
(6, 'g', 1.6, NULL, '2000-01-07 00:00:00', '2002-07-15 00:00:00', '2002-07-15 00:00:00'),
|
||||
(7, 'h', 1.7, NULL, '2000-01-08 00:00:00', '2002-08-18 00:00:00', '2002-08-18 00:00:00'),
|
||||
(8, 'i', 1.8, NULL, '2000-01-09 00:00:00', '2002-09-20 00:00:00', '2002-09-20 00:00:00'),
|
||||
(9, 'j', 1.9, NULL, '2000-01-10 00:00:00', '2002-10-22 00:00:00', '2002-10-22 00:00:00');
|
||||
"""
|
||||
)
|
||||
engine.execute(
|
||||
"""
|
||||
INSERT INTO physical_dataset values
|
||||
(0, 'a', 1.0, NULL, '2000-01-01 00:00:00', '2002-01-03 00:00:00', '2002-01-03 00:00:00'),
|
||||
(1, 'b', 1.1, NULL, '2000-01-02 00:00:00', '2002-02-04 00:00:00', '2002-02-04 00:00:00'),
|
||||
(2, 'c', 1.2, NULL, '2000-01-03 00:00:00', '2002-03-07 00:00:00', '2002-03-07 00:00:00'),
|
||||
(3, 'd', 1.3, NULL, '2000-01-04 00:00:00', '2002-04-12 00:00:00', '2002-04-12 00:00:00'),
|
||||
(4, 'e', 1.4, NULL, '2000-01-05 00:00:00', '2002-05-11 00:00:00', '2002-05-11 00:00:00'),
|
||||
(5, 'f', 1.5, NULL, '2000-01-06 00:00:00', '2002-06-13 00:00:00', '2002-06-13 00:00:00'),
|
||||
(6, 'g', 1.6, NULL, '2000-01-07 00:00:00', '2002-07-15 00:00:00', '2002-07-15 00:00:00'),
|
||||
(7, 'h', 1.7, NULL, '2000-01-08 00:00:00', '2002-08-18 00:00:00', '2002-08-18 00:00:00'),
|
||||
(8, 'i', 1.8, NULL, '2000-01-09 00:00:00', '2002-09-20 00:00:00', '2002-09-20 00:00:00'),
|
||||
(9, 'j', 1.9, NULL, '2000-01-10 00:00:00', '2002-10-22 00:00:00', '2002-10-22 00:00:00');
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
dataset = SqlaTable(
|
||||
table_name="physical_dataset",
|
||||
|
||||
@@ -641,7 +641,7 @@ class TestImportDatabasesCommand(SupersetTestCase):
|
||||
|
||||
|
||||
class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
@mock.patch("superset.databases.dao.Database.get_sqla_engine")
|
||||
@mock.patch("superset.databases.dao.Database._get_sqla_engine")
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.event_logger.log_with_context"
|
||||
)
|
||||
@@ -664,7 +664,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
)
|
||||
mock_event_logger.assert_called()
|
||||
|
||||
@mock.patch("superset.databases.dao.Database.get_sqla_engine")
|
||||
@mock.patch("superset.databases.dao.Database._get_sqla_engine")
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.event_logger.log_with_context"
|
||||
)
|
||||
@@ -713,7 +713,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
== SupersetErrorType.CONNECTION_DATABASE_TIMEOUT
|
||||
)
|
||||
|
||||
@mock.patch("superset.databases.dao.Database.get_sqla_engine")
|
||||
@mock.patch("superset.databases.dao.Database._get_sqla_engine")
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.event_logger.log_with_context"
|
||||
)
|
||||
@@ -738,7 +738,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
|
||||
mock_event_logger.assert_called()
|
||||
|
||||
@mock.patch("superset.databases.dao.Database.get_sqla_engine")
|
||||
@mock.patch("superset.databases.dao.Database._get_sqla_engine")
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.event_logger.log_with_context"
|
||||
)
|
||||
|
||||
@@ -227,8 +227,10 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
|
||||
return_value="account_info"
|
||||
)
|
||||
|
||||
mock_get_engine.return_value.url.host = "google-host"
|
||||
mock_get_engine.return_value.dialect.credentials_info = "secrets"
|
||||
mock_get_engine.return_value.__enter__.return_value.url.host = "google-host"
|
||||
mock_get_engine.return_value.__enter__.return_value.dialect.credentials_info = (
|
||||
"secrets"
|
||||
)
|
||||
|
||||
BigQueryEngineSpec.df_to_sql(
|
||||
database=database,
|
||||
|
||||
@@ -204,7 +204,9 @@ def test_df_to_sql_if_exists_replace(mock_upload_to_s3, mock_g):
|
||||
mock_database = mock.MagicMock()
|
||||
mock_database.get_df.return_value.empty = False
|
||||
mock_execute = mock.MagicMock(return_value=True)
|
||||
mock_database.get_sqla_engine.return_value.execute = mock_execute
|
||||
mock_database.get_sqla_engine_with_context.return_value.__enter__.return_value.execute = (
|
||||
mock_execute
|
||||
)
|
||||
table_name = "foobar"
|
||||
|
||||
with app.app_context():
|
||||
@@ -229,7 +231,9 @@ def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g):
|
||||
mock_database = mock.MagicMock()
|
||||
mock_database.get_df.return_value.empty = False
|
||||
mock_execute = mock.MagicMock(return_value=True)
|
||||
mock_database.get_sqla_engine.return_value.execute = mock_execute
|
||||
mock_database.get_sqla_engine_with_context.return_value.__enter__.return_value.execute = (
|
||||
mock_execute
|
||||
)
|
||||
table_name = "foobar"
|
||||
schema = "schema"
|
||||
|
||||
|
||||
@@ -37,12 +37,13 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
||||
def test_get_view_names_with_schema(self):
|
||||
database = mock.MagicMock()
|
||||
mock_execute = mock.MagicMock()
|
||||
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
|
||||
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
|
||||
mock_execute
|
||||
)
|
||||
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
|
||||
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
|
||||
return_value=[["a", "b,", "c"], ["d", "e"]]
|
||||
)
|
||||
|
||||
schema = "schema"
|
||||
result = PrestoEngineSpec.get_view_names(database, mock.Mock(), schema)
|
||||
mock_execute.assert_called_once_with(
|
||||
@@ -60,10 +61,10 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
||||
def test_get_view_names_without_schema(self):
|
||||
database = mock.MagicMock()
|
||||
mock_execute = mock.MagicMock()
|
||||
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
|
||||
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
|
||||
mock_execute
|
||||
)
|
||||
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
|
||||
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
|
||||
return_value=[["a", "b,", "c"], ["d", "e"]]
|
||||
)
|
||||
result = PrestoEngineSpec.get_view_names(database, mock.Mock(), None)
|
||||
@@ -821,13 +822,13 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
||||
mock_execute = mock.MagicMock()
|
||||
mock_fetchall = mock.MagicMock(return_value=[["a", "b,", "c"], ["d", "e"]])
|
||||
database = mock.MagicMock()
|
||||
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
|
||||
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
|
||||
mock_execute
|
||||
)
|
||||
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = (
|
||||
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = (
|
||||
mock_fetchall
|
||||
)
|
||||
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.poll.return_value = (
|
||||
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.poll.return_value = (
|
||||
False
|
||||
)
|
||||
schema = "schema"
|
||||
@@ -839,7 +840,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
||||
def test_get_create_view_exception(self):
|
||||
mock_execute = mock.MagicMock(side_effect=Exception())
|
||||
database = mock.MagicMock()
|
||||
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
|
||||
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
|
||||
mock_execute
|
||||
)
|
||||
schema = "schema"
|
||||
@@ -852,7 +853,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
||||
|
||||
mock_execute = mock.MagicMock(side_effect=DatabaseError())
|
||||
database = mock.MagicMock()
|
||||
database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = (
|
||||
database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
|
||||
mock_execute
|
||||
)
|
||||
schema = "schema"
|
||||
|
||||
@@ -51,8 +51,8 @@ def load_unicode_data():
|
||||
|
||||
yield
|
||||
with app.app_context():
|
||||
engine = get_example_database().get_sqla_engine()
|
||||
engine.execute("DROP TABLE IF EXISTS unicode_test")
|
||||
with get_example_database().get_sqla_engine_with_context() as engine:
|
||||
engine.execute("DROP TABLE IF EXISTS unicode_test")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
||||
@@ -64,8 +64,8 @@ def load_world_bank_data():
|
||||
|
||||
yield
|
||||
with app.app_context():
|
||||
engine = get_example_database().get_sqla_engine()
|
||||
engine.execute("DROP TABLE IF EXISTS wb_health_population")
|
||||
with get_example_database().get_sqla_engine_with_context() as engine:
|
||||
engine.execute("DROP TABLE IF EXISTS wb_health_population")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
||||
@@ -164,7 +164,7 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
database_name="test_database", sqlalchemy_uri=uri, extra=extra
|
||||
)
|
||||
model.impersonate_user = True
|
||||
model.get_sqla_engine()
|
||||
model._get_sqla_engine()
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "presto://gamma@localhost"
|
||||
@@ -177,7 +177,7 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
}
|
||||
|
||||
model.impersonate_user = False
|
||||
model.get_sqla_engine()
|
||||
model._get_sqla_engine()
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "presto://localhost"
|
||||
@@ -197,7 +197,7 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
database_name="test_database", sqlalchemy_uri="trino://localhost"
|
||||
)
|
||||
model.impersonate_user = True
|
||||
model.get_sqla_engine()
|
||||
model._get_sqla_engine()
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "trino://localhost"
|
||||
@@ -209,7 +209,7 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
)
|
||||
|
||||
model.impersonate_user = True
|
||||
model.get_sqla_engine()
|
||||
model._get_sqla_engine()
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert (
|
||||
@@ -242,7 +242,7 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
database_name="test_database", sqlalchemy_uri=uri, extra=extra
|
||||
)
|
||||
model.impersonate_user = True
|
||||
model.get_sqla_engine()
|
||||
model._get_sqla_engine()
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "hive://localhost"
|
||||
@@ -255,7 +255,7 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
}
|
||||
|
||||
model.impersonate_user = False
|
||||
model.get_sqla_engine()
|
||||
model._get_sqla_engine()
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "hive://localhost"
|
||||
@@ -380,21 +380,7 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
)
|
||||
mocked_create_engine.side_effect = Exception()
|
||||
with self.assertRaises(SupersetException):
|
||||
model.get_sqla_engine()
|
||||
|
||||
# todo(hughhh): update this test
|
||||
# @mock.patch("superset.models.core.create_engine")
|
||||
# def test_get_sqla_engine_with_context(self, mocked_create_engine):
|
||||
# model = Database(
|
||||
# database_name="test_database",
|
||||
# sqlalchemy_uri="mysql://root@localhost",
|
||||
# )
|
||||
# model.db_engine_spec.get_dbapi_exception_mapping = mock.Mock(
|
||||
# return_value={Exception: SupersetException}
|
||||
# )
|
||||
# mocked_create_engine.side_effect = Exception()
|
||||
# with self.assertRaises(SupersetException):
|
||||
# model.get_sqla_engine()
|
||||
model._get_sqla_engine()
|
||||
|
||||
|
||||
class TestSqlaTableModel(SupersetTestCase):
|
||||
|
||||
@@ -174,7 +174,9 @@ class TestPrestoValidator(SupersetTestCase):
|
||||
def setUp(self):
|
||||
self.validator = PrestoDBSQLValidator
|
||||
self.database = MagicMock()
|
||||
self.database_engine = self.database.get_sqla_engine.return_value
|
||||
self.database_engine = (
|
||||
self.database.get_sqla_engine_with_context.return_value.__enter__.return_value
|
||||
)
|
||||
self.database_conn = self.database_engine.raw_connection.return_value
|
||||
self.database_cursor = self.database_conn.cursor.return_value
|
||||
self.database_cursor.poll.return_value = None
|
||||
|
||||
@@ -733,7 +733,7 @@ class TestSqlLab(SupersetTestCase):
|
||||
mock_query = mock.MagicMock()
|
||||
mock_query.database.allow_run_async = False
|
||||
mock_cursor = mock.MagicMock()
|
||||
mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = (
|
||||
mock_query.database.get_sqla_engine_with_context().__enter__().raw_connection().cursor.return_value = (
|
||||
mock_cursor
|
||||
)
|
||||
mock_query.database.db_engine_spec.run_multiple_statements_as_one = False
|
||||
@@ -786,7 +786,7 @@ class TestSqlLab(SupersetTestCase):
|
||||
mock_query = mock.MagicMock()
|
||||
mock_query.database.allow_run_async = True
|
||||
mock_cursor = mock.MagicMock()
|
||||
mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = (
|
||||
mock_query.database.get_sqla_engine_with_context().__enter__().raw_connection().cursor.return_value = (
|
||||
mock_cursor
|
||||
)
|
||||
mock_query.database.db_engine_spec.run_multiple_statements_as_one = False
|
||||
@@ -836,7 +836,7 @@ class TestSqlLab(SupersetTestCase):
|
||||
mock_query = mock.MagicMock()
|
||||
mock_query.database.allow_run_async = False
|
||||
mock_cursor = mock.MagicMock()
|
||||
mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = (
|
||||
mock_query.database.get_sqla_engine_with_context().__enter__().raw_connection().cursor.return_value = (
|
||||
mock_cursor
|
||||
)
|
||||
mock_query.database.db_engine_spec.run_multiple_statements_as_one = False
|
||||
|
||||
Reference in New Issue
Block a user