refactor: Moving get_user_datasources to security manager (#15467)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley
2021-06-30 09:51:11 -07:00
committed by GitHub
parent cad5ba828c
commit ffa51753e3
6 changed files with 140 additions and 121 deletions

View File

@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections import defaultdict
from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING
from flask_babel import _
@@ -100,41 +99,6 @@ class ConnectorRegistry:
pass
raise NoResultFound(_("Datasource id not found: %(id)s", id=datasource_id))
@classmethod
def get_user_datasources(cls, session: Session) -> List["BaseDatasource"]:
from superset import security_manager
# collect datasources which the user has explicit permissions to
user_perms = security_manager.user_view_menu_names("datasource_access")
schema_perms = security_manager.user_view_menu_names("schema_access")
user_datasources = set()
for datasource_class in ConnectorRegistry.sources.values():
user_datasources.update(
session.query(datasource_class)
.filter(
or_(
datasource_class.perm.in_(user_perms),
datasource_class.schema_perm.in_(schema_perms),
)
)
.all()
)
# group all datasources by database
all_datasources = cls.get_all_datasources(session)
datasources_by_database: Dict["Database", Set["BaseDatasource"]] = defaultdict(
set
)
for datasource in all_datasources:
datasources_by_database[datasource.database].add(datasource)
# add datasources with implicit permission (eg, database access)
for database, datasources in datasources_by_database.items():
if security_manager.can_access_database(database):
user_datasources.update(datasources)
return list(user_datasources)
@classmethod
def get_datasource_by_name( # pylint: disable=too-many-arguments
cls,

View File

@@ -18,7 +18,19 @@
"""A set of constants and methods to manage permissions and security"""
import logging
import re
from typing import Any, Callable, cast, List, Optional, Set, Tuple, TYPE_CHECKING, Union
from collections import defaultdict
from typing import (
Any,
Callable,
cast,
Dict,
List,
Optional,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
from flask import current_app, g
from flask_appbuilder import Model
@@ -419,6 +431,43 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
return conf.get("PERMISSION_INSTRUCTIONS_LINK")
def get_user_datasources(self) -> List["BaseDatasource"]:
"""
Collect datasources which the user has explicit permissions to.
:returns: The list of datasources
"""
user_perms = self.user_view_menu_names("datasource_access")
schema_perms = self.user_view_menu_names("schema_access")
user_datasources = set()
for datasource_class in ConnectorRegistry.sources.values():
user_datasources.update(
self.get_session.query(datasource_class)
.filter(
or_(
datasource_class.perm.in_(user_perms),
datasource_class.schema_perm.in_(schema_perms),
)
)
.all()
)
# group all datasources by database
all_datasources = ConnectorRegistry.get_all_datasources(self.get_session)
datasources_by_database: Dict["Database", Set["BaseDatasource"]] = defaultdict(
set
)
for datasource in all_datasources:
datasources_by_database[datasource.database].add(datasource)
# add datasources with implicit permission (eg, database access)
for database, datasources in datasources_by_database.items():
if self.can_access_database(database):
user_datasources.update(datasources)
return list(user_datasources)
def can_access_table(self, database: "Database", table: "Table") -> bool:
"""
Return True if the user can access the SQL table, False otherwise.

View File

@@ -21,8 +21,7 @@ from flask_appbuilder import expose, has_access
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import lazy_gettext as _
from superset import db, is_feature_enabled
from superset.connectors.connector_registry import ConnectorRegistry
from superset import is_feature_enabled, security_manager
from superset.constants import MODEL_VIEW_RW_METHOD_PERMISSION_MAP, RouteMethod
from superset.models.slice import Slice
from superset.typing import FlaskResponse
@@ -65,7 +64,7 @@ class SliceModelView(
def add(self) -> FlaskResponse:
datasources = [
{"value": str(d.id) + "__" + d.type, "label": repr(d)}
for d in ConnectorRegistry.get_user_datasources(db.session)
for d in security_manager.get_user_datasources()
]
payload = {
"datasources": sorted(

View File

@@ -186,7 +186,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
sorted(
[
datasource.short_data
for datasource in ConnectorRegistry.get_user_datasources(db.session)
for datasource in security_manager.get_user_datasources()
if datasource.short_data.get("name")
],
key=lambda datasource: datasource["name"],

View File

@@ -18,7 +18,6 @@
"""Unit tests for Superset"""
import json
import unittest
from collections import namedtuple
from unittest import mock
from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices
@@ -627,83 +626,5 @@ class TestRequestAccess(SupersetTestCase):
session.commit()
class TestDatasources(SupersetTestCase):
def test_get_user_datasources_admin(self):
Datasource = namedtuple("Datasource", ["database", "schema", "name"])
mock_session = mock.MagicMock()
mock_session.query.return_value.filter.return_value.all.return_value = []
with mock.patch("superset.security_manager") as mock_security_manager:
mock_security_manager.can_access_database.return_value = True
with mock.patch.object(
ConnectorRegistry, "get_all_datasources"
) as mock_get_all_datasources:
mock_get_all_datasources.return_value = [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
Datasource("database2", None, "table1"),
]
datasources = ConnectorRegistry.get_user_datasources(mock_session)
assert sorted(datasources) == [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
Datasource("database2", None, "table1"),
]
def test_get_user_datasources_gamma(self):
Datasource = namedtuple("Datasource", ["database", "schema", "name"])
mock_session = mock.MagicMock()
mock_session.query.return_value.filter.return_value.all.return_value = []
with mock.patch("superset.security_manager") as mock_security_manager:
mock_security_manager.can_access_database.return_value = False
with mock.patch.object(
ConnectorRegistry, "get_all_datasources"
) as mock_get_all_datasources:
mock_get_all_datasources.return_value = [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
Datasource("database2", None, "table1"),
]
datasources = ConnectorRegistry.get_user_datasources(mock_session)
assert datasources == []
def test_get_user_datasources_gamma_with_schema(self):
Datasource = namedtuple("Datasource", ["database", "schema", "name"])
mock_session = mock.MagicMock()
mock_session.query.return_value.filter.return_value.all.return_value = [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
]
with mock.patch("superset.security_manager") as mock_security_manager:
mock_security_manager.can_access_database.return_value = False
with mock.patch.object(
ConnectorRegistry, "get_all_datasources"
) as mock_get_all_datasources:
mock_get_all_datasources.return_value = [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
Datasource("database2", None, "table1"),
]
datasources = ConnectorRegistry.get_user_datasources(mock_session)
assert sorted(datasources) == [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
]
if __name__ == "__main__":
unittest.main()

View File

@@ -18,7 +18,8 @@
import inspect
import re
import unittest
from collections import namedtuple
from unittest import mock
from unittest.mock import Mock, patch
from typing import Any, Dict
@@ -1220,3 +1221,88 @@ class TestAccessRequestEndpoints(SupersetTestCase):
uri = "/accessrequestsmodelview/list/"
rv = self.client.get(uri)
self.assertLess(rv.status_code, 400)
class TestDatasources(SupersetTestCase):
@patch("superset.security.manager.g")
@patch("superset.security.SupersetSecurityManager.can_access_database")
@patch("superset.security.SupersetSecurityManager.get_session")
def test_get_user_datasources_admin(
self, mock_get_session, mock_can_access_database, mock_g
):
Datasource = namedtuple("Datasource", ["database", "schema", "name"])
mock_g.user = security_manager.find_user("admin")
mock_can_access_database.return_value = True
mock_get_session.query.return_value.filter.return_value.all.return_value = []
with mock.patch.object(
ConnectorRegistry, "get_all_datasources"
) as mock_get_all_datasources:
mock_get_all_datasources.return_value = [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
Datasource("database2", None, "table1"),
]
datasources = security_manager.get_user_datasources()
assert sorted(datasources) == [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
Datasource("database2", None, "table1"),
]
@patch("superset.security.manager.g")
@patch("superset.security.SupersetSecurityManager.can_access_database")
@patch("superset.security.SupersetSecurityManager.get_session")
def test_get_user_datasources_gamma(
self, mock_get_session, mock_can_access_database, mock_g
):
Datasource = namedtuple("Datasource", ["database", "schema", "name"])
mock_g.user = security_manager.find_user("gamma")
mock_can_access_database.return_value = False
mock_get_session.query.return_value.filter.return_value.all.return_value = []
with mock.patch.object(
ConnectorRegistry, "get_all_datasources"
) as mock_get_all_datasources:
mock_get_all_datasources.return_value = [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
Datasource("database2", None, "table1"),
]
datasources = security_manager.get_user_datasources()
assert datasources == []
@patch("superset.security.manager.g")
@patch("superset.security.SupersetSecurityManager.can_access_database")
@patch("superset.security.SupersetSecurityManager.get_session")
def test_get_user_datasources_gamma_with_schema(
self, mock_get_session, mock_can_access_database, mock_g
):
Datasource = namedtuple("Datasource", ["database", "schema", "name"])
mock_g.user = security_manager.find_user("gamma")
mock_can_access_database.return_value = False
mock_get_session.query.return_value.filter.return_value.all.return_value = [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
]
with mock.patch.object(
ConnectorRegistry, "get_all_datasources"
) as mock_get_all_datasources:
mock_get_all_datasources.return_value = [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
Datasource("database2", None, "table1"),
]
datasources = security_manager.get_user_datasources()
assert sorted(datasources) == [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
]