feat: allow create/update OAuth2 DB (#30071)

This commit is contained in:
Beto Dealmeida
2024-09-03 19:22:38 -04:00
committed by GitHub
parent c929f5ed7a
commit 0415ed34ce
16 changed files with 620 additions and 29 deletions

View File

@@ -25,6 +25,7 @@ from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import make_url
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.errors import SupersetErrorType
from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.models.core import Database
from superset.sql_parse import Table
@@ -38,7 +39,7 @@ oauth2_client_info = {
"secret": "my_client_secret",
"authorization_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/authorize",
"token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request",
"scope": "refresh_token session:role:SYSADMIN",
"scope": "refresh_token session:role:USERADMIN",
}
}
@@ -306,6 +307,80 @@ def test_get_all_catalog_names(mocker: MockerFixture) -> None:
get_inspector.assert_called_with(ssh_tunnel=None)
def test_get_all_schema_names_needs_oauth2(mocker: MockerFixture) -> None:
"""
Test the `get_all_schema_names` method when OAuth2 is needed.
"""
database = Database(
database_name="db",
sqlalchemy_uri="snowflake://:@abcd1234.snowflakecomputing.com/db",
encrypted_extra=json.dumps(oauth2_client_info),
)
class DriverSpecificError(Exception):
"""
A custom exception that is raised by the Snowflake driver.
"""
mocker.patch.object(
database.db_engine_spec,
"oauth2_exception",
DriverSpecificError,
)
mocker.patch.object(
database.db_engine_spec,
"get_schema_names",
side_effect=DriverSpecificError("User needs to authenticate"),
)
mocker.patch.object(database, "get_inspector")
user = mocker.MagicMock()
user.id = 42
mocker.patch("superset.db_engine_specs.base.g", user=user)
with pytest.raises(OAuth2RedirectError) as excinfo:
database.get_all_schema_names()
assert excinfo.value.message == "You don't have permission to access the data."
assert excinfo.value.error.error_type == SupersetErrorType.OAUTH2_REDIRECT
def test_get_all_catalog_names_needs_oauth2(mocker: MockerFixture) -> None:
"""
Test the `get_all_catalog_names` method when OAuth2 is needed.
"""
database = Database(
database_name="db",
sqlalchemy_uri="snowflake://:@abcd1234.snowflakecomputing.com/db",
encrypted_extra=json.dumps(oauth2_client_info),
)
class DriverSpecificError(Exception):
"""
A custom exception that is raised by the Snowflake driver.
"""
mocker.patch.object(
database.db_engine_spec,
"oauth2_exception",
DriverSpecificError,
)
mocker.patch.object(
database.db_engine_spec,
"get_catalog_names",
side_effect=DriverSpecificError("User needs to authenticate"),
)
mocker.patch.object(database, "get_inspector")
user = mocker.MagicMock()
user.id = 42
mocker.patch("superset.db_engine_specs.base.g", user=user)
with pytest.raises(OAuth2RedirectError) as excinfo:
database.get_all_catalog_names()
assert excinfo.value.message == "You don't have permission to access the data."
assert excinfo.value.error.error_type == SupersetErrorType.OAUTH2_REDIRECT
def test_get_sqla_engine(mocker: MockerFixture) -> None:
"""
Test `_get_sqla_engine`.
@@ -425,7 +500,7 @@ def test_get_oauth2_config(app_context: None) -> None:
"secret": "my_client_secret",
"authorization_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/authorize",
"token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request",
"scope": "refresh_token session:role:SYSADMIN",
"scope": "refresh_token session:role:USERADMIN",
"redirect_uri": "http://example.com/api/v1/database/oauth2/",
}