chore(dao/command): Add transaction decorator to try to enforce "unit of work" (#24969)

This commit is contained in:
John Bodley
2024-06-28 12:33:56 -07:00
committed by GitHub
parent a3f0d00714
commit 8fb8199a55
151 changed files with 681 additions and 916 deletions

View File

@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Any, Optional
from flask import current_app
@@ -39,11 +40,11 @@ from superset.commands.database.ssh_tunnel.exceptions import (
)
from superset.commands.database.test_connection import TestConnectionDatabaseCommand
from superset.daos.database import DatabaseDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.exceptions import SupersetErrorsException
from superset.extensions import db, event_logger, security_manager
from superset.extensions import event_logger, security_manager
from superset.models.core import Database
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
stats_logger = current_app.config["STATS_LOGGER"]
@@ -53,6 +54,7 @@ class CreateDatabaseCommand(BaseCommand):
def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
@transaction(on_error=partial(on_error, reraise=DatabaseCreateFailedError))
def run(self) -> Model:
self.validate()
@@ -96,8 +98,6 @@ class CreateDatabaseCommand(BaseCommand):
database, ssh_tunnel_properties
).run()
db.session.commit()
# add catalog/schema permissions
if database.db_engine_spec.supports_catalog:
catalogs = database.get_all_catalog_names(
@@ -121,14 +121,12 @@ class CreateDatabaseCommand(BaseCommand):
except Exception: # pylint: disable=broad-except
logger.warning("Error processing catalog '%s'", catalog)
continue
except (
SSHTunnelInvalidError,
SSHTunnelCreateFailedError,
SSHTunnelingNotEnabledError,
SSHTunnelDatabasePortError,
) as ex:
db.session.rollback()
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel",
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
@@ -136,11 +134,9 @@ class CreateDatabaseCommand(BaseCommand):
# So we can show the original message
raise
except (
DAOCreateFailedError,
DatabaseInvalidError,
Exception,
) as ex:
db.session.rollback()
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}",
engine=database.db_engine_spec.__name__,
@@ -198,6 +194,6 @@ class CreateDatabaseCommand(BaseCommand):
raise exception
def _create_database(self) -> Database:
database = DatabaseDAO.create(attributes=self._properties, commit=False)
database = DatabaseDAO.create(attributes=self._properties)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
return database

View File

@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Optional
from flask_babel import lazy_gettext as _
@@ -27,9 +28,9 @@ from superset.commands.database.exceptions import (
DatabaseNotFoundError,
)
from superset.daos.database import DatabaseDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.daos.report import ReportScheduleDAO
from superset.models.core import Database
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@@ -39,15 +40,11 @@ class DeleteDatabaseCommand(BaseCommand):
self._model_id = model_id
self._model: Optional[Database] = None
@transaction(on_error=partial(on_error, reraise=DatabaseDeleteFailedError))
def run(self) -> None:
self.validate()
assert self._model
try:
DatabaseDAO.delete([self._model])
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise DatabaseDeleteFailedError() from ex
DatabaseDAO.delete([self._model])
def validate(self) -> None:
# Validate/populate model exists

View File

@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
@@ -28,10 +29,10 @@ from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelRequiredFieldValidationError,
)
from superset.daos.database import SSHTunnelDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.databases.utils import make_url_safe
from superset.extensions import event_logger
from superset.models.core import Database
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@@ -44,6 +45,7 @@ class CreateSSHTunnelCommand(BaseCommand):
self._properties["database"] = database
self._database = database
@transaction(on_error=partial(on_error, reraise=SSHTunnelCreateFailedError))
def run(self) -> Model:
"""
Create an SSH tunnel.
@@ -53,11 +55,8 @@ class CreateSSHTunnelCommand(BaseCommand):
:raises SSHTunnelInvalidError: If the configuration are invalid
"""
try:
self.validate()
return SSHTunnelDAO.create(attributes=self._properties, commit=False)
except DAOCreateFailedError as ex:
raise SSHTunnelCreateFailedError() from ex
self.validate()
return SSHTunnelDAO.create(attributes=self._properties)
def validate(self) -> None:
# TODO(hughhh): check to make sure the server port is not localhost

View File

@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Optional
from superset import is_feature_enabled
@@ -25,8 +26,8 @@ from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelNotFoundError,
)
from superset.daos.database import SSHTunnelDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@@ -36,16 +37,13 @@ class DeleteSSHTunnelCommand(BaseCommand):
self._model_id = model_id
self._model: Optional[SSHTunnel] = None
@transaction(on_error=partial(on_error, reraise=SSHTunnelDeleteFailedError))
def run(self) -> None:
if not is_feature_enabled("SSH_TUNNELING"):
raise SSHTunnelingNotEnabledError()
self.validate()
assert self._model
try:
SSHTunnelDAO.delete([self._model])
except DAODeleteFailedError as ex:
raise SSHTunnelDeleteFailedError() from ex
SSHTunnelDAO.delete([self._model])
def validate(self) -> None:
# Validate/populate model exists

View File

@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
@@ -28,9 +29,9 @@ from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelUpdateFailedError,
)
from superset.daos.database import SSHTunnelDAO
from superset.daos.exceptions import DAOUpdateFailedError
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@@ -41,25 +42,23 @@ class UpdateSSHTunnelCommand(BaseCommand):
self._model_id = model_id
self._model: Optional[SSHTunnel] = None
@transaction(on_error=partial(on_error, reraise=SSHTunnelUpdateFailedError))
def run(self) -> Optional[Model]:
self.validate()
try:
if self._model is None:
return None
# unset password if private key is provided
if self._properties.get("private_key"):
self._properties["password"] = None
if self._model is None:
return None
# unset private key and password if password is provided
if self._properties.get("password"):
self._properties["private_key"] = None
self._properties["private_key_password"] = None
# unset password if private key is provided
if self._properties.get("private_key"):
self._properties["password"] = None
tunnel = SSHTunnelDAO.update(self._model, self._properties)
return tunnel
except DAOUpdateFailedError as ex:
raise SSHTunnelUpdateFailedError() from ex
# unset private key and password if password is provided
if self._properties.get("password"):
self._properties["private_key"] = None
self._properties["private_key_password"] = None
return SSHTunnelDAO.update(self._model, self._properties)
def validate(self) -> None:
# Validate/populate model exists

View File

@@ -18,6 +18,7 @@
from __future__ import annotations
import logging
from functools import partial
from typing import Any
from flask_appbuilder.models.sqla import Model
@@ -34,16 +35,14 @@ from superset.commands.database.exceptions import (
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelError,
SSHTunnelingNotEnabledError,
)
from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
from superset.daos.database import DatabaseDAO
from superset.daos.dataset import DatasetDAO
from superset.daos.exceptions import DAOCreateFailedError, DAOUpdateFailedError
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.extensions import db
from superset.models.core import Database
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@@ -56,6 +55,7 @@ class UpdateDatabaseCommand(BaseCommand):
self._model_id = model_id
self._model: Database | None = None
@transaction(on_error=partial(on_error, reraise=DatabaseUpdateFailedError))
def run(self) -> Model:
self._model = DatabaseDAO.find_by_id(self._model_id)
@@ -76,21 +76,10 @@ class UpdateDatabaseCommand(BaseCommand):
# since they're name based
original_database_name = self._model.database_name
try:
database = DatabaseDAO.update(
self._model,
self._properties,
commit=False,
)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
ssh_tunnel = self._handle_ssh_tunnel(database)
self._refresh_catalogs(database, original_database_name, ssh_tunnel)
except SSHTunnelError: # pylint: disable=try-except-raise
# allow exception to bubble for debugbing information
raise
except (DAOUpdateFailedError, DAOCreateFailedError) as ex:
raise DatabaseUpdateFailedError() from ex
database = DatabaseDAO.update(self._model, self._properties)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
ssh_tunnel = self._handle_ssh_tunnel(database)
self._refresh_catalogs(database, original_database_name, ssh_tunnel)
return database
def _handle_ssh_tunnel(self, database: Database) -> SSHTunnel | None:
@@ -101,7 +90,6 @@ class UpdateDatabaseCommand(BaseCommand):
return None
if not is_feature_enabled("SSH_TUNNELING"):
db.session.rollback()
raise SSHTunnelingNotEnabledError()
current_ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
@@ -131,13 +119,13 @@ class UpdateDatabaseCommand(BaseCommand):
This method captures a generic exception, since errors could potentially come
from any of the 50+ database drivers we support.
"""
try:
return database.get_all_catalog_names(
force=True,
ssh_tunnel=ssh_tunnel,
)
except Exception as ex:
db.session.rollback()
raise DatabaseConnectionFailedError() from ex
def _get_schema_names(
@@ -152,6 +140,7 @@ class UpdateDatabaseCommand(BaseCommand):
This method captures a generic exception, since errors could potentially come
from any of the 50+ database drivers we support.
"""
try:
return database.get_all_schema_names(
force=True,
@@ -159,7 +148,6 @@ class UpdateDatabaseCommand(BaseCommand):
ssh_tunnel=ssh_tunnel,
)
except Exception as ex:
db.session.rollback()
raise DatabaseConnectionFailedError() from ex
def _refresh_catalogs(
@@ -225,8 +213,6 @@ class UpdateDatabaseCommand(BaseCommand):
schemas,
)
db.session.commit()
def _refresh_schemas(
self,
database: Database,

View File

@@ -16,11 +16,11 @@
# under the License.
import logging
from abc import abstractmethod
from functools import partial
from typing import Any, Optional, TypedDict
import pandas as pd
from flask_babel import lazy_gettext as _
from sqlalchemy.exc import SQLAlchemyError
from werkzeug.datastructures import FileStorage
from superset import db
@@ -37,6 +37,7 @@ from superset.daos.database import DatabaseDAO
from superset.models.core import Database
from superset.sql_parse import Table
from superset.utils.core import get_user
from superset.utils.decorators import on_error, transaction
from superset.views.database.validators import schema_allows_file_upload
logger = logging.getLogger(__name__)
@@ -144,6 +145,7 @@ class UploadCommand(BaseCommand):
self._file = file
self._reader = reader
@transaction(on_error=partial(on_error, reraise=DatabaseUploadSaveMetadataFailed))
def run(self) -> None:
self.validate()
if not self._model:
@@ -172,12 +174,6 @@ class UploadCommand(BaseCommand):
sqla_table.fetch_metadata()
try:
db.session.commit()
except SQLAlchemyError as ex:
db.session.rollback()
raise DatabaseUploadSaveMetadataFailed() from ex
def validate(self) -> None:
self._model = DatabaseDAO.find_by_id(self._model_id)
if not self._model: