fix: more DB OAuth2 fixes (#37398)

This commit is contained in:
Vitor Avila
2026-01-30 21:11:26 -03:00
committed by GitHub
parent 1ee14c5993
commit 6043e7e7e3
6 changed files with 543 additions and 76 deletions

View File

@@ -30,9 +30,11 @@ from flask_babel import gettext as __
from marshmallow import fields, Schema
from marshmallow.exceptions import ValidationError
from requests import Session
from requests.exceptions import HTTPError
from shillelagh.adapters.api.gsheets.lib import SCOPES
from shillelagh.exceptions import UnauthenticatedError
from sqlalchemy.engine import create_engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from superset import db, security_manager
@@ -41,7 +43,9 @@ from superset.db_engine_specs.base import DatabaseCategory
from superset.db_engine_specs.shillelagh import ShillelaghEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetException
from superset.superset_typing import OAuth2TokenResponse
from superset.utils import json
from superset.utils.oauth2 import get_oauth2_access_token
if TYPE_CHECKING:
from superset.models.core import Database
@@ -83,14 +87,16 @@ class GSheetsParametersSchema(Schema):
)
class GSheetsParametersType(TypedDict):
class GSheetsParametersType(TypedDict, total=False):
service_account_info: str
catalog: dict[str, str] | None
oauth2_client_info: dict[str, str] | None
class GSheetsPropertiesType(TypedDict):
class GSheetsPropertiesType(TypedDict, total=False):
parameters: GSheetsParametersType
catalog: dict[str, str]
masked_encrypted_extra: str
class GSheetsEngineSpec(ShillelaghEngineSpec):
@@ -123,7 +129,10 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
# when editing the database, mask this field in `encrypted_extra`
# pylint: disable=invalid-name
encrypted_extra_sensitive_fields = {"$.service_account_info.private_key"}
encrypted_extra_sensitive_fields = {
"$.service_account_info.private_key",
"$.oauth2_client_info.secret",
}
custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
SYNTAX_ERROR_REGEX: (
@@ -179,6 +188,47 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
}
return urljoin(uri, "?" + urlencode(params))
@classmethod
def needs_oauth2(cls, ex: Exception) -> bool:
"""
Check if the exception is one that indicates OAuth2 is needed.
In case the token was manually revoked on Google side, `google-auth` will
try to automatically refresh credentials, but it fails since it only has the
access token. This override catches this scenario as well.
"""
return (
g
and hasattr(g, "user")
and (
isinstance(ex, cls.oauth2_exception)
or "credentials do not contain the necessary fields" in str(ex)
)
)
@classmethod
def get_oauth2_fresh_token(
cls,
config: OAuth2ClientConfig,
refresh_token: str,
) -> OAuth2TokenResponse:
"""
Refresh an OAuth2 access token that has expired.
When trying to refresh an expired token that was revoked on Google side,
the request fails with 400 status code.
"""
try:
return super().get_oauth2_fresh_token(config, refresh_token)
except HTTPError as ex:
if ex.response is not None and ex.response.status_code == 400:
error_data = ex.response.json()
if error_data.get("error") == "invalid_grant":
raise UnauthenticatedError(
error_data.get("error_description", "Token has been revoked")
) from ex
raise
@classmethod
def impersonate_user(
cls,
@@ -198,6 +248,28 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
return url, engine_kwargs
@classmethod
def get_table_names(
cls,
database: Database,
inspector: Inspector,
schema: str | None,
) -> set[str]:
"""
Get all sheets added to the connection.
For OAuth2 connections, force the OAuth2 dance in case the user
doesn't have a token yet to avoid showing table names berofe auth.
"""
if database.is_oauth2_enabled() and not get_oauth2_access_token(
database.get_oauth2_config(),
database.id,
g.user.id,
database.db_engine_spec,
):
database.start_oauth2_dance()
return super().get_table_names(database, inspector, schema)
@classmethod
def get_extra_table_metadata(
cls,
@@ -311,6 +383,14 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
conn = engine.connect()
idx = 0
# Check for OAuth2 config. Skip URL access for OAuth2 connections (user
# might not have a token, or admin adding a sheet they don't have access to)
oauth2_config_in_params = parameters.get("oauth2_client_info")
oauth2_config_in_secure_extra = json.loads(
properties.get("masked_encrypted_extra", "{}")
).get("oauth2_client_info")
is_oauth2_conn = bool(oauth2_config_in_params or oauth2_config_in_secure_extra)
for name, url in table_catalog.items():
if not name:
errors.append(
@@ -334,7 +414,11 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
)
return errors
if is_oauth2_conn:
continue
try:
url = url.replace('"', '""')
results = conn.execute(f'SELECT * FROM "{url}" LIMIT 1') # noqa: S608
results.fetchall()
except Exception: # pylint: disable=broad-except