mirror of
https://github.com/apache/superset.git
synced 2026-04-19 16:14:52 +00:00
chore(Databricks): New Databricks driver (#28393)
This commit is contained in:
@@ -116,6 +116,69 @@ export const databaseField = ({
|
|||||||
helpText={t('Copy the name of the database you are trying to connect to.')}
|
helpText={t('Copy the name of the database you are trying to connect to.')}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
|
export const defaultCatalogField = ({
|
||||||
|
required,
|
||||||
|
changeMethods,
|
||||||
|
getValidation,
|
||||||
|
validationErrors,
|
||||||
|
db,
|
||||||
|
}: FieldPropTypes) => (
|
||||||
|
<ValidatedInput
|
||||||
|
id="default_catalog"
|
||||||
|
name="default_catalog"
|
||||||
|
required={required}
|
||||||
|
value={db?.parameters?.default_catalog}
|
||||||
|
validationMethods={{ onBlur: getValidation }}
|
||||||
|
errorMessage={validationErrors?.default_catalog}
|
||||||
|
placeholder={t('e.g. hive_metastore')}
|
||||||
|
label={t('Default Catalog')}
|
||||||
|
onChange={changeMethods.onParametersChange}
|
||||||
|
helpText={t('The default catalog that should be used for the connection.')}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
export const defaultSchemaField = ({
|
||||||
|
required,
|
||||||
|
changeMethods,
|
||||||
|
getValidation,
|
||||||
|
validationErrors,
|
||||||
|
db,
|
||||||
|
}: FieldPropTypes) => (
|
||||||
|
<ValidatedInput
|
||||||
|
id="default_schema"
|
||||||
|
name="default_schema"
|
||||||
|
required={required}
|
||||||
|
value={db?.parameters?.default_schema}
|
||||||
|
validationMethods={{ onBlur: getValidation }}
|
||||||
|
errorMessage={validationErrors?.default_schema}
|
||||||
|
placeholder={t('e.g. default')}
|
||||||
|
label={t('Default Schema')}
|
||||||
|
onChange={changeMethods.onParametersChange}
|
||||||
|
helpText={t('The default schema that should be used for the connection.')}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
export const httpPathField = ({
|
||||||
|
required,
|
||||||
|
changeMethods,
|
||||||
|
getValidation,
|
||||||
|
validationErrors,
|
||||||
|
db,
|
||||||
|
}: FieldPropTypes) => {
|
||||||
|
console.error(db);
|
||||||
|
return (
|
||||||
|
<ValidatedInput
|
||||||
|
id="http_path_field"
|
||||||
|
name="http_path_field"
|
||||||
|
required={required}
|
||||||
|
value={db?.parameters?.http_path_field}
|
||||||
|
validationMethods={{ onBlur: getValidation }}
|
||||||
|
errorMessage={validationErrors?.http_path}
|
||||||
|
placeholder={t('e.g. sql/protocolv1/o/12345')}
|
||||||
|
label="HTTP Path"
|
||||||
|
onChange={changeMethods.onParametersChange}
|
||||||
|
helpText={t('Copy the name of the HTTP Path of your cluster.')}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
export const usernameField = ({
|
export const usernameField = ({
|
||||||
required,
|
required,
|
||||||
changeMethods,
|
changeMethods,
|
||||||
|
|||||||
@@ -27,10 +27,13 @@ import { Form } from 'src/components/Form';
|
|||||||
import {
|
import {
|
||||||
accessTokenField,
|
accessTokenField,
|
||||||
databaseField,
|
databaseField,
|
||||||
|
defaultCatalogField,
|
||||||
|
defaultSchemaField,
|
||||||
displayField,
|
displayField,
|
||||||
forceSSLField,
|
forceSSLField,
|
||||||
hostField,
|
hostField,
|
||||||
httpPath,
|
httpPath,
|
||||||
|
httpPathField,
|
||||||
passwordField,
|
passwordField,
|
||||||
portField,
|
portField,
|
||||||
queryField,
|
queryField,
|
||||||
@@ -47,10 +50,13 @@ export const FormFieldOrder = [
|
|||||||
'host',
|
'host',
|
||||||
'port',
|
'port',
|
||||||
'database',
|
'database',
|
||||||
|
'default_catalog',
|
||||||
|
'default_schema',
|
||||||
'username',
|
'username',
|
||||||
'password',
|
'password',
|
||||||
'access_token',
|
'access_token',
|
||||||
'http_path',
|
'http_path',
|
||||||
|
'http_path_field',
|
||||||
'database_name',
|
'database_name',
|
||||||
'credentials_info',
|
'credentials_info',
|
||||||
'service_account_info',
|
'service_account_info',
|
||||||
@@ -71,8 +77,11 @@ const SSHTunnelSwitchComponent =
|
|||||||
const FORM_FIELD_MAP = {
|
const FORM_FIELD_MAP = {
|
||||||
host: hostField,
|
host: hostField,
|
||||||
http_path: httpPath,
|
http_path: httpPath,
|
||||||
|
http_path_field: httpPathField,
|
||||||
port: portField,
|
port: portField,
|
||||||
database: databaseField,
|
database: databaseField,
|
||||||
|
default_catalog: defaultCatalogField,
|
||||||
|
default_schema: defaultSchemaField,
|
||||||
username: usernameField,
|
username: usernameField,
|
||||||
password: passwordField,
|
password: passwordField,
|
||||||
access_token: accessTokenField,
|
access_token: accessTokenField,
|
||||||
|
|||||||
@@ -633,11 +633,23 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
|||||||
const history = useHistory();
|
const history = useHistory();
|
||||||
|
|
||||||
const dbModel: DatabaseForm =
|
const dbModel: DatabaseForm =
|
||||||
|
// TODO: we need a centralized engine in one place
|
||||||
|
|
||||||
|
// first try to match both engine and driver
|
||||||
|
availableDbs?.databases?.find(
|
||||||
|
(available: {
|
||||||
|
engine: string | undefined;
|
||||||
|
default_driver: string | undefined;
|
||||||
|
}) =>
|
||||||
|
available.engine === (isEditMode ? db?.backend : db?.engine) &&
|
||||||
|
available.default_driver === db?.driver,
|
||||||
|
) ||
|
||||||
|
// alternatively try to match only engine
|
||||||
availableDbs?.databases?.find(
|
availableDbs?.databases?.find(
|
||||||
(available: { engine: string | undefined }) =>
|
(available: { engine: string | undefined }) =>
|
||||||
// TODO: we need a centralized engine in one place
|
|
||||||
available.engine === (isEditMode ? db?.backend : db?.engine),
|
available.engine === (isEditMode ? db?.backend : db?.engine),
|
||||||
) || {};
|
) ||
|
||||||
|
{};
|
||||||
|
|
||||||
// Test Connection logic
|
// Test Connection logic
|
||||||
const testConnection = () => {
|
const testConnection = () => {
|
||||||
|
|||||||
@@ -63,6 +63,9 @@ export type DatabaseObject = {
|
|||||||
host?: string;
|
host?: string;
|
||||||
port?: number;
|
port?: number;
|
||||||
database?: string;
|
database?: string;
|
||||||
|
default_catalog?: string;
|
||||||
|
default_schema?: string;
|
||||||
|
http_path_field?: string;
|
||||||
username?: string;
|
username?: string;
|
||||||
password?: string;
|
password?: string;
|
||||||
encryption?: boolean;
|
encryption?: boolean;
|
||||||
@@ -126,6 +129,18 @@ export type DatabaseForm = {
|
|||||||
description: string;
|
description: string;
|
||||||
type: string;
|
type: string;
|
||||||
};
|
};
|
||||||
|
default_catalog: {
|
||||||
|
description: string;
|
||||||
|
type: string;
|
||||||
|
};
|
||||||
|
default_schema: {
|
||||||
|
description: string;
|
||||||
|
type: string;
|
||||||
|
};
|
||||||
|
http_path_field: {
|
||||||
|
description: string;
|
||||||
|
type: string;
|
||||||
|
};
|
||||||
host: {
|
host: {
|
||||||
description: string;
|
description: string;
|
||||||
type: string;
|
type: string;
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, TYPE_CHECKING, TypedDict
|
from typing import Any, TYPE_CHECKING, TypedDict, Union
|
||||||
|
|
||||||
from apispec import APISpec
|
from apispec import APISpec
|
||||||
from apispec.ext.marshmallow import MarshmallowPlugin
|
from apispec.ext.marshmallow import MarshmallowPlugin
|
||||||
@@ -40,10 +40,10 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
class DatabricksParametersSchema(Schema):
|
class DatabricksBaseSchema(Schema):
|
||||||
"""
|
"""
|
||||||
This is the list of fields that are expected
|
Fields that are required for both Databricks drivers that uses a
|
||||||
from the client in order to build the sqlalchemy string
|
dynamic form.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
access_token = fields.Str(required=True)
|
access_token = fields.Str(required=True)
|
||||||
@@ -53,44 +53,85 @@ class DatabricksParametersSchema(Schema):
|
|||||||
metadata={"description": __("Database port")},
|
metadata={"description": __("Database port")},
|
||||||
validate=Range(min=0, max=2**16, max_inclusive=False),
|
validate=Range(min=0, max=2**16, max_inclusive=False),
|
||||||
)
|
)
|
||||||
database = fields.Str(required=True)
|
|
||||||
encryption = fields.Boolean(
|
encryption = fields.Boolean(
|
||||||
required=False,
|
required=False,
|
||||||
metadata={"description": __("Use an encrypted connection to the database")},
|
metadata={"description": __("Use an encrypted connection to the database")},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DatabricksPropertiesSchema(DatabricksParametersSchema):
|
class DatabricksBaseParametersType(TypedDict):
|
||||||
"""
|
"""
|
||||||
This is the list of fields expected
|
The parameters are all the keys that do not exist on the Database model.
|
||||||
for successful database creation execution
|
These are used to build the sqlalchemy uri.
|
||||||
"""
|
|
||||||
|
|
||||||
http_path = fields.Str(required=True)
|
|
||||||
|
|
||||||
|
|
||||||
class DatabricksParametersType(TypedDict):
|
|
||||||
"""
|
|
||||||
The parameters are all the keys that do
|
|
||||||
not exist on the Database model.
|
|
||||||
These are used to build the sqlalchemy uri
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
access_token: str
|
access_token: str
|
||||||
host: str
|
host: str
|
||||||
port: int
|
port: int
|
||||||
database: str
|
|
||||||
encryption: bool
|
encryption: bool
|
||||||
|
|
||||||
|
|
||||||
class DatabricksPropertiesType(TypedDict):
|
class DatabricksNativeSchema(DatabricksBaseSchema):
|
||||||
"""
|
"""
|
||||||
All properties that need to be available to
|
Additional fields required only for the DatabricksNativeEngineSpec.
|
||||||
this engine in order to create a connection
|
|
||||||
if the dynamic form is used
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
parameters: DatabricksParametersType
|
database = fields.Str(required=True)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabricksNativePropertiesSchema(DatabricksNativeSchema):
|
||||||
|
"""
|
||||||
|
Properties required only for the DatabricksNativeEngineSpec.
|
||||||
|
"""
|
||||||
|
|
||||||
|
http_path = fields.Str(required=True)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabricksNativeParametersType(DatabricksBaseParametersType):
|
||||||
|
"""
|
||||||
|
Additional parameters required only for the DatabricksNativeEngineSpec.
|
||||||
|
"""
|
||||||
|
|
||||||
|
database: str
|
||||||
|
|
||||||
|
|
||||||
|
class DatabricksNativePropertiesType(TypedDict):
|
||||||
|
"""
|
||||||
|
All properties that need to be available to the DatabricksNativeEngineSpec
|
||||||
|
in order tocreate a connection if the dynamic form is used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
parameters: DatabricksNativeParametersType
|
||||||
|
extra: str
|
||||||
|
|
||||||
|
|
||||||
|
class DatabricksPythonConnectorSchema(DatabricksBaseSchema):
|
||||||
|
"""
|
||||||
|
Additional fields required only for the DatabricksPythonConnectorEngineSpec.
|
||||||
|
"""
|
||||||
|
|
||||||
|
http_path_field = fields.Str(required=True)
|
||||||
|
default_catalog = fields.Str(required=True)
|
||||||
|
default_schema = fields.Str(required=True)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabricksPythonConnectorParametersType(DatabricksBaseParametersType):
|
||||||
|
"""
|
||||||
|
Additional parameters required only for the DatabricksPythonConnectorEngineSpec.
|
||||||
|
"""
|
||||||
|
|
||||||
|
http_path_field: str
|
||||||
|
default_catalog: str
|
||||||
|
default_schema: str
|
||||||
|
|
||||||
|
|
||||||
|
class DatabricksPythonConnectorPropertiesType(TypedDict):
|
||||||
|
"""
|
||||||
|
All properties that need to be available to the DatabricksPythonConnectorEngineSpec
|
||||||
|
in order to create a connection if the dynamic form is used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
parameters: DatabricksPythonConnectorParametersType
|
||||||
extra: str
|
extra: str
|
||||||
|
|
||||||
|
|
||||||
@@ -125,13 +166,7 @@ class DatabricksHiveEngineSpec(HiveEngineSpec):
|
|||||||
_time_grain_expressions = time_grain_expressions
|
_time_grain_expressions = time_grain_expressions
|
||||||
|
|
||||||
|
|
||||||
class DatabricksODBCEngineSpec(BaseEngineSpec):
|
class DatabricksBaseEngineSpec(BaseEngineSpec):
|
||||||
engine_name = "Databricks SQL Endpoint"
|
|
||||||
|
|
||||||
engine = "databricks"
|
|
||||||
drivers = {"pyodbc": "ODBC driver for SQL endpoint"}
|
|
||||||
default_driver = "pyodbc"
|
|
||||||
|
|
||||||
_time_grain_expressions = time_grain_expressions
|
_time_grain_expressions = time_grain_expressions
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -145,20 +180,23 @@ class DatabricksODBCEngineSpec(BaseEngineSpec):
|
|||||||
return HiveEngineSpec.epoch_to_dttm()
|
return HiveEngineSpec.epoch_to_dttm()
|
||||||
|
|
||||||
|
|
||||||
class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec):
|
class DatabricksODBCEngineSpec(DatabricksBaseEngineSpec):
|
||||||
engine_name = "Databricks"
|
engine_name = "Databricks SQL Endpoint"
|
||||||
|
|
||||||
engine = "databricks"
|
engine = "databricks"
|
||||||
drivers = {"connector": "Native all-purpose driver"}
|
drivers = {"pyodbc": "ODBC driver for SQL endpoint"}
|
||||||
default_driver = "connector"
|
default_driver = "pyodbc"
|
||||||
|
|
||||||
parameters_schema = DatabricksParametersSchema()
|
|
||||||
properties_schema = DatabricksPropertiesSchema()
|
|
||||||
|
|
||||||
sqlalchemy_uri_placeholder = (
|
class DatabricksDynamicBaseEngineSpec(BasicParametersMixin, DatabricksBaseEngineSpec):
|
||||||
"databricks+connector://token:{access_token}@{host}:{port}/{database_name}"
|
default_driver = ""
|
||||||
)
|
|
||||||
encryption_parameters = {"ssl": "1"}
|
encryption_parameters = {"ssl": "1"}
|
||||||
|
required_parameters = {"access_token", "host", "port"}
|
||||||
|
context_key_mapping = {
|
||||||
|
"access_token": "password",
|
||||||
|
"host": "hostname",
|
||||||
|
"port": "port",
|
||||||
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_extra_params(database: Database) -> dict[str, Any]:
|
def get_extra_params(database: Database) -> dict[str, Any]:
|
||||||
@@ -190,30 +228,6 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec)
|
|||||||
database, inspector, schema
|
database, inspector, schema
|
||||||
) - cls.get_view_names(database, inspector, schema)
|
) - cls.get_view_names(database, inspector, schema)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def build_sqlalchemy_uri( # type: ignore
|
|
||||||
cls, parameters: DatabricksParametersType, *_
|
|
||||||
) -> str:
|
|
||||||
query = {}
|
|
||||||
if parameters.get("encryption"):
|
|
||||||
if not cls.encryption_parameters:
|
|
||||||
raise Exception( # pylint: disable=broad-exception-raised
|
|
||||||
"Unable to build a URL with encryption enabled"
|
|
||||||
)
|
|
||||||
query.update(cls.encryption_parameters)
|
|
||||||
|
|
||||||
return str(
|
|
||||||
URL.create(
|
|
||||||
f"{cls.engine}+{cls.default_driver}".rstrip("+"),
|
|
||||||
username="token",
|
|
||||||
password=parameters.get("access_token"),
|
|
||||||
host=parameters["host"],
|
|
||||||
port=parameters["port"],
|
|
||||||
database=parameters["database"],
|
|
||||||
query=query,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def extract_errors(
|
def extract_errors(
|
||||||
cls, ex: Exception, context: dict[str, Any] | None = None
|
cls, ex: Exception, context: dict[str, Any] | None = None
|
||||||
@@ -224,13 +238,10 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec)
|
|||||||
# access_token isn't currently parseable from the
|
# access_token isn't currently parseable from the
|
||||||
# databricks error response, but adding it in here
|
# databricks error response, but adding it in here
|
||||||
# for reference if their error message changes
|
# for reference if their error message changes
|
||||||
context = {
|
|
||||||
"host": context.get("hostname"),
|
for key, value in cls.context_key_mapping.items():
|
||||||
"access_token": context.get("password"),
|
context[key] = context.get(value)
|
||||||
"port": context.get("port"),
|
|
||||||
"username": context.get("username"),
|
|
||||||
"database": context.get("database"),
|
|
||||||
}
|
|
||||||
for regex, (message, error_type, extra) in cls.custom_errors.items():
|
for regex, (message, error_type, extra) in cls.custom_errors.items():
|
||||||
match = regex.search(raw_message)
|
match = regex.search(raw_message)
|
||||||
if match:
|
if match:
|
||||||
@@ -254,32 +265,18 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec)
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_parameters_from_uri( # type: ignore
|
|
||||||
cls, uri: str, *_, **__
|
|
||||||
) -> DatabricksParametersType:
|
|
||||||
url = make_url_safe(uri)
|
|
||||||
encryption = all(
|
|
||||||
item in url.query.items() for item in cls.encryption_parameters.items()
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"access_token": url.password,
|
|
||||||
"host": url.host,
|
|
||||||
"port": url.port,
|
|
||||||
"database": url.database,
|
|
||||||
"encryption": encryption,
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_parameters( # type: ignore
|
def validate_parameters( # type: ignore
|
||||||
cls,
|
cls,
|
||||||
properties: DatabricksPropertiesType,
|
properties: Union[
|
||||||
|
DatabricksNativePropertiesType,
|
||||||
|
DatabricksPythonConnectorPropertiesType,
|
||||||
|
],
|
||||||
) -> list[SupersetError]:
|
) -> list[SupersetError]:
|
||||||
errors: list[SupersetError] = []
|
errors: list[SupersetError] = []
|
||||||
required = {"access_token", "host", "port", "database", "extra"}
|
if extra := json.loads(properties.get("extra")): # type: ignore
|
||||||
extra = json.loads(properties.get("extra", "{}"))
|
engine_params = extra.get("engine_params", {})
|
||||||
engine_params = extra.get("engine_params", {})
|
connect_args = engine_params.get("connect_args", {})
|
||||||
connect_args = engine_params.get("connect_args", {})
|
|
||||||
parameters = {
|
parameters = {
|
||||||
**properties,
|
**properties,
|
||||||
**properties.get("parameters", {}),
|
**properties.get("parameters", {}),
|
||||||
@@ -289,7 +286,7 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec)
|
|||||||
|
|
||||||
present = {key for key in parameters if parameters.get(key, ())}
|
present = {key for key in parameters if parameters.get(key, ())}
|
||||||
|
|
||||||
if missing := sorted(required - present):
|
if missing := sorted(cls.required_parameters - present):
|
||||||
errors.append(
|
errors.append(
|
||||||
SupersetError(
|
SupersetError(
|
||||||
message=f'One or more parameters are missing: {", ".join(missing)}',
|
message=f'One or more parameters are missing: {", ".join(missing)}',
|
||||||
@@ -351,6 +348,69 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec)
|
|||||||
)
|
)
|
||||||
return errors
|
return errors
|
||||||
|
|
||||||
|
|
||||||
|
class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
|
||||||
|
engine = "databricks"
|
||||||
|
engine_name = "Databricks"
|
||||||
|
drivers = {"connector": "Native all-purpose driver"}
|
||||||
|
default_driver = "connector"
|
||||||
|
|
||||||
|
parameters_schema = DatabricksNativeSchema()
|
||||||
|
properties_schema = DatabricksNativePropertiesSchema()
|
||||||
|
|
||||||
|
sqlalchemy_uri_placeholder = (
|
||||||
|
"databricks+connector://token:{access_token}@{host}:{port}/{database_name}"
|
||||||
|
)
|
||||||
|
context_key_mapping = {
|
||||||
|
**DatabricksDynamicBaseEngineSpec.context_key_mapping,
|
||||||
|
"database": "database",
|
||||||
|
"username": "username",
|
||||||
|
}
|
||||||
|
required_parameters = DatabricksDynamicBaseEngineSpec.required_parameters | {
|
||||||
|
"database",
|
||||||
|
"extra",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_sqlalchemy_uri( # type: ignore
|
||||||
|
cls, parameters: DatabricksNativeParametersType, *_
|
||||||
|
) -> str:
|
||||||
|
query = {}
|
||||||
|
if parameters.get("encryption"):
|
||||||
|
if not cls.encryption_parameters:
|
||||||
|
raise Exception( # pylint: disable=broad-exception-raised
|
||||||
|
"Unable to build a URL with encryption enabled"
|
||||||
|
)
|
||||||
|
query.update(cls.encryption_parameters)
|
||||||
|
|
||||||
|
return str(
|
||||||
|
URL.create(
|
||||||
|
f"{cls.engine}+{cls.default_driver}".rstrip("+"),
|
||||||
|
username="token",
|
||||||
|
password=parameters.get("access_token"),
|
||||||
|
host=parameters["host"],
|
||||||
|
port=parameters["port"],
|
||||||
|
database=parameters["database"],
|
||||||
|
query=query,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_parameters_from_uri( # type: ignore
|
||||||
|
cls, uri: str, *_, **__
|
||||||
|
) -> DatabricksNativeParametersType:
|
||||||
|
url = make_url_safe(uri)
|
||||||
|
encryption = all(
|
||||||
|
item in url.query.items() for item in cls.encryption_parameters.items()
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"access_token": url.password,
|
||||||
|
"host": url.host,
|
||||||
|
"port": url.port,
|
||||||
|
"database": url.database,
|
||||||
|
"encryption": encryption,
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parameters_json_schema(cls) -> Any:
|
def parameters_json_schema(cls) -> Any:
|
||||||
"""
|
"""
|
||||||
@@ -367,3 +427,78 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, DatabricksODBCEngineSpec)
|
|||||||
)
|
)
|
||||||
spec.components.schema(cls.__name__, schema=cls.properties_schema)
|
spec.components.schema(cls.__name__, schema=cls.properties_schema)
|
||||||
return spec.to_dict()["components"]["schemas"][cls.__name__]
|
return spec.to_dict()["components"]["schemas"][cls.__name__]
|
||||||
|
|
||||||
|
|
||||||
|
class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
|
||||||
|
engine = "databricks"
|
||||||
|
engine_name = "Databricks Python Connector"
|
||||||
|
default_driver = "databricks-sql-python"
|
||||||
|
drivers = {"databricks-sql-python": "Databricks SQL Python"}
|
||||||
|
|
||||||
|
parameters_schema = DatabricksPythonConnectorSchema()
|
||||||
|
|
||||||
|
sqlalchemy_uri_placeholder = (
|
||||||
|
"databricks://token:{access_token}@{host}:{port}?http_path={http_path}"
|
||||||
|
"&catalog={default_catalog}&schema={default_schema}"
|
||||||
|
)
|
||||||
|
|
||||||
|
context_key_mapping = {
|
||||||
|
**DatabricksDynamicBaseEngineSpec.context_key_mapping,
|
||||||
|
"default_catalog": "catalog",
|
||||||
|
"default_schema": "schema",
|
||||||
|
"http_path_field": "http_path",
|
||||||
|
}
|
||||||
|
|
||||||
|
required_parameters = DatabricksDynamicBaseEngineSpec.required_parameters | {
|
||||||
|
"default_catalog",
|
||||||
|
"default_schema",
|
||||||
|
"http_path_field",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_sqlalchemy_uri( # type: ignore
|
||||||
|
cls, parameters: DatabricksPythonConnectorParametersType, *_
|
||||||
|
) -> str:
|
||||||
|
query = {}
|
||||||
|
if http_path := parameters.get("http_path_field"):
|
||||||
|
query["http_path"] = http_path
|
||||||
|
if catalog := parameters.get("default_catalog"):
|
||||||
|
query["catalog"] = catalog
|
||||||
|
if schema := parameters.get("default_schema"):
|
||||||
|
query["schema"] = schema
|
||||||
|
if parameters.get("encryption"):
|
||||||
|
query.update(cls.encryption_parameters)
|
||||||
|
|
||||||
|
return str(
|
||||||
|
URL.create(
|
||||||
|
cls.engine,
|
||||||
|
username="token",
|
||||||
|
password=parameters.get("access_token"),
|
||||||
|
host=parameters["host"],
|
||||||
|
port=parameters["port"],
|
||||||
|
query=query,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_parameters_from_uri( # type: ignore
|
||||||
|
cls, uri: str, *_: Any, **__: Any
|
||||||
|
) -> DatabricksPythonConnectorParametersType:
|
||||||
|
url = make_url_safe(uri)
|
||||||
|
query = {
|
||||||
|
key: value
|
||||||
|
for (key, value) in url.query.items()
|
||||||
|
if (key, value) not in cls.encryption_parameters.items()
|
||||||
|
}
|
||||||
|
encryption = all(
|
||||||
|
item in url.query.items() for item in cls.encryption_parameters.items()
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"access_token": url.password,
|
||||||
|
"host": url.host,
|
||||||
|
"port": url.port,
|
||||||
|
"http_path_field": query["http_path"],
|
||||||
|
"default_catalog": query["catalog"],
|
||||||
|
"default_schema": query["schema"],
|
||||||
|
"encryption": encryption,
|
||||||
|
}
|
||||||
|
|||||||
@@ -35,13 +35,13 @@ def test_get_parameters_from_uri() -> None:
|
|||||||
"""
|
"""
|
||||||
from superset.db_engine_specs.databricks import (
|
from superset.db_engine_specs.databricks import (
|
||||||
DatabricksNativeEngineSpec,
|
DatabricksNativeEngineSpec,
|
||||||
DatabricksParametersType,
|
DatabricksNativeParametersType,
|
||||||
)
|
)
|
||||||
|
|
||||||
parameters = DatabricksNativeEngineSpec.get_parameters_from_uri(
|
parameters = DatabricksNativeEngineSpec.get_parameters_from_uri(
|
||||||
"databricks+connector://token:abc12345@my_hostname:1234/test"
|
"databricks+connector://token:abc12345@my_hostname:1234/test"
|
||||||
)
|
)
|
||||||
assert parameters == DatabricksParametersType(
|
assert parameters == DatabricksNativeParametersType(
|
||||||
{
|
{
|
||||||
"access_token": "abc12345",
|
"access_token": "abc12345",
|
||||||
"host": "my_hostname",
|
"host": "my_hostname",
|
||||||
@@ -60,10 +60,10 @@ def test_build_sqlalchemy_uri() -> None:
|
|||||||
"""
|
"""
|
||||||
from superset.db_engine_specs.databricks import (
|
from superset.db_engine_specs.databricks import (
|
||||||
DatabricksNativeEngineSpec,
|
DatabricksNativeEngineSpec,
|
||||||
DatabricksParametersType,
|
DatabricksNativeParametersType,
|
||||||
)
|
)
|
||||||
|
|
||||||
parameters = DatabricksParametersType(
|
parameters = DatabricksNativeParametersType(
|
||||||
{
|
{
|
||||||
"access_token": "abc12345",
|
"access_token": "abc12345",
|
||||||
"host": "my_hostname",
|
"host": "my_hostname",
|
||||||
|
|||||||
Reference in New Issue
Block a user