fix: trino cursor (#25897)

(cherry picked from commit cdb18e04ff)
This commit is contained in:
Beto Dealmeida
2023-11-08 07:38:38 -05:00
committed by Michael S. Molina
parent 8c099a3f6f
commit d265bd2ffc

View File

@@ -187,7 +187,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def execute_with_cursor(
cls, cursor: Any, sql: str, query: Query, session: Session
cls, cursor: Cursor, sql: str, query: Query, session: Session
) -> None:
"""
Trigger execution of a query and handle the resulting cursor.
@@ -196,34 +196,40 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
in another thread and invoke `handle_cursor` to poll for the query ID
to appear on the cursor in parallel.
"""
# Fetch the query ID beforehand, since it might fail inside the thread due to
# how the SQLAlchemy session is handled.
query_id = query.id
execute_result: dict[str, Any] = {}
execute_event = threading.Event()
def _execute(results: dict[str, Any]) -> None:
logger.debug("Query %d: Running query: %s", query.id, sql)
def _execute(results: dict[str, Any], event: threading.Event) -> None:
logger.debug("Query %d: Running query: %s", query_id, sql)
# Pass result / exception information back to the parent thread
try:
cls.execute(cursor, sql)
results["complete"] = True
except Exception as ex: # pylint: disable=broad-except
results["complete"] = True
results["error"] = ex
finally:
event.set()
execute_thread = threading.Thread(target=_execute, args=(execute_result,))
execute_thread = threading.Thread(
target=_execute,
args=(execute_result, execute_event),
)
execute_thread.start()
# Wait for a query ID to be available before handling the cursor, as
# it's required by that method; it may never become available on error.
while not cursor.query_id and not execute_result.get("complete"):
while not cursor.query_id and not execute_event.is_set():
time.sleep(0.1)
logger.debug("Query %d: Handling cursor", query.id)
logger.debug("Query %d: Handling cursor", query_id)
cls.handle_cursor(cursor, query, session)
# Block until the query completes; same behaviour as the client itself
logger.debug("Query %d: Waiting for query to complete", query.id)
while not execute_result.get("complete"):
time.sleep(0.5)
logger.debug("Query %d: Waiting for query to complete", query_id)
execute_event.wait()
# Unfortunately we'll mangle the stack trace due to the thread, but
# throwing the original exception allows mapping database errors as normal
@@ -237,7 +243,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
session.commit()
@classmethod
def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool:
def cancel_query(cls, cursor: Cursor, query: Query, cancel_query_id: str) -> bool:
"""
Cancel query in the underlying database.