feat: Update database permissions in async mode (#32231)

This commit is contained in:
Vitor Avila
2025-02-28 21:25:47 -03:00
committed by GitHub
parent 84b52b2323
commit d79f7b28c2
22 changed files with 1715 additions and 425 deletions

View File

@@ -23,10 +23,9 @@ from typing import Any
from flask_appbuilder.models.sqla import Model
from superset import is_feature_enabled, security_manager
from superset import is_feature_enabled
from superset.commands.base import BaseCommand
from superset.commands.database.exceptions import (
DatabaseConnectionFailedError,
DatabaseExistsValidationError,
DatabaseInvalidError,
DatabaseNotFoundError,
@@ -38,13 +37,13 @@ from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelingNotEnabledError,
)
from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
from superset.commands.database.sync_permissions import SyncPermissionsCommand
from superset.daos.database import DatabaseDAO
from superset.daos.dataset import DatasetDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.db_engine_specs.base import GenericDBException
from superset.exceptions import OAuth2RedirectError
from superset.models.core import Database
from superset.utils import json
from superset.utils.core import get_username
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@@ -88,7 +87,14 @@ class UpdateDatabaseCommand(BaseCommand):
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
ssh_tunnel = self._handle_ssh_tunnel(database)
try:
self._refresh_catalogs(database, original_database_name, ssh_tunnel)
current_username = get_username()
SyncPermissionsCommand(
self._model_id,
current_username,
old_db_connection_name=original_database_name,
db_connection=database,
ssh_tunnel=ssh_tunnel,
).run()
except OAuth2RedirectError:
pass
@@ -153,201 +159,6 @@ class UpdateDatabaseCommand(BaseCommand):
ssh_tunnel_properties,
).run()
def _get_catalog_names(
self,
database: Database,
ssh_tunnel: SSHTunnel | None,
) -> set[str]:
"""
Helper method to load catalogs.
"""
try:
return database.get_all_catalog_names(
force=True,
ssh_tunnel=ssh_tunnel,
)
except OAuth2RedirectError:
# raise OAuth2 exceptions as-is
raise
except GenericDBException as ex:
raise DatabaseConnectionFailedError() from ex
def _get_schema_names(
self,
database: Database,
catalog: str | None,
ssh_tunnel: SSHTunnel | None,
) -> set[str]:
"""
Helper method to load schemas.
"""
try:
return database.get_all_schema_names(
force=True,
catalog=catalog,
ssh_tunnel=ssh_tunnel,
)
except OAuth2RedirectError:
# raise OAuth2 exceptions as-is
raise
except GenericDBException as ex:
raise DatabaseConnectionFailedError() from ex
def _refresh_catalogs(
self,
database: Database,
original_database_name: str,
ssh_tunnel: SSHTunnel | None,
) -> None:
"""
Add permissions for any new catalogs and schemas.
"""
catalogs = (
self._get_catalog_names(database, ssh_tunnel)
if database.db_engine_spec.supports_catalog
else [None]
)
for catalog in catalogs:
try:
schemas = self._get_schema_names(database, catalog, ssh_tunnel)
if catalog:
perm = security_manager.get_catalog_perm(
original_database_name,
catalog,
)
existing_pvm = security_manager.find_permission_view_menu(
"catalog_access",
perm,
)
if not existing_pvm:
# new catalog
security_manager.add_permission_view_menu(
"catalog_access",
security_manager.get_catalog_perm(
database.database_name,
catalog,
),
)
for schema in schemas:
security_manager.add_permission_view_menu(
"schema_access",
security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
),
)
continue
except DatabaseConnectionFailedError:
# more than one catalog, move to next
if catalog:
logger.warning("Error processing catalog %s", catalog)
continue
raise
# add possible new schemas in catalog
self._refresh_schemas(
database,
original_database_name,
catalog,
schemas,
)
if original_database_name != database.database_name:
self._rename_database_in_permissions(
database,
original_database_name,
catalog,
schemas,
)
def _refresh_schemas(
self,
database: Database,
original_database_name: str,
catalog: str | None,
schemas: set[str],
) -> None:
"""
Add new schemas that don't have permissions yet.
"""
for schema in schemas:
perm = security_manager.get_schema_perm(
original_database_name,
catalog,
schema,
)
existing_pvm = security_manager.find_permission_view_menu(
"schema_access",
perm,
)
if not existing_pvm:
new_name = security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
)
security_manager.add_permission_view_menu("schema_access", new_name)
def _rename_database_in_permissions(
self,
database: Database,
original_database_name: str,
catalog: str | None,
schemas: set[str],
) -> None:
new_catalog_perm_name = security_manager.get_catalog_perm(
database.database_name,
catalog,
)
# rename existing catalog permission
if catalog:
perm = security_manager.get_catalog_perm(
original_database_name,
catalog,
)
existing_pvm = security_manager.find_permission_view_menu(
"catalog_access",
perm,
)
if existing_pvm:
existing_pvm.view_menu.name = new_catalog_perm_name
for schema in schemas:
new_schema_perm_name = security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
)
# rename existing schema permission
perm = security_manager.get_schema_perm(
original_database_name,
catalog,
schema,
)
existing_pvm = security_manager.find_permission_view_menu(
"schema_access",
perm,
)
if existing_pvm:
existing_pvm.view_menu.name = new_schema_perm_name
# rename permissions on datasets and charts
for dataset in DatabaseDAO.get_datasets(
database.id,
catalog=catalog,
schema=schema,
):
dataset.catalog_perm = new_catalog_perm_name
dataset.schema_perm = new_schema_perm_name
for chart in DatasetDAO.get_related_objects(dataset.id)["charts"]:
chart.catalog_perm = new_catalog_perm_name
chart.schema_perm = new_schema_perm_name
def validate(self) -> None:
if database_name := self._properties.get("database_name"):
if not DatabaseDAO.validate_update_uniqueness(