mirror of
https://github.com/apache/superset.git
synced 2026-04-20 00:24:38 +00:00
feat: improve engine spec discoverability (#14204)
* feat: improve engine spec discoverability * Address comments * Fix tests
This commit is contained in:
@@ -28,28 +28,57 @@ at all. The classes here will use a common interface to specify all this.
|
||||
The general idea is to use static classes and an inheritance scheme.
|
||||
"""
|
||||
import inspect
|
||||
import logging
|
||||
import pkgutil
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from typing import Dict, Type
|
||||
from typing import Any, Dict, List, Type
|
||||
|
||||
from pkg_resources import iter_entry_points
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
engines: Dict[str, Type[BaseEngineSpec]] = {}
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
for (_, name, _) in pkgutil.iter_modules([Path(__file__).parent]): # type: ignore
|
||||
imported_module = import_module("." + name, package=__name__)
|
||||
|
||||
for i in dir(imported_module):
|
||||
attribute = getattr(imported_module, i)
|
||||
def is_engine_spec(attr: Any) -> bool:
|
||||
return (
|
||||
inspect.isclass(attr)
|
||||
and issubclass(attr, BaseEngineSpec)
|
||||
and attr != BaseEngineSpec
|
||||
)
|
||||
|
||||
if (
|
||||
inspect.isclass(attribute)
|
||||
and issubclass(attribute, BaseEngineSpec)
|
||||
and attribute.engine != ""
|
||||
):
|
||||
engines[attribute.engine] = attribute
|
||||
|
||||
# populate engine alias name to engine dictionary
|
||||
for engine_alias in attribute.engine_aliases or []:
|
||||
engines[engine_alias] = attribute
|
||||
def get_engine_specs() -> Dict[str, Type[BaseEngineSpec]]:
|
||||
engine_specs: List[Type[BaseEngineSpec]] = []
|
||||
|
||||
# load standard engines
|
||||
db_engine_spec_dir = str(Path(__file__).parent)
|
||||
for module_info in pkgutil.iter_modules([db_engine_spec_dir], prefix="."):
|
||||
module = import_module(module_info.name, package=__name__)
|
||||
engine_specs.extend(
|
||||
getattr(module, attr)
|
||||
for attr in module.__dict__
|
||||
if is_engine_spec(getattr(module, attr))
|
||||
)
|
||||
|
||||
# load additional engines from external modules
|
||||
for ep in iter_entry_points("superset.db_engine_specs"):
|
||||
try:
|
||||
engine_spec = ep.load()
|
||||
except Exception: # pylint: disable=broad-except
|
||||
logger.warning("Unable to load engine spec: %s", engine_spec)
|
||||
continue
|
||||
engine_specs.append(engine_spec)
|
||||
|
||||
# build map from name/alias -> spec
|
||||
engine_specs_map: Dict[str, Type[BaseEngineSpec]] = {}
|
||||
for engine_spec in engine_specs:
|
||||
names = [engine_spec.engine]
|
||||
if engine_spec.engine_aliases:
|
||||
names.extend(engine_spec.engine_aliases)
|
||||
|
||||
for name in names:
|
||||
engine_specs_map[name] = engine_spec
|
||||
|
||||
return engine_specs_map
|
||||
|
||||
@@ -564,13 +564,15 @@ class Database(
|
||||
|
||||
@property
|
||||
def db_engine_spec(self) -> Type[db_engine_specs.BaseEngineSpec]:
|
||||
return db_engine_specs.engines.get(self.backend, db_engine_specs.BaseEngineSpec)
|
||||
engines = db_engine_specs.get_engine_specs()
|
||||
return engines.get(self.backend, db_engine_specs.BaseEngineSpec)
|
||||
|
||||
@classmethod
|
||||
def get_db_engine_spec_for_backend(
|
||||
cls, backend: str
|
||||
) -> Type[db_engine_specs.BaseEngineSpec]:
|
||||
return db_engine_specs.engines.get(backend, db_engine_specs.BaseEngineSpec)
|
||||
engines = db_engine_specs.get_engine_specs()
|
||||
return engines.get(backend, db_engine_specs.BaseEngineSpec)
|
||||
|
||||
def grains(self) -> Tuple[TimeGrain, ...]:
|
||||
"""Defines time granularity database-specific expressions.
|
||||
|
||||
@@ -19,7 +19,7 @@ from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from superset.db_engine_specs import engines
|
||||
from superset.db_engine_specs import get_engine_specs
|
||||
from superset.db_engine_specs.base import (
|
||||
BaseEngineSpec,
|
||||
builtin_time_grains,
|
||||
@@ -189,7 +189,7 @@ class TestDbEngineSpecs(TestDbEngineSpec):
|
||||
def test_engine_time_grain_validity(self):
|
||||
time_grains = set(builtin_time_grains.keys())
|
||||
# loop over all subclasses of BaseEngineSpec
|
||||
for engine in engines.values():
|
||||
for engine in get_engine_specs().values():
|
||||
if engine is not BaseEngineSpec:
|
||||
# make sure time grain functions have been defined
|
||||
self.assertGreater(len(engine.get_time_grain_expressions()), 0)
|
||||
@@ -254,8 +254,8 @@ class TestDbEngineSpecs(TestDbEngineSpec):
|
||||
)
|
||||
|
||||
def test_get_time_grain_with_unkown_values(self):
|
||||
""" Should concatenate from configs and then sort in the proper order
|
||||
putting unknown patterns at the end """
|
||||
"""Should concatenate from configs and then sort in the proper order
|
||||
putting unknown patterns at the end"""
|
||||
app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = {
|
||||
"mysql": {"PT2H": "foo", "weird": "foo", "PT12H": "foo",}
|
||||
}
|
||||
|
||||
@@ -22,7 +22,9 @@ import pytest
|
||||
import pandas as pd
|
||||
from sqlalchemy.sql import select
|
||||
from tests.test_app import app
|
||||
from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3
|
||||
|
||||
with app.app_context():
|
||||
from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.sql_parse import Table, ParsedQuery
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from unittest import mock
|
||||
from sqlalchemy import column, literal_column
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from superset.db_engine_specs import engines
|
||||
from superset.db_engine_specs import get_engine_specs
|
||||
from superset.db_engine_specs.postgres import PostgresEngineSpec
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||
@@ -132,7 +132,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
|
||||
"""
|
||||
DB Eng Specs (postgres): Test "postgres" in engine spec
|
||||
"""
|
||||
self.assertIn("postgres", engines)
|
||||
self.assertIn("postgres", get_engine_specs())
|
||||
|
||||
def test_extras_without_ssl(self):
|
||||
db = mock.Mock()
|
||||
|
||||
Reference in New Issue
Block a user