feat(revert): Re-introduces the RLS page (#23777)

This commit is contained in:
Michael S. Molina
2023-04-24 13:10:58 -03:00
committed by GitHub
parent c536d92ade
commit f7810b6020
22 changed files with 2642 additions and 257 deletions

View File

@@ -21,13 +21,18 @@ from unittest import mock
import pytest
from flask import g
import json
import prison
from superset import db, security_manager
from superset import db, security_manager, app
from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable
from superset.security.guest_token import (
GuestTokenResourceType,
GuestUser,
)
from flask_babel import lazy_gettext as _
from flask_appbuilder.models.sqla import filters
from ..conftest import with_config
from ..base_tests import SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
@@ -38,6 +43,7 @@ from tests.integration_tests.fixtures.energy_dashboard import (
load_energy_table_data,
)
from tests.integration_tests.fixtures.unicode_dashboard import (
UNICODE_TBL_NAME,
load_unicode_dashboard_with_slice,
load_unicode_data,
)
@@ -173,19 +179,18 @@ class TestRowLevelSecurity(SupersetTestCase):
self.login(username="admin")
test_dataset = self._get_test_dataset()
rv = self.client.post(
"/rowlevelsecurityfiltersmodelview/add",
data=dict(
name="rls1",
description="Some description",
filter_type="Regular",
tables=[test_dataset.id],
roles=[security_manager.find_role("Alpha").id],
group_key="group_key_1",
clause="client_id=1",
),
follow_redirects=True,
"/api/v1/rowlevelsecurity/",
json={
"name": "rls1",
"description": "Some description",
"filter_type": "Regular",
"tables": [test_dataset.id],
"roles": [security_manager.find_role("Alpha").id],
"group_key": "group_key_1",
"clause": "client_id=1",
},
)
self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.status_code, 201)
rls1 = (
db.session.query(RowLevelSecurityFilter).filter_by(name="rls1")
).one_or_none()
@@ -200,41 +205,39 @@ class TestRowLevelSecurity(SupersetTestCase):
self.login(username="admin")
test_dataset = self._get_test_dataset()
rv = self.client.post(
"/rowlevelsecurityfiltersmodelview/add",
data=dict(
name="rls_entry1",
description="Some description",
filter_type="Regular",
tables=[test_dataset.id],
roles=[security_manager.find_role("Alpha").id],
group_key="group_key_1",
clause="client_id=1",
),
follow_redirects=True,
"/api/v1/rowlevelsecurity/",
json={
"name": "rls_entry1",
"description": "Some description",
"filter_type": "Regular",
"tables": [test_dataset.id],
"roles": [security_manager.find_role("Alpha").id],
"group_key": "group_key_1",
"clause": "client_id=1",
},
)
self.assertEqual(rv.status_code, 200)
data = rv.data.decode("utf-8")
assert "Already exists." in data
self.assertEqual(rv.status_code, 422)
data = json.loads(rv.data.decode("utf-8"))
assert "Create failed" in data["message"]
@pytest.mark.usefixtures("create_dataset")
def test_model_view_rls_add_tables_required(self):
self.login(username="admin")
rv = self.client.post(
"/rowlevelsecurityfiltersmodelview/add",
data=dict(
name="rls1",
description="Some description",
filter_type="Regular",
tables=[],
roles=[security_manager.find_role("Alpha").id],
group_key="group_key_1",
clause="client_id=1",
),
follow_redirects=True,
"/api/v1/rowlevelsecurity/",
json={
"name": "rls1",
"description": "Some description",
"filter_type": "Regular",
"tables": [],
"roles": [security_manager.find_role("Alpha").id],
"group_key": "group_key_1",
"clause": "client_id=1",
},
)
self.assertEqual(rv.status_code, 200)
data = rv.data.decode("utf-8")
assert "This field is required." in data
self.assertEqual(rv.status_code, 400)
data = json.loads(rv.data.decode("utf-8"))
assert data["message"] == {"tables": ["Shorter than minimum length 1."]}
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_rls_filter_alters_energy_query(self):
@@ -303,6 +306,340 @@ class TestRowLevelSecurity(SupersetTestCase):
assert not self.BASE_FILTER_REGEX.search(sql)
class TestRowLevelSecurityCreateAPI(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_invalid_role_failure(self):
self.login("Admin")
payload = {
"name": "rls 1",
"clause": "1=1",
"filter_type": "Base",
"tables": [1],
"roles": [999999],
}
rv = self.client.post("/api/v1/rowlevelsecurity/", json=payload)
status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8"))
self.assertEqual(status_code, 422)
self.assertEqual(data["message"], "[l'Some roles do not exist']")
def test_invalid_table_failure(self):
self.login("Admin")
payload = {
"name": "rls 1",
"clause": "1=1",
"filter_type": "Base",
"tables": [999999],
"roles": [1],
}
rv = self.client.post("/api/v1/rowlevelsecurity/", json=payload)
status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8"))
self.assertEqual(status_code, 422)
self.assertEqual(data["message"], "[l'Datasource does not exist']")
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_post_success(self):
table = db.session.query(SqlaTable).first()
self.login("Admin")
payload = {
"name": "rls 1",
"clause": "1=1",
"filter_type": "Base",
"tables": [table.id],
"roles": [1],
}
rv = self.client.post("/api/v1/rowlevelsecurity/", json=payload)
status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8"))
self.assertEqual(status_code, 201)
rls = (
db.session.query(RowLevelSecurityFilter)
.filter(RowLevelSecurityFilter.id == data["id"])
.one_or_none()
)
assert rls
self.assertEqual(rls.name, "rls 1")
self.assertEqual(rls.clause, "1=1")
self.assertEqual(rls.filter_type, "Base")
self.assertEqual(rls.tables[0].id, table.id)
self.assertEqual(rls.roles[0].id, 1)
db.session.delete(rls)
db.session.commit()
class TestRowLevelSecurityUpdateAPI(SupersetTestCase):
def test_invalid_id_failure(self):
self.login("Admin")
payload = {
"name": "rls 1",
"clause": "1=1",
"filter_type": "Base",
"tables": [1],
"roles": [1],
}
rv = self.client.put("/api/v1/rowlevelsecurity/99999999", json=payload)
status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8"))
self.assertEqual(status_code, 404)
self.assertEqual(data["message"], "Not found")
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_invalid_role_failure(self):
table = db.session.query(SqlaTable).first()
rls = RowLevelSecurityFilter(
name="rls test invalid role",
clause="1=1",
filter_type="Regular",
tables=[table],
)
db.session.add(rls)
db.session.commit()
self.login("Admin")
payload = {
"roles": [999999],
}
rv = self.client.put(f"/api/v1/rowlevelsecurity/{rls.id}", json=payload)
status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8"))
self.assertEqual(status_code, 422)
self.assertEqual(data["message"], "[l'Some roles do not exist']")
db.session.delete(rls)
db.session.commit()
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_invalid_table_failure(self):
table = db.session.query(SqlaTable).first()
rls = RowLevelSecurityFilter(
name="rls test invalid role",
clause="1=1",
filter_type="Regular",
tables=[table],
)
db.session.add(rls)
db.session.commit()
self.login("Admin")
payload = {
"name": "rls 1",
"clause": "1=1",
"filter_type": "Base",
"tables": [999999],
"roles": [1],
}
rv = self.client.put(f"/api/v1/rowlevelsecurity/{rls.id}", json=payload)
status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8"))
self.assertEqual(status_code, 422)
self.assertEqual(data["message"], "[l'Datasource does not exist']")
db.session.delete(rls)
db.session.commit()
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_put_success(self):
tables = db.session.query(SqlaTable).limit(2).all()
roles = db.session.query(security_manager.role_model).limit(2).all()
rls = RowLevelSecurityFilter(
name="rls 1",
clause="1=1",
filter_type="Regular",
tables=[tables[0]],
roles=[roles[0]],
)
db.session.add(rls)
db.session.commit()
self.login("Admin")
payload = {
"name": "rls put success",
"clause": "2=2",
"filter_type": "Base",
"tables": [tables[1].id],
"roles": [roles[1].id],
}
rv = self.client.put(f"/api/v1/rowlevelsecurity/{rls.id}", json=payload)
status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8"))
self.assertEqual(status_code, 201)
rls = (
db.session.query(RowLevelSecurityFilter)
.filter(RowLevelSecurityFilter.id == rls.id)
.one_or_none()
)
self.assertEqual(rls.name, "rls put success")
self.assertEqual(rls.clause, "2=2")
self.assertEqual(rls.filter_type, "Base")
self.assertEqual(rls.tables[0].id, tables[1].id)
self.assertEqual(rls.roles[0].id, roles[1].id)
db.session.delete(rls)
db.session.commit()
class TestRowLevelSecurityBulkDeleteAPI(SupersetTestCase):
def test_invalid_id_failure(self):
self.login("Admin")
ids_to_delete = prison.dumps([10000, 10001, 100002])
rv = self.client.delete(f"/api/v1/rowlevelsecurity/?q={ids_to_delete}")
status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8"))
self.assertEqual(status_code, 404)
self.assertEqual(data["message"], "Not found")
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_bulk_delete_success(self):
tables = db.session.query(SqlaTable).limit(2).all()
roles = db.session.query(security_manager.role_model).limit(2).all()
rls_1 = RowLevelSecurityFilter(
name="rls 1",
clause="1=1",
filter_type="Regular",
tables=[tables[0]],
roles=[roles[0]],
)
rls_2 = RowLevelSecurityFilter(
name="rls 2",
clause="2=2",
filter_type="Base",
tables=[tables[1]],
roles=[roles[1]],
)
db.session.add_all([rls_1, rls_2])
db.session.commit()
self.login("Admin")
ids_to_delete = prison.dumps([rls_1.id, rls_2.id])
rv = self.client.delete(f"/api/v1/rowlevelsecurity/?q={ids_to_delete}")
status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8"))
self.assertEqual(status_code, 200)
self.assertEqual(data["message"], "Deleted 2 rules")
class TestRowLevelSecurityWithRelatedAPI(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_data")
@pytest.mark.usefixtures("load_energy_table_data")
def test_rls_tables_related_api(self):
self.login("Admin")
params = prison.dumps({"page": 0, "page_size": 100})
rv = self.client.get(f"/api/v1/rowlevelsecurity/related/tables?q={params}")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
result = data["result"]
db_tables = db.session.query(SqlaTable).all()
db_table_names = set([t.name for t in db_tables])
received_tables = set([table["text"] for table in result])
assert data["count"] == len(db_tables)
assert len(result) == len(db_tables)
assert db_table_names == received_tables
def test_rls_roles_related_api(self):
self.login("Admin")
params = prison.dumps({"page": 0, "page_size": 100})
rv = self.client.get(f"/api/v1/rowlevelsecurity/related/roles?q={params}")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
result = data["result"]
db_role_names = set([r.name for r in security_manager.get_all_roles()])
received_roles = set([role["text"] for role in result])
assert data["count"] == len(db_role_names)
assert len(result) == len(db_role_names)
assert db_role_names == received_roles
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@pytest.mark.usefixtures("load_energy_table_with_slice")
@mock.patch(
"superset.row_level_security.api.RLSRestApi.base_related_field_filters",
{"tables": [["table_name", filters.FilterStartsWith, "birth"]]},
)
def test_table_related_filter(self):
self.login("Admin")
params = prison.dumps({"page": 0, "page_size": 10})
rv = self.client.get(f"/api/v1/rowlevelsecurity/related/tables?q={params}")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
result = data["result"]
received_tables = set([table["text"].split(".")[-1] for table in result])
assert data["count"] == 1
assert len(result) == 1
assert {"birth_names"} == received_tables
@mock.patch(
"superset.row_level_security.api.RLSRestApi.base_related_field_filters",
{"roles": [["name", filters.FilterEqual, "Admin"]]},
)
def test_role_related_filter(self):
self.login("Admin")
params = prison.dumps({"page": 0, "page_size": 10})
rv = self.client.get(f"/api/v1/rowlevelsecurity/related/roles?q={params}")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
result = data["result"]
received_roles = set([role["text"] for role in result])
assert data["count"] == 1
assert len(result) == 1
assert {"Admin"} == received_roles
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@pytest.mark.usefixtures("load_energy_table_with_slice")
@mock.patch(
"superset.row_level_security.api.RLSRestApi.base_related_field_filters",
{
"tables": [["table_name", filters.FilterStartsWith, "birth"]],
"roles": [["name", filters.FilterEqual, "Admin"]],
},
)
def test_table_and_role_related_filter(self):
self.login("Admin")
params = prison.dumps({"page": 0, "page_size": 10})
rv = self.client.get(f"/api/v1/rowlevelsecurity/related/tables?q={params}")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
result = data["result"]
received_tables = set([table["text"].split(".")[-1] for table in result])
assert data["count"] == 1
assert len(result) == 1
assert {"birth_names"} == received_tables
rv = self.client.get(f"/api/v1/rowlevelsecurity/related/roles?q={params}")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
result = data["result"]
received_roles = set([role["text"] for role in result])
assert data["count"] == 1
assert len(result) == 1
assert {"Admin"} == received_roles
RLS_ALICE_REGEX = re.compile(r"name = 'Alice'")
RLS_GENDER_REGEX = re.compile(r"AND \(gender = 'girl'\)")