Compare commits

...

7 Commits

Author SHA1 Message Date
Beto Dealmeida
33f9ccdeb2 fix(Snowflake): fix OAuth2 session 2026-06-25 19:38:14 -04:00
Evan Rusackas
dd082e887d Merge branch 'master' into snowflake-oauth 2026-06-21 22:06:08 -07:00
htamakos
97ddb27391 chore: run ruff format 2026-01-09 08:48:59 +09:00
htamakos
7a210c66b8 feat(snowflake): guard get_oauth2_config with is_oauth2_enabled 2025-12-31 00:49:26 +09:00
htamakos
cd75612aeb fix(snowflake): disable OAuth2 in background without request context 2025-12-30 23:04:41 +09:00
htamakos
829b7738c3 feat(snowflake): Update 'Impersonate logged in user' checkbox label to include snowflake 2025-12-30 00:47:26 +09:00
htamakos
6793919262 feat(snowflake): Add support for OAuth 2.0 authentication 2025-12-30 00:39:30 +09:00
3 changed files with 435 additions and 3 deletions

View File

@@ -498,7 +498,7 @@ const ExtraOptions = ({
onChange={onInputChange}
>
{t(
'Impersonate logged in user (Presto, Trino, Drill, Hive, and Google Sheets)',
'Impersonate logged in user (Presto, Trino, Drill, Hive, Snowflake and Google Sheets)',
)}
</Checkbox>
<InfoTooltip

View File

@@ -20,21 +20,24 @@ import logging
import re
from datetime import datetime
from re import Pattern
from typing import Any, Optional, TYPE_CHECKING, TypedDict
from typing import Any, cast, Optional, TYPE_CHECKING, TypedDict
from urllib import parse
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from flask import current_app as app
from flask import current_app as app, has_request_context
from flask_babel import gettext as __
from marshmallow import fields, Schema
from sqlalchemy import types
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.exc import DatabaseError as SqlalchemyDatabaseError
from superset import is_feature_enabled, security_manager
from superset.constants import TimeGrain
from superset.exceptions import SupersetErrorException
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import (
BaseEngineSpec,
@@ -44,12 +47,48 @@ from superset.db_engine_specs.base import (
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query
from superset.superset_typing import (
OAuth2ClientConfig,
OAuth2State,
)
from superset.utils import json
from superset.utils.core import get_user_agent, QuerySource
from superset.utils.oauth2 import encode_oauth2_state
if TYPE_CHECKING:
from superset.models.core import Database
try:
from snowflake.connector.errors import DatabaseError
except ImportError:
# Use a distinct sentinel type when snowflake is not installed to avoid
# matching unrelated exception types (using `Exception` would be too broad).
class _SnowflakeDatabaseError(Exception):
"""Sentinel type to stand in for snowflake.connector.errors.DatabaseError."""
pass
DatabaseError = _SnowflakeDatabaseError
class CustomSnowflakeAuthErrorMeta(type):
def __instancecheck__(cls, instance: object) -> bool:
if not isinstance(instance, SqlalchemyDatabaseError):
return False
orig = cast(SqlalchemyDatabaseError, instance).orig
return isinstance(orig, DatabaseError) and (
"Invalid OAuth access token" in str(instance)
or "User is empty" in str(instance)
)
class CustomSnowflakeAuthError(DatabaseError, metaclass=CustomSnowflakeAuthErrorMeta):
pass
# Snowflake error code: "This session does not have a current database."
_MISSING_DB_CONTEXT_CODE = "090105"
# Regular expressions to catch custom errors
OBJECT_DOES_NOT_EXIST_REGEX = re.compile(
r"Object (?P<object>.*?) does not exist or not authorized."
@@ -191,6 +230,168 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
),
}
# OAuth 2.0 support
supports_oauth2 = True
oauth2_exception = CustomSnowflakeAuthError
@classmethod
def is_oauth2_enabled(cls) -> bool:
"""
Return whether OAuth2 authentication is enabled.
"""
# When alerts or reports connect to the database in the background,
# OAuth2 authentication fails; therefore, OAuth2 authentication is disabled
# for background execution.
if not has_request_context():
return False
return (
cls.supports_oauth2
and cls.engine_name in app.config["DATABASE_OAUTH2_CLIENTS"]
)
@classmethod
def get_oauth2_config(cls) -> OAuth2ClientConfig | None:
"""
Build the DB engine spec level OAuth2 client config.
"""
if not cls.is_oauth2_enabled():
return None
return super().get_oauth2_config()
@classmethod
def impersonate_user(
cls,
database: Database,
username: str | None,
user_token: str | None,
url: URL,
engine_kwargs: dict[str, Any],
) -> tuple[URL, dict[str, Any]]:
"""
Modify URL and/or engine kwargs to impersonate a different user.
"""
connect_args = engine_kwargs.setdefault("connect_args", {})
# When test_connection is executed (i.e., when validate_default_parameters is
# set to True in connect_args), authentication via OAuth is not performed.
if (
not connect_args.get("validate_default_parameters", False)
and cls.is_oauth2_enabled()
):
url = url.update_query_dict({"authenticator": "oauth"})
connect_args["authenticator"] = "oauth"
# Let OAuth token determine the role rather than forcing hardcoded one.
url = url.difference_update_query(["role"])
if user_token and cls.is_oauth2_enabled():
if username is not None:
user = security_manager.find_user(username=username)
if user and user.email:
if is_feature_enabled("IMPERSONATE_WITH_EMAIL_PREFIX"):
url = url.set(username=user.email.split("@")[0])
else:
url = url.set(username=user.email)
url = url.update_query_dict({"token": user_token})
return url, engine_kwargs
@classmethod
def get_oauth2_authorization_uri(
cls,
config: OAuth2ClientConfig,
state: OAuth2State,
) -> str:
"""
Return URI for initial OAuth2 request.
"""
uri = config["authorization_request_uri"]
# When calling the Snowflake OAuth authorization endpoint for a custom client,
# specify only the query parameters documented in the URL below.
# Adding unsupported parameters
# (e.g., `prompt` as used in BaseEngineSpec.get_oauth2_authorization_uri)
# will cause an error.
# https://docs.snowflake.com/user-guide/oauth-custom#query-parameters
params = {
"scope": config["scope"],
"response_type": "code",
"state": encode_oauth2_state(state),
"redirect_uri": config["redirect_uri"],
"client_id": config["id"],
}
return parse.urljoin(uri, "?" + parse.urlencode(params))
@classmethod
def _restore_session_context(cls, cursor: Any, database: "Database") -> None:
"""
Re-issue USE statements to restore session context lost after an OAuth
reconnect or connection-pool recycle (Snowflake error 090105).
"""
url = database.url_object
db_name, _, schema_name = (url.database or "").partition("/")
query_params = url.query
def _val(key: str) -> str | None:
v = query_params.get(key)
if isinstance(v, (list, tuple)):
return v[-1] if v else None
return v
warehouse = _val("warehouse")
role = _val("role")
def _use(stmt: str, name: str) -> None:
name = name.replace('"', '""')
cursor.execute(f'{stmt} "{name}"')
if db_name:
_use("USE DATABASE", db_name)
if schema_name:
_use("USE SCHEMA", schema_name)
if warehouse:
_use("USE WAREHOUSE", warehouse)
# OAuth token determines the role; only restore for non-OAuth connections.
if role and not database.is_oauth2_enabled():
_use("USE ROLE", role)
@classmethod
def execute(
cls,
cursor: Any,
query: str,
database: "Database",
**kwargs: Any,
) -> None:
try:
cursor.execute(query)
except Exception as ex:
if database.is_oauth2_enabled() and cls.needs_oauth2(ex):
cls.start_oauth2_dance(database)
if _MISSING_DB_CONTEXT_CODE in str(ex):
try:
cls._restore_session_context(cursor, database)
cursor.execute(query)
return
except Exception as retry_ex:
if _MISSING_DB_CONTEXT_CODE in str(retry_ex):
raise SupersetErrorException(
SupersetError(
message=(
"Snowflake connection is missing a default "
"database. Ensure the database is specified "
"in the connection URI "
"(e.g. snowflake://user@account/MY_DATABASE)."
),
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
level=ErrorLevel.ERROR,
)
) from retry_ex
raise cls.get_dbapi_mapped_exception(retry_ex) from retry_ex
raise cls.get_dbapi_mapped_exception(ex) from ex
@staticmethod
def get_extra_params(
database: Database, source: QuerySource | None = None
@@ -438,6 +639,18 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
database: "Database",
params: dict[str, Any],
) -> None:
# To use OAuth authentication, a database connection must first be created using
# another authenticator (typically key-pair authentication)
# with “Impersonate logged in user” enabled.
# Key-pair authentication is used for connection tests,
# while OAuth authentication is used when executing actual queries,
# such as in SQL Lab or dashboards.
# Therefore, when using OAuth authentication, the key-pair authentication
# settings are not loaded, and the connection is established using OAuth only.
connect_args = params.get("connect_args") or {}
if connect_args.get("authenticator") == "oauth":
return
if not database.encrypted_extra:
return
try:

View File

@@ -26,6 +26,7 @@ from pytest_mock import MockerFixture
from sqlalchemy.engine.url import make_url
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.superset_typing import OAuth2ClientConfig
from superset.utils import json
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm # noqa: F401
@@ -438,3 +439,221 @@ def test_unmask_encrypted_extra() -> None:
},
}
)
@pytest.fixture
def oauth2_config() -> OAuth2ClientConfig:
"""
Config for Snowflake OAuth2.
"""
return {
"id": "snowflake-oauth2-client-id",
"secret": "snowflake-oauth2-client-secret",
"scope": "refresh_token",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "https://snowflake.oauth2.example/oauth/authorize",
"token_request_uri": "https://snowflake.oauth2.example/oauth/token-request",
"request_content_type": "data",
}
def test_get_oauth2_token(
mocker: MockerFixture,
oauth2_config: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_token`.
"""
from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
requests = mocker.patch("superset.db_engine_specs.base.requests")
requests.post().json.return_value = {
"access_token": "access-token",
"expires_in": 3600,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
assert SnowflakeEngineSpec.get_oauth2_token(oauth2_config, "code") == {
"access_token": "access-token",
"expires_in": 3600,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
requests.post.assert_called_with(
"https://snowflake.oauth2.example/oauth/token-request",
data={
"code": "code",
"client_id": "snowflake-oauth2-client-id",
"client_secret": "snowflake-oauth2-client-secret",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"grant_type": "authorization_code",
},
timeout=30.0,
)
def test_impersonate_user(mocker: MockerFixture) -> None:
"""
Test that Snowflake supports user impersonation.
"""
from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
from superset.models.core import Database
database = Database(sqlalchemy_uri="snowflake://abc")
mocker.patch(
"superset.db_engine_specs.snowflake.SnowflakeEngineSpec.is_oauth2_enabled",
return_value=True,
)
assert SnowflakeEngineSpec.impersonate_user(
database=database,
username=None,
user_token=None,
url=make_url("snowflake://user:pass@account/database_name/default"),
engine_kwargs={
"connect_args": {
"validate_default_parameters": True,
},
},
) == (
make_url("snowflake://user:pass@account/database_name/default"),
{"connect_args": {"validate_default_parameters": True}},
)
assert SnowflakeEngineSpec.impersonate_user(
database=database,
username=None,
user_token=None,
url=make_url("snowflake://user:pass@account/database_name/default"),
engine_kwargs={},
) == (
make_url(
"snowflake://user:pass@account/database_name/default?authenticator=oauth"
),
{"connect_args": {"authenticator": "oauth"}},
)
mocker.patch(
"superset.db_engine_specs.snowflake.is_feature_enabled",
return_value=True,
)
mocker.patch(
"superset.security_manager.find_user",
return_value=mocker.MagicMock(email="impersonated_user@example.com"),
)
assert SnowflakeEngineSpec.impersonate_user(
database=database,
username="impersonated_user",
user_token="test_token", # noqa: S106
url=make_url("snowflake://user:pass@account/database_name/default"),
engine_kwargs={},
) == (
make_url(
"snowflake://impersonated_user:pass@account/database_name/default?authenticator=oauth&token=test_token"
),
{"connect_args": {"authenticator": "oauth"}},
)
# Role in the URL should be stripped when OAuth is active so the token's role wins.
url_with_role = make_url(
"snowflake://user:pass@account/database_name/default?role=MY_ROLE&warehouse=MY_WH"
)
result_url, result_kwargs = SnowflakeEngineSpec.impersonate_user(
database=database,
username=None,
user_token=None,
url=url_with_role,
engine_kwargs={},
)
assert "role" not in result_url.query
assert result_url.query.get("warehouse") == "MY_WH"
def test_restore_session_context(mocker: MockerFixture) -> None:
"""
Test that _restore_session_context re-issues the correct USE statements.
"""
from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
from superset.models.core import Database
cursor = mocker.MagicMock()
database = Database(
sqlalchemy_uri="snowflake://user:pass@account/MY_DB/MY_SCHEMA?warehouse=MY_WH&role=MY_ROLE"
)
mocker.patch.object(database, "is_oauth2_enabled", return_value=False)
SnowflakeEngineSpec._restore_session_context(cursor, database)
cursor.execute.assert_any_call('USE DATABASE "MY_DB"')
cursor.execute.assert_any_call('USE SCHEMA "MY_SCHEMA"')
cursor.execute.assert_any_call('USE WAREHOUSE "MY_WH"')
cursor.execute.assert_any_call('USE ROLE "MY_ROLE"')
assert cursor.execute.call_count == 4
def test_restore_session_context_oauth_skips_role(mocker: MockerFixture) -> None:
"""
Test that _restore_session_context skips USE ROLE when OAuth is enabled
so the token's role takes precedence.
"""
from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
from superset.models.core import Database
cursor = mocker.MagicMock()
database = Database(
sqlalchemy_uri="snowflake://user:pass@account/MY_DB/MY_SCHEMA?warehouse=MY_WH&role=MY_ROLE"
)
mocker.patch.object(database, "is_oauth2_enabled", return_value=True)
SnowflakeEngineSpec._restore_session_context(cursor, database)
executed = [call.args[0] for call in cursor.execute.call_args_list]
assert not any("ROLE" in stmt for stmt in executed)
assert cursor.execute.call_count == 3
def test_execute_retries_on_missing_db_context(mocker: MockerFixture) -> None:
"""
Test that execute() restores session context and retries on Snowflake error 090105.
"""
from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
from superset.models.core import Database
database = Database(sqlalchemy_uri="snowflake://user:pass@account/MY_DB")
mocker.patch.object(database, "is_oauth2_enabled", return_value=False)
restore = mocker.patch.object(SnowflakeEngineSpec, "_restore_session_context")
cursor = mocker.MagicMock()
cursor.execute.side_effect = [Exception("090105: missing db context"), None]
SnowflakeEngineSpec.execute(cursor, "SELECT 1", database)
assert cursor.execute.call_count == 2
restore.assert_called_once_with(cursor, database)
def test_execute_raises_friendly_error_when_retry_also_fails(
mocker: MockerFixture,
) -> None:
"""
Test that execute() raises a user-friendly SupersetErrorException when the
retry after context restoration still hits error 090105.
"""
from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
from superset.exceptions import SupersetErrorException
from superset.models.core import Database
database = Database(sqlalchemy_uri="snowflake://user:pass@account/MY_DB")
mocker.patch.object(database, "is_oauth2_enabled", return_value=False)
mocker.patch.object(SnowflakeEngineSpec, "_restore_session_context")
cursor = mocker.MagicMock()
cursor.execute.side_effect = Exception("090105: still missing")
with pytest.raises(SupersetErrorException, match="missing a default database"):
SnowflakeEngineSpec.execute(cursor, "SELECT 1", database)