Compare commits

...

4 Commits

Author SHA1 Message Date
Beto Dealmeida
2b1f732082 More frontend 2024-09-01 18:58:18 -04:00
Beto Dealmeida
da2bc91a32 Frontend 2024-08-28 17:16:17 -04:00
Beto Dealmeida
83accd751a WIP 2024-08-28 11:35:17 -04:00
Beto Dealmeida
d6d2277ed6 WIP 2024-08-14 16:59:24 -04:00
20 changed files with 461 additions and 92 deletions

View File

@@ -24,10 +24,11 @@ import {
useRef,
useCallback,
} from 'react';
import { styled, SupersetClient, t } from '@superset-ui/core';
import { styled, SupersetClient, SupersetError, t } from '@superset-ui/core';
import type { LabeledValue as AntdLabeledValue } from 'antd/lib/select';
import rison from 'rison';
import { AsyncSelect, Select } from 'src/components';
import ErrorMessageWithStackTrace from 'src/components/ErrorMessage/ErrorMessageWithStackTrace';
import Label from 'src/components/Label';
import { FormLabel } from 'src/components/Form';
import RefreshLabel from 'src/components/RefreshLabel';
@@ -154,6 +155,7 @@ export default function DatabaseSelector({
}: DatabaseSelectorProps) {
const showCatalogSelector = !!db?.allow_multi_catalog;
const [currentDb, setCurrentDb] = useState<DatabaseValue | undefined>();
const [errorPayload, setErrorPayload] = useState<SupersetError | null>();
const [currentCatalog, setCurrentCatalog] = useState<
CatalogOption | null | undefined
>(catalog ? { label: catalog, value: catalog, title: catalog } : undefined);
@@ -267,6 +269,7 @@ export default function DatabaseSelector({
dbId: currentDb?.value,
catalog: currentCatalog?.value,
onSuccess: (schemas, isFetched) => {
setErrorPayload(null);
if (schemas.length === 1) {
changeSchema(schemas[0]);
} else if (
@@ -279,7 +282,13 @@ export default function DatabaseSelector({
addSuccessToast('List refreshed');
}
},
onError: () => handleError(t('There was an error loading the schemas')),
onError: error => {
if (error?.errors) {
setErrorPayload(error?.errors?.[0]);
} else {
handleError(t('There was an error loading the schemas'));
}
},
});
const schemaOptions = schemaData || EMPTY_SCHEMA_OPTIONS;
@@ -299,6 +308,7 @@ export default function DatabaseSelector({
} = useCatalogs({
dbId: showCatalogSelector ? currentDb?.value : undefined,
onSuccess: (catalogs, isFetched) => {
setErrorPayload(null);
if (!showCatalogSelector) {
changeCatalog(null);
} else if (catalogs.length === 1) {
@@ -315,9 +325,13 @@ export default function DatabaseSelector({
addSuccessToast('List refreshed');
}
},
onError: () => {
onError: error => {
if (showCatalogSelector) {
handleError(t('There was an error loading the catalogs'));
if (error?.errors) {
setErrorPayload(error?.errors?.[0]);
} else {
handleError(t('There was an error loading the catalogs'));
}
}
},
});
@@ -423,9 +437,16 @@ export default function DatabaseSelector({
);
}
function renderError() {
return errorPayload ? (
<ErrorMessageWithStackTrace error={errorPayload} source="crud" />
) : null;
}
return (
<DatabaseSelectorWrapper data-test="DatabaseSelector">
{renderDatabaseSelect()}
{renderError()}
{showCatalogSelector && renderCatalogSelect()}
{renderSchemaSelect()}
</DatabaseSelectorWrapper>

View File

@@ -162,7 +162,7 @@ function OAuth2RedirectMessage({
>
provide authorization
</a>{' '}
in order to run this query.
in order to run this operation.
</>
);

View File

@@ -161,23 +161,20 @@ export const httpPathField = ({
getValidation,
validationErrors,
db,
}: FieldPropTypes) => {
console.error(db);
return (
<ValidatedInput
id="http_path_field"
name="http_path_field"
required={required}
value={db?.parameters?.http_path_field}
validationMethods={{ onBlur: getValidation }}
errorMessage={validationErrors?.http_path}
placeholder={t('e.g. sql/protocolv1/o/12345')}
label="HTTP Path"
onChange={changeMethods.onParametersChange}
helpText={t('Copy the name of the HTTP Path of your cluster.')}
/>
);
};
}: FieldPropTypes) => (
<ValidatedInput
id="http_path_field"
name="http_path_field"
required={required}
value={db?.parameters?.http_path_field}
validationMethods={{ onBlur: getValidation }}
errorMessage={validationErrors?.http_path}
placeholder={t('e.g. sql/protocolv1/o/12345')}
label="HTTP Path"
onChange={changeMethods.onParametersChange}
helpText={t('Copy the name of the HTTP Path of your cluster.')}
/>
);
export const usernameField = ({
required,
changeMethods,

View File

@@ -0,0 +1,109 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { useState } from 'react';
import Collapse from 'src/components/Collapse';
import { Input } from 'src/components/Input';
import { FormItem } from 'src/components/Form';
import { FieldPropTypes } from '../../types';
interface OAuth2ClientInfo {
id: string;
secret: string;
authorization_request_uri: string;
token_request_uri: string;
scope: string;
}
export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => {
const encryptedExtra = JSON.parse(db?.masked_encrypted_extra || '{}');
const [oauth2ClientInfo, setOauth2ClientInfo] = useState<OAuth2ClientInfo>({
id: encryptedExtra.oauth2_client_info?.id || '',
secret: encryptedExtra.oauth2_client_info?.secret || '',
authorization_request_uri:
encryptedExtra.oauth2_client_info?.authorization_request_uri || '',
token_request_uri:
encryptedExtra.oauth2_client_info?.token_request_uri || '',
scope: encryptedExtra.oauth2_client_info?.scope || '',
});
if (db?.engine_information?.supports_oauth2 !== true) {
return null;
}
const handleChange = (key: any) => (e: any) => {
const updatedInfo = {
...oauth2ClientInfo,
[key]: e.target.value,
};
setOauth2ClientInfo(updatedInfo);
const event = {
target: {
name: 'oauth2_client_info',
value: updatedInfo,
},
};
changeMethods.onEncryptedExtraInputChange(event);
};
return (
<Collapse>
<Collapse.Panel header="OAuth2 client information" key="1">
<FormItem label="Client ID">
<Input
placeholder="Enter your Client ID"
value={oauth2ClientInfo.id}
onChange={handleChange('id')}
/>
</FormItem>
<FormItem label="Client Secret">
<Input
type="password"
placeholder="Enter your Client Secret"
value={oauth2ClientInfo.secret}
onChange={handleChange('secret')}
/>
</FormItem>
<FormItem label="Authorization Request URI">
<Input
placeholder="https://"
value={oauth2ClientInfo.authorization_request_uri}
onChange={handleChange('authorization_request_uri')}
/>
</FormItem>
<FormItem label="Token Request URI">
<Input
placeholder="https://"
value={oauth2ClientInfo.token_request_uri}
onChange={handleChange('token_request_uri')}
/>
</FormItem>
<FormItem label="Scope">
<Input
value={oauth2ClientInfo.scope}
onChange={handleChange('scope')}
/>
</FormItem>
</Collapse.Panel>
</Collapse>
);
};

View File

@@ -22,16 +22,19 @@ import { FieldPropTypes } from '../../types';
const FIELD_TEXT_MAP = {
account: {
label: 'Account',
helpText: t(
'Copy the identifier of the account you are trying to connect to.',
),
placeholder: t('e.g. xy12345.us-east-2.aws'),
},
warehouse: {
label: 'Warehouse',
placeholder: t('e.g. compute_wh'),
className: 'form-group-w-50',
},
role: {
label: 'Default role',
placeholder: t('e.g. AccountAdmin'),
className: 'form-group-w-50',
},
@@ -54,7 +57,7 @@ export const validatedInputField = ({
errorMessage={validationErrors?.[field]}
placeholder={FIELD_TEXT_MAP[field].placeholder}
helpText={FIELD_TEXT_MAP[field].helpText}
label={field}
label={FIELD_TEXT_MAP[field].label || field}
onChange={changeMethods.onParametersChange}
className={FIELD_TEXT_MAP[field].className || field}
/>

View File

@@ -41,6 +41,7 @@ import {
} from './CommonParameters';
import { validatedInputField } from './ValidatedInputField';
import { EncryptedField } from './EncryptedField';
import { OAuth2ClientField } from './OAuth2ClientField';
import { TableCatalog } from './TableCatalog';
import { formScrollableStyles, validatedFormStyles } from '../styles';
import { DatabaseForm, DatabaseObject } from '../../types';
@@ -67,6 +68,7 @@ export const FormFieldOrder = [
'warehouse',
'role',
'ssh',
'oauth2_client',
];
const extensionsRegistry = getExtensionsRegistry();
@@ -84,6 +86,7 @@ const FORM_FIELD_MAP = {
default_schema: defaultSchemaField,
username: usernameField,
password: passwordField,
oauth2_client: OAuth2ClientField,
access_token: accessTokenField,
database_name: displayField,
query: queryField,
@@ -118,6 +121,9 @@ interface DatabaseConnectionFormProps {
onExtraInputChange: (
event: FormEvent<InputProps> | { target: HTMLInputElement },
) => void;
onEncryptedExtraInputChange: (
event: FormEvent<InputProps> | { target: HTMLInputElement },
) => void;
onAddTableCatalog: () => void;
onRemoveTableCatalog: (idx: number) => void;
validationErrors: JsonObject | null;
@@ -136,6 +142,7 @@ const DatabaseConnectionForm = ({
onAddTableCatalog,
onChange,
onExtraInputChange,
onEncryptedExtraInputChange,
onParametersChange,
onParametersUploadFileChange,
onQueryChange,
@@ -171,6 +178,7 @@ const DatabaseConnectionForm = ({
onAddTableCatalog,
onRemoveTableCatalog,
onExtraInputChange,
onEncryptedExtraInputChange,
},
validationErrors,
getValidation,

View File

@@ -154,6 +154,7 @@ export enum ActionType {
EditorChange,
ExtraEditorChange,
ExtraInputChange,
EncryptedExtraInputChange,
Fetched,
InputChange,
ParametersChange,
@@ -185,6 +186,7 @@ export type DBReducerActionType =
type:
| ActionType.ExtraEditorChange
| ActionType.ExtraInputChange
| ActionType.EncryptedExtraInputChange
| ActionType.TextChange
| ActionType.QueryChange
| ActionType.InputChange
@@ -269,6 +271,14 @@ export function dbReducer(
[action.payload.name]: actionPayloadJson,
}),
};
case ActionType.EncryptedExtraInputChange:
return {
...trimmedState,
masked_encrypted_extra: JSON.stringify({
...JSON.parse(trimmedState.masked_encrypted_extra || '{}'),
[action.payload.name]: action.payload.value,
}),
};
case ActionType.ExtraInputChange:
// "extra" payload in state is a string
if (
@@ -1656,6 +1666,16 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
value: target.value,
})
}
onEncryptedExtraInputChange={({
target,
}: {
target: HTMLInputElement;
}) =>
onChange(ActionType.EncryptedExtraInputChange, {
name: target.name,
value: target.value,
})
}
onRemoveTableCatalog={(idx: number) => {
setDB({
type: ActionType.RemoveTableCatalogSheet,

View File

@@ -113,6 +113,7 @@ export type DatabaseObject = {
supports_file_upload?: boolean;
disable_ssh_tunneling?: boolean;
supports_dynamic_catalog?: boolean;
supports_oauth2?: boolean;
};
// SSH Tunnel information
@@ -301,6 +302,7 @@ export interface FieldPropTypes {
onRemoveTableCatalog: (idx: number) => void;
} & {
onExtraInputChange: (value: any) => void;
onEncryptedExtraInputChange: (value: any) => void;
onSSHTunnelParametersChange: CustomEventHandlerType;
};
validationErrors: JsonObject | null;

View File

@@ -17,6 +17,7 @@
* under the License.
*/
import { useCallback, useEffect } from 'react';
import { ClientErrorObject } from '@superset-ui/core';
import useEffectEvent from 'src/hooks/useEffectEvent';
import { api, JsonResponse } from './queryApi';
@@ -30,7 +31,7 @@ export type FetchCatalogsQueryParams = {
dbId?: string | number;
forceRefresh: boolean;
onSuccess?: (data: CatalogOption[], isRefetched: boolean) => void;
onError?: () => void;
onError?: (error: ClientErrorObject) => void;
};
type Params = Omit<FetchCatalogsQueryParams, 'forceRefresh'>;
@@ -77,6 +78,12 @@ export function useCatalogs(options: Params) {
},
);
useEffect(() => {
if (result.isError) {
onError?.(result.error);
}
}, [result.isError, result.error, onError]);
const fetchData = useEffectEvent(
(dbId: FetchCatalogsQueryParams['dbId'], forceRefresh = false) => {
if (dbId && (!result.currentData || forceRefresh)) {
@@ -85,7 +92,7 @@ export function useCatalogs(options: Params) {
onSuccess?.(data || EMPTY_CATALOGS, forceRefresh);
}
if (isError) {
onError?.();
onError?.(result.error);
}
});
}

View File

@@ -64,6 +64,7 @@ export const supersetClientQuery: BaseQueryFn<
getClientErrorObject(response).then(errorObj => ({
error: {
error: errorObj?.message || errorObj?.error || response.statusText,
errors: errorObj?.errors || [], // used by <ErrorMessageWithStackTrace />
status: response.status,
},
})),

View File

@@ -17,6 +17,7 @@
* under the License.
*/
import { useCallback, useEffect } from 'react';
import { ClientErrorObject } from '@superset-ui/core';
import useEffectEvent from 'src/hooks/useEffectEvent';
import { api, JsonResponse } from './queryApi';
@@ -31,7 +32,7 @@ export type FetchSchemasQueryParams = {
catalog?: string;
forceRefresh: boolean;
onSuccess?: (data: SchemaOption[], isRefetched: boolean) => void;
onError?: () => void;
onError?: (error: ClientErrorObject) => void;
};
type Params = Omit<FetchSchemasQueryParams, 'forceRefresh'>;
@@ -81,6 +82,12 @@ export function useSchemas(options: Params) {
},
);
useEffect(() => {
if (result.isError) {
onError?.(result.error);
}
}, [result.isError, result.error, onError]);
const fetchData = useEffectEvent(
(
dbId: FetchSchemasQueryParams['dbId'],
@@ -94,7 +101,7 @@ export function useSchemas(options: Params) {
onSuccess?.(data || EMPTY_SCHEMAS, forceRefresh);
}
if (isError) {
onError?.();
onError?.(result.error);
}
},
);

View File

@@ -42,7 +42,7 @@ from superset.commands.database.test_connection import TestConnectionDatabaseCom
from superset.daos.database import DatabaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.db_engine_specs.base import GenericDBException
from superset.exceptions import SupersetErrorsException
from superset.exceptions import OAuth2RedirectError, SupersetErrorsException
from superset.extensions import event_logger, security_manager
from superset.models.core import Database
from superset.utils.decorators import on_error, transaction
@@ -55,13 +55,21 @@ class CreateDatabaseCommand(BaseCommand):
def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
@transaction(on_error=partial(on_error, reraise=DatabaseCreateFailedError))
@transaction(
on_error=partial(on_error, reraise=DatabaseCreateFailedError),
allowed=(OAuth2RedirectError,),
)
def run(self) -> Model:
self.validate()
try:
# Test connection before starting create transaction
TestConnectionDatabaseCommand(self._properties).run()
except OAuth2RedirectError:
# If we can't connect to the database due to an OAuth2 error we can still
# save the database. Later, the user can sync permissions when setting up
# data access rules.
return self._create_database()
except (
SupersetErrorsException,
SSHTunnelingNotEnabledError,
@@ -80,12 +88,6 @@ class CreateDatabaseCommand(BaseCommand):
)
raise DatabaseConnectionFailedError() from ex
# when creating a new database we don't need to unmask encrypted extra
self._properties["encrypted_extra"] = self._properties.pop(
"masked_encrypted_extra",
"{}",
)
ssh_tunnel: Optional[SSHTunnel] = None
try:
@@ -195,6 +197,12 @@ class CreateDatabaseCommand(BaseCommand):
raise exception
def _create_database(self) -> Database:
# when creating a new database we don't need to unmask encrypted extra
self._properties["encrypted_extra"] = self._properties.pop(
"masked_encrypted_extra",
"{}",
)
database = DatabaseDAO.create(attributes=self._properties)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
return database

View File

@@ -41,6 +41,7 @@ from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe
from superset.errors import ErrorLevel, SupersetErrorType
from superset.exceptions import (
OAuth2RedirectError,
SupersetErrorsException,
SupersetSecurityException,
SupersetTimeoutException,
@@ -162,6 +163,12 @@ class TestConnectionDatabaseCommand(BaseCommand):
extra={"sqlalchemy_uri": database.sqlalchemy_uri},
) from ex
except Exception as ex: # pylint: disable=broad-except
if (
database.is_oauth2_enabled()
and database.db_engine_spec.needs_oauth2(ex)
):
database.start_oauth2_dance()
alive = False
# So we stop losing the original message if any
ex_str = str(ex)
@@ -197,6 +204,8 @@ class TestConnectionDatabaseCommand(BaseCommand):
# check for custom errors (wrong username, wrong password, etc)
errors = database.db_engine_spec.extract_errors(ex, self._context)
raise SupersetErrorsException(errors) from ex
except OAuth2RedirectError:
raise
except SupersetSecurityException as ex:
event_logger.log_with_context(
action=get_log_connection_action(
@@ -205,23 +214,13 @@ class TestConnectionDatabaseCommand(BaseCommand):
engine=database.db_engine_spec.__name__,
)
raise DatabaseSecurityUnsafeError(message=str(ex)) from ex
except SupersetTimeoutException as ex:
except (SupersetTimeoutException, SSHTunnelingNotEnabledError) as ex:
event_logger.log_with_context(
action=get_log_connection_action(
"test_connection_error", ssh_tunnel, ex
),
engine=database.db_engine_spec.__name__,
)
# bubble up the exception to return a 408
raise
except SSHTunnelingNotEnabledError as ex:
event_logger.log_with_context(
action=get_log_connection_action(
"test_connection_error", ssh_tunnel, ex
),
engine=database.db_engine_spec.__name__,
)
# bubble up the exception to return a 400
raise
except Exception as ex:
event_logger.log_with_context(

View File

@@ -42,6 +42,7 @@ from superset.daos.database import DatabaseDAO
from superset.daos.dataset import DatasetDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.db_engine_specs.base import GenericDBException
from superset.exceptions import OAuth2RedirectError
from superset.models.core import Database
from superset.utils.decorators import on_error, transaction
@@ -80,7 +81,10 @@ class UpdateDatabaseCommand(BaseCommand):
database = DatabaseDAO.update(self._model, self._properties)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
ssh_tunnel = self._handle_ssh_tunnel(database)
self._refresh_catalogs(database, original_database_name, ssh_tunnel)
try:
self._refresh_catalogs(database, original_database_name, ssh_tunnel)
except OAuth2RedirectError:
pass
return database
@@ -123,6 +127,8 @@ class UpdateDatabaseCommand(BaseCommand):
force=True,
ssh_tunnel=ssh_tunnel,
)
except OAuth2RedirectError:
raise
except GenericDBException as ex:
raise DatabaseConnectionFailedError() from ex

View File

@@ -107,6 +107,15 @@ class ValidateDatabaseParametersCommand(BaseCommand):
with closing(engine.raw_connection()) as conn:
alive = engine.dialect.do_ping(conn)
except Exception as ex:
# if the connection failed because OAuth2 is needed, we can save the
# database and trigger the OAuth2 flow whenever a user tries to run a
# query.
if (
database.is_oauth2_enabled()
and database.db_engine_spec.needs_oauth2(ex)
):
return
url = make_url_safe(sqlalchemy_uri)
context = {
"hostname": url.host,

View File

@@ -110,6 +110,7 @@ from superset.exceptions import (
DatabaseNotFoundException,
InvalidPayloadSchemaError,
OAuth2Error,
OAuth2RedirectError,
SupersetErrorsException,
SupersetException,
SupersetSecurityException,
@@ -398,7 +399,6 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
@expose("/", methods=("POST",))
@protect()
@safe
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.post",
@@ -462,6 +462,8 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
item["ssh_tunnel"] = mask_password_info(new_model.ssh_tunnel)
return self.response(201, id=new_model.id, result=item)
except OAuth2RedirectError:
raise
except DatabaseInvalidError as ex:
return self.response_422(message=ex.normalized_messages())
except DatabaseConnectionFailedError as ex:
@@ -621,7 +623,6 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
@expose("/<int:pk>/catalogs/")
@protect()
@safe
@rison(database_catalogs_query_schema)
@statsd_metrics
@event_logger.log_this_with_context(
@@ -680,12 +681,13 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
500,
message="There was an error connecting to the database",
)
except OAuth2RedirectError:
raise
except SupersetException as ex:
return self.response(ex.status, message=ex.message)
@expose("/<int:pk>/schemas/")
@protect()
@safe
@rison(database_schemas_query_schema)
@statsd_metrics
@event_logger.log_this_with_context(
@@ -746,6 +748,8 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
return self.response(
500, message="There was an error connecting to the database"
)
except OAuth2RedirectError:
raise
except SupersetException as ex:
return self.response(ex.status, message=ex.message)
@@ -1894,9 +1898,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
@protect()
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self,
*args,
**kwargs: f"{self.__class__.__name__}.columnar_upload",
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.columnar_upload",
log_to_statsd=False,
)
@requires_form_data

View File

@@ -136,7 +136,9 @@ builtin_time_grains: dict[str | None, str] = {
}
class TimestampExpression(ColumnClause): # pylint: disable=abstract-method, too-many-ancestors
class TimestampExpression(
ColumnClause
): # pylint: disable=abstract-method, too-many-ancestors
def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None:
"""Sqlalchemy class that can be used to render native column elements respecting
engine-specific quoting rules as part of a string-based expression.
@@ -394,9 +396,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
max_column_name_length: int | None = None
try_remove_schema_from_table_name = True # pylint: disable=invalid-name
run_multiple_statements_as_one = False
custom_errors: dict[
Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]
] = {}
custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = (
{}
)
# Whether the engine supports file uploads
# if True, database will be listed as option in the upload file form
@@ -423,8 +425,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# the user impersonation methods to handle personal tokens.
supports_oauth2 = False
oauth2_scope = ""
oauth2_authorization_request_uri = "" # pylint: disable=invalid-name
oauth2_token_request_uri = ""
oauth2_authorization_request_uri: str | None = None # pylint: disable=invalid-name
oauth2_token_request_uri: str | None = None
# Driver-specific exception that should be mapped to OAuth2RedirectError
oauth2_exception = OAuth2RedirectError
@@ -473,11 +475,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# message.
"tab_id": tab_id,
}
oauth2_config = database.get_oauth2_config()
if oauth2_config is None:
config = database.get_oauth2_config()
if config is None:
raise OAuth2Error("No configuration found for OAuth2")
oauth_url = cls.get_oauth2_authorization_uri(oauth2_config, state)
oauth_url = cls.get_oauth2_authorization_uri(config, state)
raise OAuth2RedirectError(oauth_url, tab_id, default_redirect_uri)
@@ -2196,6 +2198,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"supports_file_upload": cls.supports_file_upload,
"disable_ssh_tunneling": cls.disable_ssh_tunneling,
"supports_dynamic_catalog": cls.supports_dynamic_catalog,
"supports_oauth2": cls.supports_oauth2,
}
@classmethod

View File

@@ -14,23 +14,28 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
import re
from datetime import datetime
from re import Pattern
from typing import Any, Optional, TYPE_CHECKING, TypedDict
from typing import Any, Literal, Optional, TYPE_CHECKING, TypedDict
from urllib import parse
import requests
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from flask import current_app
from flask import current_app, g
from flask_babel import gettext as __
from marshmallow import fields, Schema
from requests.auth import HTTPBasicAuth
from sqlalchemy import types
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.exc import ProgrammingError
from superset.constants import TimeGrain, USER_AGENT
from superset.databases.utils import make_url_safe
@@ -38,6 +43,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query
from superset.superset_typing import OAuth2ClientConfig, OAuth2TokenResponse
from superset.utils import json
if TYPE_CHECKING:
@@ -57,12 +63,29 @@ logger = logging.getLogger(__name__)
class SnowflakeParametersSchema(Schema):
username = fields.Str(required=True)
password = fields.Str(required=True)
account = fields.Str(required=True)
database = fields.Str(required=True)
role = fields.Str(required=True)
warehouse = fields.Str(required=True)
username = fields.Str(
required=False,
allow_none=True,
metadata={"description": "Username"},
)
password = fields.Str(
required=False,
allow_none=True,
metadata={"description": "Password"},
)
oauth2_client = fields.Str(
required=False,
allow_none=True,
metadata={"description": "OAuth2 client information"},
)
account = fields.Str(required=True, metadata={"description": "Account name"})
database = fields.Str(required=True, metadata={"description": "Database name"})
role = fields.Str(
required=False,
allow_none=True,
metadata={"description": "Default role"},
)
warehouse = fields.Str(required=True, metadata={"description": "Warehouse name"})
class SnowflakeParametersType(TypedDict):
@@ -74,6 +97,17 @@ class SnowflakeParametersType(TypedDict):
warehouse: str
SnowflakeParametersKey = Literal[
"username",
"password",
"account",
"database",
"role",
"warehouse",
]
# pylint: disable=too-many-public-methods
class SnowflakeEngineSpec(PostgresBaseEngineSpec):
engine = "snowflake"
engine_name = "Snowflake"
@@ -87,6 +121,9 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
supports_dynamic_schema = True
supports_catalog = supports_dynamic_catalog = True
# pylint: disable=invalid-name
encrypted_extra_sensitive_fields = ["$.oauth2_client_info.secret"]
_time_grain_expressions = {
None: "{col}",
TimeGrain.SECOND: "DATE_TRUNC('SECOND', {col})",
@@ -123,19 +160,6 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
),
}
@staticmethod
def get_extra_params(database: "Database") -> dict[str, Any]:
"""
Add a user agent to be used in the requests.
"""
extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database)
engine_params: dict[str, Any] = extra.setdefault("engine_params", {})
connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {})
connect_args.setdefault("application", USER_AGENT)
return extra
@classmethod
def adjust_engine_params(
cls,
@@ -278,6 +302,8 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
dict[str, Any]
] = None,
) -> str:
query_keys: list[SnowflakeParametersKey] = ["role", "warehouse"]
query = {key: parameters[key] for key in query_keys if parameters.get(key)}
return str(
URL.create(
"snowflake",
@@ -285,10 +311,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
password=parameters.get("password"),
host=parameters.get("account"),
database=parameters.get("database"),
query={
"role": parameters.get("role"),
"warehouse": parameters.get("warehouse"),
},
query=query,
)
)
@@ -317,12 +340,9 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
) -> list[SupersetError]:
errors: list[SupersetError] = []
required = {
"warehouse",
"username",
"database",
"account",
"role",
"password",
"database",
"warehouse",
}
parameters = properties.get("parameters", {})
present = {key for key in parameters if parameters.get(key, ())}
@@ -405,3 +425,131 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config"
)
connect_args["auth"] = snowflake_auth(**auth_params)
supports_oauth2 = True
oauth2_scope = "refresh_token session:role:PUBLIC"
# pylint: disable=invalid-name
oauth2_authorization_request_uri = None
oauth2_token_request_uri = None
@classmethod
def get_extra_params(cls, database: "Database") -> dict[str, Any]:
"""
Add a user agent to be used in the requests.
"""
extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database)
engine_params: dict[str, Any] = extra.setdefault("engine_params", {})
connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {})
connect_args.setdefault("application", USER_AGENT)
# populate OAuth2 URLs if not set, since they can be inferred from the account
if oauth2_client_info := extra.get("oauth2_client_info"):
account = database.url_object.host
oauth2_client_info.setdefault(
"authorization_request_uri",
f"https://{account}.snowflakecomputing.com/oauth/authorize",
)
oauth2_client_info.setdefault(
"token_request_uri",
f"https://{account}.snowflakecomputing.com/oauth/token-request",
)
oauth2_client_info.setdefault("scope", cls.oauth2_scope)
return extra
@classmethod
def update_impersonation_config(
cls,
connect_args: dict[str, Any],
uri: str,
username: str | None,
access_token: str | None,
) -> None:
if access_token:
connect_args.update(
{
"authenticator": "oauth",
"token": access_token,
},
)
@classmethod
def get_url_for_impersonation(
cls,
url: URL,
impersonate_user: bool,
username: str | None,
access_token: str | None,
) -> URL:
# force OAuth2
if impersonate_user:
# remove username/password if present
url = url._replace(username="", password="")
# remove hardcoded role so that the one from OAuth2 is used
url = url.difference_update_query(["role"])
return url
@classmethod
def execute(
cls,
cursor: Any,
query: str,
database: Database,
**kwargs: Any,
) -> None:
try:
cursor.execute(query)
except Exception as ex:
if database.is_oauth2_enabled() and cls.needs_oauth2(ex):
cls.start_oauth2_dance(database)
raise cls.get_dbapi_mapped_exception(ex) from ex
@classmethod
def needs_oauth2(cls, ex: Exception) -> bool:
return (
g
and g.user
and isinstance(ex, ProgrammingError)
and "User is empty" in str(ex)
)
@classmethod
def get_oauth2_token(
cls,
config: OAuth2ClientConfig,
code: str,
) -> OAuth2TokenResponse:
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
uri = config["token_request_uri"]
response = requests.post(
uri,
data={
"code": code,
"redirect_uri": config["redirect_uri"],
"grant_type": "authorization_code",
},
auth=HTTPBasicAuth(config["id"], config["secret"]),
timeout=timeout,
)
return response.json()
@classmethod
def get_oauth2_fresh_token(
cls,
config: OAuth2ClientConfig,
refresh_token: str,
) -> OAuth2TokenResponse:
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
uri = config["token_request_uri"]
response = requests.post(
uri,
data={
"refresh_token": refresh_token,
"grant_type": "refresh_token",
},
auth=HTTPBasicAuth(config["id"], config["secret"]),
timeout=timeout,
)
return response.json()

View File

@@ -844,6 +844,9 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
) as inspector:
return self.db_engine_spec.get_schema_names(inspector)
except Exception as ex:
if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex):
self.start_oauth2_dance()
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
@cache_util.memoized_func(
@@ -865,6 +868,9 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
with self.get_inspector(ssh_tunnel=ssh_tunnel) as inspector:
return self.db_engine_spec.get_catalog_names(self, inspector)
except Exception as ex:
if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex):
self.start_oauth2_dance()
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
@property
@@ -1096,6 +1102,9 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
return self.db_engine_spec.get_oauth2_config()
def start_oauth2_dance(self) -> None:
return self.db_engine_spec.start_oauth2_dance(self)
sqla.event.listen(Database, "after_insert", security_manager.database_after_insert)
sqla.event.listen(Database, "after_update", security_manager.database_after_update)

View File

@@ -238,6 +238,7 @@ def on_error(
def transaction( # pylint: disable=redefined-outer-name
on_error: Callable[..., Any] | None = on_error,
allowed: tuple[type[Exception], ...] = (),
) -> Callable[..., Any]:
"""
Perform a "unit of work".
@@ -246,7 +247,12 @@ def transaction( # pylint: disable=redefined-outer-name
proved rather complicated, likely due to many architectural facets, and thus has
been left for a follow up exercise.
In certain cases it might be desirable to commit even though an exception was
raised. For OAuth2, foe example, we use exceptions as a way to signal the client to
redirect to the login page. In this case, we ignore the exception and commit.
:param on_error: Callback invoked when an exception is caught
:param allowed: Exception types to ignore and not rollback
:see: https://github.com/apache/superset/issues/25108
"""
@@ -259,6 +265,10 @@ def transaction( # pylint: disable=redefined-outer-name
result = func(*args, **kwargs)
db.session.commit() # pylint: disable=consider-using-transaction
return result
except allowed:
db.session.commit() # pylint: disable=consider-using-transaction
raise
except Exception as ex:
db.session.rollback() # pylint: disable=consider-using-transaction