Compare commits

...

1 Commits

Author SHA1 Message Date
Beto Dealmeida
1c0429a0de feat(AWS IAM): phase 1 2026-01-22 19:06:46 -05:00
6 changed files with 1017 additions and 0 deletions

View File

@@ -108,6 +108,8 @@ services:
extra_hosts:
- "host.docker.internal:host-gateway"
user: *superset-user
ports:
- "${SUPERSET_PORT:-8088}:8088"
depends_on:
superset-init-light:
condition: service_completed_successfully

View File

@@ -42,3 +42,16 @@ class AuroraPostgresDataAPI(PostgresEngineSpec):
"secret_arn={secret_arn}&"
"region_name={region_name}"
)
class AuroraPostgresEngineSpec(PostgresEngineSpec):
"""
Aurora PostgreSQL engine spec.
IAM authentication is handled by the parent PostgresEngineSpec via
the aws_iam config in encrypted_extra.
"""
engine = "postgresql"
engine_name = "Aurora PostgreSQL"
default_driver = "psycopg2"

View File

@@ -0,0 +1,343 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
AWS IAM Authentication Mixin for database engine specs.
This mixin provides cross-account IAM authentication support for AWS databases
(Aurora PostgreSQL, Aurora MySQL, Redshift). It handles:
- Assuming IAM roles via STS AssumeRole
- Generating RDS IAM auth tokens
- Configuring SSL (required for IAM auth)
"""
from __future__ import annotations
import logging
from typing import Any, TYPE_CHECKING, TypedDict
from superset.databases.utils import make_url_safe
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetSecurityException
if TYPE_CHECKING:
from superset.models.core import Database
logger = logging.getLogger(__name__)
# Default session duration for STS AssumeRole (1 hour)
DEFAULT_SESSION_DURATION = 3600
# Default port for PostgreSQL
DEFAULT_POSTGRES_PORT = 5432
class AWSIAMConfig(TypedDict, total=False):
"""Configuration for AWS IAM authentication."""
enabled: bool
role_arn: str
external_id: str
region: str
db_username: str
session_duration: int
class AWSIAMAuthMixin:
"""
Mixin that provides AWS IAM authentication for database connections.
This mixin can be used with database engine specs that support IAM
authentication (Aurora PostgreSQL, Aurora MySQL, Redshift).
Configuration is provided via the database's encrypted_extra JSON:
{
"aws_iam": {
"enabled": true,
"role_arn": "arn:aws:iam::222222222222:role/SupersetDatabaseAccess",
"external_id": "superset-prod-12345", # optional
"region": "us-east-1",
"db_username": "superset_iam_user",
"session_duration": 3600 # optional, defaults to 3600
}
}
"""
supports_iam_authentication = True
# AWS error patterns for actionable error messages
aws_iam_custom_errors: dict[str, tuple[SupersetErrorType, str]] = {
"AccessDenied": (
SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR,
"Unable to assume IAM role. Verify the role ARN and trust policy "
"allow access from Superset's IAM role.",
),
"InvalidIdentityToken": (
SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR,
"Invalid IAM credentials. Ensure Superset has a valid IAM role "
"with permissions to assume the target role.",
),
"MalformedPolicyDocument": (
SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR,
"Invalid IAM role ARN format. Please verify the role ARN.",
),
"ExpiredTokenException": (
SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR,
"AWS credentials have expired. Please refresh the connection.",
),
}
@classmethod
def get_iam_credentials(
cls,
role_arn: str,
region: str,
external_id: str | None = None,
session_duration: int = DEFAULT_SESSION_DURATION,
) -> dict[str, Any]:
"""
Assume cross-account IAM role via STS AssumeRole.
:param role_arn: The ARN of the IAM role to assume
:param region: AWS region for the STS client
:param external_id: External ID for the role assumption (optional)
:param session_duration: Duration of the session in seconds
:returns: Dictionary with AccessKeyId, SecretAccessKey, SessionToken
:raises SupersetSecurityException: If role assumption fails
"""
try:
# Lazy import to avoid errors when boto3 is not installed
import boto3
from botocore.exceptions import ClientError
except ImportError as ex:
raise SupersetSecurityException(
SupersetError(
message="boto3 is required for AWS IAM authentication. "
"Install it with: pip install boto3",
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
level=ErrorLevel.ERROR,
)
) from ex
try:
sts_client = boto3.client("sts", region_name=region)
assume_role_kwargs: dict[str, Any] = {
"RoleArn": role_arn,
"RoleSessionName": "superset-iam-session",
"DurationSeconds": session_duration,
}
if external_id:
assume_role_kwargs["ExternalId"] = external_id
response = sts_client.assume_role(**assume_role_kwargs)
return response["Credentials"]
except ClientError as ex:
error_code = ex.response.get("Error", {}).get("Code", "")
error_message = ex.response.get("Error", {}).get("Message", "")
# Handle ExternalId mismatch (shows as AccessDenied with specific message)
# Check this first before generic AccessDenied handling
if "external id" in error_message.lower():
raise SupersetSecurityException(
SupersetError(
message="External ID mismatch. Verify the external_id "
"configuration matches the trust policy.",
error_type=SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR,
level=ErrorLevel.ERROR,
)
) from ex
if error_code in cls.aws_iam_custom_errors:
error_type, message = cls.aws_iam_custom_errors[error_code]
raise SupersetSecurityException(
SupersetError(
message=message,
error_type=error_type,
level=ErrorLevel.ERROR,
)
) from ex
raise SupersetSecurityException(
SupersetError(
message=f"Failed to assume IAM role: {ex}",
error_type=SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR,
level=ErrorLevel.ERROR,
)
) from ex
@classmethod
def generate_rds_auth_token(
cls,
credentials: dict[str, Any],
hostname: str,
port: int,
username: str,
region: str,
) -> str:
"""
Generate RDS IAM auth token using temporary credentials.
:param credentials: STS credentials from assume_role
:param hostname: RDS/Aurora endpoint hostname
:param port: Database port
:param username: Database username configured for IAM auth
:param region: AWS region
:returns: IAM auth token to use as database password
:raises SupersetSecurityException: If token generation fails
"""
try:
import boto3
from botocore.exceptions import ClientError
except ImportError as ex:
raise SupersetSecurityException(
SupersetError(
message="boto3 is required for AWS IAM authentication.",
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
level=ErrorLevel.ERROR,
)
) from ex
try:
rds_client = boto3.client(
"rds",
region_name=region,
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
)
token = rds_client.generate_db_auth_token(
DBHostname=hostname,
Port=port,
DBUsername=username,
)
return token
except ClientError as ex:
raise SupersetSecurityException(
SupersetError(
message=f"Failed to generate RDS auth token: {ex}",
error_type=SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR,
level=ErrorLevel.ERROR,
)
) from ex
@classmethod
def _apply_iam_authentication(
cls,
database: Database,
params: dict[str, Any],
iam_config: AWSIAMConfig,
) -> None:
"""
Apply IAM authentication to the connection parameters.
Full flow: assume role -> generate token -> update connect_args -> enable SSL.
:param database: Database model instance
:param params: Engine parameters dict to modify
:param iam_config: IAM configuration from encrypted_extra
:raises SupersetSecurityException: If any step fails
"""
# Extract configuration
role_arn = iam_config.get("role_arn")
region = iam_config.get("region")
db_username = iam_config.get("db_username")
external_id = iam_config.get("external_id")
session_duration = iam_config.get("session_duration", DEFAULT_SESSION_DURATION)
# Validate required fields
missing_fields = []
if not role_arn:
missing_fields.append("role_arn")
if not region:
missing_fields.append("region")
if not db_username:
missing_fields.append("db_username")
if missing_fields:
raise SupersetSecurityException(
SupersetError(
message="AWS IAM configuration missing required fields: "
f"{', '.join(missing_fields)}",
error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR,
level=ErrorLevel.ERROR,
)
)
# Type assertions after validation (mypy doesn't narrow types from list check)
assert role_arn is not None
assert region is not None
assert db_username is not None
# Get hostname and port from the database URI
uri = make_url_safe(database.sqlalchemy_uri_decrypted)
hostname = uri.host
port = uri.port or DEFAULT_POSTGRES_PORT
if not hostname:
raise SupersetSecurityException(
SupersetError(
message=(
"Database URI must include a hostname for IAM authentication"
),
error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR,
level=ErrorLevel.ERROR,
)
)
logger.debug(
"Applying IAM authentication for %s:%d as user %s",
hostname,
port,
db_username,
)
# Step 1: Assume the IAM role
credentials = cls.get_iam_credentials(
role_arn=role_arn,
region=region,
external_id=external_id,
session_duration=session_duration,
)
# Step 2: Generate the RDS auth token
token = cls.generate_rds_auth_token(
credentials=credentials,
hostname=hostname,
port=port,
username=db_username,
region=region,
)
# Step 3: Update connection parameters
connect_args = params.setdefault("connect_args", {})
# Set the IAM token as the password
connect_args["password"] = token
# Override username if different from URI
connect_args["user"] = db_username
# Step 4: Enable SSL (required for IAM authentication)
# sslmode=require ensures encrypted connection without cert verification
# For production, consider sslmode=verify-full with RDS CA bundle
connect_args["sslmode"] = "require"
logger.debug("IAM authentication configured successfully")

View File

@@ -218,6 +218,12 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
max_column_name_length = 63
try_remove_schema_from_table_name = False # pylint: disable=invalid-name
# Sensitive fields that should be masked in encrypted_extra
encrypted_extra_sensitive_fields = {
"$.aws_iam.external_id",
"$.aws_iam.role_arn",
}
column_type_mappings = (
(
re.compile(r"^double precision", re.IGNORECASE),
@@ -320,6 +326,37 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
return uri, connect_args
@staticmethod
def update_params_from_encrypted_extra(
database: Database,
params: dict[str, Any],
) -> None:
"""
Extract sensitive parameters from encrypted_extra.
Handles AWS IAM authentication if configured, then merges any
remaining encrypted_extra keys into params (standard behavior).
"""
if not database.encrypted_extra:
return
try:
encrypted_extra = json.loads(database.encrypted_extra)
except json.JSONDecodeError as ex:
logger.error(ex, exc_info=True)
raise
# Handle AWS IAM auth: pop the key so it doesn't reach create_engine()
iam_config = encrypted_extra.pop("aws_iam", None)
if iam_config and iam_config.get("enabled"):
from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin
AWSIAMAuthMixin._apply_iam_authentication(database, params, iam_config)
# Standard behavior: merge remaining keys into params
if encrypted_extra:
params.update(encrypted_extra)
@classmethod
def get_default_catalog(cls, database: Database) -> str:
"""

View File

@@ -0,0 +1,239 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-outside-toplevel
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from superset.utils import json
def test_aurora_postgres_engine_spec_properties() -> None:
from superset.db_engine_specs.aurora import AuroraPostgresEngineSpec
assert AuroraPostgresEngineSpec.engine == "postgresql"
assert AuroraPostgresEngineSpec.engine_name == "Aurora PostgreSQL"
assert AuroraPostgresEngineSpec.default_driver == "psycopg2"
def test_update_params_from_encrypted_extra_without_iam() -> None:
from superset.db_engine_specs.postgres import PostgresEngineSpec
database = MagicMock()
database.encrypted_extra = json.dumps({})
database.sqlalchemy_uri_decrypted = (
"postgresql://user:password@mydb.us-east-1.rds.amazonaws.com:5432/mydb"
)
params: dict[str, Any] = {}
PostgresEngineSpec.update_params_from_encrypted_extra(database, params)
# No modifications should be made
assert params == {}
def test_update_params_from_encrypted_extra_iam_disabled() -> None:
from superset.db_engine_specs.postgres import PostgresEngineSpec
database = MagicMock()
database.encrypted_extra = json.dumps(
{
"aws_iam": {
"enabled": False,
"role_arn": "arn:aws:iam::123456789012:role/TestRole",
"region": "us-east-1",
"db_username": "superset_user",
}
}
)
database.sqlalchemy_uri_decrypted = (
"postgresql://user:password@mydb.us-east-1.rds.amazonaws.com:5432/mydb"
)
params: dict[str, Any] = {}
PostgresEngineSpec.update_params_from_encrypted_extra(database, params)
# No modifications should be made when IAM is disabled
assert params == {}
def test_update_params_from_encrypted_extra_with_iam() -> None:
from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin
from superset.db_engine_specs.postgres import PostgresEngineSpec
database = MagicMock()
database.encrypted_extra = json.dumps(
{
"aws_iam": {
"enabled": True,
"role_arn": "arn:aws:iam::123456789012:role/TestRole",
"region": "us-east-1",
"db_username": "superset_iam_user",
}
}
)
database.sqlalchemy_uri_decrypted = (
"postgresql://user@mydb.cluster-xyz.us-east-1.rds.amazonaws.com:5432/mydb"
)
params: dict[str, Any] = {}
with (
patch.object(
AWSIAMAuthMixin,
"get_iam_credentials",
return_value={
"AccessKeyId": "ASIA...",
"SecretAccessKey": "secret...",
"SessionToken": "token...",
},
),
patch.object(
AWSIAMAuthMixin,
"generate_rds_auth_token",
return_value="iam-auth-token",
),
):
PostgresEngineSpec.update_params_from_encrypted_extra(database, params)
assert "connect_args" in params
assert params["connect_args"]["password"] == "iam-auth-token" # noqa: S105
assert params["connect_args"]["user"] == "superset_iam_user"
assert params["connect_args"]["sslmode"] == "require"
def test_update_params_merges_remaining_encrypted_extra() -> None:
from superset.db_engine_specs.postgres import PostgresEngineSpec
database = MagicMock()
database.encrypted_extra = json.dumps(
{
"aws_iam": {"enabled": False},
"pool_size": 10,
}
)
database.sqlalchemy_uri_decrypted = (
"postgresql://user:password@mydb.us-east-1.rds.amazonaws.com:5432/mydb"
)
params: dict[str, Any] = {}
PostgresEngineSpec.update_params_from_encrypted_extra(database, params)
# aws_iam should be consumed, pool_size should be merged
assert "aws_iam" not in params
assert params["pool_size"] == 10
def test_update_params_from_encrypted_extra_no_encrypted_extra() -> None:
from superset.db_engine_specs.postgres import PostgresEngineSpec
database = MagicMock()
database.encrypted_extra = None
params: dict[str, Any] = {}
PostgresEngineSpec.update_params_from_encrypted_extra(database, params)
# No modifications should be made
assert params == {}
def test_update_params_from_encrypted_extra_invalid_json() -> None:
from superset.db_engine_specs.postgres import PostgresEngineSpec
database = MagicMock()
database.encrypted_extra = "not-valid-json"
params: dict[str, Any] = {}
with pytest.raises(json.JSONDecodeError):
PostgresEngineSpec.update_params_from_encrypted_extra(database, params)
def test_encrypted_extra_sensitive_fields() -> None:
from superset.db_engine_specs.postgres import PostgresEngineSpec
# Verify sensitive fields are properly defined
assert (
"$.aws_iam.external_id" in PostgresEngineSpec.encrypted_extra_sensitive_fields
)
assert "$.aws_iam.role_arn" in PostgresEngineSpec.encrypted_extra_sensitive_fields
def test_mask_encrypted_extra() -> None:
from superset.db_engine_specs.postgres import PostgresEngineSpec
encrypted_extra = json.dumps(
{
"aws_iam": {
"enabled": True,
"role_arn": "arn:aws:iam::123456789012:role/SecretRole",
"external_id": "secret-external-id-12345",
"region": "us-east-1",
"db_username": "superset_user",
}
}
)
masked = PostgresEngineSpec.mask_encrypted_extra(encrypted_extra)
assert masked is not None
masked_config = json.loads(masked)
# role_arn and external_id should be masked
assert (
masked_config["aws_iam"]["role_arn"]
!= "arn:aws:iam::123456789012:role/SecretRole"
)
assert masked_config["aws_iam"]["external_id"] != "secret-external-id-12345"
# Non-sensitive fields should remain unchanged
assert masked_config["aws_iam"]["enabled"] is True
assert masked_config["aws_iam"]["region"] == "us-east-1"
assert masked_config["aws_iam"]["db_username"] == "superset_user"
def test_aurora_postgres_inherits_from_postgres() -> None:
from superset.db_engine_specs.aurora import AuroraPostgresEngineSpec
from superset.db_engine_specs.postgres import PostgresEngineSpec
# Verify inheritance
assert issubclass(AuroraPostgresEngineSpec, PostgresEngineSpec)
# Verify it inherits PostgreSQL capabilities
assert AuroraPostgresEngineSpec.supports_dynamic_schema is True
assert AuroraPostgresEngineSpec.supports_catalog is True
def test_aurora_data_api_classes_unchanged() -> None:
from superset.db_engine_specs.aurora import (
AuroraMySQLDataAPI,
AuroraPostgresDataAPI,
)
# Verify Data API classes are still available and unchanged
assert AuroraMySQLDataAPI.engine == "mysql"
assert AuroraMySQLDataAPI.default_driver == "auroradataapi"
assert AuroraMySQLDataAPI.engine_name == "Aurora MySQL (Data API)"
assert AuroraPostgresDataAPI.engine == "postgresql"
assert AuroraPostgresDataAPI.default_driver == "auroradataapi"
assert AuroraPostgresDataAPI.engine_name == "Aurora PostgreSQL (Data API)"

View File

@@ -0,0 +1,383 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-outside-toplevel, protected-access
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from superset.exceptions import SupersetSecurityException
def test_get_iam_credentials_success() -> None:
from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin
mock_credentials = {
"AccessKeyId": "ASIA...",
"SecretAccessKey": "secret...",
"SessionToken": "token...",
"Expiration": "2025-01-01T00:00:00Z",
}
with patch("boto3.client") as mock_boto3_client:
mock_sts = MagicMock()
mock_sts.assume_role.return_value = {"Credentials": mock_credentials}
mock_boto3_client.return_value = mock_sts
credentials = AWSIAMAuthMixin.get_iam_credentials(
role_arn="arn:aws:iam::123456789012:role/TestRole",
region="us-east-1",
)
assert credentials == mock_credentials
mock_boto3_client.assert_called_once_with("sts", region_name="us-east-1")
mock_sts.assume_role.assert_called_once_with(
RoleArn="arn:aws:iam::123456789012:role/TestRole",
RoleSessionName="superset-iam-session",
DurationSeconds=3600,
)
def test_get_iam_credentials_with_external_id() -> None:
from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin
mock_credentials = {
"AccessKeyId": "ASIA...",
"SecretAccessKey": "secret...",
"SessionToken": "token...",
}
with patch("boto3.client") as mock_boto3_client:
mock_sts = MagicMock()
mock_sts.assume_role.return_value = {"Credentials": mock_credentials}
mock_boto3_client.return_value = mock_sts
credentials = AWSIAMAuthMixin.get_iam_credentials(
role_arn="arn:aws:iam::123456789012:role/TestRole",
region="us-west-2",
external_id="external-id-12345",
session_duration=900,
)
assert credentials == mock_credentials
mock_sts.assume_role.assert_called_once_with(
RoleArn="arn:aws:iam::123456789012:role/TestRole",
RoleSessionName="superset-iam-session",
DurationSeconds=900,
ExternalId="external-id-12345",
)
def test_get_iam_credentials_access_denied() -> None:
from botocore.exceptions import ClientError
from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin
with patch("boto3.client") as mock_boto3_client:
mock_sts = MagicMock()
mock_sts.assume_role.side_effect = ClientError(
{"Error": {"Code": "AccessDenied", "Message": "Access Denied"}},
"AssumeRole",
)
mock_boto3_client.return_value = mock_sts
with pytest.raises(SupersetSecurityException) as exc_info:
AWSIAMAuthMixin.get_iam_credentials(
role_arn="arn:aws:iam::123456789012:role/TestRole",
region="us-east-1",
)
assert "Unable to assume IAM role" in str(exc_info.value)
def test_get_iam_credentials_external_id_mismatch() -> None:
from botocore.exceptions import ClientError
from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin
with patch("boto3.client") as mock_boto3_client:
mock_sts = MagicMock()
mock_sts.assume_role.side_effect = ClientError(
{
"Error": {
"Code": "AccessDenied",
"Message": "The external id does not match",
}
},
"AssumeRole",
)
mock_boto3_client.return_value = mock_sts
with pytest.raises(SupersetSecurityException) as exc_info:
AWSIAMAuthMixin.get_iam_credentials(
role_arn="arn:aws:iam::123456789012:role/TestRole",
region="us-east-1",
external_id="wrong-id",
)
assert "External ID mismatch" in str(exc_info.value)
def test_generate_rds_auth_token() -> None:
from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin
credentials = {
"AccessKeyId": "ASIA...",
"SecretAccessKey": "secret...",
"SessionToken": "token...",
}
with patch("boto3.client") as mock_boto3_client:
mock_rds = MagicMock()
mock_rds.generate_db_auth_token.return_value = "iam-token-12345"
mock_boto3_client.return_value = mock_rds
token = AWSIAMAuthMixin.generate_rds_auth_token(
credentials=credentials,
hostname="mydb.cluster-xyz.us-east-1.rds.amazonaws.com",
port=5432,
username="superset_user",
region="us-east-1",
)
assert token == "iam-token-12345" # noqa: S105
mock_boto3_client.assert_called_once_with(
"rds",
region_name="us-east-1",
aws_access_key_id="ASIA...",
aws_secret_access_key="secret...", # noqa: S106
aws_session_token="token...", # noqa: S106
)
mock_rds.generate_db_auth_token.assert_called_once_with(
DBHostname="mydb.cluster-xyz.us-east-1.rds.amazonaws.com",
Port=5432,
DBUsername="superset_user",
)
def test_apply_iam_authentication() -> None:
from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig
mock_database = MagicMock()
mock_database.sqlalchemy_uri_decrypted = (
"postgresql://user@mydb.cluster-xyz.us-east-1.rds.amazonaws.com:5432/mydb"
)
iam_config: AWSIAMConfig = {
"enabled": True,
"role_arn": "arn:aws:iam::123456789012:role/TestRole",
"region": "us-east-1",
"db_username": "superset_iam_user",
}
params: dict[str, Any] = {}
with (
patch.object(
AWSIAMAuthMixin,
"get_iam_credentials",
return_value={
"AccessKeyId": "ASIA...",
"SecretAccessKey": "secret...",
"SessionToken": "token...",
},
) as mock_get_creds,
patch.object(
AWSIAMAuthMixin,
"generate_rds_auth_token",
return_value="iam-auth-token",
) as mock_gen_token,
):
AWSIAMAuthMixin._apply_iam_authentication(mock_database, params, iam_config)
mock_get_creds.assert_called_once_with(
role_arn="arn:aws:iam::123456789012:role/TestRole",
region="us-east-1",
external_id=None,
session_duration=3600,
)
mock_gen_token.assert_called_once()
token_call_kwargs = mock_gen_token.call_args[1]
assert (
token_call_kwargs["hostname"] == "mydb.cluster-xyz.us-east-1.rds.amazonaws.com"
)
assert token_call_kwargs["port"] == 5432
assert token_call_kwargs["username"] == "superset_iam_user"
assert params["connect_args"]["password"] == "iam-auth-token" # noqa: S105
assert params["connect_args"]["user"] == "superset_iam_user"
assert params["connect_args"]["sslmode"] == "require"
def test_apply_iam_authentication_with_external_id() -> None:
from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig
mock_database = MagicMock()
mock_database.sqlalchemy_uri_decrypted = (
"postgresql://user@mydb.us-west-2.rds.amazonaws.com:5432/mydb"
)
iam_config: AWSIAMConfig = {
"enabled": True,
"role_arn": "arn:aws:iam::222222222222:role/CrossAccountRole",
"external_id": "superset-prod-12345",
"region": "us-west-2",
"db_username": "iam_user",
"session_duration": 1800,
}
params: dict[str, Any] = {}
with (
patch.object(
AWSIAMAuthMixin,
"get_iam_credentials",
return_value={
"AccessKeyId": "ASIA...",
"SecretAccessKey": "secret...",
"SessionToken": "token...",
},
) as mock_get_creds,
patch.object(
AWSIAMAuthMixin,
"generate_rds_auth_token",
return_value="iam-auth-token",
),
):
AWSIAMAuthMixin._apply_iam_authentication(mock_database, params, iam_config)
mock_get_creds.assert_called_once_with(
role_arn="arn:aws:iam::222222222222:role/CrossAccountRole",
region="us-west-2",
external_id="superset-prod-12345",
session_duration=1800,
)
def test_apply_iam_authentication_missing_role_arn() -> None:
from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig
mock_database = MagicMock()
mock_database.sqlalchemy_uri_decrypted = (
"postgresql://user@mydb.us-east-1.rds.amazonaws.com:5432/mydb"
)
iam_config: AWSIAMConfig = {
"enabled": True,
"region": "us-east-1",
"db_username": "superset_iam_user",
}
params: dict[str, Any] = {}
with pytest.raises(SupersetSecurityException) as exc_info:
AWSIAMAuthMixin._apply_iam_authentication(mock_database, params, iam_config)
assert "role_arn" in str(exc_info.value)
def test_apply_iam_authentication_missing_db_username() -> None:
from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig
mock_database = MagicMock()
mock_database.sqlalchemy_uri_decrypted = (
"postgresql://user@mydb.us-east-1.rds.amazonaws.com:5432/mydb"
)
iam_config: AWSIAMConfig = {
"enabled": True,
"role_arn": "arn:aws:iam::123456789012:role/TestRole",
"region": "us-east-1",
}
params: dict[str, Any] = {}
with pytest.raises(SupersetSecurityException) as exc_info:
AWSIAMAuthMixin._apply_iam_authentication(mock_database, params, iam_config)
assert "db_username" in str(exc_info.value)
def test_apply_iam_authentication_default_port() -> None:
from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig
mock_database = MagicMock()
# URI without explicit port
mock_database.sqlalchemy_uri_decrypted = (
"postgresql://user@mydb.us-east-1.rds.amazonaws.com/mydb"
)
iam_config: AWSIAMConfig = {
"enabled": True,
"role_arn": "arn:aws:iam::123456789012:role/TestRole",
"region": "us-east-1",
"db_username": "superset_iam_user",
}
params: dict[str, Any] = {}
with (
patch.object(
AWSIAMAuthMixin,
"get_iam_credentials",
return_value={
"AccessKeyId": "ASIA...",
"SecretAccessKey": "secret...",
"SessionToken": "token...",
},
),
patch.object(
AWSIAMAuthMixin,
"generate_rds_auth_token",
return_value="iam-auth-token",
) as mock_gen_token,
):
AWSIAMAuthMixin._apply_iam_authentication(mock_database, params, iam_config)
# Should use default port 5432
token_call_kwargs = mock_gen_token.call_args[1]
assert token_call_kwargs["port"] == 5432
def test_get_iam_credentials_boto3_not_installed() -> None:
import sys
from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin
# Temporarily hide boto3
boto3_module = sys.modules.get("boto3")
sys.modules["boto3"] = None # type: ignore
try:
with pytest.raises(SupersetSecurityException) as exc_info:
AWSIAMAuthMixin.get_iam_credentials(
role_arn="arn:aws:iam::123456789012:role/TestRole",
region="us-east-1",
)
assert "boto3 is required" in str(exc_info.value)
finally:
# Restore boto3
if boto3_module is not None:
sys.modules["boto3"] = boto3_module
else:
del sys.modules["boto3"]