mirror of
https://github.com/apache/superset.git
synced 2026-04-07 18:35:15 +00:00
refactor: Moving get_user_datasources to security manager (#15467)
Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
@@ -14,7 +14,6 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING
|
||||
|
||||
from flask_babel import _
|
||||
@@ -100,41 +99,6 @@ class ConnectorRegistry:
|
||||
pass
|
||||
raise NoResultFound(_("Datasource id not found: %(id)s", id=datasource_id))
|
||||
|
||||
@classmethod
|
||||
def get_user_datasources(cls, session: Session) -> List["BaseDatasource"]:
|
||||
from superset import security_manager
|
||||
|
||||
# collect datasources which the user has explicit permissions to
|
||||
user_perms = security_manager.user_view_menu_names("datasource_access")
|
||||
schema_perms = security_manager.user_view_menu_names("schema_access")
|
||||
user_datasources = set()
|
||||
for datasource_class in ConnectorRegistry.sources.values():
|
||||
user_datasources.update(
|
||||
session.query(datasource_class)
|
||||
.filter(
|
||||
or_(
|
||||
datasource_class.perm.in_(user_perms),
|
||||
datasource_class.schema_perm.in_(schema_perms),
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# group all datasources by database
|
||||
all_datasources = cls.get_all_datasources(session)
|
||||
datasources_by_database: Dict["Database", Set["BaseDatasource"]] = defaultdict(
|
||||
set
|
||||
)
|
||||
for datasource in all_datasources:
|
||||
datasources_by_database[datasource.database].add(datasource)
|
||||
|
||||
# add datasources with implicit permission (eg, database access)
|
||||
for database, datasources in datasources_by_database.items():
|
||||
if security_manager.can_access_database(database):
|
||||
user_datasources.update(datasources)
|
||||
|
||||
return list(user_datasources)
|
||||
|
||||
@classmethod
|
||||
def get_datasource_by_name( # pylint: disable=too-many-arguments
|
||||
cls,
|
||||
|
||||
@@ -18,7 +18,19 @@
|
||||
"""A set of constants and methods to manage permissions and security"""
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Callable, cast, List, Optional, Set, Tuple, TYPE_CHECKING, Union
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
|
||||
from flask import current_app, g
|
||||
from flask_appbuilder import Model
|
||||
@@ -419,6 +431,43 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
|
||||
return conf.get("PERMISSION_INSTRUCTIONS_LINK")
|
||||
|
||||
def get_user_datasources(self) -> List["BaseDatasource"]:
|
||||
"""
|
||||
Collect datasources which the user has explicit permissions to.
|
||||
|
||||
:returns: The list of datasources
|
||||
"""
|
||||
|
||||
user_perms = self.user_view_menu_names("datasource_access")
|
||||
schema_perms = self.user_view_menu_names("schema_access")
|
||||
user_datasources = set()
|
||||
for datasource_class in ConnectorRegistry.sources.values():
|
||||
user_datasources.update(
|
||||
self.get_session.query(datasource_class)
|
||||
.filter(
|
||||
or_(
|
||||
datasource_class.perm.in_(user_perms),
|
||||
datasource_class.schema_perm.in_(schema_perms),
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# group all datasources by database
|
||||
all_datasources = ConnectorRegistry.get_all_datasources(self.get_session)
|
||||
datasources_by_database: Dict["Database", Set["BaseDatasource"]] = defaultdict(
|
||||
set
|
||||
)
|
||||
for datasource in all_datasources:
|
||||
datasources_by_database[datasource.database].add(datasource)
|
||||
|
||||
# add datasources with implicit permission (eg, database access)
|
||||
for database, datasources in datasources_by_database.items():
|
||||
if self.can_access_database(database):
|
||||
user_datasources.update(datasources)
|
||||
|
||||
return list(user_datasources)
|
||||
|
||||
def can_access_table(self, database: "Database", table: "Table") -> bool:
|
||||
"""
|
||||
Return True if the user can access the SQL table, False otherwise.
|
||||
|
||||
@@ -21,8 +21,7 @@ from flask_appbuilder import expose, has_access
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
from flask_babel import lazy_gettext as _
|
||||
|
||||
from superset import db, is_feature_enabled
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset import is_feature_enabled, security_manager
|
||||
from superset.constants import MODEL_VIEW_RW_METHOD_PERMISSION_MAP, RouteMethod
|
||||
from superset.models.slice import Slice
|
||||
from superset.typing import FlaskResponse
|
||||
@@ -65,7 +64,7 @@ class SliceModelView(
|
||||
def add(self) -> FlaskResponse:
|
||||
datasources = [
|
||||
{"value": str(d.id) + "__" + d.type, "label": repr(d)}
|
||||
for d in ConnectorRegistry.get_user_datasources(db.session)
|
||||
for d in security_manager.get_user_datasources()
|
||||
]
|
||||
payload = {
|
||||
"datasources": sorted(
|
||||
|
||||
@@ -186,7 +186,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
||||
sorted(
|
||||
[
|
||||
datasource.short_data
|
||||
for datasource in ConnectorRegistry.get_user_datasources(db.session)
|
||||
for datasource in security_manager.get_user_datasources()
|
||||
if datasource.short_data.get("name")
|
||||
],
|
||||
key=lambda datasource: datasource["name"],
|
||||
|
||||
@@ -18,7 +18,6 @@
|
||||
"""Unit tests for Superset"""
|
||||
import json
|
||||
import unittest
|
||||
from collections import namedtuple
|
||||
from unittest import mock
|
||||
from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices
|
||||
|
||||
@@ -627,83 +626,5 @@ class TestRequestAccess(SupersetTestCase):
|
||||
session.commit()
|
||||
|
||||
|
||||
class TestDatasources(SupersetTestCase):
|
||||
def test_get_user_datasources_admin(self):
|
||||
Datasource = namedtuple("Datasource", ["database", "schema", "name"])
|
||||
|
||||
mock_session = mock.MagicMock()
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = []
|
||||
|
||||
with mock.patch("superset.security_manager") as mock_security_manager:
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
with mock.patch.object(
|
||||
ConnectorRegistry, "get_all_datasources"
|
||||
) as mock_get_all_datasources:
|
||||
mock_get_all_datasources.return_value = [
|
||||
Datasource("database1", "schema1", "table1"),
|
||||
Datasource("database1", "schema1", "table2"),
|
||||
Datasource("database2", None, "table1"),
|
||||
]
|
||||
|
||||
datasources = ConnectorRegistry.get_user_datasources(mock_session)
|
||||
|
||||
assert sorted(datasources) == [
|
||||
Datasource("database1", "schema1", "table1"),
|
||||
Datasource("database1", "schema1", "table2"),
|
||||
Datasource("database2", None, "table1"),
|
||||
]
|
||||
|
||||
def test_get_user_datasources_gamma(self):
|
||||
Datasource = namedtuple("Datasource", ["database", "schema", "name"])
|
||||
|
||||
mock_session = mock.MagicMock()
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = []
|
||||
|
||||
with mock.patch("superset.security_manager") as mock_security_manager:
|
||||
mock_security_manager.can_access_database.return_value = False
|
||||
|
||||
with mock.patch.object(
|
||||
ConnectorRegistry, "get_all_datasources"
|
||||
) as mock_get_all_datasources:
|
||||
mock_get_all_datasources.return_value = [
|
||||
Datasource("database1", "schema1", "table1"),
|
||||
Datasource("database1", "schema1", "table2"),
|
||||
Datasource("database2", None, "table1"),
|
||||
]
|
||||
|
||||
datasources = ConnectorRegistry.get_user_datasources(mock_session)
|
||||
|
||||
assert datasources == []
|
||||
|
||||
def test_get_user_datasources_gamma_with_schema(self):
|
||||
Datasource = namedtuple("Datasource", ["database", "schema", "name"])
|
||||
|
||||
mock_session = mock.MagicMock()
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [
|
||||
Datasource("database1", "schema1", "table1"),
|
||||
Datasource("database1", "schema1", "table2"),
|
||||
]
|
||||
|
||||
with mock.patch("superset.security_manager") as mock_security_manager:
|
||||
mock_security_manager.can_access_database.return_value = False
|
||||
|
||||
with mock.patch.object(
|
||||
ConnectorRegistry, "get_all_datasources"
|
||||
) as mock_get_all_datasources:
|
||||
mock_get_all_datasources.return_value = [
|
||||
Datasource("database1", "schema1", "table1"),
|
||||
Datasource("database1", "schema1", "table2"),
|
||||
Datasource("database2", None, "table1"),
|
||||
]
|
||||
|
||||
datasources = ConnectorRegistry.get_user_datasources(mock_session)
|
||||
|
||||
assert sorted(datasources) == [
|
||||
Datasource("database1", "schema1", "table1"),
|
||||
Datasource("database1", "schema1", "table2"),
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -18,7 +18,8 @@
|
||||
import inspect
|
||||
import re
|
||||
import unittest
|
||||
|
||||
from collections import namedtuple
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock, patch
|
||||
from typing import Any, Dict
|
||||
|
||||
@@ -1220,3 +1221,88 @@ class TestAccessRequestEndpoints(SupersetTestCase):
|
||||
uri = "/accessrequestsmodelview/list/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertLess(rv.status_code, 400)
|
||||
|
||||
|
||||
class TestDatasources(SupersetTestCase):
|
||||
@patch("superset.security.manager.g")
|
||||
@patch("superset.security.SupersetSecurityManager.can_access_database")
|
||||
@patch("superset.security.SupersetSecurityManager.get_session")
|
||||
def test_get_user_datasources_admin(
|
||||
self, mock_get_session, mock_can_access_database, mock_g
|
||||
):
|
||||
Datasource = namedtuple("Datasource", ["database", "schema", "name"])
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
mock_can_access_database.return_value = True
|
||||
mock_get_session.query.return_value.filter.return_value.all.return_value = []
|
||||
|
||||
with mock.patch.object(
|
||||
ConnectorRegistry, "get_all_datasources"
|
||||
) as mock_get_all_datasources:
|
||||
mock_get_all_datasources.return_value = [
|
||||
Datasource("database1", "schema1", "table1"),
|
||||
Datasource("database1", "schema1", "table2"),
|
||||
Datasource("database2", None, "table1"),
|
||||
]
|
||||
|
||||
datasources = security_manager.get_user_datasources()
|
||||
|
||||
assert sorted(datasources) == [
|
||||
Datasource("database1", "schema1", "table1"),
|
||||
Datasource("database1", "schema1", "table2"),
|
||||
Datasource("database2", None, "table1"),
|
||||
]
|
||||
|
||||
@patch("superset.security.manager.g")
|
||||
@patch("superset.security.SupersetSecurityManager.can_access_database")
|
||||
@patch("superset.security.SupersetSecurityManager.get_session")
|
||||
def test_get_user_datasources_gamma(
|
||||
self, mock_get_session, mock_can_access_database, mock_g
|
||||
):
|
||||
Datasource = namedtuple("Datasource", ["database", "schema", "name"])
|
||||
mock_g.user = security_manager.find_user("gamma")
|
||||
mock_can_access_database.return_value = False
|
||||
mock_get_session.query.return_value.filter.return_value.all.return_value = []
|
||||
|
||||
with mock.patch.object(
|
||||
ConnectorRegistry, "get_all_datasources"
|
||||
) as mock_get_all_datasources:
|
||||
mock_get_all_datasources.return_value = [
|
||||
Datasource("database1", "schema1", "table1"),
|
||||
Datasource("database1", "schema1", "table2"),
|
||||
Datasource("database2", None, "table1"),
|
||||
]
|
||||
|
||||
datasources = security_manager.get_user_datasources()
|
||||
|
||||
assert datasources == []
|
||||
|
||||
@patch("superset.security.manager.g")
|
||||
@patch("superset.security.SupersetSecurityManager.can_access_database")
|
||||
@patch("superset.security.SupersetSecurityManager.get_session")
|
||||
def test_get_user_datasources_gamma_with_schema(
|
||||
self, mock_get_session, mock_can_access_database, mock_g
|
||||
):
|
||||
Datasource = namedtuple("Datasource", ["database", "schema", "name"])
|
||||
mock_g.user = security_manager.find_user("gamma")
|
||||
mock_can_access_database.return_value = False
|
||||
|
||||
mock_get_session.query.return_value.filter.return_value.all.return_value = [
|
||||
Datasource("database1", "schema1", "table1"),
|
||||
Datasource("database1", "schema1", "table2"),
|
||||
]
|
||||
|
||||
with mock.patch.object(
|
||||
ConnectorRegistry, "get_all_datasources"
|
||||
) as mock_get_all_datasources:
|
||||
mock_get_all_datasources.return_value = [
|
||||
Datasource("database1", "schema1", "table1"),
|
||||
Datasource("database1", "schema1", "table2"),
|
||||
Datasource("database2", None, "table1"),
|
||||
]
|
||||
|
||||
datasources = security_manager.get_user_datasources()
|
||||
|
||||
assert sorted(datasources) == [
|
||||
Datasource("database1", "schema1", "table1"),
|
||||
Datasource("database1", "schema1", "table2"),
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user