refactor: test connection raises only command exceptions (#12307)

* refactor: test connection raises only command exceptions

* fix tests

* fix tests

* fix tests

* lint fix
This commit is contained in:
Daniel Vaz Gaspar
2021-01-08 13:10:11 +00:00
committed by GitHub
parent fecfc34cd3
commit c685c9ea8f
5 changed files with 52 additions and 52 deletions

View File

@@ -24,16 +24,8 @@ from zipfile import ZipFile
from flask import g, request, Response, send_file
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import gettext as _
from marshmallow import ValidationError
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import (
DBAPIError,
NoSuchModuleError,
NoSuchTableError,
OperationalError,
SQLAlchemyError,
)
from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError
from superset import event_logger
from superset.commands.exceptions import CommandInvalidError
@@ -49,7 +41,7 @@ from superset.databases.commands.exceptions import (
DatabaseImportError,
DatabaseInvalidError,
DatabaseNotFoundError,
DatabaseSecurityUnsafeError,
DatabaseTestConnectionFailedError,
DatabaseUpdateFailedError,
)
from superset.databases.commands.export import ExportDatabasesCommand
@@ -589,29 +581,8 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
try:
TestConnectionDatabaseCommand(g.user, item).run()
return self.response(200, message="OK")
except (NoSuchModuleError, ModuleNotFoundError):
logger.info("Invalid driver")
driver_name = make_url(item.get("sqlalchemy_uri")).drivername
return self.response(
400,
message=_("Could not load database driver: {}").format(driver_name),
driver_name=driver_name,
)
except DatabaseSecurityUnsafeError as ex:
return self.response_422(message=ex)
except DBAPIError:
logger.warning("Connection failed")
return self.response(
500,
message=_("Connection failed, please check your connection settings"),
)
except Exception as ex: # pylint: disable=broad-except
logger.error("Unexpected error %s", type(ex).__name__)
return self.response_400(
message=_(
"Unexpected error occurred, please check your logs for details"
)
)
except DatabaseTestConnectionFailedError as ex:
return self.response_422(message=str(ex))
@expose("/<int:pk>/related_objects/", methods=["GET"])
@protect()

View File

@@ -25,7 +25,6 @@ from superset.commands.exceptions import (
ImportFailedError,
UpdateFailedError,
)
from superset.security.analytics_db_safety import DBSecurityException
class DatabaseInvalidError(CommandInvalidError):
@@ -102,7 +101,7 @@ class DatabaseUpdateFailedError(UpdateFailedError):
class DatabaseConnectionFailedError( # pylint: disable=too-many-ancestors
DatabaseCreateFailedError, DatabaseUpdateFailedError,
):
message = _("Could not connect to database.")
message = _("Connection failed, please check your connection settings")
class DatabaseDeleteDatasetsExistFailedError(DeleteFailedError):
@@ -117,9 +116,21 @@ class DatabaseDeleteFailedReportsExistError(DatabaseDeleteFailedError):
message = _("There are associated alerts or reports")
class DatabaseSecurityUnsafeError(DBSecurityException):
class DatabaseTestConnectionFailedError(CommandException):
message = _("Connection failed, please check your connection settings")
class DatabaseSecurityUnsafeError(DatabaseTestConnectionFailedError):
message = _("Stopped an unsafe database connection")
class DatabaseTestConnectionDriverError(DatabaseTestConnectionFailedError):
message = _("Could not load database driver")
class DatabaseTestConnectionUnexpectedError(DatabaseTestConnectionFailedError):
message = _("Unexpected error occurred, please check your logs for details")
class DatabaseImportError(ImportFailedError):
message = _("Import database failed for an unknown reason")

View File

@@ -19,10 +19,17 @@ from contextlib import closing
from typing import Any, Dict, Optional
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.exc import DBAPIError
from flask_babel import gettext as _
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import DBAPIError, NoSuchModuleError
from superset.commands.base import BaseCommand
from superset.databases.commands.exceptions import DatabaseSecurityUnsafeError
from superset.databases.commands.exceptions import (
DatabaseSecurityUnsafeError,
DatabaseTestConnectionDriverError,
DatabaseTestConnectionFailedError,
DatabaseTestConnectionUnexpectedError,
)
from superset.databases.dao import DatabaseDAO
from superset.models.core import Database
from superset.security.analytics_db_safety import DBSecurityException
@@ -38,11 +45,10 @@ class TestConnectionDatabaseCommand(BaseCommand):
def run(self) -> None:
self.validate()
uri = self._properties.get("sqlalchemy_uri", "")
if self._model and uri == self._model.safe_sqlalchemy_uri():
uri = self._model.sqlalchemy_uri_decrypted
try:
uri = self._properties.get("sqlalchemy_uri", "")
if self._model and uri == self._model.safe_sqlalchemy_uri():
uri = self._model.sqlalchemy_uri_decrypted
database = DatabaseDAO.build_db_for_connection_test(
server_cert=self._properties.get("server_cert", ""),
extra=self._properties.get("extra", "{}"),
@@ -57,9 +63,17 @@ class TestConnectionDatabaseCommand(BaseCommand):
with closing(engine.raw_connection()) as conn:
if not engine.dialect.do_ping(conn):
raise DBAPIError(None, None, None)
except (NoSuchModuleError, ModuleNotFoundError):
driver_name = make_url(uri).drivername
raise DatabaseTestConnectionDriverError(
message=_("Could not load database driver: {}").format(driver_name),
)
except DBAPIError:
raise DatabaseTestConnectionFailedError()
except DBSecurityException as ex:
logger.warning(ex)
raise DatabaseSecurityUnsafeError()
raise DatabaseSecurityUnsafeError(message=str(ex))
except Exception:
raise DatabaseTestConnectionUnexpectedError()
def validate(self) -> None:
database_name = self._properties.get("database_name")

View File

@@ -16,9 +16,11 @@
# under the License.
from sqlalchemy.engine.url import URL
from superset.exceptions import SupersetException
class DBSecurityException(Exception):
""" Exception to prevent a security issue with connecting a DB """
class DBSecurityException(SupersetException):
""" Exception to prevent a security issue with connecting to a DB """
status = 400

View File

@@ -389,7 +389,9 @@ class TestDatabaseApi(SupersetTestCase):
self.login(username="admin")
response = self.client.post(uri, json=database_data)
response_data = json.loads(response.data.decode("utf-8"))
expected_response = {"message": "Could not connect to database."}
expected_response = {
"message": "Connection failed, please check your connection settings"
}
self.assertEqual(response.status_code, 422)
self.assertEqual(response_data, expected_response)
@@ -431,7 +433,9 @@ class TestDatabaseApi(SupersetTestCase):
self.login(username="admin")
rv = self.client.put(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"message": "Could not connect to database."}
expected_response = {
"message": "Connection failed, please check your connection settings"
}
self.assertEqual(rv.status_code, 422)
self.assertEqual(response, expected_response)
# Cleanup
@@ -787,11 +791,10 @@ class TestDatabaseApi(SupersetTestCase):
}
url = "api/v1/database/test_connection"
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 400)
self.assertEqual(rv.status_code, 422)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"driver_name": "broken",
"message": "Could not load database driver: broken",
}
self.assertEqual(response, expected_response)
@@ -803,11 +806,10 @@ class TestDatabaseApi(SupersetTestCase):
"server_cert": None,
}
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 400)
self.assertEqual(rv.status_code, 422)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"driver_name": "mssql+pymssql",
"message": "Could not load database driver: mssql+pymssql",
}
self.assertEqual(response, expected_response)