feat(oauth2): add support for trino (#30081)

This commit is contained in:
João Ferrão
2024-11-04 17:54:47 +01:00
committed by GitHub
parent 64f8140731
commit 305b6df6e3
7 changed files with 156 additions and 57 deletions

View File

@@ -433,6 +433,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
oauth2_scope = ""
oauth2_authorization_request_uri: str | None = None # pylint: disable=invalid-name
oauth2_token_request_uri: str | None = None
oauth2_token_request_type = "data"
# Driver-specific exception that should be mapped to OAuth2RedirectError
oauth2_exception = OAuth2RedirectError
@@ -525,6 +526,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"token_request_uri",
cls.oauth2_token_request_uri,
),
"request_content_type": db_engine_spec_config.get(
"request_content_type", cls.oauth2_token_request_type
),
}
return config
@@ -562,18 +566,16 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
uri = config["token_request_uri"]
response = requests.post(
uri,
json={
"code": code,
"client_id": config["id"],
"client_secret": config["secret"],
"redirect_uri": config["redirect_uri"],
"grant_type": "authorization_code",
},
timeout=timeout,
)
return response.json()
req_body = {
"code": code,
"client_id": config["id"],
"client_secret": config["secret"],
"redirect_uri": config["redirect_uri"],
"grant_type": "authorization_code",
}
if config["request_content_type"] == "data":
return requests.post(uri, data=req_body, timeout=timeout).json()
return requests.post(uri, json=req_body, timeout=timeout).json()
@classmethod
def get_oauth2_fresh_token(
@@ -586,17 +588,15 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
uri = config["token_request_uri"]
response = requests.post(
uri,
json={
"client_id": config["id"],
"client_secret": config["secret"],
"refresh_token": refresh_token,
"grant_type": "refresh_token",
},
timeout=timeout,
)
return response.json()
req_body = {
"client_id": config["id"],
"client_secret": config["secret"],
"refresh_token": refresh_token,
"grant_type": "refresh_token",
}
if config["request_content_type"] == "data":
return requests.post(uri, data=req_body, timeout=timeout).json()
return requests.post(uri, json=req_body, timeout=timeout).json()
@classmethod
def get_allows_alias_in_select(