diff --git a/superset/models/core.py b/superset/models/core.py index 173bd5b5907..9e042eeab57 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -744,13 +744,14 @@ class Database( def get_table(self, table_name: str, schema: Optional[str] = None) -> Table: extra = self.get_extra() meta = MetaData(**extra.get("metadata_params", {})) - return Table( - table_name, - meta, - schema=schema or None, - autoload=True, - autoload_with=self._get_sqla_engine(), - ) + with self.get_sqla_engine_with_context() as engine: + return Table( + table_name, + meta, + schema=schema or None, + autoload=True, + autoload_with=engine, + ) def get_table_comment( self, table_name: str, schema: Optional[str] = None @@ -846,12 +847,12 @@ class Database( return self.perm # type: ignore def has_table(self, table: Table) -> bool: - engine = self._get_sqla_engine() - return engine.has_table(table.table_name, table.schema or None) + with self.get_sqla_engine_with_context() as engine: + return engine.has_table(table.table_name, table.schema or None) def has_table_by_name(self, table_name: str, schema: Optional[str] = None) -> bool: - engine = self._get_sqla_engine() - return engine.has_table(table_name, schema) + with self.get_sqla_engine_with_context() as engine: + return engine.has_table(table_name, schema) @classmethod def _has_view( diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index 6e9f1a8d33c..86246084fb8 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -27,10 +27,12 @@ from typing import Dict, List from urllib.parse import quote import superset.utils.database +from superset.utils.core import backend from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, load_birth_names_data, ) +from sqlalchemy import Table import pytest import pytz @@ -79,6 +81,7 @@ from tests.integration_tests.fixtures.world_bank_dashboard import ( load_world_bank_dashboard_with_slices, load_world_bank_data, ) +from tests.integration_tests.conftest import CTAS_SCHEMA_NAME logger = logging.getLogger(__name__) @@ -1673,6 +1676,16 @@ class TestCore(SupersetTestCase): ) self.assertRedirects(rv, f"/explore/?form_data_key={random_key}") + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_has_table_by_name(self): + if backend() in ("sqlite", "mysql"): + return + example_db = superset.utils.database.get_example_database() + assert ( + example_db.has_table_by_name(table_name="birth_names", schema="public") + is True + ) + if __name__ == "__main__": unittest.main()