feat: improve Doris catalog support (#34140)

This commit is contained in:
Beto Dealmeida
2025-07-14 12:01:08 -04:00
committed by GitHub
parent 0aa48b6564
commit 68b84acd93
2 changed files with 101 additions and 28 deletions

View File

@@ -31,6 +31,9 @@ from superset.errors import SupersetErrorType
from superset.models.core import Database
from superset.utils.core import GenericDataType
DEFAULT_CATALOG = "internal"
DEFAULT_SCHEMA = "information_schema"
# Regular expressions to catch custom errors
CONNECTION_ACCESS_DENIED_REGEX = re.compile(
"Access denied for user '(?P<username>.*?)'"
@@ -248,29 +251,39 @@ class DorisEngineSpec(MySQLEngineSpec):
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> tuple[URL, dict[str, Any]]:
if catalog:
pass
elif uri.database and "." in uri.database:
catalog, _ = uri.database.split(".", 1)
if not uri.database:
raise ValueError("Doris requires a database to be specified in the URI.")
elif "." not in uri.database:
current_catalog, current_schema = None, uri.database
else:
catalog = "internal"
current_catalog, current_schema = uri.database.split(".", 1)
# In Apache Doris, each catalog has an information_schema for BI tool
# compatibility. See: https://github.com/apache/doris/pull/28919
schema = schema or "information_schema"
database = ".".join([catalog or "", schema])
# and possibly override them
catalog = catalog or current_catalog
schema = schema or current_schema
database = ".".join(part for part in (catalog, schema) if part)
uri = uri.set(database=database)
return uri, connect_args
@classmethod
def get_default_catalog(cls, database: Database) -> Optional[str]:
def get_default_catalog(cls, database: Database) -> str:
"""
Return the default catalog.
"""
if database.url_object.database is None:
return None
# first check the URI to see if a default catalog is set
if database.url_object.database and "." in database.url_object.database:
return database.url_object.database.split(".")[0]
return database.url_object.database.split(".")[0]
# if not, iterate over existing catalogs and find the current one
with database.get_sqla_engine() as engine:
for catalog in engine.execute("SHOW CATALOGS"):
if catalog.IsCurrent:
return catalog.CatalogName
# fallback to "internal"
return DEFAULT_CATALOG
@classmethod
def get_catalog_names(
@@ -301,9 +314,8 @@ class DorisEngineSpec(MySQLEngineSpec):
doris://localhost:9030/catalog.database
"""
database = sqlalchemy_uri.database.strip("/")
if "." not in database:
if not sqlalchemy_uri.database:
return None
return parse.unquote(database.split(".")[1])
schema = sqlalchemy_uri.database.split(".")[-1].strip("/")
return parse.unquote(schema)

View File

@@ -19,6 +19,7 @@ from typing import Any, Optional
from unittest.mock import Mock
import pytest
from pytest_mock import MockerFixture
from sqlalchemy import JSON, types
from sqlalchemy.engine.url import make_url
@@ -81,30 +82,62 @@ def test_get_column_spec(
@pytest.mark.parametrize(
"sqlalchemy_uri,connect_args,return_schema,return_connect_args",
"sqlalchemy_uri, connect_args, catalog, schema, return_schema,return_connect_args",
[
(
"doris://user:password@host/db1",
{"param1": "some_value"},
"internal.information_schema",
None,
None,
"db1",
{"param1": "some_value"},
),
(
"pydoris://user:password@host/db1",
{"param1": "some_value"},
"internal.information_schema",
None,
None,
"db1",
{"param1": "some_value"},
),
(
"doris://user:password@host/catalog1.db1",
{"param1": "some_value"},
"catalog1.information_schema",
None,
None,
"catalog1.db1",
{"param1": "some_value"},
),
(
"pydoris://user:password@host/catalog1.db1",
{"param1": "some_value"},
"catalog1.information_schema",
None,
None,
"catalog1.db1",
{"param1": "some_value"},
),
(
"pydoris://user:password@host/catalog1.db1",
{"param1": "some_value"},
"catalog2",
None,
"catalog2.db1",
{"param1": "some_value"},
),
(
"pydoris://user:password@host/catalog1.db1",
{"param1": "some_value"},
None,
"db2",
"catalog1.db2",
{"param1": "some_value"},
),
(
"pydoris://user:password@host/catalog1.db1",
{"param1": "some_value"},
"catalog2",
"db2",
"catalog2.db2",
{"param1": "some_value"},
),
],
@@ -112,6 +145,8 @@ def test_get_column_spec(
def test_adjust_engine_params(
sqlalchemy_uri: str,
connect_args: dict[str, Any],
catalog: str | None,
schema: str | None,
return_schema: str,
return_connect_args: dict[str, Any],
) -> None:
@@ -119,18 +154,36 @@ def test_adjust_engine_params(
url = make_url(sqlalchemy_uri)
returned_url, returned_connect_args = DorisEngineSpec.adjust_engine_params(
url, connect_args
url,
connect_args,
catalog,
schema,
)
assert returned_url.database == return_schema
assert returned_connect_args == return_connect_args
def test_adjust_engine_params_no_database() -> None:
"""
Test that we raise an exception when the database is not specified.
"""
from superset.db_engine_specs.doris import DorisEngineSpec
url = make_url("doris://user:password@host")
with pytest.raises(
ValueError,
match="Doris requires a database to be specified in the URI.",
):
DorisEngineSpec.adjust_engine_params(url, {})
@pytest.mark.parametrize(
"url,expected_schema",
[
("doris://localhost:9030/hive.test", "test"),
("doris://localhost:9030/hive", None),
("doris://localhost:9030/test", "test"),
("doris://localhost:9030/", None),
],
)
def test_get_schema_from_engine_params(
@@ -154,12 +207,14 @@ def test_get_schema_from_engine_params(
"database_value,expected_catalog",
[
("catalog1.schema1", "catalog1"),
("catalog1", "catalog1"),
(None, None),
("schema1", "catalog2"),
("", "catalog2"),
],
)
def test_get_default_catalog(
database_value: Optional[str], expected_catalog: Optional[str]
mocker: MockerFixture,
database_value: Optional[str],
expected_catalog: Optional[str],
) -> None:
"""
Test the ``get_default_catalog`` method.
@@ -167,8 +222,14 @@ def test_get_default_catalog(
from superset.db_engine_specs.doris import DorisEngineSpec
from superset.models.core import Database
database = Mock(spec=Database)
database = mocker.MagicMock(spec=Database)
database.url_object.database = database_value
rows = [
mocker.MagicMock(IsCurrent=False, CatalogName="catalog1"),
mocker.MagicMock(IsCurrent=True, CatalogName="catalog2"),
]
with database.get_sqla_engine() as engine:
engine.execute.return_value = rows
assert DorisEngineSpec.get_default_catalog(database) == expected_catalog