mirror of
https://github.com/apache/superset.git
synced 2026-04-19 16:14:52 +00:00
feat: Update database permissions in async mode (#32231)
This commit is contained in:
@@ -88,11 +88,21 @@ class DatabaseExtraValidationError(ValidationError):
|
||||
)
|
||||
|
||||
|
||||
class DatabaseConnectionSyncPermissionsError(CommandException):
|
||||
status = 500
|
||||
message = _("Unable to sync permissions for this database connection.")
|
||||
|
||||
|
||||
class DatabaseNotFoundError(CommandException):
|
||||
status = 404
|
||||
message = _("Database not found.")
|
||||
|
||||
|
||||
class UserNotFoundInSessionError(CommandException):
|
||||
status = 500
|
||||
message = _("Could not validate the user in the current session.")
|
||||
|
||||
|
||||
class DatabaseSchemaUploadNotAllowed(CommandException):
|
||||
status = 403
|
||||
message = _("Database schema is not allowed for csv uploads.")
|
||||
|
||||
344
superset/commands/database/sync_permissions.py
Normal file
344
superset/commands/database/sync_permissions.py
Normal file
@@ -0,0 +1,344 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Iterable
|
||||
|
||||
from flask import current_app, g
|
||||
|
||||
from superset import app, security_manager
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.commands.database.exceptions import (
|
||||
DatabaseConnectionFailedError,
|
||||
DatabaseConnectionSyncPermissionsError,
|
||||
DatabaseNotFoundError,
|
||||
UserNotFoundInSessionError,
|
||||
)
|
||||
from superset.commands.database.utils import (
|
||||
add_pvm,
|
||||
add_vm,
|
||||
ping,
|
||||
)
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.db_engine_specs.base import GenericDBException
|
||||
from superset.exceptions import OAuth2RedirectError
|
||||
from superset.extensions import celery_app, db
|
||||
from superset.models.core import Database
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SyncPermissionsCommand(BaseCommand):
|
||||
"""
|
||||
Command to sync database permissions.
|
||||
|
||||
This command can be called either via its dedicated endpoint, or as part of
|
||||
another command. If async mode is enabled, the command is executed through
|
||||
a celery task, otherwise it's executed synchronously.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: int,
|
||||
username: str | None,
|
||||
old_db_connection_name: str | None = None,
|
||||
db_connection: Database | None = None,
|
||||
ssh_tunnel: SSHTunnel | None = None,
|
||||
):
|
||||
"""
|
||||
Constructor method.
|
||||
"""
|
||||
self.db_connection_id = model_id
|
||||
self.username = username
|
||||
self._old_db_connection_name: str | None = old_db_connection_name
|
||||
self._db_connection: Database | None = db_connection
|
||||
self.db_connection_ssh_tunnel: SSHTunnel | None = ssh_tunnel
|
||||
|
||||
self.async_mode: bool = app.config["SYNC_DB_PERMISSIONS_IN_ASYNC_MODE"]
|
||||
|
||||
@property
|
||||
def db_connection(self) -> Database:
|
||||
if not self._db_connection:
|
||||
raise DatabaseNotFoundError()
|
||||
return self._db_connection
|
||||
|
||||
@property
|
||||
def old_db_connection_name(self) -> str:
|
||||
return (
|
||||
self._old_db_connection_name
|
||||
if self._old_db_connection_name is not None
|
||||
else self.db_connection.database_name
|
||||
)
|
||||
|
||||
def validate(self) -> None:
|
||||
self._db_connection = (
|
||||
self._db_connection
|
||||
if self._db_connection
|
||||
else DatabaseDAO.find_by_id(self.db_connection_id)
|
||||
)
|
||||
if not self._db_connection:
|
||||
raise DatabaseNotFoundError()
|
||||
|
||||
if not self.db_connection_ssh_tunnel:
|
||||
self.db_connection_ssh_tunnel = DatabaseDAO.get_ssh_tunnel(
|
||||
self.db_connection_id
|
||||
)
|
||||
|
||||
# Need user info to impersonate for OAuth2 connections
|
||||
if not self.username or not security_manager.get_user_by_username(
|
||||
self.username
|
||||
):
|
||||
raise UserNotFoundInSessionError()
|
||||
|
||||
with self.db_connection.get_sqla_engine(
|
||||
override_ssh_tunnel=self.db_connection_ssh_tunnel
|
||||
) as engine:
|
||||
try:
|
||||
alive = ping(engine)
|
||||
except Exception as err:
|
||||
raise DatabaseConnectionFailedError() from err
|
||||
|
||||
if not alive:
|
||||
raise DatabaseConnectionFailedError()
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Triggers the perm sync in sync or async mode.
|
||||
"""
|
||||
self.validate()
|
||||
if self.async_mode:
|
||||
sync_database_permissions_task.delay(
|
||||
self.db_connection_id, self.username, self.old_db_connection_name
|
||||
)
|
||||
return
|
||||
|
||||
self.sync_database_permissions()
|
||||
|
||||
@transaction(
|
||||
on_error=partial(on_error, reraise=DatabaseConnectionSyncPermissionsError)
|
||||
)
|
||||
def sync_database_permissions(self) -> None:
|
||||
"""
|
||||
Syncs the permissions for a DB connection.
|
||||
"""
|
||||
catalogs = (
|
||||
self._get_catalog_names()
|
||||
if self.db_connection.db_engine_spec.supports_catalog
|
||||
else [None]
|
||||
)
|
||||
|
||||
for catalog in catalogs:
|
||||
try:
|
||||
schemas = self._get_schema_names(catalog)
|
||||
|
||||
if catalog:
|
||||
perm = security_manager.get_catalog_perm(
|
||||
self.old_db_connection_name,
|
||||
catalog,
|
||||
)
|
||||
existing_pvm = security_manager.find_permission_view_menu(
|
||||
"catalog_access",
|
||||
perm,
|
||||
)
|
||||
if not existing_pvm:
|
||||
# new catalog
|
||||
add_pvm(
|
||||
db.session,
|
||||
security_manager,
|
||||
"catalog_access",
|
||||
security_manager.get_catalog_perm(
|
||||
self.db_connection.database_name,
|
||||
catalog,
|
||||
),
|
||||
)
|
||||
for schema in schemas:
|
||||
add_pvm(
|
||||
db.session,
|
||||
security_manager,
|
||||
"schema_access",
|
||||
security_manager.get_schema_perm(
|
||||
self.db_connection.database_name,
|
||||
catalog,
|
||||
schema,
|
||||
),
|
||||
)
|
||||
continue
|
||||
except DatabaseConnectionFailedError:
|
||||
logger.warning("Error processing catalog %s", catalog or "(default)")
|
||||
continue
|
||||
|
||||
# add possible new schemas in catalog
|
||||
self._refresh_schemas(catalog, schemas)
|
||||
|
||||
if self.old_db_connection_name != self.db_connection.database_name:
|
||||
self._rename_database_in_permissions(catalog, schemas)
|
||||
|
||||
def _get_catalog_names(self) -> set[str]:
|
||||
"""
|
||||
Helper method to load catalogs.
|
||||
"""
|
||||
try:
|
||||
return self.db_connection.get_all_catalog_names(
|
||||
force=True,
|
||||
ssh_tunnel=self.db_connection_ssh_tunnel,
|
||||
)
|
||||
except OAuth2RedirectError:
|
||||
# raise OAuth2 exceptions as-is
|
||||
raise
|
||||
except GenericDBException as ex:
|
||||
raise DatabaseConnectionFailedError() from ex
|
||||
|
||||
def _get_schema_names(self, catalog: str | None) -> set[str]:
|
||||
"""
|
||||
Helper method to load schemas.
|
||||
"""
|
||||
try:
|
||||
return self.db_connection.get_all_schema_names(
|
||||
force=True,
|
||||
catalog=catalog,
|
||||
ssh_tunnel=self.db_connection_ssh_tunnel,
|
||||
)
|
||||
except OAuth2RedirectError:
|
||||
# raise OAuth2 exceptions as-is
|
||||
raise
|
||||
except GenericDBException as ex:
|
||||
raise DatabaseConnectionFailedError() from ex
|
||||
|
||||
def _refresh_schemas(self, catalog: str | None, schemas: Iterable[str]) -> None:
|
||||
"""
|
||||
Add new schemas that don't have permissions yet.
|
||||
"""
|
||||
for schema in schemas:
|
||||
perm = security_manager.get_schema_perm(
|
||||
self.old_db_connection_name,
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
existing_pvm = security_manager.find_permission_view_menu(
|
||||
"schema_access",
|
||||
perm,
|
||||
)
|
||||
if not existing_pvm:
|
||||
new_name = security_manager.get_schema_perm(
|
||||
self.db_connection.name,
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
add_pvm(db.session, security_manager, "schema_access", new_name)
|
||||
|
||||
def _rename_database_in_permissions(
|
||||
self, catalog: str | None, schemas: Iterable[str]
|
||||
) -> None:
|
||||
# rename existing catalog permission
|
||||
if catalog:
|
||||
new_catalog_perm_name = security_manager.get_catalog_perm(
|
||||
self.db_connection.name,
|
||||
catalog,
|
||||
)
|
||||
new_catalog_vm = add_vm(db.session, security_manager, new_catalog_perm_name)
|
||||
perm = security_manager.get_catalog_perm(
|
||||
self.old_db_connection_name,
|
||||
catalog,
|
||||
)
|
||||
existing_pvm = security_manager.find_permission_view_menu(
|
||||
"catalog_access",
|
||||
perm,
|
||||
)
|
||||
if existing_pvm:
|
||||
existing_pvm.view_menu = new_catalog_vm
|
||||
|
||||
for schema in schemas:
|
||||
new_schema_perm_name = security_manager.get_schema_perm(
|
||||
self.db_connection.name,
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
|
||||
# rename existing schema permission
|
||||
perm = security_manager.get_schema_perm(
|
||||
self.old_db_connection_name,
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
existing_pvm = security_manager.find_permission_view_menu(
|
||||
"schema_access",
|
||||
perm,
|
||||
)
|
||||
if existing_pvm:
|
||||
existing_pvm.view_menu.name = new_schema_perm_name
|
||||
|
||||
# rename permissions on datasets and charts
|
||||
for dataset in DatabaseDAO.get_datasets(
|
||||
self.db_connection_id,
|
||||
catalog=catalog,
|
||||
schema=schema,
|
||||
):
|
||||
dataset.catalog_perm = new_catalog_perm_name
|
||||
dataset.schema_perm = new_schema_perm_name
|
||||
for chart in DatasetDAO.get_related_objects(dataset.id)["charts"]:
|
||||
chart.catalog_perm = new_catalog_perm_name
|
||||
chart.schema_perm = new_schema_perm_name
|
||||
|
||||
|
||||
@celery_app.task(name="sync_database_permissions", soft_time_limit=600)
|
||||
def sync_database_permissions_task(
|
||||
database_id: int, username: str, old_db_connection_name: str
|
||||
) -> None:
|
||||
"""
|
||||
Celery task that triggers the SyncPermissionsCommand in async mode.
|
||||
"""
|
||||
with current_app.test_request_context():
|
||||
try:
|
||||
user = security_manager.get_user_by_username(username)
|
||||
if not user:
|
||||
raise UserNotFoundInSessionError()
|
||||
g.user = user
|
||||
logger.info(
|
||||
"Syncing permissions for DB connection %s while impersonating user %s",
|
||||
database_id,
|
||||
user.id,
|
||||
)
|
||||
|
||||
db_connection = DatabaseDAO.find_by_id(database_id)
|
||||
if not db_connection:
|
||||
raise DatabaseNotFoundError()
|
||||
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database_id)
|
||||
|
||||
SyncPermissionsCommand(
|
||||
database_id,
|
||||
username,
|
||||
old_db_connection_name=old_db_connection_name,
|
||||
db_connection=db_connection,
|
||||
ssh_tunnel=ssh_tunnel,
|
||||
).sync_database_permissions()
|
||||
|
||||
logger.info(
|
||||
"Successfully synced permissions for DB connection %s",
|
||||
database_id,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.error(
|
||||
"An error occurred while syncing permissions for DB connection ID %s",
|
||||
database_id,
|
||||
exc_info=True,
|
||||
)
|
||||
@@ -15,13 +15,9 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
import sqlite3
|
||||
from contextlib import closing
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import current_app as app
|
||||
from flask_babel import gettext as _
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import DBAPIError, NoSuchModuleError
|
||||
|
||||
from superset import is_feature_enabled
|
||||
@@ -35,6 +31,7 @@ from superset.commands.database.ssh_tunnel.exceptions import (
|
||||
SSHTunnelDatabasePortError,
|
||||
SSHTunnelingNotEnabledError,
|
||||
)
|
||||
from superset.commands.database.utils import ping
|
||||
from superset.daos.database import DatabaseDAO, SSHTunnelDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.databases.utils import make_url_safe
|
||||
@@ -47,7 +44,6 @@ from superset.exceptions import (
|
||||
)
|
||||
from superset.extensions import event_logger
|
||||
from superset.models.core import Database
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.ssh_tunnel import unmask_password_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -136,19 +132,9 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
||||
engine=database.db_engine_spec.__name__,
|
||||
)
|
||||
|
||||
def ping(engine: Engine) -> bool:
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
return engine.dialect.do_ping(conn)
|
||||
|
||||
with database.get_sqla_engine(override_ssh_tunnel=ssh_tunnel) as engine:
|
||||
try:
|
||||
time_delta = app.config["TEST_DATABASE_CONNECTION_TIMEOUT"]
|
||||
with utils.timeout(int(time_delta.total_seconds())):
|
||||
alive = ping(engine)
|
||||
except (sqlite3.ProgrammingError, RuntimeError):
|
||||
# SQLite can't run on a separate thread, so ``utils.timeout`` fails
|
||||
# RuntimeError catches the equivalent error from duckdb.
|
||||
alive = engine.dialect.do_ping(engine)
|
||||
alive = ping(engine)
|
||||
except SupersetTimeoutException as ex:
|
||||
raise SupersetTimeoutException(
|
||||
error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT,
|
||||
|
||||
@@ -23,10 +23,9 @@ from typing import Any
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
|
||||
from superset import is_feature_enabled, security_manager
|
||||
from superset import is_feature_enabled
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.commands.database.exceptions import (
|
||||
DatabaseConnectionFailedError,
|
||||
DatabaseExistsValidationError,
|
||||
DatabaseInvalidError,
|
||||
DatabaseNotFoundError,
|
||||
@@ -38,13 +37,13 @@ from superset.commands.database.ssh_tunnel.exceptions import (
|
||||
SSHTunnelingNotEnabledError,
|
||||
)
|
||||
from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
|
||||
from superset.commands.database.sync_permissions import SyncPermissionsCommand
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.db_engine_specs.base import GenericDBException
|
||||
from superset.exceptions import OAuth2RedirectError
|
||||
from superset.models.core import Database
|
||||
from superset.utils import json
|
||||
from superset.utils.core import get_username
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -88,7 +87,14 @@ class UpdateDatabaseCommand(BaseCommand):
|
||||
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
||||
ssh_tunnel = self._handle_ssh_tunnel(database)
|
||||
try:
|
||||
self._refresh_catalogs(database, original_database_name, ssh_tunnel)
|
||||
current_username = get_username()
|
||||
SyncPermissionsCommand(
|
||||
self._model_id,
|
||||
current_username,
|
||||
old_db_connection_name=original_database_name,
|
||||
db_connection=database,
|
||||
ssh_tunnel=ssh_tunnel,
|
||||
).run()
|
||||
except OAuth2RedirectError:
|
||||
pass
|
||||
|
||||
@@ -153,201 +159,6 @@ class UpdateDatabaseCommand(BaseCommand):
|
||||
ssh_tunnel_properties,
|
||||
).run()
|
||||
|
||||
def _get_catalog_names(
|
||||
self,
|
||||
database: Database,
|
||||
ssh_tunnel: SSHTunnel | None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
Helper method to load catalogs.
|
||||
"""
|
||||
try:
|
||||
return database.get_all_catalog_names(
|
||||
force=True,
|
||||
ssh_tunnel=ssh_tunnel,
|
||||
)
|
||||
except OAuth2RedirectError:
|
||||
# raise OAuth2 exceptions as-is
|
||||
raise
|
||||
except GenericDBException as ex:
|
||||
raise DatabaseConnectionFailedError() from ex
|
||||
|
||||
def _get_schema_names(
|
||||
self,
|
||||
database: Database,
|
||||
catalog: str | None,
|
||||
ssh_tunnel: SSHTunnel | None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
Helper method to load schemas.
|
||||
"""
|
||||
try:
|
||||
return database.get_all_schema_names(
|
||||
force=True,
|
||||
catalog=catalog,
|
||||
ssh_tunnel=ssh_tunnel,
|
||||
)
|
||||
except OAuth2RedirectError:
|
||||
# raise OAuth2 exceptions as-is
|
||||
raise
|
||||
except GenericDBException as ex:
|
||||
raise DatabaseConnectionFailedError() from ex
|
||||
|
||||
def _refresh_catalogs(
|
||||
self,
|
||||
database: Database,
|
||||
original_database_name: str,
|
||||
ssh_tunnel: SSHTunnel | None,
|
||||
) -> None:
|
||||
"""
|
||||
Add permissions for any new catalogs and schemas.
|
||||
"""
|
||||
catalogs = (
|
||||
self._get_catalog_names(database, ssh_tunnel)
|
||||
if database.db_engine_spec.supports_catalog
|
||||
else [None]
|
||||
)
|
||||
|
||||
for catalog in catalogs:
|
||||
try:
|
||||
schemas = self._get_schema_names(database, catalog, ssh_tunnel)
|
||||
|
||||
if catalog:
|
||||
perm = security_manager.get_catalog_perm(
|
||||
original_database_name,
|
||||
catalog,
|
||||
)
|
||||
existing_pvm = security_manager.find_permission_view_menu(
|
||||
"catalog_access",
|
||||
perm,
|
||||
)
|
||||
if not existing_pvm:
|
||||
# new catalog
|
||||
security_manager.add_permission_view_menu(
|
||||
"catalog_access",
|
||||
security_manager.get_catalog_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
),
|
||||
)
|
||||
for schema in schemas:
|
||||
security_manager.add_permission_view_menu(
|
||||
"schema_access",
|
||||
security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
schema,
|
||||
),
|
||||
)
|
||||
continue
|
||||
except DatabaseConnectionFailedError:
|
||||
# more than one catalog, move to next
|
||||
if catalog:
|
||||
logger.warning("Error processing catalog %s", catalog)
|
||||
continue
|
||||
raise
|
||||
|
||||
# add possible new schemas in catalog
|
||||
self._refresh_schemas(
|
||||
database,
|
||||
original_database_name,
|
||||
catalog,
|
||||
schemas,
|
||||
)
|
||||
|
||||
if original_database_name != database.database_name:
|
||||
self._rename_database_in_permissions(
|
||||
database,
|
||||
original_database_name,
|
||||
catalog,
|
||||
schemas,
|
||||
)
|
||||
|
||||
def _refresh_schemas(
|
||||
self,
|
||||
database: Database,
|
||||
original_database_name: str,
|
||||
catalog: str | None,
|
||||
schemas: set[str],
|
||||
) -> None:
|
||||
"""
|
||||
Add new schemas that don't have permissions yet.
|
||||
"""
|
||||
for schema in schemas:
|
||||
perm = security_manager.get_schema_perm(
|
||||
original_database_name,
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
existing_pvm = security_manager.find_permission_view_menu(
|
||||
"schema_access",
|
||||
perm,
|
||||
)
|
||||
if not existing_pvm:
|
||||
new_name = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
security_manager.add_permission_view_menu("schema_access", new_name)
|
||||
|
||||
def _rename_database_in_permissions(
|
||||
self,
|
||||
database: Database,
|
||||
original_database_name: str,
|
||||
catalog: str | None,
|
||||
schemas: set[str],
|
||||
) -> None:
|
||||
new_catalog_perm_name = security_manager.get_catalog_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
)
|
||||
|
||||
# rename existing catalog permission
|
||||
if catalog:
|
||||
perm = security_manager.get_catalog_perm(
|
||||
original_database_name,
|
||||
catalog,
|
||||
)
|
||||
existing_pvm = security_manager.find_permission_view_menu(
|
||||
"catalog_access",
|
||||
perm,
|
||||
)
|
||||
if existing_pvm:
|
||||
existing_pvm.view_menu.name = new_catalog_perm_name
|
||||
|
||||
for schema in schemas:
|
||||
new_schema_perm_name = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
|
||||
# rename existing schema permission
|
||||
perm = security_manager.get_schema_perm(
|
||||
original_database_name,
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
existing_pvm = security_manager.find_permission_view_menu(
|
||||
"schema_access",
|
||||
perm,
|
||||
)
|
||||
if existing_pvm:
|
||||
existing_pvm.view_menu.name = new_schema_perm_name
|
||||
|
||||
# rename permissions on datasets and charts
|
||||
for dataset in DatabaseDAO.get_datasets(
|
||||
database.id,
|
||||
catalog=catalog,
|
||||
schema=schema,
|
||||
):
|
||||
dataset.catalog_perm = new_catalog_perm_name
|
||||
dataset.schema_perm = new_schema_perm_name
|
||||
for chart in DatasetDAO.get_related_objects(dataset.id)["charts"]:
|
||||
chart.catalog_perm = new_catalog_perm_name
|
||||
chart.schema_perm = new_schema_perm_name
|
||||
|
||||
def validate(self) -> None:
|
||||
if database_name := self._properties.get("database_name"):
|
||||
if not DatabaseDAO.validate_update_uniqueness(
|
||||
|
||||
@@ -17,19 +17,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
from contextlib import closing
|
||||
|
||||
from flask import current_app as app
|
||||
from flask_appbuilder.security.sqla.models import (
|
||||
Permission,
|
||||
PermissionView,
|
||||
ViewMenu,
|
||||
)
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset import security_manager
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.db_engine_specs.base import GenericDBException
|
||||
from superset.models.core import Database
|
||||
from superset.security.manager import SupersetSecurityManager
|
||||
from superset.utils.core import timeout
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def ping(engine: Engine) -> bool:
|
||||
try:
|
||||
time_delta = app.config["TEST_DATABASE_CONNECTION_TIMEOUT"]
|
||||
with timeout(int(time_delta.total_seconds())):
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
return engine.dialect.do_ping(conn)
|
||||
except (sqlite3.ProgrammingError, RuntimeError):
|
||||
# SQLite can't run on a separate thread, so ``utils.timeout`` fails
|
||||
# RuntimeError catches the equivalent error from duckdb.
|
||||
return engine.dialect.do_ping(engine)
|
||||
|
||||
|
||||
def add_permissions(database: Database, ssh_tunnel: SSHTunnel | None) -> None:
|
||||
"""
|
||||
Add DAR for catalogs and schemas.
|
||||
"""
|
||||
# TODO: Migrate this to use the non-commiting add_pvm helper instead
|
||||
if database.db_engine_spec.supports_catalog:
|
||||
catalogs = database.get_all_catalog_names(
|
||||
cache=False,
|
||||
@@ -65,3 +91,69 @@ def add_permissions(database: Database, ssh_tunnel: SSHTunnel | None) -> None:
|
||||
except GenericDBException: # pylint: disable=broad-except
|
||||
logger.warning("Error processing catalog '%s'", catalog)
|
||||
continue
|
||||
|
||||
|
||||
def add_vm(
|
||||
session: Session,
|
||||
security_manager: SupersetSecurityManager,
|
||||
view_menu_name: str | None,
|
||||
) -> ViewMenu:
|
||||
"""
|
||||
Similar to security_manager.add_view_menu, but without commit.
|
||||
|
||||
This ensures an atomic operation.
|
||||
"""
|
||||
if view_menu := security_manager.find_view_menu(view_menu_name):
|
||||
return view_menu
|
||||
|
||||
view_menu = security_manager.viewmenu_model()
|
||||
view_menu.name = view_menu_name
|
||||
session.add(view_menu)
|
||||
return view_menu
|
||||
|
||||
|
||||
def add_perm(
|
||||
session: Session,
|
||||
security_manager: SupersetSecurityManager,
|
||||
permission_name: str | None,
|
||||
) -> Permission:
|
||||
"""
|
||||
Similar to security_manager.add_permission, but without commit.
|
||||
|
||||
This ensures an atomic operation.
|
||||
"""
|
||||
if perm := security_manager.find_permission(permission_name):
|
||||
return perm
|
||||
|
||||
perm = security_manager.permission_model()
|
||||
perm.name = permission_name
|
||||
session.add(perm)
|
||||
return perm
|
||||
|
||||
|
||||
def add_pvm(
|
||||
session: Session,
|
||||
security_manager: SupersetSecurityManager,
|
||||
permission_name: str | None,
|
||||
view_menu_name: str | None,
|
||||
) -> PermissionView | None:
|
||||
"""
|
||||
Similar to security_manager.add_permission_view_menu, but without commit.
|
||||
|
||||
This ensures an atomic operation.
|
||||
"""
|
||||
if not (permission_name and view_menu_name):
|
||||
return None
|
||||
|
||||
if pv := security_manager.find_permission_view_menu(
|
||||
permission_name, view_menu_name
|
||||
):
|
||||
return pv
|
||||
|
||||
vm = add_vm(session, security_manager, view_menu_name)
|
||||
perm = add_perm(session, security_manager, permission_name)
|
||||
pv = security_manager.permissionview_model()
|
||||
pv.view_menu, pv.permission = vm, perm
|
||||
session.add(pv)
|
||||
|
||||
return pv
|
||||
|
||||
Reference in New Issue
Block a user