mirror of
https://github.com/apache/superset.git
synced 2026-04-20 08:34:37 +00:00
fix: more DB OAuth2 fixes (#37398)
This commit is contained in:
@@ -30,9 +30,11 @@ from flask_babel import gettext as __
|
||||
from marshmallow import fields, Schema
|
||||
from marshmallow.exceptions import ValidationError
|
||||
from requests import Session
|
||||
from requests.exceptions import HTTPError
|
||||
from shillelagh.adapters.api.gsheets.lib import SCOPES
|
||||
from shillelagh.exceptions import UnauthenticatedError
|
||||
from sqlalchemy.engine import create_engine
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import URL
|
||||
|
||||
from superset import db, security_manager
|
||||
@@ -41,7 +43,9 @@ from superset.db_engine_specs.base import DatabaseCategory
|
||||
from superset.db_engine_specs.shillelagh import ShillelaghEngineSpec
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.superset_typing import OAuth2TokenResponse
|
||||
from superset.utils import json
|
||||
from superset.utils.oauth2 import get_oauth2_access_token
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.core import Database
|
||||
@@ -83,14 +87,16 @@ class GSheetsParametersSchema(Schema):
|
||||
)
|
||||
|
||||
|
||||
class GSheetsParametersType(TypedDict):
|
||||
class GSheetsParametersType(TypedDict, total=False):
|
||||
service_account_info: str
|
||||
catalog: dict[str, str] | None
|
||||
oauth2_client_info: dict[str, str] | None
|
||||
|
||||
|
||||
class GSheetsPropertiesType(TypedDict):
|
||||
class GSheetsPropertiesType(TypedDict, total=False):
|
||||
parameters: GSheetsParametersType
|
||||
catalog: dict[str, str]
|
||||
masked_encrypted_extra: str
|
||||
|
||||
|
||||
class GSheetsEngineSpec(ShillelaghEngineSpec):
|
||||
@@ -123,7 +129,10 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
|
||||
|
||||
# when editing the database, mask this field in `encrypted_extra`
|
||||
# pylint: disable=invalid-name
|
||||
encrypted_extra_sensitive_fields = {"$.service_account_info.private_key"}
|
||||
encrypted_extra_sensitive_fields = {
|
||||
"$.service_account_info.private_key",
|
||||
"$.oauth2_client_info.secret",
|
||||
}
|
||||
|
||||
custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
|
||||
SYNTAX_ERROR_REGEX: (
|
||||
@@ -179,6 +188,47 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
|
||||
}
|
||||
return urljoin(uri, "?" + urlencode(params))
|
||||
|
||||
@classmethod
|
||||
def needs_oauth2(cls, ex: Exception) -> bool:
|
||||
"""
|
||||
Check if the exception is one that indicates OAuth2 is needed.
|
||||
|
||||
In case the token was manually revoked on Google side, `google-auth` will
|
||||
try to automatically refresh credentials, but it fails since it only has the
|
||||
access token. This override catches this scenario as well.
|
||||
"""
|
||||
return (
|
||||
g
|
||||
and hasattr(g, "user")
|
||||
and (
|
||||
isinstance(ex, cls.oauth2_exception)
|
||||
or "credentials do not contain the necessary fields" in str(ex)
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_oauth2_fresh_token(
|
||||
cls,
|
||||
config: OAuth2ClientConfig,
|
||||
refresh_token: str,
|
||||
) -> OAuth2TokenResponse:
|
||||
"""
|
||||
Refresh an OAuth2 access token that has expired.
|
||||
|
||||
When trying to refresh an expired token that was revoked on Google side,
|
||||
the request fails with 400 status code.
|
||||
"""
|
||||
try:
|
||||
return super().get_oauth2_fresh_token(config, refresh_token)
|
||||
except HTTPError as ex:
|
||||
if ex.response is not None and ex.response.status_code == 400:
|
||||
error_data = ex.response.json()
|
||||
if error_data.get("error") == "invalid_grant":
|
||||
raise UnauthenticatedError(
|
||||
error_data.get("error_description", "Token has been revoked")
|
||||
) from ex
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def impersonate_user(
|
||||
cls,
|
||||
@@ -198,6 +248,28 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
|
||||
|
||||
return url, engine_kwargs
|
||||
|
||||
@classmethod
|
||||
def get_table_names(
|
||||
cls,
|
||||
database: Database,
|
||||
inspector: Inspector,
|
||||
schema: str | None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
Get all sheets added to the connection.
|
||||
|
||||
For OAuth2 connections, force the OAuth2 dance in case the user
|
||||
doesn't have a token yet to avoid showing table names berofe auth.
|
||||
"""
|
||||
if database.is_oauth2_enabled() and not get_oauth2_access_token(
|
||||
database.get_oauth2_config(),
|
||||
database.id,
|
||||
g.user.id,
|
||||
database.db_engine_spec,
|
||||
):
|
||||
database.start_oauth2_dance()
|
||||
return super().get_table_names(database, inspector, schema)
|
||||
|
||||
@classmethod
|
||||
def get_extra_table_metadata(
|
||||
cls,
|
||||
@@ -311,6 +383,14 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
|
||||
conn = engine.connect()
|
||||
idx = 0
|
||||
|
||||
# Check for OAuth2 config. Skip URL access for OAuth2 connections (user
|
||||
# might not have a token, or admin adding a sheet they don't have access to)
|
||||
oauth2_config_in_params = parameters.get("oauth2_client_info")
|
||||
oauth2_config_in_secure_extra = json.loads(
|
||||
properties.get("masked_encrypted_extra", "{}")
|
||||
).get("oauth2_client_info")
|
||||
is_oauth2_conn = bool(oauth2_config_in_params or oauth2_config_in_secure_extra)
|
||||
|
||||
for name, url in table_catalog.items():
|
||||
if not name:
|
||||
errors.append(
|
||||
@@ -334,7 +414,11 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
|
||||
)
|
||||
return errors
|
||||
|
||||
if is_oauth2_conn:
|
||||
continue
|
||||
|
||||
try:
|
||||
url = url.replace('"', '""')
|
||||
results = conn.execute(f'SELECT * FROM "{url}" LIMIT 1') # noqa: S608
|
||||
results.fetchall()
|
||||
except Exception: # pylint: disable=broad-except
|
||||
|
||||
Reference in New Issue
Block a user