Files
superset2/tests/unit_tests/db_engine_specs/test_trino.py

1388 lines
47 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
from __future__ import annotations
import copy
from collections import namedtuple
from datetime import datetime
from typing import Any, Optional
from unittest.mock import MagicMock, Mock, patch
import pandas as pd
import pytest
from flask import g, has_app_context
from pytest_mock import MockerFixture
from requests.exceptions import ConnectionError as RequestsConnectionError
from sqlalchemy import column, sql, text, types
from sqlalchemy.dialects import sqlite
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import NoSuchTableError
from trino.exceptions import TrinoExternalError, TrinoInternalError, TrinoUserError
from trino.sqlalchemy import datatype
from trino.sqlalchemy.dialect import TrinoDialect
import superset.config
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY
from superset.db_engine_specs.exceptions import (
SupersetDBAPIConnectionError,
SupersetDBAPIDatabaseError,
SupersetDBAPIOperationalError,
SupersetDBAPIProgrammingError,
)
from superset.sql.parse import Table
from superset.superset_typing import (
OAuth2ClientConfig,
ResultSetColumnType,
SQLAColumnType,
SQLType,
)
from superset.utils import json
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
assert_column_spec,
assert_convert_dttm,
)
from tests.unit_tests.fixtures.common import dttm # noqa: F401
def _assert_columns_equal(actual_cols, expected_cols) -> None:
"""
Assert equality of the given cols, bearing in mind sqlalchemy type
instances can't be compared for equality, so will have to be converted to
strings first.
"""
actual = copy.deepcopy(actual_cols)
expected = copy.deepcopy(expected_cols)
for col in actual:
col["type"] = str(col["type"])
for col in expected:
col["type"] = str(col["type"])
assert actual == expected
@pytest.mark.parametrize(
"extra,expected",
[
({}, {"engine_params": {"connect_args": {"source": "Apache Superset"}}}),
(
{
"first": 1,
"engine_params": {
"second": "two",
"connect_args": {"source": "foobar", "third": "three"},
},
},
{
"first": 1,
"engine_params": {
"second": "two",
"connect_args": {"source": "foobar", "third": "three"},
},
},
),
],
)
def test_get_extra_params(extra: dict[str, Any], expected: dict[str, Any]) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
database.extra = json.dumps(extra)
database.server_cert = None
assert TrinoEngineSpec.get_extra_params(database) == expected
@patch("superset.db_engine_specs.trino.create_ssl_cert_file")
def test_get_extra_params_with_server_cert(mock_create_ssl_cert_file: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
database.extra = json.dumps({})
database.server_cert = "TEST_CERT"
database.db_engine_spec = TrinoEngineSpec
mock_create_ssl_cert_file.return_value = "/path/to/tls.crt"
extra = TrinoEngineSpec.get_extra_params(database)
connect_args = extra.get("engine_params", {}).get("connect_args", {})
assert connect_args.get("http_scheme") == "https"
assert connect_args.get("verify") == "/path/to/tls.crt"
mock_create_ssl_cert_file.assert_called_once_with(database.server_cert)
@patch("trino.auth.BasicAuthentication")
def test_auth_basic(mock_auth: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
auth_params = {"username": "username", "password": "password"}
database.encrypted_extra = json.dumps(
{"auth_method": "basic", "auth_params": auth_params}
)
params: dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
assert connect_args.get("http_scheme") == "https"
mock_auth.assert_called_once_with(**auth_params)
@patch("trino.auth.KerberosAuthentication")
def test_auth_kerberos(mock_auth: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
auth_params = {
"service_name": "superset",
"mutual_authentication": False,
"delegate": True,
}
database.encrypted_extra = json.dumps(
{"auth_method": "kerberos", "auth_params": auth_params}
)
params: dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
assert connect_args.get("http_scheme") == "https"
mock_auth.assert_called_once_with(**auth_params)
@patch("trino.auth.CertificateAuthentication")
def test_auth_certificate(mock_auth: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
auth_params = {"cert": "/path/to/cert.pem", "key": "/path/to/key.pem"}
database.encrypted_extra = json.dumps(
{"auth_method": "certificate", "auth_params": auth_params}
)
params: dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
assert connect_args.get("http_scheme") == "https"
mock_auth.assert_called_once_with(**auth_params)
@patch("trino.auth.JWTAuthentication")
def test_auth_jwt(mock_auth: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
auth_params = {"token": "jwt-token-string"}
database.encrypted_extra = json.dumps(
{"auth_method": "jwt", "auth_params": auth_params}
)
params: dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
assert connect_args.get("http_scheme") == "https"
mock_auth.assert_called_once_with(**auth_params)
def test_auth_custom_auth() -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
auth_class = Mock()
auth_method = "custom_auth"
auth_params = {"params1": "params1", "params2": "params2"}
database.encrypted_extra = json.dumps(
{"auth_method": auth_method, "auth_params": auth_params}
)
with patch.dict(
"superset.config.ALLOWED_EXTRA_AUTHENTICATIONS",
{"trino": {"custom_auth": auth_class}},
clear=True,
):
params: dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
assert connect_args.get("http_scheme") == "https"
auth_class.assert_called_once_with(**auth_params)
def test_auth_custom_auth_denied() -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
auth_method = "my.module:TrinoAuthClass"
auth_params = {"params1": "params1", "params2": "params2"}
database.encrypted_extra = json.dumps(
{"auth_method": auth_method, "auth_params": auth_params}
)
superset.config.ALLOWED_EXTRA_AUTHENTICATIONS = {}
with pytest.raises(ValueError) as excinfo: # noqa: PT011
TrinoEngineSpec.update_params_from_encrypted_extra(database, {})
assert str(excinfo.value) == (
f"For security reason, custom authentication '{auth_method}' "
f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config"
)
@pytest.mark.parametrize(
"native_type,sqla_type,attrs,generic_type,is_dttm",
[
("BOOLEAN", types.Boolean, None, GenericDataType.BOOLEAN, False),
("TINYINT", types.Integer, None, GenericDataType.NUMERIC, False),
("SMALLINT", types.SmallInteger, None, GenericDataType.NUMERIC, False),
("INTEGER", types.Integer, None, GenericDataType.NUMERIC, False),
("BIGINT", types.BigInteger, None, GenericDataType.NUMERIC, False),
("REAL", types.FLOAT, None, GenericDataType.NUMERIC, False),
("DOUBLE", types.FLOAT, None, GenericDataType.NUMERIC, False),
("DECIMAL", types.DECIMAL, None, GenericDataType.NUMERIC, False),
("VARCHAR", types.String, None, GenericDataType.STRING, False),
("VARCHAR(20)", types.VARCHAR, {"length": 20}, GenericDataType.STRING, False),
("CHAR", types.String, None, GenericDataType.STRING, False),
("CHAR(2)", types.CHAR, {"length": 2}, GenericDataType.STRING, False),
("JSON", types.JSON, None, GenericDataType.STRING, False),
("TIMESTAMP", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True),
("TIMESTAMP(3)", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True),
(
"TIMESTAMP WITH TIME ZONE",
types.TIMESTAMP,
None,
GenericDataType.TEMPORAL,
True,
),
(
"TIMESTAMP(3) WITH TIME ZONE",
types.TIMESTAMP,
None,
GenericDataType.TEMPORAL,
True,
),
("DATE", types.Date, None, GenericDataType.TEMPORAL, True),
],
)
def test_get_column_spec(
native_type: str,
sqla_type: type[types.TypeEngine],
attrs: Optional[dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec as spec # noqa: N813
assert_column_spec(
spec,
native_type,
sqla_type,
attrs,
generic_type,
is_dttm,
)
@pytest.mark.parametrize(
"target_type,expected_result",
[
("TimeStamp", "TIMESTAMP '2019-01-02 03:04:05.678900'"),
("TimeStamp(3)", "TIMESTAMP '2019-01-02 03:04:05.678900'"),
("TimeStamp With Time Zone", "TIMESTAMP '2019-01-02 03:04:05.678900'"),
("TimeStamp(3) With Time Zone", "TIMESTAMP '2019-01-02 03:04:05.678900'"),
("Date", "DATE '2019-01-02'"),
("Other", None),
],
)
def test_convert_dttm(
target_type: str,
expected_result: Optional[str],
dttm: datetime, # noqa: F811
) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
assert_convert_dttm(TrinoEngineSpec, target_type, expected_result, dttm)
def test_get_extra_table_metadata(mocker: MockerFixture) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
db_mock = mocker.MagicMock()
db_mock.get_indexes = Mock(
return_value=[{"column_names": ["ds", "hour"], "name": "partition"}]
)
db_mock.get_extra = Mock(return_value={})
db_mock.has_view = Mock(return_value=None)
db_mock.get_df = Mock(return_value=pd.DataFrame({"ds": ["01-01-19"], "hour": [1]}))
result = TrinoEngineSpec.get_extra_table_metadata(
db_mock,
Table("test_table", "test_schema"),
)
assert result["partitions"]["cols"] == ["ds", "hour"]
assert result["partitions"]["latest"] == {"ds": "01-01-19", "hour": 1}
@patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query_success(engine_mock: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
query = Query()
cursor_mock = engine_mock.return_value.__enter__.return_value
assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is True
@patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query_failed(engine_mock: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
query = Query()
cursor_mock = engine_mock.raiseError.side_effect = Exception()
assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is False
@pytest.mark.parametrize(
"initial_extra,final_extra",
[
({}, {QUERY_EARLY_CANCEL_KEY: True}),
({QUERY_CANCEL_KEY: "my_key"}, {QUERY_CANCEL_KEY: "my_key"}),
],
)
def test_prepare_cancel_query(
initial_extra: dict[str, Any],
final_extra: dict[str, Any],
mocker: MockerFixture,
) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
query = Query(extra_json=json.dumps(initial_extra))
TrinoEngineSpec.prepare_cancel_query(query=query)
assert query.extra == final_extra
@pytest.mark.parametrize("cancel_early", [True, False])
@patch("superset.db_engine_specs.presto.PrestoBaseEngineSpec.handle_cursor")
@patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query")
@patch("superset.db_engine_specs.trino.db")
@patch("superset.db_engine_specs.trino.app")
def test_handle_cursor_early_cancel(
mock_app: Mock,
mock_db: Mock,
cancel_query_mock: Mock,
mock_presto_handle_cursor: Mock,
cancel_early: bool,
mocker: MockerFixture,
) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
mock_app.config = {"DB_POLL_INTERVAL_SECONDS": {"trino": 0}}
query_id = "myQueryId"
# Use spec to prevent MagicMock from creating attributes automatically
cursor_mock = mocker.MagicMock(spec=["query_id", "stats", "info_uri"])
cursor_mock.query_id = query_id
# Set stats to FINISHED so the progress loop exits immediately
cursor_mock.stats = {"state": "FINISHED", "completedSplits": 0, "totalSplits": 0}
query = Query()
if cancel_early:
TrinoEngineSpec.prepare_cancel_query(query=query)
TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query)
if cancel_early:
assert cancel_query_mock.call_args[1]["cancel_query_id"] == query_id
else:
assert cancel_query_mock.call_args is None
def test_execute_with_cursor_in_parallel(app, mocker: MockerFixture):
"""Test that `execute_with_cursor` fetches query ID from the cursor"""
from superset.db_engine_specs.trino import TrinoEngineSpec
query_id = "myQueryId"
mock_cursor = mocker.MagicMock()
mock_cursor.query_id = None
mock_query = mocker.MagicMock()
def _mock_execute(*args, **kwargs):
mock_cursor.query_id = query_id
with app.test_request_context("/some/place/"):
mock_cursor.execute.side_effect = _mock_execute
with patch.dict(
"superset.config.DISALLOWED_SQL_FUNCTIONS",
{},
clear=True,
):
TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
)
mock_query.set_extra_json_key.assert_called_once_with(
key=QUERY_CANCEL_KEY, value=query_id
)
def test_execute_with_cursor_app_context(app, mocker: MockerFixture):
"""Test that `execute_with_cursor` still contains the current app context"""
from superset.db_engine_specs.trino import TrinoEngineSpec
mock_cursor = mocker.MagicMock()
mock_cursor.query_id = None
mock_query = mocker.MagicMock()
def _mock_execute(*args, **kwargs):
assert has_app_context()
assert g.some_value == "some_value"
with app.test_request_context("/some/place/"):
g.some_value = "some_value"
with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute):
with patch.dict(
"superset.config.DISALLOWED_SQL_FUNCTIONS",
{},
clear=True,
):
TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
)
def test_get_columns(mocker: MockerFixture):
"""Test that ROW columns are not expanded without expand_rows"""
from superset.db_engine_specs.trino import TrinoEngineSpec
field1_type = datatype.parse_sqltype("row(a varchar, b date)")
field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))")
field3_type = datatype.parse_sqltype("int")
sqla_columns = [
SQLAColumnType(name="field1", type=field1_type, is_dttm=False),
SQLAColumnType(name="field2", type=field2_type, is_dttm=False),
SQLAColumnType(name="field3", type=field3_type, is_dttm=False),
]
mock_inspector = mocker.MagicMock()
mock_inspector.get_columns.return_value = sqla_columns
actual = TrinoEngineSpec.get_columns(mock_inspector, Table("table", "schema"))
expected = [
ResultSetColumnType(
name="field1", column_name="field1", type=field1_type, is_dttm=False
),
ResultSetColumnType(
name="field2", column_name="field2", type=field2_type, is_dttm=False
),
ResultSetColumnType(
name="field3", column_name="field3", type=field3_type, is_dttm=False
),
]
_assert_columns_equal(actual, expected)
def test_get_columns_error(mocker: MockerFixture):
"""
Test that we fallback to a `SHOW COLUMNS FROM ...` query.
"""
from superset.db_engine_specs.trino import TrinoEngineSpec
field1_type = datatype.parse_sqltype("row(a varchar, b date)")
field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))")
field3_type = datatype.parse_sqltype("int")
mock_inspector = mocker.MagicMock()
mock_inspector.engine.dialect = sqlite.dialect()
mock_inspector.get_columns.side_effect = NoSuchTableError(
"The specified table does not exist."
)
Row = namedtuple("Row", ["Column", "Type"])
mock_inspector.bind.execute().fetchall.return_value = [
Row("field1", "row(a varchar, b date)"),
Row("field2", "row(r1 row(a varchar, b varchar))"),
Row("field3", "int"),
]
actual = TrinoEngineSpec.get_columns(mock_inspector, Table("table", "schema"))
expected = [
ResultSetColumnType(
name="field1",
column_name="field1",
type=field1_type,
is_dttm=None,
type_generic=None,
default=None,
nullable=True,
),
ResultSetColumnType(
name="field2",
column_name="field2",
type=field2_type,
is_dttm=None,
type_generic=None,
default=None,
nullable=True,
),
ResultSetColumnType(
name="field3",
column_name="field3",
type=field3_type,
is_dttm=None,
type_generic=None,
default=None,
nullable=True,
),
]
_assert_columns_equal(actual, expected)
mock_inspector.bind.execute.assert_called_with('SHOW COLUMNS FROM schema."table"')
def test_get_columns_expand_rows(mocker: MockerFixture):
"""Test that ROW columns are correctly expanded with expand_rows"""
from superset.db_engine_specs.trino import TrinoEngineSpec
field1_type = datatype.parse_sqltype("row(a varchar, b date)")
field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))")
field3_type = datatype.parse_sqltype("int")
sqla_columns = [
SQLAColumnType(name="field1", type=field1_type, is_dttm=False),
SQLAColumnType(name="field2", type=field2_type, is_dttm=False),
SQLAColumnType(name="field3", type=field3_type, is_dttm=False),
]
mock_inspector = mocker.MagicMock()
mock_inspector.get_columns.return_value = sqla_columns
actual = TrinoEngineSpec.get_columns(
mock_inspector,
Table("table", "schema"),
{"expand_rows": True},
)
expected = [
ResultSetColumnType(
name="field1", column_name="field1", type=field1_type, is_dttm=False
),
ResultSetColumnType(
name="field1.a",
column_name="field1.a",
type=types.VARCHAR(),
is_dttm=False,
query_as='"field1"."a" AS "field1.a"',
),
ResultSetColumnType(
name="field1.b",
column_name="field1.b",
type=types.DATE(),
is_dttm=True,
query_as='"field1"."b" AS "field1.b"',
),
ResultSetColumnType(
name="field2", column_name="field2", type=field2_type, is_dttm=False
),
ResultSetColumnType(
name="field2.r1",
column_name="field2.r1",
type=datatype.parse_sqltype("row(a varchar, b varchar)"),
is_dttm=False,
query_as='"field2"."r1" AS "field2.r1"',
),
ResultSetColumnType(
name="field2.r1.a",
column_name="field2.r1.a",
type=types.VARCHAR(),
is_dttm=False,
query_as='"field2"."r1"."a" AS "field2.r1.a"',
),
ResultSetColumnType(
name="field2.r1.b",
column_name="field2.r1.b",
type=types.VARCHAR(),
is_dttm=False,
query_as='"field2"."r1"."b" AS "field2.r1.b"',
),
ResultSetColumnType(
name="field3", column_name="field3", type=field3_type, is_dttm=False
),
]
_assert_columns_equal(actual, expected)
def test_get_indexes_no_table():
from superset.db_engine_specs.trino import TrinoEngineSpec
db_mock = Mock()
inspector_mock = Mock()
inspector_mock.get_indexes = Mock(
side_effect=NoSuchTableError("The specified table does not exist.")
)
result = TrinoEngineSpec.get_indexes(
db_mock,
inspector_mock,
Table("test_table", "test_schema"),
)
assert result == []
def test_get_dbapi_exception_mapping():
from superset.db_engine_specs.trino import TrinoEngineSpec
mapping = TrinoEngineSpec.get_dbapi_exception_mapping()
assert mapping.get(TrinoUserError) == SupersetDBAPIProgrammingError
assert mapping.get(TrinoInternalError) == SupersetDBAPIDatabaseError
assert mapping.get(TrinoExternalError) == SupersetDBAPIOperationalError
assert mapping.get(RequestsConnectionError) == SupersetDBAPIConnectionError
assert mapping.get(Exception) is None
def test_adjust_engine_params_fully_qualified() -> None:
"""
Test the ``adjust_engine_params`` method when the URL has catalog and schema.
"""
from superset.db_engine_specs.trino import TrinoEngineSpec
url = make_url("trino://user:pass@localhost:8080/system/default")
uri = TrinoEngineSpec.adjust_engine_params(url, {})[0]
assert str(uri) == "trino://user:pass@localhost:8080/system/default"
uri = TrinoEngineSpec.adjust_engine_params(
url,
{},
schema="new_schema",
)[0]
assert str(uri) == "trino://user:pass@localhost:8080/system/new_schema"
uri = TrinoEngineSpec.adjust_engine_params(
url,
{},
catalog="new_catalog",
)[0]
assert str(uri) == "trino://user:pass@localhost:8080/new_catalog/default"
uri = TrinoEngineSpec.adjust_engine_params(
url,
{},
catalog="new_catalog",
schema="new_schema",
)[0]
assert str(uri) == "trino://user:pass@localhost:8080/new_catalog/new_schema"
def test_adjust_engine_params_catalog_only() -> None:
"""
Test the ``adjust_engine_params`` method when the URL has only the catalog.
"""
from superset.db_engine_specs.trino import TrinoEngineSpec
url = make_url("trino://user:pass@localhost:8080/system")
uri = TrinoEngineSpec.adjust_engine_params(url, {})[0]
assert str(uri) == "trino://user:pass@localhost:8080/system"
uri = TrinoEngineSpec.adjust_engine_params(
url,
{},
schema="new_schema",
)[0]
assert str(uri) == "trino://user:pass@localhost:8080/system/new_schema"
uri = TrinoEngineSpec.adjust_engine_params(
url,
{},
catalog="new_catalog",
)[0]
assert str(uri) == "trino://user:pass@localhost:8080/new_catalog"
uri = TrinoEngineSpec.adjust_engine_params(
url,
{},
catalog="new_catalog",
schema="new_schema",
)[0]
assert str(uri) == "trino://user:pass@localhost:8080/new_catalog/new_schema"
@pytest.mark.parametrize(
"sqlalchemy_uri,result",
[
("trino://user:pass@localhost:8080/system", "system"),
("trino://user:pass@localhost:8080/system/default", "system"),
("trino://trino@localhost:8081", None),
],
)
def test_get_default_catalog(sqlalchemy_uri: str, result: str | None) -> None:
"""
Test the ``get_default_catalog`` method.
"""
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.core import Database
database = Database(
database_name="my_db",
sqlalchemy_uri=sqlalchemy_uri,
)
assert TrinoEngineSpec.get_default_catalog(database) == result
@patch("superset.db_engine_specs.trino.TrinoEngineSpec.latest_partition")
@pytest.mark.parametrize(
["column_type", "column_value", "expected_value"],
[
(types.DATE(), "2023-05-01", "DATE '2023-05-01'"),
(types.TIMESTAMP(), "2023-05-01", "TIMESTAMP '2023-05-01'"),
(types.VARCHAR(), "2023-05-01", "'2023-05-01'"),
(types.INT(), 1234, "1234"),
],
)
def test_where_latest_partition(
mock_latest_partition,
column_type: SQLType,
column_value: Any,
expected_value: str,
) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
mock_latest_partition.return_value = (["partition_key"], [column_value])
assert (
str(
TrinoEngineSpec.where_latest_partition( # type: ignore
database=MagicMock(),
table=Table("table"),
query=sql.select(text("* FROM table")),
columns=[
{
"column_name": "partition_key",
"name": "partition_key",
"type": column_type,
"is_dttm": False,
}
],
).compile(
dialect=TrinoDialect(),
compile_kwargs={"literal_binds": True},
)
)
== f"""SELECT * FROM table \nWHERE partition_key = {expected_value}""" # noqa: S608
)
@pytest.fixture
def oauth2_config() -> OAuth2ClientConfig:
"""
Config for Trino OAuth2.
"""
return {
"id": "trino",
"secret": "very-secret",
"scope": "",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "https://trino.auth.server.example/realms/master/protocol/openid-connect/auth",
"token_request_uri": "https://trino.auth.server.example/master/protocol/openid-connect/token",
"request_content_type": "data",
}
def test_get_oauth2_token(
mocker: MockerFixture,
oauth2_config: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_token`.
"""
from superset.db_engine_specs.trino import TrinoEngineSpec
requests = mocker.patch("superset.db_engine_specs.base.requests")
requests.post().json.return_value = {
"access_token": "access-token",
"expires_in": 3600,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
assert TrinoEngineSpec.get_oauth2_token(oauth2_config, "code") == {
"access_token": "access-token",
"expires_in": 3600,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
requests.post.assert_called_with(
"https://trino.auth.server.example/master/protocol/openid-connect/token",
data={
"code": "code",
"client_id": "trino",
"client_secret": "very-secret",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"grant_type": "authorization_code",
},
timeout=30.0,
)
@pytest.mark.parametrize(
"time_grain,expected_result",
[
("PT1S", "date_trunc('second', CAST(col AS TIMESTAMP))"),
(
"PT5S",
"date_trunc('second', CAST(col AS TIMESTAMP)) - interval '1' second * (second(CAST(col AS TIMESTAMP)) % 5)", # noqa: E501
),
(
"PT30S",
"date_trunc('second', CAST(col AS TIMESTAMP)) - interval '1' second * (second(CAST(col AS TIMESTAMP)) % 30)", # noqa: E501
),
("PT1M", "date_trunc('minute', CAST(col AS TIMESTAMP))"),
(
"PT5M",
"date_trunc('minute', CAST(col AS TIMESTAMP)) - interval '1' minute * (minute(CAST(col AS TIMESTAMP)) % 5)", # noqa: E501
),
(
"PT10M",
"date_trunc('minute', CAST(col AS TIMESTAMP)) - interval '1' minute * (minute(CAST(col AS TIMESTAMP)) % 10)", # noqa: E501
),
(
"PT15M",
"date_trunc('minute', CAST(col AS TIMESTAMP)) - interval '1' minute * (minute(CAST(col AS TIMESTAMP)) % 15)", # noqa: E501
),
(
"PT0.5H",
"date_trunc('minute', CAST(col AS TIMESTAMP)) - interval '1' minute * (minute(CAST(col AS TIMESTAMP)) % 30)", # noqa: E501
),
("PT1H", "date_trunc('hour', CAST(col AS TIMESTAMP))"),
(
"PT6H",
"date_trunc('hour', CAST(col AS TIMESTAMP)) - interval '1' hour * (hour(CAST(col AS TIMESTAMP)) % 6)", # noqa: E501
),
("P1D", "date_trunc('day', CAST(col AS TIMESTAMP))"),
("P1W", "date_trunc('week', CAST(col AS TIMESTAMP))"),
("P1M", "date_trunc('month', CAST(col AS TIMESTAMP))"),
("P3M", "date_trunc('quarter', CAST(col AS TIMESTAMP))"),
("P1Y", "date_trunc('year', CAST(col AS TIMESTAMP))"),
(
"1969-12-28T00:00:00Z/P1W",
"date_trunc('week', CAST(col AS TIMESTAMP) + interval '1' day) - interval '1' day", # noqa: E501
),
("1969-12-29T00:00:00Z/P1W", "date_trunc('week', CAST(col AS TIMESTAMP))"),
(
"P1W/1970-01-03T00:00:00Z",
"date_trunc('week', CAST(col AS TIMESTAMP) + interval '1' day) + interval '5' day", # noqa: E501
),
(
"P1W/1970-01-04T00:00:00Z",
"date_trunc('week', CAST(col AS TIMESTAMP)) + interval '6' day",
),
],
)
def test_timegrain_expressions(time_grain: str, expected_result: str) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec as spec # noqa: N813
actual = str(
spec.get_timestamp_expr(col=column("col"), pdf=None, time_grain=time_grain)
)
assert actual == expected_result
@patch("superset.db_engine_specs.presto.PrestoBaseEngineSpec.handle_cursor")
@patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query")
@patch("superset.db_engine_specs.trino.db")
@patch("superset.db_engine_specs.trino.app")
def test_handle_cursor_progress_updates(
mock_app: Mock,
mock_db: Mock,
mock_cancel_query: Mock,
mock_presto_handle_cursor: Mock,
mocker: MockerFixture,
) -> None:
"""Test that handle_cursor updates query progress based on cursor stats."""
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
mock_app.config = {"DB_POLL_INTERVAL_SECONDS": {"trino": 0}}
cursor_mock = mocker.MagicMock(spec=["query_id", "stats", "info_uri"])
cursor_mock.query_id = "test-query-id"
# Simulate progress: 0/10 -> 5/10 -> 10/10 (FINISHED)
cursor_mock.stats = {"state": "RUNNING", "completedSplits": 0, "totalSplits": 10}
call_count = 0
def update_stats(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
cursor_mock.stats = {
"state": "RUNNING",
"completedSplits": 5,
"totalSplits": 10,
}
elif call_count >= 2:
cursor_mock.stats = {
"state": "FINISHED",
"completedSplits": 10,
"totalSplits": 10,
}
with patch("superset.db_engine_specs.trino.time.sleep", side_effect=update_stats):
query = Query()
query.status = "running"
TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query)
assert query.progress == 100.0
assert mock_db.session.commit.called
@patch("superset.db_engine_specs.presto.PrestoBaseEngineSpec.handle_cursor")
@patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query")
@patch("superset.db_engine_specs.trino.db")
@patch("superset.db_engine_specs.trino.app")
def test_handle_cursor_cancels_on_stopped_status(
mock_app: Mock,
mock_db: Mock,
mock_cancel_query: Mock,
mock_presto_handle_cursor: Mock,
mocker: MockerFixture,
) -> None:
"""Test that handle_cursor cancels query when status is STOPPED."""
from superset.common.db_query_status import QueryStatus
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
mock_app.config = {"DB_POLL_INTERVAL_SECONDS": {"trino": 0}}
cursor_mock = mocker.MagicMock(spec=["query_id", "stats", "info_uri"])
cursor_mock.query_id = "test-query-id"
cursor_mock.stats = {"state": "RUNNING", "completedSplits": 0, "totalSplits": 10}
query = Query()
query.status = QueryStatus.STOPPED
with patch("superset.db_engine_specs.trino.time.sleep"):
TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query)
mock_cancel_query.assert_called_once_with(
cursor=cursor_mock,
query=query,
cancel_query_id="test-query-id",
)
@patch("superset.db_engine_specs.presto.PrestoBaseEngineSpec.handle_cursor")
@patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query")
@patch("superset.db_engine_specs.trino.db")
@patch("superset.db_engine_specs.trino.app")
def test_handle_cursor_cancels_on_timed_out_status(
mock_app: Mock,
mock_db: Mock,
mock_cancel_query: Mock,
mock_presto_handle_cursor: Mock,
mocker: MockerFixture,
) -> None:
"""Test that handle_cursor cancels query when status is TIMED_OUT."""
from superset.common.db_query_status import QueryStatus
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
mock_app.config = {"DB_POLL_INTERVAL_SECONDS": {"trino": 0}}
cursor_mock = mocker.MagicMock(spec=["query_id", "stats", "info_uri"])
cursor_mock.query_id = "test-query-id"
cursor_mock.stats = {"state": "RUNNING", "completedSplits": 0, "totalSplits": 10}
query = Query()
query.status = QueryStatus.TIMED_OUT
with patch("superset.db_engine_specs.trino.time.sleep"):
TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query)
mock_cancel_query.assert_called_once_with(
cursor=cursor_mock,
query=query,
cancel_query_id="test-query-id",
)
@patch("superset.db_engine_specs.presto.PrestoBaseEngineSpec.handle_cursor")
@patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query")
@patch("superset.db_engine_specs.trino.db")
@patch("superset.db_engine_specs.trino.app")
def test_handle_cursor_breaks_on_execute_error(
mock_app: Mock,
mock_db: Mock,
mock_cancel_query: Mock,
mock_presto_handle_cursor: Mock,
mocker: MockerFixture,
) -> None:
"""Test that handle_cursor breaks the loop when execute_result has an error."""
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
mock_app.config = {"DB_POLL_INTERVAL_SECONDS": {"trino": 0}}
cursor_mock = mocker.MagicMock(
spec=["query_id", "stats", "info_uri", "_execute_result", "_execute_event"]
)
cursor_mock.query_id = "test-query-id"
cursor_mock.stats = {"state": "RUNNING", "completedSplits": 0, "totalSplits": 10}
cursor_mock._execute_result = {"error": Exception("Test error")}
cursor_mock._execute_event = None
query = Query()
query.status = "running"
# Should break immediately due to error in execute_result
TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query)
# cancel_query should not be called since we broke due to error, not status
mock_cancel_query.assert_not_called()
@patch("superset.db_engine_specs.presto.PrestoBaseEngineSpec.handle_cursor")
@patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query")
@patch("superset.db_engine_specs.trino.db")
@patch("superset.db_engine_specs.trino.app")
def test_handle_cursor_breaks_on_execute_event_set(
mock_app: Mock,
mock_db: Mock,
mock_cancel_query: Mock,
mock_presto_handle_cursor: Mock,
mocker: MockerFixture,
) -> None:
"""Test that handle_cursor breaks the loop when execute_event is set."""
import threading
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
mock_app.config = {"DB_POLL_INTERVAL_SECONDS": {"trino": 0}}
cursor_mock = mocker.MagicMock(
spec=["query_id", "stats", "info_uri", "_execute_result", "_execute_event"]
)
cursor_mock.query_id = "test-query-id"
cursor_mock.stats = {"state": "RUNNING", "completedSplits": 5, "totalSplits": 10}
execute_event = threading.Event()
execute_event.set() # Simulate thread completion
cursor_mock._execute_result = {}
cursor_mock._execute_event = execute_event
query = Query()
query.status = "running"
# Should break immediately since execute_event is set
TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query)
# cancel_query should not be called since we broke due to event being set
mock_cancel_query.assert_not_called()
@patch("superset.db_engine_specs.presto.PrestoBaseEngineSpec.handle_cursor")
@patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query")
@patch("superset.db_engine_specs.trino.db")
@patch("superset.db_engine_specs.trino.app")
def test_handle_cursor_handles_zero_total_splits(
mock_app: Mock,
mock_db: Mock,
mock_cancel_query: Mock,
mock_presto_handle_cursor: Mock,
mocker: MockerFixture,
) -> None:
"""Test that handle_cursor handles zero totalSplits without division error."""
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
mock_app.config = {"DB_POLL_INTERVAL_SECONDS": {"trino": 0}}
cursor_mock = mocker.MagicMock(spec=["query_id", "stats", "info_uri"])
cursor_mock.query_id = "test-query-id"
call_count = 0
def update_stats(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count >= 1:
cursor_mock.stats = {
"state": "FINISHED",
"completedSplits": 0,
"totalSplits": 0,
}
# Start with zero splits
cursor_mock.stats = {"state": "RUNNING", "completedSplits": 0, "totalSplits": 0}
with patch("superset.db_engine_specs.trino.time.sleep", side_effect=update_stats):
query = Query()
query.status = "running"
# Should not raise ZeroDivisionError
TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query)
# Progress should be 0/1 = 0 when totalSplits is 0
assert query.progress == 0.0
@patch("superset.db_engine_specs.presto.PrestoBaseEngineSpec.handle_cursor")
@patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query")
@patch("superset.db_engine_specs.trino.db")
@patch("superset.db_engine_specs.trino.app")
def test_handle_cursor_only_commits_on_progress_change(
mock_app: Mock,
mock_db: Mock,
mock_cancel_query: Mock,
mock_presto_handle_cursor: Mock,
mocker: MockerFixture,
) -> None:
"""Test that handle_cursor only commits when progress changes."""
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
mock_app.config = {"DB_POLL_INTERVAL_SECONDS": {"trino": 0}}
cursor_mock = mocker.MagicMock(spec=["query_id", "stats", "info_uri"])
cursor_mock.query_id = "test-query-id"
call_count = 0
def update_stats(*args, **kwargs):
nonlocal call_count
call_count += 1
# Keep same progress for first two iterations, then finish
if call_count < 3:
cursor_mock.stats = {
"state": "RUNNING",
"completedSplits": 5,
"totalSplits": 10,
}
else:
cursor_mock.stats = {
"state": "FINISHED",
"completedSplits": 10,
"totalSplits": 10,
}
cursor_mock.stats = {"state": "RUNNING", "completedSplits": 5, "totalSplits": 10}
with patch("superset.db_engine_specs.trino.time.sleep", side_effect=update_stats):
query = Query()
query.status = "running"
TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query)
# Initial commit from set_extra_json_key, then commits only when progress changes
# Progress changes: None->0.5, 0.5->0.5 (no commit), 0.5->1.0
# So we expect: 1 (initial) + 1 (0.5) + 1 (1.0) = 3 commits total
commit_calls = mock_db.session.commit.call_count
assert commit_calls >= 2 # At least initial commit and one progress update
@patch("superset.db_engine_specs.presto.PrestoBaseEngineSpec.handle_cursor")
@patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query")
@patch("superset.db_engine_specs.trino.db")
@patch("superset.db_engine_specs.trino.app")
def test_handle_cursor_sets_progress_text_for_planning_state(
mock_app: Mock,
mock_db: Mock,
mock_cancel_query: Mock,
mock_presto_handle_cursor: Mock,
mocker: MockerFixture,
) -> None:
"""Test that handle_cursor sets progress_text to 'Scheduled' for PLANNING state."""
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
mock_app.config = {"DB_POLL_INTERVAL_SECONDS": {"trino": 0}}
cursor_mock = mocker.MagicMock(spec=["query_id", "stats", "info_uri"])
cursor_mock.query_id = "test-query-id"
call_count = 0
def update_stats(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count >= 1:
cursor_mock.stats = {
"state": "FINISHED",
"completedSplits": 10,
"totalSplits": 10,
}
# Start with PLANNING state
cursor_mock.stats = {"state": "PLANNING", "completedSplits": 0, "totalSplits": 10}
with patch("superset.db_engine_specs.trino.time.sleep", side_effect=update_stats):
query = Query()
query.status = "running"
TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query)
# Check that progress_text was set to "Scheduled" for PLANNING state
assert query.extra.get("progress_text") is not None
@patch("superset.db_engine_specs.presto.PrestoBaseEngineSpec.handle_cursor")
@patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query")
@patch("superset.db_engine_specs.trino.db")
@patch("superset.db_engine_specs.trino.app")
def test_handle_cursor_sets_progress_text_for_queued_state(
mock_app: Mock,
mock_db: Mock,
mock_cancel_query: Mock,
mock_presto_handle_cursor: Mock,
mocker: MockerFixture,
) -> None:
"""Test that handle_cursor sets progress_text to 'Queued' for QUEUED state."""
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
mock_app.config = {"DB_POLL_INTERVAL_SECONDS": {"trino": 0}}
cursor_mock = mocker.MagicMock(spec=["query_id", "stats", "info_uri"])
cursor_mock.query_id = "test-query-id"
call_count = 0
def update_stats(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count >= 1:
cursor_mock.stats = {
"state": "FINISHED",
"completedSplits": 10,
"totalSplits": 10,
}
# Start with QUEUED state
cursor_mock.stats = {"state": "QUEUED", "completedSplits": 0, "totalSplits": 0}
with patch("superset.db_engine_specs.trino.time.sleep", side_effect=update_stats):
query = Query()
query.status = "running"
TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query)
# Check that progress_text was set
assert query.extra.get("progress_text") is not None
@patch("superset.db_engine_specs.presto.PrestoBaseEngineSpec.handle_cursor")
@patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query")
@patch("superset.db_engine_specs.trino.db")
@patch("superset.db_engine_specs.trino.app")
def test_handle_cursor_sets_progress_text_to_state_for_unmapped_states(
mock_app: Mock,
mock_db: Mock,
mock_cancel_query: Mock,
mock_presto_handle_cursor: Mock,
mocker: MockerFixture,
) -> None:
"""Test that handle_cursor sets progress_text to raw state for unmapped states."""
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
mock_app.config = {"DB_POLL_INTERVAL_SECONDS": {"trino": 0}}
cursor_mock = mocker.MagicMock(spec=["query_id", "stats", "info_uri"])
cursor_mock.query_id = "test-query-id"
# Start directly with FINISHED state (not in the mapping)
cursor_mock.stats = {"state": "FINISHED", "completedSplits": 10, "totalSplits": 10}
query = Query()
query.status = "running"
TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query)
# Check that progress_text was set to the raw state "FINISHED"
# since FINISHED is not in the mapping (only PLANNING and QUEUED are mapped)
assert query.extra.get("progress_text") == "FINISHED"
@patch("superset.db_engine_specs.presto.PrestoBaseEngineSpec.handle_cursor")
@patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query")
@patch("superset.db_engine_specs.trino.db")
@patch("superset.db_engine_specs.trino.app")
def test_handle_cursor_commits_on_progress_text_change(
mock_app: Mock,
mock_db: Mock,
mock_cancel_query: Mock,
mock_presto_handle_cursor: Mock,
mocker: MockerFixture,
) -> None:
"""Test that handle_cursor commits only when progress_text changes."""
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
mock_app.config = {"DB_POLL_INTERVAL_SECONDS": {"trino": 0}}
cursor_mock = mocker.MagicMock(spec=["query_id", "stats", "info_uri"])
cursor_mock.query_id = "test-query-id"
call_count = 0
def update_stats(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
# State changes from QUEUED to RUNNING but progress stays at 0
cursor_mock.stats = {
"state": "RUNNING",
"completedSplits": 0,
"totalSplits": 10,
}
elif call_count >= 2:
cursor_mock.stats = {
"state": "FINISHED",
"completedSplits": 10,
"totalSplits": 10,
}
# Start with QUEUED state and 0 progress
cursor_mock.stats = {"state": "QUEUED", "completedSplits": 0, "totalSplits": 10}
with patch("superset.db_engine_specs.trino.time.sleep", side_effect=update_stats):
query = Query()
query.status = "running"
TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query)
# There should be commits for progress_text changes
assert mock_db.session.commit.call_count >= 2