Files
superset2/superset/db_engine_specs/databricks.py
2026-01-21 10:54:01 -08:00

770 lines
26 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from datetime import datetime
from typing import Any, Callable, TYPE_CHECKING, TypedDict, Union
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from flask_babel import gettext as __
from marshmallow import fields, Schema
from marshmallow.validate import Range
from sqlalchemy import types
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from superset.constants import TimeGrain
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import (
BaseEngineSpec,
BasicParametersMixin,
DatabaseCategory,
)
from superset.db_engine_specs.hive import HiveEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.utils import json
from superset.utils.core import get_user_agent, QuerySource
from superset.utils.network import is_hostname_valid, is_port_open
if TYPE_CHECKING:
from superset.models.core import Database
try:
from databricks.sql.utils import ParamEscaper
except ImportError:
class ParamEscaper: # type: ignore
"""Dummy class."""
class DatabricksStringType(types.TypeDecorator):
impl = types.String
cache_ok = True
pe = ParamEscaper()
def process_literal_param(self, value: Any, dialect: Any) -> str:
return self.pe.escape_string(value)
def literal_processor(self, dialect: Any) -> Callable[[Any], str]:
def process(value: Any) -> str:
_step1 = self.process_literal_param(value, dialect="databricks")
if dialect.identifier_preparer._double_percents:
_step2 = _step1.replace("%", "%%")
else:
_step2 = _step1
return "%s" % _step2
return process
def monkeypatch_dialect() -> None:
"""
Monkeypatch dialect to correctly escape single quotes for Databricks.
The Databricks SQLAlchemy dialect (<3.0) incorrectly escapes single quotes by
doubling them ('O''Hara') instead of using backslash escaping ('O\'Hara'). The
fixed version requires SQLAlchemy>=2.0, which is not yet compatible with Superset.
Since the DatabricksDialect.colspecs points to the base class (HiveDialect.colspecs)
we can't patch it without affecting other Hive-based dialects. The solution is to
introduce a dialect-aware string type so that the change applies only to Databricks.
"""
try:
from pyhive.sqlalchemy_hive import HiveDialect
class ContextAwareStringType(types.TypeDecorator):
impl = types.String
cache_ok = True
def literal_processor(
self, dialect: DefaultDialect
) -> Callable[[Any], str]:
if dialect.__class__.__name__ == "DatabricksDialect":
return DatabricksStringType().literal_processor(dialect)
return super().literal_processor(dialect)
HiveDialect.colspecs[types.String] = ContextAwareStringType
except ImportError:
pass
class DatabricksBaseSchema(Schema):
"""
Fields that are required for both Databricks drivers that uses a
dynamic form.
"""
access_token = fields.Str(required=True)
host = fields.Str(required=True)
port = fields.Integer(
required=True,
metadata={"description": __("Database port")},
validate=Range(min=0, max=2**16, max_inclusive=False),
)
encryption = fields.Boolean(
required=False,
metadata={"description": __("Use an encrypted connection to the database")},
)
class DatabricksBaseParametersType(TypedDict):
"""
The parameters are all the keys that do not exist on the Database model.
These are used to build the sqlalchemy uri.
"""
access_token: str
host: str
port: int
encryption: bool
class DatabricksNativeSchema(DatabricksBaseSchema):
"""
Additional fields required only for the DatabricksNativeEngineSpec.
"""
database = fields.Str(required=True)
class DatabricksNativePropertiesSchema(DatabricksNativeSchema):
"""
Properties required only for the DatabricksNativeEngineSpec.
"""
http_path = fields.Str(required=True)
class DatabricksNativeParametersType(DatabricksBaseParametersType):
"""
Additional parameters required only for the DatabricksNativeEngineSpec.
"""
database: str
class DatabricksNativePropertiesType(TypedDict):
"""
All properties that need to be available to the DatabricksNativeEngineSpec
in order tocreate a connection if the dynamic form is used.
"""
parameters: DatabricksNativeParametersType
extra: str
class DatabricksPythonConnectorSchema(DatabricksBaseSchema):
"""
Additional fields required only for the DatabricksPythonConnectorEngineSpec.
"""
http_path_field = fields.Str(required=True)
default_catalog = fields.Str(required=True)
default_schema = fields.Str(required=True)
class DatabricksPythonConnectorParametersType(DatabricksBaseParametersType):
"""
Additional parameters required only for the DatabricksPythonConnectorEngineSpec.
"""
http_path_field: str
default_catalog: str
default_schema: str
class DatabricksPythonConnectorPropertiesType(TypedDict):
"""
All properties that need to be available to the DatabricksPythonConnectorEngineSpec
in order to create a connection if the dynamic form is used.
"""
parameters: DatabricksPythonConnectorParametersType
extra: str
time_grain_expressions: dict[str | None, str] = {
None: "{col}",
TimeGrain.SECOND: "date_trunc('second', {col})",
TimeGrain.MINUTE: "date_trunc('minute', {col})",
TimeGrain.HOUR: "date_trunc('hour', {col})",
TimeGrain.DAY: "date_trunc('day', {col})",
TimeGrain.WEEK: "date_trunc('week', {col})",
TimeGrain.MONTH: "date_trunc('month', {col})",
TimeGrain.QUARTER: "date_trunc('quarter', {col})",
TimeGrain.YEAR: "date_trunc('year', {col})",
TimeGrain.WEEK_ENDING_SATURDAY: (
"date_trunc('week', {col} + interval '1 day') + interval '5 days'"
),
TimeGrain.WEEK_STARTING_SUNDAY: (
"date_trunc('week', {col} + interval '1 day') - interval '1 day'"
),
}
class DatabricksHiveEngineSpec(HiveEngineSpec):
"""Databricks engine spec using Hive connector for Interactive Clusters."""
engine_name = "Databricks Interactive Cluster"
engine = "databricks"
drivers = {"pyhive": "Hive driver for Interactive Cluster"}
default_driver = "pyhive"
# Note: Primary metadata is in DatabricksPythonConnectorEngineSpec which
# consolidates all Databricks connection methods. This spec exists for
# backwards compatibility with Interactive Cluster connections.
_show_functions_column = "function"
_time_grain_expressions = time_grain_expressions
class DatabricksBaseEngineSpec(BaseEngineSpec):
_time_grain_expressions = time_grain_expressions
@classmethod
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
) -> str | None:
return HiveEngineSpec.convert_dttm(target_type, dttm, db_extra=db_extra)
@classmethod
def epoch_to_dttm(cls) -> str:
return HiveEngineSpec.epoch_to_dttm()
class DatabricksODBCEngineSpec(DatabricksBaseEngineSpec):
"""Databricks engine spec using ODBC driver for SQL Endpoints."""
engine_name = "Databricks SQL Endpoint"
engine = "databricks"
drivers = {"pyodbc": "ODBC driver for SQL endpoint"}
default_driver = "pyodbc"
# Note: Primary metadata is in DatabricksPythonConnectorEngineSpec which
# consolidates all Databricks connection methods. This spec exists for
# backwards compatibility with ODBC connections to SQL Endpoints.
class DatabricksDynamicBaseEngineSpec(BasicParametersMixin, DatabricksBaseEngineSpec):
default_driver = ""
encryption_parameters = {"ssl": "1"}
required_parameters = {"access_token", "host", "port"}
context_key_mapping = {
"access_token": "password",
"host": "hostname",
"port": "port",
}
@staticmethod
def get_extra_params(
database: Database, source: QuerySource | None = None
) -> dict[str, Any]:
"""
Add a user agent to be used in the requests.
Trim whitespace from connect_args to avoid databricks driver errors
"""
extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database, source)
engine_params: dict[str, Any] = extra.setdefault("engine_params", {})
connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {})
user_agent = get_user_agent(database, source)
connect_args.setdefault("http_headers", [("User-Agent", user_agent)])
connect_args.setdefault("_user_agent_entry", user_agent)
# trim whitespace from http_path to avoid databricks errors on connecting
if http_path := connect_args.get("http_path"):
connect_args["http_path"] = http_path.strip()
return extra
@classmethod
def get_table_names(
cls,
database: Database,
inspector: Inspector,
schema: str | None,
) -> set[str]:
return super().get_table_names(
database, inspector, schema
) - cls.get_view_names(database, inspector, schema)
@classmethod
def extract_errors(
cls,
ex: Exception,
context: dict[str, Any] | None = None,
database_name: str | None = None,
) -> list[SupersetError]:
raw_message = cls._extract_error_message(ex)
context = context or {}
# access_token isn't currently parseable from the
# databricks error response, but adding it in here
# for reference if their error message changes
for key, value in cls.context_key_mapping.items():
context[key] = context.get(value)
db_engine_custom_errors = cls.get_database_custom_errors(database_name)
if not isinstance(db_engine_custom_errors, dict):
db_engine_custom_errors = {}
for regex, (message, error_type, extra) in [
*db_engine_custom_errors.items(),
*cls.custom_errors.items(),
]:
match = regex.search(raw_message)
if match:
params = {**context, **match.groupdict()}
extra["engine_name"] = cls.engine_name
return [
SupersetError(
error_type=error_type,
message=message % params,
level=ErrorLevel.ERROR,
extra=extra,
)
]
return [
SupersetError(
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
message=cls._extract_error_message(ex),
level=ErrorLevel.ERROR,
extra={"engine_name": cls.engine_name},
)
]
@classmethod
def validate_parameters( # type: ignore
cls,
properties: Union[
DatabricksNativePropertiesType,
DatabricksPythonConnectorPropertiesType,
],
) -> list[SupersetError]:
errors: list[SupersetError] = []
if extra := json.loads(properties.get("extra")): # type: ignore
engine_params = extra.get("engine_params", {})
connect_args = engine_params.get("connect_args", {})
parameters = {
**properties,
**properties.get("parameters", {}),
}
if connect_args.get("http_path"):
parameters["http_path"] = connect_args.get("http_path")
present = {key for key in parameters if parameters.get(key, ())}
if missing := sorted(cls.required_parameters - present):
errors.append(
SupersetError(
message=f"One or more parameters are missing: {', '.join(missing)}",
error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR,
level=ErrorLevel.WARNING,
extra={"missing": missing},
),
)
host = parameters.get("host", None)
if not host:
return errors
if not is_hostname_valid(host): # type: ignore
errors.append(
SupersetError(
message="The hostname provided can't be resolved.",
error_type=SupersetErrorType.CONNECTION_INVALID_HOSTNAME_ERROR,
level=ErrorLevel.ERROR,
extra={"invalid": ["host"]},
),
)
return errors
port = parameters.get("port", None)
if not port:
return errors
try:
port = int(port) # type: ignore
except (ValueError, TypeError):
errors.append(
SupersetError(
message="Port must be a valid integer.",
error_type=SupersetErrorType.CONNECTION_INVALID_PORT_ERROR,
level=ErrorLevel.ERROR,
extra={"invalid": ["port"]},
),
)
if not (isinstance(port, int) and 0 <= port < 2**16):
errors.append(
SupersetError(
message=(
"The port must be an integer between 0 and 65535 (inclusive)."
),
error_type=SupersetErrorType.CONNECTION_INVALID_PORT_ERROR,
level=ErrorLevel.ERROR,
extra={"invalid": ["port"]},
),
)
elif not is_port_open(host, port): # type: ignore
errors.append(
SupersetError(
message="The port is closed.",
error_type=SupersetErrorType.CONNECTION_PORT_CLOSED_ERROR,
level=ErrorLevel.ERROR,
extra={"invalid": ["port"]},
),
)
return errors
class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
"""Legacy Databricks connector using databricks-dbapi."""
engine = "databricks"
engine_name = "Databricks (legacy)"
drivers = {"connector": "Native all-purpose driver"}
default_driver = "connector"
parameters_schema = DatabricksNativeSchema()
properties_schema = DatabricksNativePropertiesSchema()
sqlalchemy_uri_placeholder = (
"databricks+connector://token:{access_token}@{host}:{port}/{database_name}"
)
# Note: Primary metadata is in DatabricksPythonConnectorEngineSpec which
# consolidates all Databricks connection methods. This spec exists for
# backwards compatibility with legacy databricks-dbapi connections.
context_key_mapping = {
**DatabricksDynamicBaseEngineSpec.context_key_mapping,
"database": "database",
"username": "username",
}
required_parameters = DatabricksDynamicBaseEngineSpec.required_parameters | {
"database",
"extra",
}
supports_dynamic_schema = True
supports_catalog = True
supports_dynamic_catalog = True
supports_cross_catalog_queries = True
@classmethod
def build_sqlalchemy_uri( # type: ignore
cls, parameters: DatabricksNativeParametersType, *_
) -> str:
query = {}
if parameters.get("encryption"):
if not cls.encryption_parameters:
raise Exception( # pylint: disable=broad-exception-raised
"Unable to build a URL with encryption enabled"
)
query.update(cls.encryption_parameters)
return str(
URL.create(
f"{cls.engine}+{cls.default_driver}".rstrip("+"),
username="token",
password=parameters.get("access_token"),
host=parameters["host"],
port=parameters["port"],
database=parameters["database"],
query=query,
)
)
@classmethod
def get_parameters_from_uri( # type: ignore
cls, uri: str, *_, **__
) -> DatabricksNativeParametersType:
url = make_url_safe(uri)
encryption = all(
item in url.query.items() for item in cls.encryption_parameters.items()
)
return {
"access_token": url.password,
"host": url.host,
"port": url.port,
"database": url.database,
"encryption": encryption,
}
@classmethod
def parameters_json_schema(cls) -> Any:
"""
Return configuration parameters as OpenAPI.
"""
if not cls.properties_schema:
return None
spec = APISpec(
title="Database Parameters",
version="1.0.0",
openapi_version="3.0.2",
plugins=[MarshmallowPlugin()],
)
spec.components.schema(cls.__name__, schema=cls.properties_schema)
return spec.to_dict()["components"]["schemas"][cls.__name__]
@classmethod
def get_default_catalog(cls, database: Database) -> str:
"""
Return the default catalog.
It's optionally specified in `connect_args.catalog`. If not:
The default behavior for Databricks is confusing. When Unity Catalog is not
enabled we have (the DB engine spec hasn't been tested with it enabled):
> SHOW CATALOGS;
spark_catalog
> SELECT current_catalog();
hive_metastore
To handle permissions correctly we use the result of `SHOW CATALOGS` when a
single catalog is returned.
"""
connect_args = cls.get_extra_params(database)["engine_params"]["connect_args"]
if default_catalog := connect_args.get("catalog"):
return default_catalog
with database.get_sqla_engine() as engine:
catalogs = {catalog for (catalog,) in engine.execute("SHOW CATALOGS")}
if len(catalogs) == 1:
return catalogs.pop()
return engine.execute("SELECT current_catalog()").scalar()
@classmethod
def get_prequeries(
cls,
database: Database,
catalog: str | None = None,
schema: str | None = None,
) -> list[str]:
prequeries = []
if catalog:
catalog = f"`{catalog}`" if not catalog.startswith("`") else catalog
prequeries.append(f"USE CATALOG {catalog}")
if schema:
schema = f"`{schema}`" if not schema.startswith("`") else schema
prequeries.append(f"USE SCHEMA {schema}")
return prequeries
@classmethod
def get_catalog_names(
cls,
database: Database,
inspector: Inspector,
) -> set[str]:
return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")}
class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
engine = "databricks"
engine_name = "Databricks"
default_driver = "databricks-sql-python"
drivers = {"databricks-sql-python": "Databricks SQL Python"}
parameters_schema = DatabricksPythonConnectorSchema()
sqlalchemy_uri_placeholder = (
"databricks://token:{access_token}@{host}:{port}?http_path={http_path}"
"&catalog={default_catalog}&schema={default_schema}"
)
metadata = {
"description": (
"Databricks is a unified analytics platform built on Apache "
"Spark, providing data engineering, data science, and machine "
"learning capabilities in the cloud. Use the Python Connector "
"for SQL warehouses and clusters."
),
"logo": "databricks.png",
"homepage_url": "https://www.databricks.com/",
"categories": [
DatabaseCategory.CLOUD_DATA_WAREHOUSES,
DatabaseCategory.ANALYTICAL_DATABASES,
DatabaseCategory.HOSTED_OPEN_SOURCE,
],
"pypi_packages": ["apache-superset[databricks]"],
"install_instructions": "pip install apache-superset[databricks]",
"connection_string": (
"databricks://token:{access_token}@{host}:{port}"
"?http_path={http_path}&catalog={catalog}&schema={schema}"
),
"parameters": {
"access_token": "Personal access token from Settings > User Settings",
"host": "Server hostname from cluster JDBC/ODBC settings",
"port": "Port (default 443)",
"http_path": "HTTP path from cluster JDBC/ODBC settings",
},
"drivers": [
{
"name": "Databricks Python Connector (Recommended)",
"pypi_package": "databricks-sql-connector",
"connection_string": (
"databricks://token:{access_token}@{host}:{port}"
"?http_path={http_path}&catalog={catalog}&schema={schema}"
),
"is_recommended": True,
"notes": (
"Official Databricks connector. Best for SQL warehouses "
"and clusters."
),
},
{
"name": "Hive Connector (Interactive Clusters)",
"pypi_package": "databricks-dbapi[sqlalchemy]",
"connection_string": (
"databricks+pyhive://token:{access_token}@{host}:{port}/{database}"
),
"is_recommended": False,
"notes": (
"For Interactive Clusters. Requires http_path in engine parameters."
),
},
{
"name": "ODBC (SQL Endpoints)",
"pypi_package": "pyodbc",
"connection_string": (
"databricks+pyodbc://token:{access_token}@{host}:{port}/{database}"
),
"is_recommended": False,
"notes": "Requires ODBC driver. For serverless SQL warehouses.",
},
{
"name": "databricks-dbapi (Legacy)",
"pypi_package": "databricks-dbapi[sqlalchemy]",
"connection_string": (
"databricks+connector://token:{access_token}@{host}:{port}/{database}"
),
"is_recommended": False,
"notes": "Legacy connector. Use Python Connector for new deployments.",
},
],
}
context_key_mapping = {
**DatabricksDynamicBaseEngineSpec.context_key_mapping,
"default_catalog": "catalog",
"default_schema": "schema",
"http_path_field": "http_path",
}
required_parameters = DatabricksDynamicBaseEngineSpec.required_parameters | {
"default_catalog",
"default_schema",
"http_path_field",
}
supports_dynamic_schema = supports_catalog = supports_dynamic_catalog = True
@classmethod
def build_sqlalchemy_uri( # type: ignore
cls, parameters: DatabricksPythonConnectorParametersType, *_
) -> str:
query = {}
if http_path := parameters.get("http_path_field"):
query["http_path"] = http_path
if catalog := parameters.get("default_catalog"):
query["catalog"] = catalog
if schema := parameters.get("default_schema"):
query["schema"] = schema
if parameters.get("encryption"):
query.update(cls.encryption_parameters)
return str(
URL.create(
cls.engine,
username="token",
password=parameters.get("access_token"),
host=parameters["host"],
port=parameters["port"],
query=query,
)
)
@classmethod
def get_parameters_from_uri( # type: ignore
cls, uri: str, *_: Any, **__: Any
) -> DatabricksPythonConnectorParametersType:
url = make_url_safe(uri)
query = {
key: value
for (key, value) in url.query.items()
if (key, value) not in cls.encryption_parameters.items()
}
encryption = all(
item in url.query.items() for item in cls.encryption_parameters.items()
)
return {
"access_token": url.password,
"host": url.host,
"port": url.port,
"http_path_field": query["http_path"],
"default_catalog": query["catalog"],
"default_schema": query["schema"],
"encryption": encryption,
}
@classmethod
def get_default_catalog(
cls,
database: Database,
) -> str | None:
return database.url_object.query.get("catalog")
@classmethod
def get_catalog_names(
cls,
database: Database,
inspector: Inspector,
) -> set[str]:
return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")}
@classmethod
def adjust_engine_params(
cls,
uri: URL,
connect_args: dict[str, Any],
catalog: str | None = None,
schema: str | None = None,
) -> tuple[URL, dict[str, Any]]:
if catalog:
uri = uri.update_query_dict({"catalog": catalog})
if schema:
uri = uri.update_query_dict({"schema": schema})
return uri, connect_args
# TODO: remove once we've upgraded to SQLAlchemy>=2.0 and databricks-sql-python>=3.x
monkeypatch_dialect()