# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import logging import re from re import Pattern from typing import Any, Optional, Union from urllib import parse from flask_babel import gettext as __ from sqlalchemy import Float, Integer, Numeric, types from sqlalchemy.engine.url import URL from sqlalchemy.sql.type_api import TypeEngine from superset.db_engine_specs.mysql import MySQLEngineSpec from superset.errors import SupersetErrorType from superset.models.core import Database from superset.utils.core import GenericDataType # Regular expressions to catch custom errors CONNECTION_ACCESS_DENIED_REGEX = re.compile( "Access denied for user '(?P.*?)'" ) CONNECTION_UNKNOWN_DATABASE_REGEX = re.compile("Unknown database '(?P.*?)'") logger = logging.getLogger(__name__) class TINYINT(Integer): __visit_name__ = "TINYINT" class LARGEINT(Integer): __visit_name__ = "LARGEINT" class DOUBLE(Float): __visit_name__ = "DOUBLE" class HLL(Numeric): __visit_name__ = "HLL" class BITMAP(Numeric): __visit_name__ = "BITMAP" class PERCENTILE(Numeric): __visit_name__ = "PERCENTILE" class ARRAY(TypeEngine): __visit_name__ = "ARRAY" @property def python_type(self) -> Optional[type[list[Any]]]: return list class MAP(TypeEngine): __visit_name__ = "MAP" @property def python_type(self) -> Optional[type[dict[Any, Any]]]: return dict class STRUCT(TypeEngine): __visit_name__ = "STRUCT" @property def python_type(self) -> Optional[type[Any]]: return None class StarRocksEngineSpec(MySQLEngineSpec): engine = "starrocks" engine_name = "StarRocks" default_driver = "starrocks" sqlalchemy_uri_placeholder = ( "starrocks://user:password@host:port/catalog.db[?key=value&key=value...]" ) column_type_mappings = ( # type: ignore ( re.compile(r"^tinyint", re.IGNORECASE), TINYINT(), GenericDataType.NUMERIC, ), ( re.compile(r"^largeint", re.IGNORECASE), LARGEINT(), GenericDataType.NUMERIC, ), ( re.compile(r"^decimal.*", re.IGNORECASE), types.DECIMAL(), GenericDataType.NUMERIC, ), ( re.compile(r"^double", re.IGNORECASE), DOUBLE(), GenericDataType.NUMERIC, ), ( re.compile(r"^varchar(\((\d+)\))*$", re.IGNORECASE), types.VARCHAR(), GenericDataType.STRING, ), ( re.compile(r"^char(\((\d+)\))*$", re.IGNORECASE), types.CHAR(), GenericDataType.STRING, ), ( re.compile(r"^json", re.IGNORECASE), types.JSON(), GenericDataType.STRING, ), ( re.compile(r"^binary.*", re.IGNORECASE), types.String(), GenericDataType.STRING, ), ( re.compile(r"^percentile", re.IGNORECASE), PERCENTILE(), GenericDataType.STRING, ), (re.compile(r"^hll", re.IGNORECASE), HLL(), GenericDataType.STRING), (re.compile(r"^bitmap", re.IGNORECASE), BITMAP(), GenericDataType.STRING), (re.compile(r"^array.*", re.IGNORECASE), ARRAY(), GenericDataType.STRING), (re.compile(r"^map.*", re.IGNORECASE), MAP(), GenericDataType.STRING), (re.compile(r"^struct.*", re.IGNORECASE), STRUCT(), GenericDataType.STRING), ) custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { CONNECTION_ACCESS_DENIED_REGEX: ( __('Either the username "%(username)s" or the password is incorrect.'), SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, {"invalid": ["username", "password"]}, ), CONNECTION_UNKNOWN_DATABASE_REGEX: ( __('Unable to connect to database "%(database)s".'), SupersetErrorType.CONNECTION_UNKNOWN_DATABASE_ERROR, {"invalid": ["database"]}, ), } @classmethod def adjust_engine_params( cls, uri: URL, connect_args: dict[str, Any], catalog: Optional[str] = None, schema: Optional[str] = None, ) -> tuple[URL, dict[str, Any]]: database = uri.database if schema and database: schema = parse.quote(schema, safe="") if "." in database: database = database.split(".")[0] + "." + schema else: database = "default_catalog." + schema uri = uri.set(database=database) return uri, connect_args @classmethod def get_schema_from_engine_params( cls, sqlalchemy_uri: URL, connect_args: dict[str, Any], ) -> Optional[str]: """ Return the configured schema. For StarRocks the SQLAlchemy URI looks like this: starrocks://localhost:9030/catalog.schema """ database = sqlalchemy_uri.database.strip("/") if "." not in database: return None return parse.unquote(database.split(".")[1]) @classmethod def impersonate_user( cls, database: Database, username: str | None, user_token: str | None, url: URL, engine_kwargs: dict[str, Any], ) -> tuple[URL, dict[str, Any]]: """ Impersonate the given user. User impersonation is actually achieved via `get_prequeries`, so this method needs to ensure that the username is not added to the URL when user impersonation is enabled (the behavior of the base class). """ return url, engine_kwargs @classmethod def get_prequeries( cls, database: Database, catalog: Union[str, None] = None, schema: Union[str, None] = None, ) -> list[str]: """ Return pre-session queries. These are currently used as an alternative to ``adjust_engine_params`` for databases where the selected schema cannot be specified in the SQLAlchemy URI or connection arguments. For example, in order to specify a default schema in RDS we need to run a query at the beginning of the session: sql> set search_path = my_schema; """ if database.impersonate_user: username = database.get_effective_user(database.url_object) if username: return [f'EXECUTE AS "{username}" WITH NO REVERT;'] return []