diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 4d72be41d8a..4113a3e8fe5 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -717,9 +717,13 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods "redirect_uri": config["redirect_uri"], "grant_type": "authorization_code", } - if config["request_content_type"] == "data": - return requests.post(uri, data=req_body, timeout=timeout).json() - return requests.post(uri, json=req_body, timeout=timeout).json() + response = ( + requests.post(uri, data=req_body, timeout=timeout) + if config["request_content_type"] == "data" + else requests.post(uri, json=req_body, timeout=timeout) + ) + response.raise_for_status() + return response.json() @classmethod def get_oauth2_fresh_token( @@ -738,9 +742,13 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods "refresh_token": refresh_token, "grant_type": "refresh_token", } - if config["request_content_type"] == "data": - return requests.post(uri, data=req_body, timeout=timeout).json() - return requests.post(uri, json=req_body, timeout=timeout).json() + response = ( + requests.post(uri, data=req_body, timeout=timeout) + if config["request_content_type"] == "data" + else requests.post(uri, json=req_body, timeout=timeout) + ) + response.raise_for_status() + return response.json() @classmethod def get_allows_alias_in_select( diff --git a/superset/models/core.py b/superset/models/core.py index cb7bdf2d352..d13c14b65ab 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -896,9 +896,7 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint: ) } except Exception as ex: - if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex): - self.start_oauth2_dance() - + self._handle_oauth2_error(ex) raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex @cache_util.memoized_func( @@ -933,9 +931,7 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint: ) } except Exception as ex: - if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex): - self.start_oauth2_dance() - + self._handle_oauth2_error(ex) raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex @cache_util.memoized_func( @@ -972,9 +968,7 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint: ) } except Exception as ex: - if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex): - self.start_oauth2_dance() - + self._handle_oauth2_error(ex) raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex return set() @@ -1003,9 +997,7 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint: with self.get_inspector(catalog=catalog) as inspector: return self.db_engine_spec.get_schema_names(inspector) except Exception as ex: - if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex): - self.start_oauth2_dance() - + self._handle_oauth2_error(ex) raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex @cache_util.memoized_func( @@ -1022,9 +1014,7 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint: with self.get_inspector() as inspector: return self.db_engine_spec.get_catalog_names(self, inspector) except Exception as ex: - if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex): - self.start_oauth2_dance() - + self._handle_oauth2_error(ex) raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex @property @@ -1261,6 +1251,10 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint: if oauth2_client_info := encrypted_extra.get("oauth2_client_info"): schema = OAuth2ClientConfigSchema() client_config = schema.load(oauth2_client_info) + if "request_content_type" not in oauth2_client_info: + client_config["request_content_type"] = ( + self.db_engine_spec.oauth2_token_request_type + ) return cast(OAuth2ClientConfig, client_config) return self.db_engine_spec.get_oauth2_config() @@ -1275,6 +1269,16 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint: """ return self.db_engine_spec.start_oauth2_dance(self) + def _handle_oauth2_error(self, ex: Exception) -> None: + """ + Handle exceptions that may require OAuth2 authentication. + + If OAuth2 is enabled and the exception indicates that OAuth2 is needed, + starts the OAuth2 dance. + """ + if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex): + self.start_oauth2_dance() + def purge_oauth2_tokens(self) -> None: """ Delete all OAuth2 tokens associated with this database. diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py index ebe1f4012eb..cd1a2a14d9e 100644 --- a/superset/utils/oauth2.py +++ b/superset/utils/oauth2.py @@ -189,7 +189,10 @@ class OAuth2ClientConfigSchema(Schema): scope = fields.String(required=True) redirect_uri = fields.String( required=False, - load_default=lambda: url_for("DatabaseRestApi.oauth2", _external=True), + load_default=lambda: app.config.get( + "DATABASE_OAUTH2_REDIRECT_URI", + url_for("DatabaseRestApi.oauth2", _external=True), + ), ) authorization_request_uri = fields.String(required=True) token_request_uri = fields.String(required=True) diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index 998a1033bb0..7d7aa96ea19 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -660,6 +660,34 @@ def test_get_oauth2_config(app_context: None) -> None: assert database.get_oauth2_config() is None + database.encrypted_extra = json.dumps(oauth2_client_info) + assert database.get_oauth2_config() == { + "id": "my_client_id", + "secret": "my_client_secret", + "authorization_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/authorize", + "token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request", + "scope": "refresh_token session:role:USERADMIN", + "redirect_uri": "http://example.com/api/v1/database/oauth2/", + "request_content_type": "data", # Default value from BaseEngineSpec + } + + +def test_get_oauth2_config_token_request_type_from_db_engine_specs( + mocker: MockerFixture, app_context: None +) -> None: + """ + Test that DB Engine Spec overrides for ``oauth2_token_request_type`` are respected. + """ + database = Database( + database_name="db", + sqlalchemy_uri="postgresql://user:password@host:5432/examples", + ) + mocker.patch.object( + database.db_engine_spec, + "oauth2_token_request_type", + "json", + ) + database.encrypted_extra = json.dumps(oauth2_client_info) assert database.get_oauth2_config() == { "id": "my_client_id", @@ -672,6 +700,59 @@ def test_get_oauth2_config(app_context: None) -> None: } +def test_get_oauth2_config_custom_token_request_type_extra(app_context: None) -> None: + """ + Test passing a custom ``token_request_type`` via ``encrypted_extra`` + takes precedence. + """ + database = Database( + database_name="db", + sqlalchemy_uri="postgresql://user:password@host:5432/examples", + ) + custom_oauth2_client_info = { + "oauth2_client_info": { + **oauth2_client_info["oauth2_client_info"], + "request_content_type": "json", + } + } + + database.encrypted_extra = json.dumps(custom_oauth2_client_info) + assert database.get_oauth2_config() == { + "id": "my_client_id", + "secret": "my_client_secret", + "authorization_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/authorize", + "token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request", + "scope": "refresh_token session:role:USERADMIN", + "redirect_uri": "http://example.com/api/v1/database/oauth2/", + "request_content_type": "json", + } + + +def test_get_oauth2_config_redirect_uri_from_config( + mocker: MockerFixture, + app_context: None, +) -> None: + """ + Test that ``DATABASE_OAUTH2_REDIRECT_URI`` config takes precedence over + url_for default. + """ + custom_redirect_uri = "https://custom.example.com/oauth/callback" + mocker.patch.dict( + "superset.utils.oauth2.app.config", + {"DATABASE_OAUTH2_REDIRECT_URI": custom_redirect_uri}, + ) + database = Database( + database_name="db", + sqlalchemy_uri="postgresql://user:password@host:5432/examples", + ) + database.encrypted_extra = json.dumps(oauth2_client_info) + + config = database.get_oauth2_config() + + assert config is not None + assert config["redirect_uri"] == custom_redirect_uri + + def test_raw_connection_oauth_engine(mocker: MockerFixture) -> None: """ Test that we can start OAuth2 from `raw_connection()` errors.