mirror of
https://github.com/apache/superset.git
synced 2026-05-07 00:44:26 +00:00
Update existing tests
This commit is contained in:
@@ -70,12 +70,6 @@ class EngineManagerExtension:
|
||||
def shutdown_engine_manager() -> None:
|
||||
if self.engine_manager:
|
||||
self.engine_manager.stop_cleanup_thread()
|
||||
# Use a try-except to handle closed log file handlers during tests
|
||||
try:
|
||||
logger.info("Stopped EngineManager cleanup thread")
|
||||
except ValueError:
|
||||
# Ignore logging errors during test shutdown when file handles are closed
|
||||
pass
|
||||
|
||||
app.teardown_appcontext_funcs.append(lambda exc: None)
|
||||
|
||||
|
||||
@@ -170,7 +170,6 @@ def example_db_provider() -> Callable[[], Database]:
|
||||
return self._db
|
||||
|
||||
def _load_lazy_data_to_decouple_from_session(self) -> None:
|
||||
self._db._get_sqla_engine() # type: ignore
|
||||
self._db.backend # type: ignore # noqa: B018
|
||||
|
||||
def remove(self) -> None:
|
||||
|
||||
@@ -897,7 +897,7 @@ class TestImportDatabasesCommand(SupersetTestCase):
|
||||
|
||||
|
||||
class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
@patch("superset.models.core.Database._get_sqla_engine")
|
||||
@patch("superset.models.core.Database.get_sqla_engine")
|
||||
@patch("superset.commands.database.test_connection.event_logger.log_with_context")
|
||||
@patch("superset.utils.core.g")
|
||||
def test_connection_db_exception(
|
||||
@@ -906,19 +906,19 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
"""Test to make sure event_logger is called when an exception is raised"""
|
||||
database = get_example_database()
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
mock_get_sqla_engine.side_effect = Exception("An error has occurred!")
|
||||
mock_get_sqla_engine.__enter__.side_effect = Exception("An error has occurred!")
|
||||
db_uri = database.sqlalchemy_uri_decrypted
|
||||
json_payload = {"sqlalchemy_uri": db_uri}
|
||||
command_without_db_name = TestConnectionDatabaseCommand(json_payload)
|
||||
|
||||
with pytest.raises(DatabaseTestConnectionUnexpectedError) as excinfo: # noqa: PT012
|
||||
command_without_db_name.run()
|
||||
assert str(excinfo.value) == (
|
||||
"Unexpected error occurred, please check your logs for details"
|
||||
)
|
||||
assert str(excinfo.value) == (
|
||||
"Unexpected error occurred, please check your logs for details"
|
||||
)
|
||||
mock_event_logger.assert_called()
|
||||
|
||||
@patch("superset.models.core.Database._get_sqla_engine")
|
||||
@patch("superset.models.core.Database.get_sqla_engine")
|
||||
@patch("superset.commands.database.test_connection.event_logger.log_with_context")
|
||||
@patch("superset.utils.core.g")
|
||||
def test_connection_do_ping_exception(
|
||||
@@ -927,7 +927,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
"""Test to make sure do_ping exceptions gets captured"""
|
||||
database = get_example_database()
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
mock_get_sqla_engine.return_value.dialect.do_ping.side_effect = Exception(
|
||||
mock_get_sqla_engine.__enter__().dialect.do_ping.side_effect = Exception(
|
||||
"An error has occurred!"
|
||||
)
|
||||
db_uri = database.sqlalchemy_uri_decrypted
|
||||
@@ -967,7 +967,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
== SupersetErrorType.CONNECTION_DATABASE_TIMEOUT
|
||||
)
|
||||
|
||||
@patch("superset.models.core.Database._get_sqla_engine")
|
||||
@patch("superset.models.core.Database.get_sqla_engine")
|
||||
@patch("superset.commands.database.test_connection.event_logger.log_with_context")
|
||||
@patch("superset.utils.core.g")
|
||||
def test_connection_superset_security_connection(
|
||||
@@ -977,7 +977,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
connection exc is raised"""
|
||||
database = get_example_database()
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
mock_get_sqla_engine.side_effect = SupersetSecurityException(
|
||||
mock_get_sqla_engine.__enter__.side_effect = SupersetSecurityException(
|
||||
SupersetError(error_type=500, message="test", level="info")
|
||||
)
|
||||
db_uri = database.sqlalchemy_uri_decrypted
|
||||
@@ -990,7 +990,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
|
||||
mock_event_logger.assert_called()
|
||||
|
||||
@patch("superset.models.core.Database._get_sqla_engine")
|
||||
@patch("superset.models.core.Database.get_sqla_engine")
|
||||
@patch("superset.commands.database.test_connection.event_logger.log_with_context")
|
||||
@patch("superset.utils.core.g")
|
||||
def test_connection_db_api_exc(
|
||||
@@ -999,7 +999,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
"""Test to make sure event_logger is called when DBAPIError is raised"""
|
||||
database = get_example_database()
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
mock_get_sqla_engine.side_effect = DBAPIError(
|
||||
mock_get_sqla_engine.__enter__.side_effect = DBAPIError(
|
||||
statement="error", params={}, orig={}
|
||||
)
|
||||
db_uri = database.sqlalchemy_uri_decrypted
|
||||
|
||||
@@ -96,7 +96,9 @@ class TestEngineManager:
|
||||
@pytest.fixture
|
||||
def engine_manager(self):
|
||||
"""Create a mock EngineManager instance."""
|
||||
from contextlib import contextmanager
|
||||
|
||||
@contextmanager
|
||||
def dummy_context_manager(
|
||||
database: MagicMock, catalog: str | None, schema: str | None
|
||||
) -> Iterator[None]:
|
||||
@@ -293,3 +295,233 @@ class TestEngineManager:
|
||||
result2 = engine_manager._get_tunnel(ssh_tunnel, uri)
|
||||
assert result2 is active_tunnel
|
||||
assert mock_tunnel_class.call_count == 2
|
||||
|
||||
@patch("superset.engines.manager.create_engine")
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
def test_get_engine_args_basic(
|
||||
self, mock_make_url, mock_create_engine, engine_manager
|
||||
):
|
||||
"""Test _get_engine_args returns correct URI and kwargs."""
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
from superset.engines.manager import EngineModes
|
||||
|
||||
engine_manager.mode = EngineModes.NEW
|
||||
|
||||
mock_uri = make_url("trino://")
|
||||
mock_make_url.return_value = mock_uri
|
||||
|
||||
database = MagicMock()
|
||||
database.id = 1
|
||||
database.sqlalchemy_uri_decrypted = "trino://"
|
||||
database.get_extra.return_value = {
|
||||
"engine_params": {},
|
||||
"connect_args": {"source": "Apache Superset"},
|
||||
}
|
||||
database.get_effective_user.return_value = "alice"
|
||||
database.impersonate_user = False
|
||||
database.update_params_from_encrypted_extra = MagicMock()
|
||||
database.db_engine_spec = MagicMock()
|
||||
database.db_engine_spec.adjust_engine_params.return_value = (
|
||||
mock_uri,
|
||||
{"source": "Apache Superset"},
|
||||
)
|
||||
database.db_engine_spec.validate_database_uri = MagicMock()
|
||||
|
||||
uri, kwargs = engine_manager._get_engine_args(database, None, None, None, None)
|
||||
|
||||
assert str(uri) == "trino://"
|
||||
assert "connect_args" in database.get_extra.return_value
|
||||
|
||||
@patch("superset.engines.manager.create_engine")
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
def test_get_engine_args_user_impersonation(
|
||||
self, mock_make_url, mock_create_engine, engine_manager
|
||||
):
|
||||
"""Test user impersonation in _get_engine_args."""
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
from superset.engines.manager import EngineModes
|
||||
|
||||
engine_manager.mode = EngineModes.NEW
|
||||
|
||||
mock_uri = make_url("trino://")
|
||||
mock_make_url.return_value = mock_uri
|
||||
|
||||
database = MagicMock()
|
||||
database.id = 1
|
||||
database.sqlalchemy_uri_decrypted = "trino://"
|
||||
database.get_extra.return_value = {
|
||||
"engine_params": {},
|
||||
"connect_args": {"source": "Apache Superset"},
|
||||
}
|
||||
database.get_effective_user.return_value = "alice"
|
||||
database.impersonate_user = True
|
||||
database.get_oauth2_config.return_value = None
|
||||
database.update_params_from_encrypted_extra = MagicMock()
|
||||
database.db_engine_spec = MagicMock()
|
||||
database.db_engine_spec.adjust_engine_params.return_value = (
|
||||
mock_uri,
|
||||
{"source": "Apache Superset"},
|
||||
)
|
||||
database.db_engine_spec.impersonate_user.return_value = (
|
||||
mock_uri,
|
||||
{"connect_args": {"user": "alice", "source": "Apache Superset"}},
|
||||
)
|
||||
database.db_engine_spec.validate_database_uri = MagicMock()
|
||||
|
||||
uri, kwargs = engine_manager._get_engine_args(database, None, None, None, None)
|
||||
|
||||
# Verify impersonate_user was called
|
||||
database.db_engine_spec.impersonate_user.assert_called_once()
|
||||
call_args = database.db_engine_spec.impersonate_user.call_args
|
||||
assert call_args[0][0] is database # database
|
||||
assert call_args[0][1] == "alice" # username
|
||||
assert call_args[0][2] is None # access_token (no OAuth2)
|
||||
|
||||
@patch("superset.engines.manager.create_engine")
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
def test_get_engine_args_user_impersonation_email_prefix(
|
||||
self,
|
||||
mock_make_url,
|
||||
mock_create_engine,
|
||||
engine_manager,
|
||||
):
|
||||
"""Test user impersonation with IMPERSONATE_WITH_EMAIL_PREFIX feature flag."""
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
from superset.engines.manager import EngineModes
|
||||
|
||||
engine_manager.mode = EngineModes.NEW
|
||||
|
||||
mock_uri = make_url("trino://")
|
||||
mock_make_url.return_value = mock_uri
|
||||
|
||||
# Mock user with email
|
||||
mock_user = MagicMock()
|
||||
mock_user.email = "alice.doe@example.org"
|
||||
|
||||
database = MagicMock()
|
||||
database.id = 1
|
||||
database.sqlalchemy_uri_decrypted = "trino://"
|
||||
database.get_extra.return_value = {
|
||||
"engine_params": {},
|
||||
"connect_args": {"source": "Apache Superset"},
|
||||
}
|
||||
database.get_effective_user.return_value = "alice"
|
||||
database.impersonate_user = True
|
||||
database.get_oauth2_config.return_value = None
|
||||
database.update_params_from_encrypted_extra = MagicMock()
|
||||
database.db_engine_spec = MagicMock()
|
||||
database.db_engine_spec.adjust_engine_params.return_value = (
|
||||
mock_uri,
|
||||
{"source": "Apache Superset"},
|
||||
)
|
||||
database.db_engine_spec.impersonate_user.return_value = (
|
||||
mock_uri,
|
||||
{"connect_args": {"user": "alice.doe", "source": "Apache Superset"}},
|
||||
)
|
||||
database.db_engine_spec.validate_database_uri = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"superset.utils.feature_flag_manager.FeatureFlagManager.is_feature_enabled",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"superset.extensions.security_manager.find_user",
|
||||
return_value=mock_user,
|
||||
),
|
||||
):
|
||||
uri, kwargs = engine_manager._get_engine_args(
|
||||
database, None, None, None, None
|
||||
)
|
||||
|
||||
# Verify impersonate_user was called with the email prefix
|
||||
database.db_engine_spec.impersonate_user.assert_called_once()
|
||||
call_args = database.db_engine_spec.impersonate_user.call_args
|
||||
assert call_args[0][1] == "alice.doe" # username from email prefix
|
||||
|
||||
@patch("superset.engines.manager.create_engine")
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
def test_engine_context_manager_called(
|
||||
self, mock_make_url, mock_create_engine, engine_manager, mock_database
|
||||
):
|
||||
"""Test that the engine context manager is properly called."""
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
mock_uri = make_url("trino://")
|
||||
mock_make_url.return_value = mock_uri
|
||||
mock_engine = MagicMock()
|
||||
mock_create_engine.return_value = mock_engine
|
||||
|
||||
# Track context manager calls
|
||||
context_manager_calls = []
|
||||
|
||||
def tracking_context_manager(database, catalog, schema):
|
||||
from contextlib import contextmanager
|
||||
|
||||
@contextmanager
|
||||
def inner():
|
||||
context_manager_calls.append(("enter", database, catalog, schema))
|
||||
yield
|
||||
context_manager_calls.append(("exit", database, catalog, schema))
|
||||
|
||||
return inner()
|
||||
|
||||
engine_manager.engine_context_manager = tracking_context_manager
|
||||
|
||||
with engine_manager.get_engine(mock_database, "catalog1", "schema1", None):
|
||||
pass
|
||||
|
||||
assert len(context_manager_calls) == 2
|
||||
assert context_manager_calls[0][0] == "enter"
|
||||
assert context_manager_calls[0][1] is mock_database
|
||||
assert context_manager_calls[0][2] == "catalog1"
|
||||
assert context_manager_calls[0][3] == "schema1"
|
||||
assert context_manager_calls[1][0] == "exit"
|
||||
|
||||
@patch("superset.utils.oauth2.check_for_oauth2")
|
||||
@patch("superset.engines.manager.create_engine")
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
def test_engine_oauth2_error_handling(
|
||||
self,
|
||||
mock_make_url,
|
||||
mock_create_engine,
|
||||
mock_check_for_oauth2,
|
||||
engine_manager,
|
||||
mock_database,
|
||||
):
|
||||
"""Test that OAuth2 errors are properly propagated from get_engine."""
|
||||
from contextlib import contextmanager
|
||||
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
mock_uri = make_url("trino://")
|
||||
mock_make_url.return_value = mock_uri
|
||||
|
||||
# Simulate OAuth2 error during engine creation
|
||||
class OAuth2TestError(Exception):
|
||||
pass
|
||||
|
||||
oauth_error = OAuth2TestError("OAuth2 required")
|
||||
mock_create_engine.side_effect = oauth_error
|
||||
|
||||
# Make get_dbapi_mapped_exception return the original exception
|
||||
mock_database.db_engine_spec.get_dbapi_mapped_exception.return_value = (
|
||||
oauth_error
|
||||
)
|
||||
|
||||
# Mock check_for_oauth2 to re-raise the exception
|
||||
@contextmanager
|
||||
def mock_oauth2_context(database):
|
||||
try:
|
||||
yield
|
||||
except OAuth2TestError:
|
||||
raise
|
||||
|
||||
mock_check_for_oauth2.return_value = mock_oauth2_context(mock_database)
|
||||
|
||||
with pytest.raises(OAuth2TestError, match="OAuth2 required"):
|
||||
with engine_manager.get_engine(mock_database, "catalog1", "schema1", None):
|
||||
pass
|
||||
|
||||
@@ -123,7 +123,7 @@ class TestSupersetAppInitializer:
|
||||
patch.object(app_initializer, "configure_data_sources"),
|
||||
patch.object(app_initializer, "configure_auth_provider"),
|
||||
patch.object(app_initializer, "configure_async_queries"),
|
||||
patch.object(app_initializer, "configure_ssh_manager"),
|
||||
patch.object(app_initializer, "configure_engine_manager"),
|
||||
patch.object(app_initializer, "configure_stats_manager"),
|
||||
patch.object(app_initializer, "init_views"),
|
||||
):
|
||||
|
||||
@@ -19,7 +19,6 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from flask import current_app
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
@@ -29,7 +28,6 @@ from sqlalchemy import (
|
||||
Table as SqlalchemyTable,
|
||||
)
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.orm.session import Session
|
||||
from sqlalchemy.sql import Select
|
||||
|
||||
@@ -525,60 +523,6 @@ def test_get_all_materialized_view_names_in_schema_needs_oauth2(
|
||||
assert excinfo.value.error.error_type == SupersetErrorType.OAUTH2_REDIRECT
|
||||
|
||||
|
||||
def test_get_sqla_engine(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test `_get_sqla_engine`.
|
||||
"""
|
||||
from superset.models.core import Database
|
||||
|
||||
user = mocker.MagicMock()
|
||||
user.email = "alice.doe@example.org"
|
||||
mocker.patch(
|
||||
"superset.models.core.security_manager.find_user",
|
||||
return_value=user,
|
||||
)
|
||||
mocker.patch("superset.models.core.get_username", return_value="alice")
|
||||
|
||||
create_engine = mocker.patch("superset.models.core.create_engine")
|
||||
|
||||
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
|
||||
database._get_sqla_engine(nullpool=False)
|
||||
|
||||
create_engine.assert_called_with(
|
||||
make_url("trino:///"),
|
||||
connect_args={"source": "Apache Superset"},
|
||||
)
|
||||
|
||||
|
||||
def test_get_sqla_engine_user_impersonation(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test user impersonation in `_get_sqla_engine`.
|
||||
"""
|
||||
from superset.models.core import Database
|
||||
|
||||
user = mocker.MagicMock()
|
||||
user.email = "alice.doe@example.org"
|
||||
mocker.patch(
|
||||
"superset.models.core.security_manager.find_user",
|
||||
return_value=user,
|
||||
)
|
||||
mocker.patch("superset.models.core.get_username", return_value="alice")
|
||||
|
||||
create_engine = mocker.patch("superset.models.core.create_engine")
|
||||
|
||||
database = Database(
|
||||
database_name="my_db",
|
||||
sqlalchemy_uri="trino://",
|
||||
impersonate_user=True,
|
||||
)
|
||||
database._get_sqla_engine(nullpool=False)
|
||||
|
||||
create_engine.assert_called_with(
|
||||
make_url("trino:///"),
|
||||
connect_args={"user": "alice", "source": "Apache Superset"},
|
||||
)
|
||||
|
||||
|
||||
def test_add_database_to_signature():
|
||||
args = ["param1", "param2"]
|
||||
|
||||
@@ -604,36 +548,6 @@ def test_add_database_to_signature():
|
||||
assert args3 == ["param1", "param2", database]
|
||||
|
||||
|
||||
@with_feature_flags(IMPERSONATE_WITH_EMAIL_PREFIX=True)
|
||||
def test_get_sqla_engine_user_impersonation_email(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test user impersonation in `_get_sqla_engine` with `username_from_email`.
|
||||
"""
|
||||
from superset.models.core import Database
|
||||
|
||||
user = mocker.MagicMock()
|
||||
user.email = "alice.doe@example.org"
|
||||
mocker.patch(
|
||||
"superset.models.core.security_manager.find_user",
|
||||
return_value=user,
|
||||
)
|
||||
mocker.patch("superset.models.core.get_username", return_value="alice")
|
||||
|
||||
create_engine = mocker.patch("superset.models.core.create_engine")
|
||||
|
||||
database = Database(
|
||||
database_name="my_db",
|
||||
sqlalchemy_uri="trino://",
|
||||
impersonate_user=True,
|
||||
)
|
||||
database._get_sqla_engine(nullpool=False)
|
||||
|
||||
create_engine.assert_called_with(
|
||||
make_url("trino:///"),
|
||||
connect_args={"user": "alice.doe", "source": "Apache Superset"},
|
||||
)
|
||||
|
||||
|
||||
def test_is_oauth2_enabled() -> None:
|
||||
"""
|
||||
Test the `is_oauth2_enabled` method.
|
||||
@@ -753,37 +667,6 @@ def test_get_oauth2_config_redirect_uri_from_config(
|
||||
assert config["redirect_uri"] == custom_redirect_uri
|
||||
|
||||
|
||||
def test_raw_connection_oauth_engine(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that we can start OAuth2 from `raw_connection()` errors.
|
||||
|
||||
With OAuth2, some databases will raise an exception when the engine is first created
|
||||
(eg, BigQuery). Others, like, Snowflake, when the connection is created. And
|
||||
finally, GSheets will raise an exception when the query is executed.
|
||||
|
||||
This tests verifies that when calling `raw_connection()` the OAuth2 flow is
|
||||
triggered when the engine is created.
|
||||
"""
|
||||
g = mocker.patch("superset.db_engine_specs.base.g")
|
||||
g.user = mocker.MagicMock()
|
||||
g.user.id = 42
|
||||
|
||||
database = Database(
|
||||
id=1,
|
||||
database_name="my_db",
|
||||
sqlalchemy_uri="sqlite://",
|
||||
encrypted_extra=json.dumps(oauth2_client_info),
|
||||
)
|
||||
database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
|
||||
_get_sqla_engine = mocker.patch.object(database, "_get_sqla_engine")
|
||||
_get_sqla_engine.side_effect = OAuth2Error("OAuth2 required")
|
||||
|
||||
with pytest.raises(OAuth2RedirectError) as excinfo:
|
||||
with database.get_raw_connection() as conn:
|
||||
conn.cursor()
|
||||
assert str(excinfo.value) == "You don't have permission to access the data."
|
||||
|
||||
|
||||
def test_raw_connection_oauth_connection(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that we can start OAuth2 from `raw_connection()` errors.
|
||||
@@ -879,56 +762,6 @@ def test_get_schema_access_for_file_upload() -> None:
|
||||
assert database.get_schema_access_for_file_upload() == {"public"}
|
||||
|
||||
|
||||
def test_engine_context_manager(mocker: MockerFixture, app_context: None) -> None:
|
||||
"""
|
||||
Test the engine context manager.
|
||||
"""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
engine_context_manager = MagicMock()
|
||||
mocker.patch.dict(
|
||||
current_app.config,
|
||||
{"ENGINE_CONTEXT_MANAGER": engine_context_manager},
|
||||
)
|
||||
_get_sqla_engine = mocker.patch.object(Database, "_get_sqla_engine")
|
||||
|
||||
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
|
||||
with database.get_sqla_engine("catalog", "schema"):
|
||||
pass
|
||||
|
||||
engine_context_manager.assert_called_once_with(database, "catalog", "schema")
|
||||
engine_context_manager().__enter__.assert_called_once()
|
||||
engine_context_manager().__exit__.assert_called_once_with(None, None, None)
|
||||
_get_sqla_engine.assert_called_once_with(
|
||||
catalog="catalog",
|
||||
schema="schema",
|
||||
nullpool=True,
|
||||
source=None,
|
||||
sqlalchemy_uri="trino://",
|
||||
)
|
||||
|
||||
|
||||
def test_engine_oauth2(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that we handle OAuth2 when `create_engine` fails.
|
||||
"""
|
||||
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
|
||||
mocker.patch.object(database, "_get_sqla_engine", side_effect=Exception)
|
||||
mocker.patch.object(database, "is_oauth2_enabled", return_value=True)
|
||||
mocker.patch.object(database.db_engine_spec, "needs_oauth2", return_value=True)
|
||||
start_oauth2_dance = mocker.patch.object(
|
||||
database.db_engine_spec,
|
||||
"start_oauth2_dance",
|
||||
side_effect=OAuth2Error("OAuth2 required"),
|
||||
)
|
||||
|
||||
with pytest.raises(OAuth2Error):
|
||||
with database.get_sqla_engine("catalog", "schema"):
|
||||
pass
|
||||
|
||||
start_oauth2_dance.assert_called_with(database)
|
||||
|
||||
|
||||
def test_purge_oauth2_tokens(session: Session) -> None:
|
||||
"""
|
||||
Test the `purge_oauth2_tokens` method.
|
||||
|
||||
Reference in New Issue
Block a user