feat: catalog support for Databricks native (#28394)

This commit is contained in:
Beto Dealmeida
2024-05-09 17:41:15 -04:00
committed by GitHub
parent e516bba8fc
commit f29e1e4c29
10 changed files with 443 additions and 36 deletions

View File

@@ -39,7 +39,6 @@ if TYPE_CHECKING:
from superset.models.core import Database
#
class DatabricksBaseSchema(Schema):
"""
Fields that are required for both Databricks drivers that uses a
@@ -371,6 +370,8 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
"extra",
}
supports_dynamic_schema = supports_catalog = supports_dynamic_catalog = True
@classmethod
def build_sqlalchemy_uri( # type: ignore
cls, parameters: DatabricksNativeParametersType, *_
@@ -428,6 +429,35 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
spec.components.schema(cls.__name__, schema=cls.properties_schema)
return spec.to_dict()["components"]["schemas"][cls.__name__]
@classmethod
def get_default_catalog(
cls,
database: Database,
) -> str | None:
with database.get_inspector() as inspector:
return inspector.bind.execute("SELECT current_catalog()").scalar()
@classmethod
def get_prequeries(
cls,
catalog: str | None = None,
schema: str | None = None,
) -> list[str]:
prequeries = []
if catalog:
prequeries.append(f"USE CATALOG {catalog}")
if schema:
prequeries.append(f"USE SCHEMA {schema}")
return prequeries
@classmethod
def get_catalog_names(
cls,
database: Database,
inspector: Inspector,
) -> set[str]:
return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")}
class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
engine = "databricks"
@@ -455,6 +485,8 @@ class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
"http_path_field",
}
supports_dynamic_schema = supports_catalog = supports_dynamic_catalog = True
@classmethod
def build_sqlalchemy_uri( # type: ignore
cls, parameters: DatabricksPythonConnectorParametersType, *_
@@ -502,3 +534,34 @@ class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
"default_schema": query["schema"],
"encryption": encryption,
}
@classmethod
def get_default_catalog(
cls,
database: Database,
) -> str | None:
return database.url_object.query.get("catalog")
@classmethod
def get_catalog_names(
cls,
database: Database,
inspector: Inspector,
) -> set[str]:
return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")}
@classmethod
def adjust_engine_params(
cls,
uri: URL,
connect_args: dict[str, Any],
catalog: str | None = None,
schema: str | None = None,
) -> tuple[URL, dict[str, Any]]:
if catalog:
uri = uri.update_query_dict({"catalog": catalog})
if schema:
uri = uri.update_query_dict({"schema": schema})
return uri, connect_args