mirror of
https://github.com/apache/superset.git
synced 2026-05-04 15:34:18 +00:00
Compare commits
1 Commits
fix-webpac
...
pre-cost-e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
992226527f |
@@ -97,6 +97,10 @@ export const COST_ESTIMATE_STARTED = 'COST_ESTIMATE_STARTED';
|
|||||||
export const COST_ESTIMATE_RETURNED = 'COST_ESTIMATE_RETURNED';
|
export const COST_ESTIMATE_RETURNED = 'COST_ESTIMATE_RETURNED';
|
||||||
export const COST_ESTIMATE_FAILED = 'COST_ESTIMATE_FAILED';
|
export const COST_ESTIMATE_FAILED = 'COST_ESTIMATE_FAILED';
|
||||||
|
|
||||||
|
export const COST_THRESHOLD_CHECK_STARTED = 'COST_THRESHOLD_CHECK_STARTED';
|
||||||
|
export const COST_THRESHOLD_CHECK_RETURNED = 'COST_THRESHOLD_CHECK_RETURNED';
|
||||||
|
export const COST_THRESHOLD_CHECK_FAILED = 'COST_THRESHOLD_CHECK_FAILED';
|
||||||
|
|
||||||
export const CREATE_DATASOURCE_STARTED = 'CREATE_DATASOURCE_STARTED';
|
export const CREATE_DATASOURCE_STARTED = 'CREATE_DATASOURCE_STARTED';
|
||||||
export const CREATE_DATASOURCE_SUCCESS = 'CREATE_DATASOURCE_SUCCESS';
|
export const CREATE_DATASOURCE_SUCCESS = 'CREATE_DATASOURCE_SUCCESS';
|
||||||
export const CREATE_DATASOURCE_FAILED = 'CREATE_DATASOURCE_FAILED';
|
export const CREATE_DATASOURCE_FAILED = 'CREATE_DATASOURCE_FAILED';
|
||||||
@@ -233,6 +237,45 @@ export function estimateQueryCost(queryEditor) {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function checkCostThreshold(queryEditor) {
|
||||||
|
return (dispatch, getState) => {
|
||||||
|
const { dbId, catalog, schema, sql, selectedText, templateParams } =
|
||||||
|
getUpToDateQuery(getState(), queryEditor);
|
||||||
|
const requestSql = selectedText || sql;
|
||||||
|
const postPayload = {
|
||||||
|
database_id: dbId,
|
||||||
|
catalog,
|
||||||
|
schema,
|
||||||
|
sql: requestSql,
|
||||||
|
template_params: JSON.parse(templateParams || '{}'),
|
||||||
|
};
|
||||||
|
return Promise.all([
|
||||||
|
dispatch({ type: COST_THRESHOLD_CHECK_STARTED, query: queryEditor }),
|
||||||
|
SupersetClient.post({
|
||||||
|
endpoint: '/api/v1/sqllab/check_cost_threshold/',
|
||||||
|
body: JSON.stringify(postPayload),
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
})
|
||||||
|
.then(({ json }) =>
|
||||||
|
dispatch({ type: COST_THRESHOLD_CHECK_RETURNED, query: queryEditor, json }),
|
||||||
|
)
|
||||||
|
.catch(response =>
|
||||||
|
getClientErrorObject(response).then(error => {
|
||||||
|
const message =
|
||||||
|
error.error ||
|
||||||
|
error.statusText ||
|
||||||
|
t('Failed at checking cost threshold');
|
||||||
|
return dispatch({
|
||||||
|
type: COST_THRESHOLD_CHECK_FAILED,
|
||||||
|
query: queryEditor,
|
||||||
|
error: message,
|
||||||
|
});
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
export function clearInactiveQueries(interval) {
|
export function clearInactiveQueries(interval) {
|
||||||
return { type: CLEAR_INACTIVE_QUERIES, interval };
|
return { type: CLEAR_INACTIVE_QUERIES, interval };
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,131 @@
|
|||||||
|
/**
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { render, screen, fireEvent } from '@testing-library/react';
|
||||||
|
import { ThemeProvider } from '@superset-ui/core';
|
||||||
|
import { theme } from 'src/preamble';
|
||||||
|
import CostWarningModal from './index';
|
||||||
|
|
||||||
|
const mockProps = {
|
||||||
|
visible: true,
|
||||||
|
onHide: jest.fn(),
|
||||||
|
onProceed: jest.fn(),
|
||||||
|
warningMessage: 'This query will scan 10 GB of data, which exceeds the threshold of 5 GB.',
|
||||||
|
thresholdInfo: {
|
||||||
|
bytes_threshold: 5 * 1024 ** 3, // 5 GB
|
||||||
|
estimated_bytes: 10 * 1024 ** 3, // 10 GB
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const renderWithTheme = (ui: React.ReactElement) =>
|
||||||
|
render(<ThemeProvider theme={theme}>{ui}</ThemeProvider>);
|
||||||
|
|
||||||
|
describe('CostWarningModal', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('renders with warning message', () => {
|
||||||
|
renderWithTheme(<CostWarningModal {...mockProps} />);
|
||||||
|
|
||||||
|
expect(screen.getByText('Query Cost Warning')).toBeInTheDocument();
|
||||||
|
expect(screen.getByText(mockProps.warningMessage)).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('shows threshold details when provided', () => {
|
||||||
|
renderWithTheme(<CostWarningModal {...mockProps} />);
|
||||||
|
|
||||||
|
expect(screen.getByText('Threshold Details:')).toBeInTheDocument();
|
||||||
|
expect(screen.getByText('Data to scan:')).toBeInTheDocument();
|
||||||
|
expect(screen.getByText('10.0 GB')).toBeInTheDocument();
|
||||||
|
expect(screen.getByText('5.0 GB')).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('disables proceed button until checkbox is checked', () => {
|
||||||
|
renderWithTheme(<CostWarningModal {...mockProps} />);
|
||||||
|
|
||||||
|
const proceedButton = screen.getByText('Run Query Anyway');
|
||||||
|
const checkbox = screen.getByRole('checkbox');
|
||||||
|
|
||||||
|
expect(proceedButton).toBeDisabled();
|
||||||
|
|
||||||
|
fireEvent.click(checkbox);
|
||||||
|
expect(proceedButton).not.toBeDisabled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('calls onProceed when proceed button is clicked with checkbox checked', () => {
|
||||||
|
renderWithTheme(<CostWarningModal {...mockProps} />);
|
||||||
|
|
||||||
|
const checkbox = screen.getByRole('checkbox');
|
||||||
|
const proceedButton = screen.getByText('Run Query Anyway');
|
||||||
|
|
||||||
|
fireEvent.click(checkbox);
|
||||||
|
fireEvent.click(proceedButton);
|
||||||
|
|
||||||
|
expect(mockProps.onProceed).toHaveBeenCalledTimes(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('calls onHide when cancel button is clicked', () => {
|
||||||
|
renderWithTheme(<CostWarningModal {...mockProps} />);
|
||||||
|
|
||||||
|
const cancelButton = screen.getByText('Cancel');
|
||||||
|
fireEvent.click(cancelButton);
|
||||||
|
|
||||||
|
expect(mockProps.onHide).toHaveBeenCalledTimes(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('renders without threshold details when not provided', () => {
|
||||||
|
const propsWithoutThreshold = {
|
||||||
|
...mockProps,
|
||||||
|
thresholdInfo: undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
renderWithTheme(<CostWarningModal {...propsWithoutThreshold} />);
|
||||||
|
|
||||||
|
expect(screen.queryByText('Threshold Details:')).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('shows default message when warningMessage is null', () => {
|
||||||
|
const propsWithNoMessage = {
|
||||||
|
...mockProps,
|
||||||
|
warningMessage: null,
|
||||||
|
};
|
||||||
|
|
||||||
|
renderWithTheme(<CostWarningModal {...propsWithNoMessage} />);
|
||||||
|
|
||||||
|
expect(screen.getByText('This query may be expensive to run.')).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('handles cost threshold details', () => {
|
||||||
|
const propsWithCostThreshold = {
|
||||||
|
...mockProps,
|
||||||
|
thresholdInfo: {
|
||||||
|
cost_threshold: 100,
|
||||||
|
estimated_cost: 250,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
renderWithTheme(<CostWarningModal {...propsWithCostThreshold} />);
|
||||||
|
|
||||||
|
expect(screen.getByText('Estimated cost:')).toBeInTheDocument();
|
||||||
|
expect(screen.getByText('250')).toBeInTheDocument();
|
||||||
|
expect(screen.getByText('Cost threshold:')).toBeInTheDocument();
|
||||||
|
expect(screen.getByText('100')).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -0,0 +1,166 @@
|
|||||||
|
/**
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { useState } from 'react';
|
||||||
|
import { styled, t } from '@superset-ui/core';
|
||||||
|
import { Button, Modal, Checkbox } from '@superset-ui/core/components';
|
||||||
|
import { ModalTitleWithIcon } from 'src/components/ModalTitleWithIcon';
|
||||||
|
|
||||||
|
const StyledModal = styled(Modal)`
|
||||||
|
.ant-modal-body {
|
||||||
|
padding: 24px;
|
||||||
|
}
|
||||||
|
`;
|
||||||
|
|
||||||
|
const WarningContent = styled.div`
|
||||||
|
margin: 16px 0;
|
||||||
|
font-size: 14px;
|
||||||
|
line-height: 1.5;
|
||||||
|
`;
|
||||||
|
|
||||||
|
const DetailsSection = styled.div`
|
||||||
|
margin: 16px 0;
|
||||||
|
padding: 12px;
|
||||||
|
background-color: ${({ theme }) => theme.colors.grayscale.light4};
|
||||||
|
border-radius: 4px;
|
||||||
|
font-size: 12px;
|
||||||
|
`;
|
||||||
|
|
||||||
|
const CheckboxWrapper = styled.div`
|
||||||
|
margin: 16px 0;
|
||||||
|
`;
|
||||||
|
|
||||||
|
interface CostWarningModalProps {
|
||||||
|
visible: boolean;
|
||||||
|
onHide: () => void;
|
||||||
|
onProceed: () => void;
|
||||||
|
warningMessage: string | null;
|
||||||
|
thresholdInfo?: {
|
||||||
|
bytes_threshold?: number;
|
||||||
|
estimated_bytes?: number;
|
||||||
|
cost_threshold?: number;
|
||||||
|
estimated_cost?: number;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function CostWarningModal({
|
||||||
|
visible,
|
||||||
|
onHide,
|
||||||
|
onProceed,
|
||||||
|
warningMessage,
|
||||||
|
thresholdInfo,
|
||||||
|
}: CostWarningModalProps) {
|
||||||
|
const [proceedAnyway, setProceedAnyway] = useState(false);
|
||||||
|
|
||||||
|
const handleProceed = () => {
|
||||||
|
if (proceedAnyway) {
|
||||||
|
onProceed();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const formatBytes = (bytes: number) => {
|
||||||
|
if (bytes < 1024) return `${bytes} B`;
|
||||||
|
if (bytes < 1024 ** 2) return `${(bytes / 1024).toFixed(1)} KB`;
|
||||||
|
if (bytes < 1024 ** 3) return `${(bytes / 1024 ** 2).toFixed(1)} MB`;
|
||||||
|
if (bytes < 1024 ** 4) return `${(bytes / 1024 ** 3).toFixed(1)} GB`;
|
||||||
|
if (bytes < 1024 ** 5) return `${(bytes / 1024 ** 4).toFixed(1)} TB`;
|
||||||
|
return `${(bytes / 1024 ** 5).toFixed(1)} PB`;
|
||||||
|
};
|
||||||
|
|
||||||
|
const renderThresholdDetails = () => {
|
||||||
|
if (!thresholdInfo) return null;
|
||||||
|
|
||||||
|
const details = [];
|
||||||
|
|
||||||
|
if (thresholdInfo.bytes_threshold && thresholdInfo.estimated_bytes) {
|
||||||
|
details.push(
|
||||||
|
<div key="bytes">
|
||||||
|
<strong>{t('Data to scan:')}</strong> {formatBytes(thresholdInfo.estimated_bytes)}
|
||||||
|
<br />
|
||||||
|
<strong>{t('Threshold:')}</strong> {formatBytes(thresholdInfo.bytes_threshold)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (thresholdInfo.cost_threshold && thresholdInfo.estimated_cost) {
|
||||||
|
details.push(
|
||||||
|
<div key="cost">
|
||||||
|
<strong>{t('Estimated cost:')}</strong> {thresholdInfo.estimated_cost}
|
||||||
|
<br />
|
||||||
|
<strong>{t('Cost threshold:')}</strong> {thresholdInfo.cost_threshold}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return details.length > 0 ? (
|
||||||
|
<DetailsSection>
|
||||||
|
<div style={{ marginBottom: '8px' }}>
|
||||||
|
<strong>{t('Threshold Details:')}</strong>
|
||||||
|
</div>
|
||||||
|
{details.map((detail, index) => (
|
||||||
|
<div key={index} style={{ marginBottom: index < details.length - 1 ? '8px' : '0' }}>
|
||||||
|
{detail}
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</DetailsSection>
|
||||||
|
) : null;
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<StyledModal
|
||||||
|
show={visible}
|
||||||
|
onHide={onHide}
|
||||||
|
title={
|
||||||
|
<ModalTitleWithIcon
|
||||||
|
icon="exclamation-triangle"
|
||||||
|
title={t('Query Cost Warning')}
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
footer={
|
||||||
|
<>
|
||||||
|
<Button onClick={onHide}>
|
||||||
|
{t('Cancel')}
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
buttonStyle="primary"
|
||||||
|
onClick={handleProceed}
|
||||||
|
disabled={!proceedAnyway}
|
||||||
|
>
|
||||||
|
{t('Run Query Anyway')}
|
||||||
|
</Button>
|
||||||
|
</>
|
||||||
|
}
|
||||||
|
>
|
||||||
|
<WarningContent>
|
||||||
|
{warningMessage || t('This query may be expensive to run.')}
|
||||||
|
</WarningContent>
|
||||||
|
|
||||||
|
{renderThresholdDetails()}
|
||||||
|
|
||||||
|
<CheckboxWrapper>
|
||||||
|
<Checkbox
|
||||||
|
checked={proceedAnyway}
|
||||||
|
onChange={(e) => setProceedAnyway(e.target.checked)}
|
||||||
|
>
|
||||||
|
{t('I understand the cost implications and want to proceed anyway')}
|
||||||
|
</Checkbox>
|
||||||
|
</CheckboxWrapper>
|
||||||
|
</StyledModal>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -71,6 +71,7 @@ import {
|
|||||||
addNewQueryEditor,
|
addNewQueryEditor,
|
||||||
CtasEnum,
|
CtasEnum,
|
||||||
estimateQueryCost,
|
estimateQueryCost,
|
||||||
|
checkCostThreshold,
|
||||||
persistEditorHeight,
|
persistEditorHeight,
|
||||||
postStopQuery,
|
postStopQuery,
|
||||||
queryEditorSetAutorun,
|
queryEditorSetAutorun,
|
||||||
@@ -123,6 +124,7 @@ import SouthPane from '../SouthPane';
|
|||||||
import SaveQuery, { QueryPayload } from '../SaveQuery';
|
import SaveQuery, { QueryPayload } from '../SaveQuery';
|
||||||
import ScheduleQueryButton from '../ScheduleQueryButton';
|
import ScheduleQueryButton from '../ScheduleQueryButton';
|
||||||
import EstimateQueryCostButton from '../EstimateQueryCostButton';
|
import EstimateQueryCostButton from '../EstimateQueryCostButton';
|
||||||
|
import CostWarningModal from '../CostWarningModal';
|
||||||
import ShareSqlLabQuery from '../ShareSqlLabQuery';
|
import ShareSqlLabQuery from '../ShareSqlLabQuery';
|
||||||
import SqlEditorLeftBar from '../SqlEditorLeftBar';
|
import SqlEditorLeftBar from '../SqlEditorLeftBar';
|
||||||
import AceEditorWrapper from '../AceEditorWrapper';
|
import AceEditorWrapper from '../AceEditorWrapper';
|
||||||
@@ -270,6 +272,7 @@ const SqlEditor: FC<Props> = ({
|
|||||||
hideLeftBar,
|
hideLeftBar,
|
||||||
currentQueryEditorId,
|
currentQueryEditorId,
|
||||||
hasSqlStatement,
|
hasSqlStatement,
|
||||||
|
costThresholdData,
|
||||||
} = useSelector<
|
} = useSelector<
|
||||||
SqlLabRootState,
|
SqlLabRootState,
|
||||||
{
|
{
|
||||||
@@ -278,8 +281,9 @@ const SqlEditor: FC<Props> = ({
|
|||||||
hideLeftBar?: boolean;
|
hideLeftBar?: boolean;
|
||||||
currentQueryEditorId: QueryEditor['id'];
|
currentQueryEditorId: QueryEditor['id'];
|
||||||
hasSqlStatement: boolean;
|
hasSqlStatement: boolean;
|
||||||
|
costThresholdData?: any;
|
||||||
}
|
}
|
||||||
>(({ sqlLab: { unsavedQueryEditor, databases, queries, tabHistory } }) => {
|
>(({ sqlLab: { unsavedQueryEditor, databases, queries, tabHistory, queryCostThresholds } }) => {
|
||||||
let { dbId, latestQueryId, hideLeftBar } = queryEditor;
|
let { dbId, latestQueryId, hideLeftBar } = queryEditor;
|
||||||
if (unsavedQueryEditor?.id === queryEditor.id) {
|
if (unsavedQueryEditor?.id === queryEditor.id) {
|
||||||
dbId = unsavedQueryEditor.dbId || dbId;
|
dbId = unsavedQueryEditor.dbId || dbId;
|
||||||
@@ -295,6 +299,7 @@ const SqlEditor: FC<Props> = ({
|
|||||||
latestQuery: queries[latestQueryId || ''],
|
latestQuery: queries[latestQueryId || ''],
|
||||||
hideLeftBar,
|
hideLeftBar,
|
||||||
currentQueryEditorId: tabHistory.slice(-1)[0],
|
currentQueryEditorId: tabHistory.slice(-1)[0],
|
||||||
|
costThresholdData: queryCostThresholds[queryEditor.id],
|
||||||
};
|
};
|
||||||
}, shallowEqual);
|
}, shallowEqual);
|
||||||
|
|
||||||
@@ -317,6 +322,11 @@ const SqlEditor: FC<Props> = ({
|
|||||||
);
|
);
|
||||||
const [showCreateAsModal, setShowCreateAsModal] = useState(false);
|
const [showCreateAsModal, setShowCreateAsModal] = useState(false);
|
||||||
const [createAs, setCreateAs] = useState('');
|
const [createAs, setCreateAs] = useState('');
|
||||||
|
const [showCostWarningModal, setShowCostWarningModal] = useState(false);
|
||||||
|
const [costWarningData, setCostWarningData] = useState<{
|
||||||
|
warningMessage: string | null;
|
||||||
|
thresholdInfo?: any;
|
||||||
|
} | null>(null);
|
||||||
const currentSQL = useRef<string>(queryEditor.sql);
|
const currentSQL = useRef<string>(queryEditor.sql);
|
||||||
const showEmptyState = useMemo(
|
const showEmptyState = useMemo(
|
||||||
() => !database || isEmpty(database),
|
() => !database || isEmpty(database),
|
||||||
@@ -330,7 +340,69 @@ const SqlEditor: FC<Props> = ({
|
|||||||
|
|
||||||
const isTempId = (value: unknown): boolean => Number.isNaN(Number(value));
|
const isTempId = (value: unknown): boolean => Number.isNaN(Number(value));
|
||||||
|
|
||||||
|
const checkCostThresholdAndRun = useCallback(
|
||||||
|
(ctasArg = false, ctas_method = CtasEnum.Table) => {
|
||||||
|
if (!database) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if cost threshold checking is enabled via feature flag or configuration
|
||||||
|
// For now, we'll implement the logic directly
|
||||||
|
dispatch(checkCostThreshold(queryEditor)).then(([_, response]) => {
|
||||||
|
if (response && response.json) {
|
||||||
|
const { exceeds_threshold, formatted_warning, threshold_info } = response.json;
|
||||||
|
|
||||||
|
if (exceeds_threshold && formatted_warning) {
|
||||||
|
// Show warning modal
|
||||||
|
setCostWarningData({
|
||||||
|
warningMessage: formatted_warning,
|
||||||
|
thresholdInfo: threshold_info,
|
||||||
|
});
|
||||||
|
setShowCostWarningModal(true);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no threshold exceeded or checking failed, proceed with query
|
||||||
|
dispatch(
|
||||||
|
runQueryFromSqlEditor(
|
||||||
|
database,
|
||||||
|
queryEditor,
|
||||||
|
defaultQueryLimit,
|
||||||
|
ctasArg ? ctas : '',
|
||||||
|
ctasArg,
|
||||||
|
ctas_method,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
dispatch(setActiveSouthPaneTab('Results'));
|
||||||
|
}).catch(() => {
|
||||||
|
// If cost checking fails, proceed with query anyway
|
||||||
|
dispatch(
|
||||||
|
runQueryFromSqlEditor(
|
||||||
|
database,
|
||||||
|
queryEditor,
|
||||||
|
defaultQueryLimit,
|
||||||
|
ctasArg ? ctas : '',
|
||||||
|
ctasArg,
|
||||||
|
ctas_method,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
dispatch(setActiveSouthPaneTab('Results'));
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[ctas, database, defaultQueryLimit, dispatch, queryEditor],
|
||||||
|
);
|
||||||
|
|
||||||
const startQuery = useCallback(
|
const startQuery = useCallback(
|
||||||
|
(ctasArg = false, ctas_method = CtasEnum.Table) => {
|
||||||
|
// Use cost threshold checking for regular queries
|
||||||
|
checkCostThresholdAndRun(ctasArg, ctas_method);
|
||||||
|
},
|
||||||
|
[checkCostThresholdAndRun],
|
||||||
|
);
|
||||||
|
|
||||||
|
// Direct query execution without cost checking (for modal "proceed anyway")
|
||||||
|
const executeQueryDirectly = useCallback(
|
||||||
(ctasArg = false, ctas_method = CtasEnum.Table) => {
|
(ctasArg = false, ctas_method = CtasEnum.Table) => {
|
||||||
if (!database) {
|
if (!database) {
|
||||||
return;
|
return;
|
||||||
@@ -1121,6 +1193,20 @@ const SqlEditor: FC<Props> = ({
|
|||||||
<span>{t('Name')}</span>
|
<span>{t('Name')}</span>
|
||||||
<Input placeholder={createModalPlaceHolder} onChange={ctasChanged} />
|
<Input placeholder={createModalPlaceHolder} onChange={ctasChanged} />
|
||||||
</Modal>
|
</Modal>
|
||||||
|
<CostWarningModal
|
||||||
|
visible={showCostWarningModal}
|
||||||
|
onHide={() => {
|
||||||
|
setShowCostWarningModal(false);
|
||||||
|
setCostWarningData(null);
|
||||||
|
}}
|
||||||
|
onProceed={() => {
|
||||||
|
setShowCostWarningModal(false);
|
||||||
|
setCostWarningData(null);
|
||||||
|
executeQueryDirectly();
|
||||||
|
}}
|
||||||
|
warningMessage={costWarningData?.warningMessage || null}
|
||||||
|
thresholdInfo={costWarningData?.thresholdInfo}
|
||||||
|
/>
|
||||||
</StyledSqlEditor>
|
</StyledSqlEditor>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -264,6 +264,7 @@ export default function getInitialState({
|
|||||||
queriesLastUpdate: Date.now(),
|
queriesLastUpdate: Date.now(),
|
||||||
editorTabLastUpdatedAt,
|
editorTabLastUpdatedAt,
|
||||||
queryCostEstimates: {},
|
queryCostEstimates: {},
|
||||||
|
queryCostThresholds: {},
|
||||||
unsavedQueryEditor,
|
unsavedQueryEditor,
|
||||||
lastUpdatedActiveTab,
|
lastUpdatedActiveTab,
|
||||||
destroyedQueryEditors,
|
destroyedQueryEditors,
|
||||||
|
|||||||
@@ -315,6 +315,51 @@ export default function sqlLabReducer(state = {}, action) {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
|
[actions.COST_THRESHOLD_CHECK_STARTED]() {
|
||||||
|
return {
|
||||||
|
...state,
|
||||||
|
queryCostThresholds: {
|
||||||
|
...state.queryCostThresholds,
|
||||||
|
[action.query.id]: {
|
||||||
|
completed: false,
|
||||||
|
exceedsThreshold: false,
|
||||||
|
thresholdInfo: null,
|
||||||
|
formattedWarning: null,
|
||||||
|
error: null,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
},
|
||||||
|
[actions.COST_THRESHOLD_CHECK_RETURNED]() {
|
||||||
|
return {
|
||||||
|
...state,
|
||||||
|
queryCostThresholds: {
|
||||||
|
...state.queryCostThresholds,
|
||||||
|
[action.query.id]: {
|
||||||
|
completed: true,
|
||||||
|
exceedsThreshold: action.json.exceeds_threshold,
|
||||||
|
thresholdInfo: action.json.threshold_info,
|
||||||
|
formattedWarning: action.json.formatted_warning,
|
||||||
|
error: null,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
},
|
||||||
|
[actions.COST_THRESHOLD_CHECK_FAILED]() {
|
||||||
|
return {
|
||||||
|
...state,
|
||||||
|
queryCostThresholds: {
|
||||||
|
...state.queryCostThresholds,
|
||||||
|
[action.query.id]: {
|
||||||
|
completed: false,
|
||||||
|
exceedsThreshold: false,
|
||||||
|
thresholdInfo: null,
|
||||||
|
formattedWarning: null,
|
||||||
|
error: action.error,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
},
|
||||||
[actions.START_QUERY]() {
|
[actions.START_QUERY]() {
|
||||||
let newState = { ...state };
|
let newState = { ...state };
|
||||||
if (action.query.sqlEditorId) {
|
if (action.query.sqlEditorId) {
|
||||||
|
|||||||
230
superset/commands/sql_lab/check_cost_threshold.py
Normal file
230
superset/commands/sql_lab/check_cost_threshold.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
|
from superset import app
|
||||||
|
from superset.commands.base import BaseCommand
|
||||||
|
from superset.commands.sql_lab.estimate import QueryEstimationCommand, EstimateQueryCostType
|
||||||
|
|
||||||
|
config = app.config
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CostThresholdResult(TypedDict):
|
||||||
|
exceeds_threshold: bool
|
||||||
|
estimated_cost: list[dict[str, Any]]
|
||||||
|
threshold_info: dict[str, Any]
|
||||||
|
formatted_warning: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class QueryCostThresholdCheckCommand(BaseCommand):
|
||||||
|
"""
|
||||||
|
Command to check if a query's estimated cost exceeds configured thresholds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_estimation_command: QueryEstimationCommand
|
||||||
|
|
||||||
|
def __init__(self, estimation_params: EstimateQueryCostType) -> None:
|
||||||
|
self._estimation_command = QueryEstimationCommand(estimation_params)
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
# Use the estimation command's validation
|
||||||
|
self._estimation_command.validate()
|
||||||
|
|
||||||
|
def run(self) -> CostThresholdResult:
|
||||||
|
"""
|
||||||
|
Check if query cost exceeds thresholds.
|
||||||
|
|
||||||
|
Returns a result indicating whether the query exceeds cost thresholds
|
||||||
|
and provides information for user warnings.
|
||||||
|
"""
|
||||||
|
self.validate()
|
||||||
|
|
||||||
|
# Check if cost checking is enabled
|
||||||
|
if not config.get("SQLLAB_QUERY_COST_CHECKING_ENABLED", False):
|
||||||
|
return self._create_empty_result()
|
||||||
|
|
||||||
|
estimated_cost = self._get_estimated_cost()
|
||||||
|
if not estimated_cost:
|
||||||
|
return self._create_empty_result()
|
||||||
|
|
||||||
|
thresholds = self._get_engine_thresholds()
|
||||||
|
if not thresholds:
|
||||||
|
return CostThresholdResult(
|
||||||
|
exceeds_threshold=False,
|
||||||
|
estimated_cost=estimated_cost,
|
||||||
|
threshold_info={},
|
||||||
|
formatted_warning=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._check_thresholds(estimated_cost, thresholds)
|
||||||
|
|
||||||
|
def _create_empty_result(self) -> CostThresholdResult:
|
||||||
|
"""Create an empty result when cost checking is disabled or fails."""
|
||||||
|
return CostThresholdResult(
|
||||||
|
exceeds_threshold=False,
|
||||||
|
estimated_cost=[],
|
||||||
|
threshold_info={},
|
||||||
|
formatted_warning=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_estimated_cost(self) -> list[dict[str, Any]] | None:
|
||||||
|
"""Get cost estimation, returning None if it fails."""
|
||||||
|
try:
|
||||||
|
return self._estimation_command.run()
|
||||||
|
except Exception as ex:
|
||||||
|
logger.warning("Cost estimation failed: %s", str(ex))
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_engine_thresholds(self) -> dict[str, Any]:
|
||||||
|
"""Get thresholds for the current database engine."""
|
||||||
|
database = self._estimation_command._database
|
||||||
|
engine_name = database.db_engine_spec.engine_name
|
||||||
|
if engine_name is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
engine_name = engine_name.lower()
|
||||||
|
return config.get("SQLLAB_QUERY_COST_THRESHOLDS", {}).get(engine_name, {})
|
||||||
|
|
||||||
|
def _check_thresholds(
|
||||||
|
self, estimated_cost: list[dict[str, Any]], thresholds: dict[str, Any]
|
||||||
|
) -> CostThresholdResult:
|
||||||
|
"""Check if estimated cost exceeds configured thresholds."""
|
||||||
|
exceeds_threshold = False
|
||||||
|
warning_messages = []
|
||||||
|
threshold_info = {}
|
||||||
|
|
||||||
|
for cost_item in estimated_cost:
|
||||||
|
if self._check_bytes_threshold(cost_item, thresholds, threshold_info, warning_messages):
|
||||||
|
exceeds_threshold = True
|
||||||
|
if self._check_cost_threshold(cost_item, thresholds, threshold_info, warning_messages):
|
||||||
|
exceeds_threshold = True
|
||||||
|
|
||||||
|
formatted_warning = None
|
||||||
|
if warning_messages:
|
||||||
|
formatted_warning = (
|
||||||
|
" ".join(warning_messages) + " Are you sure you want to continue?"
|
||||||
|
)
|
||||||
|
|
||||||
|
return CostThresholdResult(
|
||||||
|
exceeds_threshold=exceeds_threshold,
|
||||||
|
estimated_cost=estimated_cost,
|
||||||
|
threshold_info=threshold_info,
|
||||||
|
formatted_warning=formatted_warning,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_bytes_threshold(
|
||||||
|
self,
|
||||||
|
cost_item: dict[str, Any],
|
||||||
|
thresholds: dict[str, Any],
|
||||||
|
threshold_info: dict[str, Any],
|
||||||
|
warning_messages: list[str]
|
||||||
|
) -> bool:
|
||||||
|
"""Check bytes scanned threshold. Returns True if threshold exceeded."""
|
||||||
|
if "bytes_scanned" not in thresholds or "Bytes Scanned" not in cost_item:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
bytes_scanned = self._parse_bytes_from_cost_item(cost_item["Bytes Scanned"])
|
||||||
|
threshold_bytes = thresholds["bytes_scanned"]
|
||||||
|
threshold_info["bytes_threshold"] = threshold_bytes
|
||||||
|
threshold_info["estimated_bytes"] = bytes_scanned
|
||||||
|
|
||||||
|
if bytes_scanned > threshold_bytes:
|
||||||
|
warning_messages.append(
|
||||||
|
f"This query will scan approximately {self._format_bytes(bytes_scanned)} "
|
||||||
|
f"of data, which exceeds the threshold of {self._format_bytes(threshold_bytes)}."
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except (ValueError, KeyError) as ex:
|
||||||
|
logger.warning("Failed to parse bytes from cost estimation: %s", str(ex))
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _check_cost_threshold(
|
||||||
|
self,
|
||||||
|
cost_item: dict[str, Any],
|
||||||
|
thresholds: dict[str, Any],
|
||||||
|
threshold_info: dict[str, Any],
|
||||||
|
warning_messages: list[str]
|
||||||
|
) -> bool:
|
||||||
|
"""Check cost threshold. Returns True if threshold exceeded."""
|
||||||
|
if "cost_threshold" not in thresholds or "Cost" not in cost_item:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
cost_value = float(cost_item["Cost"])
|
||||||
|
threshold_cost = thresholds["cost_threshold"]
|
||||||
|
threshold_info["cost_threshold"] = threshold_cost
|
||||||
|
threshold_info["estimated_cost"] = cost_value
|
||||||
|
|
||||||
|
if cost_value > threshold_cost:
|
||||||
|
warning_messages.append(
|
||||||
|
f"This query has an estimated cost of {cost_value}, "
|
||||||
|
f"which exceeds the threshold of {threshold_cost}."
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except (ValueError, KeyError) as ex:
|
||||||
|
logger.warning("Failed to parse cost from cost estimation: %s", str(ex))
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _parse_bytes_from_cost_item(self, bytes_str: str) -> int:
|
||||||
|
"""Parse bytes from formatted string like '5.2 GB' or '1024 MB'."""
|
||||||
|
if not isinstance(bytes_str, str):
|
||||||
|
return int(bytes_str)
|
||||||
|
|
||||||
|
# Remove commas and split
|
||||||
|
parts = bytes_str.replace(",", "").strip().split()
|
||||||
|
if len(parts) != 2:
|
||||||
|
raise ValueError(f"Cannot parse bytes from: {bytes_str}")
|
||||||
|
|
||||||
|
value_str, unit = parts
|
||||||
|
value = float(value_str)
|
||||||
|
unit = unit.upper()
|
||||||
|
|
||||||
|
multipliers = {
|
||||||
|
"B": 1,
|
||||||
|
"KB": 1024,
|
||||||
|
"MB": 1024**2,
|
||||||
|
"GB": 1024**3,
|
||||||
|
"TB": 1024**4,
|
||||||
|
"PB": 1024**5,
|
||||||
|
}
|
||||||
|
|
||||||
|
if unit not in multipliers:
|
||||||
|
raise ValueError(f"Unknown unit: {unit}")
|
||||||
|
|
||||||
|
return int(value * multipliers[unit])
|
||||||
|
|
||||||
|
def _format_bytes(self, bytes_count: int) -> str:
|
||||||
|
"""Format bytes into human-readable string."""
|
||||||
|
if bytes_count < 1024:
|
||||||
|
return f"{bytes_count} B"
|
||||||
|
elif bytes_count < 1024**2:
|
||||||
|
return f"{bytes_count / 1024:.1f} KB"
|
||||||
|
elif bytes_count < 1024**3:
|
||||||
|
return f"{bytes_count / (1024**2):.1f} MB"
|
||||||
|
elif bytes_count < 1024**4:
|
||||||
|
return f"{bytes_count / (1024**3):.1f} GB"
|
||||||
|
elif bytes_count < 1024**5:
|
||||||
|
return f"{bytes_count / (1024**4):.1f} TB"
|
||||||
|
else:
|
||||||
|
return f"{bytes_count / (1024**5):.1f} PB"
|
||||||
@@ -1191,6 +1191,18 @@ SQLLAB_ASYNC_TIME_LIMIT_SEC = int(timedelta(hours=6).total_seconds())
|
|||||||
# timeout.
|
# timeout.
|
||||||
SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = int(timedelta(seconds=10).total_seconds())
|
SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = int(timedelta(seconds=10).total_seconds())
|
||||||
|
|
||||||
|
# Query cost governance configuration
|
||||||
|
# Enable automatic cost checking before query execution
|
||||||
|
SQLLAB_QUERY_COST_CHECKING_ENABLED = False
|
||||||
|
|
||||||
|
# Cost thresholds that trigger warnings before query execution
|
||||||
|
# This is a dictionary where keys are database engine names and values are threshold configs
|
||||||
|
# Each threshold config can contain:
|
||||||
|
# - 'bytes_scanned': maximum bytes that can be scanned without warning
|
||||||
|
# - 'cost_threshold': monetary cost threshold (engine-specific units)
|
||||||
|
# Example: {'bigquery': {'bytes_scanned': 5 * 1024**4}, 'presto': {'cost_threshold': 1000}}
|
||||||
|
SQLLAB_QUERY_COST_THRESHOLDS = {}
|
||||||
|
|
||||||
# Timeout duration for SQL Lab fetching query results by the resultsKey.
|
# Timeout duration for SQL Lab fetching query results by the resultsKey.
|
||||||
# 0 means no timeout.
|
# 0 means no timeout.
|
||||||
SQLLAB_QUERY_RESULT_TIMEOUT = 0
|
SQLLAB_QUERY_RESULT_TIMEOUT = 0
|
||||||
|
|||||||
@@ -25,6 +25,9 @@ from flask_appbuilder.models.sqla.interface import SQLAInterface
|
|||||||
from marshmallow import ValidationError
|
from marshmallow import ValidationError
|
||||||
|
|
||||||
from superset import app, is_feature_enabled
|
from superset import app, is_feature_enabled
|
||||||
|
from superset.commands.sql_lab.check_cost_threshold import (
|
||||||
|
QueryCostThresholdCheckCommand,
|
||||||
|
)
|
||||||
from superset.commands.sql_lab.estimate import QueryEstimationCommand
|
from superset.commands.sql_lab.estimate import QueryEstimationCommand
|
||||||
from superset.commands.sql_lab.execute import CommandResult, ExecuteSqlCommand
|
from superset.commands.sql_lab.execute import CommandResult, ExecuteSqlCommand
|
||||||
from superset.commands.sql_lab.export import SqlResultExportCommand
|
from superset.commands.sql_lab.export import SqlResultExportCommand
|
||||||
@@ -188,6 +191,66 @@ class SqlLabRestApi(BaseSupersetApi):
|
|||||||
result = command.run()
|
result = command.run()
|
||||||
return self.response(200, result=result)
|
return self.response(200, result=result)
|
||||||
|
|
||||||
|
@expose("/check_cost_threshold/", methods=("POST",))
|
||||||
|
@protect()
|
||||||
|
@statsd_metrics
|
||||||
|
@requires_json
|
||||||
|
@event_logger.log_this_with_context(
|
||||||
|
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}"
|
||||||
|
f".check_cost_threshold",
|
||||||
|
log_to_statsd=False,
|
||||||
|
)
|
||||||
|
def check_cost_threshold(self) -> Response:
|
||||||
|
"""Check if query cost exceeds configured thresholds.
|
||||||
|
---
|
||||||
|
post:
|
||||||
|
summary: Check if query cost exceeds thresholds
|
||||||
|
requestBody:
|
||||||
|
description: SQL query and params
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/EstimateQueryCostSchema'
|
||||||
|
responses:
|
||||||
|
200:
|
||||||
|
description: Cost threshold check result
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
exceeds_threshold:
|
||||||
|
type: boolean
|
||||||
|
description: Whether query exceeds cost thresholds
|
||||||
|
estimated_cost:
|
||||||
|
type: array
|
||||||
|
description: Detailed cost estimation
|
||||||
|
threshold_info:
|
||||||
|
type: object
|
||||||
|
description: Information about thresholds and estimates
|
||||||
|
formatted_warning:
|
||||||
|
type: string
|
||||||
|
nullable: true
|
||||||
|
description: Human-readable warning message
|
||||||
|
400:
|
||||||
|
$ref: '#/components/responses/400'
|
||||||
|
401:
|
||||||
|
$ref: '#/components/responses/401'
|
||||||
|
403:
|
||||||
|
$ref: '#/components/responses/403'
|
||||||
|
500:
|
||||||
|
$ref: '#/components/responses/500'
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model = self.estimate_model_schema.load(request.json)
|
||||||
|
except ValidationError as error:
|
||||||
|
return self.response_400(message=error.messages)
|
||||||
|
|
||||||
|
command = QueryCostThresholdCheckCommand(model)
|
||||||
|
result = command.run()
|
||||||
|
return self.response(200, **result)
|
||||||
|
|
||||||
@expose("/format_sql/", methods=("POST",))
|
@expose("/format_sql/", methods=("POST",))
|
||||||
@statsd_metrics
|
@statsd_metrics
|
||||||
@protect()
|
@protect()
|
||||||
|
|||||||
Reference in New Issue
Block a user