chore(Databricks): New Databricks driver (#28393)

This commit is contained in:
Vitor Avila
2024-05-09 15:58:03 -03:00
committed by GitHub
parent e6a85c5901
commit 307ebeaa19
6 changed files with 333 additions and 99 deletions

View File

@@ -116,6 +116,69 @@ export const databaseField = ({
helpText={t('Copy the name of the database you are trying to connect to.')} helpText={t('Copy the name of the database you are trying to connect to.')}
/> />
); );
export const defaultCatalogField = ({
required,
changeMethods,
getValidation,
validationErrors,
db,
}: FieldPropTypes) => (
<ValidatedInput
id="default_catalog"
name="default_catalog"
required={required}
value={db?.parameters?.default_catalog}
validationMethods={{ onBlur: getValidation }}
errorMessage={validationErrors?.default_catalog}
placeholder={t('e.g. hive_metastore')}
label={t('Default Catalog')}
onChange={changeMethods.onParametersChange}
helpText={t('The default catalog that should be used for the connection.')}
/>
);
export const defaultSchemaField = ({
required,
changeMethods,
getValidation,
validationErrors,
db,
}: FieldPropTypes) => (
<ValidatedInput
id="default_schema"
name="default_schema"
required={required}
value={db?.parameters?.default_schema}
validationMethods={{ onBlur: getValidation }}
errorMessage={validationErrors?.default_schema}
placeholder={t('e.g. default')}
label={t('Default Schema')}
onChange={changeMethods.onParametersChange}
helpText={t('The default schema that should be used for the connection.')}
/>
);
export const httpPathField = ({
required,
changeMethods,
getValidation,
validationErrors,
db,
}: FieldPropTypes) => {
console.error(db);
return (
<ValidatedInput
id="http_path_field"
name="http_path_field"
required={required}
value={db?.parameters?.http_path_field}
validationMethods={{ onBlur: getValidation }}
errorMessage={validationErrors?.http_path}
placeholder={t('e.g. sql/protocolv1/o/12345')}
label="HTTP Path"
onChange={changeMethods.onParametersChange}
helpText={t('Copy the name of the HTTP Path of your cluster.')}
/>
);
};
export const usernameField = ({ export const usernameField = ({
required, required,
changeMethods, changeMethods,

View File

@@ -27,10 +27,13 @@ import { Form } from 'src/components/Form';
import { import {
accessTokenField, accessTokenField,
databaseField, databaseField,
defaultCatalogField,
defaultSchemaField,
displayField, displayField,
forceSSLField, forceSSLField,
hostField, hostField,
httpPath, httpPath,
httpPathField,
passwordField, passwordField,
portField, portField,
queryField, queryField,
@@ -47,10 +50,13 @@ export const FormFieldOrder = [
'host', 'host',
'port', 'port',
'database', 'database',
'default_catalog',
'default_schema',
'username', 'username',
'password', 'password',
'access_token', 'access_token',
'http_path', 'http_path',
'http_path_field',
'database_name', 'database_name',
'credentials_info', 'credentials_info',
'service_account_info', 'service_account_info',
@@ -71,8 +77,11 @@ const SSHTunnelSwitchComponent =
const FORM_FIELD_MAP = { const FORM_FIELD_MAP = {
host: hostField, host: hostField,
http_path: httpPath, http_path: httpPath,
http_path_field: httpPathField,
port: portField, port: portField,
database: databaseField, database: databaseField,
default_catalog: defaultCatalogField,
default_schema: defaultSchemaField,
username: usernameField, username: usernameField,
password: passwordField, password: passwordField,
access_token: accessTokenField, access_token: accessTokenField,

View File

@@ -633,11 +633,23 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
const history = useHistory(); const history = useHistory();
const dbModel: DatabaseForm = const dbModel: DatabaseForm =
// TODO: we need a centralized engine in one place
// first try to match both engine and driver
availableDbs?.databases?.find(
(available: {
engine: string | undefined;
default_driver: string | undefined;
}) =>
available.engine === (isEditMode ? db?.backend : db?.engine) &&
available.default_driver === db?.driver,
) ||
// alternatively try to match only engine
availableDbs?.databases?.find( availableDbs?.databases?.find(
(available: { engine: string | undefined }) => (available: { engine: string | undefined }) =>
// TODO: we need a centralized engine in one place
available.engine === (isEditMode ? db?.backend : db?.engine), available.engine === (isEditMode ? db?.backend : db?.engine),
) || {}; ) ||
{};
// Test Connection logic // Test Connection logic
const testConnection = () => { const testConnection = () => {

View File

@@ -63,6 +63,9 @@ export type DatabaseObject = {
host?: string; host?: string;
port?: number; port?: number;
database?: string; database?: string;
default_catalog?: string;
default_schema?: string;
http_path_field?: string;
username?: string; username?: string;
password?: string; password?: string;
encryption?: boolean; encryption?: boolean;
@@ -126,6 +129,18 @@ export type DatabaseForm = {
description: string; description: string;
type: string; type: string;
}; };
default_catalog: {
description: string;
type: string;
};
default_schema: {
description: string;
type: string;
};
http_path_field: {
description: string;
type: string;
};
host: { host: {
description: string; description: string;
type: string; type: string;

View File

@@ -18,7 +18,7 @@ from __future__ import annotations
import json import json
from datetime import datetime from datetime import datetime
from typing import Any, TYPE_CHECKING, TypedDict from typing import Any, TYPE_CHECKING, TypedDict, Union
from apispec import APISpec from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin from apispec.ext.marshmallow import MarshmallowPlugin
@@ -40,10 +40,10 @@ if TYPE_CHECKING:
# #
class DatabricksParametersSchema(Schema): class DatabricksBaseSchema(Schema):
""" """
This is the list of fields that are expected Fields that are required for both Databricks drivers that uses a
from the client in order to build the sqlalchemy string dynamic form.
""" """
access_token = fields.Str(required=True) access_token = fields.Str(required=True)
@@ -53,44 +53,85 @@ class DatabricksParametersSchema(Schema):
metadata={"description": __("Database port")}, metadata={"description": __("Database port")},
validate=Range(min=0, max=2**16, max_inclusive=False), validate=Range(min=0, max=2**16, max_inclusive=False),
) )
database = fields.Str(required=True)
encryption = fields.Boolean( encryption = fields.Boolean(
required=False, required=False,
metadata={"description": __("Use an encrypted connection to the database")}, metadata={"description": __("Use an encrypted connection to the database")},
) )
class DatabricksPropertiesSchema(DatabricksParametersSchema): class DatabricksBaseParametersType(TypedDict):
""" """
This is the list of fields expected The parameters are all the keys that do not exist on the Database model.
for successful database creation execution These are used to build the sqlalchemy uri.
"""
http_path = fields.Str(required=True)
class DatabricksParametersType(TypedDict):
"""
The parameters are all the keys that do
not exist on the Database model.
These are used to build the sqlalchemy uri
""" """
access_token: str access_token: str
host: str host: str
port: int port: int
database: str
encryption: bool encryption: bool
class DatabricksPropertiesType(TypedDict): class DatabricksNativeSchema(DatabricksBaseSchema):
""" """
All properties that need to be available to Additional fields required only for the DatabricksNativeEngineSpec.
this engine in order to create a connection
if the dynamic form is used
""" """
parameters: DatabricksParametersType database = fields.Str(required=True)
class DatabricksNativePropertiesSchema(DatabricksNativeSchema):
"""
Properties required only for the DatabricksNativeEngineSpec.
"""
http_path = fields.Str(required=True)
class DatabricksNativeParametersType(DatabricksBaseParametersType):
"""
Additional parameters required only for the DatabricksNativeEngineSpec.
"""
database: str
class DatabricksNativePropertiesType(TypedDict):
"""
All properties that need to be available to the DatabricksNativeEngineSpec
in order tocreate a connection if the dynamic form is used.
"""
parameters: DatabricksNativeParametersType
extra: str
class DatabricksPythonConnectorSchema(DatabricksBaseSchema):
"""
Additional fields required only for the DatabricksPythonConnectorEngineSpec.
"""
http_path_field = fields.Str(required=True)
default_catalog = fields.Str(required=True)
default_schema = fields.Str(required=True)
class DatabricksPythonConnectorParametersType(DatabricksBaseParametersType):
"""
Additional parameters required only for the DatabricksPythonConnectorEngineSpec.
"""
http_path_field: str
default_catalog: str
default_schema: str
class DatabricksPythonConnectorPropertiesType(TypedDict):
"""
All properties that need to be available to the DatabricksPythonConnectorEngineSpec
in order to create a connection if the dynamic form is used.
"""
parameters: DatabricksPythonConnectorParametersType
extra: str extra: str
@@ -125,13 +166,7 @@ class DatabricksHiveEngineSpec(HiveEngineSpec):
_time_grain_expressions = time_grain_expressions _time_grain_expressions = time_grain_expressions
class DatabricksODBCEngineSpec(BaseEngineSpec): class DatabricksBaseEngineSpec(BaseEngineSpec):
engine_name = "Databricks SQL Endpoint"
engine = "databricks"
drivers = {"pyodbc": "ODBC driver for SQL endpoint"}
default_driver = "pyodbc"
_time_grain_expressions = time_grain_expressions _time_grain_expressions = time_grain_expressions
@classmethod @classmethod
@@ -145,20 +180,23 @@ class DatabricksODBCEngineSpec(BaseEngineSpec):
return HiveEngineSpec.epoch_to_dttm() return HiveEngineSpec.epoch_to_dttm()
class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec): class DatabricksODBCEngineSpec(DatabricksBaseEngineSpec):
engine_name = "Databricks" engine_name = "Databricks SQL Endpoint"
engine = "databricks" engine = "databricks"
drivers = {"connector": "Native all-purpose driver"} drivers = {"pyodbc": "ODBC driver for SQL endpoint"}
default_driver = "connector" default_driver = "pyodbc"
parameters_schema = DatabricksParametersSchema()
properties_schema = DatabricksPropertiesSchema()
sqlalchemy_uri_placeholder = ( class DatabricksDynamicBaseEngineSpec(BasicParametersMixin, DatabricksBaseEngineSpec):
"databricks+connector://token:{access_token}@{host}:{port}/{database_name}" default_driver = ""
)
encryption_parameters = {"ssl": "1"} encryption_parameters = {"ssl": "1"}
required_parameters = {"access_token", "host", "port"}
context_key_mapping = {
"access_token": "password",
"host": "hostname",
"port": "port",
}
@staticmethod @staticmethod
def get_extra_params(database: Database) -> dict[str, Any]: def get_extra_params(database: Database) -> dict[str, Any]:
@@ -190,30 +228,6 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec)
database, inspector, schema database, inspector, schema
) - cls.get_view_names(database, inspector, schema) ) - cls.get_view_names(database, inspector, schema)
@classmethod
def build_sqlalchemy_uri( # type: ignore
cls, parameters: DatabricksParametersType, *_
) -> str:
query = {}
if parameters.get("encryption"):
if not cls.encryption_parameters:
raise Exception( # pylint: disable=broad-exception-raised
"Unable to build a URL with encryption enabled"
)
query.update(cls.encryption_parameters)
return str(
URL.create(
f"{cls.engine}+{cls.default_driver}".rstrip("+"),
username="token",
password=parameters.get("access_token"),
host=parameters["host"],
port=parameters["port"],
database=parameters["database"],
query=query,
)
)
@classmethod @classmethod
def extract_errors( def extract_errors(
cls, ex: Exception, context: dict[str, Any] | None = None cls, ex: Exception, context: dict[str, Any] | None = None
@@ -224,13 +238,10 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec)
# access_token isn't currently parseable from the # access_token isn't currently parseable from the
# databricks error response, but adding it in here # databricks error response, but adding it in here
# for reference if their error message changes # for reference if their error message changes
context = {
"host": context.get("hostname"), for key, value in cls.context_key_mapping.items():
"access_token": context.get("password"), context[key] = context.get(value)
"port": context.get("port"),
"username": context.get("username"),
"database": context.get("database"),
}
for regex, (message, error_type, extra) in cls.custom_errors.items(): for regex, (message, error_type, extra) in cls.custom_errors.items():
match = regex.search(raw_message) match = regex.search(raw_message)
if match: if match:
@@ -254,32 +265,18 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec)
) )
] ]
@classmethod
def get_parameters_from_uri( # type: ignore
cls, uri: str, *_, **__
) -> DatabricksParametersType:
url = make_url_safe(uri)
encryption = all(
item in url.query.items() for item in cls.encryption_parameters.items()
)
return {
"access_token": url.password,
"host": url.host,
"port": url.port,
"database": url.database,
"encryption": encryption,
}
@classmethod @classmethod
def validate_parameters( # type: ignore def validate_parameters( # type: ignore
cls, cls,
properties: DatabricksPropertiesType, properties: Union[
DatabricksNativePropertiesType,
DatabricksPythonConnectorPropertiesType,
],
) -> list[SupersetError]: ) -> list[SupersetError]:
errors: list[SupersetError] = [] errors: list[SupersetError] = []
required = {"access_token", "host", "port", "database", "extra"} if extra := json.loads(properties.get("extra")): # type: ignore
extra = json.loads(properties.get("extra", "{}")) engine_params = extra.get("engine_params", {})
engine_params = extra.get("engine_params", {}) connect_args = engine_params.get("connect_args", {})
connect_args = engine_params.get("connect_args", {})
parameters = { parameters = {
**properties, **properties,
**properties.get("parameters", {}), **properties.get("parameters", {}),
@@ -289,7 +286,7 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec)
present = {key for key in parameters if parameters.get(key, ())} present = {key for key in parameters if parameters.get(key, ())}
if missing := sorted(required - present): if missing := sorted(cls.required_parameters - present):
errors.append( errors.append(
SupersetError( SupersetError(
message=f'One or more parameters are missing: {", ".join(missing)}', message=f'One or more parameters are missing: {", ".join(missing)}',
@@ -351,6 +348,69 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec)
) )
return errors return errors
class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
engine = "databricks"
engine_name = "Databricks"
drivers = {"connector": "Native all-purpose driver"}
default_driver = "connector"
parameters_schema = DatabricksNativeSchema()
properties_schema = DatabricksNativePropertiesSchema()
sqlalchemy_uri_placeholder = (
"databricks+connector://token:{access_token}@{host}:{port}/{database_name}"
)
context_key_mapping = {
**DatabricksDynamicBaseEngineSpec.context_key_mapping,
"database": "database",
"username": "username",
}
required_parameters = DatabricksDynamicBaseEngineSpec.required_parameters | {
"database",
"extra",
}
@classmethod
def build_sqlalchemy_uri( # type: ignore
cls, parameters: DatabricksNativeParametersType, *_
) -> str:
query = {}
if parameters.get("encryption"):
if not cls.encryption_parameters:
raise Exception( # pylint: disable=broad-exception-raised
"Unable to build a URL with encryption enabled"
)
query.update(cls.encryption_parameters)
return str(
URL.create(
f"{cls.engine}+{cls.default_driver}".rstrip("+"),
username="token",
password=parameters.get("access_token"),
host=parameters["host"],
port=parameters["port"],
database=parameters["database"],
query=query,
)
)
@classmethod
def get_parameters_from_uri( # type: ignore
cls, uri: str, *_, **__
) -> DatabricksNativeParametersType:
url = make_url_safe(uri)
encryption = all(
item in url.query.items() for item in cls.encryption_parameters.items()
)
return {
"access_token": url.password,
"host": url.host,
"port": url.port,
"database": url.database,
"encryption": encryption,
}
@classmethod @classmethod
def parameters_json_schema(cls) -> Any: def parameters_json_schema(cls) -> Any:
""" """
@@ -367,3 +427,78 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec)
) )
spec.components.schema(cls.__name__, schema=cls.properties_schema) spec.components.schema(cls.__name__, schema=cls.properties_schema)
return spec.to_dict()["components"]["schemas"][cls.__name__] return spec.to_dict()["components"]["schemas"][cls.__name__]
class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
engine = "databricks"
engine_name = "Databricks Python Connector"
default_driver = "databricks-sql-python"
drivers = {"databricks-sql-python": "Databricks SQL Python"}
parameters_schema = DatabricksPythonConnectorSchema()
sqlalchemy_uri_placeholder = (
"databricks://token:{access_token}@{host}:{port}?http_path={http_path}"
"&catalog={default_catalog}&schema={default_schema}"
)
context_key_mapping = {
**DatabricksDynamicBaseEngineSpec.context_key_mapping,
"default_catalog": "catalog",
"default_schema": "schema",
"http_path_field": "http_path",
}
required_parameters = DatabricksDynamicBaseEngineSpec.required_parameters | {
"default_catalog",
"default_schema",
"http_path_field",
}
@classmethod
def build_sqlalchemy_uri( # type: ignore
cls, parameters: DatabricksPythonConnectorParametersType, *_
) -> str:
query = {}
if http_path := parameters.get("http_path_field"):
query["http_path"] = http_path
if catalog := parameters.get("default_catalog"):
query["catalog"] = catalog
if schema := parameters.get("default_schema"):
query["schema"] = schema
if parameters.get("encryption"):
query.update(cls.encryption_parameters)
return str(
URL.create(
cls.engine,
username="token",
password=parameters.get("access_token"),
host=parameters["host"],
port=parameters["port"],
query=query,
)
)
@classmethod
def get_parameters_from_uri( # type: ignore
cls, uri: str, *_: Any, **__: Any
) -> DatabricksPythonConnectorParametersType:
url = make_url_safe(uri)
query = {
key: value
for (key, value) in url.query.items()
if (key, value) not in cls.encryption_parameters.items()
}
encryption = all(
item in url.query.items() for item in cls.encryption_parameters.items()
)
return {
"access_token": url.password,
"host": url.host,
"port": url.port,
"http_path_field": query["http_path"],
"default_catalog": query["catalog"],
"default_schema": query["schema"],
"encryption": encryption,
}

View File

@@ -35,13 +35,13 @@ def test_get_parameters_from_uri() -> None:
""" """
from superset.db_engine_specs.databricks import ( from superset.db_engine_specs.databricks import (
DatabricksNativeEngineSpec, DatabricksNativeEngineSpec,
DatabricksParametersType, DatabricksNativeParametersType,
) )
parameters = DatabricksNativeEngineSpec.get_parameters_from_uri( parameters = DatabricksNativeEngineSpec.get_parameters_from_uri(
"databricks+connector://token:abc12345@my_hostname:1234/test" "databricks+connector://token:abc12345@my_hostname:1234/test"
) )
assert parameters == DatabricksParametersType( assert parameters == DatabricksNativeParametersType(
{ {
"access_token": "abc12345", "access_token": "abc12345",
"host": "my_hostname", "host": "my_hostname",
@@ -60,10 +60,10 @@ def test_build_sqlalchemy_uri() -> None:
""" """
from superset.db_engine_specs.databricks import ( from superset.db_engine_specs.databricks import (
DatabricksNativeEngineSpec, DatabricksNativeEngineSpec,
DatabricksParametersType, DatabricksNativeParametersType,
) )
parameters = DatabricksParametersType( parameters = DatabricksNativeParametersType(
{ {
"access_token": "abc12345", "access_token": "abc12345",
"host": "my_hostname", "host": "my_hostname",