fix: improve get_db_engine_spec_for_backend (#21171)

* fix: improve get_db_engine_spec_for_backend

* Fix tests

* Fix docs

* fix lint

* fix fallback

* Fix engine validation

* Fix test
This commit is contained in:
Beto Dealmeida
2022-08-29 13:42:42 -05:00
committed by GitHub
parent 710a8ce5c0
commit 8772e2cdb3
13 changed files with 309 additions and 130 deletions

View File

@@ -1083,8 +1083,8 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
"preferred": engine_spec.engine_name in preferred_databases,
}
if hasattr(engine_spec, "default_driver"):
payload["default_driver"] = engine_spec.default_driver # type: ignore
if engine_spec.default_driver:
payload["default_driver"] = engine_spec.default_driver
# show configuration parameters for DBs that support it
if (

View File

@@ -29,8 +29,7 @@ from superset.databases.commands.exceptions import (
)
from superset.databases.dao import DatabaseDAO
from superset.databases.utils import make_url_safe
from superset.db_engine_specs import get_engine_specs
from superset.db_engine_specs.base import BasicParametersMixin
from superset.db_engine_specs import get_engine_spec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.extensions import event_logger
from superset.models.core import Database
@@ -45,25 +44,13 @@ class ValidateDatabaseParametersCommand(BaseCommand):
def run(self) -> None:
engine = self._properties["engine"]
engine_specs = get_engine_specs()
driver = self._properties.get("driver")
if engine in BYPASS_VALIDATION_ENGINES:
# Skip engines that are only validated onCreate
return
if engine not in engine_specs:
raise InvalidEngineError(
SupersetError(
message=__(
'Engine "%(engine)s" is not a valid engine.',
engine=engine,
),
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
level=ErrorLevel.ERROR,
extra={"allowed": list(engine_specs), "provided": engine},
),
)
engine_spec = engine_specs[engine]
engine_spec = get_engine_spec(engine, driver)
if not hasattr(engine_spec, "parameters_schema"):
raise InvalidEngineError(
SupersetError(
@@ -73,14 +60,6 @@ class ValidateDatabaseParametersCommand(BaseCommand):
),
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
level=ErrorLevel.ERROR,
extra={
"allowed": [
name
for name, engine_spec in engine_specs.items()
if issubclass(engine_spec, BasicParametersMixin)
],
"provided": engine,
},
),
)

View File

@@ -16,7 +16,7 @@
# under the License.
import inspect
import json
from typing import Any, Dict, Optional, Type
from typing import Any, Dict
from flask import current_app
from flask_babel import lazy_gettext as _
@@ -28,7 +28,7 @@ from sqlalchemy import MetaData
from superset import db
from superset.databases.commands.exceptions import DatabaseInvalidError
from superset.databases.utils import make_url_safe
from superset.db_engine_specs import BaseEngineSpec, get_engine_specs
from superset.db_engine_specs import get_engine_spec
from superset.exceptions import CertificateException, SupersetSecurityException
from superset.models.core import ConfigurationMethod, Database, PASSWORD_MASK
from superset.security.analytics_db_safety import check_sqlalchemy_uri
@@ -150,7 +150,7 @@ def sqlalchemy_uri_validator(value: str) -> str:
[
_(
"Invalid connection string, a valid string usually follows: "
"driver://user:password@database-host/database-name"
"backend+driver://user:password@database-host/database-name"
)
]
) from ex
@@ -231,6 +231,7 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods
"""
engine = fields.String(allow_none=True, description="SQLAlchemy engine to use")
driver = fields.String(allow_none=True, description="SQLAlchemy driver to use")
parameters = fields.Dict(
keys=fields.String(),
values=fields.Raw(),
@@ -262,10 +263,20 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods
or parameters.pop("engine", None)
or data.pop("backend", None)
)
driver = data.pop("driver", None)
configuration_method = data.get("configuration_method")
if configuration_method == ConfigurationMethod.DYNAMIC_FORM:
engine_spec = get_engine_spec(engine)
if not engine:
raise ValidationError(
[
_(
"An engine must be specified when passing "
"individual parameters to a database."
)
]
)
engine_spec = get_engine_spec(engine, driver)
if not hasattr(engine_spec, "build_sqlalchemy_uri") or not hasattr(
engine_spec, "parameters_schema"
@@ -295,34 +306,12 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods
return data
def get_engine_spec(engine: Optional[str]) -> Type[BaseEngineSpec]:
if not engine:
raise ValidationError(
[
_(
"An engine must be specified when passing "
"individual parameters to a database."
)
]
)
engine_specs = get_engine_specs()
if engine not in engine_specs:
raise ValidationError(
[
_(
'Engine "%(engine)s" is not a valid engine.',
engine=engine,
)
]
)
return engine_specs[engine]
class DatabaseValidateParametersSchema(Schema):
class Meta: # pylint: disable=too-few-public-methods
unknown = EXCLUDE
engine = fields.String(required=True, description="SQLAlchemy engine to use")
driver = fields.String(allow_none=True, description="SQLAlchemy driver to use")
parameters = fields.Dict(
keys=fields.String(),
values=fields.Raw(allow_none=True),