feat(db_engine): Implement user impersonation support for StarRocks (#28110)

This commit is contained in:
Patrick Schmidt
2024-09-06 18:13:38 +02:00
committed by GitHub
parent d3f5c795ff
commit 6294e339e2
11 changed files with 120 additions and 13 deletions

View File

@@ -247,20 +247,24 @@ def test_convert_dttm(
assert_convert_dttm(spec, target_type, expected_result, dttm)
def test_get_prequeries() -> None:
def test_get_prequeries(mocker: MockerFixture) -> None:
"""
Test the ``get_prequeries`` method.
"""
from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
assert DatabricksNativeEngineSpec.get_prequeries() == []
assert DatabricksNativeEngineSpec.get_prequeries(schema="test") == [
database = mocker.MagicMock()
assert DatabricksNativeEngineSpec.get_prequeries(database) == []
assert DatabricksNativeEngineSpec.get_prequeries(database, schema="test") == [
"USE SCHEMA test",
]
assert DatabricksNativeEngineSpec.get_prequeries(catalog="test") == [
assert DatabricksNativeEngineSpec.get_prequeries(database, catalog="test") == [
"USE CATALOG test",
]
assert DatabricksNativeEngineSpec.get_prequeries(catalog="foo", schema="bar") == [
assert DatabricksNativeEngineSpec.get_prequeries(
database, catalog="foo", schema="bar"
) == [
"USE CATALOG foo",
"USE SCHEMA bar",
]

View File

@@ -66,13 +66,15 @@ def test_get_table_comment_empty(mocker: MockerFixture):
)
def test_get_prequeries() -> None:
def test_get_prequeries(mocker: MockerFixture) -> None:
"""
Test the ``get_prequeries`` method.
"""
from superset.db_engine_specs.db2 import Db2EngineSpec
assert Db2EngineSpec.get_prequeries() == []
assert Db2EngineSpec.get_prequeries(schema="my_schema") == [
database = mocker.MagicMock()
assert Db2EngineSpec.get_prequeries(database) == []
assert Db2EngineSpec.get_prequeries(database, schema="my_schema") == [
'set current_schema "my_schema"'
]

View File

@@ -137,14 +137,16 @@ def test_get_schema_from_engine_params() -> None:
)
def test_get_prequeries() -> None:
def test_get_prequeries(mocker: MockerFixture) -> None:
"""
Test the ``get_prequeries`` method.
"""
from superset.db_engine_specs.postgres import PostgresEngineSpec
assert PostgresEngineSpec.get_prequeries() == []
assert PostgresEngineSpec.get_prequeries(schema="test") == [
database = mocker.MagicMock()
assert PostgresEngineSpec.get_prequeries(database) == []
assert PostgresEngineSpec.get_prequeries(database, schema="test") == [
'set search_path = "test"'
]

View File

@@ -18,6 +18,7 @@
from typing import Any, Optional
import pytest
from pytest_mock import MockerFixture
from sqlalchemy import JSON, types
from sqlalchemy.engine.url import make_url
@@ -124,3 +125,47 @@ def test_get_schema_from_engine_params() -> None:
)
is None
)
def test_impersonation_username(mocker: MockerFixture) -> None:
"""
Test impersonation and make sure that `get_url_for_impersonation` leaves the URL
unchanged and that `get_prequeries` returns the appropriate impersonation query.
"""
from superset.db_engine_specs.starrocks import StarRocksEngineSpec
database = mocker.MagicMock()
database.impersonate_user = True
database.get_effective_user.return_value = "alice"
assert StarRocksEngineSpec.get_url_for_impersonation(
url=make_url("starrocks://service_user@localhost:9030/hive.default"),
impersonate_user=True,
username="alice",
access_token=None,
) == make_url("starrocks://service_user@localhost:9030/hive.default")
assert StarRocksEngineSpec.get_prequeries(database) == [
'EXECUTE AS "alice" WITH NO REVERT;'
]
def test_impersonation_disabled(mocker: MockerFixture) -> None:
"""
Test that impersonation is not applied when the feature is disabled in
`get_url_for_impersonation` and `get_prequeries`.
"""
from superset.db_engine_specs.starrocks import StarRocksEngineSpec
database = mocker.MagicMock()
database.impersonate_user = False
database.get_effective_user.return_value = "alice"
assert StarRocksEngineSpec.get_url_for_impersonation(
url=make_url("starrocks://service_user@localhost:9030/hive.default"),
impersonate_user=False,
username="alice",
access_token=None,
) == make_url("starrocks://service_user@localhost:9030/hive.default")
assert StarRocksEngineSpec.get_prequeries(database) == []