diff --git a/docs/static/img/databases/datastore.png b/docs/static/img/databases/datastore.png new file mode 100644 index 00000000000..a5fd9dab4a0 Binary files /dev/null and b/docs/static/img/databases/datastore.png differ diff --git a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/EncryptedField.tsx b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/EncryptedField.tsx index 57c81cdae66..cd21f730a9a 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/EncryptedField.tsx +++ b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/EncryptedField.tsx @@ -43,6 +43,7 @@ enum CredentialInfoOptions { export const encryptedCredentialsMap = { gsheets: 'service_account_info', bigquery: 'credentials_info', + datastore: 'credentials_info', }; export const EncryptedField = ({ diff --git a/superset/commands/database/validate.py b/superset/commands/database/validate.py index 29c9497140b..4d9952ad89a 100644 --- a/superset/commands/database/validate.py +++ b/superset/commands/database/validate.py @@ -34,7 +34,7 @@ from superset.extensions import event_logger from superset.models.core import Database from superset.utils import json -BYPASS_VALIDATION_ENGINES = {"bigquery", "snowflake"} +BYPASS_VALIDATION_ENGINES = {"bigquery", "datastore", "snowflake"} class ValidateDatabaseParametersCommand(BaseCommand): diff --git a/superset/db_engine_specs/datastore.py b/superset/db_engine_specs/datastore.py new file mode 100644 index 00000000000..bc11bf9bca3 --- /dev/null +++ b/superset/db_engine_specs/datastore.py @@ -0,0 +1,610 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import logging +import re +from datetime import datetime +from re import Pattern +from typing import Any, TYPE_CHECKING, TypedDict +from urllib import parse + +from apispec import APISpec +from apispec.ext.marshmallow import MarshmallowPlugin +from flask_babel import gettext as __ +from marshmallow import fields, Schema +from marshmallow.exceptions import ValidationError +from sqlalchemy import column, types +from sqlalchemy.engine.base import Engine +from sqlalchemy.engine.interfaces import Dialect +from sqlalchemy.engine.reflection import Inspector +from sqlalchemy.engine.url import URL +from sqlalchemy.sql import sqltypes + +from superset.constants import TimeGrain +from superset.databases.schemas import encrypted_field_properties, EncryptedString +from superset.databases.utils import make_url_safe +from superset.db_engine_specs.base import ( + BaseEngineSpec, + BasicPropertiesType, + DatabaseCategory, +) +from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError +from superset.errors import SupersetError, SupersetErrorType +from superset.exceptions import SupersetException +from superset.models.core import Database +from superset.sql.parse import LimitMethod, Table +from superset.superset_typing import ResultSetColumnType +from superset.utils import json +from superset.utils.hashing import hash_from_str + +logger = logging.getLogger(__name__) + +try: + import google.auth + from google.cloud import datastore + from google.oauth2 import service_account + + dependencies_installed = True +except ImportError: + dependencies_installed = False + +if TYPE_CHECKING: + from superset.models.sql_lab import Query # pragma: no cover + +CONNECTION_DATABASE_PERMISSIONS_REGEX = re.compile( + "Access Denied: Project (?P.+?): User does not have " + + "datastore.databases.create permission in project (?P.+?)" +) + +TABLE_DOES_NOT_EXIST_REGEX = re.compile( + 'Table name "(?P.*?)" missing dataset while no default ' + "dataset is set in the request" +) + +COLUMN_DOES_NOT_EXIST_REGEX = re.compile( + r"Unrecognized name: (?P.*?) at \[(?P.+?)\]" +) + +SCHEMA_DOES_NOT_EXIST_REGEX = re.compile( + r"datastore error: 404 Not found: Dataset (?P.*?):" + r"(?P.*?) was not found in location" +) + +SYNTAX_ERROR_REGEX = re.compile( + 'Syntax error: Expected end of input but got identifier "(?P.+?)"' +) + +ma_plugin = MarshmallowPlugin() + + +class DatastoreParametersSchema(Schema): + credentials_info = EncryptedString( + required=False, + metadata={"description": "Contents of Datastore JSON credentials."}, + ) + query = fields.Dict(required=False) + + +class DatastoreParametersType(TypedDict): + credentials_info: dict[str, Any] + query: dict[str, Any] + + +class DatastoreEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-methods + """Engine spec for Google's Datastore + + As contributed by @hychang.1997.tw""" + + engine = "datastore" + engine_name = "Google Datastore" + max_column_name_length = 128 + disable_ssh_tunneling = True + + parameters_schema = DatastoreParametersSchema() + default_driver = "datastore" + sqlalchemy_uri_placeholder = "datastore://{project_id}/?database={database_id}" + + # Use FETCH_MANY to prevent Superset from injecting LIMIT via sqlglot AST + # manipulation. GQL queries should not be modified by sqlglot since it + # uses BigQuery dialect which transforms GQL-incompatible syntax. + limit_method = LimitMethod.FETCH_MANY + + metadata = { + "description": ( + "Google Cloud Datastore is a highly scalable NoSQL database " + "for your applications." + ), + "logo": "datastore.png", + "homepage_url": "https://cloud.google.com/datastore/", + "categories": [ + DatabaseCategory.CLOUD_GCP, + DatabaseCategory.SEARCH_NOSQL, + DatabaseCategory.PROPRIETARY, + ], + "pypi_packages": ["python-datastore-sqlalchemy"], + "connection_string": "datastore://{project_id}/?database={database_id}", + "authentication_methods": [ + { + "name": "Service Account JSON", + "description": ( + "Upload service account credentials JSON or paste in Secure Extra" + ), + "secure_extra": { + "credentials_info": { + "type": "service_account", + "project_id": "...", + "private_key_id": "...", + "private_key": "...", + "client_email": "...", + "client_id": "...", + "auth_uri": "...", + "token_uri": "...", + } + }, + }, + ], + "notes": ( + "Create a Service Account via GCP console with access to " + "datastore datasets." + ), + "docs_url": "https://github.com/splasky/Python-datastore-sqlalchemy", + } + + # Datastore doesn't maintain context when running multiple statements in the + # same cursor, so we need to run all statements at once + run_multiple_statements_as_one = True + + allows_hidden_cc_in_orderby = True + + supports_dynamic_schema = True + supports_catalog = supports_dynamic_catalog = supports_cross_catalog_queries = True + + # when editing the database, mask this field in `encrypted_extra` + # pylint: disable=invalid-name + encrypted_extra_sensitive_fields = {"$.credentials_info.private_key"} + + """ + https://www.python.org/dev/peps/pep-0249/#arraysize + raw_connections bypass the sqlalchemy-datastore query execution context and deal + with raw dbapi connection directly. + If this value is not set, the default value is set to 1. + """ + arraysize = 5000 + + _date_trunc_functions = { + "DATE": "DATE_TRUNC", + "DATETIME": "DATETIME_TRUNC", + "TIME": "TIME_TRUNC", + "TIMESTAMP": "TIMESTAMP_TRUNC", + } + + _time_grain_expressions = { + None: "{col}", + TimeGrain.SECOND: "CAST(TIMESTAMP_SECONDS(" + "UNIX_SECONDS(CAST({col} AS TIMESTAMP))" + ") AS {type})", + TimeGrain.MINUTE: "CAST(TIMESTAMP_SECONDS(" + "60 * DIV(UNIX_SECONDS(CAST({col} AS TIMESTAMP)), 60)" + ") AS {type})", + TimeGrain.FIVE_MINUTES: "CAST(TIMESTAMP_SECONDS(" + "5*60 * DIV(UNIX_SECONDS(CAST({col} AS TIMESTAMP)), 5*60)" + ") AS {type})", + TimeGrain.TEN_MINUTES: "CAST(TIMESTAMP_SECONDS(" + "10*60 * DIV(UNIX_SECONDS(CAST({col} AS TIMESTAMP)), 10*60)" + ") AS {type})", + TimeGrain.FIFTEEN_MINUTES: "CAST(TIMESTAMP_SECONDS(" + "15*60 * DIV(UNIX_SECONDS(CAST({col} AS TIMESTAMP)), 15*60)" + ") AS {type})", + TimeGrain.THIRTY_MINUTES: "CAST(TIMESTAMP_SECONDS(" + "30*60 * DIV(UNIX_SECONDS(CAST({col} AS TIMESTAMP)), 30*60)" + ") AS {type})", + TimeGrain.HOUR: "{func}({col}, HOUR)", + TimeGrain.DAY: "{func}({col}, DAY)", + TimeGrain.WEEK: "{func}({col}, WEEK)", + TimeGrain.WEEK_STARTING_MONDAY: "{func}({col}, ISOWEEK)", + TimeGrain.MONTH: "{func}({col}, MONTH)", + TimeGrain.QUARTER: "{func}({col}, QUARTER)", + TimeGrain.YEAR: "{func}({col}, YEAR)", + } + + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { + CONNECTION_DATABASE_PERMISSIONS_REGEX: ( + __( + "Unable to connect. Verify that the following roles are set " + 'on the service account: "Cloud Datastore Viewer", ' + '"Cloud Datastore User", "Cloud Datastore Creator"' + ), + SupersetErrorType.CONNECTION_DATABASE_PERMISSIONS_ERROR, + {}, + ), + TABLE_DOES_NOT_EXIST_REGEX: ( + __( + 'The table "%(table)s" does not exist. ' + "A valid table must be used to run this query.", + ), + SupersetErrorType.TABLE_DOES_NOT_EXIST_ERROR, + {}, + ), + COLUMN_DOES_NOT_EXIST_REGEX: ( + __('We can\'t seem to resolve column "%(column)s" at line %(location)s.'), + SupersetErrorType.COLUMN_DOES_NOT_EXIST_ERROR, + {}, + ), + SCHEMA_DOES_NOT_EXIST_REGEX: ( + __( + 'The schema "%(schema)s" does not exist. ' + "A valid schema must be used to run this query." + ), + SupersetErrorType.SCHEMA_DOES_NOT_EXIST_ERROR, + {}, + ), + SYNTAX_ERROR_REGEX: ( + __( + "Please check your query for syntax errors at or near " + '"%(syntax_error)s". Then, try running your query again.' + ), + SupersetErrorType.SYNTAX_ERROR, + {}, + ), + } + + @staticmethod + def _mutate_label(label: str) -> str: + """ + Datastore field_name should start with a letter or underscore and contain + only alphanumeric characters. Labels that start with a number are prefixed + with an underscore. Any unsupported characters are replaced with underscores + and an md5 hash is added to the end of the label to avoid possible + collisions. + + :param label: Expected expression label + :return: Conditionally mutated label + """ + label_hashed = "_" + hash_from_str(label) + + # if label starts with number, add underscore as first character + label_mutated = "_" + label if re.match(r"^\d", label) else label + + # replace non-alphanumeric characters with underscores + label_mutated = re.sub(r"[^\w]+", "_", label_mutated) + if label_mutated != label: + # add first 5 chars from md5 hash to label to avoid possible collisions + label_mutated += label_hashed[:6] + + return label_mutated + + @classmethod + def _truncate_label(cls, label: str) -> str: + """Datastore requires column names start with either a letter or + underscore. To make sure this is always the case, an underscore is prefixed + to the md5 hash of the original label. + + :param label: expected expression label + :return: truncated label + """ + return "_" + hash_from_str(label) + + @classmethod + def convert_dttm( + cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None + ) -> str | None: + sqla_type = cls.get_sqla_column_type(target_type) + if isinstance(sqla_type, types.Date): + return f"CAST('{dttm.date().isoformat()}' AS DATE)" + if isinstance(sqla_type, types.TIMESTAMP): + return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS TIMESTAMP)""" + if isinstance(sqla_type, types.DateTime): + return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS DATETIME)""" + if isinstance(sqla_type, types.Time): + return f"""CAST('{dttm.strftime("%H:%M:%S.%f")}' AS TIME)""" + return None + + @classmethod + def fetch_data(cls, cursor: Any, limit: int | None = None) -> list[tuple[Any, ...]]: + data = super().fetch_data(cursor, limit) + # Support google.cloud.datastore Row type which has a values() method + if data and hasattr(data[0], "values"): + data = [r.values() for r in data] # type: ignore + return data + + @classmethod + def _get_client(cls, engine: Engine, database: Database) -> datastore.Client: + """ + Return the Datastore client associated with an engine. + """ + if not dependencies_installed: + raise SupersetException( + "Could not import libraries needed to connect to Datastore." + ) + + database_id = engine.url.query.get("database") + + if credentials_info := engine.dialect.credentials_info: + credentials = service_account.Credentials.from_service_account_info( + credentials_info + ) + return datastore.Client(credentials=credentials, database=database_id) + + try: + credentials = google.auth.default()[0] + return datastore.Client(credentials=credentials, database=database_id) + except google.auth.exceptions.DefaultCredentialsError as ex: + raise SupersetDBAPIConnectionError( + "The database credentials could not be found." + ) from ex + + @classmethod + def get_default_catalog(cls, database: Database) -> str: + """ + Get the default catalog. + """ + url = database.url_object + + # The SQLAlchemy driver accepts both `datastore://project` (where the project is + # technically a host) and `datastore:///project` (where it's a database). But + # both can be missing, and the project is inferred from the authentication + # credentials. + if project := url.host or url.database: + return project + + with database.get_sqla_engine() as engine: + client = cls._get_client(engine, database) + return client.project + + @classmethod + def get_catalog_names( + cls, + database: Database, + inspector: Inspector, + ) -> set[str]: + """ + Get all catalogs. + + In Datastore, a catalog is called a "project". + """ + return super().get_catalog_names(database, inspector) + + @classmethod + def adjust_engine_params( + cls, + uri: URL, + connect_args: dict[str, Any], + catalog: str | None = None, + schema: str | None = None, + ) -> tuple[URL, dict[str, Any]]: + if catalog: + uri = uri.set(host=catalog, database="") + + return uri, connect_args + + @classmethod + def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool: + return False + + @classmethod + def build_sqlalchemy_uri( + cls, + parameters: DatastoreParametersType, + encrypted_extra: dict[str, Any] | None = None, + ) -> str: + query = parameters.get("query", {}) + query_params = parse.urlencode(query) + + if not encrypted_extra: + raise ValidationError("Missing service credentials") + + credentials_info = encrypted_extra.get("credentials_info", {}) + if isinstance(credentials_info, str): + credentials_info = json.loads(credentials_info) + + if project_id := credentials_info.get("project_id"): + return f"{cls.default_driver}://{project_id}/?{query_params}" + + raise ValidationError("Invalid service credentials") + + @classmethod + def get_parameters_from_uri( + cls, + uri: str, + encrypted_extra: dict[str, Any] | None = None, + ) -> Any: + value = make_url_safe(uri) + + # Building parameters from encrypted_extra and uri + if encrypted_extra: + # ``value.query`` needs to be explicitly converted into a dict (from an + # ``immutabledict``) so that it can be JSON serialized + return {**encrypted_extra, "query": dict(value.query)} + + raise ValidationError("Invalid service credentials") + + @classmethod + def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: + # pylint: disable=import-outside-toplevel + from google.auth.exceptions import DefaultCredentialsError + + return {DefaultCredentialsError: SupersetDBAPIConnectionError} + + @classmethod + def validate_parameters( + cls, + properties: BasicPropertiesType, # pylint: disable=unused-argument + ) -> list[SupersetError]: + return [] + + @classmethod + def parameters_json_schema(cls) -> Any: + """ + Return configuration parameters as OpenAPI. + """ + if not cls.parameters_schema: + return None + + spec = APISpec( + title="Database Parameters", + version="1.0.0", + openapi_version="3.0.0", + plugins=[ma_plugin], + ) + + ma_plugin.init_spec(spec) + ma_plugin.converter.add_attribute_function(encrypted_field_properties) + spec.components.schema(cls.__name__, schema=cls.parameters_schema) + return spec.to_dict()["components"]["schemas"][cls.__name__] + + @classmethod + def select_star( # pylint: disable=too-many-arguments + cls, + database: Database, + table: Table, + dialect: Dialect, + limit: int = 100, + show_cols: bool = False, + indent: bool = True, + latest_partition: bool = True, + cols: list[ResultSetColumnType] | None = None, + ) -> str: + """ + Remove array structures from ``SELECT *``. + + Datastore supports structures and arrays of structures. When loading + metadata for a table, each key in the struct is displayed as a separate + pseudo-column. When generating the ``SELECT *`` statement we want to + remove any keys from structs inside an array, since selecting them + results in an error. + + This method removes any array pseudo-columns. + """ + if cols: + array_prefixes = { + col["column_name"] + for col in cols + if isinstance(col["type"], sqltypes.ARRAY) + } + cols = [ + col + for col in cols + if "." not in col["column_name"] + or col["column_name"].split(".")[0] not in array_prefixes + ] + + return super().select_star( + database, + table, + dialect, + limit, + show_cols, + indent, + latest_partition, + cols, + ) + + @classmethod + def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[Any]: + """ + Label columns using their fully qualified name. + + Datastore supports columns of type `struct`, which are basically dictionaries. + When loading metadata for a table with struct columns, each key in the struct + is displayed as a separate pseudo-column, eg: + + author STRUCT + + Will be shown as 3 columns: + + - author + - author.name + - author.email + + If we select those fields: + + SELECT `author`, `author`.`name`, `author`.`email` FROM table + + The resulting columns will be called "author", "name", and "email", This may + result in a clash with other columns. To prevent that, we explicitly label + the columns using their fully qualified name, so we end up with "author", + "author__name" and "author__email", respectively. + """ + return [ + column(c["column_name"]).label(c["column_name"].replace(".", "__")) + for c in cols + ] + + @classmethod + def execute_with_cursor( + cls, + cursor: Any, + sql: str, + query: Query, + ) -> None: + """Execute query and capture any warnings from the cursor. + + The Datastore DBAPI cursor collects warnings when a query falls + back to fetching all entities client-side (SELECT * mode) due to + missing indexes. These warnings are stored in the query's + extra_json so they can be surfaced to the user in the UI. + """ + super().execute_with_cursor(cursor, sql, query) + if hasattr(cursor, "warnings") and cursor.warnings: + query.set_extra_json_key("warnings", cursor.warnings) + + @classmethod + def parse_error_exception(cls, exception: Exception) -> Exception: + try: + return type(exception)(str(exception).splitlines()[0].strip()) + except Exception: # pylint: disable=broad-except + # If for some reason we get an exception, for example, no new line + # We will return the original exception + return exception + + @classmethod + def get_function_names( # pylint: disable=unused-argument + cls, + database: Database, + ) -> list[str]: + """ + Get a list of function names that are able to be called on the database. + Used for SQL Lab autocomplete. + + :param database: The database to get functions for + :return: A list of function names useable in the database + """ + return ["sum", "avg", "count", "count_up_to", "min", "max"] + + @classmethod + def get_view_names( # pylint: disable=unused-argument + cls, + database: Database, + inspector: Inspector, + schema: str | None, + ) -> set[str]: + """ + Get all the view names within the specified schema. + + Per the SQLAlchemy definition if the schema is omitted the database’s default + schema is used, however some dialects infer the request as schema agnostic. + + The Datastore doesn't have a view. Return an empty set. + + :param database: The database to inspect + :param inspector: The SQLAlchemy inspector + :param schema: The schema to inspect + :returns: The view names + """ + return set() diff --git a/superset/sql/parse.py b/superset/sql/parse.py index fdbc524b149..4f9c92a2776 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -60,6 +60,7 @@ SQLGLOT_DIALECTS = { "ascend": Dialects.HIVE, "awsathena": Dialects.ATHENA, "bigquery": Dialects.BIGQUERY, + "datastore": Dialects.BIGQUERY, "clickhouse": Dialects.CLICKHOUSE, "clickhousedb": Dialects.CLICKHOUSE, "cockroachdb": Dialects.POSTGRES, diff --git a/tests/integration_tests/db_engine_specs/datastore_tests.py b/tests/integration_tests/db_engine_specs/datastore_tests.py new file mode 100644 index 00000000000..86cf1d8f7d2 --- /dev/null +++ b/tests/integration_tests/db_engine_specs/datastore_tests.py @@ -0,0 +1,521 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import unittest.mock as mock +from datetime import datetime +from typing import Any + +import pytest +from marshmallow.exceptions import ValidationError +from sqlalchemy import column + +pytest.importorskip("sqlalchemy_datastore") + +from sqlalchemy.engine.url import make_url + +from superset.connectors.sqla.models import TableColumn +from superset.db_engine_specs.base import BaseEngineSpec +from superset.db_engine_specs.datastore import DatastoreEngineSpec +from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.superset_typing import ResultSetColumnType +from tests.integration_tests.base_tests import SupersetTestCase +from tests.integration_tests.fixtures.birth_names_dashboard import ( + load_birth_names_dashboard_with_slices, # noqa: F401 + load_birth_names_data, # noqa: F401 +) + + +class TestDatastoreDbEngineSpec(SupersetTestCase): + def test_datastore_sqla_column_label(self): + """ + DB Eng Specs (datastore): Test column label + """ + # Expected labels with SHA-256 hash suffix (first 5 chars prefixed with _) + test_cases = { + "Col": "Col", + "SUM(x)": "SUM_x__b681e", + "SUM[x]": "SUM_x__ceaf6", + "12345_col": "_12345_col_b1415", + } + for original, expected in test_cases.items(): + actual = DatastoreEngineSpec.make_label_compatible(column(original).name) + assert actual == expected + + def test_timegrain_expressions(self): + """ + DB Eng Specs (datastore): Test time grain expressions + """ + col = column("temporal") + test_cases = { + "DATE": "DATE_TRUNC(temporal, HOUR)", + "TIME": "TIME_TRUNC(temporal, HOUR)", + "DATETIME": "DATETIME_TRUNC(temporal, HOUR)", + "TIMESTAMP": "TIMESTAMP_TRUNC(temporal, HOUR)", + } + for type_, expected in test_cases.items(): + col.type = type_ + actual = DatastoreEngineSpec.get_timestamp_expr( + col=col, pdf=None, time_grain="PT1H" + ) + assert str(actual) == expected + + def test_custom_minute_timegrain_expressions(self): + """ + DB Eng Specs (datastore): Test time grain expressions + """ + col = column("temporal") + test_cases = { + "DATE": "CAST(TIMESTAMP_SECONDS(" + "5*60 * DIV(UNIX_SECONDS(CAST(temporal AS TIMESTAMP)), 5*60)" + ") AS DATE)", + "DATETIME": "CAST(TIMESTAMP_SECONDS(" + "5*60 * DIV(UNIX_SECONDS(CAST(temporal AS TIMESTAMP)), 5*60)" + ") AS DATETIME)", + "TIMESTAMP": "CAST(TIMESTAMP_SECONDS(" + "5*60 * DIV(UNIX_SECONDS(CAST(temporal AS TIMESTAMP)), 5*60)" + ") AS TIMESTAMP)", + } + for type_, expected in test_cases.items(): + col.type = type_ + actual = DatastoreEngineSpec.get_timestamp_expr( + col=col, pdf=None, time_grain="PT5M" + ) + assert str(actual) == expected + + def test_fetch_data(self): + """ + DB Eng Specs (datastore): Test fetch data + """ + + # Mock a google.cloud.datastore.table.Row + class Row: + def __init__(self, value): + self._value = value + + def values(self): + return (self._value,) + + data1 = [(1, "foo")] + with mock.patch.object(BaseEngineSpec, "fetch_data", return_value=data1): + result = DatastoreEngineSpec.fetch_data(None, 0) + assert result == data1 + + data2 = [Row(1), Row(2)] + with mock.patch.object(BaseEngineSpec, "fetch_data", return_value=data2): + result = DatastoreEngineSpec.fetch_data(None, 0) + assert result == [(1,), (2,)] + + def test_extract_errors(self): + msg = "403 POST https://datastore.googleapis.com/: Access Denied: Project my-project: User does not have datastore.databases.create permission in project my-project" # noqa: E501 + result = DatastoreEngineSpec.extract_errors(Exception(msg)) + assert result == [ + SupersetError( + message="Unable to connect. Verify that the following roles are set " + 'on the service account: "Cloud Datastore Viewer", ' + '"Cloud Datastore User", "Cloud Datastore Creator"', + error_type=SupersetErrorType.CONNECTION_DATABASE_PERMISSIONS_ERROR, + level=ErrorLevel.ERROR, + extra={ + "engine_name": "Google Datastore", + "issue_codes": [ + { + "code": 1017, + "message": "", + } + ], + }, + ) + ] + + msg = "datastore error: 404 Not found: Dataset fakeDataset:bogusSchema was not found in location" # noqa: E501 + result = DatastoreEngineSpec.extract_errors(Exception(msg)) + assert result == [ + SupersetError( + message='The schema "bogusSchema" does not exist. A valid schema must be used to run this query.', # noqa: E501 + error_type=SupersetErrorType.SCHEMA_DOES_NOT_EXIST_ERROR, + level=ErrorLevel.ERROR, + extra={ + "engine_name": "Google Datastore", + "issue_codes": [ + { + "code": 1003, + "message": "Issue 1003 - There is a syntax error in the SQL query. Perhaps there was a misspelling or a typo.", # noqa: E501 + }, + { + "code": 1004, + "message": "Issue 1004 - The column was deleted or renamed in the database.", # noqa: E501 + }, + ], + }, + ) + ] + + msg = 'Table name "badtable" missing dataset while no default dataset is set in the request' # noqa: E501 + result = DatastoreEngineSpec.extract_errors(Exception(msg)) + assert result == [ + SupersetError( + message='The table "badtable" does not exist. A valid table must be used to run this query.', # noqa: E501 + error_type=SupersetErrorType.TABLE_DOES_NOT_EXIST_ERROR, + level=ErrorLevel.ERROR, + extra={ + "engine_name": "Google Datastore", + "issue_codes": [ + { + "code": 1003, + "message": "Issue 1003 - There is a syntax error in the SQL query. Perhaps there was a misspelling or a typo.", # noqa: E501 + }, + { + "code": 1005, + "message": "Issue 1005 - The table was deleted or renamed in the database.", # noqa: E501 + }, + ], + }, + ) + ] + + msg = "Unrecognized name: badColumn at [1:8]" + result = DatastoreEngineSpec.extract_errors(Exception(msg)) + assert result == [ + SupersetError( + message='We can\'t seem to resolve column "badColumn" at line 1:8.', + error_type=SupersetErrorType.COLUMN_DOES_NOT_EXIST_ERROR, + level=ErrorLevel.ERROR, + extra={ + "engine_name": "Google Datastore", + "issue_codes": [ + { + "code": 1003, + "message": "Issue 1003 - There is a syntax error in the SQL query. Perhaps there was a misspelling or a typo.", # noqa: E501 + }, + { + "code": 1004, + "message": "Issue 1004 - The column was deleted or renamed in the database.", # noqa: E501 + }, + ], + }, + ) + ] + + msg = 'Syntax error: Expected end of input but got identifier "from_"' + result = DatastoreEngineSpec.extract_errors(Exception(msg)) + assert result == [ + SupersetError( + message='Please check your query for syntax errors at or near "from_". Then, try running your query again.', # noqa: E501 + error_type=SupersetErrorType.SYNTAX_ERROR, + level=ErrorLevel.ERROR, + extra={ + "engine_name": "Google Datastore", + "issue_codes": [ + { + "code": 1030, + "message": "Issue 1030 - The query has a syntax error.", + } + ], + }, + ) + ] + + @mock.patch("superset.models.core.Database.db_engine_spec", DatastoreEngineSpec) + @mock.patch( + "sqlalchemy_datastore.base.create_datastore_client", + mock.Mock(return_value=(mock.Mock(), mock.Mock())), + ) + @mock.patch( + "sqlalchemy_datastore._helpers.create_datastore_client", + mock.Mock(return_value=(mock.Mock(), mock.Mock())), + ) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_calculated_column_in_order_by(self): + table = self.get_table(name="birth_names") + TableColumn( + column_name="gender_cc", + type="VARCHAR(255)", + table=table, + expression=""" + case + when gender='boy' then 'male' + else 'female' + end + """, + ) + + table.database.sqlalchemy_uri = "datastore://" + query_obj = { + "groupby": ["gender_cc"], + "is_timeseries": False, + "filter": [], + "orderby": [["gender_cc", True]], + } + sql = table.get_query_str(query_obj) + assert "ORDER BY gender_cc ASC" in sql + + def test_build_sqlalchemy_uri(self): + """ + DB Eng Specs (datastore): Test building SQLAlchemy URI from parameters + """ + parameters: dict[str, Any] = {"query": {}} + encrypted_extra = { + "credentials_info": { + "project_id": "my-project", + "private_key": "SECRET", + } + } + result = DatastoreEngineSpec.build_sqlalchemy_uri(parameters, encrypted_extra) + assert result == "datastore://my-project/?" + + # Test with query parameters + parameters_with_query: dict[str, Any] = {"query": {"location": "US"}} + result = DatastoreEngineSpec.build_sqlalchemy_uri( + parameters_with_query, encrypted_extra + ) + assert result == "datastore://my-project/?location=US" + + # Test missing encrypted_extra raises ValidationError + with pytest.raises(ValidationError, match="Missing service credentials"): + DatastoreEngineSpec.build_sqlalchemy_uri(parameters, None) + + # Test missing project_id raises ValidationError + bad_extra = {"credentials_info": {"private_key": "SECRET"}} + with pytest.raises(ValidationError, match="Invalid service credentials"): + DatastoreEngineSpec.build_sqlalchemy_uri(parameters, bad_extra) + + def test_get_function_names(self): + """ + DB Eng Specs (datastore): Test retrieving function names for autocomplete + """ + database = mock.MagicMock() + result = DatastoreEngineSpec.get_function_names(database) + assert result == ["sum", "avg", "count", "count_up_to", "min", "max"] + + def test_get_view_names(self): + """ + DB Eng Specs (datastore): Test that Datastore returns no view names + """ + database = mock.MagicMock() + inspector = mock.MagicMock() + result = DatastoreEngineSpec.get_view_names(database, inspector, "some_schema") + assert result == set() + + def test_validate_parameters(self): + """ + DB Eng Specs (datastore): Test parameter validation returns no errors + """ + result = DatastoreEngineSpec.validate_parameters( + { + "host": "localhost", + "port": 5432, + "username": "", + "password": "", + "database": "", + "query": {}, + } + ) + assert result == [] + + def test_get_allow_cost_estimate(self): + """ + DB Eng Specs (datastore): Test that cost estimate is not supported + """ + assert DatastoreEngineSpec.get_allow_cost_estimate({}) is False + + def test_parse_error_exception(self): + """ + DB Eng Specs (datastore): Test error message parsing extracts first line + """ + multiline_msg = ( + 'datastore error: 400 Syntax error: Table "t" must be qualified.\n' + "\n" + "(job ID: abc-123)\n" + "\n" + " -----Query Job SQL Follows-----\n" + " 1:select * from t\n" + ) + result = DatastoreEngineSpec.parse_error_exception(Exception(multiline_msg)) + assert str(result) == ( + 'datastore error: 400 Syntax error: Table "t" must be qualified.' + ) + + # Simple single-line messages pass through unchanged + simple_msg = "Some simple error" + result = DatastoreEngineSpec.parse_error_exception(Exception(simple_msg)) + assert str(result) == simple_msg + + def test_convert_dttm(self): + """ + DB Eng Specs (datastore): Test datetime conversion for all supported types + """ + dttm = datetime(2019, 1, 2, 3, 4, 5, 678900) + + assert ( + DatastoreEngineSpec.convert_dttm("DATE", dttm) + == "CAST('2019-01-02' AS DATE)" + ) + assert ( + DatastoreEngineSpec.convert_dttm("DATETIME", dttm) + == "CAST('2019-01-02T03:04:05.678900' AS DATETIME)" + ) + assert ( + DatastoreEngineSpec.convert_dttm("TIMESTAMP", dttm) + == "CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)" + ) + assert ( + DatastoreEngineSpec.convert_dttm("TIME", dttm) + == "CAST('03:04:05.678900' AS TIME)" + ) + assert DatastoreEngineSpec.convert_dttm("UnknownType", dttm) is None + + def test_get_parameters_from_uri(self): + """ + DB Eng Specs (datastore): Test extracting parameters from URI + """ + encrypted_extra = { + "credentials_info": { + "project_id": "my-project", + "private_key": "SECRET", + } + } + result = DatastoreEngineSpec.get_parameters_from_uri( + "datastore://my-project/", encrypted_extra + ) + assert result == { + "credentials_info": { + "project_id": "my-project", + "private_key": "SECRET", + }, + "query": {}, + } + + # URI with query parameters + result = DatastoreEngineSpec.get_parameters_from_uri( + "datastore://my-project/?location=US", encrypted_extra + ) + assert result["query"] == {"location": "US"} + + # Missing encrypted_extra raises ValidationError + with pytest.raises(ValidationError, match="Invalid service credentials"): + DatastoreEngineSpec.get_parameters_from_uri("datastore://my-project/", None) + + def test_get_dbapi_exception_mapping(self): + """ + DB Eng Specs (datastore): Test DBAPI exception mapping includes + DefaultCredentialsError + """ + from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError + + pytest.importorskip("google.auth") + + mapping = DatastoreEngineSpec.get_dbapi_exception_mapping() + assert len(mapping) == 1 + + # Verify the mapping key is DefaultCredentialsError + exception_class = list(mapping.keys())[0] + assert exception_class.__name__ == "DefaultCredentialsError" + assert mapping[exception_class] is SupersetDBAPIConnectionError + + def test_extract_errors_unmatched(self): + """ + DB Eng Specs (datastore): Test that an unmatched error falls through + to the base class handling + """ + msg = "Some completely unknown error message" + result = DatastoreEngineSpec.extract_errors(Exception(msg)) + assert len(result) == 1 + assert result[0].error_type == SupersetErrorType.GENERIC_DB_ENGINE_ERROR + + def test_build_sqlalchemy_uri_string_credentials(self): + """ + DB Eng Specs (datastore): Test building URI when credentials_info is a + JSON string instead of a dict + """ + from superset.utils import json + + parameters: dict[str, Any] = {"query": {}} + encrypted_extra = { + "credentials_info": json.dumps( + { + "project_id": "string-project", + "private_key": "SECRET", + } + ) + } + result = DatastoreEngineSpec.build_sqlalchemy_uri(parameters, encrypted_extra) + assert result == "datastore://string-project/?" + + def test_get_fields(self): + """ + DB Eng Specs (datastore): Test that _get_fields labels struct columns + with double-underscore separators + """ + cols: list[ResultSetColumnType] = [ + { + "column_name": "name", + "name": "name", + "type": "STRING", + "is_dttm": False, + }, + { + "column_name": "project.name", + "name": "project.name", + "type": "STRING", + "is_dttm": False, + }, + ] + fields = DatastoreEngineSpec._get_fields(cols) + assert len(fields) == 2 + # First column: simple name, label unchanged + assert fields[0].key == "name" + # Second column: struct field, dot replaced with double underscore + assert fields[1].key == "project__name" + + def test_adjust_engine_params(self): + """ + DB Eng Specs (datastore): Test engine parameter adjustment with catalog + """ + url = make_url("datastore://original-project") + + # Without catalog, URI is unchanged + uri, connect_args = DatastoreEngineSpec.adjust_engine_params(url, {}) + assert str(uri) == "datastore://original-project" + assert connect_args == {} + + # With catalog, host is replaced and database cleared + uri, _ = DatastoreEngineSpec.adjust_engine_params( + url, {}, catalog="new-project" + ) + assert str(uri) == "datastore://new-project/" + + # Schema parameter is ignored (Datastore adjusts only catalog) + uri, _ = DatastoreEngineSpec.adjust_engine_params(url, {}, schema="some_schema") + assert str(uri) == "datastore://original-project" + + def test_class_attributes(self): + """ + DB Eng Specs (datastore): Test key class attributes are set correctly + """ + assert DatastoreEngineSpec.engine == "datastore" + assert DatastoreEngineSpec.engine_name == "Google Datastore" + assert DatastoreEngineSpec.max_column_name_length == 128 + assert DatastoreEngineSpec.disable_ssh_tunneling is True + assert DatastoreEngineSpec.default_driver == "datastore" + assert DatastoreEngineSpec.run_multiple_statements_as_one is True + assert DatastoreEngineSpec.allows_hidden_cc_in_orderby is True + assert DatastoreEngineSpec.supports_dynamic_schema is True + assert DatastoreEngineSpec.supports_catalog is True + assert DatastoreEngineSpec.supports_dynamic_catalog is True + assert DatastoreEngineSpec.arraysize == 5000 + assert DatastoreEngineSpec.encrypted_extra_sensitive_fields == { + "$.credentials_info.private_key" + } diff --git a/tests/unit_tests/db_engine_specs/test_datastore.py b/tests/unit_tests/db_engine_specs/test_datastore.py new file mode 100644 index 00000000000..57a745bff3e --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_datastore.py @@ -0,0 +1,985 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=line-too-long, import-outside-toplevel, protected-access, invalid-name + +from __future__ import annotations + +from datetime import datetime + +import pytest +from pytest_mock import MockerFixture +from sqlalchemy import select +from sqlalchemy.engine.url import make_url +from sqlalchemy.sql import sqltypes + +pytest.importorskip("sqlalchemy_datastore") +from sqlalchemy_datastore import CloudDatastoreDialect # noqa: E402 + +from superset.db_engine_specs.datastore import ( + DatastoreEngineSpec, + DatastoreParametersType, +) +from superset.sql.parse import Table +from superset.superset_typing import ResultSetColumnType +from superset.utils import json +from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm +from tests.unit_tests.fixtures.common import dttm # noqa: F401 + + +def test_get_fields() -> None: + """ + Test the custom ``_get_fields`` method. + + The method adds custom labels (aliases) to the columns to prevent + collision when referencing record fields. Eg, if we had these two + columns: + + name STRING + project STRUCT + + One could write this query: + + SELECT + `name`, + `project`.`name` + FROM + the_table + + But then both columns would get aliased as "name". + + The custom method will replace the fields so that the final query + looks like this: + + SELECT + `name` AS `name`, + `project`.`name` AS project__name + FROM + the_table + + """ + + columns: list[ResultSetColumnType] = [ + {"column_name": "limit", "name": "limit", "type": "STRING", "is_dttm": False}, + {"column_name": "name", "name": "name", "type": "STRING", "is_dttm": False}, + { + "column_name": "project.name", + "name": "project.name", + "type": "STRING", + "is_dttm": False, + }, + ] + fields = DatastoreEngineSpec._get_fields(columns) + + query = select(fields) + assert str(query.compile(dialect=CloudDatastoreDialect())) == ( + 'SELECT "limit" AS "limit", name AS name, "project.name" AS project__name' + ) + + +def test_select_star(mocker: MockerFixture) -> None: + """ + Test the ``select_star`` method. + + The method removes pseudo-columns from structures inside arrays. While these + pseudo-columns show up as "columns" for metadata reasons, we can't select them + in the query, as opposed to fields from non-array structures. + """ + + cols: list[ResultSetColumnType] = [ + { + "column_name": "trailer", + "name": "trailer", + "type": sqltypes.ARRAY(sqltypes.JSON()), + "nullable": True, + "comment": None, + "default": None, + "precision": None, + "scale": None, + "max_length": None, + "is_dttm": False, + }, + { + "column_name": "trailer.key", + "name": "trailer.key", + "type": sqltypes.String(), + "nullable": True, + "comment": None, + "default": None, + "precision": None, + "scale": None, + "max_length": None, + "is_dttm": False, + }, + { + "column_name": "trailer.value", + "name": "trailer.value", + "type": sqltypes.String(), + "nullable": True, + "comment": None, + "default": None, + "precision": None, + "scale": None, + "max_length": None, + "is_dttm": False, + }, + { + "column_name": "trailer.email", + "name": "trailer.email", + "type": sqltypes.String(), + "nullable": True, + "comment": None, + "default": None, + "precision": None, + "scale": None, + "max_length": None, + "is_dttm": False, + }, + ] + + # mock the database so we can compile the query + database = mocker.MagicMock() + database.compile_sqla_query = lambda query, catalog, schema: str( + query.compile( + dialect=CloudDatastoreDialect(), compile_kwargs={"literal_binds": True} + ) + ) + + dialect = CloudDatastoreDialect() + + sql = DatastoreEngineSpec.select_star( + database=database, + table=Table("my_table"), + dialect=dialect, + limit=100, + show_cols=True, + indent=True, + latest_partition=False, + cols=cols, + ) + assert ( + sql + == """SELECT + trailer AS trailer +FROM my_table +LIMIT 100""" + ) + + +def test_get_parameters_from_uri_serializable() -> None: + """ + Test that the result from ``get_parameters_from_uri`` is JSON serializable. + """ + + parameters = DatastoreEngineSpec.get_parameters_from_uri( + "datastore://dbt-tutorial-347100/", + {"access_token": "TOP_SECRET"}, + ) + assert parameters == {"access_token": "TOP_SECRET", "query": {}} + assert json.loads(json.dumps(parameters)) == parameters + + +def test_unmask_encrypted_extra() -> None: + """ + Test that the private key can be reused from the previous `encrypted_extra`. + """ + + old = json.dumps( + { + "credentials_info": { + "project_id": "black-sanctum-314419", + "private_key": "SECRET", + }, + } + ) + new = json.dumps( + { + "credentials_info": { + "project_id": "yellow-unicorn-314419", + "private_key": "XXXXXXXXXX", + }, + } + ) + + assert DatastoreEngineSpec.unmask_encrypted_extra(old, new) == json.dumps( + { + "credentials_info": { + "project_id": "yellow-unicorn-314419", + "private_key": "SECRET", + }, + } + ) + + +def test_unmask_encrypted_extra_field_changed() -> None: + """ + Test that the private key is not reused when the field has changed. + """ + + old = json.dumps( + { + "credentials_info": { + "project_id": "black-sanctum-314419", + "private_key": "SECRET", + }, + } + ) + new = json.dumps( + { + "credentials_info": { + "project_id": "yellow-unicorn-314419", + "private_key": "NEW-SECRET", + }, + } + ) + + assert DatastoreEngineSpec.unmask_encrypted_extra(old, new) == json.dumps( + { + "credentials_info": { + "project_id": "yellow-unicorn-314419", + "private_key": "NEW-SECRET", + }, + } + ) + + +def test_unmask_encrypted_extra_when_old_is_none() -> None: + """ + Test that a `None` value for the old field works for `encrypted_extra`. + """ + + old = None + new = json.dumps( + { + "credentials_info": { + "project_id": "yellow-unicorn-314419", + "private_key": "XXXXXXXXXX", + }, + } + ) + + assert DatastoreEngineSpec.unmask_encrypted_extra(old, new) == json.dumps( + { + "credentials_info": { + "project_id": "yellow-unicorn-314419", + "private_key": "XXXXXXXXXX", + }, + } + ) + + +def test_unmask_encrypted_extra_when_new_is_none() -> None: + """ + Test that a `None` value for the new field works for `encrypted_extra`. + """ + + old = json.dumps( + { + "credentials_info": { + "project_id": "black-sanctum-314419", + "private_key": "SECRET", + }, + } + ) + new = None + + assert DatastoreEngineSpec.unmask_encrypted_extra(old, new) is None + + +def test_mask_encrypted_extra() -> None: + """ + Test that the private key is masked when the database is edited. + """ + + config = json.dumps( + { + "credentials_info": { + "project_id": "black-sanctum-314419", + "private_key": "SECRET", + }, + } + ) + + assert DatastoreEngineSpec.mask_encrypted_extra(config) == json.dumps( + { + "credentials_info": { + "project_id": "black-sanctum-314419", + "private_key": "XXXXXXXXXX", + }, + } + ) + + +def test_mask_encrypted_extra_when_empty() -> None: + """ + Test that the encrypted extra will return a none value if the field is empty. + """ + + assert DatastoreEngineSpec.mask_encrypted_extra(None) is None + + +def test_parse_error_message() -> None: + """ + Test that we parse a received message and just extract the useful information. + + Example errors: + datastore error: 400 Syntax error: Table \"case_detail_all_suites\" must be qualified with a dataset (e.g. dataset.table). + + (job ID: ddf30b05-44e8-4fbf-aa29-40bfccaed886) + -----Query Job SQL Follows----- + | . | . | . |\n 1:select * from case_detail_all_suites\n 2:LIMIT 1001\n | . | . | . | + """ # noqa: E501 + + message = 'datastore error: 400 Syntax error: Table "case_detail_all_suites" must be qualified with a dataset (e.g. dataset.table).\n\n(job ID: ddf30b05-44e8-4fbf-aa29-40bfccaed886)\n\n -----Query Job SQL Follows----- \n\n | . | . | . |\n 1:select * from case_detail_all_suites\n 2:LIMIT 1001\n | . | . | . |' # noqa: E501 + expected_result = 'datastore error: 400 Syntax error: Table "case_detail_all_suites" must be qualified with a dataset (e.g. dataset.table).' # noqa: E501 + assert ( + str(DatastoreEngineSpec.parse_error_exception(Exception(message))) + == expected_result + ) + + +def test_parse_error_raises_exception() -> None: + """ + Test that we handle any exception we might get from calling the parse_error_exception method. + + Example errors: + 400 Syntax error: Expected "(" or keyword UNNEST but got "@" at [4:80] + datastore error: 400 Table \"case_detail_all_suites\" must be qualified with a dataset (e.g. dataset.table). + """ # noqa: E501 + + message = 'datastore error: 400 Syntax error: Table "case_detail_all_suites" must be qualified with a dataset (e.g. dataset.table).' # noqa: E501 + message_2 = "6" + expected_result = 'datastore error: 400 Syntax error: Table "case_detail_all_suites" must be qualified with a dataset (e.g. dataset.table).' # noqa: E501 + assert ( + str(DatastoreEngineSpec.parse_error_exception(Exception(message))) + == expected_result + ) + assert str(DatastoreEngineSpec.parse_error_exception(Exception(message_2))) == "6" + + +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("Date", "CAST('2019-01-02' AS DATE)"), + ("DateTime", "CAST('2019-01-02T03:04:05.678900' AS DATETIME)"), + ("TimeStamp", "CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)"), + ("Time", "CAST('03:04:05.678900' AS TIME)"), + ("UnknownType", None), + ], +) +def test_convert_dttm( + target_type: str, + expected_result: str | None, + dttm: datetime, # noqa: F811 +) -> None: + """ + DB Eng Specs (datastore): Test conversion to date time + """ + assert_convert_dttm(DatastoreEngineSpec, target_type, expected_result, dttm) + + +def test_get_default_catalog(mocker: MockerFixture) -> None: + """ + Test that we get the default catalog from the connection URI. + """ + from superset.models.core import Database + + mocker.patch.object(Database, "get_sqla_engine") + get_client = mocker.patch.object(DatastoreEngineSpec, "_get_client") + get_client().project = "project" + + database = Database( + database_name="my_db", + sqlalchemy_uri="datastore://project", + ) + assert DatastoreEngineSpec.get_default_catalog(database) == "project" + + database = Database( + database_name="my_db", + sqlalchemy_uri="datastore:///project", + ) + assert DatastoreEngineSpec.get_default_catalog(database) == "project" + + database = Database( + database_name="my_db", + sqlalchemy_uri="datastore://", + ) + assert DatastoreEngineSpec.get_default_catalog(database) == "project" + + +def test_adjust_engine_params_catalog_as_host() -> None: + """ + Test passing a custom catalog. + + In this test, the original URI has the catalog as the host. + """ + + url = make_url("datastore://project") + + uri = DatastoreEngineSpec.adjust_engine_params(url, {})[0] + assert str(uri) == "datastore://project" + + uri = DatastoreEngineSpec.adjust_engine_params( + url, + {}, + catalog="other-project", + )[0] + assert str(uri) == "datastore://other-project/" + + +def test_adjust_engine_params_catalog_as_database() -> None: + """ + Test passing a custom catalog. + + In this test, the original URI has the catalog as the database. + """ + + url = make_url("datastore:///project") + + uri = DatastoreEngineSpec.adjust_engine_params(url, {})[0] + assert str(uri) == "datastore:///project" + + uri = DatastoreEngineSpec.adjust_engine_params( + url, + {}, + catalog="other-project", + )[0] + assert str(uri) == "datastore://other-project/" + + +def test_adjust_engine_params_no_catalog() -> None: + """ + Test passing a custom catalog. + + In this test, the original URI has no catalog. + """ + + url = make_url("datastore://") + + uri = DatastoreEngineSpec.adjust_engine_params(url, {})[0] + assert str(uri) == "datastore://" + + uri = DatastoreEngineSpec.adjust_engine_params( + url, + {}, + catalog="other-project", + )[0] + assert str(uri) == "datastore://other-project/" + + +def test_get_client_passes_database_from_url(mocker: MockerFixture) -> None: + """ + Test that ``_get_client`` passes the ``database`` query parameter + from the engine URL through to ``datastore.Client``. + """ + + mock_client_cls = mocker.patch( + "superset.db_engine_specs.datastore.datastore.Client" + ) + mocker.patch( + "superset.db_engine_specs.datastore.service_account" + ".Credentials.from_service_account_info", + return_value=mocker.MagicMock(), + ) + + engine = mocker.MagicMock() + engine.dialect.credentials_info = {"project_id": "my-project", "private_key": "k"} + engine.url.query = {"database": "my-db"} + + database = mocker.MagicMock() + + DatastoreEngineSpec._get_client(engine, database) + mock_client_cls.assert_called_once_with(credentials=mocker.ANY, database="my-db") + + +def test_get_client_passes_none_when_no_database(mocker: MockerFixture) -> None: + """ + Test that ``_get_client`` passes ``database=None`` when the URL + has no ``database`` query parameter. + """ + + mock_client_cls = mocker.patch( + "superset.db_engine_specs.datastore.datastore.Client" + ) + mocker.patch( + "superset.db_engine_specs.datastore.service_account" + ".Credentials.from_service_account_info", + return_value=mocker.MagicMock(), + ) + + engine = mocker.MagicMock() + engine.dialect.credentials_info = {"project_id": "my-project", "private_key": "k"} + engine.url.query = {} + + database = mocker.MagicMock() + + DatastoreEngineSpec._get_client(engine, database) + mock_client_cls.assert_called_once_with(credentials=mocker.ANY, database=None) + + +def test_get_client_default_credentials_passes_database( + mocker: MockerFixture, +) -> None: + """ + Test that ``_get_client`` passes ``database`` when falling back + to default credentials. + """ + + mock_client_cls = mocker.patch( + "superset.db_engine_specs.datastore.datastore.Client" + ) + + engine = mocker.MagicMock() + engine.dialect.credentials_info = None + engine.url.query = {"database": "other-db"} + + database = mocker.MagicMock() + + DatastoreEngineSpec._get_client(engine, database) + mock_client_cls.assert_called_once_with(credentials=mocker.ANY, database="other-db") + + +def test_parameters_json_schema_has_encrypted_extra() -> None: + """ + Test that ``parameters_json_schema`` marks ``credentials_info`` with + ``x-encrypted-extra`` so the frontend moves credentials into + ``masked_encrypted_extra``. + """ + + schema = DatastoreEngineSpec.parameters_json_schema() + assert schema is not None + credentials_info = schema["properties"]["credentials_info"] + assert credentials_info["x-encrypted-extra"] is True + + +def test_execute_with_cursor_no_warnings(mocker: MockerFixture) -> None: + """ + Test ``execute_with_cursor`` delegates to the base class and does not + set warnings when the cursor has none. + """ + + mocker.patch( + "superset.db_engine_specs.base.BaseEngineSpec.execute_with_cursor", + ) + + cursor = mocker.MagicMock() + cursor.warnings = [] + query = mocker.MagicMock() + + DatastoreEngineSpec.execute_with_cursor(cursor, "SELECT 1", query) + query.set_extra_json_key.assert_not_called() + + +def test_execute_with_cursor_with_warnings(mocker: MockerFixture) -> None: + """ + Test ``execute_with_cursor`` stores cursor warnings in the query's + ``extra_json`` when the cursor reports warnings. + """ + + mocker.patch( + "superset.db_engine_specs.base.BaseEngineSpec.execute_with_cursor", + ) + + cursor = mocker.MagicMock() + cursor.warnings = ["Missing composite index for query"] + query = mocker.MagicMock() + + DatastoreEngineSpec.execute_with_cursor(cursor, "SELECT * FROM Kind", query) + query.set_extra_json_key.assert_called_once_with( + "warnings", ["Missing composite index for query"] + ) + + +def test_execute_with_cursor_no_warnings_attr(mocker: MockerFixture) -> None: + """ + Test ``execute_with_cursor`` does not fail when the cursor has no + ``warnings`` attribute. + """ + + mocker.patch( + "superset.db_engine_specs.base.BaseEngineSpec.execute_with_cursor", + ) + + cursor = mocker.MagicMock(spec=[]) # no attributes at all + query = mocker.MagicMock() + + DatastoreEngineSpec.execute_with_cursor(cursor, "SELECT 1", query) + query.set_extra_json_key.assert_not_called() + + +def test_get_client_dependencies_not_installed(mocker: MockerFixture) -> None: + """ + Test that ``_get_client`` raises ``SupersetException`` when the + google-cloud-datastore package is not installed. + """ + from superset.exceptions import SupersetException + + mocker.patch( + "superset.db_engine_specs.datastore.dependencies_installed", + False, + ) + + engine = mocker.MagicMock() + database = mocker.MagicMock() + + with pytest.raises(SupersetException, match="Could not import libraries"): + DatastoreEngineSpec._get_client(engine, database) + + +def test_get_client_default_credentials_error(mocker: MockerFixture) -> None: + """ + Test that ``_get_client`` raises ``SupersetDBAPIConnectionError`` when + google.auth.default() fails. + """ + from google.auth.exceptions import DefaultCredentialsError + + from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError + + mocker.patch( + "superset.db_engine_specs.datastore.google.auth.default", + side_effect=DefaultCredentialsError("No credentials found"), + ) + + engine = mocker.MagicMock() + engine.dialect.credentials_info = None + engine.url.query = {} + database = mocker.MagicMock() + + with pytest.raises( + SupersetDBAPIConnectionError, + match="database credentials could not be found", + ): + DatastoreEngineSpec._get_client(engine, database) + + +def test_fetch_data_regular_tuples(mocker: MockerFixture) -> None: + """ + Test ``fetch_data`` with regular tuple rows passes them through unchanged. + """ + from superset.db_engine_specs.base import BaseEngineSpec + + data = [(1, "foo"), (2, "bar")] + mocker.patch.object(BaseEngineSpec, "fetch_data", return_value=data) + + result = DatastoreEngineSpec.fetch_data(mocker.MagicMock(), 0) + assert result == [(1, "foo"), (2, "bar")] + + +def test_fetch_data_with_row_objects(mocker: MockerFixture) -> None: + """ + Test ``fetch_data`` with google.cloud.datastore Row-like objects that + have a ``values()`` method. + """ + from superset.db_engine_specs.base import BaseEngineSpec + + class Row: + def __init__(self, val: tuple[int, str]) -> None: + self._val = val + + def values(self) -> tuple[int, str]: + return self._val + + data = [Row((1, "a")), Row((2, "b"))] + mocker.patch.object(BaseEngineSpec, "fetch_data", return_value=data) + + result = DatastoreEngineSpec.fetch_data(mocker.MagicMock(), 0) + assert result == [(1, "a"), (2, "b")] + + +def test_fetch_data_empty(mocker: MockerFixture) -> None: + """ + Test ``fetch_data`` with an empty result set. + """ + from superset.db_engine_specs.base import BaseEngineSpec + + mocker.patch.object(BaseEngineSpec, "fetch_data", return_value=[]) + + result = DatastoreEngineSpec.fetch_data(mocker.MagicMock(), 0) + assert result == [] + + +def test_build_sqlalchemy_uri() -> None: + """ + Test building a SQLAlchemy URI from parameters and encrypted_extra. + """ + + parameters: DatastoreParametersType = { + "credentials_info": {}, + "query": {}, + } + encrypted_extra = { + "credentials_info": { + "project_id": "my-project", + "private_key": "SECRET", + } + } + result = DatastoreEngineSpec.build_sqlalchemy_uri(parameters, encrypted_extra) + assert result == "datastore://my-project/?" + + +def test_build_sqlalchemy_uri_with_query_params() -> None: + """ + Test building a SQLAlchemy URI with query parameters. + """ + + parameters: DatastoreParametersType = { + "credentials_info": {}, + "query": {"database": "my-db"}, + } + encrypted_extra = { + "credentials_info": { + "project_id": "my-project", + "private_key": "SECRET", + } + } + result = DatastoreEngineSpec.build_sqlalchemy_uri(parameters, encrypted_extra) + assert result == "datastore://my-project/?database=my-db" + + +def test_build_sqlalchemy_uri_string_credentials() -> None: + """ + Test building a URI when ``credentials_info`` is a JSON string. + """ + + parameters: DatastoreParametersType = { + "credentials_info": {}, + "query": {}, + } + encrypted_extra = { + "credentials_info": json.dumps( + {"project_id": "string-project", "private_key": "SECRET"} + ) + } + result = DatastoreEngineSpec.build_sqlalchemy_uri(parameters, encrypted_extra) + assert result == "datastore://string-project/?" + + +def test_build_sqlalchemy_uri_missing_encrypted_extra() -> None: + """ + Test that ``build_sqlalchemy_uri`` raises ``ValidationError`` when + ``encrypted_extra`` is None. + """ + from marshmallow.exceptions import ValidationError + + parameters: DatastoreParametersType = {"credentials_info": {}, "query": {}} + with pytest.raises(ValidationError, match="Missing service credentials"): + DatastoreEngineSpec.build_sqlalchemy_uri(parameters, None) + + +def test_build_sqlalchemy_uri_missing_project_id() -> None: + """ + Test that ``build_sqlalchemy_uri`` raises ``ValidationError`` when + ``project_id`` is missing from credentials. + """ + from marshmallow.exceptions import ValidationError + + parameters: DatastoreParametersType = {"credentials_info": {}, "query": {}} + encrypted_extra = {"credentials_info": {"private_key": "SECRET"}} + with pytest.raises(ValidationError, match="Invalid service credentials"): + DatastoreEngineSpec.build_sqlalchemy_uri(parameters, encrypted_extra) + + +def test_get_parameters_from_uri() -> None: + """ + Test extracting parameters from a URI with encrypted_extra. + """ + + encrypted_extra = { + "credentials_info": { + "project_id": "my-project", + "private_key": "SECRET", + } + } + result = DatastoreEngineSpec.get_parameters_from_uri( + "datastore://my-project/?database=my-db", + encrypted_extra, + ) + assert result == { + "credentials_info": { + "project_id": "my-project", + "private_key": "SECRET", + }, + "query": {"database": "my-db"}, + } + + +def test_get_parameters_from_uri_missing_credentials() -> None: + """ + Test that ``get_parameters_from_uri`` raises ``ValidationError`` when + ``encrypted_extra`` is None. + """ + from marshmallow.exceptions import ValidationError + + with pytest.raises(ValidationError, match="Invalid service credentials"): + DatastoreEngineSpec.get_parameters_from_uri("datastore://project/", None) + + +def test_validate_parameters_returns_empty() -> None: + """ + Test that ``validate_parameters`` returns an empty list (validation + is a no-op for Datastore). + """ + + result = DatastoreEngineSpec.validate_parameters( + { + "parameters": { + "host": "", + "port": 0, + "username": "", + "password": "", + "database": "", + "query": {}, + }, + } + ) + assert result == [] + + +def test_get_allow_cost_estimate() -> None: + """ + Test that cost estimation is not supported. + """ + + assert DatastoreEngineSpec.get_allow_cost_estimate({}) is False + + +def test_get_function_names(mocker: MockerFixture) -> None: + """ + Test that ``get_function_names`` returns the expected GQL functions. + """ + + database = mocker.MagicMock() + result = DatastoreEngineSpec.get_function_names(database) + assert result == ["sum", "avg", "count", "count_up_to", "min", "max"] + + +def test_get_view_names(mocker: MockerFixture) -> None: + """ + Test that ``get_view_names`` returns an empty set because Datastore + has no view concept. + """ + + result = DatastoreEngineSpec.get_view_names( + mocker.MagicMock(), mocker.MagicMock(), "some_schema" + ) + assert result == set() + + +def test_get_dbapi_exception_mapping() -> None: + """ + Test that the DBAPI exception mapping maps ``DefaultCredentialsError`` + to ``SupersetDBAPIConnectionError``. + """ + from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError + + mapping = DatastoreEngineSpec.get_dbapi_exception_mapping() + assert len(mapping) == 1 + exc_cls = next(iter(mapping)) + assert exc_cls.__name__ == "DefaultCredentialsError" + assert mapping[exc_cls] is SupersetDBAPIConnectionError + + +def test_mutate_label_simple() -> None: + """ + Test ``_mutate_label`` with labels that need no mutation. + """ + + assert DatastoreEngineSpec._mutate_label("col") == "col" + assert DatastoreEngineSpec._mutate_label("my_column") == "my_column" + assert DatastoreEngineSpec._mutate_label("_private") == "_private" + + +def test_mutate_label_starts_with_digit() -> None: + """ + Test ``_mutate_label`` prefixes an underscore when the label starts + with a digit. + """ + + result = DatastoreEngineSpec._mutate_label("123col") + assert result.startswith("_123col") + # Hash suffix is added because the label was mutated + assert len(result) > len("_123col") + + +def test_mutate_label_special_characters() -> None: + """ + Test ``_mutate_label`` replaces non-alphanumeric characters and adds + a hash suffix. + """ + + result = DatastoreEngineSpec._mutate_label("SUM(x)") + assert result.startswith("SUM_x_") + # Should have a hash suffix + assert "_" in result[5:] + + +def test_truncate_label() -> None: + """ + Test ``_truncate_label`` returns a hash prefixed with underscore. + """ + + result = DatastoreEngineSpec._truncate_label("some_very_long_label") + assert result.startswith("_") + # The hash should be deterministic + assert result == DatastoreEngineSpec._truncate_label("some_very_long_label") + # Different labels produce different hashes + assert result != DatastoreEngineSpec._truncate_label("another_label") + + +def test_select_star_without_cols(mocker: MockerFixture) -> None: + """ + Test ``select_star`` when no columns are provided (cols=None). + """ + + database = mocker.MagicMock() + database.compile_sqla_query = lambda query, catalog, schema: str( + query.compile( + dialect=CloudDatastoreDialect(), compile_kwargs={"literal_binds": True} + ) + ) + dialect = CloudDatastoreDialect() + + sql = DatastoreEngineSpec.select_star( + database=database, + table=Table("my_table"), + dialect=dialect, + limit=100, + show_cols=False, + indent=True, + latest_partition=False, + cols=None, + ) + assert "FROM my_table" in sql + assert "LIMIT 100" in sql + + +def test_get_catalog_names(mocker: MockerFixture) -> None: + """ + Test that ``get_catalog_names`` delegates to the base class. + """ + + database = mocker.MagicMock() + inspector = mocker.MagicMock() + inspector.bind.execute.return_value = [] + + mocker.patch( + "superset.db_engine_specs.base.BaseEngineSpec.get_catalog_names", + return_value={"my-project"}, + ) + + result = DatastoreEngineSpec.get_catalog_names(database, inspector) + assert result == {"my-project"}