mirror of
https://github.com/apache/superset.git
synced 2026-05-06 16:34:32 +00:00
Compare commits
4 Commits
fix-webpac
...
snowflake-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2b1f732082 | ||
|
|
da2bc91a32 | ||
|
|
83accd751a | ||
|
|
d6d2277ed6 |
@@ -24,10 +24,11 @@ import {
|
|||||||
useRef,
|
useRef,
|
||||||
useCallback,
|
useCallback,
|
||||||
} from 'react';
|
} from 'react';
|
||||||
import { styled, SupersetClient, t } from '@superset-ui/core';
|
import { styled, SupersetClient, SupersetError, t } from '@superset-ui/core';
|
||||||
import type { LabeledValue as AntdLabeledValue } from 'antd/lib/select';
|
import type { LabeledValue as AntdLabeledValue } from 'antd/lib/select';
|
||||||
import rison from 'rison';
|
import rison from 'rison';
|
||||||
import { AsyncSelect, Select } from 'src/components';
|
import { AsyncSelect, Select } from 'src/components';
|
||||||
|
import ErrorMessageWithStackTrace from 'src/components/ErrorMessage/ErrorMessageWithStackTrace';
|
||||||
import Label from 'src/components/Label';
|
import Label from 'src/components/Label';
|
||||||
import { FormLabel } from 'src/components/Form';
|
import { FormLabel } from 'src/components/Form';
|
||||||
import RefreshLabel from 'src/components/RefreshLabel';
|
import RefreshLabel from 'src/components/RefreshLabel';
|
||||||
@@ -154,6 +155,7 @@ export default function DatabaseSelector({
|
|||||||
}: DatabaseSelectorProps) {
|
}: DatabaseSelectorProps) {
|
||||||
const showCatalogSelector = !!db?.allow_multi_catalog;
|
const showCatalogSelector = !!db?.allow_multi_catalog;
|
||||||
const [currentDb, setCurrentDb] = useState<DatabaseValue | undefined>();
|
const [currentDb, setCurrentDb] = useState<DatabaseValue | undefined>();
|
||||||
|
const [errorPayload, setErrorPayload] = useState<SupersetError | null>();
|
||||||
const [currentCatalog, setCurrentCatalog] = useState<
|
const [currentCatalog, setCurrentCatalog] = useState<
|
||||||
CatalogOption | null | undefined
|
CatalogOption | null | undefined
|
||||||
>(catalog ? { label: catalog, value: catalog, title: catalog } : undefined);
|
>(catalog ? { label: catalog, value: catalog, title: catalog } : undefined);
|
||||||
@@ -267,6 +269,7 @@ export default function DatabaseSelector({
|
|||||||
dbId: currentDb?.value,
|
dbId: currentDb?.value,
|
||||||
catalog: currentCatalog?.value,
|
catalog: currentCatalog?.value,
|
||||||
onSuccess: (schemas, isFetched) => {
|
onSuccess: (schemas, isFetched) => {
|
||||||
|
setErrorPayload(null);
|
||||||
if (schemas.length === 1) {
|
if (schemas.length === 1) {
|
||||||
changeSchema(schemas[0]);
|
changeSchema(schemas[0]);
|
||||||
} else if (
|
} else if (
|
||||||
@@ -279,7 +282,13 @@ export default function DatabaseSelector({
|
|||||||
addSuccessToast('List refreshed');
|
addSuccessToast('List refreshed');
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
onError: () => handleError(t('There was an error loading the schemas')),
|
onError: error => {
|
||||||
|
if (error?.errors) {
|
||||||
|
setErrorPayload(error?.errors?.[0]);
|
||||||
|
} else {
|
||||||
|
handleError(t('There was an error loading the schemas'));
|
||||||
|
}
|
||||||
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const schemaOptions = schemaData || EMPTY_SCHEMA_OPTIONS;
|
const schemaOptions = schemaData || EMPTY_SCHEMA_OPTIONS;
|
||||||
@@ -299,6 +308,7 @@ export default function DatabaseSelector({
|
|||||||
} = useCatalogs({
|
} = useCatalogs({
|
||||||
dbId: showCatalogSelector ? currentDb?.value : undefined,
|
dbId: showCatalogSelector ? currentDb?.value : undefined,
|
||||||
onSuccess: (catalogs, isFetched) => {
|
onSuccess: (catalogs, isFetched) => {
|
||||||
|
setErrorPayload(null);
|
||||||
if (!showCatalogSelector) {
|
if (!showCatalogSelector) {
|
||||||
changeCatalog(null);
|
changeCatalog(null);
|
||||||
} else if (catalogs.length === 1) {
|
} else if (catalogs.length === 1) {
|
||||||
@@ -315,9 +325,13 @@ export default function DatabaseSelector({
|
|||||||
addSuccessToast('List refreshed');
|
addSuccessToast('List refreshed');
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
onError: () => {
|
onError: error => {
|
||||||
if (showCatalogSelector) {
|
if (showCatalogSelector) {
|
||||||
handleError(t('There was an error loading the catalogs'));
|
if (error?.errors) {
|
||||||
|
setErrorPayload(error?.errors?.[0]);
|
||||||
|
} else {
|
||||||
|
handleError(t('There was an error loading the catalogs'));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
@@ -423,9 +437,16 @@ export default function DatabaseSelector({
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function renderError() {
|
||||||
|
return errorPayload ? (
|
||||||
|
<ErrorMessageWithStackTrace error={errorPayload} source="crud" />
|
||||||
|
) : null;
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<DatabaseSelectorWrapper data-test="DatabaseSelector">
|
<DatabaseSelectorWrapper data-test="DatabaseSelector">
|
||||||
{renderDatabaseSelect()}
|
{renderDatabaseSelect()}
|
||||||
|
{renderError()}
|
||||||
{showCatalogSelector && renderCatalogSelect()}
|
{showCatalogSelector && renderCatalogSelect()}
|
||||||
{renderSchemaSelect()}
|
{renderSchemaSelect()}
|
||||||
</DatabaseSelectorWrapper>
|
</DatabaseSelectorWrapper>
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ function OAuth2RedirectMessage({
|
|||||||
>
|
>
|
||||||
provide authorization
|
provide authorization
|
||||||
</a>{' '}
|
</a>{' '}
|
||||||
in order to run this query.
|
in order to run this operation.
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -161,23 +161,20 @@ export const httpPathField = ({
|
|||||||
getValidation,
|
getValidation,
|
||||||
validationErrors,
|
validationErrors,
|
||||||
db,
|
db,
|
||||||
}: FieldPropTypes) => {
|
}: FieldPropTypes) => (
|
||||||
console.error(db);
|
<ValidatedInput
|
||||||
return (
|
id="http_path_field"
|
||||||
<ValidatedInput
|
name="http_path_field"
|
||||||
id="http_path_field"
|
required={required}
|
||||||
name="http_path_field"
|
value={db?.parameters?.http_path_field}
|
||||||
required={required}
|
validationMethods={{ onBlur: getValidation }}
|
||||||
value={db?.parameters?.http_path_field}
|
errorMessage={validationErrors?.http_path}
|
||||||
validationMethods={{ onBlur: getValidation }}
|
placeholder={t('e.g. sql/protocolv1/o/12345')}
|
||||||
errorMessage={validationErrors?.http_path}
|
label="HTTP Path"
|
||||||
placeholder={t('e.g. sql/protocolv1/o/12345')}
|
onChange={changeMethods.onParametersChange}
|
||||||
label="HTTP Path"
|
helpText={t('Copy the name of the HTTP Path of your cluster.')}
|
||||||
onChange={changeMethods.onParametersChange}
|
/>
|
||||||
helpText={t('Copy the name of the HTTP Path of your cluster.')}
|
);
|
||||||
/>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
export const usernameField = ({
|
export const usernameField = ({
|
||||||
required,
|
required,
|
||||||
changeMethods,
|
changeMethods,
|
||||||
|
|||||||
@@ -0,0 +1,109 @@
|
|||||||
|
/**
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { useState } from 'react';
|
||||||
|
|
||||||
|
import Collapse from 'src/components/Collapse';
|
||||||
|
import { Input } from 'src/components/Input';
|
||||||
|
import { FormItem } from 'src/components/Form';
|
||||||
|
import { FieldPropTypes } from '../../types';
|
||||||
|
|
||||||
|
interface OAuth2ClientInfo {
|
||||||
|
id: string;
|
||||||
|
secret: string;
|
||||||
|
authorization_request_uri: string;
|
||||||
|
token_request_uri: string;
|
||||||
|
scope: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => {
|
||||||
|
const encryptedExtra = JSON.parse(db?.masked_encrypted_extra || '{}');
|
||||||
|
const [oauth2ClientInfo, setOauth2ClientInfo] = useState<OAuth2ClientInfo>({
|
||||||
|
id: encryptedExtra.oauth2_client_info?.id || '',
|
||||||
|
secret: encryptedExtra.oauth2_client_info?.secret || '',
|
||||||
|
authorization_request_uri:
|
||||||
|
encryptedExtra.oauth2_client_info?.authorization_request_uri || '',
|
||||||
|
token_request_uri:
|
||||||
|
encryptedExtra.oauth2_client_info?.token_request_uri || '',
|
||||||
|
scope: encryptedExtra.oauth2_client_info?.scope || '',
|
||||||
|
});
|
||||||
|
|
||||||
|
if (db?.engine_information?.supports_oauth2 !== true) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleChange = (key: any) => (e: any) => {
|
||||||
|
const updatedInfo = {
|
||||||
|
...oauth2ClientInfo,
|
||||||
|
[key]: e.target.value,
|
||||||
|
};
|
||||||
|
|
||||||
|
setOauth2ClientInfo(updatedInfo);
|
||||||
|
|
||||||
|
const event = {
|
||||||
|
target: {
|
||||||
|
name: 'oauth2_client_info',
|
||||||
|
value: updatedInfo,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
changeMethods.onEncryptedExtraInputChange(event);
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Collapse>
|
||||||
|
<Collapse.Panel header="OAuth2 client information" key="1">
|
||||||
|
<FormItem label="Client ID">
|
||||||
|
<Input
|
||||||
|
placeholder="Enter your Client ID"
|
||||||
|
value={oauth2ClientInfo.id}
|
||||||
|
onChange={handleChange('id')}
|
||||||
|
/>
|
||||||
|
</FormItem>
|
||||||
|
<FormItem label="Client Secret">
|
||||||
|
<Input
|
||||||
|
type="password"
|
||||||
|
placeholder="Enter your Client Secret"
|
||||||
|
value={oauth2ClientInfo.secret}
|
||||||
|
onChange={handleChange('secret')}
|
||||||
|
/>
|
||||||
|
</FormItem>
|
||||||
|
<FormItem label="Authorization Request URI">
|
||||||
|
<Input
|
||||||
|
placeholder="https://"
|
||||||
|
value={oauth2ClientInfo.authorization_request_uri}
|
||||||
|
onChange={handleChange('authorization_request_uri')}
|
||||||
|
/>
|
||||||
|
</FormItem>
|
||||||
|
<FormItem label="Token Request URI">
|
||||||
|
<Input
|
||||||
|
placeholder="https://"
|
||||||
|
value={oauth2ClientInfo.token_request_uri}
|
||||||
|
onChange={handleChange('token_request_uri')}
|
||||||
|
/>
|
||||||
|
</FormItem>
|
||||||
|
<FormItem label="Scope">
|
||||||
|
<Input
|
||||||
|
value={oauth2ClientInfo.scope}
|
||||||
|
onChange={handleChange('scope')}
|
||||||
|
/>
|
||||||
|
</FormItem>
|
||||||
|
</Collapse.Panel>
|
||||||
|
</Collapse>
|
||||||
|
);
|
||||||
|
};
|
||||||
@@ -22,16 +22,19 @@ import { FieldPropTypes } from '../../types';
|
|||||||
|
|
||||||
const FIELD_TEXT_MAP = {
|
const FIELD_TEXT_MAP = {
|
||||||
account: {
|
account: {
|
||||||
|
label: 'Account',
|
||||||
helpText: t(
|
helpText: t(
|
||||||
'Copy the identifier of the account you are trying to connect to.',
|
'Copy the identifier of the account you are trying to connect to.',
|
||||||
),
|
),
|
||||||
placeholder: t('e.g. xy12345.us-east-2.aws'),
|
placeholder: t('e.g. xy12345.us-east-2.aws'),
|
||||||
},
|
},
|
||||||
warehouse: {
|
warehouse: {
|
||||||
|
label: 'Warehouse',
|
||||||
placeholder: t('e.g. compute_wh'),
|
placeholder: t('e.g. compute_wh'),
|
||||||
className: 'form-group-w-50',
|
className: 'form-group-w-50',
|
||||||
},
|
},
|
||||||
role: {
|
role: {
|
||||||
|
label: 'Default role',
|
||||||
placeholder: t('e.g. AccountAdmin'),
|
placeholder: t('e.g. AccountAdmin'),
|
||||||
className: 'form-group-w-50',
|
className: 'form-group-w-50',
|
||||||
},
|
},
|
||||||
@@ -54,7 +57,7 @@ export const validatedInputField = ({
|
|||||||
errorMessage={validationErrors?.[field]}
|
errorMessage={validationErrors?.[field]}
|
||||||
placeholder={FIELD_TEXT_MAP[field].placeholder}
|
placeholder={FIELD_TEXT_MAP[field].placeholder}
|
||||||
helpText={FIELD_TEXT_MAP[field].helpText}
|
helpText={FIELD_TEXT_MAP[field].helpText}
|
||||||
label={field}
|
label={FIELD_TEXT_MAP[field].label || field}
|
||||||
onChange={changeMethods.onParametersChange}
|
onChange={changeMethods.onParametersChange}
|
||||||
className={FIELD_TEXT_MAP[field].className || field}
|
className={FIELD_TEXT_MAP[field].className || field}
|
||||||
/>
|
/>
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ import {
|
|||||||
} from './CommonParameters';
|
} from './CommonParameters';
|
||||||
import { validatedInputField } from './ValidatedInputField';
|
import { validatedInputField } from './ValidatedInputField';
|
||||||
import { EncryptedField } from './EncryptedField';
|
import { EncryptedField } from './EncryptedField';
|
||||||
|
import { OAuth2ClientField } from './OAuth2ClientField';
|
||||||
import { TableCatalog } from './TableCatalog';
|
import { TableCatalog } from './TableCatalog';
|
||||||
import { formScrollableStyles, validatedFormStyles } from '../styles';
|
import { formScrollableStyles, validatedFormStyles } from '../styles';
|
||||||
import { DatabaseForm, DatabaseObject } from '../../types';
|
import { DatabaseForm, DatabaseObject } from '../../types';
|
||||||
@@ -67,6 +68,7 @@ export const FormFieldOrder = [
|
|||||||
'warehouse',
|
'warehouse',
|
||||||
'role',
|
'role',
|
||||||
'ssh',
|
'ssh',
|
||||||
|
'oauth2_client',
|
||||||
];
|
];
|
||||||
|
|
||||||
const extensionsRegistry = getExtensionsRegistry();
|
const extensionsRegistry = getExtensionsRegistry();
|
||||||
@@ -84,6 +86,7 @@ const FORM_FIELD_MAP = {
|
|||||||
default_schema: defaultSchemaField,
|
default_schema: defaultSchemaField,
|
||||||
username: usernameField,
|
username: usernameField,
|
||||||
password: passwordField,
|
password: passwordField,
|
||||||
|
oauth2_client: OAuth2ClientField,
|
||||||
access_token: accessTokenField,
|
access_token: accessTokenField,
|
||||||
database_name: displayField,
|
database_name: displayField,
|
||||||
query: queryField,
|
query: queryField,
|
||||||
@@ -118,6 +121,9 @@ interface DatabaseConnectionFormProps {
|
|||||||
onExtraInputChange: (
|
onExtraInputChange: (
|
||||||
event: FormEvent<InputProps> | { target: HTMLInputElement },
|
event: FormEvent<InputProps> | { target: HTMLInputElement },
|
||||||
) => void;
|
) => void;
|
||||||
|
onEncryptedExtraInputChange: (
|
||||||
|
event: FormEvent<InputProps> | { target: HTMLInputElement },
|
||||||
|
) => void;
|
||||||
onAddTableCatalog: () => void;
|
onAddTableCatalog: () => void;
|
||||||
onRemoveTableCatalog: (idx: number) => void;
|
onRemoveTableCatalog: (idx: number) => void;
|
||||||
validationErrors: JsonObject | null;
|
validationErrors: JsonObject | null;
|
||||||
@@ -136,6 +142,7 @@ const DatabaseConnectionForm = ({
|
|||||||
onAddTableCatalog,
|
onAddTableCatalog,
|
||||||
onChange,
|
onChange,
|
||||||
onExtraInputChange,
|
onExtraInputChange,
|
||||||
|
onEncryptedExtraInputChange,
|
||||||
onParametersChange,
|
onParametersChange,
|
||||||
onParametersUploadFileChange,
|
onParametersUploadFileChange,
|
||||||
onQueryChange,
|
onQueryChange,
|
||||||
@@ -171,6 +178,7 @@ const DatabaseConnectionForm = ({
|
|||||||
onAddTableCatalog,
|
onAddTableCatalog,
|
||||||
onRemoveTableCatalog,
|
onRemoveTableCatalog,
|
||||||
onExtraInputChange,
|
onExtraInputChange,
|
||||||
|
onEncryptedExtraInputChange,
|
||||||
},
|
},
|
||||||
validationErrors,
|
validationErrors,
|
||||||
getValidation,
|
getValidation,
|
||||||
|
|||||||
@@ -154,6 +154,7 @@ export enum ActionType {
|
|||||||
EditorChange,
|
EditorChange,
|
||||||
ExtraEditorChange,
|
ExtraEditorChange,
|
||||||
ExtraInputChange,
|
ExtraInputChange,
|
||||||
|
EncryptedExtraInputChange,
|
||||||
Fetched,
|
Fetched,
|
||||||
InputChange,
|
InputChange,
|
||||||
ParametersChange,
|
ParametersChange,
|
||||||
@@ -185,6 +186,7 @@ export type DBReducerActionType =
|
|||||||
type:
|
type:
|
||||||
| ActionType.ExtraEditorChange
|
| ActionType.ExtraEditorChange
|
||||||
| ActionType.ExtraInputChange
|
| ActionType.ExtraInputChange
|
||||||
|
| ActionType.EncryptedExtraInputChange
|
||||||
| ActionType.TextChange
|
| ActionType.TextChange
|
||||||
| ActionType.QueryChange
|
| ActionType.QueryChange
|
||||||
| ActionType.InputChange
|
| ActionType.InputChange
|
||||||
@@ -269,6 +271,14 @@ export function dbReducer(
|
|||||||
[action.payload.name]: actionPayloadJson,
|
[action.payload.name]: actionPayloadJson,
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
|
case ActionType.EncryptedExtraInputChange:
|
||||||
|
return {
|
||||||
|
...trimmedState,
|
||||||
|
masked_encrypted_extra: JSON.stringify({
|
||||||
|
...JSON.parse(trimmedState.masked_encrypted_extra || '{}'),
|
||||||
|
[action.payload.name]: action.payload.value,
|
||||||
|
}),
|
||||||
|
};
|
||||||
case ActionType.ExtraInputChange:
|
case ActionType.ExtraInputChange:
|
||||||
// "extra" payload in state is a string
|
// "extra" payload in state is a string
|
||||||
if (
|
if (
|
||||||
@@ -1656,6 +1666,16 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
|||||||
value: target.value,
|
value: target.value,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
onEncryptedExtraInputChange={({
|
||||||
|
target,
|
||||||
|
}: {
|
||||||
|
target: HTMLInputElement;
|
||||||
|
}) =>
|
||||||
|
onChange(ActionType.EncryptedExtraInputChange, {
|
||||||
|
name: target.name,
|
||||||
|
value: target.value,
|
||||||
|
})
|
||||||
|
}
|
||||||
onRemoveTableCatalog={(idx: number) => {
|
onRemoveTableCatalog={(idx: number) => {
|
||||||
setDB({
|
setDB({
|
||||||
type: ActionType.RemoveTableCatalogSheet,
|
type: ActionType.RemoveTableCatalogSheet,
|
||||||
|
|||||||
@@ -113,6 +113,7 @@ export type DatabaseObject = {
|
|||||||
supports_file_upload?: boolean;
|
supports_file_upload?: boolean;
|
||||||
disable_ssh_tunneling?: boolean;
|
disable_ssh_tunneling?: boolean;
|
||||||
supports_dynamic_catalog?: boolean;
|
supports_dynamic_catalog?: boolean;
|
||||||
|
supports_oauth2?: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
// SSH Tunnel information
|
// SSH Tunnel information
|
||||||
@@ -301,6 +302,7 @@ export interface FieldPropTypes {
|
|||||||
onRemoveTableCatalog: (idx: number) => void;
|
onRemoveTableCatalog: (idx: number) => void;
|
||||||
} & {
|
} & {
|
||||||
onExtraInputChange: (value: any) => void;
|
onExtraInputChange: (value: any) => void;
|
||||||
|
onEncryptedExtraInputChange: (value: any) => void;
|
||||||
onSSHTunnelParametersChange: CustomEventHandlerType;
|
onSSHTunnelParametersChange: CustomEventHandlerType;
|
||||||
};
|
};
|
||||||
validationErrors: JsonObject | null;
|
validationErrors: JsonObject | null;
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
* under the License.
|
* under the License.
|
||||||
*/
|
*/
|
||||||
import { useCallback, useEffect } from 'react';
|
import { useCallback, useEffect } from 'react';
|
||||||
|
import { ClientErrorObject } from '@superset-ui/core';
|
||||||
import useEffectEvent from 'src/hooks/useEffectEvent';
|
import useEffectEvent from 'src/hooks/useEffectEvent';
|
||||||
import { api, JsonResponse } from './queryApi';
|
import { api, JsonResponse } from './queryApi';
|
||||||
|
|
||||||
@@ -30,7 +31,7 @@ export type FetchCatalogsQueryParams = {
|
|||||||
dbId?: string | number;
|
dbId?: string | number;
|
||||||
forceRefresh: boolean;
|
forceRefresh: boolean;
|
||||||
onSuccess?: (data: CatalogOption[], isRefetched: boolean) => void;
|
onSuccess?: (data: CatalogOption[], isRefetched: boolean) => void;
|
||||||
onError?: () => void;
|
onError?: (error: ClientErrorObject) => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
type Params = Omit<FetchCatalogsQueryParams, 'forceRefresh'>;
|
type Params = Omit<FetchCatalogsQueryParams, 'forceRefresh'>;
|
||||||
@@ -77,6 +78,12 @@ export function useCatalogs(options: Params) {
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (result.isError) {
|
||||||
|
onError?.(result.error);
|
||||||
|
}
|
||||||
|
}, [result.isError, result.error, onError]);
|
||||||
|
|
||||||
const fetchData = useEffectEvent(
|
const fetchData = useEffectEvent(
|
||||||
(dbId: FetchCatalogsQueryParams['dbId'], forceRefresh = false) => {
|
(dbId: FetchCatalogsQueryParams['dbId'], forceRefresh = false) => {
|
||||||
if (dbId && (!result.currentData || forceRefresh)) {
|
if (dbId && (!result.currentData || forceRefresh)) {
|
||||||
@@ -85,7 +92,7 @@ export function useCatalogs(options: Params) {
|
|||||||
onSuccess?.(data || EMPTY_CATALOGS, forceRefresh);
|
onSuccess?.(data || EMPTY_CATALOGS, forceRefresh);
|
||||||
}
|
}
|
||||||
if (isError) {
|
if (isError) {
|
||||||
onError?.();
|
onError?.(result.error);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ export const supersetClientQuery: BaseQueryFn<
|
|||||||
getClientErrorObject(response).then(errorObj => ({
|
getClientErrorObject(response).then(errorObj => ({
|
||||||
error: {
|
error: {
|
||||||
error: errorObj?.message || errorObj?.error || response.statusText,
|
error: errorObj?.message || errorObj?.error || response.statusText,
|
||||||
|
errors: errorObj?.errors || [], // used by <ErrorMessageWithStackTrace />
|
||||||
status: response.status,
|
status: response.status,
|
||||||
},
|
},
|
||||||
})),
|
})),
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
* under the License.
|
* under the License.
|
||||||
*/
|
*/
|
||||||
import { useCallback, useEffect } from 'react';
|
import { useCallback, useEffect } from 'react';
|
||||||
|
import { ClientErrorObject } from '@superset-ui/core';
|
||||||
import useEffectEvent from 'src/hooks/useEffectEvent';
|
import useEffectEvent from 'src/hooks/useEffectEvent';
|
||||||
import { api, JsonResponse } from './queryApi';
|
import { api, JsonResponse } from './queryApi';
|
||||||
|
|
||||||
@@ -31,7 +32,7 @@ export type FetchSchemasQueryParams = {
|
|||||||
catalog?: string;
|
catalog?: string;
|
||||||
forceRefresh: boolean;
|
forceRefresh: boolean;
|
||||||
onSuccess?: (data: SchemaOption[], isRefetched: boolean) => void;
|
onSuccess?: (data: SchemaOption[], isRefetched: boolean) => void;
|
||||||
onError?: () => void;
|
onError?: (error: ClientErrorObject) => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
type Params = Omit<FetchSchemasQueryParams, 'forceRefresh'>;
|
type Params = Omit<FetchSchemasQueryParams, 'forceRefresh'>;
|
||||||
@@ -81,6 +82,12 @@ export function useSchemas(options: Params) {
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (result.isError) {
|
||||||
|
onError?.(result.error);
|
||||||
|
}
|
||||||
|
}, [result.isError, result.error, onError]);
|
||||||
|
|
||||||
const fetchData = useEffectEvent(
|
const fetchData = useEffectEvent(
|
||||||
(
|
(
|
||||||
dbId: FetchSchemasQueryParams['dbId'],
|
dbId: FetchSchemasQueryParams['dbId'],
|
||||||
@@ -94,7 +101,7 @@ export function useSchemas(options: Params) {
|
|||||||
onSuccess?.(data || EMPTY_SCHEMAS, forceRefresh);
|
onSuccess?.(data || EMPTY_SCHEMAS, forceRefresh);
|
||||||
}
|
}
|
||||||
if (isError) {
|
if (isError) {
|
||||||
onError?.();
|
onError?.(result.error);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ from superset.commands.database.test_connection import TestConnectionDatabaseCom
|
|||||||
from superset.daos.database import DatabaseDAO
|
from superset.daos.database import DatabaseDAO
|
||||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
from superset.db_engine_specs.base import GenericDBException
|
from superset.db_engine_specs.base import GenericDBException
|
||||||
from superset.exceptions import SupersetErrorsException
|
from superset.exceptions import OAuth2RedirectError, SupersetErrorsException
|
||||||
from superset.extensions import event_logger, security_manager
|
from superset.extensions import event_logger, security_manager
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
from superset.utils.decorators import on_error, transaction
|
from superset.utils.decorators import on_error, transaction
|
||||||
@@ -55,13 +55,21 @@ class CreateDatabaseCommand(BaseCommand):
|
|||||||
def __init__(self, data: dict[str, Any]):
|
def __init__(self, data: dict[str, Any]):
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
|
|
||||||
@transaction(on_error=partial(on_error, reraise=DatabaseCreateFailedError))
|
@transaction(
|
||||||
|
on_error=partial(on_error, reraise=DatabaseCreateFailedError),
|
||||||
|
allowed=(OAuth2RedirectError,),
|
||||||
|
)
|
||||||
def run(self) -> Model:
|
def run(self) -> Model:
|
||||||
self.validate()
|
self.validate()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Test connection before starting create transaction
|
# Test connection before starting create transaction
|
||||||
TestConnectionDatabaseCommand(self._properties).run()
|
TestConnectionDatabaseCommand(self._properties).run()
|
||||||
|
except OAuth2RedirectError:
|
||||||
|
# If we can't connect to the database due to an OAuth2 error we can still
|
||||||
|
# save the database. Later, the user can sync permissions when setting up
|
||||||
|
# data access rules.
|
||||||
|
return self._create_database()
|
||||||
except (
|
except (
|
||||||
SupersetErrorsException,
|
SupersetErrorsException,
|
||||||
SSHTunnelingNotEnabledError,
|
SSHTunnelingNotEnabledError,
|
||||||
@@ -80,12 +88,6 @@ class CreateDatabaseCommand(BaseCommand):
|
|||||||
)
|
)
|
||||||
raise DatabaseConnectionFailedError() from ex
|
raise DatabaseConnectionFailedError() from ex
|
||||||
|
|
||||||
# when creating a new database we don't need to unmask encrypted extra
|
|
||||||
self._properties["encrypted_extra"] = self._properties.pop(
|
|
||||||
"masked_encrypted_extra",
|
|
||||||
"{}",
|
|
||||||
)
|
|
||||||
|
|
||||||
ssh_tunnel: Optional[SSHTunnel] = None
|
ssh_tunnel: Optional[SSHTunnel] = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -195,6 +197,12 @@ class CreateDatabaseCommand(BaseCommand):
|
|||||||
raise exception
|
raise exception
|
||||||
|
|
||||||
def _create_database(self) -> Database:
|
def _create_database(self) -> Database:
|
||||||
|
# when creating a new database we don't need to unmask encrypted extra
|
||||||
|
self._properties["encrypted_extra"] = self._properties.pop(
|
||||||
|
"masked_encrypted_extra",
|
||||||
|
"{}",
|
||||||
|
)
|
||||||
|
|
||||||
database = DatabaseDAO.create(attributes=self._properties)
|
database = DatabaseDAO.create(attributes=self._properties)
|
||||||
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
||||||
return database
|
return database
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ from superset.databases.ssh_tunnel.models import SSHTunnel
|
|||||||
from superset.databases.utils import make_url_safe
|
from superset.databases.utils import make_url_safe
|
||||||
from superset.errors import ErrorLevel, SupersetErrorType
|
from superset.errors import ErrorLevel, SupersetErrorType
|
||||||
from superset.exceptions import (
|
from superset.exceptions import (
|
||||||
|
OAuth2RedirectError,
|
||||||
SupersetErrorsException,
|
SupersetErrorsException,
|
||||||
SupersetSecurityException,
|
SupersetSecurityException,
|
||||||
SupersetTimeoutException,
|
SupersetTimeoutException,
|
||||||
@@ -162,6 +163,12 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
|||||||
extra={"sqlalchemy_uri": database.sqlalchemy_uri},
|
extra={"sqlalchemy_uri": database.sqlalchemy_uri},
|
||||||
) from ex
|
) from ex
|
||||||
except Exception as ex: # pylint: disable=broad-except
|
except Exception as ex: # pylint: disable=broad-except
|
||||||
|
if (
|
||||||
|
database.is_oauth2_enabled()
|
||||||
|
and database.db_engine_spec.needs_oauth2(ex)
|
||||||
|
):
|
||||||
|
database.start_oauth2_dance()
|
||||||
|
|
||||||
alive = False
|
alive = False
|
||||||
# So we stop losing the original message if any
|
# So we stop losing the original message if any
|
||||||
ex_str = str(ex)
|
ex_str = str(ex)
|
||||||
@@ -197,6 +204,8 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
|||||||
# check for custom errors (wrong username, wrong password, etc)
|
# check for custom errors (wrong username, wrong password, etc)
|
||||||
errors = database.db_engine_spec.extract_errors(ex, self._context)
|
errors = database.db_engine_spec.extract_errors(ex, self._context)
|
||||||
raise SupersetErrorsException(errors) from ex
|
raise SupersetErrorsException(errors) from ex
|
||||||
|
except OAuth2RedirectError:
|
||||||
|
raise
|
||||||
except SupersetSecurityException as ex:
|
except SupersetSecurityException as ex:
|
||||||
event_logger.log_with_context(
|
event_logger.log_with_context(
|
||||||
action=get_log_connection_action(
|
action=get_log_connection_action(
|
||||||
@@ -205,23 +214,13 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
|||||||
engine=database.db_engine_spec.__name__,
|
engine=database.db_engine_spec.__name__,
|
||||||
)
|
)
|
||||||
raise DatabaseSecurityUnsafeError(message=str(ex)) from ex
|
raise DatabaseSecurityUnsafeError(message=str(ex)) from ex
|
||||||
except SupersetTimeoutException as ex:
|
except (SupersetTimeoutException, SSHTunnelingNotEnabledError) as ex:
|
||||||
event_logger.log_with_context(
|
event_logger.log_with_context(
|
||||||
action=get_log_connection_action(
|
action=get_log_connection_action(
|
||||||
"test_connection_error", ssh_tunnel, ex
|
"test_connection_error", ssh_tunnel, ex
|
||||||
),
|
),
|
||||||
engine=database.db_engine_spec.__name__,
|
engine=database.db_engine_spec.__name__,
|
||||||
)
|
)
|
||||||
# bubble up the exception to return a 408
|
|
||||||
raise
|
|
||||||
except SSHTunnelingNotEnabledError as ex:
|
|
||||||
event_logger.log_with_context(
|
|
||||||
action=get_log_connection_action(
|
|
||||||
"test_connection_error", ssh_tunnel, ex
|
|
||||||
),
|
|
||||||
engine=database.db_engine_spec.__name__,
|
|
||||||
)
|
|
||||||
# bubble up the exception to return a 400
|
|
||||||
raise
|
raise
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
event_logger.log_with_context(
|
event_logger.log_with_context(
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ from superset.daos.database import DatabaseDAO
|
|||||||
from superset.daos.dataset import DatasetDAO
|
from superset.daos.dataset import DatasetDAO
|
||||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
from superset.db_engine_specs.base import GenericDBException
|
from superset.db_engine_specs.base import GenericDBException
|
||||||
|
from superset.exceptions import OAuth2RedirectError
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
from superset.utils.decorators import on_error, transaction
|
from superset.utils.decorators import on_error, transaction
|
||||||
|
|
||||||
@@ -80,7 +81,10 @@ class UpdateDatabaseCommand(BaseCommand):
|
|||||||
database = DatabaseDAO.update(self._model, self._properties)
|
database = DatabaseDAO.update(self._model, self._properties)
|
||||||
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
||||||
ssh_tunnel = self._handle_ssh_tunnel(database)
|
ssh_tunnel = self._handle_ssh_tunnel(database)
|
||||||
self._refresh_catalogs(database, original_database_name, ssh_tunnel)
|
try:
|
||||||
|
self._refresh_catalogs(database, original_database_name, ssh_tunnel)
|
||||||
|
except OAuth2RedirectError:
|
||||||
|
pass
|
||||||
|
|
||||||
return database
|
return database
|
||||||
|
|
||||||
@@ -123,6 +127,8 @@ class UpdateDatabaseCommand(BaseCommand):
|
|||||||
force=True,
|
force=True,
|
||||||
ssh_tunnel=ssh_tunnel,
|
ssh_tunnel=ssh_tunnel,
|
||||||
)
|
)
|
||||||
|
except OAuth2RedirectError:
|
||||||
|
raise
|
||||||
except GenericDBException as ex:
|
except GenericDBException as ex:
|
||||||
raise DatabaseConnectionFailedError() from ex
|
raise DatabaseConnectionFailedError() from ex
|
||||||
|
|
||||||
|
|||||||
@@ -107,6 +107,15 @@ class ValidateDatabaseParametersCommand(BaseCommand):
|
|||||||
with closing(engine.raw_connection()) as conn:
|
with closing(engine.raw_connection()) as conn:
|
||||||
alive = engine.dialect.do_ping(conn)
|
alive = engine.dialect.do_ping(conn)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
|
# if the connection failed because OAuth2 is needed, we can save the
|
||||||
|
# database and trigger the OAuth2 flow whenever a user tries to run a
|
||||||
|
# query.
|
||||||
|
if (
|
||||||
|
database.is_oauth2_enabled()
|
||||||
|
and database.db_engine_spec.needs_oauth2(ex)
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
url = make_url_safe(sqlalchemy_uri)
|
url = make_url_safe(sqlalchemy_uri)
|
||||||
context = {
|
context = {
|
||||||
"hostname": url.host,
|
"hostname": url.host,
|
||||||
|
|||||||
@@ -110,6 +110,7 @@ from superset.exceptions import (
|
|||||||
DatabaseNotFoundException,
|
DatabaseNotFoundException,
|
||||||
InvalidPayloadSchemaError,
|
InvalidPayloadSchemaError,
|
||||||
OAuth2Error,
|
OAuth2Error,
|
||||||
|
OAuth2RedirectError,
|
||||||
SupersetErrorsException,
|
SupersetErrorsException,
|
||||||
SupersetException,
|
SupersetException,
|
||||||
SupersetSecurityException,
|
SupersetSecurityException,
|
||||||
@@ -398,7 +399,6 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||||||
|
|
||||||
@expose("/", methods=("POST",))
|
@expose("/", methods=("POST",))
|
||||||
@protect()
|
@protect()
|
||||||
@safe
|
|
||||||
@statsd_metrics
|
@statsd_metrics
|
||||||
@event_logger.log_this_with_context(
|
@event_logger.log_this_with_context(
|
||||||
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.post",
|
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.post",
|
||||||
@@ -462,6 +462,8 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||||||
item["ssh_tunnel"] = mask_password_info(new_model.ssh_tunnel)
|
item["ssh_tunnel"] = mask_password_info(new_model.ssh_tunnel)
|
||||||
|
|
||||||
return self.response(201, id=new_model.id, result=item)
|
return self.response(201, id=new_model.id, result=item)
|
||||||
|
except OAuth2RedirectError:
|
||||||
|
raise
|
||||||
except DatabaseInvalidError as ex:
|
except DatabaseInvalidError as ex:
|
||||||
return self.response_422(message=ex.normalized_messages())
|
return self.response_422(message=ex.normalized_messages())
|
||||||
except DatabaseConnectionFailedError as ex:
|
except DatabaseConnectionFailedError as ex:
|
||||||
@@ -621,7 +623,6 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||||||
|
|
||||||
@expose("/<int:pk>/catalogs/")
|
@expose("/<int:pk>/catalogs/")
|
||||||
@protect()
|
@protect()
|
||||||
@safe
|
|
||||||
@rison(database_catalogs_query_schema)
|
@rison(database_catalogs_query_schema)
|
||||||
@statsd_metrics
|
@statsd_metrics
|
||||||
@event_logger.log_this_with_context(
|
@event_logger.log_this_with_context(
|
||||||
@@ -680,12 +681,13 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||||||
500,
|
500,
|
||||||
message="There was an error connecting to the database",
|
message="There was an error connecting to the database",
|
||||||
)
|
)
|
||||||
|
except OAuth2RedirectError:
|
||||||
|
raise
|
||||||
except SupersetException as ex:
|
except SupersetException as ex:
|
||||||
return self.response(ex.status, message=ex.message)
|
return self.response(ex.status, message=ex.message)
|
||||||
|
|
||||||
@expose("/<int:pk>/schemas/")
|
@expose("/<int:pk>/schemas/")
|
||||||
@protect()
|
@protect()
|
||||||
@safe
|
|
||||||
@rison(database_schemas_query_schema)
|
@rison(database_schemas_query_schema)
|
||||||
@statsd_metrics
|
@statsd_metrics
|
||||||
@event_logger.log_this_with_context(
|
@event_logger.log_this_with_context(
|
||||||
@@ -746,6 +748,8 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||||||
return self.response(
|
return self.response(
|
||||||
500, message="There was an error connecting to the database"
|
500, message="There was an error connecting to the database"
|
||||||
)
|
)
|
||||||
|
except OAuth2RedirectError:
|
||||||
|
raise
|
||||||
except SupersetException as ex:
|
except SupersetException as ex:
|
||||||
return self.response(ex.status, message=ex.message)
|
return self.response(ex.status, message=ex.message)
|
||||||
|
|
||||||
@@ -1894,9 +1898,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||||||
@protect()
|
@protect()
|
||||||
@statsd_metrics
|
@statsd_metrics
|
||||||
@event_logger.log_this_with_context(
|
@event_logger.log_this_with_context(
|
||||||
action=lambda self,
|
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.columnar_upload",
|
||||||
*args,
|
|
||||||
**kwargs: f"{self.__class__.__name__}.columnar_upload",
|
|
||||||
log_to_statsd=False,
|
log_to_statsd=False,
|
||||||
)
|
)
|
||||||
@requires_form_data
|
@requires_form_data
|
||||||
|
|||||||
@@ -136,7 +136,9 @@ builtin_time_grains: dict[str | None, str] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class TimestampExpression(ColumnClause): # pylint: disable=abstract-method, too-many-ancestors
|
class TimestampExpression(
|
||||||
|
ColumnClause
|
||||||
|
): # pylint: disable=abstract-method, too-many-ancestors
|
||||||
def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None:
|
def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None:
|
||||||
"""Sqlalchemy class that can be used to render native column elements respecting
|
"""Sqlalchemy class that can be used to render native column elements respecting
|
||||||
engine-specific quoting rules as part of a string-based expression.
|
engine-specific quoting rules as part of a string-based expression.
|
||||||
@@ -394,9 +396,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||||||
max_column_name_length: int | None = None
|
max_column_name_length: int | None = None
|
||||||
try_remove_schema_from_table_name = True # pylint: disable=invalid-name
|
try_remove_schema_from_table_name = True # pylint: disable=invalid-name
|
||||||
run_multiple_statements_as_one = False
|
run_multiple_statements_as_one = False
|
||||||
custom_errors: dict[
|
custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = (
|
||||||
Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]
|
{}
|
||||||
] = {}
|
)
|
||||||
|
|
||||||
# Whether the engine supports file uploads
|
# Whether the engine supports file uploads
|
||||||
# if True, database will be listed as option in the upload file form
|
# if True, database will be listed as option in the upload file form
|
||||||
@@ -423,8 +425,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||||||
# the user impersonation methods to handle personal tokens.
|
# the user impersonation methods to handle personal tokens.
|
||||||
supports_oauth2 = False
|
supports_oauth2 = False
|
||||||
oauth2_scope = ""
|
oauth2_scope = ""
|
||||||
oauth2_authorization_request_uri = "" # pylint: disable=invalid-name
|
oauth2_authorization_request_uri: str | None = None # pylint: disable=invalid-name
|
||||||
oauth2_token_request_uri = ""
|
oauth2_token_request_uri: str | None = None
|
||||||
|
|
||||||
# Driver-specific exception that should be mapped to OAuth2RedirectError
|
# Driver-specific exception that should be mapped to OAuth2RedirectError
|
||||||
oauth2_exception = OAuth2RedirectError
|
oauth2_exception = OAuth2RedirectError
|
||||||
@@ -473,11 +475,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||||||
# message.
|
# message.
|
||||||
"tab_id": tab_id,
|
"tab_id": tab_id,
|
||||||
}
|
}
|
||||||
oauth2_config = database.get_oauth2_config()
|
config = database.get_oauth2_config()
|
||||||
if oauth2_config is None:
|
if config is None:
|
||||||
raise OAuth2Error("No configuration found for OAuth2")
|
raise OAuth2Error("No configuration found for OAuth2")
|
||||||
|
|
||||||
oauth_url = cls.get_oauth2_authorization_uri(oauth2_config, state)
|
oauth_url = cls.get_oauth2_authorization_uri(config, state)
|
||||||
|
|
||||||
raise OAuth2RedirectError(oauth_url, tab_id, default_redirect_uri)
|
raise OAuth2RedirectError(oauth_url, tab_id, default_redirect_uri)
|
||||||
|
|
||||||
@@ -2196,6 +2198,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||||||
"supports_file_upload": cls.supports_file_upload,
|
"supports_file_upload": cls.supports_file_upload,
|
||||||
"disable_ssh_tunneling": cls.disable_ssh_tunneling,
|
"disable_ssh_tunneling": cls.disable_ssh_tunneling,
|
||||||
"supports_dynamic_catalog": cls.supports_dynamic_catalog,
|
"supports_dynamic_catalog": cls.supports_dynamic_catalog,
|
||||||
|
"supports_oauth2": cls.supports_oauth2,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -14,23 +14,28 @@
|
|||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from re import Pattern
|
from re import Pattern
|
||||||
from typing import Any, Optional, TYPE_CHECKING, TypedDict
|
from typing import Any, Literal, Optional, TYPE_CHECKING, TypedDict
|
||||||
from urllib import parse
|
from urllib import parse
|
||||||
|
|
||||||
|
import requests
|
||||||
from apispec import APISpec
|
from apispec import APISpec
|
||||||
from apispec.ext.marshmallow import MarshmallowPlugin
|
from apispec.ext.marshmallow import MarshmallowPlugin
|
||||||
from cryptography.hazmat.backends import default_backend
|
from cryptography.hazmat.backends import default_backend
|
||||||
from cryptography.hazmat.primitives import serialization
|
from cryptography.hazmat.primitives import serialization
|
||||||
from flask import current_app
|
from flask import current_app, g
|
||||||
from flask_babel import gettext as __
|
from flask_babel import gettext as __
|
||||||
from marshmallow import fields, Schema
|
from marshmallow import fields, Schema
|
||||||
|
from requests.auth import HTTPBasicAuth
|
||||||
from sqlalchemy import types
|
from sqlalchemy import types
|
||||||
from sqlalchemy.engine.reflection import Inspector
|
from sqlalchemy.engine.reflection import Inspector
|
||||||
from sqlalchemy.engine.url import URL
|
from sqlalchemy.engine.url import URL
|
||||||
|
from sqlalchemy.exc import ProgrammingError
|
||||||
|
|
||||||
from superset.constants import TimeGrain, USER_AGENT
|
from superset.constants import TimeGrain, USER_AGENT
|
||||||
from superset.databases.utils import make_url_safe
|
from superset.databases.utils import make_url_safe
|
||||||
@@ -38,6 +43,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType
|
|||||||
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
|
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
|
||||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||||
from superset.models.sql_lab import Query
|
from superset.models.sql_lab import Query
|
||||||
|
from superset.superset_typing import OAuth2ClientConfig, OAuth2TokenResponse
|
||||||
from superset.utils import json
|
from superset.utils import json
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -57,12 +63,29 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class SnowflakeParametersSchema(Schema):
|
class SnowflakeParametersSchema(Schema):
|
||||||
username = fields.Str(required=True)
|
username = fields.Str(
|
||||||
password = fields.Str(required=True)
|
required=False,
|
||||||
account = fields.Str(required=True)
|
allow_none=True,
|
||||||
database = fields.Str(required=True)
|
metadata={"description": "Username"},
|
||||||
role = fields.Str(required=True)
|
)
|
||||||
warehouse = fields.Str(required=True)
|
password = fields.Str(
|
||||||
|
required=False,
|
||||||
|
allow_none=True,
|
||||||
|
metadata={"description": "Password"},
|
||||||
|
)
|
||||||
|
oauth2_client = fields.Str(
|
||||||
|
required=False,
|
||||||
|
allow_none=True,
|
||||||
|
metadata={"description": "OAuth2 client information"},
|
||||||
|
)
|
||||||
|
account = fields.Str(required=True, metadata={"description": "Account name"})
|
||||||
|
database = fields.Str(required=True, metadata={"description": "Database name"})
|
||||||
|
role = fields.Str(
|
||||||
|
required=False,
|
||||||
|
allow_none=True,
|
||||||
|
metadata={"description": "Default role"},
|
||||||
|
)
|
||||||
|
warehouse = fields.Str(required=True, metadata={"description": "Warehouse name"})
|
||||||
|
|
||||||
|
|
||||||
class SnowflakeParametersType(TypedDict):
|
class SnowflakeParametersType(TypedDict):
|
||||||
@@ -74,6 +97,17 @@ class SnowflakeParametersType(TypedDict):
|
|||||||
warehouse: str
|
warehouse: str
|
||||||
|
|
||||||
|
|
||||||
|
SnowflakeParametersKey = Literal[
|
||||||
|
"username",
|
||||||
|
"password",
|
||||||
|
"account",
|
||||||
|
"database",
|
||||||
|
"role",
|
||||||
|
"warehouse",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=too-many-public-methods
|
||||||
class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
||||||
engine = "snowflake"
|
engine = "snowflake"
|
||||||
engine_name = "Snowflake"
|
engine_name = "Snowflake"
|
||||||
@@ -87,6 +121,9 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
|||||||
supports_dynamic_schema = True
|
supports_dynamic_schema = True
|
||||||
supports_catalog = supports_dynamic_catalog = True
|
supports_catalog = supports_dynamic_catalog = True
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
encrypted_extra_sensitive_fields = ["$.oauth2_client_info.secret"]
|
||||||
|
|
||||||
_time_grain_expressions = {
|
_time_grain_expressions = {
|
||||||
None: "{col}",
|
None: "{col}",
|
||||||
TimeGrain.SECOND: "DATE_TRUNC('SECOND', {col})",
|
TimeGrain.SECOND: "DATE_TRUNC('SECOND', {col})",
|
||||||
@@ -123,19 +160,6 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_extra_params(database: "Database") -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Add a user agent to be used in the requests.
|
|
||||||
"""
|
|
||||||
extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database)
|
|
||||||
engine_params: dict[str, Any] = extra.setdefault("engine_params", {})
|
|
||||||
connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {})
|
|
||||||
|
|
||||||
connect_args.setdefault("application", USER_AGENT)
|
|
||||||
|
|
||||||
return extra
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def adjust_engine_params(
|
def adjust_engine_params(
|
||||||
cls,
|
cls,
|
||||||
@@ -278,6 +302,8 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
|||||||
dict[str, Any]
|
dict[str, Any]
|
||||||
] = None,
|
] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
query_keys: list[SnowflakeParametersKey] = ["role", "warehouse"]
|
||||||
|
query = {key: parameters[key] for key in query_keys if parameters.get(key)}
|
||||||
return str(
|
return str(
|
||||||
URL.create(
|
URL.create(
|
||||||
"snowflake",
|
"snowflake",
|
||||||
@@ -285,10 +311,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
|||||||
password=parameters.get("password"),
|
password=parameters.get("password"),
|
||||||
host=parameters.get("account"),
|
host=parameters.get("account"),
|
||||||
database=parameters.get("database"),
|
database=parameters.get("database"),
|
||||||
query={
|
query=query,
|
||||||
"role": parameters.get("role"),
|
|
||||||
"warehouse": parameters.get("warehouse"),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -317,12 +340,9 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
|||||||
) -> list[SupersetError]:
|
) -> list[SupersetError]:
|
||||||
errors: list[SupersetError] = []
|
errors: list[SupersetError] = []
|
||||||
required = {
|
required = {
|
||||||
"warehouse",
|
|
||||||
"username",
|
|
||||||
"database",
|
|
||||||
"account",
|
"account",
|
||||||
"role",
|
"database",
|
||||||
"password",
|
"warehouse",
|
||||||
}
|
}
|
||||||
parameters = properties.get("parameters", {})
|
parameters = properties.get("parameters", {})
|
||||||
present = {key for key in parameters if parameters.get(key, ())}
|
present = {key for key in parameters if parameters.get(key, ())}
|
||||||
@@ -405,3 +425,131 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
|||||||
f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config"
|
f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config"
|
||||||
)
|
)
|
||||||
connect_args["auth"] = snowflake_auth(**auth_params)
|
connect_args["auth"] = snowflake_auth(**auth_params)
|
||||||
|
|
||||||
|
supports_oauth2 = True
|
||||||
|
oauth2_scope = "refresh_token session:role:PUBLIC"
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
oauth2_authorization_request_uri = None
|
||||||
|
oauth2_token_request_uri = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_extra_params(cls, database: "Database") -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Add a user agent to be used in the requests.
|
||||||
|
"""
|
||||||
|
extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database)
|
||||||
|
engine_params: dict[str, Any] = extra.setdefault("engine_params", {})
|
||||||
|
connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {})
|
||||||
|
|
||||||
|
connect_args.setdefault("application", USER_AGENT)
|
||||||
|
|
||||||
|
# populate OAuth2 URLs if not set, since they can be inferred from the account
|
||||||
|
if oauth2_client_info := extra.get("oauth2_client_info"):
|
||||||
|
account = database.url_object.host
|
||||||
|
oauth2_client_info.setdefault(
|
||||||
|
"authorization_request_uri",
|
||||||
|
f"https://{account}.snowflakecomputing.com/oauth/authorize",
|
||||||
|
)
|
||||||
|
oauth2_client_info.setdefault(
|
||||||
|
"token_request_uri",
|
||||||
|
f"https://{account}.snowflakecomputing.com/oauth/token-request",
|
||||||
|
)
|
||||||
|
oauth2_client_info.setdefault("scope", cls.oauth2_scope)
|
||||||
|
|
||||||
|
return extra
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def update_impersonation_config(
|
||||||
|
cls,
|
||||||
|
connect_args: dict[str, Any],
|
||||||
|
uri: str,
|
||||||
|
username: str | None,
|
||||||
|
access_token: str | None,
|
||||||
|
) -> None:
|
||||||
|
if access_token:
|
||||||
|
connect_args.update(
|
||||||
|
{
|
||||||
|
"authenticator": "oauth",
|
||||||
|
"token": access_token,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_url_for_impersonation(
|
||||||
|
cls,
|
||||||
|
url: URL,
|
||||||
|
impersonate_user: bool,
|
||||||
|
username: str | None,
|
||||||
|
access_token: str | None,
|
||||||
|
) -> URL:
|
||||||
|
# force OAuth2
|
||||||
|
if impersonate_user:
|
||||||
|
# remove username/password if present
|
||||||
|
url = url._replace(username="", password="")
|
||||||
|
# remove hardcoded role so that the one from OAuth2 is used
|
||||||
|
url = url.difference_update_query(["role"])
|
||||||
|
|
||||||
|
return url
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(
|
||||||
|
cls,
|
||||||
|
cursor: Any,
|
||||||
|
query: str,
|
||||||
|
database: Database,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
cursor.execute(query)
|
||||||
|
except Exception as ex:
|
||||||
|
if database.is_oauth2_enabled() and cls.needs_oauth2(ex):
|
||||||
|
cls.start_oauth2_dance(database)
|
||||||
|
raise cls.get_dbapi_mapped_exception(ex) from ex
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def needs_oauth2(cls, ex: Exception) -> bool:
|
||||||
|
return (
|
||||||
|
g
|
||||||
|
and g.user
|
||||||
|
and isinstance(ex, ProgrammingError)
|
||||||
|
and "User is empty" in str(ex)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_oauth2_token(
|
||||||
|
cls,
|
||||||
|
config: OAuth2ClientConfig,
|
||||||
|
code: str,
|
||||||
|
) -> OAuth2TokenResponse:
|
||||||
|
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
|
||||||
|
uri = config["token_request_uri"]
|
||||||
|
response = requests.post(
|
||||||
|
uri,
|
||||||
|
data={
|
||||||
|
"code": code,
|
||||||
|
"redirect_uri": config["redirect_uri"],
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
},
|
||||||
|
auth=HTTPBasicAuth(config["id"], config["secret"]),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_oauth2_fresh_token(
|
||||||
|
cls,
|
||||||
|
config: OAuth2ClientConfig,
|
||||||
|
refresh_token: str,
|
||||||
|
) -> OAuth2TokenResponse:
|
||||||
|
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
|
||||||
|
uri = config["token_request_uri"]
|
||||||
|
response = requests.post(
|
||||||
|
uri,
|
||||||
|
data={
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
},
|
||||||
|
auth=HTTPBasicAuth(config["id"], config["secret"]),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
return response.json()
|
||||||
|
|||||||
@@ -844,6 +844,9 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
|||||||
) as inspector:
|
) as inspector:
|
||||||
return self.db_engine_spec.get_schema_names(inspector)
|
return self.db_engine_spec.get_schema_names(inspector)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
|
if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex):
|
||||||
|
self.start_oauth2_dance()
|
||||||
|
|
||||||
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
|
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
|
||||||
|
|
||||||
@cache_util.memoized_func(
|
@cache_util.memoized_func(
|
||||||
@@ -865,6 +868,9 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
|||||||
with self.get_inspector(ssh_tunnel=ssh_tunnel) as inspector:
|
with self.get_inspector(ssh_tunnel=ssh_tunnel) as inspector:
|
||||||
return self.db_engine_spec.get_catalog_names(self, inspector)
|
return self.db_engine_spec.get_catalog_names(self, inspector)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
|
if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex):
|
||||||
|
self.start_oauth2_dance()
|
||||||
|
|
||||||
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
|
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -1096,6 +1102,9 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
|||||||
|
|
||||||
return self.db_engine_spec.get_oauth2_config()
|
return self.db_engine_spec.get_oauth2_config()
|
||||||
|
|
||||||
|
def start_oauth2_dance(self) -> None:
|
||||||
|
return self.db_engine_spec.start_oauth2_dance(self)
|
||||||
|
|
||||||
|
|
||||||
sqla.event.listen(Database, "after_insert", security_manager.database_after_insert)
|
sqla.event.listen(Database, "after_insert", security_manager.database_after_insert)
|
||||||
sqla.event.listen(Database, "after_update", security_manager.database_after_update)
|
sqla.event.listen(Database, "after_update", security_manager.database_after_update)
|
||||||
|
|||||||
@@ -238,6 +238,7 @@ def on_error(
|
|||||||
|
|
||||||
def transaction( # pylint: disable=redefined-outer-name
|
def transaction( # pylint: disable=redefined-outer-name
|
||||||
on_error: Callable[..., Any] | None = on_error,
|
on_error: Callable[..., Any] | None = on_error,
|
||||||
|
allowed: tuple[type[Exception], ...] = (),
|
||||||
) -> Callable[..., Any]:
|
) -> Callable[..., Any]:
|
||||||
"""
|
"""
|
||||||
Perform a "unit of work".
|
Perform a "unit of work".
|
||||||
@@ -246,7 +247,12 @@ def transaction( # pylint: disable=redefined-outer-name
|
|||||||
proved rather complicated, likely due to many architectural facets, and thus has
|
proved rather complicated, likely due to many architectural facets, and thus has
|
||||||
been left for a follow up exercise.
|
been left for a follow up exercise.
|
||||||
|
|
||||||
|
In certain cases it might be desirable to commit even though an exception was
|
||||||
|
raised. For OAuth2, foe example, we use exceptions as a way to signal the client to
|
||||||
|
redirect to the login page. In this case, we ignore the exception and commit.
|
||||||
|
|
||||||
:param on_error: Callback invoked when an exception is caught
|
:param on_error: Callback invoked when an exception is caught
|
||||||
|
:param allowed: Exception types to ignore and not rollback
|
||||||
:see: https://github.com/apache/superset/issues/25108
|
:see: https://github.com/apache/superset/issues/25108
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -259,6 +265,10 @@ def transaction( # pylint: disable=redefined-outer-name
|
|||||||
result = func(*args, **kwargs)
|
result = func(*args, **kwargs)
|
||||||
db.session.commit() # pylint: disable=consider-using-transaction
|
db.session.commit() # pylint: disable=consider-using-transaction
|
||||||
return result
|
return result
|
||||||
|
except allowed:
|
||||||
|
db.session.commit() # pylint: disable=consider-using-transaction
|
||||||
|
raise
|
||||||
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
db.session.rollback() # pylint: disable=consider-using-transaction
|
db.session.rollback() # pylint: disable=consider-using-transaction
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user