# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # pylint: disable=unused-argument, import-outside-toplevel, protected-access from __future__ import annotations import copy from collections import namedtuple from datetime import datetime from typing import Any, Optional from unittest.mock import MagicMock, Mock, patch import pandas as pd import pytest from flask import g, has_app_context from pytest_mock import MockerFixture from requests.exceptions import ConnectionError as RequestsConnectionError from sqlalchemy import column, sql, text, types from sqlalchemy.dialects import sqlite from sqlalchemy.engine.url import make_url from sqlalchemy.exc import NoSuchTableError from trino.exceptions import TrinoExternalError, TrinoInternalError, TrinoUserError from trino.sqlalchemy import datatype from trino.sqlalchemy.dialect import TrinoDialect import superset.config from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY from superset.db_engine_specs.exceptions import ( SupersetDBAPIConnectionError, SupersetDBAPIDatabaseError, SupersetDBAPIOperationalError, SupersetDBAPIProgrammingError, ) from superset.sql.parse import Table from superset.superset_typing import ( OAuth2ClientConfig, ResultSetColumnType, SQLAColumnType, SQLType, ) from superset.utils import json from superset.utils.core import GenericDataType from tests.unit_tests.db_engine_specs.utils import ( assert_column_spec, assert_convert_dttm, ) from tests.unit_tests.fixtures.common import dttm # noqa: F401 def _assert_columns_equal(actual_cols, expected_cols) -> None: """ Assert equality of the given cols, bearing in mind sqlalchemy type instances can't be compared for equality, so will have to be converted to strings first. """ actual = copy.deepcopy(actual_cols) expected = copy.deepcopy(expected_cols) for col in actual: col["type"] = str(col["type"]) for col in expected: col["type"] = str(col["type"]) assert actual == expected @pytest.mark.parametrize( "extra,expected", [ ({}, {"engine_params": {"connect_args": {"source": "Apache Superset"}}}), ( { "first": 1, "engine_params": { "second": "two", "connect_args": {"source": "foobar", "third": "three"}, }, }, { "first": 1, "engine_params": { "second": "two", "connect_args": {"source": "foobar", "third": "three"}, }, }, ), ], ) def test_get_extra_params(extra: dict[str, Any], expected: dict[str, Any]) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec database = Mock() database.extra = json.dumps(extra) database.server_cert = None assert TrinoEngineSpec.get_extra_params(database) == expected @patch("superset.db_engine_specs.trino.create_ssl_cert_file") def test_get_extra_params_with_server_cert(mock_create_ssl_cert_file: Mock) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec database = Mock() database.extra = json.dumps({}) database.server_cert = "TEST_CERT" database.db_engine_spec = TrinoEngineSpec mock_create_ssl_cert_file.return_value = "/path/to/tls.crt" extra = TrinoEngineSpec.get_extra_params(database) connect_args = extra.get("engine_params", {}).get("connect_args", {}) assert connect_args.get("http_scheme") == "https" assert connect_args.get("verify") == "/path/to/tls.crt" mock_create_ssl_cert_file.assert_called_once_with(database.server_cert) @patch("trino.auth.BasicAuthentication") def test_auth_basic(mock_auth: Mock) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec database = Mock() auth_params = {"username": "username", "password": "password"} database.encrypted_extra = json.dumps( {"auth_method": "basic", "auth_params": auth_params} ) params: dict[str, Any] = {} TrinoEngineSpec.update_params_from_encrypted_extra(database, params) connect_args = params.setdefault("connect_args", {}) assert connect_args.get("http_scheme") == "https" mock_auth.assert_called_once_with(**auth_params) @patch("trino.auth.KerberosAuthentication") def test_auth_kerberos(mock_auth: Mock) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec database = Mock() auth_params = { "service_name": "superset", "mutual_authentication": False, "delegate": True, } database.encrypted_extra = json.dumps( {"auth_method": "kerberos", "auth_params": auth_params} ) params: dict[str, Any] = {} TrinoEngineSpec.update_params_from_encrypted_extra(database, params) connect_args = params.setdefault("connect_args", {}) assert connect_args.get("http_scheme") == "https" mock_auth.assert_called_once_with(**auth_params) @patch("trino.auth.CertificateAuthentication") def test_auth_certificate(mock_auth: Mock) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec database = Mock() auth_params = {"cert": "/path/to/cert.pem", "key": "/path/to/key.pem"} database.encrypted_extra = json.dumps( {"auth_method": "certificate", "auth_params": auth_params} ) params: dict[str, Any] = {} TrinoEngineSpec.update_params_from_encrypted_extra(database, params) connect_args = params.setdefault("connect_args", {}) assert connect_args.get("http_scheme") == "https" mock_auth.assert_called_once_with(**auth_params) @patch("trino.auth.JWTAuthentication") def test_auth_jwt(mock_auth: Mock) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec database = Mock() auth_params = {"token": "jwt-token-string"} database.encrypted_extra = json.dumps( {"auth_method": "jwt", "auth_params": auth_params} ) params: dict[str, Any] = {} TrinoEngineSpec.update_params_from_encrypted_extra(database, params) connect_args = params.setdefault("connect_args", {}) assert connect_args.get("http_scheme") == "https" mock_auth.assert_called_once_with(**auth_params) def test_auth_custom_auth() -> None: from superset.db_engine_specs.trino import TrinoEngineSpec database = Mock() auth_class = Mock() auth_method = "custom_auth" auth_params = {"params1": "params1", "params2": "params2"} database.encrypted_extra = json.dumps( {"auth_method": auth_method, "auth_params": auth_params} ) with patch.dict( "superset.config.ALLOWED_EXTRA_AUTHENTICATIONS", {"trino": {"custom_auth": auth_class}}, clear=True, ): params: dict[str, Any] = {} TrinoEngineSpec.update_params_from_encrypted_extra(database, params) connect_args = params.setdefault("connect_args", {}) assert connect_args.get("http_scheme") == "https" auth_class.assert_called_once_with(**auth_params) def test_auth_custom_auth_denied() -> None: from superset.db_engine_specs.trino import TrinoEngineSpec database = Mock() auth_method = "my.module:TrinoAuthClass" auth_params = {"params1": "params1", "params2": "params2"} database.encrypted_extra = json.dumps( {"auth_method": auth_method, "auth_params": auth_params} ) superset.config.ALLOWED_EXTRA_AUTHENTICATIONS = {} with pytest.raises(ValueError) as excinfo: # noqa: PT011 TrinoEngineSpec.update_params_from_encrypted_extra(database, {}) assert str(excinfo.value) == ( f"For security reason, custom authentication '{auth_method}' " f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config" ) @pytest.mark.parametrize( "native_type,sqla_type,attrs,generic_type,is_dttm", [ ("BOOLEAN", types.Boolean, None, GenericDataType.BOOLEAN, False), ("TINYINT", types.Integer, None, GenericDataType.NUMERIC, False), ("SMALLINT", types.SmallInteger, None, GenericDataType.NUMERIC, False), ("INTEGER", types.Integer, None, GenericDataType.NUMERIC, False), ("BIGINT", types.BigInteger, None, GenericDataType.NUMERIC, False), ("REAL", types.FLOAT, None, GenericDataType.NUMERIC, False), ("DOUBLE", types.FLOAT, None, GenericDataType.NUMERIC, False), ("DECIMAL", types.DECIMAL, None, GenericDataType.NUMERIC, False), ("VARCHAR", types.String, None, GenericDataType.STRING, False), ("VARCHAR(20)", types.VARCHAR, {"length": 20}, GenericDataType.STRING, False), ("CHAR", types.String, None, GenericDataType.STRING, False), ("CHAR(2)", types.CHAR, {"length": 2}, GenericDataType.STRING, False), ("JSON", types.JSON, None, GenericDataType.STRING, False), ("TIMESTAMP", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True), ("TIMESTAMP(3)", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True), ( "TIMESTAMP WITH TIME ZONE", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True, ), ( "TIMESTAMP(3) WITH TIME ZONE", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True, ), ("DATE", types.Date, None, GenericDataType.TEMPORAL, True), ], ) def test_get_column_spec( native_type: str, sqla_type: type[types.TypeEngine], attrs: Optional[dict[str, Any]], generic_type: GenericDataType, is_dttm: bool, ) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec as spec # noqa: N813 assert_column_spec( spec, native_type, sqla_type, attrs, generic_type, is_dttm, ) @pytest.mark.parametrize( "target_type,expected_result", [ ("TimeStamp", "TIMESTAMP '2019-01-02 03:04:05.678900'"), ("TimeStamp(3)", "TIMESTAMP '2019-01-02 03:04:05.678900'"), ("TimeStamp With Time Zone", "TIMESTAMP '2019-01-02 03:04:05.678900'"), ("TimeStamp(3) With Time Zone", "TIMESTAMP '2019-01-02 03:04:05.678900'"), ("Date", "DATE '2019-01-02'"), ("Other", None), ], ) def test_convert_dttm( target_type: str, expected_result: Optional[str], dttm: datetime, # noqa: F811 ) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec assert_convert_dttm(TrinoEngineSpec, target_type, expected_result, dttm) def test_get_extra_table_metadata(mocker: MockerFixture) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec db_mock = mocker.MagicMock() db_mock.get_indexes = Mock( return_value=[{"column_names": ["ds", "hour"], "name": "partition"}] ) db_mock.get_extra = Mock(return_value={}) db_mock.has_view = Mock(return_value=None) db_mock.get_df = Mock(return_value=pd.DataFrame({"ds": ["01-01-19"], "hour": [1]})) result = TrinoEngineSpec.get_extra_table_metadata( db_mock, Table("test_table", "test_schema"), ) assert result["partitions"]["cols"] == ["ds", "hour"] assert result["partitions"]["latest"] == {"ds": "01-01-19", "hour": 1} @patch("sqlalchemy.engine.Engine.connect") def test_cancel_query_success(engine_mock: Mock) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.sql_lab import Query query = Query() cursor_mock = engine_mock.return_value.__enter__.return_value assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is True @patch("sqlalchemy.engine.Engine.connect") def test_cancel_query_failed(engine_mock: Mock) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.sql_lab import Query query = Query() cursor_mock = engine_mock.raiseError.side_effect = Exception() assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is False @pytest.mark.parametrize( "initial_extra,final_extra", [ ({}, {QUERY_EARLY_CANCEL_KEY: True}), ({QUERY_CANCEL_KEY: "my_key"}, {QUERY_CANCEL_KEY: "my_key"}), ], ) def test_prepare_cancel_query( initial_extra: dict[str, Any], final_extra: dict[str, Any], mocker: MockerFixture, ) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.sql_lab import Query query = Query(extra_json=json.dumps(initial_extra)) TrinoEngineSpec.prepare_cancel_query(query=query) assert query.extra == final_extra @pytest.mark.parametrize("cancel_early", [True, False]) @patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query") @patch("sqlalchemy.engine.Engine.connect") def test_handle_cursor_early_cancel( engine_mock: Mock, cancel_query_mock: Mock, cancel_early: bool, mocker: MockerFixture, ) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.sql_lab import Query query_id = "myQueryId" cursor_mock = engine_mock.return_value.__enter__.return_value cursor_mock.query_id = query_id query = Query() if cancel_early: TrinoEngineSpec.prepare_cancel_query(query=query) TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query) if cancel_early: assert cancel_query_mock.call_args[1]["cancel_query_id"] == query_id else: assert cancel_query_mock.call_args is None def test_execute_with_cursor_in_parallel(app, mocker: MockerFixture): """Test that `execute_with_cursor` fetches query ID from the cursor""" from superset.db_engine_specs.trino import TrinoEngineSpec query_id = "myQueryId" mock_cursor = mocker.MagicMock() mock_cursor.query_id = None mock_query = mocker.MagicMock() def _mock_execute(*args, **kwargs): mock_cursor.query_id = query_id with app.test_request_context("/some/place/"): mock_cursor.execute.side_effect = _mock_execute with patch.dict( "superset.config.DISALLOWED_SQL_FUNCTIONS", {}, clear=True, ): TrinoEngineSpec.execute_with_cursor( cursor=mock_cursor, sql="SELECT 1 FROM foo", query=mock_query, ) mock_query.set_extra_json_key.assert_called_once_with( key=QUERY_CANCEL_KEY, value=query_id ) def test_execute_with_cursor_app_context(app, mocker: MockerFixture): """Test that `execute_with_cursor` still contains the current app context""" from superset.db_engine_specs.trino import TrinoEngineSpec mock_cursor = mocker.MagicMock() mock_cursor.query_id = None mock_query = mocker.MagicMock() def _mock_execute(*args, **kwargs): assert has_app_context() assert g.some_value == "some_value" with app.test_request_context("/some/place/"): g.some_value = "some_value" with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute): with patch.dict( "superset.config.DISALLOWED_SQL_FUNCTIONS", {}, clear=True, ): TrinoEngineSpec.execute_with_cursor( cursor=mock_cursor, sql="SELECT 1 FROM foo", query=mock_query, ) def test_get_columns(mocker: MockerFixture): """Test that ROW columns are not expanded without expand_rows""" from superset.db_engine_specs.trino import TrinoEngineSpec field1_type = datatype.parse_sqltype("row(a varchar, b date)") field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))") field3_type = datatype.parse_sqltype("int") sqla_columns = [ SQLAColumnType(name="field1", type=field1_type, is_dttm=False), SQLAColumnType(name="field2", type=field2_type, is_dttm=False), SQLAColumnType(name="field3", type=field3_type, is_dttm=False), ] mock_inspector = mocker.MagicMock() mock_inspector.get_columns.return_value = sqla_columns actual = TrinoEngineSpec.get_columns(mock_inspector, Table("table", "schema")) expected = [ ResultSetColumnType( name="field1", column_name="field1", type=field1_type, is_dttm=False ), ResultSetColumnType( name="field2", column_name="field2", type=field2_type, is_dttm=False ), ResultSetColumnType( name="field3", column_name="field3", type=field3_type, is_dttm=False ), ] _assert_columns_equal(actual, expected) def test_get_columns_error(mocker: MockerFixture): """ Test that we fallback to a `SHOW COLUMNS FROM ...` query. """ from superset.db_engine_specs.trino import TrinoEngineSpec field1_type = datatype.parse_sqltype("row(a varchar, b date)") field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))") field3_type = datatype.parse_sqltype("int") mock_inspector = mocker.MagicMock() mock_inspector.engine.dialect = sqlite.dialect() mock_inspector.get_columns.side_effect = NoSuchTableError( "The specified table does not exist." ) Row = namedtuple("Row", ["Column", "Type"]) mock_inspector.bind.execute().fetchall.return_value = [ Row("field1", "row(a varchar, b date)"), Row("field2", "row(r1 row(a varchar, b varchar))"), Row("field3", "int"), ] actual = TrinoEngineSpec.get_columns(mock_inspector, Table("table", "schema")) expected = [ ResultSetColumnType( name="field1", column_name="field1", type=field1_type, is_dttm=None, type_generic=None, default=None, nullable=True, ), ResultSetColumnType( name="field2", column_name="field2", type=field2_type, is_dttm=None, type_generic=None, default=None, nullable=True, ), ResultSetColumnType( name="field3", column_name="field3", type=field3_type, is_dttm=None, type_generic=None, default=None, nullable=True, ), ] _assert_columns_equal(actual, expected) mock_inspector.bind.execute.assert_called_with('SHOW COLUMNS FROM schema."table"') def test_get_columns_expand_rows(mocker: MockerFixture): """Test that ROW columns are correctly expanded with expand_rows""" from superset.db_engine_specs.trino import TrinoEngineSpec field1_type = datatype.parse_sqltype("row(a varchar, b date)") field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))") field3_type = datatype.parse_sqltype("int") sqla_columns = [ SQLAColumnType(name="field1", type=field1_type, is_dttm=False), SQLAColumnType(name="field2", type=field2_type, is_dttm=False), SQLAColumnType(name="field3", type=field3_type, is_dttm=False), ] mock_inspector = mocker.MagicMock() mock_inspector.get_columns.return_value = sqla_columns actual = TrinoEngineSpec.get_columns( mock_inspector, Table("table", "schema"), {"expand_rows": True}, ) expected = [ ResultSetColumnType( name="field1", column_name="field1", type=field1_type, is_dttm=False ), ResultSetColumnType( name="field1.a", column_name="field1.a", type=types.VARCHAR(), is_dttm=False, query_as='"field1"."a" AS "field1.a"', ), ResultSetColumnType( name="field1.b", column_name="field1.b", type=types.DATE(), is_dttm=True, query_as='"field1"."b" AS "field1.b"', ), ResultSetColumnType( name="field2", column_name="field2", type=field2_type, is_dttm=False ), ResultSetColumnType( name="field2.r1", column_name="field2.r1", type=datatype.parse_sqltype("row(a varchar, b varchar)"), is_dttm=False, query_as='"field2"."r1" AS "field2.r1"', ), ResultSetColumnType( name="field2.r1.a", column_name="field2.r1.a", type=types.VARCHAR(), is_dttm=False, query_as='"field2"."r1"."a" AS "field2.r1.a"', ), ResultSetColumnType( name="field2.r1.b", column_name="field2.r1.b", type=types.VARCHAR(), is_dttm=False, query_as='"field2"."r1"."b" AS "field2.r1.b"', ), ResultSetColumnType( name="field3", column_name="field3", type=field3_type, is_dttm=False ), ] _assert_columns_equal(actual, expected) def test_get_indexes_no_table(): from superset.db_engine_specs.trino import TrinoEngineSpec db_mock = Mock() inspector_mock = Mock() inspector_mock.get_indexes = Mock( side_effect=NoSuchTableError("The specified table does not exist.") ) result = TrinoEngineSpec.get_indexes( db_mock, inspector_mock, Table("test_table", "test_schema"), ) assert result == [] def test_get_dbapi_exception_mapping(): from superset.db_engine_specs.trino import TrinoEngineSpec mapping = TrinoEngineSpec.get_dbapi_exception_mapping() assert mapping.get(TrinoUserError) == SupersetDBAPIProgrammingError assert mapping.get(TrinoInternalError) == SupersetDBAPIDatabaseError assert mapping.get(TrinoExternalError) == SupersetDBAPIOperationalError assert mapping.get(RequestsConnectionError) == SupersetDBAPIConnectionError assert mapping.get(Exception) is None def test_adjust_engine_params_fully_qualified() -> None: """ Test the ``adjust_engine_params`` method when the URL has catalog and schema. """ from superset.db_engine_specs.trino import TrinoEngineSpec url = make_url("trino://user:pass@localhost:8080/system/default") uri = TrinoEngineSpec.adjust_engine_params(url, {})[0] assert str(uri) == "trino://user:pass@localhost:8080/system/default" uri = TrinoEngineSpec.adjust_engine_params( url, {}, schema="new_schema", )[0] assert str(uri) == "trino://user:pass@localhost:8080/system/new_schema" uri = TrinoEngineSpec.adjust_engine_params( url, {}, catalog="new_catalog", )[0] assert str(uri) == "trino://user:pass@localhost:8080/new_catalog/default" uri = TrinoEngineSpec.adjust_engine_params( url, {}, catalog="new_catalog", schema="new_schema", )[0] assert str(uri) == "trino://user:pass@localhost:8080/new_catalog/new_schema" def test_adjust_engine_params_catalog_only() -> None: """ Test the ``adjust_engine_params`` method when the URL has only the catalog. """ from superset.db_engine_specs.trino import TrinoEngineSpec url = make_url("trino://user:pass@localhost:8080/system") uri = TrinoEngineSpec.adjust_engine_params(url, {})[0] assert str(uri) == "trino://user:pass@localhost:8080/system" uri = TrinoEngineSpec.adjust_engine_params( url, {}, schema="new_schema", )[0] assert str(uri) == "trino://user:pass@localhost:8080/system/new_schema" uri = TrinoEngineSpec.adjust_engine_params( url, {}, catalog="new_catalog", )[0] assert str(uri) == "trino://user:pass@localhost:8080/new_catalog" uri = TrinoEngineSpec.adjust_engine_params( url, {}, catalog="new_catalog", schema="new_schema", )[0] assert str(uri) == "trino://user:pass@localhost:8080/new_catalog/new_schema" @pytest.mark.parametrize( "sqlalchemy_uri,result", [ ("trino://user:pass@localhost:8080/system", "system"), ("trino://user:pass@localhost:8080/system/default", "system"), ("trino://trino@localhost:8081", None), ], ) def test_get_default_catalog(sqlalchemy_uri: str, result: str | None) -> None: """ Test the ``get_default_catalog`` method. """ from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.core import Database database = Database( database_name="my_db", sqlalchemy_uri=sqlalchemy_uri, ) assert TrinoEngineSpec.get_default_catalog(database) == result @patch("superset.db_engine_specs.trino.TrinoEngineSpec.latest_partition") @pytest.mark.parametrize( ["column_type", "column_value", "expected_value"], [ (types.DATE(), "2023-05-01", "DATE '2023-05-01'"), (types.TIMESTAMP(), "2023-05-01", "TIMESTAMP '2023-05-01'"), (types.VARCHAR(), "2023-05-01", "'2023-05-01'"), (types.INT(), 1234, "1234"), ], ) def test_where_latest_partition( mock_latest_partition, column_type: SQLType, column_value: Any, expected_value: str, ) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec mock_latest_partition.return_value = (["partition_key"], [column_value]) assert ( str( TrinoEngineSpec.where_latest_partition( # type: ignore database=MagicMock(), table=Table("table"), query=sql.select(text("* FROM table")), columns=[ { "column_name": "partition_key", "name": "partition_key", "type": column_type, "is_dttm": False, } ], ).compile( dialect=TrinoDialect(), compile_kwargs={"literal_binds": True}, ) ) == f"""SELECT * FROM table \nWHERE partition_key = {expected_value}""" # noqa: S608 ) @pytest.fixture def oauth2_config() -> OAuth2ClientConfig: """ Config for Trino OAuth2. """ return { "id": "trino", "secret": "very-secret", "scope": "", "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/", "authorization_request_uri": "https://trino.auth.server.example/realms/master/protocol/openid-connect/auth", "token_request_uri": "https://trino.auth.server.example/master/protocol/openid-connect/token", "request_content_type": "data", } def test_get_oauth2_token( mocker: MockerFixture, oauth2_config: OAuth2ClientConfig, ) -> None: """ Test `get_oauth2_token`. """ from superset.db_engine_specs.trino import TrinoEngineSpec requests = mocker.patch("superset.db_engine_specs.base.requests") requests.post().json.return_value = { "access_token": "access-token", "expires_in": 3600, "scope": "scope", "token_type": "Bearer", "refresh_token": "refresh-token", } assert TrinoEngineSpec.get_oauth2_token(oauth2_config, "code") == { "access_token": "access-token", "expires_in": 3600, "scope": "scope", "token_type": "Bearer", "refresh_token": "refresh-token", } requests.post.assert_called_with( "https://trino.auth.server.example/master/protocol/openid-connect/token", data={ "code": "code", "client_id": "trino", "client_secret": "very-secret", "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/", "grant_type": "authorization_code", }, timeout=30.0, ) @pytest.mark.parametrize( "time_grain,expected_result", [ ("PT1S", "date_trunc('second', CAST(col AS TIMESTAMP))"), ( "PT5S", "date_trunc('second', CAST(col AS TIMESTAMP)) - interval '1' second * (second(CAST(col AS TIMESTAMP)) % 5)", # noqa: E501 ), ( "PT30S", "date_trunc('second', CAST(col AS TIMESTAMP)) - interval '1' second * (second(CAST(col AS TIMESTAMP)) % 30)", # noqa: E501 ), ("PT1M", "date_trunc('minute', CAST(col AS TIMESTAMP))"), ( "PT5M", "date_trunc('minute', CAST(col AS TIMESTAMP)) - interval '1' minute * (minute(CAST(col AS TIMESTAMP)) % 5)", # noqa: E501 ), ( "PT10M", "date_trunc('minute', CAST(col AS TIMESTAMP)) - interval '1' minute * (minute(CAST(col AS TIMESTAMP)) % 10)", # noqa: E501 ), ( "PT15M", "date_trunc('minute', CAST(col AS TIMESTAMP)) - interval '1' minute * (minute(CAST(col AS TIMESTAMP)) % 15)", # noqa: E501 ), ( "PT0.5H", "date_trunc('minute', CAST(col AS TIMESTAMP)) - interval '1' minute * (minute(CAST(col AS TIMESTAMP)) % 30)", # noqa: E501 ), ("PT1H", "date_trunc('hour', CAST(col AS TIMESTAMP))"), ( "PT6H", "date_trunc('hour', CAST(col AS TIMESTAMP)) - interval '1' hour * (hour(CAST(col AS TIMESTAMP)) % 6)", # noqa: E501 ), ("P1D", "date_trunc('day', CAST(col AS TIMESTAMP))"), ("P1W", "date_trunc('week', CAST(col AS TIMESTAMP))"), ("P1M", "date_trunc('month', CAST(col AS TIMESTAMP))"), ("P3M", "date_trunc('quarter', CAST(col AS TIMESTAMP))"), ("P1Y", "date_trunc('year', CAST(col AS TIMESTAMP))"), ( "1969-12-28T00:00:00Z/P1W", "date_trunc('week', CAST(col AS TIMESTAMP) + interval '1' day) - interval '1' day", # noqa: E501 ), ("1969-12-29T00:00:00Z/P1W", "date_trunc('week', CAST(col AS TIMESTAMP))"), ( "P1W/1970-01-03T00:00:00Z", "date_trunc('week', CAST(col AS TIMESTAMP) + interval '1' day) + interval '5' day", # noqa: E501 ), ( "P1W/1970-01-04T00:00:00Z", "date_trunc('week', CAST(col AS TIMESTAMP)) + interval '6' day", ), ], ) def test_timegrain_expressions(time_grain: str, expected_result: str) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec as spec # noqa: N813 actual = str( spec.get_timestamp_expr(col=column("col"), pdf=None, time_grain=time_grain) ) assert actual == expected_result