mirror of
https://github.com/apache/superset.git
synced 2026-04-28 12:34:23 +00:00
Compare commits
1 Commits
fix-webpac
...
pre-cost-e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
992226527f |
@@ -97,6 +97,10 @@ export const COST_ESTIMATE_STARTED = 'COST_ESTIMATE_STARTED';
|
||||
export const COST_ESTIMATE_RETURNED = 'COST_ESTIMATE_RETURNED';
|
||||
export const COST_ESTIMATE_FAILED = 'COST_ESTIMATE_FAILED';
|
||||
|
||||
export const COST_THRESHOLD_CHECK_STARTED = 'COST_THRESHOLD_CHECK_STARTED';
|
||||
export const COST_THRESHOLD_CHECK_RETURNED = 'COST_THRESHOLD_CHECK_RETURNED';
|
||||
export const COST_THRESHOLD_CHECK_FAILED = 'COST_THRESHOLD_CHECK_FAILED';
|
||||
|
||||
export const CREATE_DATASOURCE_STARTED = 'CREATE_DATASOURCE_STARTED';
|
||||
export const CREATE_DATASOURCE_SUCCESS = 'CREATE_DATASOURCE_SUCCESS';
|
||||
export const CREATE_DATASOURCE_FAILED = 'CREATE_DATASOURCE_FAILED';
|
||||
@@ -233,6 +237,45 @@ export function estimateQueryCost(queryEditor) {
|
||||
};
|
||||
}
|
||||
|
||||
export function checkCostThreshold(queryEditor) {
|
||||
return (dispatch, getState) => {
|
||||
const { dbId, catalog, schema, sql, selectedText, templateParams } =
|
||||
getUpToDateQuery(getState(), queryEditor);
|
||||
const requestSql = selectedText || sql;
|
||||
const postPayload = {
|
||||
database_id: dbId,
|
||||
catalog,
|
||||
schema,
|
||||
sql: requestSql,
|
||||
template_params: JSON.parse(templateParams || '{}'),
|
||||
};
|
||||
return Promise.all([
|
||||
dispatch({ type: COST_THRESHOLD_CHECK_STARTED, query: queryEditor }),
|
||||
SupersetClient.post({
|
||||
endpoint: '/api/v1/sqllab/check_cost_threshold/',
|
||||
body: JSON.stringify(postPayload),
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
})
|
||||
.then(({ json }) =>
|
||||
dispatch({ type: COST_THRESHOLD_CHECK_RETURNED, query: queryEditor, json }),
|
||||
)
|
||||
.catch(response =>
|
||||
getClientErrorObject(response).then(error => {
|
||||
const message =
|
||||
error.error ||
|
||||
error.statusText ||
|
||||
t('Failed at checking cost threshold');
|
||||
return dispatch({
|
||||
type: COST_THRESHOLD_CHECK_FAILED,
|
||||
query: queryEditor,
|
||||
error: message,
|
||||
});
|
||||
}),
|
||||
),
|
||||
]);
|
||||
};
|
||||
}
|
||||
|
||||
export function clearInactiveQueries(interval) {
|
||||
return { type: CLEAR_INACTIVE_QUERIES, interval };
|
||||
}
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
import { render, screen, fireEvent } from '@testing-library/react';
|
||||
import { ThemeProvider } from '@superset-ui/core';
|
||||
import { theme } from 'src/preamble';
|
||||
import CostWarningModal from './index';
|
||||
|
||||
const mockProps = {
|
||||
visible: true,
|
||||
onHide: jest.fn(),
|
||||
onProceed: jest.fn(),
|
||||
warningMessage: 'This query will scan 10 GB of data, which exceeds the threshold of 5 GB.',
|
||||
thresholdInfo: {
|
||||
bytes_threshold: 5 * 1024 ** 3, // 5 GB
|
||||
estimated_bytes: 10 * 1024 ** 3, // 10 GB
|
||||
},
|
||||
};
|
||||
|
||||
const renderWithTheme = (ui: React.ReactElement) =>
|
||||
render(<ThemeProvider theme={theme}>{ui}</ThemeProvider>);
|
||||
|
||||
describe('CostWarningModal', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('renders with warning message', () => {
|
||||
renderWithTheme(<CostWarningModal {...mockProps} />);
|
||||
|
||||
expect(screen.getByText('Query Cost Warning')).toBeInTheDocument();
|
||||
expect(screen.getByText(mockProps.warningMessage)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('shows threshold details when provided', () => {
|
||||
renderWithTheme(<CostWarningModal {...mockProps} />);
|
||||
|
||||
expect(screen.getByText('Threshold Details:')).toBeInTheDocument();
|
||||
expect(screen.getByText('Data to scan:')).toBeInTheDocument();
|
||||
expect(screen.getByText('10.0 GB')).toBeInTheDocument();
|
||||
expect(screen.getByText('5.0 GB')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('disables proceed button until checkbox is checked', () => {
|
||||
renderWithTheme(<CostWarningModal {...mockProps} />);
|
||||
|
||||
const proceedButton = screen.getByText('Run Query Anyway');
|
||||
const checkbox = screen.getByRole('checkbox');
|
||||
|
||||
expect(proceedButton).toBeDisabled();
|
||||
|
||||
fireEvent.click(checkbox);
|
||||
expect(proceedButton).not.toBeDisabled();
|
||||
});
|
||||
|
||||
it('calls onProceed when proceed button is clicked with checkbox checked', () => {
|
||||
renderWithTheme(<CostWarningModal {...mockProps} />);
|
||||
|
||||
const checkbox = screen.getByRole('checkbox');
|
||||
const proceedButton = screen.getByText('Run Query Anyway');
|
||||
|
||||
fireEvent.click(checkbox);
|
||||
fireEvent.click(proceedButton);
|
||||
|
||||
expect(mockProps.onProceed).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('calls onHide when cancel button is clicked', () => {
|
||||
renderWithTheme(<CostWarningModal {...mockProps} />);
|
||||
|
||||
const cancelButton = screen.getByText('Cancel');
|
||||
fireEvent.click(cancelButton);
|
||||
|
||||
expect(mockProps.onHide).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('renders without threshold details when not provided', () => {
|
||||
const propsWithoutThreshold = {
|
||||
...mockProps,
|
||||
thresholdInfo: undefined,
|
||||
};
|
||||
|
||||
renderWithTheme(<CostWarningModal {...propsWithoutThreshold} />);
|
||||
|
||||
expect(screen.queryByText('Threshold Details:')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('shows default message when warningMessage is null', () => {
|
||||
const propsWithNoMessage = {
|
||||
...mockProps,
|
||||
warningMessage: null,
|
||||
};
|
||||
|
||||
renderWithTheme(<CostWarningModal {...propsWithNoMessage} />);
|
||||
|
||||
expect(screen.getByText('This query may be expensive to run.')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('handles cost threshold details', () => {
|
||||
const propsWithCostThreshold = {
|
||||
...mockProps,
|
||||
thresholdInfo: {
|
||||
cost_threshold: 100,
|
||||
estimated_cost: 250,
|
||||
},
|
||||
};
|
||||
|
||||
renderWithTheme(<CostWarningModal {...propsWithCostThreshold} />);
|
||||
|
||||
expect(screen.getByText('Estimated cost:')).toBeInTheDocument();
|
||||
expect(screen.getByText('250')).toBeInTheDocument();
|
||||
expect(screen.getByText('Cost threshold:')).toBeInTheDocument();
|
||||
expect(screen.getByText('100')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,166 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
import { useState } from 'react';
|
||||
import { styled, t } from '@superset-ui/core';
|
||||
import { Button, Modal, Checkbox } from '@superset-ui/core/components';
|
||||
import { ModalTitleWithIcon } from 'src/components/ModalTitleWithIcon';
|
||||
|
||||
const StyledModal = styled(Modal)`
|
||||
.ant-modal-body {
|
||||
padding: 24px;
|
||||
}
|
||||
`;
|
||||
|
||||
const WarningContent = styled.div`
|
||||
margin: 16px 0;
|
||||
font-size: 14px;
|
||||
line-height: 1.5;
|
||||
`;
|
||||
|
||||
const DetailsSection = styled.div`
|
||||
margin: 16px 0;
|
||||
padding: 12px;
|
||||
background-color: ${({ theme }) => theme.colors.grayscale.light4};
|
||||
border-radius: 4px;
|
||||
font-size: 12px;
|
||||
`;
|
||||
|
||||
const CheckboxWrapper = styled.div`
|
||||
margin: 16px 0;
|
||||
`;
|
||||
|
||||
interface CostWarningModalProps {
|
||||
visible: boolean;
|
||||
onHide: () => void;
|
||||
onProceed: () => void;
|
||||
warningMessage: string | null;
|
||||
thresholdInfo?: {
|
||||
bytes_threshold?: number;
|
||||
estimated_bytes?: number;
|
||||
cost_threshold?: number;
|
||||
estimated_cost?: number;
|
||||
};
|
||||
}
|
||||
|
||||
export default function CostWarningModal({
|
||||
visible,
|
||||
onHide,
|
||||
onProceed,
|
||||
warningMessage,
|
||||
thresholdInfo,
|
||||
}: CostWarningModalProps) {
|
||||
const [proceedAnyway, setProceedAnyway] = useState(false);
|
||||
|
||||
const handleProceed = () => {
|
||||
if (proceedAnyway) {
|
||||
onProceed();
|
||||
}
|
||||
};
|
||||
|
||||
const formatBytes = (bytes: number) => {
|
||||
if (bytes < 1024) return `${bytes} B`;
|
||||
if (bytes < 1024 ** 2) return `${(bytes / 1024).toFixed(1)} KB`;
|
||||
if (bytes < 1024 ** 3) return `${(bytes / 1024 ** 2).toFixed(1)} MB`;
|
||||
if (bytes < 1024 ** 4) return `${(bytes / 1024 ** 3).toFixed(1)} GB`;
|
||||
if (bytes < 1024 ** 5) return `${(bytes / 1024 ** 4).toFixed(1)} TB`;
|
||||
return `${(bytes / 1024 ** 5).toFixed(1)} PB`;
|
||||
};
|
||||
|
||||
const renderThresholdDetails = () => {
|
||||
if (!thresholdInfo) return null;
|
||||
|
||||
const details = [];
|
||||
|
||||
if (thresholdInfo.bytes_threshold && thresholdInfo.estimated_bytes) {
|
||||
details.push(
|
||||
<div key="bytes">
|
||||
<strong>{t('Data to scan:')}</strong> {formatBytes(thresholdInfo.estimated_bytes)}
|
||||
<br />
|
||||
<strong>{t('Threshold:')}</strong> {formatBytes(thresholdInfo.bytes_threshold)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (thresholdInfo.cost_threshold && thresholdInfo.estimated_cost) {
|
||||
details.push(
|
||||
<div key="cost">
|
||||
<strong>{t('Estimated cost:')}</strong> {thresholdInfo.estimated_cost}
|
||||
<br />
|
||||
<strong>{t('Cost threshold:')}</strong> {thresholdInfo.cost_threshold}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return details.length > 0 ? (
|
||||
<DetailsSection>
|
||||
<div style={{ marginBottom: '8px' }}>
|
||||
<strong>{t('Threshold Details:')}</strong>
|
||||
</div>
|
||||
{details.map((detail, index) => (
|
||||
<div key={index} style={{ marginBottom: index < details.length - 1 ? '8px' : '0' }}>
|
||||
{detail}
|
||||
</div>
|
||||
))}
|
||||
</DetailsSection>
|
||||
) : null;
|
||||
};
|
||||
|
||||
return (
|
||||
<StyledModal
|
||||
show={visible}
|
||||
onHide={onHide}
|
||||
title={
|
||||
<ModalTitleWithIcon
|
||||
icon="exclamation-triangle"
|
||||
title={t('Query Cost Warning')}
|
||||
/>
|
||||
}
|
||||
footer={
|
||||
<>
|
||||
<Button onClick={onHide}>
|
||||
{t('Cancel')}
|
||||
</Button>
|
||||
<Button
|
||||
buttonStyle="primary"
|
||||
onClick={handleProceed}
|
||||
disabled={!proceedAnyway}
|
||||
>
|
||||
{t('Run Query Anyway')}
|
||||
</Button>
|
||||
</>
|
||||
}
|
||||
>
|
||||
<WarningContent>
|
||||
{warningMessage || t('This query may be expensive to run.')}
|
||||
</WarningContent>
|
||||
|
||||
{renderThresholdDetails()}
|
||||
|
||||
<CheckboxWrapper>
|
||||
<Checkbox
|
||||
checked={proceedAnyway}
|
||||
onChange={(e) => setProceedAnyway(e.target.checked)}
|
||||
>
|
||||
{t('I understand the cost implications and want to proceed anyway')}
|
||||
</Checkbox>
|
||||
</CheckboxWrapper>
|
||||
</StyledModal>
|
||||
);
|
||||
}
|
||||
@@ -71,6 +71,7 @@ import {
|
||||
addNewQueryEditor,
|
||||
CtasEnum,
|
||||
estimateQueryCost,
|
||||
checkCostThreshold,
|
||||
persistEditorHeight,
|
||||
postStopQuery,
|
||||
queryEditorSetAutorun,
|
||||
@@ -123,6 +124,7 @@ import SouthPane from '../SouthPane';
|
||||
import SaveQuery, { QueryPayload } from '../SaveQuery';
|
||||
import ScheduleQueryButton from '../ScheduleQueryButton';
|
||||
import EstimateQueryCostButton from '../EstimateQueryCostButton';
|
||||
import CostWarningModal from '../CostWarningModal';
|
||||
import ShareSqlLabQuery from '../ShareSqlLabQuery';
|
||||
import SqlEditorLeftBar from '../SqlEditorLeftBar';
|
||||
import AceEditorWrapper from '../AceEditorWrapper';
|
||||
@@ -270,6 +272,7 @@ const SqlEditor: FC<Props> = ({
|
||||
hideLeftBar,
|
||||
currentQueryEditorId,
|
||||
hasSqlStatement,
|
||||
costThresholdData,
|
||||
} = useSelector<
|
||||
SqlLabRootState,
|
||||
{
|
||||
@@ -278,8 +281,9 @@ const SqlEditor: FC<Props> = ({
|
||||
hideLeftBar?: boolean;
|
||||
currentQueryEditorId: QueryEditor['id'];
|
||||
hasSqlStatement: boolean;
|
||||
costThresholdData?: any;
|
||||
}
|
||||
>(({ sqlLab: { unsavedQueryEditor, databases, queries, tabHistory } }) => {
|
||||
>(({ sqlLab: { unsavedQueryEditor, databases, queries, tabHistory, queryCostThresholds } }) => {
|
||||
let { dbId, latestQueryId, hideLeftBar } = queryEditor;
|
||||
if (unsavedQueryEditor?.id === queryEditor.id) {
|
||||
dbId = unsavedQueryEditor.dbId || dbId;
|
||||
@@ -295,6 +299,7 @@ const SqlEditor: FC<Props> = ({
|
||||
latestQuery: queries[latestQueryId || ''],
|
||||
hideLeftBar,
|
||||
currentQueryEditorId: tabHistory.slice(-1)[0],
|
||||
costThresholdData: queryCostThresholds[queryEditor.id],
|
||||
};
|
||||
}, shallowEqual);
|
||||
|
||||
@@ -317,6 +322,11 @@ const SqlEditor: FC<Props> = ({
|
||||
);
|
||||
const [showCreateAsModal, setShowCreateAsModal] = useState(false);
|
||||
const [createAs, setCreateAs] = useState('');
|
||||
const [showCostWarningModal, setShowCostWarningModal] = useState(false);
|
||||
const [costWarningData, setCostWarningData] = useState<{
|
||||
warningMessage: string | null;
|
||||
thresholdInfo?: any;
|
||||
} | null>(null);
|
||||
const currentSQL = useRef<string>(queryEditor.sql);
|
||||
const showEmptyState = useMemo(
|
||||
() => !database || isEmpty(database),
|
||||
@@ -330,7 +340,69 @@ const SqlEditor: FC<Props> = ({
|
||||
|
||||
const isTempId = (value: unknown): boolean => Number.isNaN(Number(value));
|
||||
|
||||
const checkCostThresholdAndRun = useCallback(
|
||||
(ctasArg = false, ctas_method = CtasEnum.Table) => {
|
||||
if (!database) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if cost threshold checking is enabled via feature flag or configuration
|
||||
// For now, we'll implement the logic directly
|
||||
dispatch(checkCostThreshold(queryEditor)).then(([_, response]) => {
|
||||
if (response && response.json) {
|
||||
const { exceeds_threshold, formatted_warning, threshold_info } = response.json;
|
||||
|
||||
if (exceeds_threshold && formatted_warning) {
|
||||
// Show warning modal
|
||||
setCostWarningData({
|
||||
warningMessage: formatted_warning,
|
||||
thresholdInfo: threshold_info,
|
||||
});
|
||||
setShowCostWarningModal(true);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// If no threshold exceeded or checking failed, proceed with query
|
||||
dispatch(
|
||||
runQueryFromSqlEditor(
|
||||
database,
|
||||
queryEditor,
|
||||
defaultQueryLimit,
|
||||
ctasArg ? ctas : '',
|
||||
ctasArg,
|
||||
ctas_method,
|
||||
),
|
||||
);
|
||||
dispatch(setActiveSouthPaneTab('Results'));
|
||||
}).catch(() => {
|
||||
// If cost checking fails, proceed with query anyway
|
||||
dispatch(
|
||||
runQueryFromSqlEditor(
|
||||
database,
|
||||
queryEditor,
|
||||
defaultQueryLimit,
|
||||
ctasArg ? ctas : '',
|
||||
ctasArg,
|
||||
ctas_method,
|
||||
),
|
||||
);
|
||||
dispatch(setActiveSouthPaneTab('Results'));
|
||||
});
|
||||
},
|
||||
[ctas, database, defaultQueryLimit, dispatch, queryEditor],
|
||||
);
|
||||
|
||||
const startQuery = useCallback(
|
||||
(ctasArg = false, ctas_method = CtasEnum.Table) => {
|
||||
// Use cost threshold checking for regular queries
|
||||
checkCostThresholdAndRun(ctasArg, ctas_method);
|
||||
},
|
||||
[checkCostThresholdAndRun],
|
||||
);
|
||||
|
||||
// Direct query execution without cost checking (for modal "proceed anyway")
|
||||
const executeQueryDirectly = useCallback(
|
||||
(ctasArg = false, ctas_method = CtasEnum.Table) => {
|
||||
if (!database) {
|
||||
return;
|
||||
@@ -1121,6 +1193,20 @@ const SqlEditor: FC<Props> = ({
|
||||
<span>{t('Name')}</span>
|
||||
<Input placeholder={createModalPlaceHolder} onChange={ctasChanged} />
|
||||
</Modal>
|
||||
<CostWarningModal
|
||||
visible={showCostWarningModal}
|
||||
onHide={() => {
|
||||
setShowCostWarningModal(false);
|
||||
setCostWarningData(null);
|
||||
}}
|
||||
onProceed={() => {
|
||||
setShowCostWarningModal(false);
|
||||
setCostWarningData(null);
|
||||
executeQueryDirectly();
|
||||
}}
|
||||
warningMessage={costWarningData?.warningMessage || null}
|
||||
thresholdInfo={costWarningData?.thresholdInfo}
|
||||
/>
|
||||
</StyledSqlEditor>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -264,6 +264,7 @@ export default function getInitialState({
|
||||
queriesLastUpdate: Date.now(),
|
||||
editorTabLastUpdatedAt,
|
||||
queryCostEstimates: {},
|
||||
queryCostThresholds: {},
|
||||
unsavedQueryEditor,
|
||||
lastUpdatedActiveTab,
|
||||
destroyedQueryEditors,
|
||||
|
||||
@@ -315,6 +315,51 @@ export default function sqlLabReducer(state = {}, action) {
|
||||
},
|
||||
};
|
||||
},
|
||||
[actions.COST_THRESHOLD_CHECK_STARTED]() {
|
||||
return {
|
||||
...state,
|
||||
queryCostThresholds: {
|
||||
...state.queryCostThresholds,
|
||||
[action.query.id]: {
|
||||
completed: false,
|
||||
exceedsThreshold: false,
|
||||
thresholdInfo: null,
|
||||
formattedWarning: null,
|
||||
error: null,
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
[actions.COST_THRESHOLD_CHECK_RETURNED]() {
|
||||
return {
|
||||
...state,
|
||||
queryCostThresholds: {
|
||||
...state.queryCostThresholds,
|
||||
[action.query.id]: {
|
||||
completed: true,
|
||||
exceedsThreshold: action.json.exceeds_threshold,
|
||||
thresholdInfo: action.json.threshold_info,
|
||||
formattedWarning: action.json.formatted_warning,
|
||||
error: null,
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
[actions.COST_THRESHOLD_CHECK_FAILED]() {
|
||||
return {
|
||||
...state,
|
||||
queryCostThresholds: {
|
||||
...state.queryCostThresholds,
|
||||
[action.query.id]: {
|
||||
completed: false,
|
||||
exceedsThreshold: false,
|
||||
thresholdInfo: null,
|
||||
formattedWarning: null,
|
||||
error: action.error,
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
[actions.START_QUERY]() {
|
||||
let newState = { ...state };
|
||||
if (action.query.sqlEditorId) {
|
||||
|
||||
230
superset/commands/sql_lab/check_cost_threshold.py
Normal file
230
superset/commands/sql_lab/check_cost_threshold.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from superset import app
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.commands.sql_lab.estimate import QueryEstimationCommand, EstimateQueryCostType
|
||||
|
||||
config = app.config
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CostThresholdResult(TypedDict):
|
||||
exceeds_threshold: bool
|
||||
estimated_cost: list[dict[str, Any]]
|
||||
threshold_info: dict[str, Any]
|
||||
formatted_warning: str | None
|
||||
|
||||
|
||||
class QueryCostThresholdCheckCommand(BaseCommand):
|
||||
"""
|
||||
Command to check if a query's estimated cost exceeds configured thresholds.
|
||||
"""
|
||||
|
||||
_estimation_command: QueryEstimationCommand
|
||||
|
||||
def __init__(self, estimation_params: EstimateQueryCostType) -> None:
|
||||
self._estimation_command = QueryEstimationCommand(estimation_params)
|
||||
|
||||
def validate(self) -> None:
|
||||
# Use the estimation command's validation
|
||||
self._estimation_command.validate()
|
||||
|
||||
def run(self) -> CostThresholdResult:
|
||||
"""
|
||||
Check if query cost exceeds thresholds.
|
||||
|
||||
Returns a result indicating whether the query exceeds cost thresholds
|
||||
and provides information for user warnings.
|
||||
"""
|
||||
self.validate()
|
||||
|
||||
# Check if cost checking is enabled
|
||||
if not config.get("SQLLAB_QUERY_COST_CHECKING_ENABLED", False):
|
||||
return self._create_empty_result()
|
||||
|
||||
estimated_cost = self._get_estimated_cost()
|
||||
if not estimated_cost:
|
||||
return self._create_empty_result()
|
||||
|
||||
thresholds = self._get_engine_thresholds()
|
||||
if not thresholds:
|
||||
return CostThresholdResult(
|
||||
exceeds_threshold=False,
|
||||
estimated_cost=estimated_cost,
|
||||
threshold_info={},
|
||||
formatted_warning=None,
|
||||
)
|
||||
|
||||
return self._check_thresholds(estimated_cost, thresholds)
|
||||
|
||||
def _create_empty_result(self) -> CostThresholdResult:
|
||||
"""Create an empty result when cost checking is disabled or fails."""
|
||||
return CostThresholdResult(
|
||||
exceeds_threshold=False,
|
||||
estimated_cost=[],
|
||||
threshold_info={},
|
||||
formatted_warning=None,
|
||||
)
|
||||
|
||||
def _get_estimated_cost(self) -> list[dict[str, Any]] | None:
|
||||
"""Get cost estimation, returning None if it fails."""
|
||||
try:
|
||||
return self._estimation_command.run()
|
||||
except Exception as ex:
|
||||
logger.warning("Cost estimation failed: %s", str(ex))
|
||||
return None
|
||||
|
||||
def _get_engine_thresholds(self) -> dict[str, Any]:
|
||||
"""Get thresholds for the current database engine."""
|
||||
database = self._estimation_command._database
|
||||
engine_name = database.db_engine_spec.engine_name
|
||||
if engine_name is None:
|
||||
return {}
|
||||
|
||||
engine_name = engine_name.lower()
|
||||
return config.get("SQLLAB_QUERY_COST_THRESHOLDS", {}).get(engine_name, {})
|
||||
|
||||
def _check_thresholds(
|
||||
self, estimated_cost: list[dict[str, Any]], thresholds: dict[str, Any]
|
||||
) -> CostThresholdResult:
|
||||
"""Check if estimated cost exceeds configured thresholds."""
|
||||
exceeds_threshold = False
|
||||
warning_messages = []
|
||||
threshold_info = {}
|
||||
|
||||
for cost_item in estimated_cost:
|
||||
if self._check_bytes_threshold(cost_item, thresholds, threshold_info, warning_messages):
|
||||
exceeds_threshold = True
|
||||
if self._check_cost_threshold(cost_item, thresholds, threshold_info, warning_messages):
|
||||
exceeds_threshold = True
|
||||
|
||||
formatted_warning = None
|
||||
if warning_messages:
|
||||
formatted_warning = (
|
||||
" ".join(warning_messages) + " Are you sure you want to continue?"
|
||||
)
|
||||
|
||||
return CostThresholdResult(
|
||||
exceeds_threshold=exceeds_threshold,
|
||||
estimated_cost=estimated_cost,
|
||||
threshold_info=threshold_info,
|
||||
formatted_warning=formatted_warning,
|
||||
)
|
||||
|
||||
def _check_bytes_threshold(
|
||||
self,
|
||||
cost_item: dict[str, Any],
|
||||
thresholds: dict[str, Any],
|
||||
threshold_info: dict[str, Any],
|
||||
warning_messages: list[str]
|
||||
) -> bool:
|
||||
"""Check bytes scanned threshold. Returns True if threshold exceeded."""
|
||||
if "bytes_scanned" not in thresholds or "Bytes Scanned" not in cost_item:
|
||||
return False
|
||||
|
||||
try:
|
||||
bytes_scanned = self._parse_bytes_from_cost_item(cost_item["Bytes Scanned"])
|
||||
threshold_bytes = thresholds["bytes_scanned"]
|
||||
threshold_info["bytes_threshold"] = threshold_bytes
|
||||
threshold_info["estimated_bytes"] = bytes_scanned
|
||||
|
||||
if bytes_scanned > threshold_bytes:
|
||||
warning_messages.append(
|
||||
f"This query will scan approximately {self._format_bytes(bytes_scanned)} "
|
||||
f"of data, which exceeds the threshold of {self._format_bytes(threshold_bytes)}."
|
||||
)
|
||||
return True
|
||||
except (ValueError, KeyError) as ex:
|
||||
logger.warning("Failed to parse bytes from cost estimation: %s", str(ex))
|
||||
|
||||
return False
|
||||
|
||||
def _check_cost_threshold(
|
||||
self,
|
||||
cost_item: dict[str, Any],
|
||||
thresholds: dict[str, Any],
|
||||
threshold_info: dict[str, Any],
|
||||
warning_messages: list[str]
|
||||
) -> bool:
|
||||
"""Check cost threshold. Returns True if threshold exceeded."""
|
||||
if "cost_threshold" not in thresholds or "Cost" not in cost_item:
|
||||
return False
|
||||
|
||||
try:
|
||||
cost_value = float(cost_item["Cost"])
|
||||
threshold_cost = thresholds["cost_threshold"]
|
||||
threshold_info["cost_threshold"] = threshold_cost
|
||||
threshold_info["estimated_cost"] = cost_value
|
||||
|
||||
if cost_value > threshold_cost:
|
||||
warning_messages.append(
|
||||
f"This query has an estimated cost of {cost_value}, "
|
||||
f"which exceeds the threshold of {threshold_cost}."
|
||||
)
|
||||
return True
|
||||
except (ValueError, KeyError) as ex:
|
||||
logger.warning("Failed to parse cost from cost estimation: %s", str(ex))
|
||||
|
||||
return False
|
||||
|
||||
def _parse_bytes_from_cost_item(self, bytes_str: str) -> int:
|
||||
"""Parse bytes from formatted string like '5.2 GB' or '1024 MB'."""
|
||||
if not isinstance(bytes_str, str):
|
||||
return int(bytes_str)
|
||||
|
||||
# Remove commas and split
|
||||
parts = bytes_str.replace(",", "").strip().split()
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"Cannot parse bytes from: {bytes_str}")
|
||||
|
||||
value_str, unit = parts
|
||||
value = float(value_str)
|
||||
unit = unit.upper()
|
||||
|
||||
multipliers = {
|
||||
"B": 1,
|
||||
"KB": 1024,
|
||||
"MB": 1024**2,
|
||||
"GB": 1024**3,
|
||||
"TB": 1024**4,
|
||||
"PB": 1024**5,
|
||||
}
|
||||
|
||||
if unit not in multipliers:
|
||||
raise ValueError(f"Unknown unit: {unit}")
|
||||
|
||||
return int(value * multipliers[unit])
|
||||
|
||||
def _format_bytes(self, bytes_count: int) -> str:
|
||||
"""Format bytes into human-readable string."""
|
||||
if bytes_count < 1024:
|
||||
return f"{bytes_count} B"
|
||||
elif bytes_count < 1024**2:
|
||||
return f"{bytes_count / 1024:.1f} KB"
|
||||
elif bytes_count < 1024**3:
|
||||
return f"{bytes_count / (1024**2):.1f} MB"
|
||||
elif bytes_count < 1024**4:
|
||||
return f"{bytes_count / (1024**3):.1f} GB"
|
||||
elif bytes_count < 1024**5:
|
||||
return f"{bytes_count / (1024**4):.1f} TB"
|
||||
else:
|
||||
return f"{bytes_count / (1024**5):.1f} PB"
|
||||
@@ -1191,6 +1191,18 @@ SQLLAB_ASYNC_TIME_LIMIT_SEC = int(timedelta(hours=6).total_seconds())
|
||||
# timeout.
|
||||
SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = int(timedelta(seconds=10).total_seconds())
|
||||
|
||||
# Query cost governance configuration
|
||||
# Enable automatic cost checking before query execution
|
||||
SQLLAB_QUERY_COST_CHECKING_ENABLED = False
|
||||
|
||||
# Cost thresholds that trigger warnings before query execution
|
||||
# This is a dictionary where keys are database engine names and values are threshold configs
|
||||
# Each threshold config can contain:
|
||||
# - 'bytes_scanned': maximum bytes that can be scanned without warning
|
||||
# - 'cost_threshold': monetary cost threshold (engine-specific units)
|
||||
# Example: {'bigquery': {'bytes_scanned': 5 * 1024**4}, 'presto': {'cost_threshold': 1000}}
|
||||
SQLLAB_QUERY_COST_THRESHOLDS = {}
|
||||
|
||||
# Timeout duration for SQL Lab fetching query results by the resultsKey.
|
||||
# 0 means no timeout.
|
||||
SQLLAB_QUERY_RESULT_TIMEOUT = 0
|
||||
|
||||
@@ -25,6 +25,9 @@ from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
from marshmallow import ValidationError
|
||||
|
||||
from superset import app, is_feature_enabled
|
||||
from superset.commands.sql_lab.check_cost_threshold import (
|
||||
QueryCostThresholdCheckCommand,
|
||||
)
|
||||
from superset.commands.sql_lab.estimate import QueryEstimationCommand
|
||||
from superset.commands.sql_lab.execute import CommandResult, ExecuteSqlCommand
|
||||
from superset.commands.sql_lab.export import SqlResultExportCommand
|
||||
@@ -188,6 +191,66 @@ class SqlLabRestApi(BaseSupersetApi):
|
||||
result = command.run()
|
||||
return self.response(200, result=result)
|
||||
|
||||
@expose("/check_cost_threshold/", methods=("POST",))
|
||||
@protect()
|
||||
@statsd_metrics
|
||||
@requires_json
|
||||
@event_logger.log_this_with_context(
|
||||
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}"
|
||||
f".check_cost_threshold",
|
||||
log_to_statsd=False,
|
||||
)
|
||||
def check_cost_threshold(self) -> Response:
|
||||
"""Check if query cost exceeds configured thresholds.
|
||||
---
|
||||
post:
|
||||
summary: Check if query cost exceeds thresholds
|
||||
requestBody:
|
||||
description: SQL query and params
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/EstimateQueryCostSchema'
|
||||
responses:
|
||||
200:
|
||||
description: Cost threshold check result
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
exceeds_threshold:
|
||||
type: boolean
|
||||
description: Whether query exceeds cost thresholds
|
||||
estimated_cost:
|
||||
type: array
|
||||
description: Detailed cost estimation
|
||||
threshold_info:
|
||||
type: object
|
||||
description: Information about thresholds and estimates
|
||||
formatted_warning:
|
||||
type: string
|
||||
nullable: true
|
||||
description: Human-readable warning message
|
||||
400:
|
||||
$ref: '#/components/responses/400'
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
403:
|
||||
$ref: '#/components/responses/403'
|
||||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
try:
|
||||
model = self.estimate_model_schema.load(request.json)
|
||||
except ValidationError as error:
|
||||
return self.response_400(message=error.messages)
|
||||
|
||||
command = QueryCostThresholdCheckCommand(model)
|
||||
result = command.run()
|
||||
return self.response(200, **result)
|
||||
|
||||
@expose("/format_sql/", methods=("POST",))
|
||||
@statsd_metrics
|
||||
@protect()
|
||||
|
||||
Reference in New Issue
Block a user