mirror of
https://github.com/apache/superset.git
synced 2026-04-18 23:55:00 +00:00
feat(presto): add support for user impersonation (#13214)
* changes to support presto impersionation with ldap * renamed method to match 30 char limit * import spell check * added presto impersonation test * refactored impersionation code to generalize for extension * moving config_args mutation to the update_connect_args_for_impersonation * moving config_args mutation to the update_connect_args_for_impersonation * nits * refactored update_impersonation_config method name to match lint rule * reduced comment line length * black reformats Co-authored-by: rijojoseph01 <rijo.joseph@myntra.com>
This commit is contained in:
@@ -909,19 +909,19 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
url.username = username
|
||||
|
||||
@classmethod
|
||||
def get_configuration_for_impersonation( # pylint: disable=invalid-name
|
||||
cls, uri: str, impersonate_user: bool, username: Optional[str]
|
||||
) -> Dict[str, str]:
|
||||
def update_impersonation_config(
|
||||
cls, connect_args: Dict[str, Any], uri: str, username: Optional[str],
|
||||
) -> None:
|
||||
"""
|
||||
Return a configuration dictionary that can be merged with other configs
|
||||
Update a configuration dictionary
|
||||
that can set the correct properties for impersonating users
|
||||
|
||||
:param connect_args: config to be updated
|
||||
:param uri: URI
|
||||
:param impersonate_user: Flag indicating if impersonation is enabled
|
||||
:param username: Effective username
|
||||
:return: Configs required for impersonation
|
||||
:return: None
|
||||
"""
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def execute(cls, cursor: Any, query: str, **kwargs: Any) -> None:
|
||||
|
||||
@@ -487,26 +487,28 @@ class HiveEngineSpec(PrestoEngineSpec):
|
||||
# the configuraiton dictionary. See get_configuration_for_impersonation
|
||||
|
||||
@classmethod
|
||||
def get_configuration_for_impersonation(
|
||||
cls, uri: str, impersonate_user: bool, username: Optional[str]
|
||||
) -> Dict[str, str]:
|
||||
def update_impersonation_config(
|
||||
cls, connect_args: Dict[str, Any], uri: str, username: Optional[str],
|
||||
) -> None:
|
||||
"""
|
||||
Return a configuration dictionary that can be merged with other configs
|
||||
Update a configuration dictionary
|
||||
that can set the correct properties for impersonating users
|
||||
:param connect_args:
|
||||
:param uri: URI string
|
||||
:param impersonate_user: Flag indicating if impersonation is enabled
|
||||
:param username: Effective username
|
||||
:return: Configs required for impersonation
|
||||
:return: None
|
||||
"""
|
||||
configuration = {}
|
||||
url = make_url(uri)
|
||||
backend_name = url.get_backend_name()
|
||||
|
||||
# Must be Hive connection, enable impersonation, and set optional param
|
||||
# auth=LDAP|KERBEROS
|
||||
if backend_name == "hive" and impersonate_user and username is not None:
|
||||
# this will set hive.server2.proxy.user=$effective_username on connect_args['configuration']
|
||||
if backend_name == "hive" and username is not None:
|
||||
configuration = connect_args.get("configuration", {})
|
||||
configuration["hive.server2.proxy.user"] = username
|
||||
return configuration
|
||||
connect_args["configuration"] = configuration
|
||||
|
||||
@staticmethod
|
||||
def execute( # type: ignore
|
||||
|
||||
@@ -33,7 +33,7 @@ from sqlalchemy import Column, literal_column, types
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.result import RowProxy
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.engine.url import make_url, URL
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import ColumnClause, Select
|
||||
|
||||
@@ -136,6 +136,28 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
|
||||
version = extra.get("version")
|
||||
return version is not None and StrictVersion(version) >= StrictVersion("0.319")
|
||||
|
||||
@classmethod
|
||||
def update_impersonation_config(
|
||||
cls, connect_args: Dict[str, Any], uri: str, username: Optional[str],
|
||||
) -> None:
|
||||
"""
|
||||
Update a configuration dictionary
|
||||
that can set the correct properties for impersonating users
|
||||
:param connect_args: config to be updated
|
||||
:param uri: URI string
|
||||
:param impersonate_user: Flag indicating if impersonation is enabled
|
||||
:param username: Effective username
|
||||
:return: None
|
||||
"""
|
||||
url = make_url(uri)
|
||||
backend_name = url.get_backend_name()
|
||||
|
||||
# Must be Presto connection, enable impersonation, and set optional param
|
||||
# auth=LDAP|KERBEROS
|
||||
# Set principal_username=$effective_username
|
||||
if backend_name == "presto" and username is not None:
|
||||
connect_args["principal_username"] = username
|
||||
|
||||
@classmethod
|
||||
def get_table_names(
|
||||
cls, database: "Database", inspector: Inspector, schema: Optional[str]
|
||||
|
||||
@@ -325,16 +325,11 @@ class Database(
|
||||
params["poolclass"] = NullPool
|
||||
|
||||
connect_args = params.get("connect_args", {})
|
||||
configuration = connect_args.get("configuration", {})
|
||||
|
||||
# If using Hive, this will set hive.server2.proxy.user=$effective_username
|
||||
configuration.update(
|
||||
self.db_engine_spec.get_configuration_for_impersonation(
|
||||
str(sqlalchemy_url), self.impersonate_user, effective_username
|
||||
if self.impersonate_user:
|
||||
self.db_engine_spec.update_impersonation_config(
|
||||
connect_args, str(sqlalchemy_url), effective_username
|
||||
)
|
||||
)
|
||||
if configuration:
|
||||
connect_args["configuration"] = configuration
|
||||
|
||||
if connect_args:
|
||||
params["connect_args"] = connect_args
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
# isort:skip_file
|
||||
import textwrap
|
||||
import unittest
|
||||
from unittest import mock
|
||||
from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices
|
||||
|
||||
import pandas
|
||||
@@ -110,6 +111,98 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username
|
||||
self.assertNotEqual(example_user, user_name)
|
||||
|
||||
@mock.patch("superset.models.core.create_engine")
|
||||
def test_impersonate_user_presto(self, mocked_create_engine):
|
||||
uri = "presto://localhost"
|
||||
principal_user = "logged_in_user"
|
||||
extra = """
|
||||
{
|
||||
"metadata_params": {},
|
||||
"engine_params": {
|
||||
"connect_args":{
|
||||
"protocol": "https",
|
||||
"username":"original_user",
|
||||
"password":"original_user_password"
|
||||
}
|
||||
},
|
||||
"metadata_cache_timeout": {},
|
||||
"schemas_allowed_for_csv_upload": []
|
||||
}
|
||||
"""
|
||||
|
||||
model = Database(database_name="test_database", sqlalchemy_uri=uri, extra=extra)
|
||||
|
||||
model.impersonate_user = True
|
||||
model.get_sqla_engine(user_name=principal_user)
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "presto://logged_in_user@localhost"
|
||||
|
||||
assert call_args[1]["connect_args"] == {
|
||||
"protocol": "https",
|
||||
"username": "original_user",
|
||||
"password": "original_user_password",
|
||||
"principal_username": "logged_in_user",
|
||||
}
|
||||
|
||||
model.impersonate_user = False
|
||||
model.get_sqla_engine(user_name=principal_user)
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "presto://localhost"
|
||||
|
||||
assert call_args[1]["connect_args"] == {
|
||||
"protocol": "https",
|
||||
"username": "original_user",
|
||||
"password": "original_user_password",
|
||||
}
|
||||
|
||||
@mock.patch("superset.models.core.create_engine")
|
||||
def test_impersonate_user_hive(self, mocked_create_engine):
|
||||
uri = "hive://localhost"
|
||||
principal_user = "logged_in_user"
|
||||
extra = """
|
||||
{
|
||||
"metadata_params": {},
|
||||
"engine_params": {
|
||||
"connect_args":{
|
||||
"protocol": "https",
|
||||
"username":"original_user",
|
||||
"password":"original_user_password"
|
||||
}
|
||||
},
|
||||
"metadata_cache_timeout": {},
|
||||
"schemas_allowed_for_csv_upload": []
|
||||
}
|
||||
"""
|
||||
|
||||
model = Database(database_name="test_database", sqlalchemy_uri=uri, extra=extra)
|
||||
|
||||
model.impersonate_user = True
|
||||
model.get_sqla_engine(user_name=principal_user)
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "hive://localhost"
|
||||
|
||||
assert call_args[1]["connect_args"] == {
|
||||
"protocol": "https",
|
||||
"username": "original_user",
|
||||
"password": "original_user_password",
|
||||
"configuration": {"hive.server2.proxy.user": "logged_in_user"},
|
||||
}
|
||||
|
||||
model.impersonate_user = False
|
||||
model.get_sqla_engine(user_name=principal_user)
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "hive://localhost"
|
||||
|
||||
assert call_args[1]["connect_args"] == {
|
||||
"protocol": "https",
|
||||
"username": "original_user",
|
||||
"password": "original_user_password",
|
||||
}
|
||||
|
||||
@pytest.mark.usefixtures("load_energy_table_with_slice")
|
||||
def test_select_star(self):
|
||||
db = get_example_database()
|
||||
|
||||
Reference in New Issue
Block a user