diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index dcdfff6c3f2..2e555e32f1c 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -58,8 +58,8 @@ from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClau from sqlalchemy.types import TypeEngine from sqlparse.tokens import CTE -from superset import sql_parse -from superset.constants import TimeGrain as TimeGrainConstants +from superset import db, sql_parse +from superset.constants import QUERY_CANCEL_KEY, TimeGrain as TimeGrainConstants from superset.databases.utils import get_table_metadata, make_url_safe from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2RedirectError @@ -437,6 +437,14 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # Driver-specific exception that should be mapped to OAuth2RedirectError oauth2_exception = OAuth2RedirectError + # Does the query id related to the connection? + # The default value is True, which means that the query id is determined when + # the connection is created. + # When this is changed to false in a DB engine spec it means the query id + # is determined only after the specific query is executed and it will update + # the `cancel_query` value in the `extra` field of the `query` object + has_query_id_before_execute = True + @classmethod def is_oauth2_enabled(cls) -> bool: return ( @@ -1316,6 +1324,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # TODO: Fix circular import error caused by importing sql_lab.Query @classmethod + # pylint: disable=consider-using-transaction def execute_with_cursor( cls, cursor: Any, @@ -1333,6 +1342,13 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods """ logger.debug("Query %d: Running query: %s", query.id, sql) cls.execute(cursor, sql, query.database, async_=True) + if not cls.has_query_id_before_execute: + cancel_query_id = query.database.db_engine_spec.get_cancel_query_id( + cursor, query + ) + if cancel_query_id is not None: + query.set_extra_json_key(QUERY_CANCEL_KEY, cancel_query_id) + db.session.commit() logger.debug("Query %d: Handling cursor", query.id) cls.handle_cursor(cursor, query) diff --git a/superset/db_engine_specs/impala.py b/superset/db_engine_specs/impala.py index ea74df83164..ce34ae5648f 100644 --- a/superset/db_engine_specs/impala.py +++ b/superset/db_engine_specs/impala.py @@ -21,8 +21,9 @@ import logging import re import time from datetime import datetime -from typing import Any, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING +import requests from flask import current_app from sqlalchemy import types from sqlalchemy.engine.reflection import Inspector @@ -57,6 +58,8 @@ class ImpalaEngineSpec(BaseEngineSpec): TimeGrain.YEAR: "TRUNC({col}, 'YYYY')", } + has_query_id_before_execute = False + @classmethod def epoch_to_dttm(cls) -> str: return "from_unixtime({col})" @@ -91,7 +94,7 @@ class ImpalaEngineSpec(BaseEngineSpec): :see: handle_cursor """ - return True + return False @classmethod def execute( @@ -160,3 +163,38 @@ class ImpalaEngineSpec(BaseEngineSpec): except Exception: # pylint: disable=broad-except logger.debug("Call to status() failed ") return + + @classmethod + def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]: + """ + Get Impala Query ID that will be used to cancel the running + queries to release impala resources. + + :param cursor: Cursor instance in which the query will be executed + :param query: Query instance + :return: Impala Query ID + """ + last_operation = getattr(cursor, "_last_operation", None) + if not last_operation: + return None + guid = last_operation.handle.operationId.guid[::-1].hex() + return f"{guid[-16:]}:{guid[:16]}" + + @classmethod + def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool: + """ + Cancel query in the underlying database. + + :param cursor: New cursor instance to the db of the query + :param query: Query instance + :param cancel_query_id: impala db not need + :return: True if query cancelled successfully, False otherwise + """ + try: + impala_host = query.database.url_object.host + url = f"http://{impala_host}:25000/cancel_query?query_id={cancel_query_id}" + response = requests.post(url, timeout=3) + except Exception: # pylint: disable=broad-except + return False + + return bool(response and response.status_code == 200) diff --git a/tests/unit_tests/db_engine_specs/test_impala.py b/tests/unit_tests/db_engine_specs/test_impala.py index efaed81cba7..543db243684 100644 --- a/tests/unit_tests/db_engine_specs/test_impala.py +++ b/tests/unit_tests/db_engine_specs/test_impala.py @@ -17,9 +17,13 @@ from datetime import datetime from typing import Optional +from unittest.mock import Mock, patch import pytest +from superset.db_engine_specs.impala import ImpalaEngineSpec as spec +from superset.models.core import Database +from superset.models.sql_lab import Query from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm from tests.unit_tests.fixtures.common import dttm # noqa: F401 @@ -37,6 +41,77 @@ def test_convert_dttm( expected_result: Optional[str], dttm: datetime, # noqa: F811 ) -> None: - from superset.db_engine_specs.impala import ImpalaEngineSpec as spec - assert_convert_dttm(spec, target_type, expected_result, dttm) + + +def test_get_cancel_query_id() -> None: + query = Query() + + cursor_mock = Mock() + last_operation_mock = Mock() + cursor_mock._last_operation = last_operation_mock + + guid = bytes(reversed(bytes.fromhex("9fbdba20000000006940643a2731718b"))) + last_operation_mock.handle.operationId.guid = guid + + assert ( + spec.get_cancel_query_id(cursor_mock, query) + == "6940643a2731718b:9fbdba2000000000" + ) + + +@patch("requests.post") +def test_cancel_query(post_mock: Mock) -> None: + query = Query() + database = Database( + database_name="test_impala", sqlalchemy_uri="impala://localhost:21050/default" + ) + query.database = database + + response_mock = Mock() + response_mock.status_code = 200 + post_mock.return_value = response_mock + + result = spec.cancel_query(None, query, "6940643a2731718b:9fbdba2000000000") + + post_mock.assert_called_once_with( + "http://localhost:25000/cancel_query?query_id=6940643a2731718b:9fbdba2000000000", + timeout=3, + ) + assert result is True + + +@patch("requests.post") +def test_cancel_query_failed(post_mock: Mock) -> None: + query = Query() + database = Database( + database_name="test_impala", sqlalchemy_uri="impala://localhost:21050/default" + ) + query.database = database + + response_mock = Mock() + response_mock.status_code = 500 + post_mock.return_value = response_mock + + result = spec.cancel_query(None, query, "6940643a2731718b:9fbdba2000000000") + + post_mock.assert_called_once_with( + "http://localhost:25000/cancel_query?query_id=6940643a2731718b:9fbdba2000000000", + timeout=3, + ) + assert result is False + + +@patch("requests.post") +def test_cancel_query_exception(post_mock: Mock) -> None: + query = Query() + database = Database( + database_name="test_impala", sqlalchemy_uri="impala://localhost:21050/default" + ) + query.database = database + + post_mock.side_effect = Exception("Network error") + + result = spec.cancel_query(None, query, "6940643a2731718b:9fbdba2000000000") + + assert result is False