feat: add name, description and non null tables to RLS (#20432)

* feat: add name, description and non null tables to RLS

* add validation

* add and fix tests

* fix sqlite migration

* improve default value for name
This commit is contained in:
Daniel Vaz Gaspar
2022-06-20 13:52:05 +01:00
committed by GitHub
parent 8b0bee5e8b
commit 60eb1094a4
5 changed files with 215 additions and 8 deletions

View File

@@ -25,7 +25,6 @@ from flask import g
from superset import db, security_manager
from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable
from superset.security.guest_token import (
GuestTokenRlsRule,
GuestTokenResourceType,
GuestUser,
)
@@ -82,6 +81,7 @@ class TestRowLevelSecurity(SupersetTestCase):
# Create regular RowLevelSecurityFilter (energy_usage, unicode_test)
self.rls_entry1 = RowLevelSecurityFilter()
self.rls_entry1.name = "rls_entry1"
self.rls_entry1.tables.extend(
session.query(SqlaTable)
.filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"]))
@@ -96,6 +96,7 @@ class TestRowLevelSecurity(SupersetTestCase):
# Create regular RowLevelSecurityFilter (birth_names name starts with A or B)
self.rls_entry2 = RowLevelSecurityFilter()
self.rls_entry2.name = "rls_entry2"
self.rls_entry2.tables.extend(
session.query(SqlaTable)
.filter(SqlaTable.table_name.in_(["birth_names"]))
@@ -109,6 +110,7 @@ class TestRowLevelSecurity(SupersetTestCase):
# Create Regular RowLevelSecurityFilter (birth_names name starts with Q)
self.rls_entry3 = RowLevelSecurityFilter()
self.rls_entry3.name = "rls_entry3"
self.rls_entry3.tables.extend(
session.query(SqlaTable)
.filter(SqlaTable.table_name.in_(["birth_names"]))
@@ -122,6 +124,7 @@ class TestRowLevelSecurity(SupersetTestCase):
# Create Base RowLevelSecurityFilter (birth_names boys)
self.rls_entry4 = RowLevelSecurityFilter()
self.rls_entry4.name = "rls_entry4"
self.rls_entry4.tables.extend(
session.query(SqlaTable)
.filter(SqlaTable.table_name.in_(["birth_names"]))
@@ -146,6 +149,94 @@ class TestRowLevelSecurity(SupersetTestCase):
session.delete(self.get_user("NoRlsRoleUser"))
session.commit()
@pytest.fixture()
def create_dataset(self):
with self.create_app().app_context():
dataset = SqlaTable(database_id=1, schema=None, table_name="table1")
db.session.add(dataset)
db.session.flush()
db.session.commit()
yield dataset
# rollback changes (assuming cascade delete)
db.session.delete(dataset)
db.session.commit()
def _get_test_dataset(self):
return (
db.session.query(SqlaTable).filter(SqlaTable.table_name == "table1")
).one_or_none()
@pytest.mark.usefixtures("create_dataset")
def test_model_view_rls_add_success(self):
self.login(username="admin")
test_dataset = self._get_test_dataset()
rv = self.client.post(
"/rowlevelsecurityfiltersmodelview/add",
data=dict(
name="rls1",
description="Some description",
filter_type="Regular",
tables=[test_dataset.id],
roles=[security_manager.find_role("Alpha").id],
group_key="group_key_1",
clause="client_id=1",
),
follow_redirects=True,
)
self.assertEqual(rv.status_code, 200)
rls1 = (
db.session.query(RowLevelSecurityFilter).filter_by(name="rls1")
).one_or_none()
assert rls1 is not None
# Revert data changes
db.session.delete(rls1)
db.session.commit()
@pytest.mark.usefixtures("create_dataset")
def test_model_view_rls_add_name_unique(self):
self.login(username="admin")
test_dataset = self._get_test_dataset()
rv = self.client.post(
"/rowlevelsecurityfiltersmodelview/add",
data=dict(
name="rls_entry1",
description="Some description",
filter_type="Regular",
tables=[test_dataset.id],
roles=[security_manager.find_role("Alpha").id],
group_key="group_key_1",
clause="client_id=1",
),
follow_redirects=True,
)
self.assertEqual(rv.status_code, 200)
data = rv.data.decode("utf-8")
assert "Already exists." in data
@pytest.mark.usefixtures("create_dataset")
def test_model_view_rls_add_tables_required(self):
self.login(username="admin")
rv = self.client.post(
"/rowlevelsecurityfiltersmodelview/add",
data=dict(
name="rls1",
description="Some description",
filter_type="Regular",
tables=[],
roles=[security_manager.find_role("Alpha").id],
group_key="group_key_1",
clause="client_id=1",
),
follow_redirects=True,
)
self.assertEqual(rv.status_code, 200)
data = rv.data.decode("utf-8")
assert "This field is required." in data
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_rls_filter_alters_energy_query(self):
g.user = self.get_user(username="alpha")