refactor: Ensure Flask framework leverages the Flask-SQLAlchemy session (Phase II) (#26909)

This commit is contained in:
John Bodley
2024-02-14 06:20:15 +13:00
committed by GitHub
parent 827864b939
commit 847ed3f5b0
96 changed files with 656 additions and 730 deletions

View File

@@ -33,10 +33,10 @@ class TestDatabricksDbEngineSpec(TestDbEngineSpec):
assert get_engine_spec("databricks", "pyhive").engine == "databricks"
def test_extras_without_ssl(self):
db = mock.Mock()
db.extra = default_db_extra
db.server_cert = None
extras = DatabricksNativeEngineSpec.get_extra_params(db)
database = mock.Mock()
database.extra = default_db_extra
database.server_cert = None
extras = DatabricksNativeEngineSpec.get_extra_params(database)
assert extras == {
"engine_params": {
"connect_args": {
@@ -50,12 +50,12 @@ class TestDatabricksDbEngineSpec(TestDbEngineSpec):
}
def test_extras_with_ssl_custom(self):
db = mock.Mock()
db.extra = default_db_extra.replace(
database = mock.Mock()
database.extra = default_db_extra.replace(
'"engine_params": {}',
'"engine_params": {"connect_args": {"ssl": "1"}}',
)
db.server_cert = ssl_certificate
extras = DatabricksNativeEngineSpec.get_extra_params(db)
database.server_cert = ssl_certificate
extras = DatabricksNativeEngineSpec.get_extra_params(database)
connect_args = extras["engine_params"]["connect_args"]
assert connect_args["ssl"] == "1"

View File

@@ -337,14 +337,14 @@ def test_fetch_data_success(fetch_data_mock):
@mock.patch("superset.db_engine_specs.hive.HiveEngineSpec._latest_partition_from_df")
def test_where_latest_partition(mock_method):
mock_method.return_value = ("01-01-19", 1)
db = mock.Mock()
db.get_indexes = mock.Mock(return_value=[{"column_names": ["ds", "hour"]}])
db.get_extra = mock.Mock(return_value={})
db.get_df = mock.Mock()
database = mock.Mock()
database.get_indexes = mock.Mock(return_value=[{"column_names": ["ds", "hour"]}])
database.get_extra = mock.Mock(return_value={})
database.get_df = mock.Mock()
columns = [{"name": "ds"}, {"name": "hour"}]
with app.app_context():
result = HiveEngineSpec.where_latest_partition(
"test_table", "test_schema", db, select(), columns
"test_table", "test_schema", database, select(), columns
)
query_result = str(result.compile(compile_kwargs={"literal_binds": True}))
assert "SELECT \nWHERE ds = '01-01-19' AND hour = 1" == query_result
@@ -353,11 +353,11 @@ def test_where_latest_partition(mock_method):
@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.latest_partition")
def test_where_latest_partition_super_method_exception(mock_method):
mock_method.side_effect = Exception()
db = mock.Mock()
database = mock.Mock()
columns = [{"name": "ds"}, {"name": "hour"}]
with app.app_context():
result = HiveEngineSpec.where_latest_partition(
"test_table", "test_schema", db, select(), columns
"test_table", "test_schema", database, select(), columns
)
assert result is None
mock_method.assert_called()

View File

@@ -119,29 +119,29 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
assert "postgres" in backends
def test_extras_without_ssl(self):
db = mock.Mock()
db.extra = default_db_extra
db.server_cert = None
extras = PostgresEngineSpec.get_extra_params(db)
database = mock.Mock()
database.extra = default_db_extra
database.server_cert = None
extras = PostgresEngineSpec.get_extra_params(database)
assert "connect_args" not in extras["engine_params"]
def test_extras_with_ssl_default(self):
db = mock.Mock()
db.extra = default_db_extra
db.server_cert = ssl_certificate
extras = PostgresEngineSpec.get_extra_params(db)
database = mock.Mock()
database.extra = default_db_extra
database.server_cert = ssl_certificate
extras = PostgresEngineSpec.get_extra_params(database)
connect_args = extras["engine_params"]["connect_args"]
assert connect_args["sslmode"] == "verify-full"
assert "sslrootcert" in connect_args
def test_extras_with_ssl_custom(self):
db = mock.Mock()
db.extra = default_db_extra.replace(
database = mock.Mock()
database.extra = default_db_extra.replace(
'"engine_params": {}',
'"engine_params": {"connect_args": {"sslmode": "verify-ca"}}',
)
db.server_cert = ssl_certificate
extras = PostgresEngineSpec.get_extra_params(db)
database.server_cert = ssl_certificate
extras = PostgresEngineSpec.get_extra_params(database)
connect_args = extras["engine_params"]["connect_args"]
assert connect_args["sslmode"] == "verify-ca"
assert "sslrootcert" in connect_args

View File

@@ -550,13 +550,17 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
def test_presto_extra_table_metadata(self):
db = mock.Mock()
db.get_indexes = mock.Mock(return_value=[{"column_names": ["ds", "hour"]}])
db.get_extra = mock.Mock(return_value={})
database = mock.Mock()
database.get_indexes = mock.Mock(
return_value=[{"column_names": ["ds", "hour"]}]
)
database.get_extra = mock.Mock(return_value={})
df = pd.DataFrame({"ds": ["01-01-19"], "hour": [1]})
db.get_df = mock.Mock(return_value=df)
database.get_df = mock.Mock(return_value=df)
PrestoEngineSpec.get_create_view = mock.Mock(return_value=None)
result = PrestoEngineSpec.extra_table_metadata(db, "test_table", "test_schema")
result = PrestoEngineSpec.extra_table_metadata(
database, "test_table", "test_schema"
)
assert result["partitions"]["cols"] == ["ds", "hour"]
assert result["partitions"]["latest"] == {"ds": "01-01-19", "hour": 1}