# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # pylint: disable=unused-argument, import-outside-toplevel, protected-access from __future__ import annotations import copy from collections import namedtuple from datetime import datetime from typing import Any, Optional from unittest.mock import MagicMock, Mock, patch import pandas as pd import pytest from flask import g, has_app_context from pytest_mock import MockerFixture from requests.exceptions import ConnectionError as RequestsConnectionError from sqlalchemy import column, sql, text, types from sqlalchemy.dialects import sqlite from sqlalchemy.engine.url import make_url from sqlalchemy.exc import NoSuchTableError from trino.exceptions import TrinoExternalError, TrinoInternalError, TrinoUserError from trino.sqlalchemy import datatype from trino.sqlalchemy.dialect import TrinoDialect import superset.config from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY from superset.db_engine_specs.exceptions import ( SupersetDBAPIConnectionError, SupersetDBAPIDatabaseError, SupersetDBAPIOperationalError, SupersetDBAPIProgrammingError, ) from superset.sql.parse import Table from superset.superset_typing import ( OAuth2ClientConfig, ResultSetColumnType, SQLAColumnType, SQLType, ) from superset.utils import json from superset.utils.core import GenericDataType from tests.unit_tests.db_engine_specs.utils import ( assert_column_spec, assert_convert_dttm, ) from tests.unit_tests.fixtures.common import dttm # noqa: F401 def _assert_columns_equal(actual_cols, expected_cols) -> None: """ Assert equality of the given cols, bearing in mind sqlalchemy type instances can't be compared for equality, so will have to be converted to strings first. """ actual = copy.deepcopy(actual_cols) expected = copy.deepcopy(expected_cols) for col in actual: col["type"] = str(col["type"]) for col in expected: col["type"] = str(col["type"]) assert actual == expected @pytest.mark.parametrize( "extra,expected", [ ({}, {"engine_params": {"connect_args": {"source": "Apache Superset"}}}), ( { "first": 1, "engine_params": { "second": "two", "connect_args": {"source": "foobar", "third": "three"}, }, }, { "first": 1, "engine_params": { "second": "two", "connect_args": {"source": "foobar", "third": "three"}, }, }, ), ], ) def test_get_extra_params(extra: dict[str, Any], expected: dict[str, Any]) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec database = Mock() database.extra = json.dumps(extra) database.server_cert = None assert TrinoEngineSpec.get_extra_params(database) == expected @patch("superset.db_engine_specs.trino.create_ssl_cert_file") def test_get_extra_params_with_server_cert(mock_create_ssl_cert_file: Mock) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec database = Mock() database.extra = json.dumps({}) database.server_cert = "TEST_CERT" database.db_engine_spec = TrinoEngineSpec mock_create_ssl_cert_file.return_value = "/path/to/tls.crt" extra = TrinoEngineSpec.get_extra_params(database) connect_args = extra.get("engine_params", {}).get("connect_args", {}) assert connect_args.get("http_scheme") == "https" assert connect_args.get("verify") == "/path/to/tls.crt" mock_create_ssl_cert_file.assert_called_once_with(database.server_cert) @patch("trino.auth.BasicAuthentication") def test_auth_basic(mock_auth: Mock) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec database = Mock() auth_params = {"username": "username", "password": "password"} database.encrypted_extra = json.dumps( {"auth_method": "basic", "auth_params": auth_params} ) params: dict[str, Any] = {} TrinoEngineSpec.update_params_from_encrypted_extra(database, params) connect_args = params.setdefault("connect_args", {}) assert connect_args.get("http_scheme") == "https" mock_auth.assert_called_once_with(**auth_params) @patch("trino.auth.KerberosAuthentication") def test_auth_kerberos(mock_auth: Mock) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec database = Mock() auth_params = { "service_name": "superset", "mutual_authentication": False, "delegate": True, } database.encrypted_extra = json.dumps( {"auth_method": "kerberos", "auth_params": auth_params} ) params: dict[str, Any] = {} TrinoEngineSpec.update_params_from_encrypted_extra(database, params) connect_args = params.setdefault("connect_args", {}) assert connect_args.get("http_scheme") == "https" mock_auth.assert_called_once_with(**auth_params) @patch("trino.auth.CertificateAuthentication") def test_auth_certificate(mock_auth: Mock) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec database = Mock() auth_params = {"cert": "/path/to/cert.pem", "key": "/path/to/key.pem"} database.encrypted_extra = json.dumps( {"auth_method": "certificate", "auth_params": auth_params} ) params: dict[str, Any] = {} TrinoEngineSpec.update_params_from_encrypted_extra(database, params) connect_args = params.setdefault("connect_args", {}) assert connect_args.get("http_scheme") == "https" mock_auth.assert_called_once_with(**auth_params) @patch("trino.auth.JWTAuthentication") def test_auth_jwt(mock_auth: Mock) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec database = Mock() auth_params = {"token": "jwt-token-string"} database.encrypted_extra = json.dumps( {"auth_method": "jwt", "auth_params": auth_params} ) params: dict[str, Any] = {} TrinoEngineSpec.update_params_from_encrypted_extra(database, params) connect_args = params.setdefault("connect_args", {}) assert connect_args.get("http_scheme") == "https" mock_auth.assert_called_once_with(**auth_params) def test_auth_custom_auth() -> None: from superset.db_engine_specs.trino import TrinoEngineSpec database = Mock() auth_class = Mock() auth_method = "custom_auth" auth_params = {"params1": "params1", "params2": "params2"} database.encrypted_extra = json.dumps( {"auth_method": auth_method, "auth_params": auth_params} ) with patch.dict( "superset.config.ALLOWED_EXTRA_AUTHENTICATIONS", {"trino": {"custom_auth": auth_class}}, clear=True, ): params: dict[str, Any] = {} TrinoEngineSpec.update_params_from_encrypted_extra(database, params) connect_args = params.setdefault("connect_args", {}) assert connect_args.get("http_scheme") == "https" auth_class.assert_called_once_with(**auth_params) def test_auth_custom_auth_denied() -> None: from superset.db_engine_specs.trino import TrinoEngineSpec database = Mock() auth_method = "my.module:TrinoAuthClass" auth_params = {"params1": "params1", "params2": "params2"} database.encrypted_extra = json.dumps( {"auth_method": auth_method, "auth_params": auth_params} ) superset.config.ALLOWED_EXTRA_AUTHENTICATIONS = {} with pytest.raises(ValueError) as excinfo: # noqa: PT011 TrinoEngineSpec.update_params_from_encrypted_extra(database, {}) assert str(excinfo.value) == ( f"For security reason, custom authentication '{auth_method}' " f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config" ) @pytest.mark.parametrize( "native_type,sqla_type,attrs,generic_type,is_dttm", [ ("BOOLEAN", types.Boolean, None, GenericDataType.BOOLEAN, False), ("TINYINT", types.Integer, None, GenericDataType.NUMERIC, False), ("SMALLINT", types.SmallInteger, None, GenericDataType.NUMERIC, False), ("INTEGER", types.Integer, None, GenericDataType.NUMERIC, False), ("BIGINT", types.BigInteger, None, GenericDataType.NUMERIC, False), ("REAL", types.FLOAT, None, GenericDataType.NUMERIC, False), ("DOUBLE", types.FLOAT, None, GenericDataType.NUMERIC, False), ("DECIMAL", types.DECIMAL, None, GenericDataType.NUMERIC, False), ("VARCHAR", types.String, None, GenericDataType.STRING, False), ("VARCHAR(20)", types.VARCHAR, {"length": 20}, GenericDataType.STRING, False), ("CHAR", types.String, None, GenericDataType.STRING, False), ("CHAR(2)", types.CHAR, {"length": 2}, GenericDataType.STRING, False), ("JSON", types.JSON, None, GenericDataType.STRING, False), ("TIMESTAMP", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True), ("TIMESTAMP(3)", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True), ( "TIMESTAMP WITH TIME ZONE", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True, ), ( "TIMESTAMP(3) WITH TIME ZONE", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True, ), ("DATE", types.Date, None, GenericDataType.TEMPORAL, True), ], ) def test_get_column_spec( native_type: str, sqla_type: type[types.TypeEngine], attrs: Optional[dict[str, Any]], generic_type: GenericDataType, is_dttm: bool, ) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec as spec # noqa: N813 assert_column_spec( spec, native_type, sqla_type, attrs, generic_type, is_dttm, ) @pytest.mark.parametrize( "target_type,expected_result", [ ("TimeStamp", "TIMESTAMP '2019-01-02 03:04:05.678900'"), ("TimeStamp(3)", "TIMESTAMP '2019-01-02 03:04:05.678900'"), ("TimeStamp With Time Zone", "TIMESTAMP '2019-01-02 03:04:05.678900'"), ("TimeStamp(3) With Time Zone", "TIMESTAMP '2019-01-02 03:04:05.678900'"), ("Date", "DATE '2019-01-02'"), ("Other", None), ], ) def test_convert_dttm( target_type: str, expected_result: Optional[str], dttm: datetime, # noqa: F811 ) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec assert_convert_dttm(TrinoEngineSpec, target_type, expected_result, dttm) def test_get_extra_table_metadata(mocker: MockerFixture) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec db_mock = mocker.MagicMock() db_mock.get_indexes = Mock( return_value=[{"column_names": ["ds", "hour"], "name": "partition"}] ) db_mock.get_extra = Mock(return_value={}) db_mock.has_view = Mock(return_value=None) db_mock.get_df = Mock(return_value=pd.DataFrame({"ds": ["01-01-19"], "hour": [1]})) result = TrinoEngineSpec.get_extra_table_metadata( db_mock, Table("test_table", "test_schema"), ) assert result["partitions"]["cols"] == ["ds", "hour"] assert result["partitions"]["latest"] == {"ds": "01-01-19", "hour": 1} @patch("sqlalchemy.engine.Engine.connect") def test_cancel_query_success(engine_mock: Mock) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.sql_lab import Query query = Query() cursor_mock = engine_mock.return_value.__enter__.return_value assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is True @patch("sqlalchemy.engine.Engine.connect") def test_cancel_query_failed(engine_mock: Mock) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.sql_lab import Query query = Query() cursor_mock = engine_mock.raiseError.side_effect = Exception() assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is False @pytest.mark.parametrize( "initial_extra,final_extra", [ ({}, {QUERY_EARLY_CANCEL_KEY: True}), ({QUERY_CANCEL_KEY: "my_key"}, {QUERY_CANCEL_KEY: "my_key"}), ], ) def test_prepare_cancel_query( initial_extra: dict[str, Any], final_extra: dict[str, Any], mocker: MockerFixture, ) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.sql_lab import Query query = Query(extra_json=json.dumps(initial_extra)) TrinoEngineSpec.prepare_cancel_query(query=query) assert query.extra == final_extra @pytest.mark.parametrize("cancel_early", [True, False]) @patch("superset.db_engine_specs.presto.PrestoBaseEngineSpec.handle_cursor") @patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query") @patch("superset.db_engine_specs.trino.db") @patch("superset.db_engine_specs.trino.app") def test_handle_cursor_early_cancel( mock_app: Mock, mock_db: Mock, cancel_query_mock: Mock, mock_presto_handle_cursor: Mock, cancel_early: bool, mocker: MockerFixture, ) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.sql_lab import Query mock_app.config = {"DB_POLL_INTERVAL_SECONDS": {"trino": 0}} query_id = "myQueryId" # Use spec to prevent MagicMock from creating attributes automatically cursor_mock = mocker.MagicMock(spec=["query_id", "stats", "info_uri"]) cursor_mock.query_id = query_id # Set stats to FINISHED so the progress loop exits immediately cursor_mock.stats = {"state": "FINISHED", "completedSplits": 0, "totalSplits": 0} query = Query() if cancel_early: TrinoEngineSpec.prepare_cancel_query(query=query) TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query) if cancel_early: assert cancel_query_mock.call_args[1]["cancel_query_id"] == query_id else: assert cancel_query_mock.call_args is None def test_execute_with_cursor_in_parallel(app, mocker: MockerFixture): """Test that `execute_with_cursor` fetches query ID from the cursor""" from superset.db_engine_specs.trino import TrinoEngineSpec query_id = "myQueryId" mock_cursor = mocker.MagicMock() mock_cursor.query_id = None mock_query = mocker.MagicMock() def _mock_execute(*args, **kwargs): mock_cursor.query_id = query_id with app.test_request_context("/some/place/"): mock_cursor.execute.side_effect = _mock_execute with patch.dict( "superset.config.DISALLOWED_SQL_FUNCTIONS", {}, clear=True, ): TrinoEngineSpec.execute_with_cursor( cursor=mock_cursor, sql="SELECT 1 FROM foo", query=mock_query, ) mock_query.set_extra_json_key.assert_called_once_with( key=QUERY_CANCEL_KEY, value=query_id ) def test_execute_with_cursor_app_context(app, mocker: MockerFixture): """Test that `execute_with_cursor` still contains the current app context""" from superset.db_engine_specs.trino import TrinoEngineSpec mock_cursor = mocker.MagicMock() mock_cursor.query_id = None mock_query = mocker.MagicMock() def _mock_execute(*args, **kwargs): assert has_app_context() assert g.some_value == "some_value" with app.test_request_context("/some/place/"): g.some_value = "some_value" with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute): with patch.dict( "superset.config.DISALLOWED_SQL_FUNCTIONS", {}, clear=True, ): TrinoEngineSpec.execute_with_cursor( cursor=mock_cursor, sql="SELECT 1 FROM foo", query=mock_query, ) def test_get_columns(mocker: MockerFixture): """Test that ROW columns are not expanded without expand_rows""" from superset.db_engine_specs.trino import TrinoEngineSpec field1_type = datatype.parse_sqltype("row(a varchar, b date)") field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))") field3_type = datatype.parse_sqltype("int") sqla_columns = [ SQLAColumnType(name="field1", type=field1_type, is_dttm=False), SQLAColumnType(name="field2", type=field2_type, is_dttm=False), SQLAColumnType(name="field3", type=field3_type, is_dttm=False), ] mock_inspector = mocker.MagicMock() mock_inspector.get_columns.return_value = sqla_columns actual = TrinoEngineSpec.get_columns(mock_inspector, Table("table", "schema")) expected = [ ResultSetColumnType( name="field1", column_name="field1", type=field1_type, is_dttm=False ), ResultSetColumnType( name="field2", column_name="field2", type=field2_type, is_dttm=False ), ResultSetColumnType( name="field3", column_name="field3", type=field3_type, is_dttm=False ), ] _assert_columns_equal(actual, expected) def test_get_columns_error(mocker: MockerFixture): """ Test that we fallback to a `SHOW COLUMNS FROM ...` query. """ from superset.db_engine_specs.trino import TrinoEngineSpec field1_type = datatype.parse_sqltype("row(a varchar, b date)") field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))") field3_type = datatype.parse_sqltype("int") mock_inspector = mocker.MagicMock() mock_inspector.engine.dialect = sqlite.dialect() mock_inspector.get_columns.side_effect = NoSuchTableError( "The specified table does not exist." ) Row = namedtuple("Row", ["Column", "Type"]) mock_inspector.bind.execute().fetchall.return_value = [ Row("field1", "row(a varchar, b date)"), Row("field2", "row(r1 row(a varchar, b varchar))"), Row("field3", "int"), ] actual = TrinoEngineSpec.get_columns(mock_inspector, Table("table", "schema")) expected = [ ResultSetColumnType( name="field1", column_name="field1", type=field1_type, is_dttm=None, type_generic=None, default=None, nullable=True, ), ResultSetColumnType( name="field2", column_name="field2", type=field2_type, is_dttm=None, type_generic=None, default=None, nullable=True, ), ResultSetColumnType( name="field3", column_name="field3", type=field3_type, is_dttm=None, type_generic=None, default=None, nullable=True, ), ] _assert_columns_equal(actual, expected) mock_inspector.bind.execute.assert_called_with('SHOW COLUMNS FROM schema."table"') def test_get_columns_expand_rows(mocker: MockerFixture): """Test that ROW columns are correctly expanded with expand_rows""" from superset.db_engine_specs.trino import TrinoEngineSpec field1_type = datatype.parse_sqltype("row(a varchar, b date)") field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))") field3_type = datatype.parse_sqltype("int") sqla_columns = [ SQLAColumnType(name="field1", type=field1_type, is_dttm=False), SQLAColumnType(name="field2", type=field2_type, is_dttm=False), SQLAColumnType(name="field3", type=field3_type, is_dttm=False), ] mock_inspector = mocker.MagicMock() mock_inspector.get_columns.return_value = sqla_columns actual = TrinoEngineSpec.get_columns( mock_inspector, Table("table", "schema"), {"expand_rows": True}, ) expected = [ ResultSetColumnType( name="field1", column_name="field1", type=field1_type, is_dttm=False ), ResultSetColumnType( name="field1.a", column_name="field1.a", type=types.VARCHAR(), is_dttm=False, query_as='"field1"."a" AS "field1.a"', ), ResultSetColumnType( name="field1.b", column_name="field1.b", type=types.DATE(), is_dttm=True, query_as='"field1"."b" AS "field1.b"', ), ResultSetColumnType( name="field2", column_name="field2", type=field2_type, is_dttm=False ), ResultSetColumnType( name="field2.r1", column_name="field2.r1", type=datatype.parse_sqltype("row(a varchar, b varchar)"), is_dttm=False, query_as='"field2"."r1" AS "field2.r1"', ), ResultSetColumnType( name="field2.r1.a", column_name="field2.r1.a", type=types.VARCHAR(), is_dttm=False, query_as='"field2"."r1"."a" AS "field2.r1.a"', ), ResultSetColumnType( name="field2.r1.b", column_name="field2.r1.b", type=types.VARCHAR(), is_dttm=False, query_as='"field2"."r1"."b" AS "field2.r1.b"', ), ResultSetColumnType( name="field3", column_name="field3", type=field3_type, is_dttm=False ), ] _assert_columns_equal(actual, expected) def test_get_indexes_no_table(): from superset.db_engine_specs.trino import TrinoEngineSpec db_mock = Mock() inspector_mock = Mock() inspector_mock.get_indexes = Mock( side_effect=NoSuchTableError("The specified table does not exist.") ) result = TrinoEngineSpec.get_indexes( db_mock, inspector_mock, Table("test_table", "test_schema"), ) assert result == [] def test_get_dbapi_exception_mapping(): from superset.db_engine_specs.trino import TrinoEngineSpec mapping = TrinoEngineSpec.get_dbapi_exception_mapping() assert mapping.get(TrinoUserError) == SupersetDBAPIProgrammingError assert mapping.get(TrinoInternalError) == SupersetDBAPIDatabaseError assert mapping.get(TrinoExternalError) == SupersetDBAPIOperationalError assert mapping.get(RequestsConnectionError) == SupersetDBAPIConnectionError assert mapping.get(Exception) is None def test_adjust_engine_params_fully_qualified() -> None: """ Test the ``adjust_engine_params`` method when the URL has catalog and schema. """ from superset.db_engine_specs.trino import TrinoEngineSpec url = make_url("trino://user:pass@localhost:8080/system/default") uri = TrinoEngineSpec.adjust_engine_params(url, {})[0] assert str(uri) == "trino://user:pass@localhost:8080/system/default" uri = TrinoEngineSpec.adjust_engine_params( url, {}, schema="new_schema", )[0] assert str(uri) == "trino://user:pass@localhost:8080/system/new_schema" uri = TrinoEngineSpec.adjust_engine_params( url, {}, catalog="new_catalog", )[0] assert str(uri) == "trino://user:pass@localhost:8080/new_catalog/default" uri = TrinoEngineSpec.adjust_engine_params( url, {}, catalog="new_catalog", schema="new_schema", )[0] assert str(uri) == "trino://user:pass@localhost:8080/new_catalog/new_schema" def test_adjust_engine_params_catalog_only() -> None: """ Test the ``adjust_engine_params`` method when the URL has only the catalog. """ from superset.db_engine_specs.trino import TrinoEngineSpec url = make_url("trino://user:pass@localhost:8080/system") uri = TrinoEngineSpec.adjust_engine_params(url, {})[0] assert str(uri) == "trino://user:pass@localhost:8080/system" uri = TrinoEngineSpec.adjust_engine_params( url, {}, schema="new_schema", )[0] assert str(uri) == "trino://user:pass@localhost:8080/system/new_schema" uri = TrinoEngineSpec.adjust_engine_params( url, {}, catalog="new_catalog", )[0] assert str(uri) == "trino://user:pass@localhost:8080/new_catalog" uri = TrinoEngineSpec.adjust_engine_params( url, {}, catalog="new_catalog", schema="new_schema", )[0] assert str(uri) == "trino://user:pass@localhost:8080/new_catalog/new_schema" @pytest.mark.parametrize( "sqlalchemy_uri,result", [ ("trino://user:pass@localhost:8080/system", "system"), ("trino://user:pass@localhost:8080/system/default", "system"), ("trino://trino@localhost:8081", None), ], ) def test_get_default_catalog(sqlalchemy_uri: str, result: str | None) -> None: """ Test the ``get_default_catalog`` method. """ from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.core import Database database = Database( database_name="my_db", sqlalchemy_uri=sqlalchemy_uri, ) assert TrinoEngineSpec.get_default_catalog(database) == result @patch("superset.db_engine_specs.trino.TrinoEngineSpec.latest_partition") @pytest.mark.parametrize( ["column_type", "column_value", "expected_value"], [ (types.DATE(), "2023-05-01", "DATE '2023-05-01'"), (types.TIMESTAMP(), "2023-05-01", "TIMESTAMP '2023-05-01'"), (types.VARCHAR(), "2023-05-01", "'2023-05-01'"), (types.INT(), 1234, "1234"), ], ) def test_where_latest_partition( mock_latest_partition, column_type: SQLType, column_value: Any, expected_value: str, ) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec mock_latest_partition.return_value = (["partition_key"], [column_value]) assert ( str( TrinoEngineSpec.where_latest_partition( # type: ignore database=MagicMock(), table=Table("table"), query=sql.select(text("* FROM table")), columns=[ { "column_name": "partition_key", "name": "partition_key", "type": column_type, "is_dttm": False, } ], ).compile( dialect=TrinoDialect(), compile_kwargs={"literal_binds": True}, ) ) == f"""SELECT * FROM table \nWHERE partition_key = {expected_value}""" # noqa: S608 ) @pytest.fixture def oauth2_config() -> OAuth2ClientConfig: """ Config for Trino OAuth2. """ return { "id": "trino", "secret": "very-secret", "scope": "", "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/", "authorization_request_uri": "https://trino.auth.server.example/realms/master/protocol/openid-connect/auth", "token_request_uri": "https://trino.auth.server.example/master/protocol/openid-connect/token", "request_content_type": "data", } def test_get_oauth2_token( mocker: MockerFixture, oauth2_config: OAuth2ClientConfig, ) -> None: """ Test `get_oauth2_token`. """ from superset.db_engine_specs.trino import TrinoEngineSpec requests = mocker.patch("superset.db_engine_specs.base.requests") requests.post().json.return_value = { "access_token": "access-token", "expires_in": 3600, "scope": "scope", "token_type": "Bearer", "refresh_token": "refresh-token", } assert TrinoEngineSpec.get_oauth2_token(oauth2_config, "code") == { "access_token": "access-token", "expires_in": 3600, "scope": "scope", "token_type": "Bearer", "refresh_token": "refresh-token", } requests.post.assert_called_with( "https://trino.auth.server.example/master/protocol/openid-connect/token", data={ "code": "code", "client_id": "trino", "client_secret": "very-secret", "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/", "grant_type": "authorization_code", }, timeout=30.0, ) @pytest.mark.parametrize( "time_grain,expected_result", [ ("PT1S", "date_trunc('second', CAST(col AS TIMESTAMP))"), ( "PT5S", "date_trunc('second', CAST(col AS TIMESTAMP)) - interval '1' second * (second(CAST(col AS TIMESTAMP)) % 5)", # noqa: E501 ), ( "PT30S", "date_trunc('second', CAST(col AS TIMESTAMP)) - interval '1' second * (second(CAST(col AS TIMESTAMP)) % 30)", # noqa: E501 ), ("PT1M", "date_trunc('minute', CAST(col AS TIMESTAMP))"), ( "PT5M", "date_trunc('minute', CAST(col AS TIMESTAMP)) - interval '1' minute * (minute(CAST(col AS TIMESTAMP)) % 5)", # noqa: E501 ), ( "PT10M", "date_trunc('minute', CAST(col AS TIMESTAMP)) - interval '1' minute * (minute(CAST(col AS TIMESTAMP)) % 10)", # noqa: E501 ), ( "PT15M", "date_trunc('minute', CAST(col AS TIMESTAMP)) - interval '1' minute * (minute(CAST(col AS TIMESTAMP)) % 15)", # noqa: E501 ), ( "PT0.5H", "date_trunc('minute', CAST(col AS TIMESTAMP)) - interval '1' minute * (minute(CAST(col AS TIMESTAMP)) % 30)", # noqa: E501 ), ("PT1H", "date_trunc('hour', CAST(col AS TIMESTAMP))"), ( "PT6H", "date_trunc('hour', CAST(col AS TIMESTAMP)) - interval '1' hour * (hour(CAST(col AS TIMESTAMP)) % 6)", # noqa: E501 ), ("P1D", "date_trunc('day', CAST(col AS TIMESTAMP))"), ("P1W", "date_trunc('week', CAST(col AS TIMESTAMP))"), ("P1M", "date_trunc('month', CAST(col AS TIMESTAMP))"), ("P3M", "date_trunc('quarter', CAST(col AS TIMESTAMP))"), ("P1Y", "date_trunc('year', CAST(col AS TIMESTAMP))"), ( "1969-12-28T00:00:00Z/P1W", "date_trunc('week', CAST(col AS TIMESTAMP) + interval '1' day) - interval '1' day", # noqa: E501 ), ("1969-12-29T00:00:00Z/P1W", "date_trunc('week', CAST(col AS TIMESTAMP))"), ( "P1W/1970-01-03T00:00:00Z", "date_trunc('week', CAST(col AS TIMESTAMP) + interval '1' day) + interval '5' day", # noqa: E501 ), ( "P1W/1970-01-04T00:00:00Z", "date_trunc('week', CAST(col AS TIMESTAMP)) + interval '6' day", ), ], ) def test_timegrain_expressions(time_grain: str, expected_result: str) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec as spec # noqa: N813 actual = str( spec.get_timestamp_expr(col=column("col"), pdf=None, time_grain=time_grain) ) assert actual == expected_result @patch("superset.db_engine_specs.presto.PrestoBaseEngineSpec.handle_cursor") @patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query") @patch("superset.db_engine_specs.trino.db") @patch("superset.db_engine_specs.trino.app") def test_handle_cursor_progress_updates( mock_app: Mock, mock_db: Mock, mock_cancel_query: Mock, mock_presto_handle_cursor: Mock, mocker: MockerFixture, ) -> None: """Test that handle_cursor updates query progress based on cursor stats.""" from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.sql_lab import Query mock_app.config = {"DB_POLL_INTERVAL_SECONDS": {"trino": 0}} cursor_mock = mocker.MagicMock(spec=["query_id", "stats", "info_uri"]) cursor_mock.query_id = "test-query-id" # Simulate progress: 0/10 -> 5/10 -> 10/10 (FINISHED) cursor_mock.stats = {"state": "RUNNING", "completedSplits": 0, "totalSplits": 10} call_count = 0 def update_stats(*args, **kwargs): nonlocal call_count call_count += 1 if call_count == 1: cursor_mock.stats = { "state": "RUNNING", "completedSplits": 5, "totalSplits": 10, } elif call_count >= 2: cursor_mock.stats = { "state": "FINISHED", "completedSplits": 10, "totalSplits": 10, } with patch("superset.db_engine_specs.trino.time.sleep", side_effect=update_stats): query = Query() query.status = "running" TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query) assert query.progress == 100.0 assert mock_db.session.commit.called @patch("superset.db_engine_specs.presto.PrestoBaseEngineSpec.handle_cursor") @patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query") @patch("superset.db_engine_specs.trino.db") @patch("superset.db_engine_specs.trino.app") def test_handle_cursor_cancels_on_stopped_status( mock_app: Mock, mock_db: Mock, mock_cancel_query: Mock, mock_presto_handle_cursor: Mock, mocker: MockerFixture, ) -> None: """Test that handle_cursor cancels query when status is STOPPED.""" from superset.common.db_query_status import QueryStatus from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.sql_lab import Query mock_app.config = {"DB_POLL_INTERVAL_SECONDS": {"trino": 0}} cursor_mock = mocker.MagicMock(spec=["query_id", "stats", "info_uri"]) cursor_mock.query_id = "test-query-id" cursor_mock.stats = {"state": "RUNNING", "completedSplits": 0, "totalSplits": 10} query = Query() query.status = QueryStatus.STOPPED with patch("superset.db_engine_specs.trino.time.sleep"): TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query) mock_cancel_query.assert_called_once_with( cursor=cursor_mock, query=query, cancel_query_id="test-query-id", ) @patch("superset.db_engine_specs.presto.PrestoBaseEngineSpec.handle_cursor") @patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query") @patch("superset.db_engine_specs.trino.db") @patch("superset.db_engine_specs.trino.app") def test_handle_cursor_cancels_on_timed_out_status( mock_app: Mock, mock_db: Mock, mock_cancel_query: Mock, mock_presto_handle_cursor: Mock, mocker: MockerFixture, ) -> None: """Test that handle_cursor cancels query when status is TIMED_OUT.""" from superset.common.db_query_status import QueryStatus from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.sql_lab import Query mock_app.config = {"DB_POLL_INTERVAL_SECONDS": {"trino": 0}} cursor_mock = mocker.MagicMock(spec=["query_id", "stats", "info_uri"]) cursor_mock.query_id = "test-query-id" cursor_mock.stats = {"state": "RUNNING", "completedSplits": 0, "totalSplits": 10} query = Query() query.status = QueryStatus.TIMED_OUT with patch("superset.db_engine_specs.trino.time.sleep"): TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query) mock_cancel_query.assert_called_once_with( cursor=cursor_mock, query=query, cancel_query_id="test-query-id", ) @patch("superset.db_engine_specs.presto.PrestoBaseEngineSpec.handle_cursor") @patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query") @patch("superset.db_engine_specs.trino.db") @patch("superset.db_engine_specs.trino.app") def test_handle_cursor_breaks_on_execute_error( mock_app: Mock, mock_db: Mock, mock_cancel_query: Mock, mock_presto_handle_cursor: Mock, mocker: MockerFixture, ) -> None: """Test that handle_cursor breaks the loop when execute_result has an error.""" from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.sql_lab import Query mock_app.config = {"DB_POLL_INTERVAL_SECONDS": {"trino": 0}} cursor_mock = mocker.MagicMock( spec=["query_id", "stats", "info_uri", "_execute_result", "_execute_event"] ) cursor_mock.query_id = "test-query-id" cursor_mock.stats = {"state": "RUNNING", "completedSplits": 0, "totalSplits": 10} cursor_mock._execute_result = {"error": Exception("Test error")} cursor_mock._execute_event = None query = Query() query.status = "running" # Should break immediately due to error in execute_result TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query) # cancel_query should not be called since we broke due to error, not status mock_cancel_query.assert_not_called() @patch("superset.db_engine_specs.presto.PrestoBaseEngineSpec.handle_cursor") @patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query") @patch("superset.db_engine_specs.trino.db") @patch("superset.db_engine_specs.trino.app") def test_handle_cursor_breaks_on_execute_event_set( mock_app: Mock, mock_db: Mock, mock_cancel_query: Mock, mock_presto_handle_cursor: Mock, mocker: MockerFixture, ) -> None: """Test that handle_cursor breaks the loop when execute_event is set.""" import threading from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.sql_lab import Query mock_app.config = {"DB_POLL_INTERVAL_SECONDS": {"trino": 0}} cursor_mock = mocker.MagicMock( spec=["query_id", "stats", "info_uri", "_execute_result", "_execute_event"] ) cursor_mock.query_id = "test-query-id" cursor_mock.stats = {"state": "RUNNING", "completedSplits": 5, "totalSplits": 10} execute_event = threading.Event() execute_event.set() # Simulate thread completion cursor_mock._execute_result = {} cursor_mock._execute_event = execute_event query = Query() query.status = "running" # Should break immediately since execute_event is set TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query) # cancel_query should not be called since we broke due to event being set mock_cancel_query.assert_not_called() @patch("superset.db_engine_specs.presto.PrestoBaseEngineSpec.handle_cursor") @patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query") @patch("superset.db_engine_specs.trino.db") @patch("superset.db_engine_specs.trino.app") def test_handle_cursor_handles_zero_total_splits( mock_app: Mock, mock_db: Mock, mock_cancel_query: Mock, mock_presto_handle_cursor: Mock, mocker: MockerFixture, ) -> None: """Test that handle_cursor handles zero totalSplits without division error.""" from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.sql_lab import Query mock_app.config = {"DB_POLL_INTERVAL_SECONDS": {"trino": 0}} cursor_mock = mocker.MagicMock(spec=["query_id", "stats", "info_uri"]) cursor_mock.query_id = "test-query-id" call_count = 0 def update_stats(*args, **kwargs): nonlocal call_count call_count += 1 if call_count >= 1: cursor_mock.stats = { "state": "FINISHED", "completedSplits": 0, "totalSplits": 0, } # Start with zero splits cursor_mock.stats = {"state": "RUNNING", "completedSplits": 0, "totalSplits": 0} with patch("superset.db_engine_specs.trino.time.sleep", side_effect=update_stats): query = Query() query.status = "running" # Should not raise ZeroDivisionError TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query) # Progress should be 0/1 = 0 when totalSplits is 0 assert query.progress == 0.0 @patch("superset.db_engine_specs.presto.PrestoBaseEngineSpec.handle_cursor") @patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query") @patch("superset.db_engine_specs.trino.db") @patch("superset.db_engine_specs.trino.app") def test_handle_cursor_only_commits_on_progress_change( mock_app: Mock, mock_db: Mock, mock_cancel_query: Mock, mock_presto_handle_cursor: Mock, mocker: MockerFixture, ) -> None: """Test that handle_cursor only commits when progress changes.""" from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.sql_lab import Query mock_app.config = {"DB_POLL_INTERVAL_SECONDS": {"trino": 0}} cursor_mock = mocker.MagicMock(spec=["query_id", "stats", "info_uri"]) cursor_mock.query_id = "test-query-id" call_count = 0 def update_stats(*args, **kwargs): nonlocal call_count call_count += 1 # Keep same progress for first two iterations, then finish if call_count < 3: cursor_mock.stats = { "state": "RUNNING", "completedSplits": 5, "totalSplits": 10, } else: cursor_mock.stats = { "state": "FINISHED", "completedSplits": 10, "totalSplits": 10, } cursor_mock.stats = {"state": "RUNNING", "completedSplits": 5, "totalSplits": 10} with patch("superset.db_engine_specs.trino.time.sleep", side_effect=update_stats): query = Query() query.status = "running" TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query) # Initial commit from set_extra_json_key, then commits only when progress changes # Progress changes: None->0.5, 0.5->0.5 (no commit), 0.5->1.0 # So we expect: 1 (initial) + 1 (0.5) + 1 (1.0) = 3 commits total commit_calls = mock_db.session.commit.call_count assert commit_calls >= 2 # At least initial commit and one progress update @patch("superset.db_engine_specs.presto.PrestoBaseEngineSpec.handle_cursor") @patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query") @patch("superset.db_engine_specs.trino.db") @patch("superset.db_engine_specs.trino.app") def test_handle_cursor_sets_progress_text_for_planning_state( mock_app: Mock, mock_db: Mock, mock_cancel_query: Mock, mock_presto_handle_cursor: Mock, mocker: MockerFixture, ) -> None: """Test that handle_cursor sets progress_text to 'Scheduled' for PLANNING state.""" from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.sql_lab import Query mock_app.config = {"DB_POLL_INTERVAL_SECONDS": {"trino": 0}} cursor_mock = mocker.MagicMock(spec=["query_id", "stats", "info_uri"]) cursor_mock.query_id = "test-query-id" call_count = 0 def update_stats(*args, **kwargs): nonlocal call_count call_count += 1 if call_count >= 1: cursor_mock.stats = { "state": "FINISHED", "completedSplits": 10, "totalSplits": 10, } # Start with PLANNING state cursor_mock.stats = {"state": "PLANNING", "completedSplits": 0, "totalSplits": 10} with patch("superset.db_engine_specs.trino.time.sleep", side_effect=update_stats): query = Query() query.status = "running" TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query) # Check that progress_text was set to "Scheduled" for PLANNING state assert query.extra.get("progress_text") is not None @patch("superset.db_engine_specs.presto.PrestoBaseEngineSpec.handle_cursor") @patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query") @patch("superset.db_engine_specs.trino.db") @patch("superset.db_engine_specs.trino.app") def test_handle_cursor_sets_progress_text_for_queued_state( mock_app: Mock, mock_db: Mock, mock_cancel_query: Mock, mock_presto_handle_cursor: Mock, mocker: MockerFixture, ) -> None: """Test that handle_cursor sets progress_text to 'Queued' for QUEUED state.""" from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.sql_lab import Query mock_app.config = {"DB_POLL_INTERVAL_SECONDS": {"trino": 0}} cursor_mock = mocker.MagicMock(spec=["query_id", "stats", "info_uri"]) cursor_mock.query_id = "test-query-id" call_count = 0 def update_stats(*args, **kwargs): nonlocal call_count call_count += 1 if call_count >= 1: cursor_mock.stats = { "state": "FINISHED", "completedSplits": 10, "totalSplits": 10, } # Start with QUEUED state cursor_mock.stats = {"state": "QUEUED", "completedSplits": 0, "totalSplits": 0} with patch("superset.db_engine_specs.trino.time.sleep", side_effect=update_stats): query = Query() query.status = "running" TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query) # Check that progress_text was set assert query.extra.get("progress_text") is not None @patch("superset.db_engine_specs.presto.PrestoBaseEngineSpec.handle_cursor") @patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query") @patch("superset.db_engine_specs.trino.db") @patch("superset.db_engine_specs.trino.app") def test_handle_cursor_sets_progress_text_to_state_for_unmapped_states( mock_app: Mock, mock_db: Mock, mock_cancel_query: Mock, mock_presto_handle_cursor: Mock, mocker: MockerFixture, ) -> None: """Test that handle_cursor sets progress_text to raw state for unmapped states.""" from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.sql_lab import Query mock_app.config = {"DB_POLL_INTERVAL_SECONDS": {"trino": 0}} cursor_mock = mocker.MagicMock(spec=["query_id", "stats", "info_uri"]) cursor_mock.query_id = "test-query-id" # Start directly with FINISHED state (not in the mapping) cursor_mock.stats = {"state": "FINISHED", "completedSplits": 10, "totalSplits": 10} query = Query() query.status = "running" TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query) # Check that progress_text was set to the raw state "FINISHED" # since FINISHED is not in the mapping (only PLANNING and QUEUED are mapped) assert query.extra.get("progress_text") == "FINISHED" @patch("superset.db_engine_specs.presto.PrestoBaseEngineSpec.handle_cursor") @patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query") @patch("superset.db_engine_specs.trino.db") @patch("superset.db_engine_specs.trino.app") def test_handle_cursor_commits_on_progress_text_change( mock_app: Mock, mock_db: Mock, mock_cancel_query: Mock, mock_presto_handle_cursor: Mock, mocker: MockerFixture, ) -> None: """Test that handle_cursor commits only when progress_text changes.""" from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.sql_lab import Query mock_app.config = {"DB_POLL_INTERVAL_SECONDS": {"trino": 0}} cursor_mock = mocker.MagicMock(spec=["query_id", "stats", "info_uri"]) cursor_mock.query_id = "test-query-id" call_count = 0 def update_stats(*args, **kwargs): nonlocal call_count call_count += 1 if call_count == 1: # State changes from QUEUED to RUNNING but progress stays at 0 cursor_mock.stats = { "state": "RUNNING", "completedSplits": 0, "totalSplits": 10, } elif call_count >= 2: cursor_mock.stats = { "state": "FINISHED", "completedSplits": 10, "totalSplits": 10, } # Start with QUEUED state and 0 progress cursor_mock.stats = {"state": "QUEUED", "completedSplits": 0, "totalSplits": 10} with patch("superset.db_engine_specs.trino.time.sleep", side_effect=update_stats): query = Query() query.status = "running" TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query) # There should be commits for progress_text changes assert mock_db.session.commit.call_count >= 2