Update existing tests

This commit is contained in:
Beto Dealmeida
2025-12-03 20:00:29 -05:00
parent be31abeb7e
commit de8c250f86
6 changed files with 244 additions and 186 deletions

View File

@@ -19,7 +19,6 @@
from datetime import datetime
import pytest
from flask import current_app
from pytest_mock import MockerFixture
from sqlalchemy import (
Column,
@@ -29,7 +28,6 @@ from sqlalchemy import (
Table as SqlalchemyTable,
)
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import make_url
from sqlalchemy.orm.session import Session
from sqlalchemy.sql import Select
@@ -525,60 +523,6 @@ def test_get_all_materialized_view_names_in_schema_needs_oauth2(
assert excinfo.value.error.error_type == SupersetErrorType.OAUTH2_REDIRECT
def test_get_sqla_engine(mocker: MockerFixture) -> None:
"""
Test `_get_sqla_engine`.
"""
from superset.models.core import Database
user = mocker.MagicMock()
user.email = "alice.doe@example.org"
mocker.patch(
"superset.models.core.security_manager.find_user",
return_value=user,
)
mocker.patch("superset.models.core.get_username", return_value="alice")
create_engine = mocker.patch("superset.models.core.create_engine")
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
database._get_sqla_engine(nullpool=False)
create_engine.assert_called_with(
make_url("trino:///"),
connect_args={"source": "Apache Superset"},
)
def test_get_sqla_engine_user_impersonation(mocker: MockerFixture) -> None:
"""
Test user impersonation in `_get_sqla_engine`.
"""
from superset.models.core import Database
user = mocker.MagicMock()
user.email = "alice.doe@example.org"
mocker.patch(
"superset.models.core.security_manager.find_user",
return_value=user,
)
mocker.patch("superset.models.core.get_username", return_value="alice")
create_engine = mocker.patch("superset.models.core.create_engine")
database = Database(
database_name="my_db",
sqlalchemy_uri="trino://",
impersonate_user=True,
)
database._get_sqla_engine(nullpool=False)
create_engine.assert_called_with(
make_url("trino:///"),
connect_args={"user": "alice", "source": "Apache Superset"},
)
def test_add_database_to_signature():
args = ["param1", "param2"]
@@ -604,36 +548,6 @@ def test_add_database_to_signature():
assert args3 == ["param1", "param2", database]
@with_feature_flags(IMPERSONATE_WITH_EMAIL_PREFIX=True)
def test_get_sqla_engine_user_impersonation_email(mocker: MockerFixture) -> None:
"""
Test user impersonation in `_get_sqla_engine` with `username_from_email`.
"""
from superset.models.core import Database
user = mocker.MagicMock()
user.email = "alice.doe@example.org"
mocker.patch(
"superset.models.core.security_manager.find_user",
return_value=user,
)
mocker.patch("superset.models.core.get_username", return_value="alice")
create_engine = mocker.patch("superset.models.core.create_engine")
database = Database(
database_name="my_db",
sqlalchemy_uri="trino://",
impersonate_user=True,
)
database._get_sqla_engine(nullpool=False)
create_engine.assert_called_with(
make_url("trino:///"),
connect_args={"user": "alice.doe", "source": "Apache Superset"},
)
def test_is_oauth2_enabled() -> None:
"""
Test the `is_oauth2_enabled` method.
@@ -753,37 +667,6 @@ def test_get_oauth2_config_redirect_uri_from_config(
assert config["redirect_uri"] == custom_redirect_uri
def test_raw_connection_oauth_engine(mocker: MockerFixture) -> None:
"""
Test that we can start OAuth2 from `raw_connection()` errors.
With OAuth2, some databases will raise an exception when the engine is first created
(eg, BigQuery). Others, like, Snowflake, when the connection is created. And
finally, GSheets will raise an exception when the query is executed.
This tests verifies that when calling `raw_connection()` the OAuth2 flow is
triggered when the engine is created.
"""
g = mocker.patch("superset.db_engine_specs.base.g")
g.user = mocker.MagicMock()
g.user.id = 42
database = Database(
id=1,
database_name="my_db",
sqlalchemy_uri="sqlite://",
encrypted_extra=json.dumps(oauth2_client_info),
)
database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
_get_sqla_engine = mocker.patch.object(database, "_get_sqla_engine")
_get_sqla_engine.side_effect = OAuth2Error("OAuth2 required")
with pytest.raises(OAuth2RedirectError) as excinfo:
with database.get_raw_connection() as conn:
conn.cursor()
assert str(excinfo.value) == "You don't have permission to access the data."
def test_raw_connection_oauth_connection(mocker: MockerFixture) -> None:
"""
Test that we can start OAuth2 from `raw_connection()` errors.
@@ -879,56 +762,6 @@ def test_get_schema_access_for_file_upload() -> None:
assert database.get_schema_access_for_file_upload() == {"public"}
def test_engine_context_manager(mocker: MockerFixture, app_context: None) -> None:
"""
Test the engine context manager.
"""
from unittest.mock import MagicMock
engine_context_manager = MagicMock()
mocker.patch.dict(
current_app.config,
{"ENGINE_CONTEXT_MANAGER": engine_context_manager},
)
_get_sqla_engine = mocker.patch.object(Database, "_get_sqla_engine")
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
with database.get_sqla_engine("catalog", "schema"):
pass
engine_context_manager.assert_called_once_with(database, "catalog", "schema")
engine_context_manager().__enter__.assert_called_once()
engine_context_manager().__exit__.assert_called_once_with(None, None, None)
_get_sqla_engine.assert_called_once_with(
catalog="catalog",
schema="schema",
nullpool=True,
source=None,
sqlalchemy_uri="trino://",
)
def test_engine_oauth2(mocker: MockerFixture) -> None:
"""
Test that we handle OAuth2 when `create_engine` fails.
"""
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
mocker.patch.object(database, "_get_sqla_engine", side_effect=Exception)
mocker.patch.object(database, "is_oauth2_enabled", return_value=True)
mocker.patch.object(database.db_engine_spec, "needs_oauth2", return_value=True)
start_oauth2_dance = mocker.patch.object(
database.db_engine_spec,
"start_oauth2_dance",
side_effect=OAuth2Error("OAuth2 required"),
)
with pytest.raises(OAuth2Error):
with database.get_sqla_engine("catalog", "schema"):
pass
start_oauth2_dance.assert_called_with(database)
def test_purge_oauth2_tokens(session: Session) -> None:
"""
Test the `purge_oauth2_tokens` method.