mirror of
https://github.com/apache/superset.git
synced 2026-04-19 16:14:52 +00:00
fix: move oauth2 capture to get_sqla_engine (#32137)
This commit is contained in:
@@ -558,16 +558,47 @@ def test_get_oauth2_config(app_context: None) -> None:
|
||||
}
|
||||
|
||||
|
||||
def test_raw_connection_oauth(mocker: MockerFixture) -> None:
|
||||
def test_raw_connection_oauth_engine(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that we can start OAuth2 from `raw_connection()` errors.
|
||||
|
||||
Some databases that use OAuth2 need to trigger the flow when the connection is
|
||||
created, rather than when the query runs. This happens when the SQLAlchemy engine
|
||||
URI cannot be built without the user personal token.
|
||||
With OAuth2, some databases will raise an exception when the engine is first created
|
||||
(eg, BigQuery). Others, like, Snowflake, when the connection is created. And
|
||||
finally, GSheets will raise an exception when the query is executed.
|
||||
|
||||
This test verifies that the exception is captured and raised correctly so that the
|
||||
frontend can trigger the OAuth2 dance.
|
||||
This tests verifies that when calling `raw_connection()` the OAuth2 flow is
|
||||
triggered when the engine is created.
|
||||
"""
|
||||
g = mocker.patch("superset.db_engine_specs.base.g")
|
||||
g.user = mocker.MagicMock()
|
||||
g.user.id = 42
|
||||
|
||||
database = Database(
|
||||
id=1,
|
||||
database_name="my_db",
|
||||
sqlalchemy_uri="sqlite://",
|
||||
encrypted_extra=json.dumps(oauth2_client_info),
|
||||
)
|
||||
database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
|
||||
_get_sqla_engine = mocker.patch.object(database, "_get_sqla_engine")
|
||||
_get_sqla_engine.side_effect = OAuth2Error("OAuth2 required")
|
||||
|
||||
with pytest.raises(OAuth2RedirectError) as excinfo:
|
||||
with database.get_raw_connection() as conn:
|
||||
conn.cursor()
|
||||
assert str(excinfo.value) == "You don't have permission to access the data."
|
||||
|
||||
|
||||
def test_raw_connection_oauth_connection(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that we can start OAuth2 from `raw_connection()` errors.
|
||||
|
||||
With OAuth2, some databases will raise an exception when the engine is first created
|
||||
(eg, BigQuery). Others, like, Snowflake, when the connection is created. And
|
||||
finally, GSheets will raise an exception when the query is executed.
|
||||
|
||||
This tests verifies that when calling `raw_connection()` the OAuth2 flow is
|
||||
triggered when the connection is created.
|
||||
"""
|
||||
g = mocker.patch("superset.db_engine_specs.base.g")
|
||||
g.user = mocker.MagicMock()
|
||||
@@ -591,6 +622,40 @@ def test_raw_connection_oauth(mocker: MockerFixture) -> None:
|
||||
assert str(excinfo.value) == "You don't have permission to access the data."
|
||||
|
||||
|
||||
def test_raw_connection_oauth_execute(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that we can start OAuth2 from `raw_connection()` errors.
|
||||
|
||||
With OAuth2, some databases will raise an exception when the engine is first created
|
||||
(eg, BigQuery). Others, like, Snowflake, when the connection is created. And
|
||||
finally, GSheets will raise an exception when the query is executed.
|
||||
|
||||
This tests verifies that when calling `raw_connection()` the OAuth2 flow is
|
||||
triggered when the connection is created.
|
||||
"""
|
||||
g = mocker.patch("superset.db_engine_specs.base.g")
|
||||
g.user = mocker.MagicMock()
|
||||
g.user.id = 42
|
||||
|
||||
database = Database(
|
||||
id=1,
|
||||
database_name="my_db",
|
||||
sqlalchemy_uri="sqlite://",
|
||||
encrypted_extra=json.dumps(oauth2_client_info),
|
||||
)
|
||||
database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
|
||||
get_sqla_engine = mocker.patch.object(database, "get_sqla_engine")
|
||||
get_sqla_engine().__enter__().raw_connection().cursor().execute.side_effect = (
|
||||
OAuth2Error("OAuth2 required")
|
||||
)
|
||||
|
||||
with pytest.raises(OAuth2RedirectError) as excinfo: # noqa: PT012
|
||||
with database.get_raw_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
assert str(excinfo.value) == "You don't have permission to access the data."
|
||||
|
||||
|
||||
def test_get_schema_access_for_file_upload() -> None:
|
||||
"""
|
||||
Test the `get_schema_access_for_file_upload` method.
|
||||
@@ -638,6 +703,27 @@ def test_engine_context_manager(mocker: MockerFixture) -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_engine_oauth2(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that we handle OAuth2 when `create_engine` fails.
|
||||
"""
|
||||
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
|
||||
mocker.patch.object(database, "_get_sqla_engine", side_effect=Exception)
|
||||
mocker.patch.object(database, "is_oauth2_enabled", return_value=True)
|
||||
mocker.patch.object(database.db_engine_spec, "needs_oauth2", return_value=True)
|
||||
start_oauth2_dance = mocker.patch.object(
|
||||
database.db_engine_spec,
|
||||
"start_oauth2_dance",
|
||||
side_effect=OAuth2Error("OAuth2 required"),
|
||||
)
|
||||
|
||||
with pytest.raises(OAuth2Error):
|
||||
with database.get_sqla_engine("catalog", "schema"):
|
||||
pass
|
||||
|
||||
start_oauth2_dance.assert_called_with(database)
|
||||
|
||||
|
||||
def test_purge_oauth2_tokens(session: Session) -> None:
|
||||
"""
|
||||
Test the `purge_oauth2_tokens` method.
|
||||
|
||||
Reference in New Issue
Block a user