mirror of
https://github.com/apache/superset.git
synced 2026-04-18 15:44:57 +00:00
fix: calls to _get_sqla_engine (#24953)
This commit is contained in:
@@ -86,9 +86,10 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
||||
}
|
||||
|
||||
if database.has_view_by_name(table_name, schema_name):
|
||||
metadata["view"] = database.inspector.get_view_definition(
|
||||
table_name, schema_name
|
||||
)
|
||||
with database.get_inspector_with_context() as inspector:
|
||||
metadata["view"] = inspector.get_view_definition(
|
||||
table_name, schema_name
|
||||
)
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
@@ -563,7 +563,8 @@ class Database(
|
||||
mutator: Callable[[pd.DataFrame], None] | None = None,
|
||||
) -> pd.DataFrame:
|
||||
sqls = self.db_engine_spec.parse_sql(sql)
|
||||
engine = self._get_sqla_engine(schema)
|
||||
with self.get_sqla_engine_with_context(schema) as engine:
|
||||
engine_url = engine.url
|
||||
mutate_after_split = config["MUTATE_AFTER_SPLIT"]
|
||||
sql_query_mutator = config["SQL_QUERY_MUTATOR"]
|
||||
|
||||
@@ -577,7 +578,7 @@ class Database(
|
||||
def _log_query(sql: str) -> None:
|
||||
if log_query:
|
||||
log_query(
|
||||
engine.url,
|
||||
engine_url,
|
||||
sql,
|
||||
schema,
|
||||
__name__,
|
||||
@@ -624,13 +625,12 @@ class Database(
|
||||
return df
|
||||
|
||||
def compile_sqla_query(self, qry: Select, schema: str | None = None) -> str:
|
||||
engine = self._get_sqla_engine(schema=schema)
|
||||
with self.get_sqla_engine_with_context(schema) as engine:
|
||||
sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
|
||||
|
||||
sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
|
||||
|
||||
# pylint: disable=protected-access
|
||||
if engine.dialect.identifier_preparer._double_percents: # noqa
|
||||
sql = sql.replace("%%", "%")
|
||||
# pylint: disable=protected-access
|
||||
if engine.dialect.identifier_preparer._double_percents: # noqa
|
||||
sql = sql.replace("%%", "%")
|
||||
|
||||
return sql
|
||||
|
||||
@@ -645,18 +645,18 @@ class Database(
|
||||
cols: list[ResultSetColumnType] | None = None,
|
||||
) -> str:
|
||||
"""Generates a ``select *`` statement in the proper dialect"""
|
||||
eng = self._get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB)
|
||||
return self.db_engine_spec.select_star(
|
||||
self,
|
||||
table_name,
|
||||
schema=schema,
|
||||
engine=eng,
|
||||
limit=limit,
|
||||
show_cols=show_cols,
|
||||
indent=indent,
|
||||
latest_partition=latest_partition,
|
||||
cols=cols,
|
||||
)
|
||||
with self.get_sqla_engine_with_context(schema) as engine:
|
||||
return self.db_engine_spec.select_star(
|
||||
self,
|
||||
table_name,
|
||||
schema=schema,
|
||||
engine=engine,
|
||||
limit=limit,
|
||||
show_cols=show_cols,
|
||||
indent=indent,
|
||||
latest_partition=latest_partition,
|
||||
cols=cols,
|
||||
)
|
||||
|
||||
def apply_limit_to_sql(
|
||||
self, sql: str, limit: int = 1000, force: bool = False
|
||||
@@ -668,11 +668,6 @@ class Database(
|
||||
def safe_sqlalchemy_uri(self) -> str:
|
||||
return self.sqlalchemy_uri
|
||||
|
||||
@property
|
||||
def inspector(self) -> Inspector:
|
||||
engine = self._get_sqla_engine()
|
||||
return sqla.inspect(engine)
|
||||
|
||||
@cache_util.memoized_func(
|
||||
key="db:{self.id}:schema:{schema}:table_list",
|
||||
cache=cache_manager.cache,
|
||||
@@ -955,8 +950,10 @@ class Database(
|
||||
return view_name in view_names
|
||||
|
||||
def has_view(self, view_name: str, schema: str | None = None) -> bool:
|
||||
engine = self._get_sqla_engine()
|
||||
return engine.run_callable(self._has_view, engine.dialect, view_name, schema)
|
||||
with self.get_sqla_engine_with_context(schema) as engine:
|
||||
return engine.run_callable(
|
||||
self._has_view, engine.dialect, view_name, schema
|
||||
)
|
||||
|
||||
def has_view_by_name(self, view_name: str, schema: str | None = None) -> bool:
|
||||
return self.has_view(view_name=view_name, schema=schema)
|
||||
|
||||
@@ -120,9 +120,8 @@ def drop_table_if_exists(table_name: str, table_type: CtasMethod) -> None:
|
||||
def quote_f(value: Optional[str]):
|
||||
if not value:
|
||||
return value
|
||||
return get_example_database().inspector.engine.dialect.identifier_preparer.quote_identifier(
|
||||
value
|
||||
)
|
||||
with get_example_database().get_inspector_with_context() as inspector:
|
||||
return inspector.engine.dialect.identifier_preparer.quote_identifier(value)
|
||||
|
||||
|
||||
def cta_result(ctas_method: CtasMethod):
|
||||
|
||||
@@ -113,9 +113,10 @@ class BaseTestChartDataApi(SupersetTestCase):
|
||||
|
||||
def quote_name(self, name: str):
|
||||
if get_main_database().backend in {"presto", "hive"}:
|
||||
return get_example_database().inspector.engine.dialect.identifier_preparer.quote_identifier(
|
||||
name
|
||||
)
|
||||
with get_example_database().get_inspector_with_context() as inspector: # E: Ne
|
||||
return inspector.engine.dialect.identifier_preparer.quote_identifier(
|
||||
name
|
||||
)
|
||||
return name
|
||||
|
||||
|
||||
|
||||
@@ -296,7 +296,8 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
db = get_example_database()
|
||||
table_name = "energy_usage"
|
||||
sql = db.select_star(table_name, show_cols=False, latest_partition=False)
|
||||
quote = db.inspector.engine.dialect.identifier_preparer.quote_identifier
|
||||
with db.get_sqla_engine_with_context() as engine:
|
||||
quote = engine.dialect.identifier_preparer.quote_identifier
|
||||
expected = (
|
||||
textwrap.dedent(
|
||||
f"""\
|
||||
|
||||
Reference in New Issue
Block a user