# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from __future__ import annotations import logging import re from datetime import datetime from typing import Any, cast, TYPE_CHECKING from urllib import parse from flask import current_app as app from flask_babel import gettext as __ from marshmallow import fields, Schema from marshmallow.validate import Range from sqlalchemy import types from sqlalchemy.engine.url import URL from urllib3.exceptions import NewConnectionError from superset.databases.utils import make_url_safe from superset.db_engine_specs.base import ( BaseEngineSpec, BasicParametersMixin, BasicParametersType, BasicPropertiesType, ) from superset.db_engine_specs.exceptions import SupersetDBAPIDatabaseError from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.extensions import cache_manager from superset.utils.core import GenericDataType from superset.utils.hashing import hash_from_str from superset.utils.network import is_hostname_valid, is_port_open if TYPE_CHECKING: from superset.models.core import Database logger = logging.getLogger(__name__) class ClickHouseBaseEngineSpec(BaseEngineSpec): """Shared engine spec for ClickHouse.""" time_groupby_inline = True supports_multivalues_insert = True _time_grain_expressions = { None: "{col}", "PT1M": "toStartOfMinute(toDateTime({col}))", "PT5M": "toDateTime(intDiv(toUInt32(toDateTime({col})), 300)*300)", "PT10M": "toDateTime(intDiv(toUInt32(toDateTime({col})), 600)*600)", "PT15M": "toDateTime(intDiv(toUInt32(toDateTime({col})), 900)*900)", "PT30M": "toDateTime(intDiv(toUInt32(toDateTime({col})), 1800)*1800)", "PT1H": "toStartOfHour(toDateTime({col}))", "P1D": "toStartOfDay(toDateTime({col}))", "P1W": "toMonday(toDateTime({col}))", "P1M": "toStartOfMonth(toDateTime({col}))", "P3M": "toStartOfQuarter(toDateTime({col}))", "P1Y": "toStartOfYear(toDateTime({col}))", } column_type_mappings = ( ( re.compile(r".*Enum.*", re.IGNORECASE), types.String(), GenericDataType.STRING, ), ( re.compile(r".*Array.*", re.IGNORECASE), types.String(), GenericDataType.STRING, ), ( re.compile(r".*UUID.*", re.IGNORECASE), types.String(), GenericDataType.STRING, ), ( re.compile(r".*Bool.*", re.IGNORECASE), types.Boolean(), GenericDataType.BOOLEAN, ), ( re.compile(r".*String.*", re.IGNORECASE), types.String(), GenericDataType.STRING, ), ( re.compile(r".*Int\d+.*", re.IGNORECASE), types.INTEGER(), GenericDataType.NUMERIC, ), ( re.compile(r".*Decimal.*", re.IGNORECASE), types.DECIMAL(), GenericDataType.NUMERIC, ), ( re.compile(r".*DateTime.*", re.IGNORECASE), types.DateTime(), GenericDataType.TEMPORAL, ), ( re.compile(r".*Date.*", re.IGNORECASE), types.Date(), GenericDataType.TEMPORAL, ), ) @classmethod def epoch_to_dttm(cls) -> str: return "{col}" @classmethod def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None ) -> str | None: sqla_type = cls.get_sqla_column_type(target_type) if isinstance(sqla_type, types.Date): return f"toDate('{dttm.date().isoformat()}')" if isinstance(sqla_type, types.DateTime): return f"""toDateTime('{dttm.isoformat(sep=" ", timespec="seconds")}')""" return None class ClickHouseEngineSpec(ClickHouseBaseEngineSpec): """Engine spec for clickhouse_sqlalchemy connector""" engine = "clickhouse" engine_name = "ClickHouse" _show_functions_column = "name" supports_file_upload = False @classmethod def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: return {NewConnectionError: SupersetDBAPIDatabaseError} @classmethod def get_dbapi_mapped_exception(cls, exception: Exception) -> Exception: new_exception = cls.get_dbapi_exception_mapping().get(type(exception)) if new_exception == SupersetDBAPIDatabaseError: return SupersetDBAPIDatabaseError("Connection failed") if not new_exception: return exception return new_exception(str(exception)) @classmethod @cache_manager.cache.memoize() def get_function_names(cls, database: Database) -> list[str]: """ Get a list of function names that are able to be called on the database. Used for SQL Lab autocomplete. :param database: The database to get functions for :return: A list of function names usable in the database """ system_functions_sql = "SELECT name FROM system.functions" try: df = database.get_df(system_functions_sql) if cls._show_functions_column in df: return df[cls._show_functions_column].tolist() columns = df.columns.values.tolist() logger.error( "Payload from `%s` has the incorrect format. " "Expected column `%s`, found: %s.", system_functions_sql, cls._show_functions_column, ", ".join(columns), exc_info=True, ) # if the results have a single column, use that if len(columns) == 1: return df[columns[0]].tolist() except Exception as ex: # pylint: disable=broad-except logger.error( "Query `%s` fire error %s. ", system_functions_sql, str(ex), exc_info=True, ) return [] # otherwise, return no function names to prevent errors return [] class ClickHouseParametersSchema(Schema): username = fields.String(allow_none=True, metadata={"description": __("Username")}) password = fields.String(allow_none=True, metadata={"description": __("Password")}) host = fields.String( required=True, metadata={"description": __("Hostname or IP address")} ) port = fields.Integer( allow_none=True, metadata={"description": __("Database port")}, validate=Range(min=0, max=65535), ) database = fields.String( allow_none=True, metadata={"description": __("Database name")} ) encryption = fields.Boolean( dump_default=True, metadata={"description": __("Use an encrypted connection to the database")}, ) query = fields.Dict( keys=fields.Str(), values=fields.Raw(), metadata={"description": __("Additional parameters")}, ) ssh = fields.Boolean( required=False, metadata={"description": __("Use an ssh tunnel connection to the database")}, ) try: from clickhouse_connect.common import set_setting from clickhouse_connect.datatypes.format import set_default_formats # override default formats for compatibility set_default_formats( "FixedString", "string", "IPv*", "string", "UInt64", "signed", "UUID", "string", "*Int256", "string", "*Int128", "string", ) set_setting( "product_name", f"superset/{app.config.get('VERSION_STRING', 'dev')}", ) except ImportError: # ClickHouse Connect not installed, do nothing pass class ClickHouseConnectEngineSpec(BasicParametersMixin, ClickHouseEngineSpec): """Engine spec for clickhouse-connect connector""" engine = "clickhousedb" engine_name = "ClickHouse Connect (Superset)" default_driver = "connect" _function_names: list[str] = [] sqlalchemy_uri_placeholder = ( "clickhousedb://user:password@host[:port][/dbname][?secure=value&=value...]" ) parameters_schema = ClickHouseParametersSchema() encryption_parameters = {"secure": "true"} supports_dynamic_schema = True @classmethod def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: return {} @classmethod def get_dbapi_mapped_exception(cls, exception: Exception) -> Exception: new_exception = cls.get_dbapi_exception_mapping().get(type(exception)) if new_exception == SupersetDBAPIDatabaseError: return SupersetDBAPIDatabaseError("Connection failed") if not new_exception: return exception return new_exception(str(exception)) @classmethod def get_function_names(cls, database: Database) -> list[str]: # pylint: disable=import-outside-toplevel, import-error from clickhouse_connect.driver.exceptions import ClickHouseError if cls._function_names: return cls._function_names try: names = database.get_df( "SELECT name FROM system.functions UNION ALL " # noqa: S608 + "SELECT name FROM system.table_functions LIMIT 10000" )["name"].tolist() cls._function_names = names return names except ClickHouseError: logger.exception("Error retrieving system.functions") return [] @classmethod def get_datatype(cls, type_code: str) -> str: # keep it lowercase, as ClickHouse types aren't typical SHOUTCASE ANSI SQL return type_code @classmethod def build_sqlalchemy_uri( cls, parameters: BasicParametersType, encrypted_extra: dict[str, str] | None = None, ) -> str: url_params = parameters.copy() if url_params.get("encryption"): query = parameters.get("query", {}).copy() query.update(cls.encryption_parameters) url_params["query"] = query if not url_params.get("database"): url_params["database"] = "__default__" return str( URL.create( f"{cls.engine}+{cls.default_driver}", username=url_params.get("username"), password=url_params.get("password"), host=url_params.get("host"), port=url_params.get("port"), database=url_params.get("database"), query=url_params.get("query"), ) ) @classmethod def get_parameters_from_uri( cls, uri: str, encrypted_extra: dict[str, Any] | None = None ) -> BasicParametersType: url = make_url_safe(uri) query = dict(url.query) if "secure" in query: encryption = query.get("secure") == "true" query.pop("secure") else: encryption = False return BasicParametersType( username=url.username, password=url.password, host=url.host, port=url.port, database="" if url.database == "__default__" else cast(str, url.database), query=query, encryption=encryption, ) @classmethod def validate_parameters( cls, properties: BasicPropertiesType ) -> list[SupersetError]: # pylint: disable=import-outside-toplevel, import-error from clickhouse_connect.driver import default_port parameters = properties.get("parameters", {}) host = parameters.get("host", None) if not host: return [ SupersetError( "Hostname is required", SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR, ErrorLevel.WARNING, {"missing": ["host"]}, ) ] if not is_hostname_valid(host): return [ SupersetError( "The hostname provided can't be resolved.", SupersetErrorType.CONNECTION_INVALID_HOSTNAME_ERROR, ErrorLevel.ERROR, {"invalid": ["host"]}, ) ] port = parameters.get("port") if port is None: port = default_port("http", parameters.get("encryption", False)) try: port = int(port) except (ValueError, TypeError): port = -1 if port <= 0 or port >= 65535: return [ SupersetError( "Port must be a valid integer between 0 and 65535 (inclusive).", SupersetErrorType.CONNECTION_INVALID_PORT_ERROR, ErrorLevel.ERROR, {"invalid": ["port"]}, ) ] if not is_port_open(host, port): return [ SupersetError( "The port is closed.", SupersetErrorType.CONNECTION_PORT_CLOSED_ERROR, ErrorLevel.ERROR, {"invalid": ["port"]}, ) ] return [] @staticmethod def _mutate_label(label: str) -> str: """ Suffix with the first six characters from the md5 of the label to avoid collisions with original column names :param label: Expected expression label :return: Conditionally mutated label """ return f"{label}_{hash_from_str(label)[:6]}" @classmethod def adjust_engine_params( cls, uri: URL, connect_args: dict[str, Any], catalog: str | None = None, schema: str | None = None, ) -> tuple[URL, dict[str, Any]]: if schema: uri = uri.set(database=parse.quote(schema, safe="")) return uri, connect_args