feat: improve engine spec discoverability (#14204)

* feat: improve engine spec discoverability

* Address comments

* Fix tests
This commit is contained in:
Beto Dealmeida
2021-04-19 17:24:22 -07:00
committed by GitHub
parent c2d11ac53e
commit 13bf023100
5 changed files with 57 additions and 24 deletions

View File

@@ -28,28 +28,57 @@ at all. The classes here will use a common interface to specify all this.
The general idea is to use static classes and an inheritance scheme.
"""
import inspect
import logging
import pkgutil
from importlib import import_module
from pathlib import Path
from typing import Dict, Type
from typing import Any, Dict, List, Type
from pkg_resources import iter_entry_points
from superset.db_engine_specs.base import BaseEngineSpec
engines: Dict[str, Type[BaseEngineSpec]] = {}
logger = logging.getLogger(__name__)
for (_, name, _) in pkgutil.iter_modules([Path(__file__).parent]): # type: ignore
imported_module = import_module("." + name, package=__name__)
for i in dir(imported_module):
attribute = getattr(imported_module, i)
def is_engine_spec(attr: Any) -> bool:
return (
inspect.isclass(attr)
and issubclass(attr, BaseEngineSpec)
and attr != BaseEngineSpec
)
if (
inspect.isclass(attribute)
and issubclass(attribute, BaseEngineSpec)
and attribute.engine != ""
):
engines[attribute.engine] = attribute
# populate engine alias name to engine dictionary
for engine_alias in attribute.engine_aliases or []:
engines[engine_alias] = attribute
def get_engine_specs() -> Dict[str, Type[BaseEngineSpec]]:
engine_specs: List[Type[BaseEngineSpec]] = []
# load standard engines
db_engine_spec_dir = str(Path(__file__).parent)
for module_info in pkgutil.iter_modules([db_engine_spec_dir], prefix="."):
module = import_module(module_info.name, package=__name__)
engine_specs.extend(
getattr(module, attr)
for attr in module.__dict__
if is_engine_spec(getattr(module, attr))
)
# load additional engines from external modules
for ep in iter_entry_points("superset.db_engine_specs"):
try:
engine_spec = ep.load()
except Exception: # pylint: disable=broad-except
logger.warning("Unable to load engine spec: %s", engine_spec)
continue
engine_specs.append(engine_spec)
# build map from name/alias -> spec
engine_specs_map: Dict[str, Type[BaseEngineSpec]] = {}
for engine_spec in engine_specs:
names = [engine_spec.engine]
if engine_spec.engine_aliases:
names.extend(engine_spec.engine_aliases)
for name in names:
engine_specs_map[name] = engine_spec
return engine_specs_map

View File

@@ -564,13 +564,15 @@ class Database(
@property
def db_engine_spec(self) -> Type[db_engine_specs.BaseEngineSpec]:
return db_engine_specs.engines.get(self.backend, db_engine_specs.BaseEngineSpec)
engines = db_engine_specs.get_engine_specs()
return engines.get(self.backend, db_engine_specs.BaseEngineSpec)
@classmethod
def get_db_engine_spec_for_backend(
cls, backend: str
) -> Type[db_engine_specs.BaseEngineSpec]:
return db_engine_specs.engines.get(backend, db_engine_specs.BaseEngineSpec)
engines = db_engine_specs.get_engine_specs()
return engines.get(backend, db_engine_specs.BaseEngineSpec)
def grains(self) -> Tuple[TimeGrain, ...]:
"""Defines time granularity database-specific expressions.

View File

@@ -19,7 +19,7 @@ from unittest import mock
import pytest
from superset.db_engine_specs import engines
from superset.db_engine_specs import get_engine_specs
from superset.db_engine_specs.base import (
BaseEngineSpec,
builtin_time_grains,
@@ -189,7 +189,7 @@ class TestDbEngineSpecs(TestDbEngineSpec):
def test_engine_time_grain_validity(self):
time_grains = set(builtin_time_grains.keys())
# loop over all subclasses of BaseEngineSpec
for engine in engines.values():
for engine in get_engine_specs().values():
if engine is not BaseEngineSpec:
# make sure time grain functions have been defined
self.assertGreater(len(engine.get_time_grain_expressions()), 0)
@@ -254,8 +254,8 @@ class TestDbEngineSpecs(TestDbEngineSpec):
)
def test_get_time_grain_with_unkown_values(self):
""" Should concatenate from configs and then sort in the proper order
putting unknown patterns at the end """
"""Should concatenate from configs and then sort in the proper order
putting unknown patterns at the end"""
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
"mysql": {"PT2H": "foo", "weird": "foo", "PT12H": "foo",}
}

View File

@@ -22,7 +22,9 @@ import pytest
import pandas as pd
from sqlalchemy.sql import select
from tests.test_app import app
from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3
with app.app_context():
from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3
from superset.exceptions import SupersetException
from superset.sql_parse import Table, ParsedQuery

View File

@@ -20,7 +20,7 @@ from unittest import mock
from sqlalchemy import column, literal_column
from sqlalchemy.dialects import postgresql
from superset.db_engine_specs import engines
from superset.db_engine_specs import get_engine_specs
from superset.db_engine_specs.postgres import PostgresEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from tests.db_engine_specs.base_tests import TestDbEngineSpec
@@ -132,7 +132,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
"""
DB Eng Specs (postgres): Test "postgres" in engine spec
"""
self.assertIn("postgres", engines)
self.assertIn("postgres", get_engine_specs())
def test_extras_without_ssl(self):
db = mock.Mock()