fix: Refactor SQL username logic (#19914)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley
2022-05-12 21:03:05 -07:00
committed by GitHub
parent fff9ad05d4
commit 449d08b25e
22 changed files with 388 additions and 340 deletions

View File

@@ -21,6 +21,8 @@ import unittest
from unittest import mock
import pytest
from flask import g
from flask.ctx import AppContext
from sqlalchemy import inspect
from tests.integration_tests.fixtures.birth_names_dashboard import (
@@ -41,6 +43,7 @@ from superset.connectors.connector_registry import ConnectorRegistry
from superset.connectors.sqla.models import SqlaTable
from superset.models import core as models
from superset.models.datasource_access_request import DatasourceAccessRequest
from superset.utils.core import get_username, override_user
from superset.utils.database import get_example_database
from .base_tests import SupersetTestCase
@@ -86,7 +89,7 @@ DB_ACCESS_ROLE = "db_access_role"
SCHEMA_ACCESS_ROLE = "schema_access_role"
def create_access_request(session, ds_type, ds_name, role_name, user_name):
def create_access_request(session, ds_type, ds_name, role_name, username):
ds_class = ConnectorRegistry.sources[ds_type]
# TODO: generalize datasource names
if ds_type == "table":
@@ -102,7 +105,7 @@ def create_access_request(session, ds_type, ds_name, role_name, user_name):
access_request = DatasourceAccessRequest(
datasource_id=ds.id,
datasource_type=ds_type,
created_by_fk=security_manager.find_user(username=user_name).id,
created_by_fk=security_manager.find_user(username=username).id,
)
session.add(access_request)
session.commit()
@@ -565,5 +568,46 @@ class TestRequestAccess(SupersetTestCase):
session.commit()
@pytest.mark.parametrize(
"username",
[
None,
"gamma",
],
)
def test_get_username(app_context: AppContext, username: str) -> None:
assert not hasattr(g, "user")
assert get_username() is None
g.user = security_manager.find_user(username)
assert get_username() == username
@pytest.mark.parametrize(
"username",
[
None,
"gamma",
],
)
def test_override_user(app_context: AppContext, username: str) -> None:
admin = security_manager.find_user(username="admin")
user = security_manager.find_user(username)
assert not hasattr(g, "user")
with override_user(user):
assert g.user == user
assert not hasattr(g, "user")
g.user = admin
with override_user(user):
assert g.user == user
assert g.user == admin
if __name__ == "__main__":
unittest.main()

View File

@@ -329,7 +329,7 @@ class SupersetTestCase(TestCase):
self,
sql,
client_id=None,
user_name=None,
username=None,
raise_on_error=False,
query_limit=None,
database_name="examples",
@@ -340,9 +340,9 @@ class SupersetTestCase(TestCase):
ctas_method=CtasMethod.TABLE,
template_params="{}",
):
if user_name:
if username:
self.logout()
self.login(username=(user_name or "admin"))
self.login(username=username)
dbid = SupersetTestCase.get_database_by_name(database_name).id
json_payload = {
"database_id": dbid,
@@ -427,14 +427,14 @@ class SupersetTestCase(TestCase):
self,
sql,
client_id=None,
user_name=None,
username=None,
raise_on_error=False,
database_name="examples",
template_params=None,
):
if user_name:
if username:
self.logout()
self.login(username=(user_name if user_name else "admin"))
self.login(username=username)
dbid = SupersetTestCase.get_database_by_name(database_name).id
resp = self.get_json_resp(
"/superset/validate_sql_json/",

View File

@@ -1064,7 +1064,7 @@ class TestCore(SupersetTestCase):
LIMIT 10;
""",
client_id="client_id_1",
user_name="admin",
username="admin",
)
count_ds = []
count_name = []
@@ -1454,7 +1454,7 @@ class TestCore(SupersetTestCase):
self.run_sql(
"SELECT name FROM birth_names",
"client_id_1",
user_name=username,
username=username,
raise_on_error=True,
sql_editor_id=str(tab_state_id),
)
@@ -1462,7 +1462,7 @@ class TestCore(SupersetTestCase):
self.run_sql(
"SELECT name FROM birth_names",
"client_id_2",
user_name=username,
username=username,
raise_on_error=True,
)

View File

@@ -20,8 +20,10 @@ import textwrap
import unittest
from unittest import mock
from superset import security_manager
from superset.connectors.sqla.models import SqlaTable
from superset.exceptions import SupersetException
from superset.utils.core import override_user
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
load_birth_names_data,
@@ -112,21 +114,22 @@ class TestDatabaseModel(SupersetTestCase):
)
def test_database_impersonate_user(self):
uri = "mysql://root@localhost"
example_user = "giuseppe"
example_user = security_manager.find_user(username="gamma")
model = Database(database_name="test_database", sqlalchemy_uri=uri)
model.impersonate_user = True
user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username
self.assertEqual(example_user, user_name)
with override_user(example_user):
model.impersonate_user = True
username = make_url(model.get_sqla_engine().url).username
self.assertEqual(example_user.username, username)
model.impersonate_user = False
user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username
self.assertNotEqual(example_user, user_name)
model.impersonate_user = False
username = make_url(model.get_sqla_engine().url).username
self.assertNotEqual(example_user.username, username)
@mock.patch("superset.models.core.create_engine")
def test_impersonate_user_presto(self, mocked_create_engine):
uri = "presto://localhost"
principal_user = "logged_in_user"
principal_user = security_manager.find_user(username="gamma")
extra = """
{
"metadata_params": {},
@@ -142,64 +145,66 @@ class TestDatabaseModel(SupersetTestCase):
}
"""
model = Database(database_name="test_database", sqlalchemy_uri=uri, extra=extra)
with override_user(principal_user):
model = Database(
database_name="test_database", sqlalchemy_uri=uri, extra=extra
)
model.impersonate_user = True
model.get_sqla_engine()
call_args = mocked_create_engine.call_args
model.impersonate_user = True
model.get_sqla_engine(user_name=principal_user)
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "presto://gamma@localhost"
assert str(call_args[0][0]) == "presto://logged_in_user@localhost"
assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
"principal_username": "gamma",
}
assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
"principal_username": "logged_in_user",
}
model.impersonate_user = False
model.get_sqla_engine()
call_args = mocked_create_engine.call_args
model.impersonate_user = False
model.get_sqla_engine(user_name=principal_user)
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "presto://localhost"
assert str(call_args[0][0]) == "presto://localhost"
assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
}
assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
}
@mock.patch("superset.models.core.create_engine")
def test_impersonate_user_trino(self, mocked_create_engine):
uri = "trino://localhost"
principal_user = "logged_in_user"
principal_user = security_manager.find_user(username="gamma")
model = Database(database_name="test_database", sqlalchemy_uri=uri)
with override_user(principal_user):
model = Database(
database_name="test_database", sqlalchemy_uri="trino://localhost"
)
model.impersonate_user = True
model.get_sqla_engine()
call_args = mocked_create_engine.call_args
model.impersonate_user = True
model.get_sqla_engine(user_name=principal_user)
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "trino://localhost"
assert call_args[1]["connect_args"] == {"user": "gamma"}
assert str(call_args[0][0]) == "trino://localhost"
model = Database(
database_name="test_database",
sqlalchemy_uri="trino://original_user:original_user_password@localhost",
)
assert call_args[1]["connect_args"] == {
"user": "logged_in_user",
}
model.impersonate_user = True
model.get_sqla_engine()
call_args = mocked_create_engine.call_args
uri = "trino://original_user:original_user_password@localhost"
model = Database(database_name="test_database", sqlalchemy_uri=uri)
model.impersonate_user = True
model.get_sqla_engine(user_name=principal_user)
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "trino://original_user@localhost"
assert call_args[1]["connect_args"] == {"user": "logged_in_user"}
assert str(call_args[0][0]) == "trino://original_user@localhost"
assert call_args[1]["connect_args"] == {"user": "gamma"}
@mock.patch("superset.models.core.create_engine")
def test_impersonate_user_hive(self, mocked_create_engine):
uri = "hive://localhost"
principal_user = "logged_in_user"
principal_user = security_manager.find_user(username="gamma")
extra = """
{
"metadata_params": {},
@@ -215,32 +220,34 @@ class TestDatabaseModel(SupersetTestCase):
}
"""
model = Database(database_name="test_database", sqlalchemy_uri=uri, extra=extra)
with override_user(principal_user):
model = Database(
database_name="test_database", sqlalchemy_uri=uri, extra=extra
)
model.impersonate_user = True
model.get_sqla_engine()
call_args = mocked_create_engine.call_args
model.impersonate_user = True
model.get_sqla_engine(user_name=principal_user)
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "hive://localhost"
assert str(call_args[0][0]) == "hive://localhost"
assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
"configuration": {"hive.server2.proxy.user": "gamma"},
}
assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
"configuration": {"hive.server2.proxy.user": "logged_in_user"},
}
model.impersonate_user = False
model.get_sqla_engine()
call_args = mocked_create_engine.call_args
model.impersonate_user = False
model.get_sqla_engine(user_name=principal_user)
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "hive://localhost"
assert str(call_args[0][0]) == "hive://localhost"
assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
}
assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
}
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_select_star(self):
@@ -345,19 +352,6 @@ class TestDatabaseModel(SupersetTestCase):
df = main_db.get_df("USE superset; SELECT ';';", None)
self.assertEqual(df.iat[0, 0], ";")
@mock.patch("superset.models.core.Database.get_sqla_engine")
def test_username_param(self, mocked_get_sqla_engine):
main_db = get_example_database()
main_db.impersonate_user = True
test_username = "test_username_param"
if main_db.backend == "mysql":
main_db.get_df("USE superset; SELECT 1", username=test_username)
mocked_get_sqla_engine.assert_called_with(
schema=None,
user_name="test_username_param",
)
@mock.patch("superset.models.core.create_engine")
def test_get_sqla_engine(self, mocked_create_engine):
model = Database(

View File

@@ -187,7 +187,7 @@ class TestPrestoValidator(SupersetTestCase):
"message": "your query isn't how I like it",
}
@patch("superset.sql_validators.presto_db.g")
@patch("superset.utils.core.g")
def test_validator_success(self, flask_g):
flask_g.user.username = "nobody"
sql = "SELECT 1 FROM default.notarealtable"
@@ -197,7 +197,7 @@ class TestPrestoValidator(SupersetTestCase):
self.assertEqual([], errors)
@patch("superset.sql_validators.presto_db.g")
@patch("superset.utils.core.g")
def test_validator_db_error(self, flask_g):
flask_g.user.username = "nobody"
sql = "SELECT 1 FROM default.notarealtable"
@@ -209,7 +209,7 @@ class TestPrestoValidator(SupersetTestCase):
with self.assertRaises(PrestoSQLValidationError):
self.validator.validate(sql, schema, self.database)
@patch("superset.sql_validators.presto_db.g")
@patch("superset.utils.core.g")
def test_validator_unexpected_error(self, flask_g):
flask_g.user.username = "nobody"
sql = "SELECT 1 FROM default.notarealtable"
@@ -221,7 +221,7 @@ class TestPrestoValidator(SupersetTestCase):
with self.assertRaises(Exception):
self.validator.validate(sql, schema, self.database)
@patch("superset.sql_validators.presto_db.g")
@patch("superset.utils.core.g")
def test_validator_query_error(self, flask_g):
flask_g.user.username = "nobody"
sql = "SELECT 1 FROM default.notarealtable"

View File

@@ -68,9 +68,9 @@ class TestSqlLab(SupersetTestCase):
def run_some_queries(self):
db.session.query(Query).delete()
db.session.commit()
self.run_sql(QUERY_1, client_id="client_id_1", user_name="admin")
self.run_sql(QUERY_2, client_id="client_id_3", user_name="admin")
self.run_sql(QUERY_3, client_id="client_id_2", user_name="gamma_sqllab")
self.run_sql(QUERY_1, client_id="client_id_1", username="admin")
self.run_sql(QUERY_2, client_id="client_id_3", username="admin")
self.run_sql(QUERY_3, client_id="client_id_2", username="gamma_sqllab")
self.logout()
def tearDown(self):
@@ -162,7 +162,7 @@ class TestSqlLab(SupersetTestCase):
db.session.commit()
with freeze_time(datetime.now().isoformat(timespec="seconds")):
self.run_sql(sql_statement, "1")
self.run_sql(sql_statement, "1", username="admin")
saved_query_ = (
db.session.query(SavedQuery)
.filter(
@@ -248,7 +248,7 @@ class TestSqlLab(SupersetTestCase):
# Gamma user, with sqllab and db permission
self.create_user_with_roles("Gagarin", ["ExampleDBAccess", "Gamma", "sql_lab"])
data = self.run_sql(QUERY_1, "1", user_name="Gagarin")
data = self.run_sql(QUERY_1, "1", username="Gagarin")
db.session.query(Query).delete()
db.session.commit()
self.assertLess(0, len(data["data"]))
@@ -278,14 +278,14 @@ class TestSqlLab(SupersetTestCase):
)
data = self.run_sql(
f"SELECT * FROM {CTAS_SCHEMA_NAME}.test_table", "3", user_name="SchemaUser"
f"SELECT * FROM {CTAS_SCHEMA_NAME}.test_table", "3", username="SchemaUser"
)
self.assertEqual(1, len(data["data"]))
data = self.run_sql(
f"SELECT * FROM {CTAS_SCHEMA_NAME}.test_table",
"4",
user_name="SchemaUser",
username="SchemaUser",
schema=CTAS_SCHEMA_NAME,
)
self.assertEqual(1, len(data["data"]))
@@ -295,7 +295,7 @@ class TestSqlLab(SupersetTestCase):
data = self.run_sql(
"SELECT * FROM test_table",
"5",
user_name="SchemaUser",
username="SchemaUser",
schema=CTAS_SCHEMA_NAME,
)
self.assertEqual(1, len(data["data"]))
@@ -441,7 +441,7 @@ class TestSqlLab(SupersetTestCase):
self.run_sql(
"SELECT name as col, gender as col FROM birth_names LIMIT 10",
client_id="2e2df3",
user_name="admin",
username="admin",
raise_on_error=True,
)
@@ -747,7 +747,6 @@ class TestSqlLab(SupersetTestCase):
rendered_query=sql,
return_results=True,
store_results=False,
user_name="admin",
session=mock_session,
start_time=None,
expand_data=False,
@@ -758,7 +757,6 @@ class TestSqlLab(SupersetTestCase):
mock.call(
"SET @value = 42",
mock_query,
"admin",
mock_session,
mock_cursor,
None,
@@ -767,7 +765,6 @@ class TestSqlLab(SupersetTestCase):
mock.call(
"SELECT @value AS foo",
mock_query,
"admin",
mock_session,
mock_cursor,
None,
@@ -804,7 +801,6 @@ class TestSqlLab(SupersetTestCase):
rendered_query=sql,
return_results=True,
store_results=False,
user_name="admin",
session=mock_session,
start_time=None,
expand_data=False,
@@ -858,7 +854,6 @@ class TestSqlLab(SupersetTestCase):
rendered_query=sql,
return_results=True,
store_results=False,
user_name="admin",
session=mock_session,
start_time=None,
expand_data=False,
@@ -869,7 +864,6 @@ class TestSqlLab(SupersetTestCase):
mock.call(
"SET @value = 42",
mock_query,
"admin",
mock_session,
mock_cursor,
None,
@@ -878,7 +872,6 @@ class TestSqlLab(SupersetTestCase):
mock.call(
"SELECT @value AS foo",
mock_query,
"admin",
mock_session,
mock_cursor,
None,
@@ -895,7 +888,6 @@ class TestSqlLab(SupersetTestCase):
rendered_query=sql,
return_results=True,
store_results=False,
user_name="admin",
session=mock_session,
start_time=None,
expand_data=False,
@@ -929,7 +921,6 @@ class TestSqlLab(SupersetTestCase):
rendered_query=sql,
return_results=True,
store_results=False,
user_name="admin",
session=mock_session,
start_time=None,
expand_data=False,