Compare commits

...

4 Commits

Author SHA1 Message Date
Beto Dealmeida
2b1f732082 More frontend 2024-09-01 18:58:18 -04:00
Beto Dealmeida
da2bc91a32 Frontend 2024-08-28 17:16:17 -04:00
Beto Dealmeida
83accd751a WIP 2024-08-28 11:35:17 -04:00
Beto Dealmeida
d6d2277ed6 WIP 2024-08-14 16:59:24 -04:00
20 changed files with 461 additions and 92 deletions

View File

@@ -24,10 +24,11 @@ import {
useRef, useRef,
useCallback, useCallback,
} from 'react'; } from 'react';
import { styled, SupersetClient, t } from '@superset-ui/core'; import { styled, SupersetClient, SupersetError, t } from '@superset-ui/core';
import type { LabeledValue as AntdLabeledValue } from 'antd/lib/select'; import type { LabeledValue as AntdLabeledValue } from 'antd/lib/select';
import rison from 'rison'; import rison from 'rison';
import { AsyncSelect, Select } from 'src/components'; import { AsyncSelect, Select } from 'src/components';
import ErrorMessageWithStackTrace from 'src/components/ErrorMessage/ErrorMessageWithStackTrace';
import Label from 'src/components/Label'; import Label from 'src/components/Label';
import { FormLabel } from 'src/components/Form'; import { FormLabel } from 'src/components/Form';
import RefreshLabel from 'src/components/RefreshLabel'; import RefreshLabel from 'src/components/RefreshLabel';
@@ -154,6 +155,7 @@ export default function DatabaseSelector({
}: DatabaseSelectorProps) { }: DatabaseSelectorProps) {
const showCatalogSelector = !!db?.allow_multi_catalog; const showCatalogSelector = !!db?.allow_multi_catalog;
const [currentDb, setCurrentDb] = useState<DatabaseValue | undefined>(); const [currentDb, setCurrentDb] = useState<DatabaseValue | undefined>();
const [errorPayload, setErrorPayload] = useState<SupersetError | null>();
const [currentCatalog, setCurrentCatalog] = useState< const [currentCatalog, setCurrentCatalog] = useState<
CatalogOption | null | undefined CatalogOption | null | undefined
>(catalog ? { label: catalog, value: catalog, title: catalog } : undefined); >(catalog ? { label: catalog, value: catalog, title: catalog } : undefined);
@@ -267,6 +269,7 @@ export default function DatabaseSelector({
dbId: currentDb?.value, dbId: currentDb?.value,
catalog: currentCatalog?.value, catalog: currentCatalog?.value,
onSuccess: (schemas, isFetched) => { onSuccess: (schemas, isFetched) => {
setErrorPayload(null);
if (schemas.length === 1) { if (schemas.length === 1) {
changeSchema(schemas[0]); changeSchema(schemas[0]);
} else if ( } else if (
@@ -279,7 +282,13 @@ export default function DatabaseSelector({
addSuccessToast('List refreshed'); addSuccessToast('List refreshed');
} }
}, },
onError: () => handleError(t('There was an error loading the schemas')), onError: error => {
if (error?.errors) {
setErrorPayload(error?.errors?.[0]);
} else {
handleError(t('There was an error loading the schemas'));
}
},
}); });
const schemaOptions = schemaData || EMPTY_SCHEMA_OPTIONS; const schemaOptions = schemaData || EMPTY_SCHEMA_OPTIONS;
@@ -299,6 +308,7 @@ export default function DatabaseSelector({
} = useCatalogs({ } = useCatalogs({
dbId: showCatalogSelector ? currentDb?.value : undefined, dbId: showCatalogSelector ? currentDb?.value : undefined,
onSuccess: (catalogs, isFetched) => { onSuccess: (catalogs, isFetched) => {
setErrorPayload(null);
if (!showCatalogSelector) { if (!showCatalogSelector) {
changeCatalog(null); changeCatalog(null);
} else if (catalogs.length === 1) { } else if (catalogs.length === 1) {
@@ -315,9 +325,13 @@ export default function DatabaseSelector({
addSuccessToast('List refreshed'); addSuccessToast('List refreshed');
} }
}, },
onError: () => { onError: error => {
if (showCatalogSelector) { if (showCatalogSelector) {
handleError(t('There was an error loading the catalogs')); if (error?.errors) {
setErrorPayload(error?.errors?.[0]);
} else {
handleError(t('There was an error loading the catalogs'));
}
} }
}, },
}); });
@@ -423,9 +437,16 @@ export default function DatabaseSelector({
); );
} }
function renderError() {
return errorPayload ? (
<ErrorMessageWithStackTrace error={errorPayload} source="crud" />
) : null;
}
return ( return (
<DatabaseSelectorWrapper data-test="DatabaseSelector"> <DatabaseSelectorWrapper data-test="DatabaseSelector">
{renderDatabaseSelect()} {renderDatabaseSelect()}
{renderError()}
{showCatalogSelector && renderCatalogSelect()} {showCatalogSelector && renderCatalogSelect()}
{renderSchemaSelect()} {renderSchemaSelect()}
</DatabaseSelectorWrapper> </DatabaseSelectorWrapper>

View File

@@ -162,7 +162,7 @@ function OAuth2RedirectMessage({
> >
provide authorization provide authorization
</a>{' '} </a>{' '}
in order to run this query. in order to run this operation.
</> </>
); );

View File

@@ -161,23 +161,20 @@ export const httpPathField = ({
getValidation, getValidation,
validationErrors, validationErrors,
db, db,
}: FieldPropTypes) => { }: FieldPropTypes) => (
console.error(db); <ValidatedInput
return ( id="http_path_field"
<ValidatedInput name="http_path_field"
id="http_path_field" required={required}
name="http_path_field" value={db?.parameters?.http_path_field}
required={required} validationMethods={{ onBlur: getValidation }}
value={db?.parameters?.http_path_field} errorMessage={validationErrors?.http_path}
validationMethods={{ onBlur: getValidation }} placeholder={t('e.g. sql/protocolv1/o/12345')}
errorMessage={validationErrors?.http_path} label="HTTP Path"
placeholder={t('e.g. sql/protocolv1/o/12345')} onChange={changeMethods.onParametersChange}
label="HTTP Path" helpText={t('Copy the name of the HTTP Path of your cluster.')}
onChange={changeMethods.onParametersChange} />
helpText={t('Copy the name of the HTTP Path of your cluster.')} );
/>
);
};
export const usernameField = ({ export const usernameField = ({
required, required,
changeMethods, changeMethods,

View File

@@ -0,0 +1,109 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { useState } from 'react';
import Collapse from 'src/components/Collapse';
import { Input } from 'src/components/Input';
import { FormItem } from 'src/components/Form';
import { FieldPropTypes } from '../../types';
interface OAuth2ClientInfo {
id: string;
secret: string;
authorization_request_uri: string;
token_request_uri: string;
scope: string;
}
export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => {
const encryptedExtra = JSON.parse(db?.masked_encrypted_extra || '{}');
const [oauth2ClientInfo, setOauth2ClientInfo] = useState<OAuth2ClientInfo>({
id: encryptedExtra.oauth2_client_info?.id || '',
secret: encryptedExtra.oauth2_client_info?.secret || '',
authorization_request_uri:
encryptedExtra.oauth2_client_info?.authorization_request_uri || '',
token_request_uri:
encryptedExtra.oauth2_client_info?.token_request_uri || '',
scope: encryptedExtra.oauth2_client_info?.scope || '',
});
if (db?.engine_information?.supports_oauth2 !== true) {
return null;
}
const handleChange = (key: any) => (e: any) => {
const updatedInfo = {
...oauth2ClientInfo,
[key]: e.target.value,
};
setOauth2ClientInfo(updatedInfo);
const event = {
target: {
name: 'oauth2_client_info',
value: updatedInfo,
},
};
changeMethods.onEncryptedExtraInputChange(event);
};
return (
<Collapse>
<Collapse.Panel header="OAuth2 client information" key="1">
<FormItem label="Client ID">
<Input
placeholder="Enter your Client ID"
value={oauth2ClientInfo.id}
onChange={handleChange('id')}
/>
</FormItem>
<FormItem label="Client Secret">
<Input
type="password"
placeholder="Enter your Client Secret"
value={oauth2ClientInfo.secret}
onChange={handleChange('secret')}
/>
</FormItem>
<FormItem label="Authorization Request URI">
<Input
placeholder="https://"
value={oauth2ClientInfo.authorization_request_uri}
onChange={handleChange('authorization_request_uri')}
/>
</FormItem>
<FormItem label="Token Request URI">
<Input
placeholder="https://"
value={oauth2ClientInfo.token_request_uri}
onChange={handleChange('token_request_uri')}
/>
</FormItem>
<FormItem label="Scope">
<Input
value={oauth2ClientInfo.scope}
onChange={handleChange('scope')}
/>
</FormItem>
</Collapse.Panel>
</Collapse>
);
};

View File

@@ -22,16 +22,19 @@ import { FieldPropTypes } from '../../types';
const FIELD_TEXT_MAP = { const FIELD_TEXT_MAP = {
account: { account: {
label: 'Account',
helpText: t( helpText: t(
'Copy the identifier of the account you are trying to connect to.', 'Copy the identifier of the account you are trying to connect to.',
), ),
placeholder: t('e.g. xy12345.us-east-2.aws'), placeholder: t('e.g. xy12345.us-east-2.aws'),
}, },
warehouse: { warehouse: {
label: 'Warehouse',
placeholder: t('e.g. compute_wh'), placeholder: t('e.g. compute_wh'),
className: 'form-group-w-50', className: 'form-group-w-50',
}, },
role: { role: {
label: 'Default role',
placeholder: t('e.g. AccountAdmin'), placeholder: t('e.g. AccountAdmin'),
className: 'form-group-w-50', className: 'form-group-w-50',
}, },
@@ -54,7 +57,7 @@ export const validatedInputField = ({
errorMessage={validationErrors?.[field]} errorMessage={validationErrors?.[field]}
placeholder={FIELD_TEXT_MAP[field].placeholder} placeholder={FIELD_TEXT_MAP[field].placeholder}
helpText={FIELD_TEXT_MAP[field].helpText} helpText={FIELD_TEXT_MAP[field].helpText}
label={field} label={FIELD_TEXT_MAP[field].label || field}
onChange={changeMethods.onParametersChange} onChange={changeMethods.onParametersChange}
className={FIELD_TEXT_MAP[field].className || field} className={FIELD_TEXT_MAP[field].className || field}
/> />

View File

@@ -41,6 +41,7 @@ import {
} from './CommonParameters'; } from './CommonParameters';
import { validatedInputField } from './ValidatedInputField'; import { validatedInputField } from './ValidatedInputField';
import { EncryptedField } from './EncryptedField'; import { EncryptedField } from './EncryptedField';
import { OAuth2ClientField } from './OAuth2ClientField';
import { TableCatalog } from './TableCatalog'; import { TableCatalog } from './TableCatalog';
import { formScrollableStyles, validatedFormStyles } from '../styles'; import { formScrollableStyles, validatedFormStyles } from '../styles';
import { DatabaseForm, DatabaseObject } from '../../types'; import { DatabaseForm, DatabaseObject } from '../../types';
@@ -67,6 +68,7 @@ export const FormFieldOrder = [
'warehouse', 'warehouse',
'role', 'role',
'ssh', 'ssh',
'oauth2_client',
]; ];
const extensionsRegistry = getExtensionsRegistry(); const extensionsRegistry = getExtensionsRegistry();
@@ -84,6 +86,7 @@ const FORM_FIELD_MAP = {
default_schema: defaultSchemaField, default_schema: defaultSchemaField,
username: usernameField, username: usernameField,
password: passwordField, password: passwordField,
oauth2_client: OAuth2ClientField,
access_token: accessTokenField, access_token: accessTokenField,
database_name: displayField, database_name: displayField,
query: queryField, query: queryField,
@@ -118,6 +121,9 @@ interface DatabaseConnectionFormProps {
onExtraInputChange: ( onExtraInputChange: (
event: FormEvent<InputProps> | { target: HTMLInputElement }, event: FormEvent<InputProps> | { target: HTMLInputElement },
) => void; ) => void;
onEncryptedExtraInputChange: (
event: FormEvent<InputProps> | { target: HTMLInputElement },
) => void;
onAddTableCatalog: () => void; onAddTableCatalog: () => void;
onRemoveTableCatalog: (idx: number) => void; onRemoveTableCatalog: (idx: number) => void;
validationErrors: JsonObject | null; validationErrors: JsonObject | null;
@@ -136,6 +142,7 @@ const DatabaseConnectionForm = ({
onAddTableCatalog, onAddTableCatalog,
onChange, onChange,
onExtraInputChange, onExtraInputChange,
onEncryptedExtraInputChange,
onParametersChange, onParametersChange,
onParametersUploadFileChange, onParametersUploadFileChange,
onQueryChange, onQueryChange,
@@ -171,6 +178,7 @@ const DatabaseConnectionForm = ({
onAddTableCatalog, onAddTableCatalog,
onRemoveTableCatalog, onRemoveTableCatalog,
onExtraInputChange, onExtraInputChange,
onEncryptedExtraInputChange,
}, },
validationErrors, validationErrors,
getValidation, getValidation,

View File

@@ -154,6 +154,7 @@ export enum ActionType {
EditorChange, EditorChange,
ExtraEditorChange, ExtraEditorChange,
ExtraInputChange, ExtraInputChange,
EncryptedExtraInputChange,
Fetched, Fetched,
InputChange, InputChange,
ParametersChange, ParametersChange,
@@ -185,6 +186,7 @@ export type DBReducerActionType =
type: type:
| ActionType.ExtraEditorChange | ActionType.ExtraEditorChange
| ActionType.ExtraInputChange | ActionType.ExtraInputChange
| ActionType.EncryptedExtraInputChange
| ActionType.TextChange | ActionType.TextChange
| ActionType.QueryChange | ActionType.QueryChange
| ActionType.InputChange | ActionType.InputChange
@@ -269,6 +271,14 @@ export function dbReducer(
[action.payload.name]: actionPayloadJson, [action.payload.name]: actionPayloadJson,
}), }),
}; };
case ActionType.EncryptedExtraInputChange:
return {
...trimmedState,
masked_encrypted_extra: JSON.stringify({
...JSON.parse(trimmedState.masked_encrypted_extra || '{}'),
[action.payload.name]: action.payload.value,
}),
};
case ActionType.ExtraInputChange: case ActionType.ExtraInputChange:
// "extra" payload in state is a string // "extra" payload in state is a string
if ( if (
@@ -1656,6 +1666,16 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
value: target.value, value: target.value,
}) })
} }
onEncryptedExtraInputChange={({
target,
}: {
target: HTMLInputElement;
}) =>
onChange(ActionType.EncryptedExtraInputChange, {
name: target.name,
value: target.value,
})
}
onRemoveTableCatalog={(idx: number) => { onRemoveTableCatalog={(idx: number) => {
setDB({ setDB({
type: ActionType.RemoveTableCatalogSheet, type: ActionType.RemoveTableCatalogSheet,

View File

@@ -113,6 +113,7 @@ export type DatabaseObject = {
supports_file_upload?: boolean; supports_file_upload?: boolean;
disable_ssh_tunneling?: boolean; disable_ssh_tunneling?: boolean;
supports_dynamic_catalog?: boolean; supports_dynamic_catalog?: boolean;
supports_oauth2?: boolean;
}; };
// SSH Tunnel information // SSH Tunnel information
@@ -301,6 +302,7 @@ export interface FieldPropTypes {
onRemoveTableCatalog: (idx: number) => void; onRemoveTableCatalog: (idx: number) => void;
} & { } & {
onExtraInputChange: (value: any) => void; onExtraInputChange: (value: any) => void;
onEncryptedExtraInputChange: (value: any) => void;
onSSHTunnelParametersChange: CustomEventHandlerType; onSSHTunnelParametersChange: CustomEventHandlerType;
}; };
validationErrors: JsonObject | null; validationErrors: JsonObject | null;

View File

@@ -17,6 +17,7 @@
* under the License. * under the License.
*/ */
import { useCallback, useEffect } from 'react'; import { useCallback, useEffect } from 'react';
import { ClientErrorObject } from '@superset-ui/core';
import useEffectEvent from 'src/hooks/useEffectEvent'; import useEffectEvent from 'src/hooks/useEffectEvent';
import { api, JsonResponse } from './queryApi'; import { api, JsonResponse } from './queryApi';
@@ -30,7 +31,7 @@ export type FetchCatalogsQueryParams = {
dbId?: string | number; dbId?: string | number;
forceRefresh: boolean; forceRefresh: boolean;
onSuccess?: (data: CatalogOption[], isRefetched: boolean) => void; onSuccess?: (data: CatalogOption[], isRefetched: boolean) => void;
onError?: () => void; onError?: (error: ClientErrorObject) => void;
}; };
type Params = Omit<FetchCatalogsQueryParams, 'forceRefresh'>; type Params = Omit<FetchCatalogsQueryParams, 'forceRefresh'>;
@@ -77,6 +78,12 @@ export function useCatalogs(options: Params) {
}, },
); );
useEffect(() => {
if (result.isError) {
onError?.(result.error);
}
}, [result.isError, result.error, onError]);
const fetchData = useEffectEvent( const fetchData = useEffectEvent(
(dbId: FetchCatalogsQueryParams['dbId'], forceRefresh = false) => { (dbId: FetchCatalogsQueryParams['dbId'], forceRefresh = false) => {
if (dbId && (!result.currentData || forceRefresh)) { if (dbId && (!result.currentData || forceRefresh)) {
@@ -85,7 +92,7 @@ export function useCatalogs(options: Params) {
onSuccess?.(data || EMPTY_CATALOGS, forceRefresh); onSuccess?.(data || EMPTY_CATALOGS, forceRefresh);
} }
if (isError) { if (isError) {
onError?.(); onError?.(result.error);
} }
}); });
} }

View File

@@ -64,6 +64,7 @@ export const supersetClientQuery: BaseQueryFn<
getClientErrorObject(response).then(errorObj => ({ getClientErrorObject(response).then(errorObj => ({
error: { error: {
error: errorObj?.message || errorObj?.error || response.statusText, error: errorObj?.message || errorObj?.error || response.statusText,
errors: errorObj?.errors || [], // used by <ErrorMessageWithStackTrace />
status: response.status, status: response.status,
}, },
})), })),

View File

@@ -17,6 +17,7 @@
* under the License. * under the License.
*/ */
import { useCallback, useEffect } from 'react'; import { useCallback, useEffect } from 'react';
import { ClientErrorObject } from '@superset-ui/core';
import useEffectEvent from 'src/hooks/useEffectEvent'; import useEffectEvent from 'src/hooks/useEffectEvent';
import { api, JsonResponse } from './queryApi'; import { api, JsonResponse } from './queryApi';
@@ -31,7 +32,7 @@ export type FetchSchemasQueryParams = {
catalog?: string; catalog?: string;
forceRefresh: boolean; forceRefresh: boolean;
onSuccess?: (data: SchemaOption[], isRefetched: boolean) => void; onSuccess?: (data: SchemaOption[], isRefetched: boolean) => void;
onError?: () => void; onError?: (error: ClientErrorObject) => void;
}; };
type Params = Omit<FetchSchemasQueryParams, 'forceRefresh'>; type Params = Omit<FetchSchemasQueryParams, 'forceRefresh'>;
@@ -81,6 +82,12 @@ export function useSchemas(options: Params) {
}, },
); );
useEffect(() => {
if (result.isError) {
onError?.(result.error);
}
}, [result.isError, result.error, onError]);
const fetchData = useEffectEvent( const fetchData = useEffectEvent(
( (
dbId: FetchSchemasQueryParams['dbId'], dbId: FetchSchemasQueryParams['dbId'],
@@ -94,7 +101,7 @@ export function useSchemas(options: Params) {
onSuccess?.(data || EMPTY_SCHEMAS, forceRefresh); onSuccess?.(data || EMPTY_SCHEMAS, forceRefresh);
} }
if (isError) { if (isError) {
onError?.(); onError?.(result.error);
} }
}, },
); );

View File

@@ -42,7 +42,7 @@ from superset.commands.database.test_connection import TestConnectionDatabaseCom
from superset.daos.database import DatabaseDAO from superset.daos.database import DatabaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.db_engine_specs.base import GenericDBException from superset.db_engine_specs.base import GenericDBException
from superset.exceptions import SupersetErrorsException from superset.exceptions import OAuth2RedirectError, SupersetErrorsException
from superset.extensions import event_logger, security_manager from superset.extensions import event_logger, security_manager
from superset.models.core import Database from superset.models.core import Database
from superset.utils.decorators import on_error, transaction from superset.utils.decorators import on_error, transaction
@@ -55,13 +55,21 @@ class CreateDatabaseCommand(BaseCommand):
def __init__(self, data: dict[str, Any]): def __init__(self, data: dict[str, Any]):
self._properties = data.copy() self._properties = data.copy()
@transaction(on_error=partial(on_error, reraise=DatabaseCreateFailedError)) @transaction(
on_error=partial(on_error, reraise=DatabaseCreateFailedError),
allowed=(OAuth2RedirectError,),
)
def run(self) -> Model: def run(self) -> Model:
self.validate() self.validate()
try: try:
# Test connection before starting create transaction # Test connection before starting create transaction
TestConnectionDatabaseCommand(self._properties).run() TestConnectionDatabaseCommand(self._properties).run()
except OAuth2RedirectError:
# If we can't connect to the database due to an OAuth2 error we can still
# save the database. Later, the user can sync permissions when setting up
# data access rules.
return self._create_database()
except ( except (
SupersetErrorsException, SupersetErrorsException,
SSHTunnelingNotEnabledError, SSHTunnelingNotEnabledError,
@@ -80,12 +88,6 @@ class CreateDatabaseCommand(BaseCommand):
) )
raise DatabaseConnectionFailedError() from ex raise DatabaseConnectionFailedError() from ex
# when creating a new database we don't need to unmask encrypted extra
self._properties["encrypted_extra"] = self._properties.pop(
"masked_encrypted_extra",
"{}",
)
ssh_tunnel: Optional[SSHTunnel] = None ssh_tunnel: Optional[SSHTunnel] = None
try: try:
@@ -195,6 +197,12 @@ class CreateDatabaseCommand(BaseCommand):
raise exception raise exception
def _create_database(self) -> Database: def _create_database(self) -> Database:
# when creating a new database we don't need to unmask encrypted extra
self._properties["encrypted_extra"] = self._properties.pop(
"masked_encrypted_extra",
"{}",
)
database = DatabaseDAO.create(attributes=self._properties) database = DatabaseDAO.create(attributes=self._properties)
database.set_sqlalchemy_uri(database.sqlalchemy_uri) database.set_sqlalchemy_uri(database.sqlalchemy_uri)
return database return database

View File

@@ -41,6 +41,7 @@ from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe from superset.databases.utils import make_url_safe
from superset.errors import ErrorLevel, SupersetErrorType from superset.errors import ErrorLevel, SupersetErrorType
from superset.exceptions import ( from superset.exceptions import (
OAuth2RedirectError,
SupersetErrorsException, SupersetErrorsException,
SupersetSecurityException, SupersetSecurityException,
SupersetTimeoutException, SupersetTimeoutException,
@@ -162,6 +163,12 @@ class TestConnectionDatabaseCommand(BaseCommand):
extra={"sqlalchemy_uri": database.sqlalchemy_uri}, extra={"sqlalchemy_uri": database.sqlalchemy_uri},
) from ex ) from ex
except Exception as ex: # pylint: disable=broad-except except Exception as ex: # pylint: disable=broad-except
if (
database.is_oauth2_enabled()
and database.db_engine_spec.needs_oauth2(ex)
):
database.start_oauth2_dance()
alive = False alive = False
# So we stop losing the original message if any # So we stop losing the original message if any
ex_str = str(ex) ex_str = str(ex)
@@ -197,6 +204,8 @@ class TestConnectionDatabaseCommand(BaseCommand):
# check for custom errors (wrong username, wrong password, etc) # check for custom errors (wrong username, wrong password, etc)
errors = database.db_engine_spec.extract_errors(ex, self._context) errors = database.db_engine_spec.extract_errors(ex, self._context)
raise SupersetErrorsException(errors) from ex raise SupersetErrorsException(errors) from ex
except OAuth2RedirectError:
raise
except SupersetSecurityException as ex: except SupersetSecurityException as ex:
event_logger.log_with_context( event_logger.log_with_context(
action=get_log_connection_action( action=get_log_connection_action(
@@ -205,23 +214,13 @@ class TestConnectionDatabaseCommand(BaseCommand):
engine=database.db_engine_spec.__name__, engine=database.db_engine_spec.__name__,
) )
raise DatabaseSecurityUnsafeError(message=str(ex)) from ex raise DatabaseSecurityUnsafeError(message=str(ex)) from ex
except SupersetTimeoutException as ex: except (SupersetTimeoutException, SSHTunnelingNotEnabledError) as ex:
event_logger.log_with_context( event_logger.log_with_context(
action=get_log_connection_action( action=get_log_connection_action(
"test_connection_error", ssh_tunnel, ex "test_connection_error", ssh_tunnel, ex
), ),
engine=database.db_engine_spec.__name__, engine=database.db_engine_spec.__name__,
) )
# bubble up the exception to return a 408
raise
except SSHTunnelingNotEnabledError as ex:
event_logger.log_with_context(
action=get_log_connection_action(
"test_connection_error", ssh_tunnel, ex
),
engine=database.db_engine_spec.__name__,
)
# bubble up the exception to return a 400
raise raise
except Exception as ex: except Exception as ex:
event_logger.log_with_context( event_logger.log_with_context(

View File

@@ -42,6 +42,7 @@ from superset.daos.database import DatabaseDAO
from superset.daos.dataset import DatasetDAO from superset.daos.dataset import DatasetDAO
from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.db_engine_specs.base import GenericDBException from superset.db_engine_specs.base import GenericDBException
from superset.exceptions import OAuth2RedirectError
from superset.models.core import Database from superset.models.core import Database
from superset.utils.decorators import on_error, transaction from superset.utils.decorators import on_error, transaction
@@ -80,7 +81,10 @@ class UpdateDatabaseCommand(BaseCommand):
database = DatabaseDAO.update(self._model, self._properties) database = DatabaseDAO.update(self._model, self._properties)
database.set_sqlalchemy_uri(database.sqlalchemy_uri) database.set_sqlalchemy_uri(database.sqlalchemy_uri)
ssh_tunnel = self._handle_ssh_tunnel(database) ssh_tunnel = self._handle_ssh_tunnel(database)
self._refresh_catalogs(database, original_database_name, ssh_tunnel) try:
self._refresh_catalogs(database, original_database_name, ssh_tunnel)
except OAuth2RedirectError:
pass
return database return database
@@ -123,6 +127,8 @@ class UpdateDatabaseCommand(BaseCommand):
force=True, force=True,
ssh_tunnel=ssh_tunnel, ssh_tunnel=ssh_tunnel,
) )
except OAuth2RedirectError:
raise
except GenericDBException as ex: except GenericDBException as ex:
raise DatabaseConnectionFailedError() from ex raise DatabaseConnectionFailedError() from ex

View File

@@ -107,6 +107,15 @@ class ValidateDatabaseParametersCommand(BaseCommand):
with closing(engine.raw_connection()) as conn: with closing(engine.raw_connection()) as conn:
alive = engine.dialect.do_ping(conn) alive = engine.dialect.do_ping(conn)
except Exception as ex: except Exception as ex:
# if the connection failed because OAuth2 is needed, we can save the
# database and trigger the OAuth2 flow whenever a user tries to run a
# query.
if (
database.is_oauth2_enabled()
and database.db_engine_spec.needs_oauth2(ex)
):
return
url = make_url_safe(sqlalchemy_uri) url = make_url_safe(sqlalchemy_uri)
context = { context = {
"hostname": url.host, "hostname": url.host,

View File

@@ -110,6 +110,7 @@ from superset.exceptions import (
DatabaseNotFoundException, DatabaseNotFoundException,
InvalidPayloadSchemaError, InvalidPayloadSchemaError,
OAuth2Error, OAuth2Error,
OAuth2RedirectError,
SupersetErrorsException, SupersetErrorsException,
SupersetException, SupersetException,
SupersetSecurityException, SupersetSecurityException,
@@ -398,7 +399,6 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
@expose("/", methods=("POST",)) @expose("/", methods=("POST",))
@protect() @protect()
@safe
@statsd_metrics @statsd_metrics
@event_logger.log_this_with_context( @event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.post", action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.post",
@@ -462,6 +462,8 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
item["ssh_tunnel"] = mask_password_info(new_model.ssh_tunnel) item["ssh_tunnel"] = mask_password_info(new_model.ssh_tunnel)
return self.response(201, id=new_model.id, result=item) return self.response(201, id=new_model.id, result=item)
except OAuth2RedirectError:
raise
except DatabaseInvalidError as ex: except DatabaseInvalidError as ex:
return self.response_422(message=ex.normalized_messages()) return self.response_422(message=ex.normalized_messages())
except DatabaseConnectionFailedError as ex: except DatabaseConnectionFailedError as ex:
@@ -621,7 +623,6 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
@expose("/<int:pk>/catalogs/") @expose("/<int:pk>/catalogs/")
@protect() @protect()
@safe
@rison(database_catalogs_query_schema) @rison(database_catalogs_query_schema)
@statsd_metrics @statsd_metrics
@event_logger.log_this_with_context( @event_logger.log_this_with_context(
@@ -680,12 +681,13 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
500, 500,
message="There was an error connecting to the database", message="There was an error connecting to the database",
) )
except OAuth2RedirectError:
raise
except SupersetException as ex: except SupersetException as ex:
return self.response(ex.status, message=ex.message) return self.response(ex.status, message=ex.message)
@expose("/<int:pk>/schemas/") @expose("/<int:pk>/schemas/")
@protect() @protect()
@safe
@rison(database_schemas_query_schema) @rison(database_schemas_query_schema)
@statsd_metrics @statsd_metrics
@event_logger.log_this_with_context( @event_logger.log_this_with_context(
@@ -746,6 +748,8 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
return self.response( return self.response(
500, message="There was an error connecting to the database" 500, message="There was an error connecting to the database"
) )
except OAuth2RedirectError:
raise
except SupersetException as ex: except SupersetException as ex:
return self.response(ex.status, message=ex.message) return self.response(ex.status, message=ex.message)
@@ -1894,9 +1898,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
@protect() @protect()
@statsd_metrics @statsd_metrics
@event_logger.log_this_with_context( @event_logger.log_this_with_context(
action=lambda self, action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.columnar_upload",
*args,
**kwargs: f"{self.__class__.__name__}.columnar_upload",
log_to_statsd=False, log_to_statsd=False,
) )
@requires_form_data @requires_form_data

View File

@@ -136,7 +136,9 @@ builtin_time_grains: dict[str | None, str] = {
} }
class TimestampExpression(ColumnClause): # pylint: disable=abstract-method, too-many-ancestors class TimestampExpression(
ColumnClause
): # pylint: disable=abstract-method, too-many-ancestors
def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None: def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None:
"""Sqlalchemy class that can be used to render native column elements respecting """Sqlalchemy class that can be used to render native column elements respecting
engine-specific quoting rules as part of a string-based expression. engine-specific quoting rules as part of a string-based expression.
@@ -394,9 +396,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
max_column_name_length: int | None = None max_column_name_length: int | None = None
try_remove_schema_from_table_name = True # pylint: disable=invalid-name try_remove_schema_from_table_name = True # pylint: disable=invalid-name
run_multiple_statements_as_one = False run_multiple_statements_as_one = False
custom_errors: dict[ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = (
Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]] {}
] = {} )
# Whether the engine supports file uploads # Whether the engine supports file uploads
# if True, database will be listed as option in the upload file form # if True, database will be listed as option in the upload file form
@@ -423,8 +425,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# the user impersonation methods to handle personal tokens. # the user impersonation methods to handle personal tokens.
supports_oauth2 = False supports_oauth2 = False
oauth2_scope = "" oauth2_scope = ""
oauth2_authorization_request_uri = "" # pylint: disable=invalid-name oauth2_authorization_request_uri: str | None = None # pylint: disable=invalid-name
oauth2_token_request_uri = "" oauth2_token_request_uri: str | None = None
# Driver-specific exception that should be mapped to OAuth2RedirectError # Driver-specific exception that should be mapped to OAuth2RedirectError
oauth2_exception = OAuth2RedirectError oauth2_exception = OAuth2RedirectError
@@ -473,11 +475,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# message. # message.
"tab_id": tab_id, "tab_id": tab_id,
} }
oauth2_config = database.get_oauth2_config() config = database.get_oauth2_config()
if oauth2_config is None: if config is None:
raise OAuth2Error("No configuration found for OAuth2") raise OAuth2Error("No configuration found for OAuth2")
oauth_url = cls.get_oauth2_authorization_uri(oauth2_config, state) oauth_url = cls.get_oauth2_authorization_uri(config, state)
raise OAuth2RedirectError(oauth_url, tab_id, default_redirect_uri) raise OAuth2RedirectError(oauth_url, tab_id, default_redirect_uri)
@@ -2196,6 +2198,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"supports_file_upload": cls.supports_file_upload, "supports_file_upload": cls.supports_file_upload,
"disable_ssh_tunneling": cls.disable_ssh_tunneling, "disable_ssh_tunneling": cls.disable_ssh_tunneling,
"supports_dynamic_catalog": cls.supports_dynamic_catalog, "supports_dynamic_catalog": cls.supports_dynamic_catalog,
"supports_oauth2": cls.supports_oauth2,
} }
@classmethod @classmethod

View File

@@ -14,23 +14,28 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from __future__ import annotations
import logging import logging
import re import re
from datetime import datetime from datetime import datetime
from re import Pattern from re import Pattern
from typing import Any, Optional, TYPE_CHECKING, TypedDict from typing import Any, Literal, Optional, TYPE_CHECKING, TypedDict
from urllib import parse from urllib import parse
import requests
from apispec import APISpec from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin from apispec.ext.marshmallow import MarshmallowPlugin
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
from flask import current_app from flask import current_app, g
from flask_babel import gettext as __ from flask_babel import gettext as __
from marshmallow import fields, Schema from marshmallow import fields, Schema
from requests.auth import HTTPBasicAuth
from sqlalchemy import types from sqlalchemy import types
from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL from sqlalchemy.engine.url import URL
from sqlalchemy.exc import ProgrammingError
from superset.constants import TimeGrain, USER_AGENT from superset.constants import TimeGrain, USER_AGENT
from superset.databases.utils import make_url_safe from superset.databases.utils import make_url_safe
@@ -38,6 +43,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query from superset.models.sql_lab import Query
from superset.superset_typing import OAuth2ClientConfig, OAuth2TokenResponse
from superset.utils import json from superset.utils import json
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -57,12 +63,29 @@ logger = logging.getLogger(__name__)
class SnowflakeParametersSchema(Schema): class SnowflakeParametersSchema(Schema):
username = fields.Str(required=True) username = fields.Str(
password = fields.Str(required=True) required=False,
account = fields.Str(required=True) allow_none=True,
database = fields.Str(required=True) metadata={"description": "Username"},
role = fields.Str(required=True) )
warehouse = fields.Str(required=True) password = fields.Str(
required=False,
allow_none=True,
metadata={"description": "Password"},
)
oauth2_client = fields.Str(
required=False,
allow_none=True,
metadata={"description": "OAuth2 client information"},
)
account = fields.Str(required=True, metadata={"description": "Account name"})
database = fields.Str(required=True, metadata={"description": "Database name"})
role = fields.Str(
required=False,
allow_none=True,
metadata={"description": "Default role"},
)
warehouse = fields.Str(required=True, metadata={"description": "Warehouse name"})
class SnowflakeParametersType(TypedDict): class SnowflakeParametersType(TypedDict):
@@ -74,6 +97,17 @@ class SnowflakeParametersType(TypedDict):
warehouse: str warehouse: str
SnowflakeParametersKey = Literal[
"username",
"password",
"account",
"database",
"role",
"warehouse",
]
# pylint: disable=too-many-public-methods
class SnowflakeEngineSpec(PostgresBaseEngineSpec): class SnowflakeEngineSpec(PostgresBaseEngineSpec):
engine = "snowflake" engine = "snowflake"
engine_name = "Snowflake" engine_name = "Snowflake"
@@ -87,6 +121,9 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
supports_dynamic_schema = True supports_dynamic_schema = True
supports_catalog = supports_dynamic_catalog = True supports_catalog = supports_dynamic_catalog = True
# pylint: disable=invalid-name
encrypted_extra_sensitive_fields = ["$.oauth2_client_info.secret"]
_time_grain_expressions = { _time_grain_expressions = {
None: "{col}", None: "{col}",
TimeGrain.SECOND: "DATE_TRUNC('SECOND', {col})", TimeGrain.SECOND: "DATE_TRUNC('SECOND', {col})",
@@ -123,19 +160,6 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
), ),
} }
@staticmethod
def get_extra_params(database: "Database") -> dict[str, Any]:
"""
Add a user agent to be used in the requests.
"""
extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database)
engine_params: dict[str, Any] = extra.setdefault("engine_params", {})
connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {})
connect_args.setdefault("application", USER_AGENT)
return extra
@classmethod @classmethod
def adjust_engine_params( def adjust_engine_params(
cls, cls,
@@ -278,6 +302,8 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
dict[str, Any] dict[str, Any]
] = None, ] = None,
) -> str: ) -> str:
query_keys: list[SnowflakeParametersKey] = ["role", "warehouse"]
query = {key: parameters[key] for key in query_keys if parameters.get(key)}
return str( return str(
URL.create( URL.create(
"snowflake", "snowflake",
@@ -285,10 +311,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
password=parameters.get("password"), password=parameters.get("password"),
host=parameters.get("account"), host=parameters.get("account"),
database=parameters.get("database"), database=parameters.get("database"),
query={ query=query,
"role": parameters.get("role"),
"warehouse": parameters.get("warehouse"),
},
) )
) )
@@ -317,12 +340,9 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
) -> list[SupersetError]: ) -> list[SupersetError]:
errors: list[SupersetError] = [] errors: list[SupersetError] = []
required = { required = {
"warehouse",
"username",
"database",
"account", "account",
"role", "database",
"password", "warehouse",
} }
parameters = properties.get("parameters", {}) parameters = properties.get("parameters", {})
present = {key for key in parameters if parameters.get(key, ())} present = {key for key in parameters if parameters.get(key, ())}
@@ -405,3 +425,131 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config" f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config"
) )
connect_args["auth"] = snowflake_auth(**auth_params) connect_args["auth"] = snowflake_auth(**auth_params)
supports_oauth2 = True
oauth2_scope = "refresh_token session:role:PUBLIC"
# pylint: disable=invalid-name
oauth2_authorization_request_uri = None
oauth2_token_request_uri = None
@classmethod
def get_extra_params(cls, database: "Database") -> dict[str, Any]:
"""
Add a user agent to be used in the requests.
"""
extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database)
engine_params: dict[str, Any] = extra.setdefault("engine_params", {})
connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {})
connect_args.setdefault("application", USER_AGENT)
# populate OAuth2 URLs if not set, since they can be inferred from the account
if oauth2_client_info := extra.get("oauth2_client_info"):
account = database.url_object.host
oauth2_client_info.setdefault(
"authorization_request_uri",
f"https://{account}.snowflakecomputing.com/oauth/authorize",
)
oauth2_client_info.setdefault(
"token_request_uri",
f"https://{account}.snowflakecomputing.com/oauth/token-request",
)
oauth2_client_info.setdefault("scope", cls.oauth2_scope)
return extra
@classmethod
def update_impersonation_config(
cls,
connect_args: dict[str, Any],
uri: str,
username: str | None,
access_token: str | None,
) -> None:
if access_token:
connect_args.update(
{
"authenticator": "oauth",
"token": access_token,
},
)
@classmethod
def get_url_for_impersonation(
cls,
url: URL,
impersonate_user: bool,
username: str | None,
access_token: str | None,
) -> URL:
# force OAuth2
if impersonate_user:
# remove username/password if present
url = url._replace(username="", password="")
# remove hardcoded role so that the one from OAuth2 is used
url = url.difference_update_query(["role"])
return url
@classmethod
def execute(
cls,
cursor: Any,
query: str,
database: Database,
**kwargs: Any,
) -> None:
try:
cursor.execute(query)
except Exception as ex:
if database.is_oauth2_enabled() and cls.needs_oauth2(ex):
cls.start_oauth2_dance(database)
raise cls.get_dbapi_mapped_exception(ex) from ex
@classmethod
def needs_oauth2(cls, ex: Exception) -> bool:
return (
g
and g.user
and isinstance(ex, ProgrammingError)
and "User is empty" in str(ex)
)
@classmethod
def get_oauth2_token(
cls,
config: OAuth2ClientConfig,
code: str,
) -> OAuth2TokenResponse:
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
uri = config["token_request_uri"]
response = requests.post(
uri,
data={
"code": code,
"redirect_uri": config["redirect_uri"],
"grant_type": "authorization_code",
},
auth=HTTPBasicAuth(config["id"], config["secret"]),
timeout=timeout,
)
return response.json()
@classmethod
def get_oauth2_fresh_token(
cls,
config: OAuth2ClientConfig,
refresh_token: str,
) -> OAuth2TokenResponse:
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
uri = config["token_request_uri"]
response = requests.post(
uri,
data={
"refresh_token": refresh_token,
"grant_type": "refresh_token",
},
auth=HTTPBasicAuth(config["id"], config["secret"]),
timeout=timeout,
)
return response.json()

View File

@@ -844,6 +844,9 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
) as inspector: ) as inspector:
return self.db_engine_spec.get_schema_names(inspector) return self.db_engine_spec.get_schema_names(inspector)
except Exception as ex: except Exception as ex:
if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex):
self.start_oauth2_dance()
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
@cache_util.memoized_func( @cache_util.memoized_func(
@@ -865,6 +868,9 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
with self.get_inspector(ssh_tunnel=ssh_tunnel) as inspector: with self.get_inspector(ssh_tunnel=ssh_tunnel) as inspector:
return self.db_engine_spec.get_catalog_names(self, inspector) return self.db_engine_spec.get_catalog_names(self, inspector)
except Exception as ex: except Exception as ex:
if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex):
self.start_oauth2_dance()
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
@property @property
@@ -1096,6 +1102,9 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
return self.db_engine_spec.get_oauth2_config() return self.db_engine_spec.get_oauth2_config()
def start_oauth2_dance(self) -> None:
return self.db_engine_spec.start_oauth2_dance(self)
sqla.event.listen(Database, "after_insert", security_manager.database_after_insert) sqla.event.listen(Database, "after_insert", security_manager.database_after_insert)
sqla.event.listen(Database, "after_update", security_manager.database_after_update) sqla.event.listen(Database, "after_update", security_manager.database_after_update)

View File

@@ -238,6 +238,7 @@ def on_error(
def transaction( # pylint: disable=redefined-outer-name def transaction( # pylint: disable=redefined-outer-name
on_error: Callable[..., Any] | None = on_error, on_error: Callable[..., Any] | None = on_error,
allowed: tuple[type[Exception], ...] = (),
) -> Callable[..., Any]: ) -> Callable[..., Any]:
""" """
Perform a "unit of work". Perform a "unit of work".
@@ -246,7 +247,12 @@ def transaction( # pylint: disable=redefined-outer-name
proved rather complicated, likely due to many architectural facets, and thus has proved rather complicated, likely due to many architectural facets, and thus has
been left for a follow up exercise. been left for a follow up exercise.
In certain cases it might be desirable to commit even though an exception was
raised. For OAuth2, foe example, we use exceptions as a way to signal the client to
redirect to the login page. In this case, we ignore the exception and commit.
:param on_error: Callback invoked when an exception is caught :param on_error: Callback invoked when an exception is caught
:param allowed: Exception types to ignore and not rollback
:see: https://github.com/apache/superset/issues/25108 :see: https://github.com/apache/superset/issues/25108
""" """
@@ -259,6 +265,10 @@ def transaction( # pylint: disable=redefined-outer-name
result = func(*args, **kwargs) result = func(*args, **kwargs)
db.session.commit() # pylint: disable=consider-using-transaction db.session.commit() # pylint: disable=consider-using-transaction
return result return result
except allowed:
db.session.commit() # pylint: disable=consider-using-transaction
raise
except Exception as ex: except Exception as ex:
db.session.rollback() # pylint: disable=consider-using-transaction db.session.rollback() # pylint: disable=consider-using-transaction