mirror of
https://github.com/apache/superset.git
synced 2026-04-19 16:14:52 +00:00
fix: improve get_db_engine_spec_for_backend (#21171)
* fix: improve get_db_engine_spec_for_backend * Fix tests * Fix docs * fix lint * fix fallback * Fix engine validation * Fix test
This commit is contained in:
@@ -1425,7 +1425,7 @@ class TestDatabaseApi(SupersetTestCase):
|
||||
expected_response = {
|
||||
"errors": [
|
||||
{
|
||||
"message": "Could not load database driver: AzureSynapseSpec",
|
||||
"message": "Could not load database driver: MssqlEngineSpec",
|
||||
"error_type": "GENERIC_COMMAND_ERROR",
|
||||
"level": "warning",
|
||||
"extra": {
|
||||
|
||||
@@ -20,7 +20,7 @@ from unittest import mock
|
||||
import pytest
|
||||
|
||||
from superset.connectors.sqla.models import TableColumn
|
||||
from superset.db_engine_specs import get_engine_specs
|
||||
from superset.db_engine_specs import load_engine_specs
|
||||
from superset.db_engine_specs.base import (
|
||||
BaseEngineSpec,
|
||||
BasicParametersMixin,
|
||||
@@ -195,7 +195,7 @@ class TestDbEngineSpecs(TestDbEngineSpec):
|
||||
def test_engine_time_grain_validity(self):
|
||||
time_grains = set(builtin_time_grains.keys())
|
||||
# loop over all subclasses of BaseEngineSpec
|
||||
for engine in get_engine_specs().values():
|
||||
for engine in load_engine_specs():
|
||||
if engine is not BaseEngineSpec:
|
||||
# make sure time grain functions have been defined
|
||||
self.assertGreater(len(engine.get_time_grain_expressions()), 0)
|
||||
|
||||
@@ -20,7 +20,7 @@ from unittest import mock
|
||||
from sqlalchemy import column, literal_column
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from superset.db_engine_specs import get_engine_specs
|
||||
from superset.db_engine_specs import load_engine_specs
|
||||
from superset.db_engine_specs.postgres import PostgresEngineSpec
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.models.sql_lab import Query
|
||||
@@ -137,7 +137,11 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
|
||||
"""
|
||||
DB Eng Specs (postgres): Test "postgres" in engine spec
|
||||
"""
|
||||
self.assertIn("postgres", get_engine_specs())
|
||||
backends = set()
|
||||
for engine in load_engine_specs():
|
||||
backends.add(engine.engine)
|
||||
backends.update(engine.engine_aliases)
|
||||
assert "postgres" in backends
|
||||
|
||||
def test_extras_without_ssl(self):
|
||||
db = mock.Mock()
|
||||
|
||||
@@ -15,31 +15,59 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from unittest import mock
|
||||
# pylint: disable=import-outside-toplevel, invalid-name, unused-argument, redefined-outer-name
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from marshmallow import fields, Schema, ValidationError
|
||||
from pytest_mock import MockFixture
|
||||
|
||||
from superset.databases.schemas import DatabaseParametersSchemaMixin
|
||||
from superset.db_engine_specs.base import BasicParametersMixin
|
||||
from superset.models.core import ConfigurationMethod
|
||||
|
||||
|
||||
class DummySchema(Schema, DatabaseParametersSchemaMixin):
|
||||
sqlalchemy_uri = fields.String()
|
||||
|
||||
|
||||
class DummyEngine(BasicParametersMixin):
|
||||
engine = "dummy"
|
||||
default_driver = "dummy"
|
||||
if TYPE_CHECKING:
|
||||
from superset.databases.schemas import DatabaseParametersSchemaMixin
|
||||
from superset.db_engine_specs.base import BasicParametersMixin
|
||||
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
class InvalidEngine:
|
||||
pass
|
||||
"""
|
||||
An invalid DB engine spec.
|
||||
"""
|
||||
|
||||
|
||||
@mock.patch("superset.databases.schemas.get_engine_specs")
|
||||
def test_database_parameters_schema_mixin(get_engine_specs):
|
||||
get_engine_specs.return_value = {"dummy_engine": DummyEngine}
|
||||
@pytest.fixture
|
||||
def dummy_schema() -> "DatabaseParametersSchemaMixin":
|
||||
"""
|
||||
Fixture providing a dummy schema.
|
||||
"""
|
||||
from superset.databases.schemas import DatabaseParametersSchemaMixin
|
||||
|
||||
class DummySchema(Schema, DatabaseParametersSchemaMixin):
|
||||
sqlalchemy_uri = fields.String()
|
||||
|
||||
return DummySchema()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_engine(mocker: MockFixture) -> None:
|
||||
"""
|
||||
Fixture proving a dummy DB engine spec.
|
||||
"""
|
||||
from superset.db_engine_specs.base import BasicParametersMixin
|
||||
|
||||
class DummyEngine(BasicParametersMixin):
|
||||
engine = "dummy"
|
||||
default_driver = "dummy"
|
||||
|
||||
mocker.patch("superset.databases.schemas.get_engine_spec", return_value=DummyEngine)
|
||||
|
||||
|
||||
def test_database_parameters_schema_mixin(
|
||||
dummy_engine: None,
|
||||
dummy_schema: "Schema",
|
||||
) -> None:
|
||||
from superset.models.core import ConfigurationMethod
|
||||
|
||||
payload = {
|
||||
"engine": "dummy_engine",
|
||||
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
|
||||
@@ -51,15 +79,18 @@ def test_database_parameters_schema_mixin(get_engine_specs):
|
||||
"database": "dbname",
|
||||
},
|
||||
}
|
||||
schema = DummySchema()
|
||||
result = schema.load(payload)
|
||||
result = dummy_schema.load(payload)
|
||||
assert result == {
|
||||
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
|
||||
"sqlalchemy_uri": "dummy+dummy://username:password@localhost:12345/dbname",
|
||||
}
|
||||
|
||||
|
||||
def test_database_parameters_schema_mixin_no_engine():
|
||||
def test_database_parameters_schema_mixin_no_engine(
|
||||
dummy_schema: "Schema",
|
||||
) -> None:
|
||||
from superset.models.core import ConfigurationMethod
|
||||
|
||||
payload = {
|
||||
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
|
||||
"parameters": {
|
||||
@@ -67,23 +98,28 @@ def test_database_parameters_schema_mixin_no_engine():
|
||||
"password": "password",
|
||||
"host": "localhost",
|
||||
"port": 12345,
|
||||
"dbname": "dbname",
|
||||
"database": "dbname",
|
||||
},
|
||||
}
|
||||
schema = DummySchema()
|
||||
try:
|
||||
schema.load(payload)
|
||||
dummy_schema.load(payload)
|
||||
except ValidationError as err:
|
||||
assert err.messages == {
|
||||
"_schema": [
|
||||
"An engine must be specified when passing individual parameters to a database."
|
||||
(
|
||||
"An engine must be specified when passing individual parameters to "
|
||||
"a database."
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@mock.patch("superset.databases.schemas.get_engine_specs")
|
||||
def test_database_parameters_schema_mixin_invalid_engine(get_engine_specs):
|
||||
get_engine_specs.return_value = {}
|
||||
def test_database_parameters_schema_mixin_invalid_engine(
|
||||
dummy_engine: None,
|
||||
dummy_schema: "Schema",
|
||||
) -> None:
|
||||
from superset.models.core import ConfigurationMethod
|
||||
|
||||
payload = {
|
||||
"engine": "dummy_engine",
|
||||
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
|
||||
@@ -92,21 +128,24 @@ def test_database_parameters_schema_mixin_invalid_engine(get_engine_specs):
|
||||
"password": "password",
|
||||
"host": "localhost",
|
||||
"port": 12345,
|
||||
"dbname": "dbname",
|
||||
"database": "dbname",
|
||||
},
|
||||
}
|
||||
schema = DummySchema()
|
||||
try:
|
||||
schema.load(payload)
|
||||
dummy_schema.load(payload)
|
||||
except ValidationError as err:
|
||||
print(err.messages)
|
||||
assert err.messages == {
|
||||
"_schema": ['Engine "dummy_engine" is not a valid engine.']
|
||||
}
|
||||
|
||||
|
||||
@mock.patch("superset.databases.schemas.get_engine_specs")
|
||||
def test_database_parameters_schema_no_mixin(get_engine_specs):
|
||||
get_engine_specs.return_value = {"invalid_engine": InvalidEngine}
|
||||
def test_database_parameters_schema_no_mixin(
|
||||
dummy_engine: None,
|
||||
dummy_schema: "Schema",
|
||||
) -> None:
|
||||
from superset.models.core import ConfigurationMethod
|
||||
|
||||
payload = {
|
||||
"engine": "invalid_engine",
|
||||
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
|
||||
@@ -118,9 +157,8 @@ def test_database_parameters_schema_no_mixin(get_engine_specs):
|
||||
"database": "dbname",
|
||||
},
|
||||
}
|
||||
schema = DummySchema()
|
||||
try:
|
||||
schema.load(payload)
|
||||
dummy_schema.load(payload)
|
||||
except ValidationError as err:
|
||||
assert err.messages == {
|
||||
"_schema": [
|
||||
@@ -132,9 +170,12 @@ def test_database_parameters_schema_no_mixin(get_engine_specs):
|
||||
}
|
||||
|
||||
|
||||
@mock.patch("superset.databases.schemas.get_engine_specs")
|
||||
def test_database_parameters_schema_mixin_invalid_type(get_engine_specs):
|
||||
get_engine_specs.return_value = {"dummy_engine": DummyEngine}
|
||||
def test_database_parameters_schema_mixin_invalid_type(
|
||||
dummy_engine: None,
|
||||
dummy_schema: "Schema",
|
||||
) -> None:
|
||||
from superset.models.core import ConfigurationMethod
|
||||
|
||||
payload = {
|
||||
"engine": "dummy_engine",
|
||||
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
|
||||
@@ -146,8 +187,7 @@ def test_database_parameters_schema_mixin_invalid_type(get_engine_specs):
|
||||
"database": "dbname",
|
||||
},
|
||||
}
|
||||
schema = DummySchema()
|
||||
try:
|
||||
schema.load(payload)
|
||||
dummy_schema.load(payload)
|
||||
except ValidationError as err:
|
||||
assert err.messages == {"port": ["Not a valid integer."]}
|
||||
@@ -59,7 +59,7 @@ def test_get_metrics(mocker: MockFixture) -> None:
|
||||
},
|
||||
]
|
||||
|
||||
database.get_db_engine_spec_for_backend = mocker.MagicMock( # type: ignore
|
||||
database.get_db_engine_spec = mocker.MagicMock( # type: ignore
|
||||
return_value=CustomSqliteEngineSpec
|
||||
)
|
||||
assert database.get_metrics("table") == [
|
||||
@@ -70,3 +70,78 @@ def test_get_metrics(mocker: MockFixture) -> None:
|
||||
"verbose_name": "COUNT(DISTINCT user_id)",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_get_db_engine_spec(mocker: MockFixture) -> None:
|
||||
"""
|
||||
Tests for ``get_db_engine_spec``.
|
||||
"""
|
||||
from superset.db_engine_specs import BaseEngineSpec
|
||||
from superset.models.core import Database
|
||||
|
||||
# pylint: disable=abstract-method
|
||||
class PostgresDBEngineSpec(BaseEngineSpec):
|
||||
"""
|
||||
A DB engine spec with drivers and a default driver.
|
||||
"""
|
||||
|
||||
engine = "postgresql"
|
||||
engine_aliases = {"postgres"}
|
||||
drivers = {
|
||||
"psycopg2": "The default Postgres driver",
|
||||
"asyncpg": "An async Postgres driver",
|
||||
}
|
||||
default_driver = "psycopg2"
|
||||
|
||||
# pylint: disable=abstract-method
|
||||
class OldDBEngineSpec(BaseEngineSpec):
|
||||
"""
|
||||
And old DB engine spec without drivers nor a default driver.
|
||||
"""
|
||||
|
||||
engine = "mysql"
|
||||
|
||||
load_engine_specs = mocker.patch("superset.db_engine_specs.load_engine_specs")
|
||||
load_engine_specs.return_value = [
|
||||
PostgresDBEngineSpec,
|
||||
OldDBEngineSpec,
|
||||
]
|
||||
|
||||
assert (
|
||||
Database(database_name="db", sqlalchemy_uri="postgresql://").db_engine_spec
|
||||
== PostgresDBEngineSpec
|
||||
)
|
||||
assert (
|
||||
Database(
|
||||
database_name="db", sqlalchemy_uri="postgresql+psycopg2://"
|
||||
).db_engine_spec
|
||||
== PostgresDBEngineSpec
|
||||
)
|
||||
assert (
|
||||
Database(
|
||||
database_name="db", sqlalchemy_uri="postgresql+asyncpg://"
|
||||
).db_engine_spec
|
||||
== PostgresDBEngineSpec
|
||||
)
|
||||
assert (
|
||||
Database(
|
||||
database_name="db", sqlalchemy_uri="postgresql+fancynewdriver://"
|
||||
).db_engine_spec
|
||||
== PostgresDBEngineSpec
|
||||
)
|
||||
assert (
|
||||
Database(database_name="db", sqlalchemy_uri="mysql://").db_engine_spec
|
||||
== OldDBEngineSpec
|
||||
)
|
||||
assert (
|
||||
Database(
|
||||
database_name="db", sqlalchemy_uri="mysql+mysqlconnector://"
|
||||
).db_engine_spec
|
||||
== OldDBEngineSpec
|
||||
)
|
||||
assert (
|
||||
Database(
|
||||
database_name="db", sqlalchemy_uri="mysql+fancynewdriver://"
|
||||
).db_engine_spec
|
||||
== OldDBEngineSpec
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user