fix: move oauth2 capture to get_sqla_engine (#32137)

This commit is contained in:
Beto Dealmeida
2025-02-04 18:24:05 -05:00
committed by GitHub
parent c64018d421
commit c7c3b1b0e9
3 changed files with 123 additions and 24 deletions

View File

@@ -558,16 +558,47 @@ def test_get_oauth2_config(app_context: None) -> None:
}
def test_raw_connection_oauth(mocker: MockerFixture) -> None:
def test_raw_connection_oauth_engine(mocker: MockerFixture) -> None:
"""
Test that we can start OAuth2 from `raw_connection()` errors.
Some databases that use OAuth2 need to trigger the flow when the connection is
created, rather than when the query runs. This happens when the SQLAlchemy engine
URI cannot be built without the user personal token.
With OAuth2, some databases will raise an exception when the engine is first created
(eg, BigQuery). Others, like, Snowflake, when the connection is created. And
finally, GSheets will raise an exception when the query is executed.
This test verifies that the exception is captured and raised correctly so that the
frontend can trigger the OAuth2 dance.
This tests verifies that when calling `raw_connection()` the OAuth2 flow is
triggered when the engine is created.
"""
g = mocker.patch("superset.db_engine_specs.base.g")
g.user = mocker.MagicMock()
g.user.id = 42
database = Database(
id=1,
database_name="my_db",
sqlalchemy_uri="sqlite://",
encrypted_extra=json.dumps(oauth2_client_info),
)
database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
_get_sqla_engine = mocker.patch.object(database, "_get_sqla_engine")
_get_sqla_engine.side_effect = OAuth2Error("OAuth2 required")
with pytest.raises(OAuth2RedirectError) as excinfo:
with database.get_raw_connection() as conn:
conn.cursor()
assert str(excinfo.value) == "You don't have permission to access the data."
def test_raw_connection_oauth_connection(mocker: MockerFixture) -> None:
"""
Test that we can start OAuth2 from `raw_connection()` errors.
With OAuth2, some databases will raise an exception when the engine is first created
(eg, BigQuery). Others, like, Snowflake, when the connection is created. And
finally, GSheets will raise an exception when the query is executed.
This tests verifies that when calling `raw_connection()` the OAuth2 flow is
triggered when the connection is created.
"""
g = mocker.patch("superset.db_engine_specs.base.g")
g.user = mocker.MagicMock()
@@ -591,6 +622,40 @@ def test_raw_connection_oauth(mocker: MockerFixture) -> None:
assert str(excinfo.value) == "You don't have permission to access the data."
def test_raw_connection_oauth_execute(mocker: MockerFixture) -> None:
"""
Test that we can start OAuth2 from `raw_connection()` errors.
With OAuth2, some databases will raise an exception when the engine is first created
(eg, BigQuery). Others, like, Snowflake, when the connection is created. And
finally, GSheets will raise an exception when the query is executed.
This tests verifies that when calling `raw_connection()` the OAuth2 flow is
triggered when the connection is created.
"""
g = mocker.patch("superset.db_engine_specs.base.g")
g.user = mocker.MagicMock()
g.user.id = 42
database = Database(
id=1,
database_name="my_db",
sqlalchemy_uri="sqlite://",
encrypted_extra=json.dumps(oauth2_client_info),
)
database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
get_sqla_engine = mocker.patch.object(database, "get_sqla_engine")
get_sqla_engine().__enter__().raw_connection().cursor().execute.side_effect = (
OAuth2Error("OAuth2 required")
)
with pytest.raises(OAuth2RedirectError) as excinfo: # noqa: PT012
with database.get_raw_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT 1")
assert str(excinfo.value) == "You don't have permission to access the data."
def test_get_schema_access_for_file_upload() -> None:
"""
Test the `get_schema_access_for_file_upload` method.
@@ -638,6 +703,27 @@ def test_engine_context_manager(mocker: MockerFixture) -> None:
)
def test_engine_oauth2(mocker: MockerFixture) -> None:
"""
Test that we handle OAuth2 when `create_engine` fails.
"""
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
mocker.patch.object(database, "_get_sqla_engine", side_effect=Exception)
mocker.patch.object(database, "is_oauth2_enabled", return_value=True)
mocker.patch.object(database.db_engine_spec, "needs_oauth2", return_value=True)
start_oauth2_dance = mocker.patch.object(
database.db_engine_spec,
"start_oauth2_dance",
side_effect=OAuth2Error("OAuth2 required"),
)
with pytest.raises(OAuth2Error):
with database.get_sqla_engine("catalog", "schema"):
pass
start_oauth2_dance.assert_called_with(database)
def test_purge_oauth2_tokens(session: Session) -> None:
"""
Test the `purge_oauth2_tokens` method.