mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
chore(command): Organize Commands according to SIP-92 (#25850)
This commit is contained in:
@@ -29,16 +29,9 @@ from marshmallow import ValidationError
|
||||
from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError
|
||||
|
||||
from superset import app, event_logger
|
||||
from superset.commands.importers.exceptions import (
|
||||
IncorrectFormatError,
|
||||
NoValidFilesFoundError,
|
||||
)
|
||||
from superset.commands.importers.v1.utils import get_contents_from_bundle
|
||||
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.databases.commands.create import CreateDatabaseCommand
|
||||
from superset.databases.commands.delete import DeleteDatabaseCommand
|
||||
from superset.databases.commands.exceptions import (
|
||||
from superset.commands.database.create import CreateDatabaseCommand
|
||||
from superset.commands.database.delete import DeleteDatabaseCommand
|
||||
from superset.commands.database.exceptions import (
|
||||
DatabaseConnectionFailedError,
|
||||
DatabaseCreateFailedError,
|
||||
DatabaseDeleteDatasetsExistFailedError,
|
||||
@@ -49,13 +42,26 @@ from superset.databases.commands.exceptions import (
|
||||
DatabaseUpdateFailedError,
|
||||
InvalidParametersError,
|
||||
)
|
||||
from superset.databases.commands.export import ExportDatabasesCommand
|
||||
from superset.databases.commands.importers.dispatcher import ImportDatabasesCommand
|
||||
from superset.databases.commands.tables import TablesDatabaseCommand
|
||||
from superset.databases.commands.test_connection import TestConnectionDatabaseCommand
|
||||
from superset.databases.commands.update import UpdateDatabaseCommand
|
||||
from superset.databases.commands.validate import ValidateDatabaseParametersCommand
|
||||
from superset.databases.commands.validate_sql import ValidateSQLCommand
|
||||
from superset.commands.database.export import ExportDatabasesCommand
|
||||
from superset.commands.database.importers.dispatcher import ImportDatabasesCommand
|
||||
from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
|
||||
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||
SSHTunnelDeleteFailedError,
|
||||
SSHTunnelingNotEnabledError,
|
||||
SSHTunnelNotFoundError,
|
||||
)
|
||||
from superset.commands.database.tables import TablesDatabaseCommand
|
||||
from superset.commands.database.test_connection import TestConnectionDatabaseCommand
|
||||
from superset.commands.database.update import UpdateDatabaseCommand
|
||||
from superset.commands.database.validate import ValidateDatabaseParametersCommand
|
||||
from superset.commands.database.validate_sql import ValidateSQLCommand
|
||||
from superset.commands.importers.exceptions import (
|
||||
IncorrectFormatError,
|
||||
NoValidFilesFoundError,
|
||||
)
|
||||
from superset.commands.importers.v1.utils import get_contents_from_bundle
|
||||
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.databases.decorators import check_datasource_access
|
||||
from superset.databases.filters import DatabaseFilter, DatabaseUploadEnabledFilter
|
||||
from superset.databases.schemas import (
|
||||
@@ -79,12 +85,6 @@ from superset.databases.schemas import (
|
||||
ValidateSQLRequest,
|
||||
ValidateSQLResponse,
|
||||
)
|
||||
from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||
SSHTunnelDeleteFailedError,
|
||||
SSHTunnelingNotEnabledError,
|
||||
SSHTunnelNotFoundError,
|
||||
)
|
||||
from superset.databases.utils import get_table_metadata
|
||||
from superset.db_engine_specs import get_available_engine_specs
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
@@ -1,152 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import current_app
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from marshmallow import ValidationError
|
||||
|
||||
from superset import is_feature_enabled
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.daos.exceptions import DAOCreateFailedError
|
||||
from superset.databases.commands.exceptions import (
|
||||
DatabaseConnectionFailedError,
|
||||
DatabaseCreateFailedError,
|
||||
DatabaseExistsValidationError,
|
||||
DatabaseInvalidError,
|
||||
DatabaseRequiredFieldValidationError,
|
||||
)
|
||||
from superset.databases.commands.test_connection import TestConnectionDatabaseCommand
|
||||
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||
SSHTunnelCreateFailedError,
|
||||
SSHTunnelingNotEnabledError,
|
||||
SSHTunnelInvalidError,
|
||||
)
|
||||
from superset.exceptions import SupersetErrorsException
|
||||
from superset.extensions import db, event_logger, security_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
stats_logger = current_app.config["STATS_LOGGER"]
|
||||
|
||||
|
||||
class CreateDatabaseCommand(BaseCommand):
|
||||
def __init__(self, data: dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
|
||||
try:
|
||||
# Test connection before starting create transaction
|
||||
TestConnectionDatabaseCommand(self._properties).run()
|
||||
except (SupersetErrorsException, SSHTunnelingNotEnabledError) as ex:
|
||||
event_logger.log_with_context(
|
||||
action=f"db_creation_failed.{ex.__class__.__name__}",
|
||||
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
|
||||
)
|
||||
# So we can show the original message
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
event_logger.log_with_context(
|
||||
action=f"db_creation_failed.{ex.__class__.__name__}",
|
||||
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
|
||||
)
|
||||
raise DatabaseConnectionFailedError() from ex
|
||||
|
||||
# when creating a new database we don't need to unmask encrypted extra
|
||||
self._properties["encrypted_extra"] = self._properties.pop(
|
||||
"masked_encrypted_extra",
|
||||
"{}",
|
||||
)
|
||||
|
||||
try:
|
||||
database = DatabaseDAO.create(attributes=self._properties, commit=False)
|
||||
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
||||
|
||||
ssh_tunnel = None
|
||||
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
|
||||
if not is_feature_enabled("SSH_TUNNELING"):
|
||||
db.session.rollback()
|
||||
raise SSHTunnelingNotEnabledError()
|
||||
try:
|
||||
# So database.id is not None
|
||||
db.session.flush()
|
||||
ssh_tunnel = CreateSSHTunnelCommand(
|
||||
database.id, ssh_tunnel_properties
|
||||
).run()
|
||||
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
|
||||
event_logger.log_with_context(
|
||||
action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel",
|
||||
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
|
||||
)
|
||||
# So we can show the original message
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
event_logger.log_with_context(
|
||||
action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel",
|
||||
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
|
||||
)
|
||||
raise DatabaseCreateFailedError() from ex
|
||||
|
||||
# adding a new database we always want to force refresh schema list
|
||||
schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel)
|
||||
for schema in schemas:
|
||||
security_manager.add_permission_view_menu(
|
||||
"schema_access", security_manager.get_schema_perm(database, schema)
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
except DAOCreateFailedError as ex:
|
||||
db.session.rollback()
|
||||
event_logger.log_with_context(
|
||||
action=f"db_creation_failed.{ex.__class__.__name__}",
|
||||
engine=database.db_engine_spec.__name__,
|
||||
)
|
||||
raise DatabaseCreateFailedError() from ex
|
||||
|
||||
if ssh_tunnel:
|
||||
stats_logger.incr("db_creation_success.ssh_tunnel")
|
||||
|
||||
return database
|
||||
|
||||
def validate(self) -> None:
|
||||
exceptions: list[ValidationError] = []
|
||||
sqlalchemy_uri: Optional[str] = self._properties.get("sqlalchemy_uri")
|
||||
database_name: Optional[str] = self._properties.get("database_name")
|
||||
if not sqlalchemy_uri:
|
||||
exceptions.append(DatabaseRequiredFieldValidationError("sqlalchemy_uri"))
|
||||
if not database_name:
|
||||
exceptions.append(DatabaseRequiredFieldValidationError("database_name"))
|
||||
else:
|
||||
# Check database_name uniqueness
|
||||
if not DatabaseDAO.validate_uniqueness(database_name):
|
||||
exceptions.append(DatabaseExistsValidationError())
|
||||
if exceptions:
|
||||
exception = DatabaseInvalidError()
|
||||
exception.extend(exceptions)
|
||||
event_logger.log_with_context(
|
||||
# pylint: disable=consider-using-f-string
|
||||
action="db_connection_failed.{}.{}".format(
|
||||
exception.__class__.__name__,
|
||||
".".join(exception.get_list_classnames()),
|
||||
)
|
||||
)
|
||||
raise exception
|
||||
@@ -1,66 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from flask_babel import lazy_gettext as _
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.daos.exceptions import DAODeleteFailedError
|
||||
from superset.daos.report import ReportScheduleDAO
|
||||
from superset.databases.commands.exceptions import (
|
||||
DatabaseDeleteDatasetsExistFailedError,
|
||||
DatabaseDeleteFailedError,
|
||||
DatabaseDeleteFailedReportsExistError,
|
||||
DatabaseNotFoundError,
|
||||
)
|
||||
from superset.models.core import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeleteDatabaseCommand(BaseCommand):
|
||||
def __init__(self, model_id: int):
|
||||
self._model_id = model_id
|
||||
self._model: Optional[Database] = None
|
||||
|
||||
def run(self) -> None:
|
||||
self.validate()
|
||||
assert self._model
|
||||
|
||||
try:
|
||||
DatabaseDAO.delete([self._model])
|
||||
except DAODeleteFailedError as ex:
|
||||
logger.exception(ex.exception)
|
||||
raise DatabaseDeleteFailedError() from ex
|
||||
|
||||
def validate(self) -> None:
|
||||
# Validate/populate model exists
|
||||
self._model = DatabaseDAO.find_by_id(self._model_id)
|
||||
if not self._model:
|
||||
raise DatabaseNotFoundError()
|
||||
# Check there are no associated ReportSchedules
|
||||
|
||||
if reports := ReportScheduleDAO.find_by_database_id(self._model_id):
|
||||
report_names = [report.name for report in reports]
|
||||
raise DatabaseDeleteFailedReportsExistError(
|
||||
_(f"There are associated alerts or reports: {','.join(report_names)}")
|
||||
)
|
||||
# Check if there are datasets for this database
|
||||
if self._model.tables:
|
||||
raise DatabaseDeleteDatasetsExistFailedError()
|
||||
@@ -1,183 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from flask_babel import lazy_gettext as _
|
||||
from marshmallow.validate import ValidationError
|
||||
|
||||
from superset.commands.exceptions import (
|
||||
CommandException,
|
||||
CommandInvalidError,
|
||||
CreateFailedError,
|
||||
DeleteFailedError,
|
||||
ImportFailedError,
|
||||
UpdateFailedError,
|
||||
)
|
||||
from superset.exceptions import SupersetErrorException, SupersetErrorsException
|
||||
|
||||
|
||||
class DatabaseInvalidError(CommandInvalidError):
|
||||
message = _("Database parameters are invalid.")
|
||||
|
||||
|
||||
class DatabaseExistsValidationError(ValidationError):
|
||||
"""
|
||||
Marshmallow validation error for dataset already exists
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
_("A database with the same name already exists."),
|
||||
field_name="database_name",
|
||||
)
|
||||
|
||||
|
||||
class DatabaseRequiredFieldValidationError(ValidationError):
|
||||
def __init__(self, field_name: str) -> None:
|
||||
super().__init__(
|
||||
[_("Field is required")],
|
||||
field_name=field_name,
|
||||
)
|
||||
|
||||
|
||||
class DatabaseExtraJSONValidationError(ValidationError):
|
||||
"""
|
||||
Marshmallow validation error for database encrypted extra must be a valid JSON
|
||||
"""
|
||||
|
||||
def __init__(self, json_error: str = "") -> None:
|
||||
super().__init__(
|
||||
[
|
||||
_(
|
||||
"Field cannot be decoded by JSON. %(json_error)s",
|
||||
json_error=json_error,
|
||||
)
|
||||
],
|
||||
field_name="extra",
|
||||
)
|
||||
|
||||
|
||||
class DatabaseExtraValidationError(ValidationError):
|
||||
"""
|
||||
Marshmallow validation error for database encrypted extra must be a valid JSON
|
||||
"""
|
||||
|
||||
def __init__(self, key: str = "") -> None:
|
||||
super().__init__(
|
||||
[
|
||||
_(
|
||||
"The metadata_params in Extra field "
|
||||
"is not configured correctly. The key "
|
||||
"%{key}s is invalid.",
|
||||
key=key,
|
||||
)
|
||||
],
|
||||
field_name="extra",
|
||||
)
|
||||
|
||||
|
||||
class DatabaseNotFoundError(CommandException):
|
||||
message = _("Database not found.")
|
||||
|
||||
|
||||
class DatabaseCreateFailedError(CreateFailedError):
|
||||
message = _("Database could not be created.")
|
||||
|
||||
|
||||
class DatabaseUpdateFailedError(UpdateFailedError):
|
||||
message = _("Database could not be updated.")
|
||||
|
||||
|
||||
class DatabaseConnectionFailedError( # pylint: disable=too-many-ancestors
|
||||
DatabaseCreateFailedError,
|
||||
DatabaseUpdateFailedError,
|
||||
):
|
||||
message = _("Connection failed, please check your connection settings")
|
||||
|
||||
|
||||
class DatabaseDeleteDatasetsExistFailedError(DeleteFailedError):
|
||||
message = _("Cannot delete a database that has datasets attached")
|
||||
|
||||
|
||||
class DatabaseDeleteFailedError(DeleteFailedError):
|
||||
message = _("Database could not be deleted.")
|
||||
|
||||
|
||||
class DatabaseDeleteFailedReportsExistError(DatabaseDeleteFailedError):
|
||||
message = _("There are associated alerts or reports")
|
||||
|
||||
|
||||
class DatabaseTestConnectionFailedError(SupersetErrorsException):
|
||||
status = 422
|
||||
message = _("Connection failed, please check your connection settings")
|
||||
|
||||
|
||||
class DatabaseSecurityUnsafeError(CommandInvalidError):
|
||||
message = _("Stopped an unsafe database connection")
|
||||
|
||||
|
||||
class DatabaseTestConnectionDriverError(CommandInvalidError):
|
||||
message = _("Could not load database driver")
|
||||
|
||||
|
||||
class DatabaseTestConnectionUnexpectedError(SupersetErrorsException):
|
||||
status = 422
|
||||
message = _("Unexpected error occurred, please check your logs for details")
|
||||
|
||||
|
||||
class DatabaseTablesUnexpectedError(Exception):
|
||||
status = 422
|
||||
message = _("Unexpected error occurred, please check your logs for details")
|
||||
|
||||
|
||||
class NoValidatorConfigFoundError(SupersetErrorException):
|
||||
status = 422
|
||||
message = _("no SQL validator is configured")
|
||||
|
||||
|
||||
class NoValidatorFoundError(SupersetErrorException):
|
||||
status = 422
|
||||
message = _("No validator found (configured for the engine)")
|
||||
|
||||
|
||||
class ValidatorSQLError(SupersetErrorException):
|
||||
status = 422
|
||||
message = _("Was unable to check your query")
|
||||
|
||||
|
||||
class ValidatorSQLUnexpectedError(CommandException):
|
||||
status = 422
|
||||
message = _("An unexpected error occurred")
|
||||
|
||||
|
||||
class ValidatorSQL400Error(SupersetErrorException):
|
||||
status = 400
|
||||
message = _("Was unable to check your query")
|
||||
|
||||
|
||||
class DatabaseImportError(ImportFailedError):
|
||||
message = _("Import database failed for an unknown reason")
|
||||
|
||||
|
||||
class InvalidEngineError(SupersetErrorException):
|
||||
status = 422
|
||||
|
||||
|
||||
class DatabaseOfflineError(SupersetErrorException):
|
||||
status = 422
|
||||
|
||||
|
||||
class InvalidParametersError(SupersetErrorsException):
|
||||
status = 422
|
||||
@@ -1,122 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# isort:skip_file
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from collections.abc import Iterator
|
||||
|
||||
import yaml
|
||||
|
||||
from superset.databases.commands.exceptions import DatabaseNotFoundError
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.commands.export.models import ExportModelsCommand
|
||||
from superset.models.core import Database
|
||||
from superset.utils.dict_import_export import EXPORT_VERSION
|
||||
from superset.utils.file import get_filename
|
||||
from superset.utils.ssh_tunnel import mask_password_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_extra(extra_payload: str) -> dict[str, Any]:
|
||||
try:
|
||||
extra = json.loads(extra_payload)
|
||||
except json.decoder.JSONDecodeError:
|
||||
logger.info("Unable to decode `extra` field: %s", extra_payload)
|
||||
return {}
|
||||
|
||||
# Fix for DBs saved with an invalid ``schemas_allowed_for_csv_upload``
|
||||
schemas_allowed_for_csv_upload = extra.get("schemas_allowed_for_csv_upload")
|
||||
if isinstance(schemas_allowed_for_csv_upload, str):
|
||||
extra["schemas_allowed_for_csv_upload"] = json.loads(
|
||||
schemas_allowed_for_csv_upload
|
||||
)
|
||||
|
||||
return extra
|
||||
|
||||
|
||||
class ExportDatabasesCommand(ExportModelsCommand):
|
||||
dao = DatabaseDAO
|
||||
not_found = DatabaseNotFoundError
|
||||
|
||||
@staticmethod
|
||||
def _export(
|
||||
model: Database, export_related: bool = True
|
||||
) -> Iterator[tuple[str, str]]:
|
||||
db_file_name = get_filename(model.database_name, model.id, skip_id=True)
|
||||
file_path = f"databases/{db_file_name}.yaml"
|
||||
|
||||
payload = model.export_to_dict(
|
||||
recursive=False,
|
||||
include_parent_ref=False,
|
||||
include_defaults=True,
|
||||
export_uuids=True,
|
||||
)
|
||||
|
||||
# https://github.com/apache/superset/pull/16756 renamed ``allow_csv_upload``
|
||||
# to ``allow_file_upload`, but we can't change the V1 schema
|
||||
replacements = {"allow_file_upload": "allow_csv_upload"}
|
||||
# this preserves key order, which is important
|
||||
payload = {replacements.get(key, key): value for key, value in payload.items()}
|
||||
|
||||
# TODO (betodealmeida): move this logic to export_to_dict once this
|
||||
# becomes the default export endpoint
|
||||
if payload.get("extra"):
|
||||
extra = payload["extra"] = parse_extra(payload["extra"])
|
||||
|
||||
# ``schemas_allowed_for_csv_upload`` was also renamed to
|
||||
# ``schemas_allowed_for_file_upload``, we need to change to preserve the
|
||||
# V1 schema
|
||||
if "schemas_allowed_for_file_upload" in extra:
|
||||
extra["schemas_allowed_for_csv_upload"] = extra.pop(
|
||||
"schemas_allowed_for_file_upload"
|
||||
)
|
||||
|
||||
if ssh_tunnel := DatabaseDAO.get_ssh_tunnel(model.id):
|
||||
ssh_tunnel_payload = ssh_tunnel.export_to_dict(
|
||||
recursive=False,
|
||||
include_parent_ref=False,
|
||||
include_defaults=True,
|
||||
export_uuids=False,
|
||||
)
|
||||
payload["ssh_tunnel"] = mask_password_info(ssh_tunnel_payload)
|
||||
|
||||
payload["version"] = EXPORT_VERSION
|
||||
|
||||
file_content = yaml.safe_dump(payload, sort_keys=False)
|
||||
yield file_path, file_content
|
||||
|
||||
if export_related:
|
||||
for dataset in model.tables:
|
||||
ds_file_name = get_filename(
|
||||
dataset.table_name, dataset.id, skip_id=True
|
||||
)
|
||||
file_path = f"datasets/{db_file_name}/{ds_file_name}.yaml"
|
||||
|
||||
payload = dataset.export_to_dict(
|
||||
recursive=True,
|
||||
include_parent_ref=False,
|
||||
include_defaults=True,
|
||||
export_uuids=True,
|
||||
)
|
||||
payload["version"] = EXPORT_VERSION
|
||||
payload["database_uuid"] = str(model.uuid)
|
||||
|
||||
file_content = yaml.safe_dump(payload, sort_keys=False)
|
||||
yield file_path, file_content
|
||||
@@ -1,16 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
@@ -1,68 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from marshmallow.exceptions import ValidationError
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.commands.exceptions import CommandInvalidError
|
||||
from superset.commands.importers.exceptions import IncorrectVersionError
|
||||
from superset.databases.commands.importers import v1
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
command_versions = [v1.ImportDatabasesCommand]
|
||||
|
||||
|
||||
class ImportDatabasesCommand(BaseCommand):
|
||||
"""
|
||||
Import databases.
|
||||
|
||||
This command dispatches the import to different versions of the command
|
||||
until it finds one that matches.
|
||||
"""
|
||||
|
||||
def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
|
||||
self.contents = contents
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def run(self) -> None:
|
||||
# iterate over all commands until we find a version that can
|
||||
# handle the contents
|
||||
for version in command_versions:
|
||||
command = version(self.contents, *self.args, **self.kwargs)
|
||||
try:
|
||||
command.run()
|
||||
return
|
||||
except IncorrectVersionError:
|
||||
logger.debug("File not handled by command, skipping")
|
||||
except (CommandInvalidError, ValidationError) as exc:
|
||||
# found right version, but file is invalid
|
||||
logger.info("Command failed validation")
|
||||
raise exc
|
||||
except Exception as exc:
|
||||
# validation succeeded but something went wrong
|
||||
logger.exception("Error running import command")
|
||||
raise exc
|
||||
|
||||
raise CommandInvalidError("Could not find a valid command to import file")
|
||||
|
||||
def validate(self) -> None:
|
||||
pass
|
||||
@@ -1,64 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from marshmallow import Schema
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset.commands.importers.v1 import ImportModelsCommand
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.databases.commands.exceptions import DatabaseImportError
|
||||
from superset.databases.commands.importers.v1.utils import import_database
|
||||
from superset.databases.schemas import ImportV1DatabaseSchema
|
||||
from superset.datasets.commands.importers.v1.utils import import_dataset
|
||||
from superset.datasets.schemas import ImportV1DatasetSchema
|
||||
|
||||
|
||||
class ImportDatabasesCommand(ImportModelsCommand):
|
||||
|
||||
"""Import databases"""
|
||||
|
||||
dao = DatabaseDAO
|
||||
model_name = "database"
|
||||
prefix = "databases/"
|
||||
schemas: dict[str, Schema] = {
|
||||
"databases/": ImportV1DatabaseSchema(),
|
||||
"datasets/": ImportV1DatasetSchema(),
|
||||
}
|
||||
import_error = DatabaseImportError
|
||||
|
||||
@staticmethod
|
||||
def _import(
|
||||
session: Session, configs: dict[str, Any], overwrite: bool = False
|
||||
) -> None:
|
||||
# first import databases
|
||||
database_ids: dict[str, int] = {}
|
||||
for file_name, config in configs.items():
|
||||
if file_name.startswith("databases/"):
|
||||
database = import_database(session, config, overwrite=overwrite)
|
||||
database_ids[str(database.uuid)] = database.id
|
||||
|
||||
# import related datasets
|
||||
for file_name, config in configs.items():
|
||||
if (
|
||||
file_name.startswith("datasets/")
|
||||
and config["database_uuid"] in database_ids
|
||||
):
|
||||
config["database_id"] = database_ids[config["database_uuid"]]
|
||||
# overwrite=False prevents deleting any non-imported columns/metrics
|
||||
import_dataset(session, config, overwrite=False)
|
||||
@@ -1,78 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset import app, security_manager
|
||||
from superset.commands.exceptions import ImportFailedError
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
from superset.models.core import Database
|
||||
from superset.security.analytics_db_safety import check_sqlalchemy_uri
|
||||
|
||||
|
||||
def import_database(
|
||||
session: Session,
|
||||
config: dict[str, Any],
|
||||
overwrite: bool = False,
|
||||
ignore_permissions: bool = False,
|
||||
) -> Database:
|
||||
can_write = ignore_permissions or security_manager.can_access(
|
||||
"can_write",
|
||||
"Database",
|
||||
)
|
||||
existing = session.query(Database).filter_by(uuid=config["uuid"]).first()
|
||||
if existing:
|
||||
if not overwrite or not can_write:
|
||||
return existing
|
||||
config["id"] = existing.id
|
||||
elif not can_write:
|
||||
raise ImportFailedError(
|
||||
"Database doesn't exist and user doesn't have permission to create databases"
|
||||
)
|
||||
# Check if this URI is allowed
|
||||
if app.config["PREVENT_UNSAFE_DB_CONNECTIONS"]:
|
||||
try:
|
||||
check_sqlalchemy_uri(make_url_safe(config["sqlalchemy_uri"]))
|
||||
except SupersetSecurityException as exc:
|
||||
raise ImportFailedError(exc.message) from exc
|
||||
# https://github.com/apache/superset/pull/16756 renamed ``csv`` to ``file``.
|
||||
config["allow_file_upload"] = config.pop("allow_csv_upload")
|
||||
if "schemas_allowed_for_csv_upload" in config["extra"]:
|
||||
config["extra"]["schemas_allowed_for_file_upload"] = config["extra"].pop(
|
||||
"schemas_allowed_for_csv_upload"
|
||||
)
|
||||
|
||||
# TODO (betodealmeida): move this logic to import_from_dict
|
||||
config["extra"] = json.dumps(config["extra"])
|
||||
|
||||
# Before it gets removed in import_from_dict
|
||||
ssh_tunnel = config.pop("ssh_tunnel", None)
|
||||
|
||||
database = Database.import_from_dict(session, config, recursive=False)
|
||||
if database.id is None:
|
||||
session.flush()
|
||||
|
||||
if ssh_tunnel:
|
||||
ssh_tunnel["database_id"] = database.id
|
||||
SSHTunnel.import_from_dict(session, ssh_tunnel, recursive=False)
|
||||
|
||||
return database
|
||||
@@ -1,123 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy.orm import lazyload, load_only
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.databases.commands.exceptions import (
|
||||
DatabaseNotFoundError,
|
||||
DatabaseTablesUnexpectedError,
|
||||
)
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.extensions import db, security_manager
|
||||
from superset.models.core import Database
|
||||
from superset.utils.core import DatasourceName
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TablesDatabaseCommand(BaseCommand):
|
||||
_model: Database
|
||||
|
||||
def __init__(self, db_id: int, schema_name: str, force: bool):
|
||||
self._db_id = db_id
|
||||
self._schema_name = schema_name
|
||||
self._force = force
|
||||
|
||||
def run(self) -> dict[str, Any]:
|
||||
self.validate()
|
||||
try:
|
||||
tables = security_manager.get_datasources_accessible_by_user(
|
||||
database=self._model,
|
||||
schema=self._schema_name,
|
||||
datasource_names=sorted(
|
||||
DatasourceName(*datasource_name)
|
||||
for datasource_name in self._model.get_all_table_names_in_schema(
|
||||
schema=self._schema_name,
|
||||
force=self._force,
|
||||
cache=self._model.table_cache_enabled,
|
||||
cache_timeout=self._model.table_cache_timeout,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
views = security_manager.get_datasources_accessible_by_user(
|
||||
database=self._model,
|
||||
schema=self._schema_name,
|
||||
datasource_names=sorted(
|
||||
DatasourceName(*datasource_name)
|
||||
for datasource_name in self._model.get_all_view_names_in_schema(
|
||||
schema=self._schema_name,
|
||||
force=self._force,
|
||||
cache=self._model.table_cache_enabled,
|
||||
cache_timeout=self._model.table_cache_timeout,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
extra_dict_by_name = {
|
||||
table.name: table.extra_dict
|
||||
for table in (
|
||||
db.session.query(SqlaTable)
|
||||
.filter(
|
||||
SqlaTable.database_id == self._model.id,
|
||||
SqlaTable.schema == self._schema_name,
|
||||
)
|
||||
.options(
|
||||
load_only(
|
||||
SqlaTable.schema, SqlaTable.table_name, SqlaTable.extra
|
||||
),
|
||||
lazyload(SqlaTable.columns),
|
||||
lazyload(SqlaTable.metrics),
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
options = sorted(
|
||||
[
|
||||
{
|
||||
"value": table.table,
|
||||
"type": "table",
|
||||
"extra": extra_dict_by_name.get(table.table, None),
|
||||
}
|
||||
for table in tables
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"value": view.table,
|
||||
"type": "view",
|
||||
}
|
||||
for view in views
|
||||
],
|
||||
key=lambda item: item["value"],
|
||||
)
|
||||
|
||||
payload = {"count": len(tables) + len(views), "result": options}
|
||||
return payload
|
||||
except SupersetException as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
raise DatabaseTablesUnexpectedError(ex) from ex
|
||||
|
||||
def validate(self) -> None:
|
||||
self._model = cast(Database, DatabaseDAO.find_by_id(self._db_id))
|
||||
if not self._model:
|
||||
raise DatabaseNotFoundError()
|
||||
@@ -1,231 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
import sqlite3
|
||||
from contextlib import closing
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import current_app as app
|
||||
from flask_babel import gettext as _
|
||||
from func_timeout import func_timeout, FunctionTimedOut
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import DBAPIError, NoSuchModuleError
|
||||
|
||||
from superset import is_feature_enabled
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.daos.database import DatabaseDAO, SSHTunnelDAO
|
||||
from superset.databases.commands.exceptions import (
|
||||
DatabaseSecurityUnsafeError,
|
||||
DatabaseTestConnectionDriverError,
|
||||
DatabaseTestConnectionUnexpectedError,
|
||||
)
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||
SSHTunnelingNotEnabledError,
|
||||
)
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.errors import ErrorLevel, SupersetErrorType
|
||||
from superset.exceptions import (
|
||||
SupersetErrorsException,
|
||||
SupersetSecurityException,
|
||||
SupersetTimeoutException,
|
||||
)
|
||||
from superset.extensions import event_logger
|
||||
from superset.models.core import Database
|
||||
from superset.utils.ssh_tunnel import unmask_password_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_log_connection_action(
|
||||
action: str, ssh_tunnel: Optional[Any], exc: Optional[Exception] = None
|
||||
) -> str:
|
||||
action_modified = action
|
||||
if exc:
|
||||
action_modified += f".{exc.__class__.__name__}"
|
||||
if ssh_tunnel:
|
||||
action_modified += ".ssh_tunnel"
|
||||
return action_modified
|
||||
|
||||
|
||||
class TestConnectionDatabaseCommand(BaseCommand):
|
||||
def __init__(self, data: dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
self._model: Optional[Database] = None
|
||||
|
||||
def run(self) -> None: # pylint: disable=too-many-statements, too-many-branches
|
||||
self.validate()
|
||||
ex_str = ""
|
||||
uri = self._properties.get("sqlalchemy_uri", "")
|
||||
if self._model and uri == self._model.safe_sqlalchemy_uri():
|
||||
uri = self._model.sqlalchemy_uri_decrypted
|
||||
ssh_tunnel = self._properties.get("ssh_tunnel")
|
||||
|
||||
# context for error messages
|
||||
url = make_url_safe(uri)
|
||||
context = {
|
||||
"hostname": url.host,
|
||||
"password": url.password,
|
||||
"port": url.port,
|
||||
"username": url.username,
|
||||
"database": url.database,
|
||||
}
|
||||
|
||||
serialized_encrypted_extra = self._properties.get(
|
||||
"masked_encrypted_extra",
|
||||
"{}",
|
||||
)
|
||||
if self._model:
|
||||
serialized_encrypted_extra = (
|
||||
self._model.db_engine_spec.unmask_encrypted_extra(
|
||||
self._model.encrypted_extra,
|
||||
serialized_encrypted_extra,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
database = DatabaseDAO.build_db_for_connection_test(
|
||||
server_cert=self._properties.get("server_cert", ""),
|
||||
extra=self._properties.get("extra", "{}"),
|
||||
impersonate_user=self._properties.get("impersonate_user", False),
|
||||
encrypted_extra=serialized_encrypted_extra,
|
||||
)
|
||||
|
||||
database.set_sqlalchemy_uri(uri)
|
||||
database.db_engine_spec.mutate_db_for_connection_test(database)
|
||||
|
||||
# Generate tunnel if present in the properties
|
||||
if ssh_tunnel:
|
||||
if not is_feature_enabled("SSH_TUNNELING"):
|
||||
raise SSHTunnelingNotEnabledError()
|
||||
# If there's an existing tunnel for that DB we need to use the stored
|
||||
# password, private_key and private_key_password instead
|
||||
if ssh_tunnel_id := ssh_tunnel.pop("id", None):
|
||||
if existing_ssh_tunnel := SSHTunnelDAO.find_by_id(ssh_tunnel_id):
|
||||
ssh_tunnel = unmask_password_info(
|
||||
ssh_tunnel, existing_ssh_tunnel
|
||||
)
|
||||
ssh_tunnel = SSHTunnel(**ssh_tunnel)
|
||||
|
||||
event_logger.log_with_context(
|
||||
action=get_log_connection_action("test_connection_attempt", ssh_tunnel),
|
||||
engine=database.db_engine_spec.__name__,
|
||||
)
|
||||
|
||||
def ping(engine: Engine) -> bool:
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
return engine.dialect.do_ping(conn)
|
||||
|
||||
with database.get_sqla_engine_with_context(
|
||||
override_ssh_tunnel=ssh_tunnel
|
||||
) as engine:
|
||||
try:
|
||||
alive = func_timeout(
|
||||
app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds(),
|
||||
ping,
|
||||
args=(engine,),
|
||||
)
|
||||
except (sqlite3.ProgrammingError, RuntimeError):
|
||||
# SQLite can't run on a separate thread, so ``func_timeout`` fails
|
||||
# RuntimeError catches the equivalent error from duckdb.
|
||||
alive = engine.dialect.do_ping(engine)
|
||||
except FunctionTimedOut as ex:
|
||||
raise SupersetTimeoutException(
|
||||
error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT,
|
||||
message=(
|
||||
"Please check your connection details and database settings, "
|
||||
"and ensure that your database is accepting connections, "
|
||||
"then try connecting again."
|
||||
),
|
||||
level=ErrorLevel.ERROR,
|
||||
extra={"sqlalchemy_uri": database.sqlalchemy_uri},
|
||||
) from ex
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
alive = False
|
||||
# So we stop losing the original message if any
|
||||
ex_str = str(ex)
|
||||
|
||||
if not alive:
|
||||
raise DBAPIError(ex_str or None, None, None)
|
||||
|
||||
# Log succesful connection test with engine
|
||||
event_logger.log_with_context(
|
||||
action=get_log_connection_action("test_connection_success", ssh_tunnel),
|
||||
engine=database.db_engine_spec.__name__,
|
||||
)
|
||||
|
||||
except (NoSuchModuleError, ModuleNotFoundError) as ex:
|
||||
event_logger.log_with_context(
|
||||
action=get_log_connection_action(
|
||||
"test_connection_error", ssh_tunnel, ex
|
||||
),
|
||||
engine=database.db_engine_spec.__name__,
|
||||
)
|
||||
raise DatabaseTestConnectionDriverError(
|
||||
message=_("Could not load database driver: {}").format(
|
||||
database.db_engine_spec.__name__
|
||||
),
|
||||
) from ex
|
||||
except DBAPIError as ex:
|
||||
event_logger.log_with_context(
|
||||
action=get_log_connection_action(
|
||||
"test_connection_error", ssh_tunnel, ex
|
||||
),
|
||||
engine=database.db_engine_spec.__name__,
|
||||
)
|
||||
# check for custom errors (wrong username, wrong password, etc)
|
||||
errors = database.db_engine_spec.extract_errors(ex, context)
|
||||
raise SupersetErrorsException(errors) from ex
|
||||
except SupersetSecurityException as ex:
|
||||
event_logger.log_with_context(
|
||||
action=get_log_connection_action(
|
||||
"test_connection_error", ssh_tunnel, ex
|
||||
),
|
||||
engine=database.db_engine_spec.__name__,
|
||||
)
|
||||
raise DatabaseSecurityUnsafeError(message=str(ex)) from ex
|
||||
except SupersetTimeoutException as ex:
|
||||
event_logger.log_with_context(
|
||||
action=get_log_connection_action(
|
||||
"test_connection_error", ssh_tunnel, ex
|
||||
),
|
||||
engine=database.db_engine_spec.__name__,
|
||||
)
|
||||
# bubble up the exception to return a 408
|
||||
raise ex
|
||||
except SSHTunnelingNotEnabledError as ex:
|
||||
event_logger.log_with_context(
|
||||
action=get_log_connection_action(
|
||||
"test_connection_error", ssh_tunnel, ex
|
||||
),
|
||||
engine=database.db_engine_spec.__name__,
|
||||
)
|
||||
# bubble up the exception to return a 400
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
event_logger.log_with_context(
|
||||
action=get_log_connection_action(
|
||||
"test_connection_error", ssh_tunnel, ex
|
||||
),
|
||||
engine=database.db_engine_spec.__name__,
|
||||
)
|
||||
errors = database.db_engine_spec.extract_errors(ex, context)
|
||||
raise DatabaseTestConnectionUnexpectedError(errors) from ex
|
||||
|
||||
def validate(self) -> None:
|
||||
if (database_name := self._properties.get("database_name")) is not None:
|
||||
self._model = DatabaseDAO.get_database_by_name(database_name)
|
||||
@@ -1,182 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from marshmallow import ValidationError
|
||||
|
||||
from superset import is_feature_enabled
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.daos.exceptions import DAOCreateFailedError, DAOUpdateFailedError
|
||||
from superset.databases.commands.exceptions import (
|
||||
DatabaseConnectionFailedError,
|
||||
DatabaseExistsValidationError,
|
||||
DatabaseInvalidError,
|
||||
DatabaseNotFoundError,
|
||||
DatabaseUpdateFailedError,
|
||||
)
|
||||
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||
SSHTunnelCreateFailedError,
|
||||
SSHTunnelingNotEnabledError,
|
||||
SSHTunnelInvalidError,
|
||||
SSHTunnelUpdateFailedError,
|
||||
)
|
||||
from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand
|
||||
from superset.extensions import db, security_manager
|
||||
from superset.models.core import Database
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UpdateDatabaseCommand(BaseCommand):
|
||||
def __init__(self, model_id: int, data: dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
self._model_id = model_id
|
||||
self._model: Optional[Database] = None
|
||||
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
if not self._model:
|
||||
raise DatabaseNotFoundError()
|
||||
old_database_name = self._model.database_name
|
||||
|
||||
# unmask ``encrypted_extra``
|
||||
self._properties[
|
||||
"encrypted_extra"
|
||||
] = self._model.db_engine_spec.unmask_encrypted_extra(
|
||||
self._model.encrypted_extra,
|
||||
self._properties.pop("masked_encrypted_extra", "{}"),
|
||||
)
|
||||
|
||||
try:
|
||||
database = DatabaseDAO.update(self._model, self._properties, commit=False)
|
||||
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
||||
|
||||
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
|
||||
if not is_feature_enabled("SSH_TUNNELING"):
|
||||
db.session.rollback()
|
||||
raise SSHTunnelingNotEnabledError()
|
||||
existing_ssh_tunnel_model = DatabaseDAO.get_ssh_tunnel(database.id)
|
||||
if existing_ssh_tunnel_model is None:
|
||||
# We couldn't found an existing tunnel so we need to create one
|
||||
try:
|
||||
CreateSSHTunnelCommand(database.id, ssh_tunnel_properties).run()
|
||||
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
|
||||
# So we can show the original message
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
raise DatabaseUpdateFailedError() from ex
|
||||
else:
|
||||
# We found an existing tunnel so we need to update it
|
||||
try:
|
||||
UpdateSSHTunnelCommand(
|
||||
existing_ssh_tunnel_model.id, ssh_tunnel_properties
|
||||
).run()
|
||||
except (SSHTunnelInvalidError, SSHTunnelUpdateFailedError) as ex:
|
||||
# So we can show the original message
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
raise DatabaseUpdateFailedError() from ex
|
||||
|
||||
# adding a new database we always want to force refresh schema list
|
||||
# TODO Improve this simplistic implementation for catching DB conn fails
|
||||
try:
|
||||
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
|
||||
schemas = database.get_all_schema_names(ssh_tunnel=ssh_tunnel)
|
||||
except Exception as ex:
|
||||
db.session.rollback()
|
||||
raise DatabaseConnectionFailedError() from ex
|
||||
|
||||
# Update database schema permissions
|
||||
new_schemas: list[str] = []
|
||||
|
||||
for schema in schemas:
|
||||
old_view_menu_name = security_manager.get_schema_perm(
|
||||
old_database_name, schema
|
||||
)
|
||||
new_view_menu_name = security_manager.get_schema_perm(
|
||||
database.database_name, schema
|
||||
)
|
||||
schema_pvm = security_manager.find_permission_view_menu(
|
||||
"schema_access", old_view_menu_name
|
||||
)
|
||||
# Update the schema permission if the database name changed
|
||||
if schema_pvm and old_database_name != database.database_name:
|
||||
schema_pvm.view_menu.name = new_view_menu_name
|
||||
|
||||
self._propagate_schema_permissions(
|
||||
old_view_menu_name, new_view_menu_name
|
||||
)
|
||||
else:
|
||||
new_schemas.append(schema)
|
||||
for schema in new_schemas:
|
||||
security_manager.add_permission_view_menu(
|
||||
"schema_access", security_manager.get_schema_perm(database, schema)
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
except (DAOUpdateFailedError, DAOCreateFailedError) as ex:
|
||||
raise DatabaseUpdateFailedError() from ex
|
||||
return database
|
||||
|
||||
@staticmethod
|
||||
def _propagate_schema_permissions(
|
||||
old_view_menu_name: str, new_view_menu_name: str
|
||||
) -> None:
|
||||
from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel
|
||||
SqlaTable,
|
||||
)
|
||||
from superset.models.slice import ( # pylint: disable=import-outside-toplevel
|
||||
Slice,
|
||||
)
|
||||
|
||||
# Update schema_perm on all datasets
|
||||
datasets = (
|
||||
db.session.query(SqlaTable)
|
||||
.filter(SqlaTable.schema_perm == old_view_menu_name)
|
||||
.all()
|
||||
)
|
||||
for dataset in datasets:
|
||||
dataset.schema_perm = new_view_menu_name
|
||||
charts = db.session.query(Slice).filter(
|
||||
Slice.datasource_type == DatasourceType.TABLE,
|
||||
Slice.datasource_id == dataset.id,
|
||||
)
|
||||
# Update schema_perm on all charts
|
||||
for chart in charts:
|
||||
chart.schema_perm = new_view_menu_name
|
||||
|
||||
def validate(self) -> None:
|
||||
exceptions: list[ValidationError] = []
|
||||
# Validate/populate model exists
|
||||
self._model = DatabaseDAO.find_by_id(self._model_id)
|
||||
if not self._model:
|
||||
raise DatabaseNotFoundError()
|
||||
database_name: Optional[str] = self._properties.get("database_name")
|
||||
if database_name:
|
||||
# Check database_name uniqueness
|
||||
if not DatabaseDAO.validate_update_uniqueness(
|
||||
self._model_id, database_name
|
||||
):
|
||||
exceptions.append(DatabaseExistsValidationError())
|
||||
if exceptions:
|
||||
raise DatabaseInvalidError(exceptions=exceptions)
|
||||
@@ -1,132 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import json
|
||||
from contextlib import closing
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask_babel import gettext as __
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.databases.commands.exceptions import (
|
||||
DatabaseOfflineError,
|
||||
DatabaseTestConnectionFailedError,
|
||||
InvalidEngineError,
|
||||
InvalidParametersError,
|
||||
)
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.db_engine_specs import get_engine_spec
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.extensions import event_logger
|
||||
from superset.models.core import Database
|
||||
|
||||
BYPASS_VALIDATION_ENGINES = {"bigquery"}
|
||||
|
||||
|
||||
class ValidateDatabaseParametersCommand(BaseCommand):
|
||||
def __init__(self, properties: dict[str, Any]):
|
||||
self._properties = properties.copy()
|
||||
self._model: Optional[Database] = None
|
||||
|
||||
def run(self) -> None:
|
||||
self.validate()
|
||||
|
||||
engine = self._properties["engine"]
|
||||
driver = self._properties.get("driver")
|
||||
|
||||
if engine in BYPASS_VALIDATION_ENGINES:
|
||||
# Skip engines that are only validated onCreate
|
||||
return
|
||||
|
||||
engine_spec = get_engine_spec(engine, driver)
|
||||
if not hasattr(engine_spec, "parameters_schema"):
|
||||
raise InvalidEngineError(
|
||||
SupersetError(
|
||||
message=__(
|
||||
'Engine "%(engine)s" cannot be configured through parameters.',
|
||||
engine=engine,
|
||||
),
|
||||
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
|
||||
level=ErrorLevel.ERROR,
|
||||
),
|
||||
)
|
||||
|
||||
# perform initial validation
|
||||
errors = engine_spec.validate_parameters(self._properties) # type: ignore
|
||||
if errors:
|
||||
event_logger.log_with_context(action="validation_error", engine=engine)
|
||||
raise InvalidParametersError(errors)
|
||||
|
||||
serialized_encrypted_extra = self._properties.get(
|
||||
"masked_encrypted_extra",
|
||||
"{}",
|
||||
)
|
||||
if self._model:
|
||||
serialized_encrypted_extra = engine_spec.unmask_encrypted_extra(
|
||||
self._model.encrypted_extra,
|
||||
serialized_encrypted_extra,
|
||||
)
|
||||
try:
|
||||
encrypted_extra = json.loads(serialized_encrypted_extra)
|
||||
except json.decoder.JSONDecodeError:
|
||||
encrypted_extra = {}
|
||||
|
||||
# try to connect
|
||||
sqlalchemy_uri = engine_spec.build_sqlalchemy_uri( # type: ignore
|
||||
self._properties.get("parameters"),
|
||||
encrypted_extra,
|
||||
)
|
||||
if self._model and sqlalchemy_uri == self._model.safe_sqlalchemy_uri():
|
||||
sqlalchemy_uri = self._model.sqlalchemy_uri_decrypted
|
||||
database = DatabaseDAO.build_db_for_connection_test(
|
||||
server_cert=self._properties.get("server_cert", ""),
|
||||
extra=self._properties.get("extra", "{}"),
|
||||
impersonate_user=self._properties.get("impersonate_user", False),
|
||||
encrypted_extra=serialized_encrypted_extra,
|
||||
)
|
||||
database.set_sqlalchemy_uri(sqlalchemy_uri)
|
||||
database.db_engine_spec.mutate_db_for_connection_test(database)
|
||||
|
||||
alive = False
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
try:
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
alive = engine.dialect.do_ping(conn)
|
||||
except Exception as ex:
|
||||
url = make_url_safe(sqlalchemy_uri)
|
||||
context = {
|
||||
"hostname": url.host,
|
||||
"password": url.password,
|
||||
"port": url.port,
|
||||
"username": url.username,
|
||||
"database": url.database,
|
||||
}
|
||||
errors = database.db_engine_spec.extract_errors(ex, context)
|
||||
raise DatabaseTestConnectionFailedError(errors) from ex
|
||||
|
||||
if not alive:
|
||||
raise DatabaseOfflineError(
|
||||
SupersetError(
|
||||
message=__("Database is offline."),
|
||||
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
|
||||
level=ErrorLevel.ERROR,
|
||||
),
|
||||
)
|
||||
|
||||
def validate(self) -> None:
|
||||
if (database_id := self._properties.get("id")) is not None:
|
||||
self._model = DatabaseDAO.find_by_id(database_id)
|
||||
@@ -1,117 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import current_app
|
||||
from flask_babel import gettext as __
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.databases.commands.exceptions import (
|
||||
DatabaseNotFoundError,
|
||||
NoValidatorConfigFoundError,
|
||||
NoValidatorFoundError,
|
||||
ValidatorSQL400Error,
|
||||
ValidatorSQLError,
|
||||
ValidatorSQLUnexpectedError,
|
||||
)
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.models.core import Database
|
||||
from superset.sql_validators import get_validator_by_name
|
||||
from superset.sql_validators.base import BaseSQLValidator
|
||||
from superset.utils import core as utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ValidateSQLCommand(BaseCommand):
|
||||
def __init__(self, model_id: int, data: dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
self._model_id = model_id
|
||||
self._model: Optional[Database] = None
|
||||
self._validator: Optional[type[BaseSQLValidator]] = None
|
||||
|
||||
def run(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Validates a SQL statement
|
||||
|
||||
:return: A List of SQLValidationAnnotation
|
||||
:raises: DatabaseNotFoundError, NoValidatorConfigFoundError
|
||||
NoValidatorFoundError, ValidatorSQLUnexpectedError, ValidatorSQLError
|
||||
ValidatorSQL400Error
|
||||
"""
|
||||
self.validate()
|
||||
if not self._validator or not self._model:
|
||||
raise ValidatorSQLUnexpectedError()
|
||||
sql = self._properties["sql"]
|
||||
schema = self._properties.get("schema")
|
||||
try:
|
||||
timeout = current_app.config["SQLLAB_VALIDATION_TIMEOUT"]
|
||||
timeout_msg = f"The query exceeded the {timeout} seconds timeout."
|
||||
with utils.timeout(seconds=timeout, error_message=timeout_msg):
|
||||
errors = self._validator.validate(sql, schema, self._model)
|
||||
return [err.to_dict() for err in errors]
|
||||
except Exception as ex:
|
||||
logger.exception(ex)
|
||||
superset_error = SupersetError(
|
||||
message=__(
|
||||
"%(validator)s was unable to check your query.\n"
|
||||
"Please recheck your query.\n"
|
||||
"Exception: %(ex)s",
|
||||
validator=self._validator.name,
|
||||
ex=ex,
|
||||
),
|
||||
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
|
||||
level=ErrorLevel.ERROR,
|
||||
)
|
||||
|
||||
# Return as a 400 if the database error message says we got a 4xx error
|
||||
if re.search(r"([\W]|^)4\d{2}([\W]|$)", str(ex)):
|
||||
raise ValidatorSQL400Error(superset_error) from ex
|
||||
raise ValidatorSQLError(superset_error) from ex
|
||||
|
||||
def validate(self) -> None:
|
||||
# Validate/populate model exists
|
||||
self._model = DatabaseDAO.find_by_id(self._model_id)
|
||||
if not self._model:
|
||||
raise DatabaseNotFoundError()
|
||||
|
||||
spec = self._model.db_engine_spec
|
||||
validators_by_engine = current_app.config["SQL_VALIDATORS_BY_ENGINE"]
|
||||
if not validators_by_engine or spec.engine not in validators_by_engine:
|
||||
raise NoValidatorConfigFoundError(
|
||||
SupersetError(
|
||||
message=__(f"no SQL validator is configured for {spec.engine}"),
|
||||
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
|
||||
level=ErrorLevel.ERROR,
|
||||
),
|
||||
)
|
||||
validator_name = validators_by_engine[spec.engine]
|
||||
self._validator = get_validator_by_name(validator_name)
|
||||
if not self._validator:
|
||||
raise NoValidatorFoundError(
|
||||
SupersetError(
|
||||
message=__(
|
||||
f"No validator named {validator_name} found "
|
||||
f"(configured for the {spec.engine} engine)"
|
||||
),
|
||||
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
|
||||
level=ErrorLevel.ERROR,
|
||||
),
|
||||
)
|
||||
@@ -28,13 +28,13 @@ from marshmallow.validate import Length, ValidationError
|
||||
from sqlalchemy import MetaData
|
||||
|
||||
from superset import db, is_feature_enabled
|
||||
from superset.constants import PASSWORD_MASK
|
||||
from superset.databases.commands.exceptions import DatabaseInvalidError
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||
from superset.commands.database.exceptions import DatabaseInvalidError
|
||||
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||
SSHTunnelingNotEnabledError,
|
||||
SSHTunnelInvalidCredentials,
|
||||
SSHTunnelMissingCredentials,
|
||||
)
|
||||
from superset.constants import PASSWORD_MASK
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.db_engine_specs import get_engine_spec
|
||||
from superset.exceptions import CertificateException, SupersetSecurityException
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
@@ -1,91 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from marshmallow import ValidationError
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.daos.database import SSHTunnelDAO
|
||||
from superset.daos.exceptions import DAOCreateFailedError
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||
SSHTunnelCreateFailedError,
|
||||
SSHTunnelInvalidError,
|
||||
SSHTunnelRequiredFieldValidationError,
|
||||
)
|
||||
from superset.extensions import db, event_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CreateSSHTunnelCommand(BaseCommand):
|
||||
def __init__(self, database_id: int, data: dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
self._properties["database_id"] = database_id
|
||||
|
||||
def run(self) -> Model:
|
||||
try:
|
||||
# Start nested transaction since we are always creating the tunnel
|
||||
# through a DB command (Create or Update). Without this, we cannot
|
||||
# safely rollback changes to databases if any, i.e, things like
|
||||
# test_do_not_create_database_if_ssh_tunnel_creation_fails test will fail
|
||||
db.session.begin_nested()
|
||||
self.validate()
|
||||
return SSHTunnelDAO.create(attributes=self._properties, commit=False)
|
||||
except DAOCreateFailedError as ex:
|
||||
# Rollback nested transaction
|
||||
db.session.rollback()
|
||||
raise SSHTunnelCreateFailedError() from ex
|
||||
except SSHTunnelInvalidError as ex:
|
||||
# Rollback nested transaction
|
||||
db.session.rollback()
|
||||
raise ex
|
||||
|
||||
def validate(self) -> None:
|
||||
# TODO(hughhh): check to make sure the server port is not localhost
|
||||
# using the config.SSH_TUNNEL_MANAGER
|
||||
exceptions: list[ValidationError] = []
|
||||
database_id: Optional[int] = self._properties.get("database_id")
|
||||
server_address: Optional[str] = self._properties.get("server_address")
|
||||
server_port: Optional[int] = self._properties.get("server_port")
|
||||
username: Optional[str] = self._properties.get("username")
|
||||
private_key: Optional[str] = self._properties.get("private_key")
|
||||
private_key_password: Optional[str] = self._properties.get(
|
||||
"private_key_password"
|
||||
)
|
||||
if not database_id:
|
||||
exceptions.append(SSHTunnelRequiredFieldValidationError("database_id"))
|
||||
if not server_address:
|
||||
exceptions.append(SSHTunnelRequiredFieldValidationError("server_address"))
|
||||
if not server_port:
|
||||
exceptions.append(SSHTunnelRequiredFieldValidationError("server_port"))
|
||||
if not username:
|
||||
exceptions.append(SSHTunnelRequiredFieldValidationError("username"))
|
||||
if private_key_password and private_key is None:
|
||||
exceptions.append(SSHTunnelRequiredFieldValidationError("private_key"))
|
||||
if exceptions:
|
||||
exception = SSHTunnelInvalidError()
|
||||
exception.extend(exceptions)
|
||||
event_logger.log_with_context(
|
||||
# pylint: disable=consider-using-f-string
|
||||
action="ssh_tunnel_creation_failed.{}.{}".format(
|
||||
exception.__class__.__name__,
|
||||
".".join(exception.get_list_classnames()),
|
||||
)
|
||||
)
|
||||
raise exception
|
||||
@@ -1,54 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from superset import is_feature_enabled
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.daos.database import SSHTunnelDAO
|
||||
from superset.daos.exceptions import DAODeleteFailedError
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||
SSHTunnelDeleteFailedError,
|
||||
SSHTunnelingNotEnabledError,
|
||||
SSHTunnelNotFoundError,
|
||||
)
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeleteSSHTunnelCommand(BaseCommand):
|
||||
def __init__(self, model_id: int):
|
||||
self._model_id = model_id
|
||||
self._model: Optional[SSHTunnel] = None
|
||||
|
||||
def run(self) -> None:
|
||||
if not is_feature_enabled("SSH_TUNNELING"):
|
||||
raise SSHTunnelingNotEnabledError()
|
||||
self.validate()
|
||||
assert self._model
|
||||
|
||||
try:
|
||||
SSHTunnelDAO.delete([self._model])
|
||||
except DAODeleteFailedError as ex:
|
||||
raise SSHTunnelDeleteFailedError() from ex
|
||||
|
||||
def validate(self) -> None:
|
||||
# Validate/populate model exists
|
||||
self._model = SSHTunnelDAO.find_by_id(self._model_id)
|
||||
if not self._model:
|
||||
raise SSHTunnelNotFoundError()
|
||||
@@ -1,67 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from flask_babel import lazy_gettext as _
|
||||
from marshmallow import ValidationError
|
||||
|
||||
from superset.commands.exceptions import (
|
||||
CommandException,
|
||||
CommandInvalidError,
|
||||
DeleteFailedError,
|
||||
UpdateFailedError,
|
||||
)
|
||||
|
||||
|
||||
class SSHTunnelDeleteFailedError(DeleteFailedError):
|
||||
message = _("SSH Tunnel could not be deleted.")
|
||||
|
||||
|
||||
class SSHTunnelNotFoundError(CommandException):
|
||||
status = 404
|
||||
message = _("SSH Tunnel not found.")
|
||||
|
||||
|
||||
class SSHTunnelInvalidError(CommandInvalidError):
|
||||
message = _("SSH Tunnel parameters are invalid.")
|
||||
|
||||
|
||||
class SSHTunnelUpdateFailedError(UpdateFailedError):
|
||||
message = _("SSH Tunnel could not be updated.")
|
||||
|
||||
|
||||
class SSHTunnelCreateFailedError(CommandException):
|
||||
message = _("Creating SSH Tunnel failed for an unknown reason")
|
||||
|
||||
|
||||
class SSHTunnelingNotEnabledError(CommandException):
|
||||
status = 400
|
||||
message = _("SSH Tunneling is not enabled")
|
||||
|
||||
|
||||
class SSHTunnelRequiredFieldValidationError(ValidationError):
|
||||
def __init__(self, field_name: str) -> None:
|
||||
super().__init__(
|
||||
[_("Field is required")],
|
||||
field_name=field_name,
|
||||
)
|
||||
|
||||
|
||||
class SSHTunnelMissingCredentials(CommandInvalidError):
|
||||
message = _("Must provide credentials for the SSH Tunnel")
|
||||
|
||||
|
||||
class SSHTunnelInvalidCredentials(CommandInvalidError):
|
||||
message = _("Cannot have multiple credentials for the SSH Tunnel")
|
||||
@@ -1,63 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.daos.database import SSHTunnelDAO
|
||||
from superset.daos.exceptions import DAOUpdateFailedError
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||
SSHTunnelInvalidError,
|
||||
SSHTunnelNotFoundError,
|
||||
SSHTunnelRequiredFieldValidationError,
|
||||
SSHTunnelUpdateFailedError,
|
||||
)
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UpdateSSHTunnelCommand(BaseCommand):
|
||||
def __init__(self, model_id: int, data: dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
self._model_id = model_id
|
||||
self._model: Optional[SSHTunnel] = None
|
||||
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
try:
|
||||
if self._model is not None: # So we dont get incompatible types error
|
||||
tunnel = SSHTunnelDAO.update(self._model, self._properties)
|
||||
except DAOUpdateFailedError as ex:
|
||||
raise SSHTunnelUpdateFailedError() from ex
|
||||
return tunnel
|
||||
|
||||
def validate(self) -> None:
|
||||
# Validate/populate model exists
|
||||
self._model = SSHTunnelDAO.find_by_id(self._model_id)
|
||||
if not self._model:
|
||||
raise SSHTunnelNotFoundError()
|
||||
private_key: Optional[str] = self._properties.get("private_key")
|
||||
private_key_password: Optional[str] = self._properties.get(
|
||||
"private_key_password"
|
||||
)
|
||||
if private_key_password and private_key is None:
|
||||
raise SSHTunnelInvalidError(
|
||||
exceptions=[SSHTunnelRequiredFieldValidationError("private_key")]
|
||||
)
|
||||
@@ -18,7 +18,7 @@ from typing import Any, Optional, Union
|
||||
|
||||
from sqlalchemy.engine.url import make_url, URL
|
||||
|
||||
from superset.databases.commands.exceptions import DatabaseInvalidError
|
||||
from superset.commands.database.exceptions import DatabaseInvalidError
|
||||
|
||||
|
||||
def get_foreign_keys_metadata(
|
||||
|
||||
Reference in New Issue
Block a user