fix: improve get_db_engine_spec_for_backend (#21171)

* fix: improve get_db_engine_spec_for_backend

* Fix tests

* Fix docs

* fix lint

* fix fallback

* Fix engine validation

* Fix test
This commit is contained in:
Beto Dealmeida
2022-08-29 13:42:42 -05:00
committed by GitHub
parent 710a8ce5c0
commit 8772e2cdb3
13 changed files with 309 additions and 130 deletions

View File

@@ -1425,7 +1425,7 @@ class TestDatabaseApi(SupersetTestCase):
expected_response = {
"errors": [
{
"message": "Could not load database driver: AzureSynapseSpec",
"message": "Could not load database driver: MssqlEngineSpec",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {

View File

@@ -20,7 +20,7 @@ from unittest import mock
import pytest
from superset.connectors.sqla.models import TableColumn
from superset.db_engine_specs import get_engine_specs
from superset.db_engine_specs import load_engine_specs
from superset.db_engine_specs.base import (
BaseEngineSpec,
BasicParametersMixin,
@@ -195,7 +195,7 @@ class TestDbEngineSpecs(TestDbEngineSpec):
def test_engine_time_grain_validity(self):
time_grains = set(builtin_time_grains.keys())
# loop over all subclasses of BaseEngineSpec
for engine in get_engine_specs().values():
for engine in load_engine_specs():
if engine is not BaseEngineSpec:
# make sure time grain functions have been defined
self.assertGreater(len(engine.get_time_grain_expressions()), 0)

View File

@@ -20,7 +20,7 @@ from unittest import mock
from sqlalchemy import column, literal_column
from sqlalchemy.dialects import postgresql
from superset.db_engine_specs import get_engine_specs
from superset.db_engine_specs import load_engine_specs
from superset.db_engine_specs.postgres import PostgresEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query
@@ -137,7 +137,11 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
"""
DB Eng Specs (postgres): Test "postgres" in engine spec
"""
self.assertIn("postgres", get_engine_specs())
backends = set()
for engine in load_engine_specs():
backends.add(engine.engine)
backends.update(engine.engine_aliases)
assert "postgres" in backends
def test_extras_without_ssl(self):
db = mock.Mock()

View File

@@ -15,31 +15,59 @@
# specific language governing permissions and limitations
# under the License.
from unittest import mock
# pylint: disable=import-outside-toplevel, invalid-name, unused-argument, redefined-outer-name
from typing import TYPE_CHECKING
import pytest
from marshmallow import fields, Schema, ValidationError
from pytest_mock import MockFixture
from superset.databases.schemas import DatabaseParametersSchemaMixin
from superset.db_engine_specs.base import BasicParametersMixin
from superset.models.core import ConfigurationMethod
class DummySchema(Schema, DatabaseParametersSchemaMixin):
sqlalchemy_uri = fields.String()
class DummyEngine(BasicParametersMixin):
engine = "dummy"
default_driver = "dummy"
if TYPE_CHECKING:
from superset.databases.schemas import DatabaseParametersSchemaMixin
from superset.db_engine_specs.base import BasicParametersMixin
# pylint: disable=too-few-public-methods
class InvalidEngine:
pass
"""
An invalid DB engine spec.
"""
@mock.patch("superset.databases.schemas.get_engine_specs")
def test_database_parameters_schema_mixin(get_engine_specs):
get_engine_specs.return_value = {"dummy_engine": DummyEngine}
@pytest.fixture
def dummy_schema() -> "DatabaseParametersSchemaMixin":
"""
Fixture providing a dummy schema.
"""
from superset.databases.schemas import DatabaseParametersSchemaMixin
class DummySchema(Schema, DatabaseParametersSchemaMixin):
sqlalchemy_uri = fields.String()
return DummySchema()
@pytest.fixture
def dummy_engine(mocker: MockFixture) -> None:
"""
Fixture proving a dummy DB engine spec.
"""
from superset.db_engine_specs.base import BasicParametersMixin
class DummyEngine(BasicParametersMixin):
engine = "dummy"
default_driver = "dummy"
mocker.patch("superset.databases.schemas.get_engine_spec", return_value=DummyEngine)
def test_database_parameters_schema_mixin(
dummy_engine: None,
dummy_schema: "Schema",
) -> None:
from superset.models.core import ConfigurationMethod
payload = {
"engine": "dummy_engine",
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
@@ -51,15 +79,18 @@ def test_database_parameters_schema_mixin(get_engine_specs):
"database": "dbname",
},
}
schema = DummySchema()
result = schema.load(payload)
result = dummy_schema.load(payload)
assert result == {
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
"sqlalchemy_uri": "dummy+dummy://username:password@localhost:12345/dbname",
}
def test_database_parameters_schema_mixin_no_engine():
def test_database_parameters_schema_mixin_no_engine(
dummy_schema: "Schema",
) -> None:
from superset.models.core import ConfigurationMethod
payload = {
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
"parameters": {
@@ -67,23 +98,28 @@ def test_database_parameters_schema_mixin_no_engine():
"password": "password",
"host": "localhost",
"port": 12345,
"dbname": "dbname",
"database": "dbname",
},
}
schema = DummySchema()
try:
schema.load(payload)
dummy_schema.load(payload)
except ValidationError as err:
assert err.messages == {
"_schema": [
"An engine must be specified when passing individual parameters to a database."
(
"An engine must be specified when passing individual parameters to "
"a database."
),
]
}
@mock.patch("superset.databases.schemas.get_engine_specs")
def test_database_parameters_schema_mixin_invalid_engine(get_engine_specs):
get_engine_specs.return_value = {}
def test_database_parameters_schema_mixin_invalid_engine(
dummy_engine: None,
dummy_schema: "Schema",
) -> None:
from superset.models.core import ConfigurationMethod
payload = {
"engine": "dummy_engine",
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
@@ -92,21 +128,24 @@ def test_database_parameters_schema_mixin_invalid_engine(get_engine_specs):
"password": "password",
"host": "localhost",
"port": 12345,
"dbname": "dbname",
"database": "dbname",
},
}
schema = DummySchema()
try:
schema.load(payload)
dummy_schema.load(payload)
except ValidationError as err:
print(err.messages)
assert err.messages == {
"_schema": ['Engine "dummy_engine" is not a valid engine.']
}
@mock.patch("superset.databases.schemas.get_engine_specs")
def test_database_parameters_schema_no_mixin(get_engine_specs):
get_engine_specs.return_value = {"invalid_engine": InvalidEngine}
def test_database_parameters_schema_no_mixin(
dummy_engine: None,
dummy_schema: "Schema",
) -> None:
from superset.models.core import ConfigurationMethod
payload = {
"engine": "invalid_engine",
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
@@ -118,9 +157,8 @@ def test_database_parameters_schema_no_mixin(get_engine_specs):
"database": "dbname",
},
}
schema = DummySchema()
try:
schema.load(payload)
dummy_schema.load(payload)
except ValidationError as err:
assert err.messages == {
"_schema": [
@@ -132,9 +170,12 @@ def test_database_parameters_schema_no_mixin(get_engine_specs):
}
@mock.patch("superset.databases.schemas.get_engine_specs")
def test_database_parameters_schema_mixin_invalid_type(get_engine_specs):
get_engine_specs.return_value = {"dummy_engine": DummyEngine}
def test_database_parameters_schema_mixin_invalid_type(
dummy_engine: None,
dummy_schema: "Schema",
) -> None:
from superset.models.core import ConfigurationMethod
payload = {
"engine": "dummy_engine",
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
@@ -146,8 +187,7 @@ def test_database_parameters_schema_mixin_invalid_type(get_engine_specs):
"database": "dbname",
},
}
schema = DummySchema()
try:
schema.load(payload)
dummy_schema.load(payload)
except ValidationError as err:
assert err.messages == {"port": ["Not a valid integer."]}

View File

@@ -59,7 +59,7 @@ def test_get_metrics(mocker: MockFixture) -> None:
},
]
database.get_db_engine_spec_for_backend = mocker.MagicMock( # type: ignore
database.get_db_engine_spec = mocker.MagicMock( # type: ignore
return_value=CustomSqliteEngineSpec
)
assert database.get_metrics("table") == [
@@ -70,3 +70,78 @@ def test_get_metrics(mocker: MockFixture) -> None:
"verbose_name": "COUNT(DISTINCT user_id)",
},
]
def test_get_db_engine_spec(mocker: MockFixture) -> None:
"""
Tests for ``get_db_engine_spec``.
"""
from superset.db_engine_specs import BaseEngineSpec
from superset.models.core import Database
# pylint: disable=abstract-method
class PostgresDBEngineSpec(BaseEngineSpec):
"""
A DB engine spec with drivers and a default driver.
"""
engine = "postgresql"
engine_aliases = {"postgres"}
drivers = {
"psycopg2": "The default Postgres driver",
"asyncpg": "An async Postgres driver",
}
default_driver = "psycopg2"
# pylint: disable=abstract-method
class OldDBEngineSpec(BaseEngineSpec):
"""
And old DB engine spec without drivers nor a default driver.
"""
engine = "mysql"
load_engine_specs = mocker.patch("superset.db_engine_specs.load_engine_specs")
load_engine_specs.return_value = [
PostgresDBEngineSpec,
OldDBEngineSpec,
]
assert (
Database(database_name="db", sqlalchemy_uri="postgresql://").db_engine_spec
== PostgresDBEngineSpec
)
assert (
Database(
database_name="db", sqlalchemy_uri="postgresql+psycopg2://"
).db_engine_spec
== PostgresDBEngineSpec
)
assert (
Database(
database_name="db", sqlalchemy_uri="postgresql+asyncpg://"
).db_engine_spec
== PostgresDBEngineSpec
)
assert (
Database(
database_name="db", sqlalchemy_uri="postgresql+fancynewdriver://"
).db_engine_spec
== PostgresDBEngineSpec
)
assert (
Database(database_name="db", sqlalchemy_uri="mysql://").db_engine_spec
== OldDBEngineSpec
)
assert (
Database(
database_name="db", sqlalchemy_uri="mysql+mysqlconnector://"
).db_engine_spec
== OldDBEngineSpec
)
assert (
Database(
database_name="db", sqlalchemy_uri="mysql+fancynewdriver://"
).db_engine_spec
== OldDBEngineSpec
)