fix: Refactor SQL username logic (#19914)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley
2022-05-12 21:03:05 -07:00
committed by GitHub
parent fff9ad05d4
commit 449d08b25e
22 changed files with 388 additions and 340 deletions

View File

@@ -20,8 +20,10 @@ import textwrap
import unittest
from unittest import mock
from superset import security_manager
from superset.connectors.sqla.models import SqlaTable
from superset.exceptions import SupersetException
from superset.utils.core import override_user
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
load_birth_names_data,
@@ -112,21 +114,22 @@ class TestDatabaseModel(SupersetTestCase):
)
def test_database_impersonate_user(self):
uri = "mysql://root@localhost"
example_user = "giuseppe"
example_user = security_manager.find_user(username="gamma")
model = Database(database_name="test_database", sqlalchemy_uri=uri)
model.impersonate_user = True
user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username
self.assertEqual(example_user, user_name)
with override_user(example_user):
model.impersonate_user = True
username = make_url(model.get_sqla_engine().url).username
self.assertEqual(example_user.username, username)
model.impersonate_user = False
user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username
self.assertNotEqual(example_user, user_name)
model.impersonate_user = False
username = make_url(model.get_sqla_engine().url).username
self.assertNotEqual(example_user.username, username)
@mock.patch("superset.models.core.create_engine")
def test_impersonate_user_presto(self, mocked_create_engine):
uri = "presto://localhost"
principal_user = "logged_in_user"
principal_user = security_manager.find_user(username="gamma")
extra = """
{
"metadata_params": {},
@@ -142,64 +145,66 @@ class TestDatabaseModel(SupersetTestCase):
}
"""
model = Database(database_name="test_database", sqlalchemy_uri=uri, extra=extra)
with override_user(principal_user):
model = Database(
database_name="test_database", sqlalchemy_uri=uri, extra=extra
)
model.impersonate_user = True
model.get_sqla_engine()
call_args = mocked_create_engine.call_args
model.impersonate_user = True
model.get_sqla_engine(user_name=principal_user)
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "presto://gamma@localhost"
assert str(call_args[0][0]) == "presto://logged_in_user@localhost"
assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
"principal_username": "gamma",
}
assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
"principal_username": "logged_in_user",
}
model.impersonate_user = False
model.get_sqla_engine()
call_args = mocked_create_engine.call_args
model.impersonate_user = False
model.get_sqla_engine(user_name=principal_user)
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "presto://localhost"
assert str(call_args[0][0]) == "presto://localhost"
assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
}
assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
}
@mock.patch("superset.models.core.create_engine")
def test_impersonate_user_trino(self, mocked_create_engine):
uri = "trino://localhost"
principal_user = "logged_in_user"
principal_user = security_manager.find_user(username="gamma")
model = Database(database_name="test_database", sqlalchemy_uri=uri)
with override_user(principal_user):
model = Database(
database_name="test_database", sqlalchemy_uri="trino://localhost"
)
model.impersonate_user = True
model.get_sqla_engine()
call_args = mocked_create_engine.call_args
model.impersonate_user = True
model.get_sqla_engine(user_name=principal_user)
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "trino://localhost"
assert call_args[1]["connect_args"] == {"user": "gamma"}
assert str(call_args[0][0]) == "trino://localhost"
model = Database(
database_name="test_database",
sqlalchemy_uri="trino://original_user:original_user_password@localhost",
)
assert call_args[1]["connect_args"] == {
"user": "logged_in_user",
}
model.impersonate_user = True
model.get_sqla_engine()
call_args = mocked_create_engine.call_args
uri = "trino://original_user:original_user_password@localhost"
model = Database(database_name="test_database", sqlalchemy_uri=uri)
model.impersonate_user = True
model.get_sqla_engine(user_name=principal_user)
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "trino://original_user@localhost"
assert call_args[1]["connect_args"] == {"user": "logged_in_user"}
assert str(call_args[0][0]) == "trino://original_user@localhost"
assert call_args[1]["connect_args"] == {"user": "gamma"}
@mock.patch("superset.models.core.create_engine")
def test_impersonate_user_hive(self, mocked_create_engine):
uri = "hive://localhost"
principal_user = "logged_in_user"
principal_user = security_manager.find_user(username="gamma")
extra = """
{
"metadata_params": {},
@@ -215,32 +220,34 @@ class TestDatabaseModel(SupersetTestCase):
}
"""
model = Database(database_name="test_database", sqlalchemy_uri=uri, extra=extra)
with override_user(principal_user):
model = Database(
database_name="test_database", sqlalchemy_uri=uri, extra=extra
)
model.impersonate_user = True
model.get_sqla_engine()
call_args = mocked_create_engine.call_args
model.impersonate_user = True
model.get_sqla_engine(user_name=principal_user)
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "hive://localhost"
assert str(call_args[0][0]) == "hive://localhost"
assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
"configuration": {"hive.server2.proxy.user": "gamma"},
}
assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
"configuration": {"hive.server2.proxy.user": "logged_in_user"},
}
model.impersonate_user = False
model.get_sqla_engine()
call_args = mocked_create_engine.call_args
model.impersonate_user = False
model.get_sqla_engine(user_name=principal_user)
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "hive://localhost"
assert str(call_args[0][0]) == "hive://localhost"
assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
}
assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
}
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_select_star(self):
@@ -345,19 +352,6 @@ class TestDatabaseModel(SupersetTestCase):
df = main_db.get_df("USE superset; SELECT ';';", None)
self.assertEqual(df.iat[0, 0], ";")
@mock.patch("superset.models.core.Database.get_sqla_engine")
def test_username_param(self, mocked_get_sqla_engine):
main_db = get_example_database()
main_db.impersonate_user = True
test_username = "test_username_param"
if main_db.backend == "mysql":
main_db.get_df("USE superset; SELECT 1", username=test_username)
mocked_get_sqla_engine.assert_called_with(
schema=None,
user_name="test_username_param",
)
@mock.patch("superset.models.core.create_engine")
def test_get_sqla_engine(self, mocked_create_engine):
model = Database(