mirror of
https://github.com/apache/superset.git
synced 2026-04-28 12:34:23 +00:00
Compare commits
4 Commits
dashboard-
...
snowflake-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2b1f732082 | ||
|
|
da2bc91a32 | ||
|
|
83accd751a | ||
|
|
d6d2277ed6 |
@@ -24,10 +24,11 @@ import {
|
||||
useRef,
|
||||
useCallback,
|
||||
} from 'react';
|
||||
import { styled, SupersetClient, t } from '@superset-ui/core';
|
||||
import { styled, SupersetClient, SupersetError, t } from '@superset-ui/core';
|
||||
import type { LabeledValue as AntdLabeledValue } from 'antd/lib/select';
|
||||
import rison from 'rison';
|
||||
import { AsyncSelect, Select } from 'src/components';
|
||||
import ErrorMessageWithStackTrace from 'src/components/ErrorMessage/ErrorMessageWithStackTrace';
|
||||
import Label from 'src/components/Label';
|
||||
import { FormLabel } from 'src/components/Form';
|
||||
import RefreshLabel from 'src/components/RefreshLabel';
|
||||
@@ -154,6 +155,7 @@ export default function DatabaseSelector({
|
||||
}: DatabaseSelectorProps) {
|
||||
const showCatalogSelector = !!db?.allow_multi_catalog;
|
||||
const [currentDb, setCurrentDb] = useState<DatabaseValue | undefined>();
|
||||
const [errorPayload, setErrorPayload] = useState<SupersetError | null>();
|
||||
const [currentCatalog, setCurrentCatalog] = useState<
|
||||
CatalogOption | null | undefined
|
||||
>(catalog ? { label: catalog, value: catalog, title: catalog } : undefined);
|
||||
@@ -267,6 +269,7 @@ export default function DatabaseSelector({
|
||||
dbId: currentDb?.value,
|
||||
catalog: currentCatalog?.value,
|
||||
onSuccess: (schemas, isFetched) => {
|
||||
setErrorPayload(null);
|
||||
if (schemas.length === 1) {
|
||||
changeSchema(schemas[0]);
|
||||
} else if (
|
||||
@@ -279,7 +282,13 @@ export default function DatabaseSelector({
|
||||
addSuccessToast('List refreshed');
|
||||
}
|
||||
},
|
||||
onError: () => handleError(t('There was an error loading the schemas')),
|
||||
onError: error => {
|
||||
if (error?.errors) {
|
||||
setErrorPayload(error?.errors?.[0]);
|
||||
} else {
|
||||
handleError(t('There was an error loading the schemas'));
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
const schemaOptions = schemaData || EMPTY_SCHEMA_OPTIONS;
|
||||
@@ -299,6 +308,7 @@ export default function DatabaseSelector({
|
||||
} = useCatalogs({
|
||||
dbId: showCatalogSelector ? currentDb?.value : undefined,
|
||||
onSuccess: (catalogs, isFetched) => {
|
||||
setErrorPayload(null);
|
||||
if (!showCatalogSelector) {
|
||||
changeCatalog(null);
|
||||
} else if (catalogs.length === 1) {
|
||||
@@ -315,9 +325,13 @@ export default function DatabaseSelector({
|
||||
addSuccessToast('List refreshed');
|
||||
}
|
||||
},
|
||||
onError: () => {
|
||||
onError: error => {
|
||||
if (showCatalogSelector) {
|
||||
handleError(t('There was an error loading the catalogs'));
|
||||
if (error?.errors) {
|
||||
setErrorPayload(error?.errors?.[0]);
|
||||
} else {
|
||||
handleError(t('There was an error loading the catalogs'));
|
||||
}
|
||||
}
|
||||
},
|
||||
});
|
||||
@@ -423,9 +437,16 @@ export default function DatabaseSelector({
|
||||
);
|
||||
}
|
||||
|
||||
function renderError() {
|
||||
return errorPayload ? (
|
||||
<ErrorMessageWithStackTrace error={errorPayload} source="crud" />
|
||||
) : null;
|
||||
}
|
||||
|
||||
return (
|
||||
<DatabaseSelectorWrapper data-test="DatabaseSelector">
|
||||
{renderDatabaseSelect()}
|
||||
{renderError()}
|
||||
{showCatalogSelector && renderCatalogSelect()}
|
||||
{renderSchemaSelect()}
|
||||
</DatabaseSelectorWrapper>
|
||||
|
||||
@@ -162,7 +162,7 @@ function OAuth2RedirectMessage({
|
||||
>
|
||||
provide authorization
|
||||
</a>{' '}
|
||||
in order to run this query.
|
||||
in order to run this operation.
|
||||
</>
|
||||
);
|
||||
|
||||
|
||||
@@ -161,23 +161,20 @@ export const httpPathField = ({
|
||||
getValidation,
|
||||
validationErrors,
|
||||
db,
|
||||
}: FieldPropTypes) => {
|
||||
console.error(db);
|
||||
return (
|
||||
<ValidatedInput
|
||||
id="http_path_field"
|
||||
name="http_path_field"
|
||||
required={required}
|
||||
value={db?.parameters?.http_path_field}
|
||||
validationMethods={{ onBlur: getValidation }}
|
||||
errorMessage={validationErrors?.http_path}
|
||||
placeholder={t('e.g. sql/protocolv1/o/12345')}
|
||||
label="HTTP Path"
|
||||
onChange={changeMethods.onParametersChange}
|
||||
helpText={t('Copy the name of the HTTP Path of your cluster.')}
|
||||
/>
|
||||
);
|
||||
};
|
||||
}: FieldPropTypes) => (
|
||||
<ValidatedInput
|
||||
id="http_path_field"
|
||||
name="http_path_field"
|
||||
required={required}
|
||||
value={db?.parameters?.http_path_field}
|
||||
validationMethods={{ onBlur: getValidation }}
|
||||
errorMessage={validationErrors?.http_path}
|
||||
placeholder={t('e.g. sql/protocolv1/o/12345')}
|
||||
label="HTTP Path"
|
||||
onChange={changeMethods.onParametersChange}
|
||||
helpText={t('Copy the name of the HTTP Path of your cluster.')}
|
||||
/>
|
||||
);
|
||||
export const usernameField = ({
|
||||
required,
|
||||
changeMethods,
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
import { useState } from 'react';
|
||||
|
||||
import Collapse from 'src/components/Collapse';
|
||||
import { Input } from 'src/components/Input';
|
||||
import { FormItem } from 'src/components/Form';
|
||||
import { FieldPropTypes } from '../../types';
|
||||
|
||||
interface OAuth2ClientInfo {
|
||||
id: string;
|
||||
secret: string;
|
||||
authorization_request_uri: string;
|
||||
token_request_uri: string;
|
||||
scope: string;
|
||||
}
|
||||
|
||||
export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => {
|
||||
const encryptedExtra = JSON.parse(db?.masked_encrypted_extra || '{}');
|
||||
const [oauth2ClientInfo, setOauth2ClientInfo] = useState<OAuth2ClientInfo>({
|
||||
id: encryptedExtra.oauth2_client_info?.id || '',
|
||||
secret: encryptedExtra.oauth2_client_info?.secret || '',
|
||||
authorization_request_uri:
|
||||
encryptedExtra.oauth2_client_info?.authorization_request_uri || '',
|
||||
token_request_uri:
|
||||
encryptedExtra.oauth2_client_info?.token_request_uri || '',
|
||||
scope: encryptedExtra.oauth2_client_info?.scope || '',
|
||||
});
|
||||
|
||||
if (db?.engine_information?.supports_oauth2 !== true) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const handleChange = (key: any) => (e: any) => {
|
||||
const updatedInfo = {
|
||||
...oauth2ClientInfo,
|
||||
[key]: e.target.value,
|
||||
};
|
||||
|
||||
setOauth2ClientInfo(updatedInfo);
|
||||
|
||||
const event = {
|
||||
target: {
|
||||
name: 'oauth2_client_info',
|
||||
value: updatedInfo,
|
||||
},
|
||||
};
|
||||
changeMethods.onEncryptedExtraInputChange(event);
|
||||
};
|
||||
|
||||
return (
|
||||
<Collapse>
|
||||
<Collapse.Panel header="OAuth2 client information" key="1">
|
||||
<FormItem label="Client ID">
|
||||
<Input
|
||||
placeholder="Enter your Client ID"
|
||||
value={oauth2ClientInfo.id}
|
||||
onChange={handleChange('id')}
|
||||
/>
|
||||
</FormItem>
|
||||
<FormItem label="Client Secret">
|
||||
<Input
|
||||
type="password"
|
||||
placeholder="Enter your Client Secret"
|
||||
value={oauth2ClientInfo.secret}
|
||||
onChange={handleChange('secret')}
|
||||
/>
|
||||
</FormItem>
|
||||
<FormItem label="Authorization Request URI">
|
||||
<Input
|
||||
placeholder="https://"
|
||||
value={oauth2ClientInfo.authorization_request_uri}
|
||||
onChange={handleChange('authorization_request_uri')}
|
||||
/>
|
||||
</FormItem>
|
||||
<FormItem label="Token Request URI">
|
||||
<Input
|
||||
placeholder="https://"
|
||||
value={oauth2ClientInfo.token_request_uri}
|
||||
onChange={handleChange('token_request_uri')}
|
||||
/>
|
||||
</FormItem>
|
||||
<FormItem label="Scope">
|
||||
<Input
|
||||
value={oauth2ClientInfo.scope}
|
||||
onChange={handleChange('scope')}
|
||||
/>
|
||||
</FormItem>
|
||||
</Collapse.Panel>
|
||||
</Collapse>
|
||||
);
|
||||
};
|
||||
@@ -22,16 +22,19 @@ import { FieldPropTypes } from '../../types';
|
||||
|
||||
const FIELD_TEXT_MAP = {
|
||||
account: {
|
||||
label: 'Account',
|
||||
helpText: t(
|
||||
'Copy the identifier of the account you are trying to connect to.',
|
||||
),
|
||||
placeholder: t('e.g. xy12345.us-east-2.aws'),
|
||||
},
|
||||
warehouse: {
|
||||
label: 'Warehouse',
|
||||
placeholder: t('e.g. compute_wh'),
|
||||
className: 'form-group-w-50',
|
||||
},
|
||||
role: {
|
||||
label: 'Default role',
|
||||
placeholder: t('e.g. AccountAdmin'),
|
||||
className: 'form-group-w-50',
|
||||
},
|
||||
@@ -54,7 +57,7 @@ export const validatedInputField = ({
|
||||
errorMessage={validationErrors?.[field]}
|
||||
placeholder={FIELD_TEXT_MAP[field].placeholder}
|
||||
helpText={FIELD_TEXT_MAP[field].helpText}
|
||||
label={field}
|
||||
label={FIELD_TEXT_MAP[field].label || field}
|
||||
onChange={changeMethods.onParametersChange}
|
||||
className={FIELD_TEXT_MAP[field].className || field}
|
||||
/>
|
||||
|
||||
@@ -41,6 +41,7 @@ import {
|
||||
} from './CommonParameters';
|
||||
import { validatedInputField } from './ValidatedInputField';
|
||||
import { EncryptedField } from './EncryptedField';
|
||||
import { OAuth2ClientField } from './OAuth2ClientField';
|
||||
import { TableCatalog } from './TableCatalog';
|
||||
import { formScrollableStyles, validatedFormStyles } from '../styles';
|
||||
import { DatabaseForm, DatabaseObject } from '../../types';
|
||||
@@ -67,6 +68,7 @@ export const FormFieldOrder = [
|
||||
'warehouse',
|
||||
'role',
|
||||
'ssh',
|
||||
'oauth2_client',
|
||||
];
|
||||
|
||||
const extensionsRegistry = getExtensionsRegistry();
|
||||
@@ -84,6 +86,7 @@ const FORM_FIELD_MAP = {
|
||||
default_schema: defaultSchemaField,
|
||||
username: usernameField,
|
||||
password: passwordField,
|
||||
oauth2_client: OAuth2ClientField,
|
||||
access_token: accessTokenField,
|
||||
database_name: displayField,
|
||||
query: queryField,
|
||||
@@ -118,6 +121,9 @@ interface DatabaseConnectionFormProps {
|
||||
onExtraInputChange: (
|
||||
event: FormEvent<InputProps> | { target: HTMLInputElement },
|
||||
) => void;
|
||||
onEncryptedExtraInputChange: (
|
||||
event: FormEvent<InputProps> | { target: HTMLInputElement },
|
||||
) => void;
|
||||
onAddTableCatalog: () => void;
|
||||
onRemoveTableCatalog: (idx: number) => void;
|
||||
validationErrors: JsonObject | null;
|
||||
@@ -136,6 +142,7 @@ const DatabaseConnectionForm = ({
|
||||
onAddTableCatalog,
|
||||
onChange,
|
||||
onExtraInputChange,
|
||||
onEncryptedExtraInputChange,
|
||||
onParametersChange,
|
||||
onParametersUploadFileChange,
|
||||
onQueryChange,
|
||||
@@ -171,6 +178,7 @@ const DatabaseConnectionForm = ({
|
||||
onAddTableCatalog,
|
||||
onRemoveTableCatalog,
|
||||
onExtraInputChange,
|
||||
onEncryptedExtraInputChange,
|
||||
},
|
||||
validationErrors,
|
||||
getValidation,
|
||||
|
||||
@@ -154,6 +154,7 @@ export enum ActionType {
|
||||
EditorChange,
|
||||
ExtraEditorChange,
|
||||
ExtraInputChange,
|
||||
EncryptedExtraInputChange,
|
||||
Fetched,
|
||||
InputChange,
|
||||
ParametersChange,
|
||||
@@ -185,6 +186,7 @@ export type DBReducerActionType =
|
||||
type:
|
||||
| ActionType.ExtraEditorChange
|
||||
| ActionType.ExtraInputChange
|
||||
| ActionType.EncryptedExtraInputChange
|
||||
| ActionType.TextChange
|
||||
| ActionType.QueryChange
|
||||
| ActionType.InputChange
|
||||
@@ -269,6 +271,14 @@ export function dbReducer(
|
||||
[action.payload.name]: actionPayloadJson,
|
||||
}),
|
||||
};
|
||||
case ActionType.EncryptedExtraInputChange:
|
||||
return {
|
||||
...trimmedState,
|
||||
masked_encrypted_extra: JSON.stringify({
|
||||
...JSON.parse(trimmedState.masked_encrypted_extra || '{}'),
|
||||
[action.payload.name]: action.payload.value,
|
||||
}),
|
||||
};
|
||||
case ActionType.ExtraInputChange:
|
||||
// "extra" payload in state is a string
|
||||
if (
|
||||
@@ -1656,6 +1666,16 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
||||
value: target.value,
|
||||
})
|
||||
}
|
||||
onEncryptedExtraInputChange={({
|
||||
target,
|
||||
}: {
|
||||
target: HTMLInputElement;
|
||||
}) =>
|
||||
onChange(ActionType.EncryptedExtraInputChange, {
|
||||
name: target.name,
|
||||
value: target.value,
|
||||
})
|
||||
}
|
||||
onRemoveTableCatalog={(idx: number) => {
|
||||
setDB({
|
||||
type: ActionType.RemoveTableCatalogSheet,
|
||||
|
||||
@@ -113,6 +113,7 @@ export type DatabaseObject = {
|
||||
supports_file_upload?: boolean;
|
||||
disable_ssh_tunneling?: boolean;
|
||||
supports_dynamic_catalog?: boolean;
|
||||
supports_oauth2?: boolean;
|
||||
};
|
||||
|
||||
// SSH Tunnel information
|
||||
@@ -301,6 +302,7 @@ export interface FieldPropTypes {
|
||||
onRemoveTableCatalog: (idx: number) => void;
|
||||
} & {
|
||||
onExtraInputChange: (value: any) => void;
|
||||
onEncryptedExtraInputChange: (value: any) => void;
|
||||
onSSHTunnelParametersChange: CustomEventHandlerType;
|
||||
};
|
||||
validationErrors: JsonObject | null;
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
* under the License.
|
||||
*/
|
||||
import { useCallback, useEffect } from 'react';
|
||||
import { ClientErrorObject } from '@superset-ui/core';
|
||||
import useEffectEvent from 'src/hooks/useEffectEvent';
|
||||
import { api, JsonResponse } from './queryApi';
|
||||
|
||||
@@ -30,7 +31,7 @@ export type FetchCatalogsQueryParams = {
|
||||
dbId?: string | number;
|
||||
forceRefresh: boolean;
|
||||
onSuccess?: (data: CatalogOption[], isRefetched: boolean) => void;
|
||||
onError?: () => void;
|
||||
onError?: (error: ClientErrorObject) => void;
|
||||
};
|
||||
|
||||
type Params = Omit<FetchCatalogsQueryParams, 'forceRefresh'>;
|
||||
@@ -77,6 +78,12 @@ export function useCatalogs(options: Params) {
|
||||
},
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (result.isError) {
|
||||
onError?.(result.error);
|
||||
}
|
||||
}, [result.isError, result.error, onError]);
|
||||
|
||||
const fetchData = useEffectEvent(
|
||||
(dbId: FetchCatalogsQueryParams['dbId'], forceRefresh = false) => {
|
||||
if (dbId && (!result.currentData || forceRefresh)) {
|
||||
@@ -85,7 +92,7 @@ export function useCatalogs(options: Params) {
|
||||
onSuccess?.(data || EMPTY_CATALOGS, forceRefresh);
|
||||
}
|
||||
if (isError) {
|
||||
onError?.();
|
||||
onError?.(result.error);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -64,6 +64,7 @@ export const supersetClientQuery: BaseQueryFn<
|
||||
getClientErrorObject(response).then(errorObj => ({
|
||||
error: {
|
||||
error: errorObj?.message || errorObj?.error || response.statusText,
|
||||
errors: errorObj?.errors || [], // used by <ErrorMessageWithStackTrace />
|
||||
status: response.status,
|
||||
},
|
||||
})),
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
* under the License.
|
||||
*/
|
||||
import { useCallback, useEffect } from 'react';
|
||||
import { ClientErrorObject } from '@superset-ui/core';
|
||||
import useEffectEvent from 'src/hooks/useEffectEvent';
|
||||
import { api, JsonResponse } from './queryApi';
|
||||
|
||||
@@ -31,7 +32,7 @@ export type FetchSchemasQueryParams = {
|
||||
catalog?: string;
|
||||
forceRefresh: boolean;
|
||||
onSuccess?: (data: SchemaOption[], isRefetched: boolean) => void;
|
||||
onError?: () => void;
|
||||
onError?: (error: ClientErrorObject) => void;
|
||||
};
|
||||
|
||||
type Params = Omit<FetchSchemasQueryParams, 'forceRefresh'>;
|
||||
@@ -81,6 +82,12 @@ export function useSchemas(options: Params) {
|
||||
},
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (result.isError) {
|
||||
onError?.(result.error);
|
||||
}
|
||||
}, [result.isError, result.error, onError]);
|
||||
|
||||
const fetchData = useEffectEvent(
|
||||
(
|
||||
dbId: FetchSchemasQueryParams['dbId'],
|
||||
@@ -94,7 +101,7 @@ export function useSchemas(options: Params) {
|
||||
onSuccess?.(data || EMPTY_SCHEMAS, forceRefresh);
|
||||
}
|
||||
if (isError) {
|
||||
onError?.();
|
||||
onError?.(result.error);
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
@@ -42,7 +42,7 @@ from superset.commands.database.test_connection import TestConnectionDatabaseCom
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.db_engine_specs.base import GenericDBException
|
||||
from superset.exceptions import SupersetErrorsException
|
||||
from superset.exceptions import OAuth2RedirectError, SupersetErrorsException
|
||||
from superset.extensions import event_logger, security_manager
|
||||
from superset.models.core import Database
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
@@ -55,13 +55,21 @@ class CreateDatabaseCommand(BaseCommand):
|
||||
def __init__(self, data: dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
|
||||
@transaction(on_error=partial(on_error, reraise=DatabaseCreateFailedError))
|
||||
@transaction(
|
||||
on_error=partial(on_error, reraise=DatabaseCreateFailedError),
|
||||
allowed=(OAuth2RedirectError,),
|
||||
)
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
|
||||
try:
|
||||
# Test connection before starting create transaction
|
||||
TestConnectionDatabaseCommand(self._properties).run()
|
||||
except OAuth2RedirectError:
|
||||
# If we can't connect to the database due to an OAuth2 error we can still
|
||||
# save the database. Later, the user can sync permissions when setting up
|
||||
# data access rules.
|
||||
return self._create_database()
|
||||
except (
|
||||
SupersetErrorsException,
|
||||
SSHTunnelingNotEnabledError,
|
||||
@@ -80,12 +88,6 @@ class CreateDatabaseCommand(BaseCommand):
|
||||
)
|
||||
raise DatabaseConnectionFailedError() from ex
|
||||
|
||||
# when creating a new database we don't need to unmask encrypted extra
|
||||
self._properties["encrypted_extra"] = self._properties.pop(
|
||||
"masked_encrypted_extra",
|
||||
"{}",
|
||||
)
|
||||
|
||||
ssh_tunnel: Optional[SSHTunnel] = None
|
||||
|
||||
try:
|
||||
@@ -195,6 +197,12 @@ class CreateDatabaseCommand(BaseCommand):
|
||||
raise exception
|
||||
|
||||
def _create_database(self) -> Database:
|
||||
# when creating a new database we don't need to unmask encrypted extra
|
||||
self._properties["encrypted_extra"] = self._properties.pop(
|
||||
"masked_encrypted_extra",
|
||||
"{}",
|
||||
)
|
||||
|
||||
database = DatabaseDAO.create(attributes=self._properties)
|
||||
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
||||
return database
|
||||
|
||||
@@ -41,6 +41,7 @@ from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.errors import ErrorLevel, SupersetErrorType
|
||||
from superset.exceptions import (
|
||||
OAuth2RedirectError,
|
||||
SupersetErrorsException,
|
||||
SupersetSecurityException,
|
||||
SupersetTimeoutException,
|
||||
@@ -162,6 +163,12 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
||||
extra={"sqlalchemy_uri": database.sqlalchemy_uri},
|
||||
) from ex
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
if (
|
||||
database.is_oauth2_enabled()
|
||||
and database.db_engine_spec.needs_oauth2(ex)
|
||||
):
|
||||
database.start_oauth2_dance()
|
||||
|
||||
alive = False
|
||||
# So we stop losing the original message if any
|
||||
ex_str = str(ex)
|
||||
@@ -197,6 +204,8 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
||||
# check for custom errors (wrong username, wrong password, etc)
|
||||
errors = database.db_engine_spec.extract_errors(ex, self._context)
|
||||
raise SupersetErrorsException(errors) from ex
|
||||
except OAuth2RedirectError:
|
||||
raise
|
||||
except SupersetSecurityException as ex:
|
||||
event_logger.log_with_context(
|
||||
action=get_log_connection_action(
|
||||
@@ -205,23 +214,13 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
||||
engine=database.db_engine_spec.__name__,
|
||||
)
|
||||
raise DatabaseSecurityUnsafeError(message=str(ex)) from ex
|
||||
except SupersetTimeoutException as ex:
|
||||
except (SupersetTimeoutException, SSHTunnelingNotEnabledError) as ex:
|
||||
event_logger.log_with_context(
|
||||
action=get_log_connection_action(
|
||||
"test_connection_error", ssh_tunnel, ex
|
||||
),
|
||||
engine=database.db_engine_spec.__name__,
|
||||
)
|
||||
# bubble up the exception to return a 408
|
||||
raise
|
||||
except SSHTunnelingNotEnabledError as ex:
|
||||
event_logger.log_with_context(
|
||||
action=get_log_connection_action(
|
||||
"test_connection_error", ssh_tunnel, ex
|
||||
),
|
||||
engine=database.db_engine_spec.__name__,
|
||||
)
|
||||
# bubble up the exception to return a 400
|
||||
raise
|
||||
except Exception as ex:
|
||||
event_logger.log_with_context(
|
||||
|
||||
@@ -42,6 +42,7 @@ from superset.daos.database import DatabaseDAO
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.db_engine_specs.base import GenericDBException
|
||||
from superset.exceptions import OAuth2RedirectError
|
||||
from superset.models.core import Database
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
@@ -80,7 +81,10 @@ class UpdateDatabaseCommand(BaseCommand):
|
||||
database = DatabaseDAO.update(self._model, self._properties)
|
||||
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
||||
ssh_tunnel = self._handle_ssh_tunnel(database)
|
||||
self._refresh_catalogs(database, original_database_name, ssh_tunnel)
|
||||
try:
|
||||
self._refresh_catalogs(database, original_database_name, ssh_tunnel)
|
||||
except OAuth2RedirectError:
|
||||
pass
|
||||
|
||||
return database
|
||||
|
||||
@@ -123,6 +127,8 @@ class UpdateDatabaseCommand(BaseCommand):
|
||||
force=True,
|
||||
ssh_tunnel=ssh_tunnel,
|
||||
)
|
||||
except OAuth2RedirectError:
|
||||
raise
|
||||
except GenericDBException as ex:
|
||||
raise DatabaseConnectionFailedError() from ex
|
||||
|
||||
|
||||
@@ -107,6 +107,15 @@ class ValidateDatabaseParametersCommand(BaseCommand):
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
alive = engine.dialect.do_ping(conn)
|
||||
except Exception as ex:
|
||||
# if the connection failed because OAuth2 is needed, we can save the
|
||||
# database and trigger the OAuth2 flow whenever a user tries to run a
|
||||
# query.
|
||||
if (
|
||||
database.is_oauth2_enabled()
|
||||
and database.db_engine_spec.needs_oauth2(ex)
|
||||
):
|
||||
return
|
||||
|
||||
url = make_url_safe(sqlalchemy_uri)
|
||||
context = {
|
||||
"hostname": url.host,
|
||||
|
||||
@@ -110,6 +110,7 @@ from superset.exceptions import (
|
||||
DatabaseNotFoundException,
|
||||
InvalidPayloadSchemaError,
|
||||
OAuth2Error,
|
||||
OAuth2RedirectError,
|
||||
SupersetErrorsException,
|
||||
SupersetException,
|
||||
SupersetSecurityException,
|
||||
@@ -398,7 +399,6 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||
|
||||
@expose("/", methods=("POST",))
|
||||
@protect()
|
||||
@safe
|
||||
@statsd_metrics
|
||||
@event_logger.log_this_with_context(
|
||||
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.post",
|
||||
@@ -462,6 +462,8 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||
item["ssh_tunnel"] = mask_password_info(new_model.ssh_tunnel)
|
||||
|
||||
return self.response(201, id=new_model.id, result=item)
|
||||
except OAuth2RedirectError:
|
||||
raise
|
||||
except DatabaseInvalidError as ex:
|
||||
return self.response_422(message=ex.normalized_messages())
|
||||
except DatabaseConnectionFailedError as ex:
|
||||
@@ -621,7 +623,6 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||
|
||||
@expose("/<int:pk>/catalogs/")
|
||||
@protect()
|
||||
@safe
|
||||
@rison(database_catalogs_query_schema)
|
||||
@statsd_metrics
|
||||
@event_logger.log_this_with_context(
|
||||
@@ -680,12 +681,13 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||
500,
|
||||
message="There was an error connecting to the database",
|
||||
)
|
||||
except OAuth2RedirectError:
|
||||
raise
|
||||
except SupersetException as ex:
|
||||
return self.response(ex.status, message=ex.message)
|
||||
|
||||
@expose("/<int:pk>/schemas/")
|
||||
@protect()
|
||||
@safe
|
||||
@rison(database_schemas_query_schema)
|
||||
@statsd_metrics
|
||||
@event_logger.log_this_with_context(
|
||||
@@ -746,6 +748,8 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||
return self.response(
|
||||
500, message="There was an error connecting to the database"
|
||||
)
|
||||
except OAuth2RedirectError:
|
||||
raise
|
||||
except SupersetException as ex:
|
||||
return self.response(ex.status, message=ex.message)
|
||||
|
||||
@@ -1894,9 +1898,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||
@protect()
|
||||
@statsd_metrics
|
||||
@event_logger.log_this_with_context(
|
||||
action=lambda self,
|
||||
*args,
|
||||
**kwargs: f"{self.__class__.__name__}.columnar_upload",
|
||||
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.columnar_upload",
|
||||
log_to_statsd=False,
|
||||
)
|
||||
@requires_form_data
|
||||
|
||||
@@ -136,7 +136,9 @@ builtin_time_grains: dict[str | None, str] = {
|
||||
}
|
||||
|
||||
|
||||
class TimestampExpression(ColumnClause): # pylint: disable=abstract-method, too-many-ancestors
|
||||
class TimestampExpression(
|
||||
ColumnClause
|
||||
): # pylint: disable=abstract-method, too-many-ancestors
|
||||
def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None:
|
||||
"""Sqlalchemy class that can be used to render native column elements respecting
|
||||
engine-specific quoting rules as part of a string-based expression.
|
||||
@@ -394,9 +396,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
max_column_name_length: int | None = None
|
||||
try_remove_schema_from_table_name = True # pylint: disable=invalid-name
|
||||
run_multiple_statements_as_one = False
|
||||
custom_errors: dict[
|
||||
Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]
|
||||
] = {}
|
||||
custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = (
|
||||
{}
|
||||
)
|
||||
|
||||
# Whether the engine supports file uploads
|
||||
# if True, database will be listed as option in the upload file form
|
||||
@@ -423,8 +425,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
# the user impersonation methods to handle personal tokens.
|
||||
supports_oauth2 = False
|
||||
oauth2_scope = ""
|
||||
oauth2_authorization_request_uri = "" # pylint: disable=invalid-name
|
||||
oauth2_token_request_uri = ""
|
||||
oauth2_authorization_request_uri: str | None = None # pylint: disable=invalid-name
|
||||
oauth2_token_request_uri: str | None = None
|
||||
|
||||
# Driver-specific exception that should be mapped to OAuth2RedirectError
|
||||
oauth2_exception = OAuth2RedirectError
|
||||
@@ -473,11 +475,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
# message.
|
||||
"tab_id": tab_id,
|
||||
}
|
||||
oauth2_config = database.get_oauth2_config()
|
||||
if oauth2_config is None:
|
||||
config = database.get_oauth2_config()
|
||||
if config is None:
|
||||
raise OAuth2Error("No configuration found for OAuth2")
|
||||
|
||||
oauth_url = cls.get_oauth2_authorization_uri(oauth2_config, state)
|
||||
oauth_url = cls.get_oauth2_authorization_uri(config, state)
|
||||
|
||||
raise OAuth2RedirectError(oauth_url, tab_id, default_redirect_uri)
|
||||
|
||||
@@ -2196,6 +2198,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
"supports_file_upload": cls.supports_file_upload,
|
||||
"disable_ssh_tunneling": cls.disable_ssh_tunneling,
|
||||
"supports_dynamic_catalog": cls.supports_dynamic_catalog,
|
||||
"supports_oauth2": cls.supports_oauth2,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -14,23 +14,28 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from re import Pattern
|
||||
from typing import Any, Optional, TYPE_CHECKING, TypedDict
|
||||
from typing import Any, Literal, Optional, TYPE_CHECKING, TypedDict
|
||||
from urllib import parse
|
||||
|
||||
import requests
|
||||
from apispec import APISpec
|
||||
from apispec.ext.marshmallow import MarshmallowPlugin
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from flask import current_app
|
||||
from flask import current_app, g
|
||||
from flask_babel import gettext as __
|
||||
from marshmallow import fields, Schema
|
||||
from requests.auth import HTTPBasicAuth
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.exc import ProgrammingError
|
||||
|
||||
from superset.constants import TimeGrain, USER_AGENT
|
||||
from superset.databases.utils import make_url_safe
|
||||
@@ -38,6 +43,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType
|
||||
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.superset_typing import OAuth2ClientConfig, OAuth2TokenResponse
|
||||
from superset.utils import json
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -57,12 +63,29 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SnowflakeParametersSchema(Schema):
|
||||
username = fields.Str(required=True)
|
||||
password = fields.Str(required=True)
|
||||
account = fields.Str(required=True)
|
||||
database = fields.Str(required=True)
|
||||
role = fields.Str(required=True)
|
||||
warehouse = fields.Str(required=True)
|
||||
username = fields.Str(
|
||||
required=False,
|
||||
allow_none=True,
|
||||
metadata={"description": "Username"},
|
||||
)
|
||||
password = fields.Str(
|
||||
required=False,
|
||||
allow_none=True,
|
||||
metadata={"description": "Password"},
|
||||
)
|
||||
oauth2_client = fields.Str(
|
||||
required=False,
|
||||
allow_none=True,
|
||||
metadata={"description": "OAuth2 client information"},
|
||||
)
|
||||
account = fields.Str(required=True, metadata={"description": "Account name"})
|
||||
database = fields.Str(required=True, metadata={"description": "Database name"})
|
||||
role = fields.Str(
|
||||
required=False,
|
||||
allow_none=True,
|
||||
metadata={"description": "Default role"},
|
||||
)
|
||||
warehouse = fields.Str(required=True, metadata={"description": "Warehouse name"})
|
||||
|
||||
|
||||
class SnowflakeParametersType(TypedDict):
|
||||
@@ -74,6 +97,17 @@ class SnowflakeParametersType(TypedDict):
|
||||
warehouse: str
|
||||
|
||||
|
||||
SnowflakeParametersKey = Literal[
|
||||
"username",
|
||||
"password",
|
||||
"account",
|
||||
"database",
|
||||
"role",
|
||||
"warehouse",
|
||||
]
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
||||
engine = "snowflake"
|
||||
engine_name = "Snowflake"
|
||||
@@ -87,6 +121,9 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
||||
supports_dynamic_schema = True
|
||||
supports_catalog = supports_dynamic_catalog = True
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
encrypted_extra_sensitive_fields = ["$.oauth2_client_info.secret"]
|
||||
|
||||
_time_grain_expressions = {
|
||||
None: "{col}",
|
||||
TimeGrain.SECOND: "DATE_TRUNC('SECOND', {col})",
|
||||
@@ -123,19 +160,6 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
||||
),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_extra_params(database: "Database") -> dict[str, Any]:
|
||||
"""
|
||||
Add a user agent to be used in the requests.
|
||||
"""
|
||||
extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database)
|
||||
engine_params: dict[str, Any] = extra.setdefault("engine_params", {})
|
||||
connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {})
|
||||
|
||||
connect_args.setdefault("application", USER_AGENT)
|
||||
|
||||
return extra
|
||||
|
||||
@classmethod
|
||||
def adjust_engine_params(
|
||||
cls,
|
||||
@@ -278,6 +302,8 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
||||
dict[str, Any]
|
||||
] = None,
|
||||
) -> str:
|
||||
query_keys: list[SnowflakeParametersKey] = ["role", "warehouse"]
|
||||
query = {key: parameters[key] for key in query_keys if parameters.get(key)}
|
||||
return str(
|
||||
URL.create(
|
||||
"snowflake",
|
||||
@@ -285,10 +311,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
||||
password=parameters.get("password"),
|
||||
host=parameters.get("account"),
|
||||
database=parameters.get("database"),
|
||||
query={
|
||||
"role": parameters.get("role"),
|
||||
"warehouse": parameters.get("warehouse"),
|
||||
},
|
||||
query=query,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -317,12 +340,9 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
||||
) -> list[SupersetError]:
|
||||
errors: list[SupersetError] = []
|
||||
required = {
|
||||
"warehouse",
|
||||
"username",
|
||||
"database",
|
||||
"account",
|
||||
"role",
|
||||
"password",
|
||||
"database",
|
||||
"warehouse",
|
||||
}
|
||||
parameters = properties.get("parameters", {})
|
||||
present = {key for key in parameters if parameters.get(key, ())}
|
||||
@@ -405,3 +425,131 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
||||
f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config"
|
||||
)
|
||||
connect_args["auth"] = snowflake_auth(**auth_params)
|
||||
|
||||
supports_oauth2 = True
|
||||
oauth2_scope = "refresh_token session:role:PUBLIC"
|
||||
# pylint: disable=invalid-name
|
||||
oauth2_authorization_request_uri = None
|
||||
oauth2_token_request_uri = None
|
||||
|
||||
@classmethod
|
||||
def get_extra_params(cls, database: "Database") -> dict[str, Any]:
|
||||
"""
|
||||
Add a user agent to be used in the requests.
|
||||
"""
|
||||
extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database)
|
||||
engine_params: dict[str, Any] = extra.setdefault("engine_params", {})
|
||||
connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {})
|
||||
|
||||
connect_args.setdefault("application", USER_AGENT)
|
||||
|
||||
# populate OAuth2 URLs if not set, since they can be inferred from the account
|
||||
if oauth2_client_info := extra.get("oauth2_client_info"):
|
||||
account = database.url_object.host
|
||||
oauth2_client_info.setdefault(
|
||||
"authorization_request_uri",
|
||||
f"https://{account}.snowflakecomputing.com/oauth/authorize",
|
||||
)
|
||||
oauth2_client_info.setdefault(
|
||||
"token_request_uri",
|
||||
f"https://{account}.snowflakecomputing.com/oauth/token-request",
|
||||
)
|
||||
oauth2_client_info.setdefault("scope", cls.oauth2_scope)
|
||||
|
||||
return extra
|
||||
|
||||
@classmethod
|
||||
def update_impersonation_config(
|
||||
cls,
|
||||
connect_args: dict[str, Any],
|
||||
uri: str,
|
||||
username: str | None,
|
||||
access_token: str | None,
|
||||
) -> None:
|
||||
if access_token:
|
||||
connect_args.update(
|
||||
{
|
||||
"authenticator": "oauth",
|
||||
"token": access_token,
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_url_for_impersonation(
|
||||
cls,
|
||||
url: URL,
|
||||
impersonate_user: bool,
|
||||
username: str | None,
|
||||
access_token: str | None,
|
||||
) -> URL:
|
||||
# force OAuth2
|
||||
if impersonate_user:
|
||||
# remove username/password if present
|
||||
url = url._replace(username="", password="")
|
||||
# remove hardcoded role so that the one from OAuth2 is used
|
||||
url = url.difference_update_query(["role"])
|
||||
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
cursor: Any,
|
||||
query: str,
|
||||
database: Database,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
try:
|
||||
cursor.execute(query)
|
||||
except Exception as ex:
|
||||
if database.is_oauth2_enabled() and cls.needs_oauth2(ex):
|
||||
cls.start_oauth2_dance(database)
|
||||
raise cls.get_dbapi_mapped_exception(ex) from ex
|
||||
|
||||
@classmethod
|
||||
def needs_oauth2(cls, ex: Exception) -> bool:
|
||||
return (
|
||||
g
|
||||
and g.user
|
||||
and isinstance(ex, ProgrammingError)
|
||||
and "User is empty" in str(ex)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_oauth2_token(
|
||||
cls,
|
||||
config: OAuth2ClientConfig,
|
||||
code: str,
|
||||
) -> OAuth2TokenResponse:
|
||||
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
|
||||
uri = config["token_request_uri"]
|
||||
response = requests.post(
|
||||
uri,
|
||||
data={
|
||||
"code": code,
|
||||
"redirect_uri": config["redirect_uri"],
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
auth=HTTPBasicAuth(config["id"], config["secret"]),
|
||||
timeout=timeout,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
@classmethod
|
||||
def get_oauth2_fresh_token(
|
||||
cls,
|
||||
config: OAuth2ClientConfig,
|
||||
refresh_token: str,
|
||||
) -> OAuth2TokenResponse:
|
||||
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
|
||||
uri = config["token_request_uri"]
|
||||
response = requests.post(
|
||||
uri,
|
||||
data={
|
||||
"refresh_token": refresh_token,
|
||||
"grant_type": "refresh_token",
|
||||
},
|
||||
auth=HTTPBasicAuth(config["id"], config["secret"]),
|
||||
timeout=timeout,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
@@ -844,6 +844,9 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
||||
) as inspector:
|
||||
return self.db_engine_spec.get_schema_names(inspector)
|
||||
except Exception as ex:
|
||||
if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex):
|
||||
self.start_oauth2_dance()
|
||||
|
||||
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
|
||||
|
||||
@cache_util.memoized_func(
|
||||
@@ -865,6 +868,9 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
||||
with self.get_inspector(ssh_tunnel=ssh_tunnel) as inspector:
|
||||
return self.db_engine_spec.get_catalog_names(self, inspector)
|
||||
except Exception as ex:
|
||||
if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex):
|
||||
self.start_oauth2_dance()
|
||||
|
||||
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
|
||||
|
||||
@property
|
||||
@@ -1096,6 +1102,9 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
||||
|
||||
return self.db_engine_spec.get_oauth2_config()
|
||||
|
||||
def start_oauth2_dance(self) -> None:
|
||||
return self.db_engine_spec.start_oauth2_dance(self)
|
||||
|
||||
|
||||
sqla.event.listen(Database, "after_insert", security_manager.database_after_insert)
|
||||
sqla.event.listen(Database, "after_update", security_manager.database_after_update)
|
||||
|
||||
@@ -238,6 +238,7 @@ def on_error(
|
||||
|
||||
def transaction( # pylint: disable=redefined-outer-name
|
||||
on_error: Callable[..., Any] | None = on_error,
|
||||
allowed: tuple[type[Exception], ...] = (),
|
||||
) -> Callable[..., Any]:
|
||||
"""
|
||||
Perform a "unit of work".
|
||||
@@ -246,7 +247,12 @@ def transaction( # pylint: disable=redefined-outer-name
|
||||
proved rather complicated, likely due to many architectural facets, and thus has
|
||||
been left for a follow up exercise.
|
||||
|
||||
In certain cases it might be desirable to commit even though an exception was
|
||||
raised. For OAuth2, foe example, we use exceptions as a way to signal the client to
|
||||
redirect to the login page. In this case, we ignore the exception and commit.
|
||||
|
||||
:param on_error: Callback invoked when an exception is caught
|
||||
:param allowed: Exception types to ignore and not rollback
|
||||
:see: https://github.com/apache/superset/issues/25108
|
||||
"""
|
||||
|
||||
@@ -259,6 +265,10 @@ def transaction( # pylint: disable=redefined-outer-name
|
||||
result = func(*args, **kwargs)
|
||||
db.session.commit() # pylint: disable=consider-using-transaction
|
||||
return result
|
||||
except allowed:
|
||||
db.session.commit() # pylint: disable=consider-using-transaction
|
||||
raise
|
||||
|
||||
except Exception as ex:
|
||||
db.session.rollback() # pylint: disable=consider-using-transaction
|
||||
|
||||
|
||||
Reference in New Issue
Block a user