mirror of
https://github.com/apache/superset.git
synced 2026-04-23 10:04:45 +00:00
chore(dao/command): Add transaction decorator to try to enforce "unit of work" (#24969)
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import current_app
|
||||
@@ -39,11 +40,11 @@ from superset.commands.database.ssh_tunnel.exceptions import (
|
||||
)
|
||||
from superset.commands.database.test_connection import TestConnectionDatabaseCommand
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.daos.exceptions import DAOCreateFailedError
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.exceptions import SupersetErrorsException
|
||||
from superset.extensions import db, event_logger, security_manager
|
||||
from superset.extensions import event_logger, security_manager
|
||||
from superset.models.core import Database
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
stats_logger = current_app.config["STATS_LOGGER"]
|
||||
@@ -53,6 +54,7 @@ class CreateDatabaseCommand(BaseCommand):
|
||||
def __init__(self, data: dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
|
||||
@transaction(on_error=partial(on_error, reraise=DatabaseCreateFailedError))
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
|
||||
@@ -96,8 +98,6 @@ class CreateDatabaseCommand(BaseCommand):
|
||||
database, ssh_tunnel_properties
|
||||
).run()
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# add catalog/schema permissions
|
||||
if database.db_engine_spec.supports_catalog:
|
||||
catalogs = database.get_all_catalog_names(
|
||||
@@ -121,14 +121,12 @@ class CreateDatabaseCommand(BaseCommand):
|
||||
except Exception: # pylint: disable=broad-except
|
||||
logger.warning("Error processing catalog '%s'", catalog)
|
||||
continue
|
||||
|
||||
except (
|
||||
SSHTunnelInvalidError,
|
||||
SSHTunnelCreateFailedError,
|
||||
SSHTunnelingNotEnabledError,
|
||||
SSHTunnelDatabasePortError,
|
||||
) as ex:
|
||||
db.session.rollback()
|
||||
event_logger.log_with_context(
|
||||
action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel",
|
||||
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
|
||||
@@ -136,11 +134,9 @@ class CreateDatabaseCommand(BaseCommand):
|
||||
# So we can show the original message
|
||||
raise
|
||||
except (
|
||||
DAOCreateFailedError,
|
||||
DatabaseInvalidError,
|
||||
Exception,
|
||||
) as ex:
|
||||
db.session.rollback()
|
||||
event_logger.log_with_context(
|
||||
action=f"db_creation_failed.{ex.__class__.__name__}",
|
||||
engine=database.db_engine_spec.__name__,
|
||||
@@ -198,6 +194,6 @@ class CreateDatabaseCommand(BaseCommand):
|
||||
raise exception
|
||||
|
||||
def _create_database(self) -> Database:
|
||||
database = DatabaseDAO.create(attributes=self._properties, commit=False)
|
||||
database = DatabaseDAO.create(attributes=self._properties)
|
||||
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
||||
return database
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
from flask_babel import lazy_gettext as _
|
||||
@@ -27,9 +28,9 @@ from superset.commands.database.exceptions import (
|
||||
DatabaseNotFoundError,
|
||||
)
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.daos.exceptions import DAODeleteFailedError
|
||||
from superset.daos.report import ReportScheduleDAO
|
||||
from superset.models.core import Database
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,15 +40,11 @@ class DeleteDatabaseCommand(BaseCommand):
|
||||
self._model_id = model_id
|
||||
self._model: Optional[Database] = None
|
||||
|
||||
@transaction(on_error=partial(on_error, reraise=DatabaseDeleteFailedError))
|
||||
def run(self) -> None:
|
||||
self.validate()
|
||||
assert self._model
|
||||
|
||||
try:
|
||||
DatabaseDAO.delete([self._model])
|
||||
except DAODeleteFailedError as ex:
|
||||
logger.exception(ex.exception)
|
||||
raise DatabaseDeleteFailedError() from ex
|
||||
DatabaseDAO.delete([self._model])
|
||||
|
||||
def validate(self) -> None:
|
||||
# Validate/populate model exists
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
@@ -28,10 +29,10 @@ from superset.commands.database.ssh_tunnel.exceptions import (
|
||||
SSHTunnelRequiredFieldValidationError,
|
||||
)
|
||||
from superset.daos.database import SSHTunnelDAO
|
||||
from superset.daos.exceptions import DAOCreateFailedError
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.extensions import event_logger
|
||||
from superset.models.core import Database
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -44,6 +45,7 @@ class CreateSSHTunnelCommand(BaseCommand):
|
||||
self._properties["database"] = database
|
||||
self._database = database
|
||||
|
||||
@transaction(on_error=partial(on_error, reraise=SSHTunnelCreateFailedError))
|
||||
def run(self) -> Model:
|
||||
"""
|
||||
Create an SSH tunnel.
|
||||
@@ -53,11 +55,8 @@ class CreateSSHTunnelCommand(BaseCommand):
|
||||
:raises SSHTunnelInvalidError: If the configuration are invalid
|
||||
"""
|
||||
|
||||
try:
|
||||
self.validate()
|
||||
return SSHTunnelDAO.create(attributes=self._properties, commit=False)
|
||||
except DAOCreateFailedError as ex:
|
||||
raise SSHTunnelCreateFailedError() from ex
|
||||
self.validate()
|
||||
return SSHTunnelDAO.create(attributes=self._properties)
|
||||
|
||||
def validate(self) -> None:
|
||||
# TODO(hughhh): check to make sure the server port is not localhost
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
from superset import is_feature_enabled
|
||||
@@ -25,8 +26,8 @@ from superset.commands.database.ssh_tunnel.exceptions import (
|
||||
SSHTunnelNotFoundError,
|
||||
)
|
||||
from superset.daos.database import SSHTunnelDAO
|
||||
from superset.daos.exceptions import DAODeleteFailedError
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -36,16 +37,13 @@ class DeleteSSHTunnelCommand(BaseCommand):
|
||||
self._model_id = model_id
|
||||
self._model: Optional[SSHTunnel] = None
|
||||
|
||||
@transaction(on_error=partial(on_error, reraise=SSHTunnelDeleteFailedError))
|
||||
def run(self) -> None:
|
||||
if not is_feature_enabled("SSH_TUNNELING"):
|
||||
raise SSHTunnelingNotEnabledError()
|
||||
self.validate()
|
||||
assert self._model
|
||||
|
||||
try:
|
||||
SSHTunnelDAO.delete([self._model])
|
||||
except DAODeleteFailedError as ex:
|
||||
raise SSHTunnelDeleteFailedError() from ex
|
||||
SSHTunnelDAO.delete([self._model])
|
||||
|
||||
def validate(self) -> None:
|
||||
# Validate/populate model exists
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
@@ -28,9 +29,9 @@ from superset.commands.database.ssh_tunnel.exceptions import (
|
||||
SSHTunnelUpdateFailedError,
|
||||
)
|
||||
from superset.daos.database import SSHTunnelDAO
|
||||
from superset.daos.exceptions import DAOUpdateFailedError
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -41,25 +42,23 @@ class UpdateSSHTunnelCommand(BaseCommand):
|
||||
self._model_id = model_id
|
||||
self._model: Optional[SSHTunnel] = None
|
||||
|
||||
@transaction(on_error=partial(on_error, reraise=SSHTunnelUpdateFailedError))
|
||||
def run(self) -> Optional[Model]:
|
||||
self.validate()
|
||||
try:
|
||||
if self._model is None:
|
||||
return None
|
||||
|
||||
# unset password if private key is provided
|
||||
if self._properties.get("private_key"):
|
||||
self._properties["password"] = None
|
||||
if self._model is None:
|
||||
return None
|
||||
|
||||
# unset private key and password if password is provided
|
||||
if self._properties.get("password"):
|
||||
self._properties["private_key"] = None
|
||||
self._properties["private_key_password"] = None
|
||||
# unset password if private key is provided
|
||||
if self._properties.get("private_key"):
|
||||
self._properties["password"] = None
|
||||
|
||||
tunnel = SSHTunnelDAO.update(self._model, self._properties)
|
||||
return tunnel
|
||||
except DAOUpdateFailedError as ex:
|
||||
raise SSHTunnelUpdateFailedError() from ex
|
||||
# unset private key and password if password is provided
|
||||
if self._properties.get("password"):
|
||||
self._properties["private_key"] = None
|
||||
self._properties["private_key_password"] = None
|
||||
|
||||
return SSHTunnelDAO.update(self._model, self._properties)
|
||||
|
||||
def validate(self) -> None:
|
||||
# Validate/populate model exists
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
@@ -34,16 +35,14 @@ from superset.commands.database.exceptions import (
|
||||
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
|
||||
from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
|
||||
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||
SSHTunnelError,
|
||||
SSHTunnelingNotEnabledError,
|
||||
)
|
||||
from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
from superset.daos.exceptions import DAOCreateFailedError, DAOUpdateFailedError
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.extensions import db
|
||||
from superset.models.core import Database
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -56,6 +55,7 @@ class UpdateDatabaseCommand(BaseCommand):
|
||||
self._model_id = model_id
|
||||
self._model: Database | None = None
|
||||
|
||||
@transaction(on_error=partial(on_error, reraise=DatabaseUpdateFailedError))
|
||||
def run(self) -> Model:
|
||||
self._model = DatabaseDAO.find_by_id(self._model_id)
|
||||
|
||||
@@ -76,21 +76,10 @@ class UpdateDatabaseCommand(BaseCommand):
|
||||
# since they're name based
|
||||
original_database_name = self._model.database_name
|
||||
|
||||
try:
|
||||
database = DatabaseDAO.update(
|
||||
self._model,
|
||||
self._properties,
|
||||
commit=False,
|
||||
)
|
||||
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
||||
ssh_tunnel = self._handle_ssh_tunnel(database)
|
||||
self._refresh_catalogs(database, original_database_name, ssh_tunnel)
|
||||
except SSHTunnelError: # pylint: disable=try-except-raise
|
||||
# allow exception to bubble for debugbing information
|
||||
raise
|
||||
except (DAOUpdateFailedError, DAOCreateFailedError) as ex:
|
||||
raise DatabaseUpdateFailedError() from ex
|
||||
|
||||
database = DatabaseDAO.update(self._model, self._properties)
|
||||
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
||||
ssh_tunnel = self._handle_ssh_tunnel(database)
|
||||
self._refresh_catalogs(database, original_database_name, ssh_tunnel)
|
||||
return database
|
||||
|
||||
def _handle_ssh_tunnel(self, database: Database) -> SSHTunnel | None:
|
||||
@@ -101,7 +90,6 @@ class UpdateDatabaseCommand(BaseCommand):
|
||||
return None
|
||||
|
||||
if not is_feature_enabled("SSH_TUNNELING"):
|
||||
db.session.rollback()
|
||||
raise SSHTunnelingNotEnabledError()
|
||||
|
||||
current_ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
|
||||
@@ -131,13 +119,13 @@ class UpdateDatabaseCommand(BaseCommand):
|
||||
This method captures a generic exception, since errors could potentially come
|
||||
from any of the 50+ database drivers we support.
|
||||
"""
|
||||
|
||||
try:
|
||||
return database.get_all_catalog_names(
|
||||
force=True,
|
||||
ssh_tunnel=ssh_tunnel,
|
||||
)
|
||||
except Exception as ex:
|
||||
db.session.rollback()
|
||||
raise DatabaseConnectionFailedError() from ex
|
||||
|
||||
def _get_schema_names(
|
||||
@@ -152,6 +140,7 @@ class UpdateDatabaseCommand(BaseCommand):
|
||||
This method captures a generic exception, since errors could potentially come
|
||||
from any of the 50+ database drivers we support.
|
||||
"""
|
||||
|
||||
try:
|
||||
return database.get_all_schema_names(
|
||||
force=True,
|
||||
@@ -159,7 +148,6 @@ class UpdateDatabaseCommand(BaseCommand):
|
||||
ssh_tunnel=ssh_tunnel,
|
||||
)
|
||||
except Exception as ex:
|
||||
db.session.rollback()
|
||||
raise DatabaseConnectionFailedError() from ex
|
||||
|
||||
def _refresh_catalogs(
|
||||
@@ -225,8 +213,6 @@ class UpdateDatabaseCommand(BaseCommand):
|
||||
schemas,
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def _refresh_schemas(
|
||||
self,
|
||||
database: Database,
|
||||
|
||||
@@ -16,11 +16,11 @@
|
||||
# under the License.
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
from typing import Any, Optional, TypedDict
|
||||
|
||||
import pandas as pd
|
||||
from flask_babel import lazy_gettext as _
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from superset import db
|
||||
@@ -37,6 +37,7 @@ from superset.daos.database import DatabaseDAO
|
||||
from superset.models.core import Database
|
||||
from superset.sql_parse import Table
|
||||
from superset.utils.core import get_user
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
from superset.views.database.validators import schema_allows_file_upload
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -144,6 +145,7 @@ class UploadCommand(BaseCommand):
|
||||
self._file = file
|
||||
self._reader = reader
|
||||
|
||||
@transaction(on_error=partial(on_error, reraise=DatabaseUploadSaveMetadataFailed))
|
||||
def run(self) -> None:
|
||||
self.validate()
|
||||
if not self._model:
|
||||
@@ -172,12 +174,6 @@ class UploadCommand(BaseCommand):
|
||||
|
||||
sqla_table.fetch_metadata()
|
||||
|
||||
try:
|
||||
db.session.commit()
|
||||
except SQLAlchemyError as ex:
|
||||
db.session.rollback()
|
||||
raise DatabaseUploadSaveMetadataFailed() from ex
|
||||
|
||||
def validate(self) -> None:
|
||||
self._model = DatabaseDAO.find_by_id(self._model_id)
|
||||
if not self._model:
|
||||
|
||||
Reference in New Issue
Block a user