diff --git a/superset/commands/database/ssh_tunnel/exceptions.py b/superset/commands/database/ssh_tunnel/exceptions.py index a0def8c087a..f74e8f397a9 100644 --- a/superset/commands/database/ssh_tunnel/exceptions.py +++ b/superset/commands/database/ssh_tunnel/exceptions.py @@ -25,47 +25,53 @@ from superset.commands.exceptions import ( ) -class SSHTunnelDeleteFailedError(DeleteFailedError): +class SSHTunnelError(Exception): + """ + Base class. + """ + + +class SSHTunnelDeleteFailedError(DeleteFailedError, SSHTunnelError): message = _("SSH Tunnel could not be deleted.") -class SSHTunnelNotFoundError(CommandException): +class SSHTunnelNotFoundError(CommandException, SSHTunnelError): status = 404 message = _("SSH Tunnel not found.") -class SSHTunnelInvalidError(CommandInvalidError): +class SSHTunnelInvalidError(CommandInvalidError, SSHTunnelError): message = _("SSH Tunnel parameters are invalid.") -class SSHTunnelDatabasePortError(CommandInvalidError): +class SSHTunnelDatabasePortError(CommandInvalidError, SSHTunnelError): message = _("A database port is required when connecting via SSH Tunnel.") -class SSHTunnelUpdateFailedError(UpdateFailedError): +class SSHTunnelUpdateFailedError(UpdateFailedError, SSHTunnelError): message = _("SSH Tunnel could not be updated.") -class SSHTunnelCreateFailedError(CommandException): +class SSHTunnelCreateFailedError(CommandException, SSHTunnelError): message = _("Creating SSH Tunnel failed for an unknown reason") -class SSHTunnelingNotEnabledError(CommandException): +class SSHTunnelingNotEnabledError(CommandException, SSHTunnelError): status = 400 message = _("SSH Tunneling is not enabled") -class SSHTunnelRequiredFieldValidationError(ValidationError): +class SSHTunnelRequiredFieldValidationError(ValidationError, SSHTunnelError): def __init__(self, field_name: str) -> None: super().__init__( - [_("Field is required")], + [_("Field is required")], # type: ignore field_name=field_name, ) -class SSHTunnelMissingCredentials(CommandInvalidError): +class SSHTunnelMissingCredentials(CommandInvalidError, SSHTunnelError): message = _("Must provide credentials for the SSH Tunnel") -class SSHTunnelInvalidCredentials(CommandInvalidError): +class SSHTunnelInvalidCredentials(CommandInvalidError, SSHTunnelError): message = _("Cannot have multiple credentials for the SSH Tunnel") diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py index b057cb300e4..5e0968954cc 100644 --- a/superset/commands/database/update.py +++ b/superset/commands/database/update.py @@ -14,13 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +from __future__ import annotations + import logging from typing import Any, Optional from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError -from superset import is_feature_enabled +from superset import is_feature_enabled, security_manager from superset.commands.base import BaseCommand from superset.commands.database.exceptions import ( DatabaseConnectionFailedError, @@ -32,19 +35,16 @@ from superset.commands.database.exceptions import ( from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand from superset.commands.database.ssh_tunnel.exceptions import ( - SSHTunnelCreateFailedError, - SSHTunnelDatabasePortError, - SSHTunnelDeleteFailedError, + SSHTunnelError, SSHTunnelingNotEnabledError, - SSHTunnelInvalidError, - SSHTunnelUpdateFailedError, ) from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand from superset.daos.database import DatabaseDAO +from superset.daos.dataset import DatasetDAO from superset.daos.exceptions import DAOCreateFailedError, DAOUpdateFailedError -from superset.extensions import db, security_manager +from superset.databases.ssh_tunnel.models import SSHTunnel +from superset.extensions import db from superset.models.core import Database -from superset.utils.core import DatasourceType logger = logging.getLogger(__name__) @@ -57,7 +57,7 @@ class UpdateDatabaseCommand(BaseCommand): self._model_id = model_id self._model: Optional[Database] = None - def run(self) -> Model: # pylint: disable=too-many-statements, too-many-branches + def run(self) -> Model: self._model = DatabaseDAO.find_by_id(self._model_id) if not self._model: @@ -65,8 +65,6 @@ class UpdateDatabaseCommand(BaseCommand): self.validate() - old_database_name = self._model.database_name - # unmask ``encrypted_extra`` self._properties["encrypted_extra"] = ( self._model.db_engine_spec.unmask_encrypted_extra( @@ -75,127 +73,106 @@ class UpdateDatabaseCommand(BaseCommand): ) ) + # if the database name changed we need to update any existing permissions, + # since they're name based + original_database_name = self._model.database_name + try: - database = DatabaseDAO.update(self._model, self._properties, commit=False) + database = DatabaseDAO.update( + self._model, + self._properties, + commit=False, + ) database.set_sqlalchemy_uri(database.sqlalchemy_uri) - - ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) - - if "ssh_tunnel" in self._properties: - if not is_feature_enabled("SSH_TUNNELING"): - db.session.rollback() - raise SSHTunnelingNotEnabledError() - - if self._properties.get("ssh_tunnel") is None and ssh_tunnel: - # We need to remove the existing tunnel - try: - DeleteSSHTunnelCommand(ssh_tunnel.id).run() - ssh_tunnel = None - except SSHTunnelDeleteFailedError as ex: - raise ex - except Exception as ex: - raise DatabaseUpdateFailedError() from ex - - if ssh_tunnel_properties := self._properties.get("ssh_tunnel"): - if ssh_tunnel is None: - # We couldn't found an existing tunnel so we need to create one - try: - ssh_tunnel = CreateSSHTunnelCommand( - database, ssh_tunnel_properties - ).run() - except ( - SSHTunnelInvalidError, - SSHTunnelCreateFailedError, - SSHTunnelDatabasePortError, - ) as ex: - # So we can show the original message - raise ex - except Exception as ex: - raise DatabaseUpdateFailedError() from ex - else: - # We found an existing tunnel so we need to update it - try: - ssh_tunnel_id = ssh_tunnel.id - ssh_tunnel = UpdateSSHTunnelCommand( - ssh_tunnel_id, ssh_tunnel_properties - ).run() - except ( - SSHTunnelInvalidError, - SSHTunnelUpdateFailedError, - SSHTunnelDatabasePortError, - ) as ex: - # So we can show the original message - raise ex - except Exception as ex: - raise DatabaseUpdateFailedError() from ex - - # adding a new database we always want to force refresh schema list - # TODO Improve this simplistic implementation for catching DB conn fails - try: - schemas = database.get_all_schema_names(ssh_tunnel=ssh_tunnel) - except Exception as ex: - db.session.rollback() - raise DatabaseConnectionFailedError() from ex - - # Update database schema permissions - new_schemas: list[str] = [] - - for schema in schemas: - old_view_menu_name = security_manager.get_schema_perm( - old_database_name, schema - ) - new_view_menu_name = security_manager.get_schema_perm( - database.database_name, schema - ) - schema_pvm = security_manager.find_permission_view_menu( - "schema_access", old_view_menu_name - ) - # Update the schema permission if the database name changed - if schema_pvm and old_database_name != database.database_name: - schema_pvm.view_menu.name = new_view_menu_name - - self._propagate_schema_permissions( - old_view_menu_name, new_view_menu_name - ) - else: - new_schemas.append(schema) - for schema in new_schemas: - security_manager.add_permission_view_menu( - "schema_access", security_manager.get_schema_perm(database, schema) - ) - - db.session.commit() - + ssh_tunnel = self._handle_ssh_tunnel(database) + self._refresh_schemas(database, original_database_name, ssh_tunnel) + except SSHTunnelError as ex: + # allow exception to bubble for debugbing information + raise ex except (DAOUpdateFailedError, DAOCreateFailedError) as ex: raise DatabaseUpdateFailedError() from ex + return database - @staticmethod - def _propagate_schema_permissions( - old_view_menu_name: str, new_view_menu_name: str - ) -> None: - from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel - SqlaTable, - ) - from superset.models.slice import ( # pylint: disable=import-outside-toplevel - Slice, - ) + def _handle_ssh_tunnel(self, database: Database) -> SSHTunnel | None: + """ + Delete, create, or update an SSH tunnel. + """ + if "ssh_tunnel" not in self._properties: + return None - # Update schema_perm on all datasets - datasets = ( - db.session.query(SqlaTable) - .filter(SqlaTable.schema_perm == old_view_menu_name) - .all() - ) - for dataset in datasets: - dataset.schema_perm = new_view_menu_name - charts = db.session.query(Slice).filter( - Slice.datasource_type == DatasourceType.TABLE, - Slice.datasource_id == dataset.id, + if not is_feature_enabled("SSH_TUNNELING"): + db.session.rollback() + raise SSHTunnelingNotEnabledError() + + current_ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) + ssh_tunnel_properties = self._properties["ssh_tunnel"] + + if ssh_tunnel_properties is None: + if current_ssh_tunnel: + DeleteSSHTunnelCommand(current_ssh_tunnel.id).run() + return None + + if current_ssh_tunnel is None: + return CreateSSHTunnelCommand(database, ssh_tunnel_properties).run() + + return UpdateSSHTunnelCommand( + current_ssh_tunnel.id, + ssh_tunnel_properties, + ).run() + + def _refresh_schemas( + self, + database: Database, + original_database_name: str, + ssh_tunnel: Optional[SSHTunnel], + ) -> None: + """ + Add permissions for any new schemas. + """ + try: + schemas = database.get_all_schema_names(ssh_tunnel=ssh_tunnel) + except Exception as ex: + db.session.rollback() + raise DatabaseConnectionFailedError() from ex + + for schema in schemas: + original_vm = security_manager.get_schema_perm( + original_database_name, + schema, ) - # Update schema_perm on all charts - for chart in charts: - chart.schema_perm = new_view_menu_name + existing_pvm = security_manager.find_permission_view_menu( + "schema_access", + original_vm, + ) + if not existing_pvm: + # new schema + security_manager.add_permission_view_menu( + "schema_access", + security_manager.get_schema_perm(database.database_name, schema), + ) + continue + + if original_database_name == database.database_name: + continue + + # rename existing schema permission + existing_pvm.view_menu.name = security_manager.get_schema_perm( + database.database_name, + schema, + ) + + # rename permissions on datasets and charts + for dataset in DatabaseDAO.get_datasets( + database.id, + catalog=None, + schema=schema, + ): + dataset.schema_perm = existing_pvm.view_menu.name + for chart in DatasetDAO.get_related_objects(dataset.id)["charts"]: + chart.schema_perm = existing_pvm.view_menu.name + + db.session.commit() def validate(self) -> None: exceptions: list[ValidationError] = [] diff --git a/superset/daos/database.py b/superset/daos/database.py index 4470ab8f2e5..15fc03710aa 100644 --- a/superset/daos/database.py +++ b/superset/daos/database.py @@ -19,6 +19,7 @@ from __future__ import annotations import logging from typing import Any +from superset.connectors.sqla.models import SqlaTable from superset.daos.base import BaseDAO from superset.databases.filters import DatabaseFilter from superset.databases.ssh_tunnel.models import SSHTunnel @@ -131,6 +132,31 @@ class DatabaseDAO(BaseDAO[Database]): "sqllab_tab_states": sqllab_tab_states, } + @classmethod + def get_datasets( + cls, + database_id: int, + catalog: str | None, + schema: str | None, + ) -> list[SqlaTable]: + """ + Return all datasets, optionally filtered by catalog/schema. + + :param database_id: The database ID + :param catalog: The catalog name + :param schema: The schema name + :return: A list of SqlaTable objects + """ + return ( + db.session.query(SqlaTable) + .filter( + SqlaTable.database_id == database_id, + SqlaTable.catalog == catalog, + SqlaTable.schema == schema, + ) + .all() + ) + @classmethod def get_ssh_tunnel(cls, database_id: int) -> SSHTunnel | None: ssh_tunnel = (