diff --git a/superset-frontend/src/pages/DatabaseList/index.tsx b/superset-frontend/src/pages/DatabaseList/index.tsx index 776dbbe817a..3e9471be74b 100644 --- a/superset-frontend/src/pages/DatabaseList/index.tsx +++ b/superset-frontend/src/pages/DatabaseList/index.tsx @@ -71,6 +71,7 @@ interface DatabaseDeleteObject extends DatabaseObject { interface DatabaseListProps { addDangerToast: (msg: string) => void; addSuccessToast: (msg: string) => void; + addInfoToast: (msg: string) => void; user: { userId: string | number; firstName: string; @@ -101,6 +102,7 @@ function BooleanDisplay({ value }: { value: Boolean }) { function DatabaseList({ addDangerToast, + addInfoToast, addSuccessToast, user, }: DatabaseListProps) { @@ -121,6 +123,9 @@ function DatabaseList({ const fullUser = useSelector( state => state.user, ); + const shouldSyncPermsInAsyncMode = useSelector( + state => state.common?.conf.SYNC_DB_PERMISSIONS_IN_ASYNC_MODE, + ); const showDatabaseModal = getUrlParam(URL_PARAMS.showDatabaseModal); const [query, setQuery] = useQueryParams({ @@ -335,6 +340,44 @@ function DatabaseList({ setPreparingExport(true); } + function handleDatabasePermSync(database: DatabaseObject) { + if (shouldSyncPermsInAsyncMode) { + addInfoToast(t('Validating connectivity for %s', database.database_name)); + } else { + addInfoToast(t('Syncing permissions for %s', database.database_name)); + } + SupersetClient.post({ + endpoint: `/api/v1/database/${database.id}/sync_permissions/`, + }).then( + ({ response }) => { + // Sync request + if (response.status === 200) { + addSuccessToast( + t('Permissions successfully synced for %s', database.database_name), + ); + } + // Async request + else { + addInfoToast( + t( + 'Syncing permissions for %s in the background', + database.database_name, + ), + ); + } + }, + createErrorHandler(errMsg => + addDangerToast( + t( + 'An error occurred while syncing permissions for %s: %s', + database.database_name, + errMsg, + ), + ), + ), + ); + } + const initialSort = [{ id: 'changed_on_delta_humanized', desc: true }]; const columns = useMemo( @@ -426,6 +469,7 @@ function DatabaseList({ handleDatabaseEditModal({ database: original, modalOpen: true }); const handleDelete = () => openDatabaseDeleteModal(original); const handleExport = () => handleDatabaseExport(original); + const handleSync = () => handleDatabasePermSync(original); if (!canEdit && !canDelete && !canExport) { return null; } @@ -481,6 +525,23 @@ function DatabaseList({ )} + {canEdit && ( + + + + + + )} ); }, diff --git a/superset/commands/database/exceptions.py b/superset/commands/database/exceptions.py index 5285deb0f9d..6dfdadacbd7 100644 --- a/superset/commands/database/exceptions.py +++ b/superset/commands/database/exceptions.py @@ -88,11 +88,21 @@ class DatabaseExtraValidationError(ValidationError): ) +class DatabaseConnectionSyncPermissionsError(CommandException): + status = 500 + message = _("Unable to sync permissions for this database connection.") + + class DatabaseNotFoundError(CommandException): status = 404 message = _("Database not found.") +class UserNotFoundInSessionError(CommandException): + status = 500 + message = _("Could not validate the user in the current session.") + + class DatabaseSchemaUploadNotAllowed(CommandException): status = 403 message = _("Database schema is not allowed for csv uploads.") diff --git a/superset/commands/database/sync_permissions.py b/superset/commands/database/sync_permissions.py new file mode 100644 index 00000000000..3f2bf36a0e7 --- /dev/null +++ b/superset/commands/database/sync_permissions.py @@ -0,0 +1,344 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +from functools import partial +from typing import Iterable + +from flask import current_app, g + +from superset import app, security_manager +from superset.commands.base import BaseCommand +from superset.commands.database.exceptions import ( + DatabaseConnectionFailedError, + DatabaseConnectionSyncPermissionsError, + DatabaseNotFoundError, + UserNotFoundInSessionError, +) +from superset.commands.database.utils import ( + add_pvm, + add_vm, + ping, +) +from superset.daos.database import DatabaseDAO +from superset.daos.dataset import DatasetDAO +from superset.databases.ssh_tunnel.models import SSHTunnel +from superset.db_engine_specs.base import GenericDBException +from superset.exceptions import OAuth2RedirectError +from superset.extensions import celery_app, db +from superset.models.core import Database +from superset.utils.decorators import on_error, transaction + +logger = logging.getLogger(__name__) + + +class SyncPermissionsCommand(BaseCommand): + """ + Command to sync database permissions. + + This command can be called either via its dedicated endpoint, or as part of + another command. If async mode is enabled, the command is executed through + a celery task, otherwise it's executed synchronously. + """ + + def __init__( + self, + model_id: int, + username: str | None, + old_db_connection_name: str | None = None, + db_connection: Database | None = None, + ssh_tunnel: SSHTunnel | None = None, + ): + """ + Constructor method. + """ + self.db_connection_id = model_id + self.username = username + self._old_db_connection_name: str | None = old_db_connection_name + self._db_connection: Database | None = db_connection + self.db_connection_ssh_tunnel: SSHTunnel | None = ssh_tunnel + + self.async_mode: bool = app.config["SYNC_DB_PERMISSIONS_IN_ASYNC_MODE"] + + @property + def db_connection(self) -> Database: + if not self._db_connection: + raise DatabaseNotFoundError() + return self._db_connection + + @property + def old_db_connection_name(self) -> str: + return ( + self._old_db_connection_name + if self._old_db_connection_name is not None + else self.db_connection.database_name + ) + + def validate(self) -> None: + self._db_connection = ( + self._db_connection + if self._db_connection + else DatabaseDAO.find_by_id(self.db_connection_id) + ) + if not self._db_connection: + raise DatabaseNotFoundError() + + if not self.db_connection_ssh_tunnel: + self.db_connection_ssh_tunnel = DatabaseDAO.get_ssh_tunnel( + self.db_connection_id + ) + + # Need user info to impersonate for OAuth2 connections + if not self.username or not security_manager.get_user_by_username( + self.username + ): + raise UserNotFoundInSessionError() + + with self.db_connection.get_sqla_engine( + override_ssh_tunnel=self.db_connection_ssh_tunnel + ) as engine: + try: + alive = ping(engine) + except Exception as err: + raise DatabaseConnectionFailedError() from err + + if not alive: + raise DatabaseConnectionFailedError() + + def run(self) -> None: + """ + Triggers the perm sync in sync or async mode. + """ + self.validate() + if self.async_mode: + sync_database_permissions_task.delay( + self.db_connection_id, self.username, self.old_db_connection_name + ) + return + + self.sync_database_permissions() + + @transaction( + on_error=partial(on_error, reraise=DatabaseConnectionSyncPermissionsError) + ) + def sync_database_permissions(self) -> None: + """ + Syncs the permissions for a DB connection. + """ + catalogs = ( + self._get_catalog_names() + if self.db_connection.db_engine_spec.supports_catalog + else [None] + ) + + for catalog in catalogs: + try: + schemas = self._get_schema_names(catalog) + + if catalog: + perm = security_manager.get_catalog_perm( + self.old_db_connection_name, + catalog, + ) + existing_pvm = security_manager.find_permission_view_menu( + "catalog_access", + perm, + ) + if not existing_pvm: + # new catalog + add_pvm( + db.session, + security_manager, + "catalog_access", + security_manager.get_catalog_perm( + self.db_connection.database_name, + catalog, + ), + ) + for schema in schemas: + add_pvm( + db.session, + security_manager, + "schema_access", + security_manager.get_schema_perm( + self.db_connection.database_name, + catalog, + schema, + ), + ) + continue + except DatabaseConnectionFailedError: + logger.warning("Error processing catalog %s", catalog or "(default)") + continue + + # add possible new schemas in catalog + self._refresh_schemas(catalog, schemas) + + if self.old_db_connection_name != self.db_connection.database_name: + self._rename_database_in_permissions(catalog, schemas) + + def _get_catalog_names(self) -> set[str]: + """ + Helper method to load catalogs. + """ + try: + return self.db_connection.get_all_catalog_names( + force=True, + ssh_tunnel=self.db_connection_ssh_tunnel, + ) + except OAuth2RedirectError: + # raise OAuth2 exceptions as-is + raise + except GenericDBException as ex: + raise DatabaseConnectionFailedError() from ex + + def _get_schema_names(self, catalog: str | None) -> set[str]: + """ + Helper method to load schemas. + """ + try: + return self.db_connection.get_all_schema_names( + force=True, + catalog=catalog, + ssh_tunnel=self.db_connection_ssh_tunnel, + ) + except OAuth2RedirectError: + # raise OAuth2 exceptions as-is + raise + except GenericDBException as ex: + raise DatabaseConnectionFailedError() from ex + + def _refresh_schemas(self, catalog: str | None, schemas: Iterable[str]) -> None: + """ + Add new schemas that don't have permissions yet. + """ + for schema in schemas: + perm = security_manager.get_schema_perm( + self.old_db_connection_name, + catalog, + schema, + ) + existing_pvm = security_manager.find_permission_view_menu( + "schema_access", + perm, + ) + if not existing_pvm: + new_name = security_manager.get_schema_perm( + self.db_connection.name, + catalog, + schema, + ) + add_pvm(db.session, security_manager, "schema_access", new_name) + + def _rename_database_in_permissions( + self, catalog: str | None, schemas: Iterable[str] + ) -> None: + # rename existing catalog permission + if catalog: + new_catalog_perm_name = security_manager.get_catalog_perm( + self.db_connection.name, + catalog, + ) + new_catalog_vm = add_vm(db.session, security_manager, new_catalog_perm_name) + perm = security_manager.get_catalog_perm( + self.old_db_connection_name, + catalog, + ) + existing_pvm = security_manager.find_permission_view_menu( + "catalog_access", + perm, + ) + if existing_pvm: + existing_pvm.view_menu = new_catalog_vm + + for schema in schemas: + new_schema_perm_name = security_manager.get_schema_perm( + self.db_connection.name, + catalog, + schema, + ) + + # rename existing schema permission + perm = security_manager.get_schema_perm( + self.old_db_connection_name, + catalog, + schema, + ) + existing_pvm = security_manager.find_permission_view_menu( + "schema_access", + perm, + ) + if existing_pvm: + existing_pvm.view_menu.name = new_schema_perm_name + + # rename permissions on datasets and charts + for dataset in DatabaseDAO.get_datasets( + self.db_connection_id, + catalog=catalog, + schema=schema, + ): + dataset.catalog_perm = new_catalog_perm_name + dataset.schema_perm = new_schema_perm_name + for chart in DatasetDAO.get_related_objects(dataset.id)["charts"]: + chart.catalog_perm = new_catalog_perm_name + chart.schema_perm = new_schema_perm_name + + +@celery_app.task(name="sync_database_permissions", soft_time_limit=600) +def sync_database_permissions_task( + database_id: int, username: str, old_db_connection_name: str +) -> None: + """ + Celery task that triggers the SyncPermissionsCommand in async mode. + """ + with current_app.test_request_context(): + try: + user = security_manager.get_user_by_username(username) + if not user: + raise UserNotFoundInSessionError() + g.user = user + logger.info( + "Syncing permissions for DB connection %s while impersonating user %s", + database_id, + user.id, + ) + + db_connection = DatabaseDAO.find_by_id(database_id) + if not db_connection: + raise DatabaseNotFoundError() + ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database_id) + + SyncPermissionsCommand( + database_id, + username, + old_db_connection_name=old_db_connection_name, + db_connection=db_connection, + ssh_tunnel=ssh_tunnel, + ).sync_database_permissions() + + logger.info( + "Successfully synced permissions for DB connection %s", + database_id, + ) + + except Exception: + logger.error( + "An error occurred while syncing permissions for DB connection ID %s", + database_id, + exc_info=True, + ) diff --git a/superset/commands/database/test_connection.py b/superset/commands/database/test_connection.py index 6d3219253ea..3c16730d00b 100644 --- a/superset/commands/database/test_connection.py +++ b/superset/commands/database/test_connection.py @@ -15,13 +15,9 @@ # specific language governing permissions and limitations # under the License. import logging -import sqlite3 -from contextlib import closing from typing import Any, Optional -from flask import current_app as app from flask_babel import gettext as _ -from sqlalchemy.engine import Engine from sqlalchemy.exc import DBAPIError, NoSuchModuleError from superset import is_feature_enabled @@ -35,6 +31,7 @@ from superset.commands.database.ssh_tunnel.exceptions import ( SSHTunnelDatabasePortError, SSHTunnelingNotEnabledError, ) +from superset.commands.database.utils import ping from superset.daos.database import DatabaseDAO, SSHTunnelDAO from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.utils import make_url_safe @@ -47,7 +44,6 @@ from superset.exceptions import ( ) from superset.extensions import event_logger from superset.models.core import Database -from superset.utils import core as utils from superset.utils.ssh_tunnel import unmask_password_info logger = logging.getLogger(__name__) @@ -136,19 +132,9 @@ class TestConnectionDatabaseCommand(BaseCommand): engine=database.db_engine_spec.__name__, ) - def ping(engine: Engine) -> bool: - with closing(engine.raw_connection()) as conn: - return engine.dialect.do_ping(conn) - with database.get_sqla_engine(override_ssh_tunnel=ssh_tunnel) as engine: try: - time_delta = app.config["TEST_DATABASE_CONNECTION_TIMEOUT"] - with utils.timeout(int(time_delta.total_seconds())): - alive = ping(engine) - except (sqlite3.ProgrammingError, RuntimeError): - # SQLite can't run on a separate thread, so ``utils.timeout`` fails - # RuntimeError catches the equivalent error from duckdb. - alive = engine.dialect.do_ping(engine) + alive = ping(engine) except SupersetTimeoutException as ex: raise SupersetTimeoutException( error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT, diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py index fbf90694f48..a1accf4df10 100644 --- a/superset/commands/database/update.py +++ b/superset/commands/database/update.py @@ -23,10 +23,9 @@ from typing import Any from flask_appbuilder.models.sqla import Model -from superset import is_feature_enabled, security_manager +from superset import is_feature_enabled from superset.commands.base import BaseCommand from superset.commands.database.exceptions import ( - DatabaseConnectionFailedError, DatabaseExistsValidationError, DatabaseInvalidError, DatabaseNotFoundError, @@ -38,13 +37,13 @@ from superset.commands.database.ssh_tunnel.exceptions import ( SSHTunnelingNotEnabledError, ) from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand +from superset.commands.database.sync_permissions import SyncPermissionsCommand from superset.daos.database import DatabaseDAO -from superset.daos.dataset import DatasetDAO from superset.databases.ssh_tunnel.models import SSHTunnel -from superset.db_engine_specs.base import GenericDBException from superset.exceptions import OAuth2RedirectError from superset.models.core import Database from superset.utils import json +from superset.utils.core import get_username from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -88,7 +87,14 @@ class UpdateDatabaseCommand(BaseCommand): database.set_sqlalchemy_uri(database.sqlalchemy_uri) ssh_tunnel = self._handle_ssh_tunnel(database) try: - self._refresh_catalogs(database, original_database_name, ssh_tunnel) + current_username = get_username() + SyncPermissionsCommand( + self._model_id, + current_username, + old_db_connection_name=original_database_name, + db_connection=database, + ssh_tunnel=ssh_tunnel, + ).run() except OAuth2RedirectError: pass @@ -153,201 +159,6 @@ class UpdateDatabaseCommand(BaseCommand): ssh_tunnel_properties, ).run() - def _get_catalog_names( - self, - database: Database, - ssh_tunnel: SSHTunnel | None, - ) -> set[str]: - """ - Helper method to load catalogs. - """ - try: - return database.get_all_catalog_names( - force=True, - ssh_tunnel=ssh_tunnel, - ) - except OAuth2RedirectError: - # raise OAuth2 exceptions as-is - raise - except GenericDBException as ex: - raise DatabaseConnectionFailedError() from ex - - def _get_schema_names( - self, - database: Database, - catalog: str | None, - ssh_tunnel: SSHTunnel | None, - ) -> set[str]: - """ - Helper method to load schemas. - """ - try: - return database.get_all_schema_names( - force=True, - catalog=catalog, - ssh_tunnel=ssh_tunnel, - ) - except OAuth2RedirectError: - # raise OAuth2 exceptions as-is - raise - except GenericDBException as ex: - raise DatabaseConnectionFailedError() from ex - - def _refresh_catalogs( - self, - database: Database, - original_database_name: str, - ssh_tunnel: SSHTunnel | None, - ) -> None: - """ - Add permissions for any new catalogs and schemas. - """ - catalogs = ( - self._get_catalog_names(database, ssh_tunnel) - if database.db_engine_spec.supports_catalog - else [None] - ) - - for catalog in catalogs: - try: - schemas = self._get_schema_names(database, catalog, ssh_tunnel) - - if catalog: - perm = security_manager.get_catalog_perm( - original_database_name, - catalog, - ) - existing_pvm = security_manager.find_permission_view_menu( - "catalog_access", - perm, - ) - if not existing_pvm: - # new catalog - security_manager.add_permission_view_menu( - "catalog_access", - security_manager.get_catalog_perm( - database.database_name, - catalog, - ), - ) - for schema in schemas: - security_manager.add_permission_view_menu( - "schema_access", - security_manager.get_schema_perm( - database.database_name, - catalog, - schema, - ), - ) - continue - except DatabaseConnectionFailedError: - # more than one catalog, move to next - if catalog: - logger.warning("Error processing catalog %s", catalog) - continue - raise - - # add possible new schemas in catalog - self._refresh_schemas( - database, - original_database_name, - catalog, - schemas, - ) - - if original_database_name != database.database_name: - self._rename_database_in_permissions( - database, - original_database_name, - catalog, - schemas, - ) - - def _refresh_schemas( - self, - database: Database, - original_database_name: str, - catalog: str | None, - schemas: set[str], - ) -> None: - """ - Add new schemas that don't have permissions yet. - """ - for schema in schemas: - perm = security_manager.get_schema_perm( - original_database_name, - catalog, - schema, - ) - existing_pvm = security_manager.find_permission_view_menu( - "schema_access", - perm, - ) - if not existing_pvm: - new_name = security_manager.get_schema_perm( - database.database_name, - catalog, - schema, - ) - security_manager.add_permission_view_menu("schema_access", new_name) - - def _rename_database_in_permissions( - self, - database: Database, - original_database_name: str, - catalog: str | None, - schemas: set[str], - ) -> None: - new_catalog_perm_name = security_manager.get_catalog_perm( - database.database_name, - catalog, - ) - - # rename existing catalog permission - if catalog: - perm = security_manager.get_catalog_perm( - original_database_name, - catalog, - ) - existing_pvm = security_manager.find_permission_view_menu( - "catalog_access", - perm, - ) - if existing_pvm: - existing_pvm.view_menu.name = new_catalog_perm_name - - for schema in schemas: - new_schema_perm_name = security_manager.get_schema_perm( - database.database_name, - catalog, - schema, - ) - - # rename existing schema permission - perm = security_manager.get_schema_perm( - original_database_name, - catalog, - schema, - ) - existing_pvm = security_manager.find_permission_view_menu( - "schema_access", - perm, - ) - if existing_pvm: - existing_pvm.view_menu.name = new_schema_perm_name - - # rename permissions on datasets and charts - for dataset in DatabaseDAO.get_datasets( - database.id, - catalog=catalog, - schema=schema, - ): - dataset.catalog_perm = new_catalog_perm_name - dataset.schema_perm = new_schema_perm_name - for chart in DatasetDAO.get_related_objects(dataset.id)["charts"]: - chart.catalog_perm = new_catalog_perm_name - chart.schema_perm = new_schema_perm_name - def validate(self) -> None: if database_name := self._properties.get("database_name"): if not DatabaseDAO.validate_update_uniqueness( diff --git a/superset/commands/database/utils.py b/superset/commands/database/utils.py index ea0ce1a27e2..2c113173fbf 100644 --- a/superset/commands/database/utils.py +++ b/superset/commands/database/utils.py @@ -17,19 +17,45 @@ from __future__ import annotations import logging +import sqlite3 +from contextlib import closing + +from flask import current_app as app +from flask_appbuilder.security.sqla.models import ( + Permission, + PermissionView, + ViewMenu, +) +from sqlalchemy.engine import Engine +from sqlalchemy.orm import Session from superset import security_manager from superset.databases.ssh_tunnel.models import SSHTunnel from superset.db_engine_specs.base import GenericDBException from superset.models.core import Database +from superset.security.manager import SupersetSecurityManager +from superset.utils.core import timeout logger = logging.getLogger(__name__) +def ping(engine: Engine) -> bool: + try: + time_delta = app.config["TEST_DATABASE_CONNECTION_TIMEOUT"] + with timeout(int(time_delta.total_seconds())): + with closing(engine.raw_connection()) as conn: + return engine.dialect.do_ping(conn) + except (sqlite3.ProgrammingError, RuntimeError): + # SQLite can't run on a separate thread, so ``utils.timeout`` fails + # RuntimeError catches the equivalent error from duckdb. + return engine.dialect.do_ping(engine) + + def add_permissions(database: Database, ssh_tunnel: SSHTunnel | None) -> None: """ Add DAR for catalogs and schemas. """ + # TODO: Migrate this to use the non-commiting add_pvm helper instead if database.db_engine_spec.supports_catalog: catalogs = database.get_all_catalog_names( cache=False, @@ -65,3 +91,69 @@ def add_permissions(database: Database, ssh_tunnel: SSHTunnel | None) -> None: except GenericDBException: # pylint: disable=broad-except logger.warning("Error processing catalog '%s'", catalog) continue + + +def add_vm( + session: Session, + security_manager: SupersetSecurityManager, + view_menu_name: str | None, +) -> ViewMenu: + """ + Similar to security_manager.add_view_menu, but without commit. + + This ensures an atomic operation. + """ + if view_menu := security_manager.find_view_menu(view_menu_name): + return view_menu + + view_menu = security_manager.viewmenu_model() + view_menu.name = view_menu_name + session.add(view_menu) + return view_menu + + +def add_perm( + session: Session, + security_manager: SupersetSecurityManager, + permission_name: str | None, +) -> Permission: + """ + Similar to security_manager.add_permission, but without commit. + + This ensures an atomic operation. + """ + if perm := security_manager.find_permission(permission_name): + return perm + + perm = security_manager.permission_model() + perm.name = permission_name + session.add(perm) + return perm + + +def add_pvm( + session: Session, + security_manager: SupersetSecurityManager, + permission_name: str | None, + view_menu_name: str | None, +) -> PermissionView | None: + """ + Similar to security_manager.add_permission_view_menu, but without commit. + + This ensures an atomic operation. + """ + if not (permission_name and view_menu_name): + return None + + if pv := security_manager.find_permission_view_menu( + permission_name, view_menu_name + ): + return pv + + vm = add_vm(session, security_manager, view_menu_name) + perm = add_perm(session, security_manager, permission_name) + pv = security_manager.permissionview_model() + pv.view_menu, pv.permission = vm, perm + session.add(pv) + + return pv diff --git a/superset/config.py b/superset/config.py index 6d71ccd3dfd..0ce9e3017c5 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1917,6 +1917,15 @@ EXTRA_DYNAMIC_QUERY_FILTERS: ExtraDynamicQueryFilters = {} CATALOGS_SIMPLIFIED_MIGRATION: bool = False +# When updating a DB connection or manually triggering a perm sync, the command +# happens in sync mode. If you have a celery worker configured, it's recommended +# to change below config to ``True`` to run this process in async mode. A DB +# connection might have hundreds of catalogs with thousands of schemas each, which +# considerably increases the time to process it. Running it in async mode prevents +# keeping a web API call open for this long. +SYNC_DB_PERMISSIONS_IN_ASYNC_MODE: bool = False + + # ------------------------------------------------------------------- # * WARNING: STOP EDITING HERE * # ------------------------------------------------------------------- diff --git a/superset/constants.py b/superset/constants.py index 3374b2bd90b..b55048463ea 100644 --- a/superset/constants.py +++ b/superset/constants.py @@ -173,6 +173,7 @@ MODEL_API_RW_METHOD_PERMISSION_MAP = { "slack_channels": "write", "put_filters": "write", "put_colors": "write", + "sync_permissions": "write", } EXTRA_FORM_DATA_APPEND_KEYS = { diff --git a/superset/databases/api.py b/superset/databases/api.py index 22433612a8d..96b6f190b9f 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -52,6 +52,7 @@ from superset.commands.database.ssh_tunnel.exceptions import ( SSHTunnelDeleteFailedError, SSHTunnelingNotEnabledError, ) +from superset.commands.database.sync_permissions import SyncPermissionsCommand from superset.commands.database.tables import TablesDatabaseCommand from superset.commands.database.test_connection import TestConnectionDatabaseCommand from superset.commands.database.update import UpdateDatabaseCommand @@ -120,7 +121,11 @@ from superset.models.core import Database from superset.sql_parse import Table from superset.superset_typing import FlaskResponse from superset.utils import json -from superset.utils.core import error_msg_from_exception, parse_js_uri_path_item +from superset.utils.core import ( + error_msg_from_exception, + get_username, + parse_js_uri_path_item, +) from superset.utils.decorators import transaction from superset.utils.oauth2 import decode_oauth2_state from superset.utils.ssh_tunnel import mask_password_info @@ -165,6 +170,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): "upload_metadata", "upload", "oauth2", + "sync_permissions", } resource_name = "database" @@ -613,6 +619,53 @@ class DatabaseRestApi(BaseSupersetModelRestApi): ) return self.response_422(message=str(ex)) + @expose("//sync_permissions/", methods=("POST",)) + @protect() + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" + f".sync-permissions", + log_to_statsd=False, + ) + def sync_permissions(self, pk: int, **kwargs: Any) -> FlaskResponse: + """Sync all permissions for a database connection. + --- + post: + summary: Re-sync all permissions for a database connection + parameters: + - in: path + schema: + type: integer + name: pk + description: The database connection ID + responses: + 200: + description: Task created to sync permissions. + content: + application/json: + schema: + type: object + properties: + message: + type: string + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + 500: + $ref: '#/components/responses/500' + """ + current_username = get_username() + SyncPermissionsCommand( + pk, + current_username, + ).run() + if app.config["SYNC_DB_PERMISSIONS_IN_ASYNC_MODE"]: + return self.response(202, message="Async task created to sync permissions") + return self.response(200, message="Permissions successfully synced") + @expose("//catalogs/") @protect() @rison(database_catalogs_query_schema) diff --git a/superset/views/base.py b/superset/views/base.py index bc1b5720895..fe4ec0e0ab8 100644 --- a/superset/views/base.py +++ b/superset/views/base.py @@ -107,6 +107,7 @@ FRONTEND_CONF_KEYS = ( "PREVENT_UNSAFE_DEFAULT_URLS_ON_DATASET", "JWT_ACCESS_CSRF_COOKIE_NAME", "SQLLAB_QUERY_RESULT_TIMEOUT", + "SYNC_DB_PERMISSIONS_IN_ASYNC_MODE", ) logger = logging.getLogger(__name__) diff --git a/tests/conftest.py b/tests/conftest.py index 7becf7d7ae2..09a16f1871f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,7 +25,8 @@ # under the License. from __future__ import annotations -from typing import Callable, TYPE_CHECKING +import functools +from typing import Any, Callable, TYPE_CHECKING from unittest.mock import MagicMock, Mock, PropertyMock from pytest import fixture # noqa: PT013 @@ -106,3 +107,35 @@ def data_loader( return PandasDataLoader( example_db_engine, pandas_loader_configuration, table_to_df_convertor ) + + +def with_config(override_config: dict[str, Any]): + """ + Use this decorator to mock specific config keys. + + Usage: + + class TestYourFeature(SupersetTestCase): + + @with_config({"SOME_CONFIG": True}) + def test_your_config(self): + self.assertEqual(curren_app.config["SOME_CONFIG"), True) + + """ + + def decorate(test_fn): + config_backup = {} + + def wrapper(*args, **kwargs): + from flask import current_app + + for key, value in override_config.items(): + config_backup[key] = current_app.config[key] + current_app.config[key] = value + test_fn(*args, **kwargs) + for key, value in config_backup.items(): + current_app.config[key] = value + + return functools.update_wrapper(wrapper, test_fn) + + return decorate diff --git a/tests/integration_tests/base_api_tests.py b/tests/integration_tests/base_api_tests.py index 6c10b7cf26f..37cb6781997 100644 --- a/tests/integration_tests/base_api_tests.py +++ b/tests/integration_tests/base_api_tests.py @@ -33,8 +33,8 @@ from superset.models.dashboard import Dashboard from superset.views.base_api import BaseSupersetModelRestApi, requires_json # noqa: F401 from superset.utils import json +from tests.conftest import with_config from tests.integration_tests.base_tests import SupersetTestCase -from tests.integration_tests.conftest import with_config from tests.integration_tests.constants import ADMIN_USERNAME diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 5a2af7f04da..4f6ce10b0f9 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -256,38 +256,6 @@ def with_feature_flags(**mock_feature_flags): return decorate -def with_config(override_config: dict[str, Any]): - """ - Use this decorator to mock specific config keys. - - Usage: - - class TestYourFeature(SupersetTestCase): - - @with_config({"SOME_CONFIG": True}) - def test_your_config(self): - self.assertEqual(curren_app.config["SOME_CONFIG"), True) - - """ - - def decorate(test_fn): - config_backup = {} - - def wrapper(*args, **kwargs): - from flask import current_app - - for key, value in override_config.items(): - config_backup[key] = current_app.config[key] - current_app.config[key] = value - test_fn(*args, **kwargs) - for key, value in config_backup.items(): - current_app.config[key] = value - - return functools.update_wrapper(wrapper, test_fn) - - return decorate - - @pytest.fixture def virtual_dataset(): from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index a2f7de1909d..7db82ee95bf 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -49,8 +49,10 @@ from superset.models.core import Database, ConfigurationMethod from superset.reports.models import ReportSchedule, ReportScheduleType from superset.utils.database import get_example_database, get_main_database from superset.utils import json +from tests.conftest import with_config from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.constants import ADMIN_USERNAME, GAMMA_USERNAME +from tests.integration_tests.conftest import with_feature_flags from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, # noqa: F401 load_birth_names_data, # noqa: F401 @@ -437,6 +439,9 @@ class TestDatabaseApi(SupersetTestCase): == "A database port is required when connecting via SSH Tunnel." ) + @mock.patch( + "superset.commands.database.sync_permissions.SyncPermissionsCommand.run", + ) @mock.patch( "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run", ) @@ -451,6 +456,7 @@ class TestDatabaseApi(SupersetTestCase): mock_update_is_feature_enabled, mock_create_is_feature_enabled, mock_test_connection_database_command_run, + mock_sync_perms_command, ): """ Database API: Test update Database with SSH Tunnel @@ -498,6 +504,9 @@ class TestDatabaseApi(SupersetTestCase): db.session.delete(model) db.session.commit() + @mock.patch( + "superset.commands.database.sync_permissions.SyncPermissionsCommand.run", + ) @mock.patch( "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run", ) @@ -512,6 +521,7 @@ class TestDatabaseApi(SupersetTestCase): mock_update_is_feature_enabled, mock_create_is_feature_enabled, mock_test_connection_database_command_run, + mock_sync_perms_cmmd_run, ): """ Database API: Test update Database with SSH Tunnel @@ -626,6 +636,9 @@ class TestDatabaseApi(SupersetTestCase): db.session.delete(model) db.session.commit() + @mock.patch( + "superset.commands.database.sync_permissions.SyncPermissionsCommand.run", + ) @mock.patch( "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run", ) @@ -642,6 +655,7 @@ class TestDatabaseApi(SupersetTestCase): mock_update_is_feature_enabled, mock_create_is_feature_enabled, mock_test_connection_database_command_run, + mock_sync_perms_command, ): """ Database API: Test deleting a SSH tunnel via Database update @@ -710,6 +724,9 @@ class TestDatabaseApi(SupersetTestCase): db.session.delete(model) db.session.commit() + @mock.patch( + "superset.commands.database.sync_permissions.SyncPermissionsCommand.run", + ) @mock.patch( "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run", ) @@ -724,6 +741,7 @@ class TestDatabaseApi(SupersetTestCase): mock_update_is_feature_enabled, mock_create_is_feature_enabled, mock_test_connection_database_command_run, + mock_sync_perms_command, ): """ Database API: Test update SSH Tunnel via Database API @@ -945,6 +963,7 @@ class TestDatabaseApi(SupersetTestCase): db.session.delete(model) db.session.commit() + @with_feature_flags(SSH_TUNNELING=False) @mock.patch("superset.models.core.Database.get_all_catalog_names") @mock.patch("superset.models.core.Database.get_all_schema_names") def test_if_ssh_tunneling_flag_is_not_active_it_raises_new_exception( @@ -2098,7 +2117,7 @@ class TestDatabaseApi(SupersetTestCase): ) assert rv.status_code == 400 - @patch("superset.utils.log.logger") + @mock.patch("superset.utils.log.logger") @mock.patch("superset.security.manager.SupersetSecurityManager.can_access_database") @mock.patch("superset.models.core.Database.get_all_table_names_in_schema") def test_database_tables_unexpected_error( @@ -2870,6 +2889,7 @@ class TestDatabaseApi(SupersetTestCase): db.session.delete(database) db.session.commit() + @with_feature_flags(SSH_TUNNELING=False) @mock.patch("superset.commands.database.importers.v1.utils.add_permissions") def test_import_database_masked_ssh_tunnel_feature_flag_disabled( self, @@ -4171,3 +4191,176 @@ class TestDatabaseApi(SupersetTestCase): db.session.delete(first_model) db.session.delete(second_model) db.session.commit() + + @with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": False}) + def test_sync_db_perms_sync(self): + """ + Database API: Test sync permissions in sync mode. + """ + self.login(ADMIN_USERNAME) + example_db = get_example_database() + test_database = self.insert_database( + "test-database", example_db.sqlalchemy_uri_decrypted + ) + db_conn_id = test_database.id + + uri = f"api/v1/database/{db_conn_id}/sync_permissions/" + rv = self.client.post(uri) + assert rv.status_code == 200 + response = json.loads(rv.data.decode("utf-8")) + assert response == {"message": "Permissions successfully synced"} + + # Cleanup + model = db.session.query(Database).get(db_conn_id) + db.session.delete(model) + db.session.commit() + + @with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": False}) + @mock.patch("superset.commands.database.sync_permissions.DatabaseDAO.find_by_id") + def test_sync_db_perms_sync_db_not_found(self, mock_find_db): + """ + Database API: Test sync permissions in sync mode when the DB connection + is not found. + """ + self.login(ADMIN_USERNAME) + mock_find_db.return_value = None + + uri = "api/v1/database/10/sync_permissions/" + rv = self.client.post(uri) + assert rv.status_code == 404 + + @with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": False}) + @mock.patch("superset.commands.database.sync_permissions.ping") + def test_sync_db_perms_sync_db_connection_failed(self, mock_ping): + """ + Database API: Test sync permissions in sync mode when the DB connection + is not working. + """ + self.login(ADMIN_USERNAME) + mock_ping.return_value = False + example_db = get_example_database() + test_database = self.insert_database( + "test-database", example_db.sqlalchemy_uri_decrypted + ) + + uri = f"api/v1/database/{test_database.id}/sync_permissions/" + rv = self.client.post(uri) + assert rv.status_code == 500 + + # Cleanup + model = db.session.query(Database).get(test_database.id) + db.session.delete(model) + db.session.commit() + + @with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": True}) + @mock.patch( + "superset.commands.database.sync_permissions.sync_database_permissions_task.delay" + ) + def test_sync_db_perms_async(self, mock_task): + """ + Database API: Test sync permissions in async mode. + """ + self.login(ADMIN_USERNAME) + example_db = get_example_database() + test_database = self.insert_database( + "test-database", example_db.sqlalchemy_uri_decrypted + ) + db_conn_id = test_database.id + + uri = f"api/v1/database/{db_conn_id}/sync_permissions/" + rv = self.client.post(uri) + assert rv.status_code == 202 + response = json.loads(rv.data.decode("utf-8")) + assert response == {"message": "Async task created to sync permissions"} + mock_task.assert_called_once_with( + test_database.id, ADMIN_USERNAME, test_database.database_name + ) + + # Cleanup + model = db.session.query(Database).get(db_conn_id) + db.session.delete(model) + db.session.commit() + + @with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": True}) + @mock.patch("superset.commands.database.sync_permissions.DatabaseDAO.find_by_id") + def test_sync_db_perms_async_db_not_found(self, mock_find_db): + """ + Database API: Test sync permissions in async mode when the DB connection + is not found. + """ + self.login(ADMIN_USERNAME) + mock_find_db.return_value = None + + uri = "api/v1/database/10/sync_permissions/" + rv = self.client.post(uri) + assert rv.status_code == 404 + + @with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": True}) + @mock.patch("superset.commands.database.sync_permissions.ping") + def test_sync_db_perms_async_db_connection_failed(self, mock_ping): + """ + Database API: Test sync permissions in async mode when the DB connection + is not working. + """ + self.login(ADMIN_USERNAME) + mock_ping.return_value = False + example_db = get_example_database() + test_database = self.insert_database( + "test-database", example_db.sqlalchemy_uri_decrypted + ) + + uri = f"api/v1/database/{test_database.id}/sync_permissions/" + rv = self.client.post(uri) + assert rv.status_code == 500 + + # Cleanup + model = db.session.query(Database).get(test_database.id) + db.session.delete(model) + db.session.commit() + + @with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": True}) + @mock.patch( + "superset.commands.database.sync_permissions.security_manager.get_user_by_username" + ) + def test_sync_db_perms_async_user_not_found(self, mock_get_user): + """ + Database API: Test sync permissions in async mode when the user to be + impersonated can't be found. + """ + self.login(ADMIN_USERNAME) + mock_get_user.return_value = False + example_db = get_example_database() + test_database = self.insert_database( + "test-database", example_db.sqlalchemy_uri_decrypted + ) + + uri = f"api/v1/database/{test_database.id}/sync_permissions/" + rv = self.client.post(uri) + assert rv.status_code == 500 + + # Cleanup + model = db.session.query(Database).get(test_database.id) + db.session.delete(model) + db.session.commit() + + @mock.patch( + "superset.commands.database.sync_permissions.SyncPermissionsCommand.run" + ) + def test_sync_db_perms_no_access(self, mock_cmmd): + """ + Database API: Test sync permissions with a user without permission to do so. + """ + self.login(GAMMA_USERNAME) + example_db = get_example_database() + test_database = self.insert_database( + "test-database", example_db.sqlalchemy_uri_decrypted + ) + + uri = f"api/v1/database/{test_database.id}/sync_permissions/" + rv = self.client.post(uri) + assert rv.status_code == 403 + + # Cleanup + model = db.session.query(Database).get(test_database.id) + db.session.delete(model) + db.session.commit() diff --git a/tests/integration_tests/databases/commands_tests.py b/tests/integration_tests/databases/commands_tests.py index 4efe2dcbdc9..f6ab6b11604 100644 --- a/tests/integration_tests/databases/commands_tests.py +++ b/tests/integration_tests/databases/commands_tests.py @@ -887,7 +887,7 @@ class TestImportDatabasesCommand(SupersetTestCase): class TestTestConnectionDatabaseCommand(SupersetTestCase): - @patch("superset.daos.database.Database._get_sqla_engine") + @patch("superset.models.core.Database._get_sqla_engine") @patch("superset.commands.database.test_connection.event_logger.log_with_context") @patch("superset.utils.core.g") def test_connection_db_exception( @@ -908,7 +908,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase): ) mock_event_logger.assert_called() - @patch("superset.daos.database.Database._get_sqla_engine") + @patch("superset.models.core.Database._get_sqla_engine") @patch("superset.commands.database.test_connection.event_logger.log_with_context") @patch("superset.utils.core.g") def test_connection_do_ping_exception( @@ -931,7 +931,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase): == SupersetErrorType.GENERIC_DB_ENGINE_ERROR ) - @patch("superset.utils.core.timeout") + @patch("superset.commands.database.utils.timeout") @patch("superset.commands.database.test_connection.event_logger.log_with_context") @patch("superset.utils.core.g") def test_connection_do_ping_timeout( @@ -957,7 +957,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase): == SupersetErrorType.CONNECTION_DATABASE_TIMEOUT ) - @patch("superset.daos.database.Database._get_sqla_engine") + @patch("superset.models.core.Database._get_sqla_engine") @patch("superset.commands.database.test_connection.event_logger.log_with_context") @patch("superset.utils.core.g") def test_connection_superset_security_connection( @@ -980,7 +980,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase): mock_event_logger.assert_called() - @patch("superset.daos.database.Database._get_sqla_engine") + @patch("superset.models.core.Database._get_sqla_engine") @patch("superset.commands.database.test_connection.event_logger.log_with_context") @patch("superset.utils.core.g") def test_connection_db_api_exc( diff --git a/tests/integration_tests/security/api_tests.py b/tests/integration_tests/security/api_tests.py index 49c2064e5db..3bb85c4cad4 100644 --- a/tests/integration_tests/security/api_tests.py +++ b/tests/integration_tests/security/api_tests.py @@ -26,7 +26,7 @@ from superset.daos.dashboard import EmbeddedDashboardDAO from superset.models.dashboard import Dashboard from superset.utils.urls import get_url_host from superset.utils import json -from tests.integration_tests.conftest import with_config +from tests.conftest import with_config from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.constants import ADMIN_USERNAME, GAMMA_USERNAME from tests.integration_tests.fixtures.birth_names_dashboard import ( diff --git a/tests/integration_tests/security/row_level_security_tests.py b/tests/integration_tests/security/row_level_security_tests.py index 0916124a6f1..b31c4fa1367 100644 --- a/tests/integration_tests/security/row_level_security_tests.py +++ b/tests/integration_tests/security/row_level_security_tests.py @@ -33,7 +33,6 @@ from superset.utils import json from flask_babel import lazy_gettext as _ # noqa: F401 from flask_appbuilder.models.sqla import filters from tests.integration_tests.base_tests import SupersetTestCase -from tests.integration_tests.conftest import with_config # noqa: F401 from tests.integration_tests.constants import ADMIN_USERNAME from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, # noqa: F401 diff --git a/tests/integration_tests/users/api_tests.py b/tests/integration_tests/users/api_tests.py index e16a957d1fb..9f7423bb13a 100644 --- a/tests/integration_tests/users/api_tests.py +++ b/tests/integration_tests/users/api_tests.py @@ -21,8 +21,9 @@ from unittest.mock import patch from superset import security_manager from superset.utils import json, slack # noqa: F401 +from tests.conftest import with_config from tests.integration_tests.base_tests import SupersetTestCase -from tests.integration_tests.conftest import with_config, with_feature_flags +from tests.integration_tests.conftest import with_feature_flags from tests.integration_tests.constants import ADMIN_USERNAME meUri = "/api/v1/me/" # noqa: N816 diff --git a/tests/unit_tests/commands/databases/conftest.py b/tests/unit_tests/commands/databases/conftest.py new file mode 100644 index 00000000000..49da52daf91 --- /dev/null +++ b/tests/unit_tests/commands/databases/conftest.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from superset.db_engine_specs.base import BaseEngineSpec +from superset.exceptions import OAuth2RedirectError +from superset.utils import json + +oauth2_client_info = { + "id": "client_id", + "secret": "client_secret", + "scope": "scope-a", + "redirect_uri": "redirect_uri", + "authorization_request_uri": "auth_uri", + "token_request_uri": "token_uri", + "request_content_type": "json", +} + + +@pytest.fixture +def database_with_catalog(mocker: MockerFixture) -> MagicMock: + """ + Mock a database with catalogs and schemas. + """ + database = mocker.MagicMock() + database.database_name = "my_db" + database.db_engine_spec.__name__ = "test_engine" + database.db_engine_spec.supports_catalog = True + database.get_all_catalog_names.return_value = ["catalog1", "catalog2"] + database.get_all_schema_names.side_effect = [ + ["schema1", "schema2"], + ["schema3", "schema4"], + ] + database.get_default_catalog.return_value = "catalog2" + + return database + + +@pytest.fixture +def database_without_catalog(mocker: MockerFixture) -> MagicMock: + """ + Mock a database without catalogs. + """ + database = mocker.MagicMock() + database.database_name = "my_db" + database.db_engine_spec.__name__ = "test_engine" + database.db_engine_spec.supports_catalog = False + database.get_all_schema_names.return_value = ["schema1", "schema2"] + + return database + + +@pytest.fixture +def database_needs_oauth2(mocker: MockerFixture) -> MagicMock: + """ + Mock a database without catalogs that needs OAuth2. + """ + database = mocker.MagicMock() + database.database_name = "my_db" + database.db_engine_spec.__name__ = "test_engine" + database.db_engine_spec.supports_catalog = False + database.get_all_schema_names.side_effect = OAuth2RedirectError( + "url", + "tab_id", + "redirect_uri", + ) + database.encrypted_extra = json.dumps({"oauth2_client_info": oauth2_client_info}) + database.db_engine_spec.unmask_encrypted_extra = ( + BaseEngineSpec.unmask_encrypted_extra + ) + + return database diff --git a/tests/unit_tests/commands/databases/sync_permissions_test.py b/tests/unit_tests/commands/databases/sync_permissions_test.py new file mode 100644 index 00000000000..10d0723f504 --- /dev/null +++ b/tests/unit_tests/commands/databases/sync_permissions_test.py @@ -0,0 +1,389 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from superset import db +from superset.commands.database.exceptions import ( + DatabaseConnectionFailedError, + DatabaseNotFoundError, + UserNotFoundInSessionError, +) +from superset.commands.database.sync_permissions import SyncPermissionsCommand +from superset.db_engine_specs.base import GenericDBException +from superset.exceptions import OAuth2RedirectError +from superset.extensions import security_manager +from tests.conftest import with_config + + +@with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": False}) +def test_sync_permissions_command_sync_mode( + mocker: MockerFixture, + database_with_catalog: MagicMock, +): + """ + Test ``SyncPermissionsCommand`` in sync mode. + """ + mock_ssh = mocker.MagicMock() + user_mock = mocker.patch( + "superset.commands.database.sync_permissions.security_manager.get_user_by_username" + ) + mocker.patch("superset.commands.database.sync_permissions.ping", return_value=True) + find_pvm_mock = mocker.patch( + "superset.commands.database.sync_permissions.security_manager.find_permission_view_menu" + ) + find_pvm_mock.side_effect = [mocker.MagicMock(), None] + add_pvm_mock = mocker.patch("superset.commands.database.sync_permissions.add_pvm") + + cmmd = SyncPermissionsCommand( + 1, "admin", db_connection=database_with_catalog, ssh_tunnel=mock_ssh + ) + mock_refresh_schemas = mocker.patch.object(cmmd, "_refresh_schemas") + mock_rename_db_perm = mocker.patch.object(cmmd, "_rename_database_in_permissions") + + cmmd.run() + + assert cmmd.db_connection == database_with_catalog + assert cmmd.old_db_connection_name == "my_db" + assert cmmd.db_connection_ssh_tunnel == mock_ssh + user_mock.assert_called_once_with("admin") + add_pvm_mock.assert_has_calls( + [ + mocker.call( + db.session, security_manager, "catalog_access", "[my_db].[catalog2]" + ), + mocker.call( + db.session, + security_manager, + "schema_access", + "[my_db].[catalog2].[schema3]", + ), + mocker.call( + db.session, + security_manager, + "schema_access", + "[my_db].[catalog2].[schema4]", + ), + ] + ) + mock_refresh_schemas.assert_called_once_with("catalog1", ["schema1", "schema2"]) + mock_rename_db_perm.assert_not_called() + + +@with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": True}) +def test_sync_permissions_command_async_mode( + mocker: MockerFixture, database_with_catalog: MagicMock +) -> None: + """ + Test ``SyncPermissionsCommand`` in async mode. + """ + mock_database_dao = mocker.patch( + "superset.commands.database.sync_permissions.DatabaseDAO" + ) + mock_database_dao.find_by_id.return_value = database_with_catalog + mocker.patch( + "superset.commands.database.sync_permissions.security_manager.get_user_by_username" + ) + async_task_mock = mocker.patch( + "superset.commands.database.sync_permissions.sync_database_permissions_task" + ) + mocker.patch("superset.commands.database.sync_permissions.ping", return_value=True) + + cmmd = SyncPermissionsCommand(1, "admin") + cmmd.run() + async_task_mock.delay.assert_called_once_with(1, "admin", "my_db") + + +@with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": False}) +def test_sync_permissions_command_passing_all_values( + mocker: MockerFixture, database_with_catalog: MagicMock +): + """ + Test ``SyncPermissionsCommand`` when providing all arguments to the constructor. + """ + mock_ssh = mocker.MagicMock() + mock_database_dao = mocker.patch( + "superset.commands.database.sync_permissions.DatabaseDAO" + ) + mocker.patch( + "superset.commands.database.sync_permissions.security_manager.get_user_by_username" + ) + mocker.patch("superset.commands.database.sync_permissions.ping", return_value=True) + + cmmd = SyncPermissionsCommand( + 1, + "admin", + old_db_connection_name="old name", + db_connection=database_with_catalog, + ssh_tunnel=mock_ssh, + ) + mocker.patch.object(cmmd, "sync_database_permissions") + cmmd.run() + + assert cmmd.db_connection == database_with_catalog + assert cmmd.old_db_connection_name == "old name" + assert cmmd.db_connection_ssh_tunnel == mock_ssh + mock_database_dao.find_by_id.assert_not_called() + mock_database_dao.get_ssh_tunnel.assert_not_called() + + +@with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": False}) +def test_sync_permissions_command_raise(mocker: MockerFixture): + """ + Test ``SyncPermissionsCommand`` when an exception is raised. + """ + mock_database_dao = mocker.patch( + "superset.commands.database.sync_permissions.DatabaseDAO" + ) + mock_database_dao.find_by_id.return_value = mocker.MagicMock() + mock_database_dao.get_ssh_tunnel.return_value = mocker.MagicMock() + mock_user = mocker.patch( + "superset.commands.database.sync_permissions.security_manager.get_user_by_username" + ) + + # Connection issues + mock_ping = mocker.patch( + "superset.commands.database.sync_permissions.ping", return_value=False + ) + with pytest.raises(DatabaseConnectionFailedError): + SyncPermissionsCommand(1, "admin").run() + mock_ping.reset_mock() + mock_ping.side_effect = Exception + with pytest.raises(DatabaseConnectionFailedError): + SyncPermissionsCommand(1, "admin").run() + + # User not found in session + mock_user.reset_mock() + mock_user.return_value = None + with pytest.raises(UserNotFoundInSessionError): + SyncPermissionsCommand(1, "admin").run() + mock_user.reset_mock() + mock_user.return_value = mocker.MagicMock() + + # DB connection not found + mock_database_dao.reset_mock() + mock_database_dao.find_by_id.return_value = None + with pytest.raises(DatabaseNotFoundError): + SyncPermissionsCommand(1, "admin").run() + + +@with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": False}) +def test_sync_permissions_command_new_db_name( + mocker: MockerFixture, database_with_catalog: MagicMock +): + """ + Test ``SyncPermissionsCommand`` when the database name changed. + """ + mocker.patch( + "superset.commands.database.sync_permissions.security_manager.get_user_by_username" + ) + cmmd = SyncPermissionsCommand( + 1, + "admin", + old_db_connection_name="Old Name", + db_connection=database_with_catalog, + ) + cmmd.run() + + assert cmmd.old_db_connection_name == "Old Name" + + +@with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": True}) +def test_sync_permissions_command_async_mode_new_db_name( + mocker: MockerFixture, database_with_catalog: MagicMock +): + """ + Test ``SyncPermissionsCommand`` in async mode when the + database name changed. + """ + mocker.patch( + "superset.commands.database.sync_permissions.security_manager.get_user_by_username" + ) + async_task_mock = mocker.patch( + "superset.commands.database.sync_permissions.sync_database_permissions_task" + ) + cmmd = SyncPermissionsCommand( + 1, + "admin", + old_db_connection_name="Old Name", + db_connection=database_with_catalog, + ) + cmmd.run() + + async_task_mock.delay.assert_called_once_with(1, "admin", "Old Name") + + +def test_resync_permissions_command_get_catalogs(database_with_catalog: MagicMock): + """ + Test the ``_get_catalog_names`` method. + """ + cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog) + assert cmmd._get_catalog_names() == ["catalog1", "catalog2"] + + +@pytest.mark.parametrize( + ("inner_exception, outer_exception"), + [ + ( + OAuth2RedirectError("Missing token", "mock_tab", "mock_url"), + OAuth2RedirectError, + ), + (GenericDBException, DatabaseConnectionFailedError), + ], +) +def test_resync_permissions_command_raise_on_getting_catalogs( + inner_exception: Exception, + outer_exception: Exception, + database_with_catalog: MagicMock, +): + """ + Test the ``_get_catalog_names`` method when raising an exception. + """ + database_with_catalog.get_all_catalog_names.side_effect = inner_exception + cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog) + with pytest.raises(outer_exception): + cmmd._get_catalog_names() + + +def test_resync_permissions_command_get_schemas(database_with_catalog: MagicMock): + """ + Test the ``_get_schema_names`` method. + """ + cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog) + assert cmmd._get_schema_names("catalog1") == ["schema1", "schema2"] + assert cmmd._get_schema_names("catalog2") == ["schema3", "schema4"] + + +@pytest.mark.parametrize( + ("inner_exception, outer_exception"), + [ + ( + OAuth2RedirectError("Missing token", "mock_tab", "mock_url"), + OAuth2RedirectError, + ), + (GenericDBException, DatabaseConnectionFailedError), + ], +) +def test_resync_permissions_command_raise_on_getting_schemas( + inner_exception: Exception, + outer_exception: Exception, + database_with_catalog: MagicMock, +): + """ + Test the ``_get_schema_names`` method when raising an exception. + """ + database_with_catalog.get_all_schema_names.side_effect = inner_exception + cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog) + with pytest.raises(outer_exception): + cmmd._get_schema_names("blah") + + +def test_resync_permissions_command_refresh_schemas( + mocker: MockerFixture, database_with_catalog: MagicMock +): + """ + Test the ``_refresh_schemas`` method. + """ + find_pvm_mock = mocker.patch( + "superset.commands.database.sync_permissions.security_manager.find_permission_view_menu" + ) + find_pvm_mock.side_effect = [mocker.MagicMock(), None] + add_pvm_mock = mocker.patch("superset.commands.database.sync_permissions.add_pvm") + + cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog) + cmmd._refresh_schemas("catalog1", ["schema1", "schema2"]) + + add_pvm_mock.assert_called_once_with( + db.session, + security_manager, + "schema_access", + f"[{database_with_catalog.name}].[catalog1].[schema2]", + ) + + +def test_resync_permissions_command_rename_db_in_perms( + mocker: MockerFixture, database_with_catalog: MagicMock +): + """ + Test the ``_rename_database_in_permissions`` method. + """ + find_pvm_mock = mocker.patch( + "superset.commands.database.sync_permissions.security_manager.find_permission_view_menu" + ) + mock_catalog_perm = mocker.MagicMock() + mock_catalog_perm.view_menu.name = "[old_name].[catalog]" + mock_schema_perm = mocker.MagicMock() + mock_schema_perm.view_menu.name = "[old_name].[catalog].[schema1]" + find_pvm_mock.side_effect = [ + mock_catalog_perm, + mock_schema_perm, + None, + ] + + mock_dataset = mocker.MagicMock() + mock_dataset.id = 1 + mock_dataset.catalog_perm = "[old_name].[catalog1]" + mock_dataset.schema_perm = "[old_name].[catalog1].[schema1]" + mock_chart = mocker.MagicMock() + mock_chart.catalog_perm = "[old_name].[catalog1]" + mock_chart.schema_perm = "[old_name].[catalog1].[schema1]" + + mock_database_dao = mocker.patch( + "superset.commands.database.sync_permissions.DatabaseDAO" + ) + mock_database_dao.get_datasets.side_effect = [ + [mock_dataset], + [], + ] + mock_dataset_dao = mocker.patch( + "superset.commands.database.sync_permissions.DatasetDAO" + ) + mock_dataset_dao.get_related_objects.return_value = {"charts": [mock_chart]} + + cmmd = SyncPermissionsCommand( + 1, None, old_db_connection_name="old_name", db_connection=database_with_catalog + ) + cmmd._rename_database_in_permissions("catalog1", ["schema1", "schema2"]) + + find_pvm_mock.assert_has_calls( + [ + mocker.call("catalog_access", "[old_name].[catalog1]"), + mocker.call("schema_access", "[old_name].[catalog1].[schema1]"), + mocker.call("schema_access", "[old_name].[catalog1].[schema2]"), + ] + ) + + assert ( + mock_catalog_perm.view_menu.name == f"[{database_with_catalog.name}].[catalog1]" + ) + assert ( + mock_schema_perm.view_menu.name + == f"[{database_with_catalog.name}].[catalog1].[schema1]" + ) + assert mock_dataset.catalog_perm == f"[{database_with_catalog.name}].[catalog1]" + assert ( + mock_dataset.schema_perm + == f"[{database_with_catalog.name}].[catalog1].[schema1]" + ) + assert mock_chart.catalog_perm == f"[{database_with_catalog.name}].[catalog1]" + assert ( + mock_chart.schema_perm == f"[{database_with_catalog.name}].[catalog1].[schema1]" + ) diff --git a/tests/unit_tests/commands/databases/update_test.py b/tests/unit_tests/commands/databases/update_test.py index daf41b75068..2d25953714e 100644 --- a/tests/unit_tests/commands/databases/update_test.py +++ b/tests/unit_tests/commands/databases/update_test.py @@ -17,84 +17,19 @@ from unittest.mock import MagicMock -import pytest from pytest_mock import MockerFixture +from superset import db from superset.commands.database.update import UpdateDatabaseCommand -from superset.db_engine_specs.base import BaseEngineSpec -from superset.exceptions import OAuth2RedirectError from superset.extensions import security_manager from superset.utils import json - -oauth2_client_info = { - "id": "client_id", - "secret": "client_secret", - "scope": "scope-a", - "redirect_uri": "redirect_uri", - "authorization_request_uri": "auth_uri", - "token_request_uri": "token_uri", - "request_content_type": "json", -} - - -@pytest.fixture -def database_with_catalog(mocker: MockerFixture) -> MagicMock: - """ - Mock a database with catalogs and schemas. - """ - database = mocker.MagicMock() - database.database_name = "my_db" - database.db_engine_spec.__name__ = "test_engine" - database.db_engine_spec.supports_catalog = True - database.get_all_catalog_names.return_value = ["catalog1", "catalog2"] - database.get_all_schema_names.side_effect = [ - ["schema1", "schema2"], - ["schema3", "schema4"], - ] - database.get_default_catalog.return_value = "catalog2" - - return database - - -@pytest.fixture -def database_without_catalog(mocker: MockerFixture) -> MagicMock: - """ - Mock a database without catalogs. - """ - database = mocker.MagicMock() - database.database_name = "my_db" - database.db_engine_spec.__name__ = "test_engine" - database.db_engine_spec.supports_catalog = False - database.get_all_schema_names.return_value = ["schema1", "schema2"] - - return database - - -@pytest.fixture -def database_needs_oauth2(mocker: MockerFixture) -> MagicMock: - """ - Mock a database without catalogs that needs OAuth2. - """ - database = mocker.MagicMock() - database.database_name = "my_db" - database.db_engine_spec.__name__ = "test_engine" - database.db_engine_spec.supports_catalog = False - database.get_all_schema_names.side_effect = OAuth2RedirectError( - "url", - "tab_id", - "redirect_uri", - ) - database.encrypted_extra = json.dumps({"oauth2_client_info": oauth2_client_info}) - database.db_engine_spec.unmask_encrypted_extra = ( - BaseEngineSpec.unmask_encrypted_extra - ) - - return database +from tests.conftest import with_config +from tests.unit_tests.commands.databases.conftest import oauth2_client_info def test_update_with_catalog( mocker: MockerFixture, - database_with_catalog: MockerFixture, + database_with_catalog: MagicMock, ) -> None: """ Test that permissions are updated correctly. @@ -111,9 +46,15 @@ def test_update_with_catalog( When update is called, only `catalog2.schema3` has permissions associated with it, so `catalog1.*` and `catalog2.schema4` are added. """ - DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO") # noqa: N806 - DatabaseDAO.find_by_id.return_value = database_with_catalog - DatabaseDAO.update.return_value = database_with_catalog + database_dao = mocker.patch("superset.commands.database.update.DatabaseDAO") + database_dao.find_by_id.return_value = database_with_catalog + database_dao.update.return_value = database_with_catalog + sync_db_perms_dao = mocker.patch( + "superset.commands.database.sync_permissions.DatabaseDAO" + ) + sync_db_perms_dao.find_by_id.return_value = database_with_catalog + mocker.patch("superset.commands.database.update.get_username") + mocker.patch("superset.security_manager.get_user_by_username") find_permission_view_menu = mocker.patch.object( security_manager, @@ -128,25 +69,66 @@ def test_update_with_catalog( None, None, ] - add_permission_view_menu = mocker.patch.object( - security_manager, - "add_permission_view_menu", - ) + add_pvm = mocker.patch("superset.commands.database.sync_permissions.add_pvm") UpdateDatabaseCommand(1, {}).run() - add_permission_view_menu.assert_has_calls( + add_pvm.assert_has_calls( [ # first catalog is added with all schemas - mocker.call("catalog_access", "[my_db].[catalog1]"), - mocker.call("schema_access", "[my_db].[catalog1].[schema1]"), - mocker.call("schema_access", "[my_db].[catalog1].[schema2]"), + mocker.call( + db.session, security_manager, "catalog_access", "[my_db].[catalog1]" + ), + mocker.call( + db.session, + security_manager, + "schema_access", + "[my_db].[catalog1].[schema1]", + ), + mocker.call( + db.session, + security_manager, + "schema_access", + "[my_db].[catalog1].[schema2]", + ), # second catalog already exists, only `schema4` is added - mocker.call("schema_access", "[my_db].[catalog2].[schema4]"), + mocker.call( + db.session, + security_manager, + "schema_access", + f"[{database_with_catalog.name}].[catalog2].[schema4]", + ), ], ) +@with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": True}) +def test_update_sync_perms_in_async_mode( + mocker: MockerFixture, + database_with_catalog: MagicMock, +) -> None: + """ + Test that updating a DB connection with async mode enables + triggers the celery task to syn perms. + """ + database_dao = mocker.patch("superset.commands.database.update.DatabaseDAO") + database_dao.find_by_id.return_value = database_with_catalog + database_dao.update.return_value = database_with_catalog + sync_db_perms_dao = mocker.patch( + "superset.commands.database.sync_permissions.DatabaseDAO" + ) + sync_db_perms_dao.find_by_id.return_value = database_with_catalog + sync_task = mocker.patch( + "superset.commands.database.sync_permissions.sync_database_permissions_task.delay" + ) + mocker.patch("superset.commands.database.update.get_username", return_value="admin") + mocker.patch("superset.security_manager.get_user_by_username") + + UpdateDatabaseCommand(1, {}).run() + + sync_task.assert_called_once_with(1, "admin", "my_db") + + def test_update_without_catalog( mocker: MockerFixture, database_without_catalog: MockerFixture, @@ -162,9 +144,15 @@ def test_update_without_catalog( When update is called, only `schema2` has permissions associated with it, so `schema1` is added. """ # noqa: E501 - DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO") # noqa: N806 - DatabaseDAO.find_by_id.return_value = database_without_catalog - DatabaseDAO.update.return_value = database_without_catalog + database_dao = mocker.patch("superset.commands.database.update.DatabaseDAO") + database_dao.find_by_id.return_value = database_without_catalog + database_dao.update.return_value = database_without_catalog + sync_db_perms_dao = mocker.patch( + "superset.commands.database.sync_permissions.DatabaseDAO" + ) + sync_db_perms_dao.find_by_id.return_value = database_without_catalog + mocker.patch("superset.commands.database.update.get_username") + mocker.patch("superset.security_manager.get_user_by_username") find_permission_view_menu = mocker.patch.object( security_manager, @@ -174,22 +162,21 @@ def test_update_without_catalog( None, # schema1 has no permissions "[my_db].[schema2]", # second schema already exists ] - add_permission_view_menu = mocker.patch.object( - security_manager, - "add_permission_view_menu", - ) + add_pvm = mocker.patch("superset.commands.database.sync_permissions.add_pvm") UpdateDatabaseCommand(1, {}).run() - add_permission_view_menu.assert_called_with( + add_pvm.assert_called_with( + db.session, + security_manager, "schema_access", - "[my_db].[schema1]", + f"[{database_without_catalog.name}].[schema1]", ) def test_rename_with_catalog( mocker: MockerFixture, - database_with_catalog: MockerFixture, + database_with_catalog: MagicMock, ) -> None: """ Test that permissions are renamed correctly. @@ -207,25 +194,33 @@ def test_rename_with_catalog( so `catalog1.*` and `catalog2.schema4` are added. Additionally, the database has been renamed from `my_db` to `my_other_db`. """ - DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO") # noqa: N806 original_database = mocker.MagicMock() original_database.database_name = "my_db" - DatabaseDAO.find_by_id.return_value = original_database + database_dao = mocker.patch("superset.commands.database.update.DatabaseDAO") + database_dao.find_by_id.return_value = original_database database_with_catalog.database_name = "my_other_db" - DatabaseDAO.update.return_value = database_with_catalog + database_dao.update.return_value = database_with_catalog + sync_db_perms_dao = mocker.patch( + "superset.commands.database.sync_permissions.DatabaseDAO" + ) + sync_db_perms_dao.find_by_id.return_value = database_with_catalog + mocker.patch("superset.commands.database.update.get_username") + mocker.patch("superset.security_manager.get_user_by_username") dataset = mocker.MagicMock() chart = mocker.MagicMock() - DatabaseDAO.get_datasets.return_value = [dataset] - DatasetDAO = mocker.patch("superset.commands.database.update.DatasetDAO") # noqa: N806 - DatasetDAO.get_related_objects.return_value = {"charts": [chart]} + sync_db_perms_dao.get_datasets.return_value = [dataset] + dataset_dao = mocker.patch("superset.commands.database.sync_permissions.DatasetDAO") + dataset_dao.get_related_objects.return_value = {"charts": [chart]} find_permission_view_menu = mocker.patch.object( security_manager, "find_permission_view_menu", ) catalog2_pvm = mocker.MagicMock() + catalog2_pvm.view_menu.name = "[my_db].[catalog2]" catalog2_schema3_pvm = mocker.MagicMock() + catalog2_schema3_pvm.view_menu.name = "[my_db].[catalog2].[schema3]" find_permission_view_menu.side_effect = [ # these are called when adding the permissions: None, # first catalog is new @@ -237,31 +232,52 @@ def test_rename_with_catalog( catalog2_schema3_pvm, # old [my_db].[catalog2].[schema3] None, # [my_db].[catalog2].[schema4] doesn't exist ] - add_permission_view_menu = mocker.patch.object( - security_manager, - "add_permission_view_menu", - ) + add_pvm = mocker.patch("superset.commands.database.sync_permissions.add_pvm") + add_vm = mocker.patch("superset.commands.database.sync_permissions.add_vm") UpdateDatabaseCommand(1, {}).run() - add_permission_view_menu.assert_has_calls( + add_pvm.assert_has_calls( [ # first catalog is added with all schemas with the new DB name - mocker.call("catalog_access", "[my_other_db].[catalog1]"), - mocker.call("schema_access", "[my_other_db].[catalog1].[schema1]"), - mocker.call("schema_access", "[my_other_db].[catalog1].[schema2]"), + mocker.call( + db.session, + security_manager, + "catalog_access", + "[my_other_db].[catalog1]", + ), + mocker.call( + db.session, + security_manager, + "schema_access", + "[my_other_db].[catalog1].[schema1]", + ), + mocker.call( + db.session, + security_manager, + "schema_access", + "[my_other_db].[catalog1].[schema2]", + ), # second catalog already exists, only `schema4` is added - mocker.call("schema_access", "[my_other_db].[catalog2].[schema4]"), + mocker.call( + db.session, + security_manager, + "schema_access", + f"[{database_with_catalog.name}].[catalog2].[schema4]", + ), ], ) - assert catalog2_pvm.view_menu.name == "[my_other_db].[catalog2]" - assert catalog2_schema3_pvm.view_menu.name == "[my_other_db].[catalog2].[schema3]" + assert catalog2_pvm.view_menu == add_vm.return_value + assert ( + catalog2_schema3_pvm.view_menu.name + == f"[{database_with_catalog.name}].[catalog2].[schema3]" + ) - assert dataset.catalog_perm == "[my_other_db].[catalog2]" - assert dataset.schema_perm == "[my_other_db].[catalog2].[schema4]" - assert chart.catalog_perm == "[my_other_db].[catalog2]" - assert chart.schema_perm == "[my_other_db].[catalog2].[schema4]" + assert dataset.catalog_perm == f"[{database_with_catalog.name}].[catalog2]" + assert dataset.schema_perm == f"[{database_with_catalog.name}].[catalog2].[schema4]" + assert chart.catalog_perm == f"[{database_with_catalog.name}].[catalog2]" + assert chart.schema_perm == f"[{database_with_catalog.name}].[catalog2].[schema4]" def test_rename_without_catalog( @@ -279,38 +295,44 @@ def test_rename_without_catalog( When update is called, only `schema2` has permissions associated with it, so `schema1` is added. Additionally, the database has been renamed from `my_db` to `my_other_db`. """ # noqa: E501 - DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO") # noqa: N806 + database_dao = mocker.patch("superset.commands.database.update.DatabaseDAO") original_database = mocker.MagicMock() original_database.database_name = "my_db" - DatabaseDAO.find_by_id.return_value = original_database database_without_catalog.database_name = "my_other_db" - DatabaseDAO.update.return_value = database_without_catalog - DatabaseDAO.get_datasets.return_value = [] + database_dao.update.return_value = database_without_catalog + database_dao.find_by_id.return_value = original_database + sync_db_perms_dao = mocker.patch( + "superset.commands.database.sync_permissions.DatabaseDAO" + ) + sync_db_perms_dao.find_by_id.return_value = database_without_catalog + sync_db_perms_dao.get_datasets.return_value = [] + mocker.patch("superset.commands.database.update.get_username") + mocker.patch("superset.security_manager.get_user_by_username") find_permission_view_menu = mocker.patch.object( security_manager, "find_permission_view_menu", ) schema2_pvm = mocker.MagicMock() + schema2_pvm.view_menu.name = "[my_db].[schema2]" find_permission_view_menu.side_effect = [ None, # schema1 has no permissions "[my_db].[schema2]", # second schema already exists None, # [my_db].[schema1] doesn't exist schema2_pvm, # old [my_db].[schema2] ] - add_permission_view_menu = mocker.patch.object( - security_manager, - "add_permission_view_menu", - ) + add_pvm = mocker.patch("superset.commands.database.sync_permissions.add_pvm") UpdateDatabaseCommand(1, {}).run() - add_permission_view_menu.assert_called_with( + add_pvm.assert_called_with( + db.session, + security_manager, "schema_access", - "[my_other_db].[schema1]", + f"[{database_without_catalog.name}].[schema1]", ) - assert schema2_pvm.view_menu.name == "[my_other_db].[schema2]" + assert schema2_pvm.view_menu.name == f"[{database_without_catalog.name}].[schema2]" def test_update_with_oauth2( @@ -320,9 +342,15 @@ def test_update_with_oauth2( """ Test that the database can be updated even if OAuth2 is needed to connect. """ - DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO") # noqa: N806 - DatabaseDAO.find_by_id.return_value = database_needs_oauth2 - DatabaseDAO.update.return_value = database_needs_oauth2 + database_dao = mocker.patch("superset.commands.database.update.DatabaseDAO") + database_dao.find_by_id.return_value = database_needs_oauth2 + database_dao.update.return_value = database_needs_oauth2 + sync_db_perms_dao = mocker.patch( + "superset.commands.database.sync_permissions.DatabaseDAO" + ) + sync_db_perms_dao.find_by_id.return_value = database_needs_oauth2 + mocker.patch("superset.commands.database.update.get_username") + mocker.patch("superset.security_manager.get_user_by_username") find_permission_view_menu = mocker.patch.object( security_manager, @@ -332,14 +360,11 @@ def test_update_with_oauth2( None, # schema1 has no permissions "[my_db].[schema2]", # second schema already exists ] - add_permission_view_menu = mocker.patch.object( - security_manager, - "add_permission_view_menu", - ) + add_pvm = mocker.patch("superset.commands.database.sync_permissions.add_pvm") UpdateDatabaseCommand(1, {}).run() - add_permission_view_menu.assert_not_called() + add_pvm.assert_not_called() database_needs_oauth2.purge_oauth2_tokens.assert_not_called() @@ -350,9 +375,15 @@ def test_update_with_oauth2_changed( """ Test that the database can be updated even if OAuth2 is needed to connect. """ - DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO") # noqa: N806 - DatabaseDAO.find_by_id.return_value = database_needs_oauth2 - DatabaseDAO.update.return_value = database_needs_oauth2 + database_dao = mocker.patch("superset.commands.database.update.DatabaseDAO") + database_dao.find_by_id.return_value = database_needs_oauth2 + database_dao.update.return_value = database_needs_oauth2 + sync_db_perms_dao = mocker.patch( + "superset.commands.database.sync_permissions.DatabaseDAO" + ) + sync_db_perms_dao.find_by_id.return_value = database_needs_oauth2 + mocker.patch("superset.commands.database.update.get_username") + mocker.patch("superset.security_manager.get_user_by_username") find_permission_view_menu = mocker.patch.object( security_manager, @@ -362,10 +393,7 @@ def test_update_with_oauth2_changed( None, # schema1 has no permissions "[my_db].[schema2]", # second schema already exists ] - add_permission_view_menu = mocker.patch.object( - security_manager, - "add_permission_view_menu", - ) + add_pvm = mocker.patch("superset.commands.database.sync_permissions.add_pvm") modified_oauth2_client_info = oauth2_client_info.copy() modified_oauth2_client_info["scope"] = "scope-b" @@ -379,7 +407,7 @@ def test_update_with_oauth2_changed( }, ).run() - add_permission_view_menu.assert_not_called() + add_pvm.assert_not_called() database_needs_oauth2.purge_oauth2_tokens.assert_called() @@ -390,9 +418,15 @@ def test_remove_oauth_config_purges_tokens( """ Test that removing the OAuth config from a database purges existing tokens. """ - DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO") # noqa: N806 - DatabaseDAO.find_by_id.return_value = database_needs_oauth2 - DatabaseDAO.update.return_value = database_needs_oauth2 + database_dao = mocker.patch("superset.commands.database.update.DatabaseDAO") + database_dao.find_by_id.return_value = database_needs_oauth2 + database_dao.update.return_value = database_needs_oauth2 + sync_db_perms_dao = mocker.patch( + "superset.commands.database.sync_permissions.DatabaseDAO" + ) + sync_db_perms_dao.find_by_id.return_value = database_needs_oauth2 + mocker.patch("superset.commands.database.update.get_username") + mocker.patch("superset.security_manager.get_user_by_username") find_permission_view_menu = mocker.patch.object( security_manager, @@ -402,19 +436,16 @@ def test_remove_oauth_config_purges_tokens( None, "[my_db].[schema2]", ] - add_permission_view_menu = mocker.patch.object( - security_manager, - "add_permission_view_menu", - ) + add_pvm = mocker.patch("superset.commands.database.sync_permissions.add_pvm") UpdateDatabaseCommand(1, {"masked_encrypted_extra": None}).run() - add_permission_view_menu.assert_not_called() + add_pvm.assert_not_called() database_needs_oauth2.purge_oauth2_tokens.assert_called() UpdateDatabaseCommand(1, {"masked_encrypted_extra": "{}"}).run() - add_permission_view_menu.assert_not_called() + add_pvm.assert_not_called() database_needs_oauth2.purge_oauth2_tokens.assert_called() @@ -425,9 +456,15 @@ def test_update_oauth2_removes_masked_encrypted_extra_key( """ Test that the ``masked_encrypted_extra`` key is properly purged from the properties. """ - DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO") # noqa: N806 - DatabaseDAO.find_by_id.return_value = database_needs_oauth2 - DatabaseDAO.update.return_value = database_needs_oauth2 + database_dao = mocker.patch("superset.commands.database.update.DatabaseDAO") + database_dao.find_by_id.return_value = database_needs_oauth2 + database_dao.update.return_value = database_needs_oauth2 + sync_db_perms_dao = mocker.patch( + "superset.commands.database.sync_permissions.DatabaseDAO" + ) + sync_db_perms_dao.find_by_id.return_value = database_needs_oauth2 + mocker.patch("superset.commands.database.update.get_username") + mocker.patch("superset.security_manager.get_user_by_username") find_permission_view_menu = mocker.patch.object( security_manager, @@ -437,10 +474,7 @@ def test_update_oauth2_removes_masked_encrypted_extra_key( None, "[my_db].[schema2]", ] - add_permission_view_menu = mocker.patch.object( - security_manager, - "add_permission_view_menu", - ) + add_pvm = mocker.patch("superset.commands.database.sync_permissions.add_pvm") modified_oauth2_client_info = oauth2_client_info.copy() modified_oauth2_client_info["scope"] = "scope-b" @@ -454,9 +488,9 @@ def test_update_oauth2_removes_masked_encrypted_extra_key( }, ).run() - add_permission_view_menu.assert_not_called() + add_pvm.assert_not_called() database_needs_oauth2.purge_oauth2_tokens.assert_called() - DatabaseDAO.update.assert_called_with( + database_dao.update.assert_called_with( database_needs_oauth2, { "encrypted_extra": json.dumps( @@ -474,9 +508,15 @@ def test_update_other_fields_dont_affect_oauth( Test that not including ``masked_encrypted_extra`` in the payload does not touch the OAuth config. """ - DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO") # noqa: N806 - DatabaseDAO.find_by_id.return_value = database_needs_oauth2 - DatabaseDAO.update.return_value = database_needs_oauth2 + database_dao = mocker.patch("superset.commands.database.update.DatabaseDAO") + database_dao.find_by_id.return_value = database_needs_oauth2 + database_dao.update.return_value = database_needs_oauth2 + sync_db_perms_dao = mocker.patch( + "superset.commands.database.sync_permissions.DatabaseDAO" + ) + sync_db_perms_dao.find_by_id.return_value = database_needs_oauth2 + mocker.patch("superset.commands.database.update.get_username") + mocker.patch("superset.security_manager.get_user_by_username") find_permission_view_menu = mocker.patch.object( security_manager, @@ -486,12 +526,9 @@ def test_update_other_fields_dont_affect_oauth( None, "[my_db].[schema2]", ] - add_permission_view_menu = mocker.patch.object( - security_manager, - "add_permission_view_menu", - ) + add_pvm = mocker.patch("superset.commands.database.sync_permissions.add_pvm") UpdateDatabaseCommand(1, {"database_name": "New DB name"}).run() - add_permission_view_menu.assert_not_called() + add_pvm.assert_not_called() database_needs_oauth2.purge_oauth2_tokens.assert_not_called() diff --git a/tests/unit_tests/commands/databases/utils_test.py b/tests/unit_tests/commands/databases/utils_test.py new file mode 100644 index 00000000000..04512e47f4f --- /dev/null +++ b/tests/unit_tests/commands/databases/utils_test.py @@ -0,0 +1,212 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime +import sqlite3 +from unittest.mock import MagicMock + +import pytest +from flask_appbuilder.security.sqla.models import ( + Permission, + PermissionView, + ViewMenu, +) +from pytest_mock import MockerFixture +from sqlalchemy.orm import Session + +from superset.commands.database.utils import ( + add_perm, + add_pvm, + add_vm, + ping, +) +from tests.conftest import with_config + + +@pytest.fixture +def mock_engine(mocker: MockerFixture) -> tuple[MagicMock, MagicMock, MagicMock]: + mock_connection = mocker.MagicMock() + mock_engine = mocker.MagicMock() + mock_dialect = mocker.MagicMock() + mock_engine.raw_connection.return_value = mock_connection + mock_engine.dialect = mock_dialect + return mock_engine, mock_connection, mock_dialect + + +@with_config({"TEST_DATABASE_CONNECTION_TIMEOUT": datetime.timedelta(seconds=10)}) +def test_ping_success(mock_engine: MockerFixture): + """ + Test the ``ping`` method. + """ + mock_engine, mock_connection, mock_dialect = mock_engine + mock_dialect.do_ping.return_value = True + + result = ping(mock_engine) + + assert result is True + + mock_engine.raw_connection.assert_called_once() + mock_dialect.do_ping.assert_called_once_with(mock_connection) + + +@with_config({"TEST_DATABASE_CONNECTION_TIMEOUT": datetime.timedelta(seconds=10)}) +def test_ping_sqlite_exception(mocker: MockerFixture, mock_engine: MockerFixture): + """ + Test the ``ping`` method when a sqlite3.ProgrammingError is raised. + """ + mock_engine, mock_connection, mock_dialect = mock_engine + mock_dialect.do_ping.side_effect = [sqlite3.ProgrammingError, True] + + result = ping(mock_engine) + + assert result is True + + mock_dialect.do_ping.assert_has_calls( + [mocker.call(mock_connection), mocker.call(mock_engine)] + ) + + +def test_ping_runtime_exception(mocker: MockerFixture, mock_engine: MockerFixture): + """ + Test the ``ping`` method when a RuntimeError is raised. + """ + mock_engine, _, mock_dialect = mock_engine + mock_timeout = mocker.patch("superset.commands.database.utils.timeout") + mock_timeout.side_effect = RuntimeError("timeout") + mock_dialect.do_ping.return_value = True + + result = ping(mock_engine) + + assert result is True + mock_dialect.do_ping.assert_called_once_with(mock_engine) + + +@pytest.fixture +def db_session(mocker: MockerFixture) -> Session: + return mocker.MagicMock(spec=Session) + + +def test_add_vm(db_session: Session, mocker: MockerFixture): + """ + Thest ``add_vm`` when the ViewMenu does not exist. + """ + sm = mocker.MagicMock() + sm.find_view_menu.return_value = None + sm.viewmenu_model = ViewMenu + + result = add_vm(db_session, sm, "new_view_menu") + + assert result.name == "new_view_menu" + sm.find_view_menu.assert_called_once_with("new_view_menu") + db_session.add.assert_called_once_with(result) + + +def test_add_vm_existing(db_session: Session, mocker: MockerFixture): + """ + Thest ``add_vm`` when the ViewMenu already exists. + """ + mock_vm = mocker.MagicMock() + sm = mocker.MagicMock() + sm.find_view_menu.return_value = mock_vm + + result = add_vm(db_session, sm, "existing_view_menu") + + assert result == mock_vm + sm.find_view_menu.assert_called_once_with("existing_view_menu") + db_session.add.assert_not_called() + + +def test_add_perm(db_session: Session, mocker: MockerFixture): + """ + Thest ``add_perm`` when the Permission does not exist. + """ + sm = mocker.MagicMock() + sm.find_permission.return_value = None + sm.permission_model = Permission + + result = add_perm(db_session, sm, "new_perm") + + assert result.name == "new_perm" + sm.find_permission.assert_called_once_with("new_perm") + db_session.add.assert_called_once_with(result) + + +def test_add_perm_existing(db_session: Session, mocker: MockerFixture): + """ + Thest ``add_perm`` when the Permission already exists. + """ + mock_perm = mocker.MagicMock() + sm = mocker.MagicMock() + sm.find_permission.return_value = mock_perm + + result = add_perm(db_session, sm, "existing_perm") + + assert result == mock_perm + sm.find_permission.assert_called_once_with("existing_perm") + db_session.add.assert_not_called() + + +def test_add_pvm(db_session: Session, mocker: MockerFixture): + """ + Thest ``add_pvm`` when the PermissionView does not exist. + """ + sm = mocker.MagicMock() + sm.find_permission_view_menu.return_value = None + sm.permissionview_model = PermissionView + + mock_vm = mocker.MagicMock() + mock_perm = mocker.MagicMock() + mock_add_vm = mocker.patch("superset.commands.database.utils.add_vm") + mock_add_vm.return_value = mock_vm + mock_add_perm = mocker.patch("superset.commands.database.utils.add_perm") + mock_add_perm.return_value = mock_perm + + result = add_pvm(db_session, sm, "new_perm", "new_view_menu") + + assert result is not None + assert result.view_menu == mock_vm + assert result.permission == mock_perm + sm.find_permission_view_menu.assert_called_once_with("new_perm", "new_view_menu") + mock_add_vm.assert_called_once_with(db_session, sm, "new_view_menu") + mock_add_perm.assert_called_once_with(db_session, sm, "new_perm") + db_session.add.assert_called_once_with(result) + + +def test_add_pvm_missing_data(db_session: Session, mocker: MockerFixture): + """ + Thest ``add_pvm`` when permission_name and view_menu_name are empty. + """ + sm = mocker.MagicMock() + result = add_pvm(db_session, sm, None, None) + + assert result is None + + +def test_add_pvm_existing(db_session: Session, mocker: MockerFixture): + """ + Thest ``add_pvm`` when the PermissionView already exists. + """ + mock_pvm = mocker.MagicMock() + sm = mocker.MagicMock() + sm.find_permission_view_menu.return_value = mock_pvm + + result = add_pvm(db_session, sm, "existinf_perm", "existing_vm") + + assert result == mock_pvm + sm.find_permission_view_menu.assert_called_once_with("existinf_perm", "existing_vm") + db_session.add.assert_not_called()