feat: Add multiple table filters for Row Level Security (#9751)

* Add multiple table filters for Row Level Security

* Set ENABLE_ROW_LEVEL_SECURITY back to False (default)

* Merge DB migrations

* Drop table_id column and foreign key on PostgreSQL, MySQL, SQLite

* Support db records migration also

* Support downgrading from the new-fashioned formatted records

* Straighten up migrations

* Update migration's down_revision to comply master branch
This commit is contained in:
Aliaksei Kushniarevich
2020-06-22 12:51:08 +03:00
committed by GitHub
parent dbc43d7c7b
commit 550e78ff7c
5 changed files with 178 additions and 18 deletions

View File

@@ -1332,6 +1332,14 @@ RLSFilterRoles = Table(
Column("rls_filter_id", Integer, ForeignKey("row_level_security_filters.id")),
)
RLSFilterTables = Table(
"rls_filter_tables",
metadata,
Column("id", Integer, primary_key=True),
Column("table_id", Integer, ForeignKey("tables.id")),
Column("rls_filter_id", Integer, ForeignKey("row_level_security_filters.id")),
)
class RowLevelSecurityFilter(Model, AuditMixinNullable):
"""
@@ -1345,7 +1353,8 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable):
secondary=RLSFilterRoles,
backref="row_level_security_filters",
)
tables = relationship(
SqlaTable, secondary=RLSFilterTables, backref="row_level_security_filters"
)
table_id = Column(Integer, ForeignKey("tables.id"), nullable=False)
table = relationship(SqlaTable, backref="row_level_security_filters")
clause = Column(Text, nullable=False)

View File

@@ -236,15 +236,15 @@ class RowLevelSecurityFiltersModelView(SupersetModelView, DeleteMixin):
add_title = _("Add Row level security filter")
edit_title = _("Edit Row level security filter")
list_columns = ["table.table_name", "roles", "clause", "creator", "modified"]
order_columns = ["table.table_name", "clause", "modified"]
edit_columns = ["table", "roles", "clause"]
list_columns = ["tables", "roles", "clause", "creator", "modified"]
order_columns = ["tables", "clause", "modified"]
edit_columns = ["tables", "roles", "clause"]
show_columns = edit_columns
search_columns = ("table", "roles", "clause")
search_columns = ("tables", "roles", "clause")
add_columns = edit_columns
base_order = ("changed_on", "desc")
description_columns = {
"table": _("This is the table this filter will be applied to."),
"tables": _("These are the tables this filter will be applied to."),
"roles": _("These are the roles this filter will be applied to."),
"clause": _(
"This is the condition that will be added to the WHERE clause. "
@@ -252,7 +252,7 @@ class RowLevelSecurityFiltersModelView(SupersetModelView, DeleteMixin):
),
}
label_columns = {
"table": _("Table"),
"tables": _("Tables"),
"roles": _("Roles"),
"clause": _("Clause"),
"creator": _("Creator"),

View File

@@ -0,0 +1,124 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""add_tables_relation_to_row_level_security
Revision ID: e557699a813e
Revises: 743a117f0d98
Create Date: 2020-04-24 10:46:24.119363
"""
# revision identifiers, used by Alembic.
revision = "e557699a813e"
down_revision = "743a117f0d98"
import sqlalchemy as sa
from alembic import op
from superset.utils.core import generic_find_fk_constraint_name
def upgrade():
bind = op.get_bind()
metadata = sa.MetaData(bind=bind)
insp = sa.engine.reflection.Inspector.from_engine(bind)
rls_filter_tables = op.create_table(
"rls_filter_tables",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("table_id", sa.Integer(), nullable=True),
sa.Column("rls_filter_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(["rls_filter_id"], ["row_level_security_filters.id"]),
sa.ForeignKeyConstraint(["table_id"], ["tables.id"]),
sa.PrimaryKeyConstraint("id"),
)
rlsf = sa.Table("row_level_security_filters", metadata, autoload=True)
filter_ids = sa.select([rlsf.c.id, rlsf.c.table_id])
for row in bind.execute(filter_ids):
move_table_id = rls_filter_tables.insert().values(
rls_filter_id=row["id"], table_id=row["table_id"]
)
bind.execute(move_table_id)
with op.batch_alter_table("row_level_security_filters") as batch_op:
fk_constraint_name = generic_find_fk_constraint_name(
"row_level_security_filters", {"id"}, "tables", insp
)
if fk_constraint_name:
batch_op.drop_constraint(fk_constraint_name, type_="foreignkey")
batch_op.drop_column("table_id")
def downgrade():
bind = op.get_bind()
metadata = sa.MetaData(bind=bind)
op.add_column(
"row_level_security_filters",
sa.Column(
"table_id",
sa.INTEGER(),
sa.ForeignKey("tables.id"),
autoincrement=False,
nullable=True,
),
)
rlsf = sa.Table("row_level_security_filters", metadata, autoload=True)
rls_filter_tables = sa.Table("rls_filter_tables", metadata, autoload=True)
rls_filter_roles = sa.Table("rls_filter_roles", metadata, autoload=True)
filter_tables = sa.select([rls_filter_tables.c.rls_filter_id]).group_by(
rls_filter_tables.c.rls_filter_id
)
for row in bind.execute(filter_tables):
filters_copy_ids = []
filter_query = rlsf.select().where(rlsf.c.id == row["rls_filter_id"])
filter_params = dict(bind.execute(filter_query).fetchone())
origin_id = filter_params.pop("id", None)
table_ids = bind.execute(
sa.select([rls_filter_tables.c.table_id]).where(
rls_filter_tables.c.rls_filter_id == row["rls_filter_id"]
)
).fetchall()
filter_params["table_id"] = table_ids.pop(0)[0]
move_table_id = (
rlsf.update().where(rlsf.c.id == origin_id).values(filter_params)
)
bind.execute(move_table_id)
for table_id in table_ids:
filter_params["table_id"] = table_id[0]
copy_filter = rlsf.insert().values(filter_params)
copy_id = bind.execute(copy_filter).inserted_primary_key[0]
filters_copy_ids.append(copy_id)
roles_query = rls_filter_roles.select().where(
rls_filter_roles.c.rls_filter_id == origin_id
)
for role in bind.execute(roles_query):
for copy_id in filters_copy_ids:
role_filter = rls_filter_roles.insert().values(
role_id=role["role_id"], rls_filter_id=copy_id
)
bind.execute(role_filter)
filters_copy_ids.clear()
op.alter_column("row_level_security_filters", "table_id", nullable=False)
op.drop_table("rls_filter_tables")

View File

@@ -904,6 +904,7 @@ class SupersetSecurityManager(SecurityManager):
from superset import db
from superset.connectors.sqla.models import (
RLSFilterRoles,
RLSFilterTables,
RowLevelSecurityFilter,
)
@@ -917,11 +918,16 @@ class SupersetSecurityManager(SecurityManager):
.filter(RLSFilterRoles.c.role_id.in_(user_roles))
.subquery()
)
filter_tables = (
db.session.query(RLSFilterTables.c.rls_filter_id)
.filter(RLSFilterTables.c.table_id == table.id)
.subquery()
)
query = (
db.session.query(
RowLevelSecurityFilter.id, RowLevelSecurityFilter.clause
)
.filter(RowLevelSecurityFilter.table_id == table.id)
.filter(RowLevelSecurityFilter.id.in_(filter_tables))
.filter(RowLevelSecurityFilter.id.in_(filter_roles))
)
return query.all()

View File

@@ -830,10 +830,12 @@ class RowLevelSecurityTests(SupersetTestCase):
# Create the RowLevelSecurityFilter
self.rls_entry = RowLevelSecurityFilter()
self.rls_entry.table = (
session.query(SqlaTable).filter_by(table_name="birth_names").first()
self.rls_entry.tables.extend(
session.query(SqlaTable)
.filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"]))
.all()
)
self.rls_entry.clause = "gender = 'boy'"
self.rls_entry.clause = "value > 1"
self.rls_entry.roles.append(
security_manager.find_role("Gamma")
) # db.session.query(Role).filter_by(name="Gamma").first())
@@ -852,36 +854,55 @@ class RowLevelSecurityTests(SupersetTestCase):
g.user = self.get_user(
username="alpha"
) # self.login() doesn't actually set the user
tbl = self.get_table_by_name("birth_names")
tbl = self.get_table_by_name("energy_usage")
query_obj = dict(
groupby=[],
metrics=[],
filter=[],
is_timeseries=False,
columns=["name"],
columns=["value"],
granularity=None,
from_dttm=None,
to_dttm=None,
extras={},
)
sql = tbl.get_query_str(query_obj)
self.assertIn("gender = 'boy'", sql)
self.assertIn("value > 1", sql)
def test_rls_filter_doesnt_alter_query(self):
g.user = self.get_user(
username="admin"
) # self.login() doesn't actually set the user
tbl = self.get_table_by_name("birth_names")
tbl = self.get_table_by_name("energy_usage")
query_obj = dict(
groupby=[],
metrics=[],
filter=[],
is_timeseries=False,
columns=["name"],
columns=["value"],
granularity=None,
from_dttm=None,
to_dttm=None,
extras={},
)
sql = tbl.get_query_str(query_obj)
self.assertNotIn("gender = 'boy'", sql)
self.assertNotIn("value > 1", sql)
def test_multiple_table_filter_alters_another_tables_query(self):
g.user = self.get_user(
username="alpha"
) # self.login() doesn't actually set the user
tbl = self.get_table_by_name("unicode_test")
query_obj = dict(
groupby=[],
metrics=[],
filter=[],
is_timeseries=False,
columns=["value"],
granularity=None,
from_dttm=None,
to_dttm=None,
extras={},
)
sql = tbl.get_query_str(query_obj)
self.assertIn("value > 1", sql)