mirror of
https://github.com/apache/superset.git
synced 2026-06-28 19:05:31 +00:00
Compare commits
7 Commits
chore/ci/s
...
snowflake-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
33f9ccdeb2 | ||
|
|
dd082e887d | ||
|
|
97ddb27391 | ||
|
|
7a210c66b8 | ||
|
|
cd75612aeb | ||
|
|
829b7738c3 | ||
|
|
6793919262 |
@@ -498,7 +498,7 @@ const ExtraOptions = ({
|
||||
onChange={onInputChange}
|
||||
>
|
||||
{t(
|
||||
'Impersonate logged in user (Presto, Trino, Drill, Hive, and Google Sheets)',
|
||||
'Impersonate logged in user (Presto, Trino, Drill, Hive, Snowflake and Google Sheets)',
|
||||
)}
|
||||
</Checkbox>
|
||||
<InfoTooltip
|
||||
|
||||
@@ -20,21 +20,24 @@ import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from re import Pattern
|
||||
from typing import Any, Optional, TYPE_CHECKING, TypedDict
|
||||
from typing import Any, cast, Optional, TYPE_CHECKING, TypedDict
|
||||
from urllib import parse
|
||||
|
||||
from apispec import APISpec
|
||||
from apispec.ext.marshmallow import MarshmallowPlugin
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from flask import current_app as app
|
||||
from flask import current_app as app, has_request_context
|
||||
from flask_babel import gettext as __
|
||||
from marshmallow import fields, Schema
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.exc import DatabaseError as SqlalchemyDatabaseError
|
||||
|
||||
from superset import is_feature_enabled, security_manager
|
||||
from superset.constants import TimeGrain
|
||||
from superset.exceptions import SupersetErrorException
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.db_engine_specs.base import (
|
||||
BaseEngineSpec,
|
||||
@@ -44,12 +47,48 @@ from superset.db_engine_specs.base import (
|
||||
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.superset_typing import (
|
||||
OAuth2ClientConfig,
|
||||
OAuth2State,
|
||||
)
|
||||
from superset.utils import json
|
||||
from superset.utils.core import get_user_agent, QuerySource
|
||||
from superset.utils.oauth2 import encode_oauth2_state
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.core import Database
|
||||
|
||||
try:
|
||||
from snowflake.connector.errors import DatabaseError
|
||||
except ImportError:
|
||||
# Use a distinct sentinel type when snowflake is not installed to avoid
|
||||
# matching unrelated exception types (using `Exception` would be too broad).
|
||||
class _SnowflakeDatabaseError(Exception):
|
||||
"""Sentinel type to stand in for snowflake.connector.errors.DatabaseError."""
|
||||
|
||||
pass
|
||||
|
||||
DatabaseError = _SnowflakeDatabaseError
|
||||
|
||||
|
||||
class CustomSnowflakeAuthErrorMeta(type):
|
||||
def __instancecheck__(cls, instance: object) -> bool:
|
||||
if not isinstance(instance, SqlalchemyDatabaseError):
|
||||
return False
|
||||
orig = cast(SqlalchemyDatabaseError, instance).orig
|
||||
return isinstance(orig, DatabaseError) and (
|
||||
"Invalid OAuth access token" in str(instance)
|
||||
or "User is empty" in str(instance)
|
||||
)
|
||||
|
||||
|
||||
class CustomSnowflakeAuthError(DatabaseError, metaclass=CustomSnowflakeAuthErrorMeta):
|
||||
pass
|
||||
|
||||
|
||||
# Snowflake error code: "This session does not have a current database."
|
||||
_MISSING_DB_CONTEXT_CODE = "090105"
|
||||
|
||||
# Regular expressions to catch custom errors
|
||||
OBJECT_DOES_NOT_EXIST_REGEX = re.compile(
|
||||
r"Object (?P<object>.*?) does not exist or not authorized."
|
||||
@@ -191,6 +230,168 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
||||
),
|
||||
}
|
||||
|
||||
# OAuth 2.0 support
|
||||
supports_oauth2 = True
|
||||
oauth2_exception = CustomSnowflakeAuthError
|
||||
|
||||
@classmethod
|
||||
def is_oauth2_enabled(cls) -> bool:
|
||||
"""
|
||||
Return whether OAuth2 authentication is enabled.
|
||||
"""
|
||||
|
||||
# When alerts or reports connect to the database in the background,
|
||||
# OAuth2 authentication fails; therefore, OAuth2 authentication is disabled
|
||||
# for background execution.
|
||||
if not has_request_context():
|
||||
return False
|
||||
|
||||
return (
|
||||
cls.supports_oauth2
|
||||
and cls.engine_name in app.config["DATABASE_OAUTH2_CLIENTS"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_oauth2_config(cls) -> OAuth2ClientConfig | None:
|
||||
"""
|
||||
Build the DB engine spec level OAuth2 client config.
|
||||
"""
|
||||
if not cls.is_oauth2_enabled():
|
||||
return None
|
||||
|
||||
return super().get_oauth2_config()
|
||||
|
||||
@classmethod
|
||||
def impersonate_user(
|
||||
cls,
|
||||
database: Database,
|
||||
username: str | None,
|
||||
user_token: str | None,
|
||||
url: URL,
|
||||
engine_kwargs: dict[str, Any],
|
||||
) -> tuple[URL, dict[str, Any]]:
|
||||
"""
|
||||
Modify URL and/or engine kwargs to impersonate a different user.
|
||||
"""
|
||||
connect_args = engine_kwargs.setdefault("connect_args", {})
|
||||
|
||||
# When test_connection is executed (i.e., when validate_default_parameters is
|
||||
# set to True in connect_args), authentication via OAuth is not performed.
|
||||
if (
|
||||
not connect_args.get("validate_default_parameters", False)
|
||||
and cls.is_oauth2_enabled()
|
||||
):
|
||||
url = url.update_query_dict({"authenticator": "oauth"})
|
||||
connect_args["authenticator"] = "oauth"
|
||||
# Let OAuth token determine the role rather than forcing hardcoded one.
|
||||
url = url.difference_update_query(["role"])
|
||||
|
||||
if user_token and cls.is_oauth2_enabled():
|
||||
if username is not None:
|
||||
user = security_manager.find_user(username=username)
|
||||
if user and user.email:
|
||||
if is_feature_enabled("IMPERSONATE_WITH_EMAIL_PREFIX"):
|
||||
url = url.set(username=user.email.split("@")[0])
|
||||
else:
|
||||
url = url.set(username=user.email)
|
||||
|
||||
url = url.update_query_dict({"token": user_token})
|
||||
|
||||
return url, engine_kwargs
|
||||
|
||||
@classmethod
|
||||
def get_oauth2_authorization_uri(
|
||||
cls,
|
||||
config: OAuth2ClientConfig,
|
||||
state: OAuth2State,
|
||||
) -> str:
|
||||
"""
|
||||
Return URI for initial OAuth2 request.
|
||||
"""
|
||||
uri = config["authorization_request_uri"]
|
||||
# When calling the Snowflake OAuth authorization endpoint for a custom client,
|
||||
# specify only the query parameters documented in the URL below.
|
||||
# Adding unsupported parameters
|
||||
# (e.g., `prompt` as used in BaseEngineSpec.get_oauth2_authorization_uri)
|
||||
# will cause an error.
|
||||
# https://docs.snowflake.com/user-guide/oauth-custom#query-parameters
|
||||
params = {
|
||||
"scope": config["scope"],
|
||||
"response_type": "code",
|
||||
"state": encode_oauth2_state(state),
|
||||
"redirect_uri": config["redirect_uri"],
|
||||
"client_id": config["id"],
|
||||
}
|
||||
return parse.urljoin(uri, "?" + parse.urlencode(params))
|
||||
|
||||
@classmethod
|
||||
def _restore_session_context(cls, cursor: Any, database: "Database") -> None:
|
||||
"""
|
||||
Re-issue USE statements to restore session context lost after an OAuth
|
||||
reconnect or connection-pool recycle (Snowflake error 090105).
|
||||
"""
|
||||
url = database.url_object
|
||||
db_name, _, schema_name = (url.database or "").partition("/")
|
||||
query_params = url.query
|
||||
|
||||
def _val(key: str) -> str | None:
|
||||
v = query_params.get(key)
|
||||
if isinstance(v, (list, tuple)):
|
||||
return v[-1] if v else None
|
||||
return v
|
||||
|
||||
warehouse = _val("warehouse")
|
||||
role = _val("role")
|
||||
|
||||
def _use(stmt: str, name: str) -> None:
|
||||
name = name.replace('"', '""')
|
||||
cursor.execute(f'{stmt} "{name}"')
|
||||
|
||||
if db_name:
|
||||
_use("USE DATABASE", db_name)
|
||||
if schema_name:
|
||||
_use("USE SCHEMA", schema_name)
|
||||
if warehouse:
|
||||
_use("USE WAREHOUSE", warehouse)
|
||||
# OAuth token determines the role; only restore for non-OAuth connections.
|
||||
if role and not database.is_oauth2_enabled():
|
||||
_use("USE ROLE", role)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
cursor: Any,
|
||||
query: str,
|
||||
database: "Database",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
try:
|
||||
cursor.execute(query)
|
||||
except Exception as ex:
|
||||
if database.is_oauth2_enabled() and cls.needs_oauth2(ex):
|
||||
cls.start_oauth2_dance(database)
|
||||
if _MISSING_DB_CONTEXT_CODE in str(ex):
|
||||
try:
|
||||
cls._restore_session_context(cursor, database)
|
||||
cursor.execute(query)
|
||||
return
|
||||
except Exception as retry_ex:
|
||||
if _MISSING_DB_CONTEXT_CODE in str(retry_ex):
|
||||
raise SupersetErrorException(
|
||||
SupersetError(
|
||||
message=(
|
||||
"Snowflake connection is missing a default "
|
||||
"database. Ensure the database is specified "
|
||||
"in the connection URI "
|
||||
"(e.g. snowflake://user@account/MY_DATABASE)."
|
||||
),
|
||||
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
|
||||
level=ErrorLevel.ERROR,
|
||||
)
|
||||
) from retry_ex
|
||||
raise cls.get_dbapi_mapped_exception(retry_ex) from retry_ex
|
||||
raise cls.get_dbapi_mapped_exception(ex) from ex
|
||||
|
||||
@staticmethod
|
||||
def get_extra_params(
|
||||
database: Database, source: QuerySource | None = None
|
||||
@@ -438,6 +639,18 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
||||
database: "Database",
|
||||
params: dict[str, Any],
|
||||
) -> None:
|
||||
# To use OAuth authentication, a database connection must first be created using
|
||||
# another authenticator (typically key-pair authentication)
|
||||
# with “Impersonate logged in user” enabled.
|
||||
# Key-pair authentication is used for connection tests,
|
||||
# while OAuth authentication is used when executing actual queries,
|
||||
# such as in SQL Lab or dashboards.
|
||||
# Therefore, when using OAuth authentication, the key-pair authentication
|
||||
# settings are not loaded, and the connection is established using OAuth only.
|
||||
connect_args = params.get("connect_args") or {}
|
||||
if connect_args.get("authenticator") == "oauth":
|
||||
return
|
||||
|
||||
if not database.encrypted_extra:
|
||||
return
|
||||
try:
|
||||
|
||||
@@ -26,6 +26,7 @@ from pytest_mock import MockerFixture
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.superset_typing import OAuth2ClientConfig
|
||||
from superset.utils import json
|
||||
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
|
||||
from tests.unit_tests.fixtures.common import dttm # noqa: F401
|
||||
@@ -438,3 +439,221 @@ def test_unmask_encrypted_extra() -> None:
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_config() -> OAuth2ClientConfig:
|
||||
"""
|
||||
Config for Snowflake OAuth2.
|
||||
"""
|
||||
return {
|
||||
"id": "snowflake-oauth2-client-id",
|
||||
"secret": "snowflake-oauth2-client-secret",
|
||||
"scope": "refresh_token",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"authorization_request_uri": "https://snowflake.oauth2.example/oauth/authorize",
|
||||
"token_request_uri": "https://snowflake.oauth2.example/oauth/token-request",
|
||||
"request_content_type": "data",
|
||||
}
|
||||
|
||||
|
||||
def test_get_oauth2_token(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config: OAuth2ClientConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Test `get_oauth2_token`.
|
||||
"""
|
||||
from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
|
||||
|
||||
requests = mocker.patch("superset.db_engine_specs.base.requests")
|
||||
requests.post().json.return_value = {
|
||||
"access_token": "access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "scope",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "refresh-token",
|
||||
}
|
||||
|
||||
assert SnowflakeEngineSpec.get_oauth2_token(oauth2_config, "code") == {
|
||||
"access_token": "access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "scope",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "refresh-token",
|
||||
}
|
||||
requests.post.assert_called_with(
|
||||
"https://snowflake.oauth2.example/oauth/token-request",
|
||||
data={
|
||||
"code": "code",
|
||||
"client_id": "snowflake-oauth2-client-id",
|
||||
"client_secret": "snowflake-oauth2-client-secret",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
|
||||
def test_impersonate_user(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that Snowflake supports user impersonation.
|
||||
"""
|
||||
from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
|
||||
from superset.models.core import Database
|
||||
|
||||
database = Database(sqlalchemy_uri="snowflake://abc")
|
||||
|
||||
mocker.patch(
|
||||
"superset.db_engine_specs.snowflake.SnowflakeEngineSpec.is_oauth2_enabled",
|
||||
return_value=True,
|
||||
)
|
||||
|
||||
assert SnowflakeEngineSpec.impersonate_user(
|
||||
database=database,
|
||||
username=None,
|
||||
user_token=None,
|
||||
url=make_url("snowflake://user:pass@account/database_name/default"),
|
||||
engine_kwargs={
|
||||
"connect_args": {
|
||||
"validate_default_parameters": True,
|
||||
},
|
||||
},
|
||||
) == (
|
||||
make_url("snowflake://user:pass@account/database_name/default"),
|
||||
{"connect_args": {"validate_default_parameters": True}},
|
||||
)
|
||||
|
||||
assert SnowflakeEngineSpec.impersonate_user(
|
||||
database=database,
|
||||
username=None,
|
||||
user_token=None,
|
||||
url=make_url("snowflake://user:pass@account/database_name/default"),
|
||||
engine_kwargs={},
|
||||
) == (
|
||||
make_url(
|
||||
"snowflake://user:pass@account/database_name/default?authenticator=oauth"
|
||||
),
|
||||
{"connect_args": {"authenticator": "oauth"}},
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"superset.db_engine_specs.snowflake.is_feature_enabled",
|
||||
return_value=True,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"superset.security_manager.find_user",
|
||||
return_value=mocker.MagicMock(email="impersonated_user@example.com"),
|
||||
)
|
||||
assert SnowflakeEngineSpec.impersonate_user(
|
||||
database=database,
|
||||
username="impersonated_user",
|
||||
user_token="test_token", # noqa: S106
|
||||
url=make_url("snowflake://user:pass@account/database_name/default"),
|
||||
engine_kwargs={},
|
||||
) == (
|
||||
make_url(
|
||||
"snowflake://impersonated_user:pass@account/database_name/default?authenticator=oauth&token=test_token"
|
||||
),
|
||||
{"connect_args": {"authenticator": "oauth"}},
|
||||
)
|
||||
|
||||
# Role in the URL should be stripped when OAuth is active so the token's role wins.
|
||||
url_with_role = make_url(
|
||||
"snowflake://user:pass@account/database_name/default?role=MY_ROLE&warehouse=MY_WH"
|
||||
)
|
||||
result_url, result_kwargs = SnowflakeEngineSpec.impersonate_user(
|
||||
database=database,
|
||||
username=None,
|
||||
user_token=None,
|
||||
url=url_with_role,
|
||||
engine_kwargs={},
|
||||
)
|
||||
assert "role" not in result_url.query
|
||||
assert result_url.query.get("warehouse") == "MY_WH"
|
||||
|
||||
|
||||
def test_restore_session_context(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that _restore_session_context re-issues the correct USE statements.
|
||||
"""
|
||||
from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
|
||||
from superset.models.core import Database
|
||||
|
||||
cursor = mocker.MagicMock()
|
||||
database = Database(
|
||||
sqlalchemy_uri="snowflake://user:pass@account/MY_DB/MY_SCHEMA?warehouse=MY_WH&role=MY_ROLE"
|
||||
)
|
||||
mocker.patch.object(database, "is_oauth2_enabled", return_value=False)
|
||||
|
||||
SnowflakeEngineSpec._restore_session_context(cursor, database)
|
||||
|
||||
cursor.execute.assert_any_call('USE DATABASE "MY_DB"')
|
||||
cursor.execute.assert_any_call('USE SCHEMA "MY_SCHEMA"')
|
||||
cursor.execute.assert_any_call('USE WAREHOUSE "MY_WH"')
|
||||
cursor.execute.assert_any_call('USE ROLE "MY_ROLE"')
|
||||
assert cursor.execute.call_count == 4
|
||||
|
||||
|
||||
def test_restore_session_context_oauth_skips_role(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that _restore_session_context skips USE ROLE when OAuth is enabled
|
||||
so the token's role takes precedence.
|
||||
"""
|
||||
from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
|
||||
from superset.models.core import Database
|
||||
|
||||
cursor = mocker.MagicMock()
|
||||
database = Database(
|
||||
sqlalchemy_uri="snowflake://user:pass@account/MY_DB/MY_SCHEMA?warehouse=MY_WH&role=MY_ROLE"
|
||||
)
|
||||
mocker.patch.object(database, "is_oauth2_enabled", return_value=True)
|
||||
|
||||
SnowflakeEngineSpec._restore_session_context(cursor, database)
|
||||
|
||||
executed = [call.args[0] for call in cursor.execute.call_args_list]
|
||||
assert not any("ROLE" in stmt for stmt in executed)
|
||||
assert cursor.execute.call_count == 3
|
||||
|
||||
|
||||
def test_execute_retries_on_missing_db_context(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that execute() restores session context and retries on Snowflake error 090105.
|
||||
"""
|
||||
from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
|
||||
from superset.models.core import Database
|
||||
|
||||
database = Database(sqlalchemy_uri="snowflake://user:pass@account/MY_DB")
|
||||
mocker.patch.object(database, "is_oauth2_enabled", return_value=False)
|
||||
restore = mocker.patch.object(SnowflakeEngineSpec, "_restore_session_context")
|
||||
|
||||
cursor = mocker.MagicMock()
|
||||
cursor.execute.side_effect = [Exception("090105: missing db context"), None]
|
||||
|
||||
SnowflakeEngineSpec.execute(cursor, "SELECT 1", database)
|
||||
|
||||
assert cursor.execute.call_count == 2
|
||||
restore.assert_called_once_with(cursor, database)
|
||||
|
||||
|
||||
def test_execute_raises_friendly_error_when_retry_also_fails(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that execute() raises a user-friendly SupersetErrorException when the
|
||||
retry after context restoration still hits error 090105.
|
||||
"""
|
||||
from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
|
||||
from superset.exceptions import SupersetErrorException
|
||||
from superset.models.core import Database
|
||||
|
||||
database = Database(sqlalchemy_uri="snowflake://user:pass@account/MY_DB")
|
||||
mocker.patch.object(database, "is_oauth2_enabled", return_value=False)
|
||||
mocker.patch.object(SnowflakeEngineSpec, "_restore_session_context")
|
||||
|
||||
cursor = mocker.MagicMock()
|
||||
cursor.execute.side_effect = Exception("090105: still missing")
|
||||
|
||||
with pytest.raises(SupersetErrorException, match="missing a default database"):
|
||||
SnowflakeEngineSpec.execute(cursor, "SELECT 1", database)
|
||||
|
||||
Reference in New Issue
Block a user