Update existing tests

This commit is contained in:
Beto Dealmeida
2025-12-03 20:00:29 -05:00
parent be31abeb7e
commit de8c250f86
6 changed files with 244 additions and 186 deletions

View File

@@ -96,7 +96,9 @@ class TestEngineManager:
@pytest.fixture
def engine_manager(self):
"""Create a mock EngineManager instance."""
from contextlib import contextmanager
@contextmanager
def dummy_context_manager(
database: MagicMock, catalog: str | None, schema: str | None
) -> Iterator[None]:
@@ -293,3 +295,233 @@ class TestEngineManager:
result2 = engine_manager._get_tunnel(ssh_tunnel, uri)
assert result2 is active_tunnel
assert mock_tunnel_class.call_count == 2
@patch("superset.engines.manager.create_engine")
@patch("superset.engines.manager.make_url_safe")
def test_get_engine_args_basic(
self, mock_make_url, mock_create_engine, engine_manager
):
"""Test _get_engine_args returns correct URI and kwargs."""
from sqlalchemy.engine.url import make_url
from superset.engines.manager import EngineModes
engine_manager.mode = EngineModes.NEW
mock_uri = make_url("trino://")
mock_make_url.return_value = mock_uri
database = MagicMock()
database.id = 1
database.sqlalchemy_uri_decrypted = "trino://"
database.get_extra.return_value = {
"engine_params": {},
"connect_args": {"source": "Apache Superset"},
}
database.get_effective_user.return_value = "alice"
database.impersonate_user = False
database.update_params_from_encrypted_extra = MagicMock()
database.db_engine_spec = MagicMock()
database.db_engine_spec.adjust_engine_params.return_value = (
mock_uri,
{"source": "Apache Superset"},
)
database.db_engine_spec.validate_database_uri = MagicMock()
uri, kwargs = engine_manager._get_engine_args(database, None, None, None, None)
assert str(uri) == "trino://"
assert "connect_args" in database.get_extra.return_value
@patch("superset.engines.manager.create_engine")
@patch("superset.engines.manager.make_url_safe")
def test_get_engine_args_user_impersonation(
self, mock_make_url, mock_create_engine, engine_manager
):
"""Test user impersonation in _get_engine_args."""
from sqlalchemy.engine.url import make_url
from superset.engines.manager import EngineModes
engine_manager.mode = EngineModes.NEW
mock_uri = make_url("trino://")
mock_make_url.return_value = mock_uri
database = MagicMock()
database.id = 1
database.sqlalchemy_uri_decrypted = "trino://"
database.get_extra.return_value = {
"engine_params": {},
"connect_args": {"source": "Apache Superset"},
}
database.get_effective_user.return_value = "alice"
database.impersonate_user = True
database.get_oauth2_config.return_value = None
database.update_params_from_encrypted_extra = MagicMock()
database.db_engine_spec = MagicMock()
database.db_engine_spec.adjust_engine_params.return_value = (
mock_uri,
{"source": "Apache Superset"},
)
database.db_engine_spec.impersonate_user.return_value = (
mock_uri,
{"connect_args": {"user": "alice", "source": "Apache Superset"}},
)
database.db_engine_spec.validate_database_uri = MagicMock()
uri, kwargs = engine_manager._get_engine_args(database, None, None, None, None)
# Verify impersonate_user was called
database.db_engine_spec.impersonate_user.assert_called_once()
call_args = database.db_engine_spec.impersonate_user.call_args
assert call_args[0][0] is database # database
assert call_args[0][1] == "alice" # username
assert call_args[0][2] is None # access_token (no OAuth2)
@patch("superset.engines.manager.create_engine")
@patch("superset.engines.manager.make_url_safe")
def test_get_engine_args_user_impersonation_email_prefix(
self,
mock_make_url,
mock_create_engine,
engine_manager,
):
"""Test user impersonation with IMPERSONATE_WITH_EMAIL_PREFIX feature flag."""
from sqlalchemy.engine.url import make_url
from superset.engines.manager import EngineModes
engine_manager.mode = EngineModes.NEW
mock_uri = make_url("trino://")
mock_make_url.return_value = mock_uri
# Mock user with email
mock_user = MagicMock()
mock_user.email = "alice.doe@example.org"
database = MagicMock()
database.id = 1
database.sqlalchemy_uri_decrypted = "trino://"
database.get_extra.return_value = {
"engine_params": {},
"connect_args": {"source": "Apache Superset"},
}
database.get_effective_user.return_value = "alice"
database.impersonate_user = True
database.get_oauth2_config.return_value = None
database.update_params_from_encrypted_extra = MagicMock()
database.db_engine_spec = MagicMock()
database.db_engine_spec.adjust_engine_params.return_value = (
mock_uri,
{"source": "Apache Superset"},
)
database.db_engine_spec.impersonate_user.return_value = (
mock_uri,
{"connect_args": {"user": "alice.doe", "source": "Apache Superset"}},
)
database.db_engine_spec.validate_database_uri = MagicMock()
with (
patch(
"superset.utils.feature_flag_manager.FeatureFlagManager.is_feature_enabled",
return_value=True,
),
patch(
"superset.extensions.security_manager.find_user",
return_value=mock_user,
),
):
uri, kwargs = engine_manager._get_engine_args(
database, None, None, None, None
)
# Verify impersonate_user was called with the email prefix
database.db_engine_spec.impersonate_user.assert_called_once()
call_args = database.db_engine_spec.impersonate_user.call_args
assert call_args[0][1] == "alice.doe" # username from email prefix
@patch("superset.engines.manager.create_engine")
@patch("superset.engines.manager.make_url_safe")
def test_engine_context_manager_called(
self, mock_make_url, mock_create_engine, engine_manager, mock_database
):
"""Test that the engine context manager is properly called."""
from sqlalchemy.engine.url import make_url
mock_uri = make_url("trino://")
mock_make_url.return_value = mock_uri
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
# Track context manager calls
context_manager_calls = []
def tracking_context_manager(database, catalog, schema):
from contextlib import contextmanager
@contextmanager
def inner():
context_manager_calls.append(("enter", database, catalog, schema))
yield
context_manager_calls.append(("exit", database, catalog, schema))
return inner()
engine_manager.engine_context_manager = tracking_context_manager
with engine_manager.get_engine(mock_database, "catalog1", "schema1", None):
pass
assert len(context_manager_calls) == 2
assert context_manager_calls[0][0] == "enter"
assert context_manager_calls[0][1] is mock_database
assert context_manager_calls[0][2] == "catalog1"
assert context_manager_calls[0][3] == "schema1"
assert context_manager_calls[1][0] == "exit"
@patch("superset.utils.oauth2.check_for_oauth2")
@patch("superset.engines.manager.create_engine")
@patch("superset.engines.manager.make_url_safe")
def test_engine_oauth2_error_handling(
self,
mock_make_url,
mock_create_engine,
mock_check_for_oauth2,
engine_manager,
mock_database,
):
"""Test that OAuth2 errors are properly propagated from get_engine."""
from contextlib import contextmanager
from sqlalchemy.engine.url import make_url
mock_uri = make_url("trino://")
mock_make_url.return_value = mock_uri
# Simulate OAuth2 error during engine creation
class OAuth2TestError(Exception):
pass
oauth_error = OAuth2TestError("OAuth2 required")
mock_create_engine.side_effect = oauth_error
# Make get_dbapi_mapped_exception return the original exception
mock_database.db_engine_spec.get_dbapi_mapped_exception.return_value = (
oauth_error
)
# Mock check_for_oauth2 to re-raise the exception
@contextmanager
def mock_oauth2_context(database):
try:
yield
except OAuth2TestError:
raise
mock_check_for_oauth2.return_value = mock_oauth2_context(mock_database)
with pytest.raises(OAuth2TestError, match="OAuth2 required"):
with engine_manager.get_engine(mock_database, "catalog1", "schema1", None):
pass