mirror of
https://github.com/apache/superset.git
synced 2026-05-06 16:34:32 +00:00
Compare commits
8 Commits
sc-103393-
...
alexandrus
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
41d8aeff5e | ||
|
|
c2a35e2eea | ||
|
|
5138aa2c11 | ||
|
|
66a9e2e16e | ||
|
|
0f417f0040 | ||
|
|
1462ac9282 | ||
|
|
da371217ef | ||
|
|
c971ea3ec6 |
@@ -67,9 +67,7 @@ export const renameOperator: PostProcessingFactory<PostProcessingRename> = (
|
||||
[...metricOffsetMap.entries()].forEach(
|
||||
([metricWithOffset, metricOnly]) => {
|
||||
const offsetLabel = timeOffsets.find(offset =>
|
||||
metricWithOffset.endsWith(
|
||||
`${TIME_COMPARISON_SEPARATOR}${offset}`,
|
||||
),
|
||||
metricWithOffset.includes(offset),
|
||||
);
|
||||
renamePairs.push([
|
||||
formData.comparison_type === ComparisonType.Values
|
||||
|
||||
@@ -26,7 +26,7 @@ export const getTimeOffset = (
|
||||
timeCompare.find(
|
||||
timeOffset =>
|
||||
// offset is represented as <offset>, group by list
|
||||
series.name.startsWith(`${timeOffset},`) ||
|
||||
series.name.includes(`${timeOffset},`) ||
|
||||
// offset is represented as <metric>__<offset>
|
||||
series.name.includes(`__${timeOffset}`) ||
|
||||
// offset is represented as <metric>, <offset>
|
||||
@@ -50,9 +50,7 @@ export const getOriginalSeries = (
|
||||
// offset in the middle: <metric>, <offset>, <dimension>
|
||||
result = result.replace(`, ${compare},`, ',');
|
||||
// offset at start: <offset>, <dimension>
|
||||
if (result.startsWith(`${compare},`)) {
|
||||
result = result.slice(`${compare},`.length);
|
||||
}
|
||||
result = result.replace(`${compare},`, '');
|
||||
// offset with double underscore: <metric>__<offset>
|
||||
result = result.replace(`__${compare}`, '');
|
||||
// offset at end: <metric>, <offset>
|
||||
|
||||
@@ -303,30 +303,6 @@ test('should add renameOperator if multiple metrics exist', () => {
|
||||
});
|
||||
});
|
||||
|
||||
test('should correctly match offsets that share a numeric prefix', () => {
|
||||
expect(
|
||||
renameOperator(
|
||||
{
|
||||
...formData,
|
||||
|
||||
comparison_type: ComparisonType.Values,
|
||||
time_compare: ['1 year ago', '11 year ago'],
|
||||
},
|
||||
queryObject,
|
||||
),
|
||||
).toEqual({
|
||||
operation: 'rename',
|
||||
options: {
|
||||
columns: {
|
||||
'count(*)__1 year ago': '1 year ago',
|
||||
'count(*)__11 year ago': '11 year ago',
|
||||
},
|
||||
inplace: true,
|
||||
level: 0,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
test('should remove renameOperator', () => {
|
||||
expect(
|
||||
renameOperator(
|
||||
|
||||
@@ -114,26 +114,3 @@ test('hasTimeOffset returns false when series name is not a string', () => {
|
||||
const timeCompare = ['1 year ago'];
|
||||
expect(hasTimeOffset(series, timeCompare)).toBe(false);
|
||||
});
|
||||
|
||||
test('getTimeOffset correctly matches offsets that share a numeric prefix', () => {
|
||||
const timeCompare = ['1 year ago', '11 year ago'];
|
||||
expect(
|
||||
getTimeOffset({ name: '11 year ago, Alexander' }, timeCompare),
|
||||
).toEqual('11 year ago');
|
||||
expect(getTimeOffset({ name: '1 year ago, Alexander' }, timeCompare)).toEqual(
|
||||
'1 year ago',
|
||||
);
|
||||
expect(getTimeOffset({ name: 'Births__11 year ago' }, timeCompare)).toEqual(
|
||||
'11 year ago',
|
||||
);
|
||||
});
|
||||
|
||||
test('getOriginalSeries correctly strips offsets that share a numeric prefix', () => {
|
||||
const timeCompare = ['1 year ago', '11 year ago'];
|
||||
expect(getOriginalSeries('11 year ago, Alexander', timeCompare)).toEqual(
|
||||
'Alexander',
|
||||
);
|
||||
expect(getOriginalSeries('1 year ago, Alexander', timeCompare)).toEqual(
|
||||
'Alexander',
|
||||
);
|
||||
});
|
||||
|
||||
@@ -25,6 +25,7 @@ import {
|
||||
waitFor,
|
||||
within,
|
||||
} from '@superset-ui/core/spec';
|
||||
import { formatNumber } from '@superset-ui/core';
|
||||
import { Select } from '.';
|
||||
|
||||
type Option = {
|
||||
@@ -85,8 +86,10 @@ const getElementsByClassName = (className: string) =>
|
||||
const getSelect = () =>
|
||||
screen.getByRole('combobox', { name: new RegExp(ARIA_LABEL, 'i') });
|
||||
|
||||
const selectAllButtonText = (length: number) => `Select all (${length})`;
|
||||
const deselectAllButtonText = (length: number) => `Clear (${length})`;
|
||||
const selectAllButtonText = (length: number) =>
|
||||
`Select all (${formatNumber('SMART_NUMBER', length)})`;
|
||||
const deselectAllButtonText = (length: number) =>
|
||||
`Clear (${formatNumber('SMART_NUMBER', length)})`;
|
||||
|
||||
const findSelectOption = (text: string) =>
|
||||
waitFor(() =>
|
||||
@@ -811,6 +814,62 @@ test('Maintains stable maxTagCount to prevent click target disappearing in oneLi
|
||||
expect(withinSelector.getByText('+ 2 ...')).toBeVisible();
|
||||
});
|
||||
|
||||
test('dropdown width matches input width after tags collapse in oneLine mode', async () => {
|
||||
render(
|
||||
<div style={{ width: '300px' }}>
|
||||
<Select
|
||||
{...defaultProps}
|
||||
value={[OPTIONS[0], OPTIONS[1], OPTIONS[2]]}
|
||||
mode="multiple"
|
||||
oneLine
|
||||
/>
|
||||
</div>,
|
||||
);
|
||||
|
||||
await open();
|
||||
|
||||
// Wait for RAF to complete and tags to collapse
|
||||
await waitFor(() => {
|
||||
const withinSelector = within(
|
||||
getElementByClassName('.ant-select-selector'),
|
||||
);
|
||||
expect(
|
||||
withinSelector.queryByText(OPTIONS[0].label),
|
||||
).not.toBeInTheDocument();
|
||||
expect(withinSelector.getByText('+ 3 ...')).toBeVisible();
|
||||
});
|
||||
|
||||
const selectElement = document.querySelector('.ant-select') as HTMLElement;
|
||||
expect(selectElement).toBeInTheDocument();
|
||||
|
||||
// Mock the select element's width since JSDOM doesn't perform real layout
|
||||
jest.spyOn(selectElement, 'getBoundingClientRect').mockReturnValue({
|
||||
width: 300,
|
||||
height: 32,
|
||||
top: 0,
|
||||
left: 0,
|
||||
right: 300,
|
||||
bottom: 32,
|
||||
x: 0,
|
||||
y: 0,
|
||||
toJSON: () => ({}),
|
||||
} as DOMRect);
|
||||
|
||||
// Close and reopen to trigger width measurement with mocked value
|
||||
await type('{esc}');
|
||||
await open();
|
||||
|
||||
const dropdown = document.querySelector(
|
||||
'.ant-select-dropdown',
|
||||
) as HTMLElement;
|
||||
expect(dropdown).toBeInTheDocument();
|
||||
|
||||
// Verify the dropdown has inline width matching the mocked select width
|
||||
await waitFor(() => {
|
||||
expect(parseInt(dropdown.style.width, 10)).toBe(300);
|
||||
});
|
||||
});
|
||||
|
||||
test('does not render "Select all" when there are 0 or 1 options', async () => {
|
||||
const { rerender } = render(
|
||||
<Select {...defaultProps} options={[]} mode="multiple" allowNewOptions />,
|
||||
@@ -915,6 +974,17 @@ test('"Select all" does not affect disabled options', async () => {
|
||||
expect(await findSelectValue()).not.toHaveTextContent(options[1].label);
|
||||
});
|
||||
|
||||
test('abbreviates large numbers in bulk action buttons', async () => {
|
||||
const manyOptions = Array.from({ length: 1500 }, (_, i) => ({
|
||||
label: `Option ${i}`,
|
||||
value: i,
|
||||
}));
|
||||
render(<Select {...defaultProps} mode="multiple" options={manyOptions} />);
|
||||
await open();
|
||||
// SMART_NUMBER format uses lowercase 'k' for thousands (d3-format)
|
||||
expect(await screen.findByText('Select all (1.5k)')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('dropdown takes full width of the select input for multi select', async () => {
|
||||
render(
|
||||
<div style={{ width: '400px' }}>
|
||||
|
||||
@@ -31,7 +31,7 @@ import {
|
||||
} from 'react';
|
||||
|
||||
import { t } from '@apache-superset/core/translation';
|
||||
import { ensureIsArray, usePrevious } from '@superset-ui/core';
|
||||
import { ensureIsArray, formatNumber, usePrevious } from '@superset-ui/core';
|
||||
import { Constants } from '@superset-ui/core/components';
|
||||
import {
|
||||
LabeledValue as AntdLabeledValue,
|
||||
@@ -149,6 +149,8 @@ const Select = forwardRef(
|
||||
// Prevent maxTagCount change during click events to avoid click target disappearing
|
||||
const [stableMaxTagCount, setStableMaxTagCount] = useState(maxTagCount);
|
||||
const isOpeningRef = useRef(false);
|
||||
const selectContainerRef = useRef<HTMLDivElement>(null);
|
||||
const [dropdownWidth, setDropdownWidth] = useState<number | true>(true);
|
||||
|
||||
useEffect(() => {
|
||||
if (oneLine) {
|
||||
@@ -159,12 +161,23 @@ const Select = forwardRef(
|
||||
requestAnimationFrame(() => {
|
||||
setStableMaxTagCount(0);
|
||||
isOpeningRef.current = false;
|
||||
|
||||
// Measure collapsed width and update dropdown width
|
||||
const selectElement =
|
||||
selectContainerRef.current?.querySelector('.ant-select');
|
||||
if (selectElement) {
|
||||
const { width } = selectElement.getBoundingClientRect();
|
||||
if (width > 0) {
|
||||
setDropdownWidth(width);
|
||||
}
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
if (!isDropdownVisible) {
|
||||
// When closing, immediately show the first tag
|
||||
setStableMaxTagCount(1);
|
||||
setDropdownWidth(true); // Reset to default when closing
|
||||
isOpeningRef.current = false;
|
||||
}
|
||||
return;
|
||||
@@ -494,7 +507,7 @@ const Select = forwardRef(
|
||||
|
||||
const bulkSelectComponent = useMemo(
|
||||
() => (
|
||||
<StyledBulkActionsContainer justify="center">
|
||||
<StyledBulkActionsContainer justify="space-between">
|
||||
<Button
|
||||
type="link"
|
||||
buttonStyle="link"
|
||||
@@ -506,7 +519,7 @@ const Select = forwardRef(
|
||||
handleSelectAll();
|
||||
}}
|
||||
>
|
||||
{`${t('Select all')} (${bulkSelectCounts.selectable})`}
|
||||
{`${t('Select all')} (${formatNumber('SMART_NUMBER', bulkSelectCounts.selectable)})`}
|
||||
</Button>
|
||||
<Button
|
||||
type="link"
|
||||
@@ -523,7 +536,7 @@ const Select = forwardRef(
|
||||
handleDeselectAll();
|
||||
}}
|
||||
>
|
||||
{`${t('Clear')} (${bulkSelectCounts.deselectable})`}
|
||||
{`${t('Clear')} (${formatNumber('SMART_NUMBER', bulkSelectCounts.deselectable)})`}
|
||||
</Button>
|
||||
</StyledBulkActionsContainer>
|
||||
),
|
||||
@@ -717,7 +730,11 @@ const Select = forwardRef(
|
||||
};
|
||||
|
||||
return (
|
||||
<StyledContainer className={className} headerPosition={headerPosition}>
|
||||
<StyledContainer
|
||||
ref={selectContainerRef}
|
||||
className={className}
|
||||
headerPosition={headerPosition}
|
||||
>
|
||||
{header && (
|
||||
<StyledHeader headerPosition={headerPosition}>{header}</StyledHeader>
|
||||
)}
|
||||
@@ -777,7 +794,7 @@ const Select = forwardRef(
|
||||
options={visibleOptions}
|
||||
optionRender={option => <Space>{option.label || option.value}</Space>}
|
||||
oneLine={oneLine}
|
||||
popupMatchSelectWidth
|
||||
popupMatchSelectWidth={oneLine ? dropdownWidth : true}
|
||||
css={props.css}
|
||||
dropdownAlign={DROPDOWN_ALIGN_BOTTOM}
|
||||
{...props}
|
||||
|
||||
@@ -140,11 +140,17 @@ export const StyledErrorMessage = styled.div`
|
||||
|
||||
export const StyledBulkActionsContainer = styled(Flex)`
|
||||
${({ theme }) => `
|
||||
padding: ${theme.sizeUnit}px;
|
||||
padding: ${theme.sizeUnit}px 0;
|
||||
border-top: 1px solid ${theme.colorSplit};
|
||||
gap: ${theme.sizeUnit * 2}px;
|
||||
& .superset-button {
|
||||
font-family: inherit;
|
||||
margin-left: 0 !important;
|
||||
}
|
||||
& .superset-button:first-of-type {
|
||||
padding-right: 0 !important;
|
||||
}
|
||||
& .superset-button:last-of-type {
|
||||
padding-left: 0 !important;
|
||||
}
|
||||
`}
|
||||
`;
|
||||
|
||||
@@ -56,6 +56,7 @@ import SearchSelectDropdown from './components/SearchSelectDropdown';
|
||||
import { SearchOption, SortByItem } from '../types';
|
||||
import getInitialSortState, { shouldSort } from '../utils/getInitialSortState';
|
||||
import getInitialFilterModel from '../utils/getInitialFilterModel';
|
||||
import reconcileColumnState from '../utils/reconcileColumnState';
|
||||
import { PAGE_SIZE_OPTIONS } from '../consts';
|
||||
import { getCompleteFilterState } from '../utils/filterStateManager';
|
||||
|
||||
@@ -429,10 +430,17 @@ const AgGridDataTable: FunctionComponent<AgGridTableProps> = memo(
|
||||
// Note: filterModel is now handled via gridInitialState for better performance
|
||||
if (chartState?.columnState && params.api) {
|
||||
try {
|
||||
params.api.applyColumnState?.({
|
||||
state: chartState.columnState as ColumnState[],
|
||||
applyOrder: true,
|
||||
});
|
||||
const reconciledColumnState = reconcileColumnState(
|
||||
chartState.columnState as ColumnState[],
|
||||
colDefsFromProps as ColDef[],
|
||||
);
|
||||
|
||||
if (reconciledColumnState) {
|
||||
params.api.applyColumnState?.({
|
||||
state: reconciledColumnState.columnState,
|
||||
applyOrder: reconciledColumnState.applyOrder,
|
||||
});
|
||||
}
|
||||
} catch {
|
||||
// Silently fail if state restoration fails
|
||||
}
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import {
|
||||
type ColDef,
|
||||
type ColumnState,
|
||||
} from '@superset-ui/core/components/ThemedAgGridReact';
|
||||
import reconcileColumnState, { getLeafColumnIds } from './reconcileColumnState';
|
||||
|
||||
test('getLeafColumnIds flattens grouped column defs in visual order', () => {
|
||||
const colDefs: ColDef[] = [
|
||||
{ field: 'Manufacture_Date' },
|
||||
{
|
||||
headerName: 'Metrics',
|
||||
children: [
|
||||
{ field: 'SUM(Sales_Amount)' },
|
||||
{ field: 'SUM(Discount_Applied)' },
|
||||
],
|
||||
} as ColDef,
|
||||
{ field: 'SUM(Quantity_Sold)' },
|
||||
];
|
||||
|
||||
expect(getLeafColumnIds(colDefs)).toEqual([
|
||||
'Manufacture_Date',
|
||||
'SUM(Sales_Amount)',
|
||||
'SUM(Discount_Applied)',
|
||||
'SUM(Quantity_Sold)',
|
||||
]);
|
||||
});
|
||||
|
||||
test('preserves saved order when the current column set is unchanged', () => {
|
||||
const colDefs: ColDef[] = [
|
||||
{ field: 'Transaction_Timestamp' },
|
||||
{ field: 'SUM(Sales_Amount)' },
|
||||
{ field: 'SUM(Discount_Applied)' },
|
||||
];
|
||||
const savedColumnState: ColumnState[] = [
|
||||
{ colId: 'SUM(Sales_Amount)' },
|
||||
{ colId: 'Transaction_Timestamp' },
|
||||
{ colId: 'SUM(Discount_Applied)' },
|
||||
];
|
||||
|
||||
expect(reconcileColumnState(savedColumnState, colDefs)).toEqual({
|
||||
applyOrder: true,
|
||||
columnState: savedColumnState,
|
||||
});
|
||||
});
|
||||
|
||||
test('drops stale order when a dynamic group by swaps the dimension column', () => {
|
||||
const currentColDefs: ColDef[] = [
|
||||
{ field: 'Manufacture_Date' },
|
||||
{ field: 'SUM(Sales_Amount)' },
|
||||
{ field: 'SUM(Discount_Applied)' },
|
||||
{ field: 'SUM(Quantity_Sold)' },
|
||||
];
|
||||
const savedColumnState: ColumnState[] = [
|
||||
{ colId: 'Transaction_Timestamp' },
|
||||
{ colId: 'SUM(Sales_Amount)' },
|
||||
{ colId: 'SUM(Discount_Applied)' },
|
||||
{ colId: 'SUM(Quantity_Sold)' },
|
||||
];
|
||||
|
||||
expect(reconcileColumnState(savedColumnState, currentColDefs)).toEqual({
|
||||
applyOrder: false,
|
||||
columnState: [
|
||||
{ colId: 'SUM(Sales_Amount)' },
|
||||
{ colId: 'SUM(Discount_Applied)' },
|
||||
{ colId: 'SUM(Quantity_Sold)' },
|
||||
],
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,83 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import {
|
||||
type ColDef,
|
||||
type ColumnState,
|
||||
} from '@superset-ui/core/components/ThemedAgGridReact';
|
||||
|
||||
type ColumnGroupDef = ColDef & {
|
||||
children?: ColumnDefLike[];
|
||||
};
|
||||
|
||||
type ColumnDefLike = ColDef | ColumnGroupDef;
|
||||
|
||||
function hasChildren(colDef: ColumnDefLike): colDef is ColumnGroupDef {
|
||||
return 'children' in colDef;
|
||||
}
|
||||
|
||||
export interface ReconciledColumnState {
|
||||
applyOrder: boolean;
|
||||
columnState: ColumnState[];
|
||||
}
|
||||
|
||||
export function getLeafColumnIds(colDefs: ColumnDefLike[]): string[] {
|
||||
return colDefs.flatMap(colDef => {
|
||||
if (
|
||||
hasChildren(colDef) &&
|
||||
Array.isArray(colDef.children) &&
|
||||
colDef.children.length > 0
|
||||
) {
|
||||
return getLeafColumnIds(colDef.children);
|
||||
}
|
||||
|
||||
return typeof colDef.field === 'string' ? [colDef.field] : [];
|
||||
});
|
||||
}
|
||||
|
||||
export default function reconcileColumnState(
|
||||
savedColumnState: ColumnState[] | undefined,
|
||||
colDefs: ColumnDefLike[],
|
||||
): ReconciledColumnState | null {
|
||||
if (!Array.isArray(savedColumnState) || savedColumnState.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const currentColumnIds = getLeafColumnIds(colDefs);
|
||||
const currentColumnIdSet = new Set(currentColumnIds);
|
||||
const filteredColumnState = savedColumnState.filter(
|
||||
column =>
|
||||
typeof column.colId === 'string' && currentColumnIdSet.has(column.colId),
|
||||
);
|
||||
|
||||
if (filteredColumnState.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const savedColumnIdSet = new Set(
|
||||
filteredColumnState.map(column => column.colId),
|
||||
);
|
||||
const hasSameColumnSet =
|
||||
currentColumnIds.length === savedColumnIdSet.size &&
|
||||
currentColumnIds.every(columnId => savedColumnIdSet.has(columnId));
|
||||
|
||||
return {
|
||||
columnState: filteredColumnState,
|
||||
applyOrder: hasSameColumnSet,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import { renderHook, act } from '@testing-library/react-hooks';
|
||||
import { SupersetClient } from '@superset-ui/core';
|
||||
import { useDownloadScreenshot } from './useDownloadScreenshot';
|
||||
import { DownloadScreenshotFormat } from '../components/menu/DownloadMenuItems/types';
|
||||
|
||||
jest.mock('@superset-ui/core', () => ({
|
||||
SupersetClient: {
|
||||
post: jest.fn(),
|
||||
get: jest.fn(),
|
||||
},
|
||||
SupersetApiError: class SupersetApiError extends Error {
|
||||
status: number;
|
||||
constructor(message: string, status: number) {
|
||||
super(message);
|
||||
this.status = status;
|
||||
}
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('react-redux', () => ({
|
||||
useSelector: jest.fn(() => undefined),
|
||||
}));
|
||||
|
||||
jest.mock('src/components/MessageToasts/withToasts', () => ({
|
||||
useToasts: () => ({
|
||||
addDangerToast: jest.fn(),
|
||||
addSuccessToast: jest.fn(),
|
||||
addInfoToast: jest.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
jest.mock('src/utils/urlUtils', () => ({
|
||||
getDashboardUrlParams: jest.fn(() => []),
|
||||
}));
|
||||
|
||||
test('downloadScreenshot calls API with force=true to ensure fresh screenshots', async () => {
|
||||
const mockCacheKey = 'test-cache-key';
|
||||
(SupersetClient.post as jest.Mock).mockResolvedValue({
|
||||
json: { cache_key: mockCacheKey },
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useDownloadScreenshot(123));
|
||||
|
||||
await act(async () => {
|
||||
result.current(DownloadScreenshotFormat.PNG);
|
||||
});
|
||||
|
||||
expect(SupersetClient.post).toHaveBeenCalledTimes(1);
|
||||
const callArgs = (SupersetClient.post as jest.Mock).mock.calls[0][0];
|
||||
|
||||
// Verify that force=true is included in the endpoint URL
|
||||
// This prevents regression where stale cached screenshots are returned
|
||||
expect(callArgs.endpoint).toContain('force');
|
||||
expect(callArgs.endpoint).toMatch(/force[:%]true|force[:%]!t/);
|
||||
});
|
||||
@@ -20,6 +20,7 @@ import { useCallback, useEffect, useRef } from 'react';
|
||||
import { useSelector } from 'react-redux';
|
||||
import { useToasts } from 'src/components/MessageToasts/withToasts';
|
||||
import { last } from 'lodash';
|
||||
import rison from 'rison';
|
||||
import contentDisposition from 'content-disposition';
|
||||
import { t } from '@apache-superset/core/translation';
|
||||
import { SupersetClient, SupersetApiError } from '@superset-ui/core';
|
||||
@@ -145,7 +146,7 @@ export const useDownloadScreenshot = (
|
||||
};
|
||||
|
||||
SupersetClient.post({
|
||||
endpoint: `/api/v1/dashboard/${dashboardId}/cache_dashboard_screenshot/`,
|
||||
endpoint: `/api/v1/dashboard/${dashboardId}/cache_dashboard_screenshot/?q=${rison.encode({ force: true })}`,
|
||||
jsonPayload: {
|
||||
anchor,
|
||||
activeTabs,
|
||||
|
||||
@@ -32,6 +32,7 @@ import { URL_PARAMS } from 'src/constants';
|
||||
import { JsonObject, VizType } from '@superset-ui/core';
|
||||
import { useUnsavedChangesPrompt } from 'src/hooks/useUnsavedChangesPrompt';
|
||||
import { getParsedExploreURLParams } from 'src/explore/exploreUtils/getParsedExploreURLParams';
|
||||
import * as messageToastActions from 'src/components/MessageToasts/actions';
|
||||
import ChartPage from '.';
|
||||
|
||||
jest.mock('src/hooks/useUnsavedChangesPrompt', () => ({
|
||||
@@ -358,4 +359,96 @@ describe('ChartPage', () => {
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
test('does not show error toast when request is aborted on unmount', async () => {
|
||||
const addDangerToastSpy = jest.spyOn(messageToastActions, 'addDangerToast');
|
||||
const exploreApiRoute = 'glob:*/api/v1/explore/*';
|
||||
let rejectRequest: (error: Error) => void;
|
||||
const pendingPromise = new Promise((_, reject) => {
|
||||
rejectRequest = reject;
|
||||
});
|
||||
|
||||
fetchMock.get(exploreApiRoute, () => pendingPromise);
|
||||
|
||||
const { unmount } = render(<ChartPage />, {
|
||||
useRouter: true,
|
||||
useRedux: true,
|
||||
useDnd: true,
|
||||
});
|
||||
|
||||
// Unmount before the request completes
|
||||
unmount();
|
||||
|
||||
// Simulate the aborted request rejection
|
||||
const abortError = new Error('The operation was aborted.');
|
||||
abortError.name = 'AbortError';
|
||||
rejectRequest!(abortError);
|
||||
|
||||
// Wait for the rejected request to settle before asserting no toast was shown
|
||||
await pendingPromise.catch(() => undefined);
|
||||
expect(addDangerToastSpy).not.toHaveBeenCalled();
|
||||
|
||||
addDangerToastSpy.mockRestore();
|
||||
});
|
||||
|
||||
test('aborts in-flight request when a new request is made', async () => {
|
||||
const addDangerToastSpy = jest.spyOn(messageToastActions, 'addDangerToast');
|
||||
const exploreApiRoute = 'glob:*/api/v1/explore/*';
|
||||
const exploreFormData = getExploreFormData({
|
||||
viz_type: VizType.Table,
|
||||
show_cell_bars: true,
|
||||
});
|
||||
|
||||
// First request will reject with AbortError when aborted
|
||||
let rejectFirstRequest: (error: Error) => void;
|
||||
const firstRequestPromise = new Promise((_, reject) => {
|
||||
rejectFirstRequest = reject;
|
||||
});
|
||||
|
||||
fetchMock.get(exploreApiRoute, () => firstRequestPromise);
|
||||
|
||||
render(
|
||||
<>
|
||||
<Link to="/?slice_id=99">Navigate</Link>
|
||||
<ChartPage />
|
||||
</>,
|
||||
{
|
||||
useRouter: true,
|
||||
useRedux: true,
|
||||
useDnd: true,
|
||||
},
|
||||
);
|
||||
|
||||
// Wait for the first request to be initiated
|
||||
await waitFor(() =>
|
||||
expect(fetchMock.callHistory.calls(exploreApiRoute).length).toBe(1),
|
||||
);
|
||||
|
||||
// Set up second request to return immediately
|
||||
fetchMock.clearHistory().removeRoutes();
|
||||
fetchMock.get(exploreApiRoute, {
|
||||
result: { dataset: { id: 1 }, form_data: exploreFormData },
|
||||
});
|
||||
|
||||
// Navigate to trigger a new request (which should abort the first)
|
||||
fireEvent.click(screen.getByText('Navigate'));
|
||||
|
||||
// Simulate the first request being aborted
|
||||
const abortError = new Error('The operation was aborted.');
|
||||
abortError.name = 'AbortError';
|
||||
rejectFirstRequest!(abortError);
|
||||
|
||||
// Wait for the first request to settle before asserting
|
||||
await firstRequestPromise.catch(() => undefined);
|
||||
|
||||
// Wait for the second request to complete
|
||||
await waitFor(() =>
|
||||
expect(fetchMock.callHistory.calls(exploreApiRoute).length).toBe(1),
|
||||
);
|
||||
|
||||
// No error toast should be shown from the aborted first request
|
||||
expect(addDangerToastSpy).not.toHaveBeenCalled();
|
||||
|
||||
addDangerToastSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -50,10 +50,14 @@ const isValidResult = (rv: JsonObject): boolean =>
|
||||
const hasDatasetId = (rv: JsonObject): boolean =>
|
||||
isDefined(rv?.result?.dataset?.id);
|
||||
|
||||
const fetchExploreData = async (exploreUrlParams: URLSearchParams) => {
|
||||
const fetchExploreData = async (
|
||||
exploreUrlParams: URLSearchParams,
|
||||
signal?: AbortSignal,
|
||||
) => {
|
||||
const rv = await makeApi<{}, ExploreResponsePayload>({
|
||||
method: 'GET',
|
||||
endpoint: 'api/v1/explore/',
|
||||
signal,
|
||||
})(exploreUrlParams);
|
||||
if (isValidResult(rv)) {
|
||||
if (hasDatasetId(rv)) {
|
||||
@@ -130,6 +134,7 @@ const getDashboardContextFormData = (search: string) => {
|
||||
export default function ExplorePage() {
|
||||
const [isLoaded, setIsLoaded] = useState(false);
|
||||
const fetchGeneration = useRef(0);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
const dispatch = useDispatch();
|
||||
const history = useHistory();
|
||||
|
||||
@@ -138,6 +143,11 @@ export default function ExplorePage() {
|
||||
loc: { search: string; pathname: string },
|
||||
saveAction?: SaveActionType | null,
|
||||
) => {
|
||||
// Abort any in-flight request before starting a new one
|
||||
abortControllerRef.current?.abort();
|
||||
const controller = new AbortController();
|
||||
abortControllerRef.current = controller;
|
||||
|
||||
fetchGeneration.current += 1;
|
||||
const generation = fetchGeneration.current;
|
||||
const exploreUrlParams = getParsedExploreURLParams(loc);
|
||||
@@ -145,7 +155,7 @@ export default function ExplorePage() {
|
||||
|
||||
const isStale = () => generation !== fetchGeneration.current;
|
||||
|
||||
fetchExploreData(exploreUrlParams)
|
||||
fetchExploreData(exploreUrlParams, controller.signal)
|
||||
.then(({ result }) => {
|
||||
if (isStale()) {
|
||||
return;
|
||||
@@ -183,7 +193,19 @@ export default function ExplorePage() {
|
||||
}),
|
||||
);
|
||||
})
|
||||
.catch(err => Promise.all([getClientErrorObject(err), err]))
|
||||
.catch(err => {
|
||||
// Silently ignore aborted requests - AbortError may be wrapped in SupersetApiError by makeApi
|
||||
// or come through with statusText === 'abort' from SupersetClient
|
||||
if (
|
||||
err.name === 'AbortError' ||
|
||||
err.statusText === 'abort' ||
|
||||
err.originalError?.name === 'AbortError' ||
|
||||
err.originalError?.statusText === 'abort'
|
||||
) {
|
||||
return;
|
||||
}
|
||||
return Promise.all([getClientErrorObject(err), err]);
|
||||
})
|
||||
.then(resolved => {
|
||||
if (isStale()) {
|
||||
return;
|
||||
@@ -251,7 +273,7 @@ export default function ExplorePage() {
|
||||
return Promise.resolve();
|
||||
})
|
||||
.finally(() => {
|
||||
if (!isStale()) {
|
||||
if (!isStale() && !controller.signal.aborted) {
|
||||
setIsLoaded(true);
|
||||
}
|
||||
});
|
||||
@@ -259,6 +281,14 @@ export default function ExplorePage() {
|
||||
[dispatch],
|
||||
);
|
||||
|
||||
// Cleanup: abort in-flight requests on unmount
|
||||
useEffect(
|
||||
() => () => {
|
||||
abortControllerRef.current?.abort();
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
// Initial fetch on mount
|
||||
useEffect(() => {
|
||||
loadExploreData(history.location);
|
||||
|
||||
@@ -446,6 +446,7 @@ from superset.mcp_service.database.tool import ( # noqa: F401, E402
|
||||
list_databases,
|
||||
)
|
||||
from superset.mcp_service.dataset.tool import ( # noqa: F401, E402
|
||||
get_column_sample_data,
|
||||
get_dataset_info,
|
||||
list_datasets,
|
||||
)
|
||||
|
||||
@@ -306,6 +306,36 @@ class GetDatasetInfoRequest(MetadataCacheControl):
|
||||
]
|
||||
|
||||
|
||||
class GetColumnSampleDataRequest(BaseModel):
|
||||
"""Request schema for get_column_sample_data."""
|
||||
|
||||
dataset_id: int = Field(..., description="The dataset ID to query")
|
||||
column_name: str = Field(
|
||||
..., description="The column name to get distinct values for"
|
||||
)
|
||||
limit: int = Field(
|
||||
default=20,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Maximum number of distinct values to return (default 20, max 100)",
|
||||
)
|
||||
|
||||
|
||||
class ColumnSampleDataResponse(BaseModel):
|
||||
"""Response schema for get_column_sample_data."""
|
||||
|
||||
dataset_id: int = Field(..., description="The dataset ID queried")
|
||||
column_name: str = Field(..., description="The column name queried")
|
||||
values: List[str | int | float | bool | None] = Field(
|
||||
..., description="Distinct values found in the column"
|
||||
)
|
||||
count: int = Field(..., description="Number of distinct values returned")
|
||||
truncated: bool = Field(
|
||||
False,
|
||||
description="True if more distinct values exist beyond the limit",
|
||||
)
|
||||
|
||||
|
||||
def _parse_json_field(obj: Any, field_name: str) -> Dict[str, Any] | None:
|
||||
"""Parse a field that may be stored as a JSON string into a dict."""
|
||||
value = getattr(obj, field_name, None)
|
||||
|
||||
@@ -15,10 +15,12 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from .get_column_sample_data import get_column_sample_data
|
||||
from .get_dataset_info import get_dataset_info
|
||||
from .list_datasets import list_datasets
|
||||
|
||||
__all__ = [
|
||||
"list_datasets",
|
||||
"get_column_sample_data",
|
||||
"get_dataset_info",
|
||||
"list_datasets",
|
||||
]
|
||||
|
||||
158
superset/mcp_service/dataset/tool/get_column_sample_data.py
Normal file
158
superset/mcp_service/dataset/tool/get_column_sample_data.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Get column sample data FastMCP tool
|
||||
|
||||
This module contains the FastMCP tool for retrieving distinct values
|
||||
from a dataset column, useful for building filters in charts.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.dataset.schemas import (
|
||||
ColumnSampleDataResponse,
|
||||
DatasetError,
|
||||
GetColumnSampleDataRequest,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["discovery"],
|
||||
class_permission_name="Dataset",
|
||||
annotations=ToolAnnotations(
|
||||
title="Get column sample data",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def get_column_sample_data(
|
||||
request: GetColumnSampleDataRequest, ctx: Context
|
||||
) -> ColumnSampleDataResponse | DatasetError:
|
||||
"""Get distinct values for a dataset column.
|
||||
|
||||
Returns up to `limit` distinct values from the specified column.
|
||||
Useful for discovering valid filter values when building charts.
|
||||
Respects row-level security and dataset fetch_values_predicate.
|
||||
|
||||
IMPORTANT FOR LLM CLIENTS:
|
||||
- Use this tool BEFORE creating charts with filters to discover actual
|
||||
column values instead of guessing
|
||||
- Use get_dataset_info first to find column names and types
|
||||
- Low-cardinality columns (gender, status, category) work best
|
||||
|
||||
Example usage:
|
||||
```json
|
||||
{
|
||||
"dataset_id": 123,
|
||||
"column_name": "gender",
|
||||
"limit": 20
|
||||
}
|
||||
```
|
||||
"""
|
||||
await ctx.info(
|
||||
"Retrieving column sample data: dataset_id=%s, column=%s, limit=%s"
|
||||
% (request.dataset_id, request.column_name, request.limit)
|
||||
)
|
||||
|
||||
try:
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
|
||||
with event_logger.log_context(action="mcp.get_column_sample_data.lookup"):
|
||||
dataset = DatasetDAO.find_by_id(request.dataset_id)
|
||||
|
||||
if not dataset:
|
||||
await ctx.warning(
|
||||
"Dataset not found: dataset_id=%s" % (request.dataset_id,)
|
||||
)
|
||||
return DatasetError(
|
||||
error=f"Dataset with ID {request.dataset_id} not found",
|
||||
error_type="NotFound",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
try:
|
||||
dataset.raise_for_access()
|
||||
except SupersetSecurityException as ex:
|
||||
await ctx.warning(
|
||||
"Permission denied for dataset_id=%s: %s"
|
||||
% (request.dataset_id, str(ex))
|
||||
)
|
||||
return DatasetError(
|
||||
error=f"Permission denied for dataset {request.dataset_id}",
|
||||
error_type="PermissionDenied",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Fetch one extra value to detect truncation without a COUNT query
|
||||
fetch_limit = request.limit + 1
|
||||
denormalize_column = not dataset.normalize_columns
|
||||
|
||||
with event_logger.log_context(action="mcp.get_column_sample_data.query"):
|
||||
raw_values = dataset.values_for_column(
|
||||
column_name=request.column_name,
|
||||
limit=fetch_limit,
|
||||
denormalize_column=denormalize_column,
|
||||
)
|
||||
|
||||
truncated = len(raw_values) > request.limit
|
||||
values = raw_values[: request.limit]
|
||||
|
||||
await ctx.info(
|
||||
"Column sample data retrieved: dataset_id=%s, column=%s, "
|
||||
"count=%s, truncated=%s"
|
||||
% (request.dataset_id, request.column_name, len(values), truncated)
|
||||
)
|
||||
|
||||
return ColumnSampleDataResponse(
|
||||
dataset_id=request.dataset_id,
|
||||
column_name=request.column_name,
|
||||
values=values,
|
||||
count=len(values),
|
||||
truncated=truncated,
|
||||
)
|
||||
|
||||
except KeyError:
|
||||
await ctx.warning(
|
||||
"Column not found: column=%s in dataset_id=%s"
|
||||
% (request.column_name, request.dataset_id)
|
||||
)
|
||||
return DatasetError(
|
||||
error=f"Column '{request.column_name}' does not exist "
|
||||
f"in dataset {request.dataset_id}",
|
||||
error_type="ColumnNotFound",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
except Exception as e:
|
||||
await ctx.error(
|
||||
"Column sample data retrieval failed: dataset_id=%s, column=%s, "
|
||||
"error=%s, error_type=%s"
|
||||
% (request.dataset_id, request.column_name, str(e), type(e).__name__)
|
||||
)
|
||||
return DatasetError(
|
||||
error=f"Failed to get column sample data: {str(e)}",
|
||||
error_type="InternalError",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
@@ -1881,26 +1881,10 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||
time_grain
|
||||
)
|
||||
|
||||
if not time_grain:
|
||||
has_temporal_join_key = any(
|
||||
pd.api.types.is_datetime64_any_dtype(df[key])
|
||||
for key in join_keys
|
||||
if key in df.columns
|
||||
if join_column_producer and not time_grain:
|
||||
raise QueryObjectValidationError(
|
||||
_("Time Grain must be specified when using Time Shift.")
|
||||
)
|
||||
if has_temporal_join_key:
|
||||
has_relative_offset = any(
|
||||
not (
|
||||
self.is_valid_date_range(offset)
|
||||
and feature_flag_manager.is_feature_enabled(
|
||||
"DATE_RANGE_TIMESHIFTS_ENABLED"
|
||||
)
|
||||
)
|
||||
for offset in offset_dfs
|
||||
)
|
||||
if has_relative_offset:
|
||||
raise QueryObjectValidationError(
|
||||
_("Time Grain must be specified when using Time Comparison.")
|
||||
)
|
||||
|
||||
for offset, offset_df in offset_dfs.items():
|
||||
is_date_range_offset = self.is_valid_date_range(
|
||||
|
||||
@@ -1783,9 +1783,9 @@ def extract_dataframe_dtypes(
|
||||
columns_by_name[column.column_name] = column
|
||||
|
||||
generic_types: list[GenericDataType] = []
|
||||
for i, column in enumerate(df.columns):
|
||||
for column in df.columns:
|
||||
column_object = columns_by_name.get(column)
|
||||
series = df.iloc[:, i]
|
||||
series = df[column]
|
||||
inferred_type: str = ""
|
||||
if series.isna().all():
|
||||
sql_type: Optional[str] = ""
|
||||
|
||||
@@ -24,7 +24,7 @@ from flask_babel import gettext as _
|
||||
from pandas import DataFrame, MultiIndex
|
||||
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.core import PostProcessingContributionOrientation, TIME_COMPARISON
|
||||
from superset.utils.core import PostProcessingContributionOrientation
|
||||
from superset.utils.pandas_postprocessing.utils import validate_column_args
|
||||
|
||||
|
||||
@@ -130,7 +130,7 @@ def get_column_groups(
|
||||
time_shift = None
|
||||
if time_shifts and isinstance(col_0, str):
|
||||
for ts in time_shifts:
|
||||
if col_0.endswith(TIME_COMPARISON + ts):
|
||||
if col_0.endswith(ts):
|
||||
time_shift = ts
|
||||
break
|
||||
if time_shift is not None:
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import pytest
|
||||
from pandas import DataFrame, Series, Timestamp
|
||||
from pandas.testing import assert_frame_equal
|
||||
from pytest import fixture, mark # noqa: PT013
|
||||
@@ -24,7 +23,6 @@ from superset.common.query_context import QueryContext
|
||||
from superset.common.query_context_processor import QueryContextProcessor
|
||||
from superset.connectors.sqla.models import BaseDatasource
|
||||
from superset.constants import TimeGrain
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.models.helpers import ExploreMixin
|
||||
|
||||
# Create processor and bind ExploreMixin methods to datasource
|
||||
@@ -244,47 +242,3 @@ def test_join_offset_dfs_totals_query_no_dimensions():
|
||||
)
|
||||
|
||||
assert_frame_equal(expected, result)
|
||||
|
||||
|
||||
def test_join_offset_dfs_raises_without_time_grain():
|
||||
"""Time comparison with relative offsets requires a time grain."""
|
||||
df = DataFrame({"ds": [Timestamp("2021-01-01")], "D": [1]})
|
||||
offset_df = DataFrame({"ds": [Timestamp("2021-02-01")], "B": [5]})
|
||||
offset_dfs = {"1 year ago": offset_df}
|
||||
|
||||
with pytest.raises(
|
||||
QueryObjectValidationError, match="Time Grain must be specified"
|
||||
):
|
||||
query_context_processor.join_offset_dfs(
|
||||
df, offset_dfs, time_grain=None, join_keys=["ds"]
|
||||
)
|
||||
|
||||
|
||||
def test_join_offset_dfs_allows_non_temporal_join_without_time_grain():
|
||||
"""Time comparison without time grain is valid when join keys are non-temporal."""
|
||||
df = DataFrame({"country": ["US", "UK"], "metric": [10, 20]})
|
||||
offset_df = DataFrame({"country": ["US", "UK"], "metric__1 year ago": [8, 15]})
|
||||
offset_dfs = {"1 year ago": offset_df}
|
||||
|
||||
result = query_context_processor.join_offset_dfs(
|
||||
df, offset_dfs, time_grain=None, join_keys=["country"]
|
||||
)
|
||||
assert "metric__1 year ago" in result.columns
|
||||
|
||||
|
||||
def test_join_offset_dfs_raises_when_temporal_key_not_first():
|
||||
"""Temporal join key detection works even when it's not the first key."""
|
||||
df = DataFrame(
|
||||
{"country": ["US", "UK"], "ds": [Timestamp("2021-01-01"), Timestamp("2021-02-01")], "D": [1, 2]}
|
||||
)
|
||||
offset_df = DataFrame(
|
||||
{"country": ["US", "UK"], "ds": [Timestamp("2021-03-01"), Timestamp("2021-04-01")], "B": [5, 6]}
|
||||
)
|
||||
offset_dfs = {"1 year ago": offset_df}
|
||||
|
||||
with pytest.raises(
|
||||
QueryObjectValidationError, match="Time Grain must be specified"
|
||||
):
|
||||
query_context_processor.join_offset_dfs(
|
||||
df, offset_dfs, time_grain=None, join_keys=["country", "ds"]
|
||||
)
|
||||
|
||||
@@ -0,0 +1,248 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
|
||||
import logging
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
from pydantic import ValidationError
|
||||
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.mcp_service.dataset.schemas import GetColumnSampleDataRequest
|
||||
from superset.utils import json
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth():
|
||||
"""Mock authentication for all tests."""
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
def _create_mock_dataset(dataset_id=1, normalize_columns=True):
|
||||
"""Create a mock dataset with values_for_column support."""
|
||||
dataset = MagicMock()
|
||||
dataset.id = dataset_id
|
||||
dataset.normalize_columns = normalize_columns
|
||||
dataset.raise_for_access = MagicMock()
|
||||
dataset.values_for_column = MagicMock()
|
||||
return dataset
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_column_sample_data_success(mock_find, mcp_server):
|
||||
"""Test successful retrieval of column sample data."""
|
||||
dataset = _create_mock_dataset()
|
||||
dataset.values_for_column.return_value = ["Male", "Female", "Other"]
|
||||
mock_find.return_value = dataset
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_column_sample_data",
|
||||
{"request": {"dataset_id": 1, "column_name": "gender", "limit": 20}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["dataset_id"] == 1
|
||||
assert data["column_name"] == "gender"
|
||||
assert data["values"] == ["Male", "Female", "Other"]
|
||||
assert data["count"] == 3
|
||||
assert data["truncated"] is False
|
||||
|
||||
dataset.values_for_column.assert_called_once_with(
|
||||
column_name="gender",
|
||||
limit=21, # limit + 1 for truncation detection
|
||||
denormalize_column=False, # normalize_columns=True -> denormalize=False
|
||||
)
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_column_sample_data_truncated(mock_find, mcp_server):
|
||||
"""Test truncation detection when more values exist than the limit."""
|
||||
dataset = _create_mock_dataset()
|
||||
# Return 4 values when limit is 3 (tool fetches limit+1=4)
|
||||
dataset.values_for_column.return_value = ["A", "B", "C", "D"]
|
||||
mock_find.return_value = dataset
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_column_sample_data",
|
||||
{"request": {"dataset_id": 1, "column_name": "status", "limit": 3}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["values"] == ["A", "B", "C"]
|
||||
assert data["count"] == 3
|
||||
assert data["truncated"] is True
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_column_sample_data_dataset_not_found(mock_find, mcp_server):
|
||||
"""Test error when dataset does not exist."""
|
||||
mock_find.return_value = None
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_column_sample_data",
|
||||
{"request": {"dataset_id": 999, "column_name": "gender"}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "NotFound"
|
||||
assert "999" in data["error"]
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_column_sample_data_permission_denied(mock_find, mcp_server):
|
||||
"""Test error when user lacks permission to access the dataset."""
|
||||
dataset = _create_mock_dataset()
|
||||
dataset.raise_for_access.side_effect = SupersetSecurityException(
|
||||
SupersetError(
|
||||
message="Access denied",
|
||||
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
|
||||
level=ErrorLevel.ERROR,
|
||||
)
|
||||
)
|
||||
mock_find.return_value = dataset
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_column_sample_data",
|
||||
{"request": {"dataset_id": 1, "column_name": "gender"}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "PermissionDenied"
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_column_sample_data_column_not_found(mock_find, mcp_server):
|
||||
"""Test error when column does not exist in the dataset."""
|
||||
dataset = _create_mock_dataset()
|
||||
dataset.values_for_column.side_effect = KeyError("nonexistent_col")
|
||||
mock_find.return_value = dataset
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_column_sample_data",
|
||||
{"request": {"dataset_id": 1, "column_name": "nonexistent_col"}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "ColumnNotFound"
|
||||
assert "nonexistent_col" in data["error"]
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_column_sample_data_default_limit(mock_find, mcp_server):
|
||||
"""Test that omitting limit defaults to 20."""
|
||||
dataset = _create_mock_dataset()
|
||||
dataset.values_for_column.return_value = list(range(10))
|
||||
mock_find.return_value = dataset
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_column_sample_data",
|
||||
{"request": {"dataset_id": 1, "column_name": "category"}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["count"] == 10
|
||||
|
||||
# Default limit=20, so fetch_limit should be 21
|
||||
dataset.values_for_column.assert_called_once_with(
|
||||
column_name="category",
|
||||
limit=21,
|
||||
denormalize_column=False,
|
||||
)
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_column_sample_data_denormalize_column(mock_find, mcp_server):
|
||||
"""Test that denormalize_column is set based on dataset.normalize_columns."""
|
||||
dataset = _create_mock_dataset(normalize_columns=False)
|
||||
dataset.values_for_column.return_value = ["val1"]
|
||||
mock_find.return_value = dataset
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
await client.call_tool(
|
||||
"get_column_sample_data",
|
||||
{"request": {"dataset_id": 1, "column_name": "col"}},
|
||||
)
|
||||
|
||||
dataset.values_for_column.assert_called_once_with(
|
||||
column_name="col",
|
||||
limit=21,
|
||||
denormalize_column=True, # normalize_columns=False -> denormalize=True
|
||||
)
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_column_sample_data_with_none_values(mock_find, mcp_server):
|
||||
"""Test that None values (from NULL columns) are handled correctly."""
|
||||
dataset = _create_mock_dataset()
|
||||
dataset.values_for_column.return_value = ["Active", None, "Inactive"]
|
||||
mock_find.return_value = dataset
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_column_sample_data",
|
||||
{"request": {"dataset_id": 1, "column_name": "status"}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["values"] == ["Active", None, "Inactive"]
|
||||
assert data["count"] == 3
|
||||
|
||||
|
||||
def test_get_column_sample_data_request_limit_validation():
|
||||
"""Test that Pydantic rejects invalid limit values."""
|
||||
with pytest.raises(ValidationError, match="greater than or equal to 1"):
|
||||
GetColumnSampleDataRequest(dataset_id=1, column_name="col", limit=0)
|
||||
|
||||
with pytest.raises(ValidationError, match="less than or equal to 100"):
|
||||
GetColumnSampleDataRequest(dataset_id=1, column_name="col", limit=101)
|
||||
|
||||
with pytest.raises(ValidationError, match="greater than or equal to 1"):
|
||||
GetColumnSampleDataRequest(dataset_id=1, column_name="col", limit=-1)
|
||||
|
||||
# Valid limits should work
|
||||
req = GetColumnSampleDataRequest(dataset_id=1, column_name="col", limit=1)
|
||||
assert req.limit == 1
|
||||
|
||||
req = GetColumnSampleDataRequest(dataset_id=1, column_name="col", limit=100)
|
||||
assert req.limit == 100
|
||||
|
||||
req = GetColumnSampleDataRequest(dataset_id=1, column_name="col")
|
||||
assert req.limit == 20
|
||||
@@ -124,36 +124,3 @@ def test_contribution_with_time_shift_columns():
|
||||
assert_array_equal(processed_df["a__1 week ago"].tolist(), [0.5, 0.5])
|
||||
assert_array_equal(processed_df["b__1 week ago"].tolist(), [0.25, 0.25])
|
||||
assert_array_equal(processed_df["c__1 week ago"].tolist(), [0.25, 0.25])
|
||||
|
||||
|
||||
def test_contribution_with_numeric_prefix_time_shifts():
|
||||
"""Time shifts like '2 weeks ago' and '22 weeks ago' share a numeric suffix;
|
||||
columns must be grouped by their exact offset, not by suffix matching."""
|
||||
df = DataFrame(
|
||||
{
|
||||
DTTM_ALIAS: [
|
||||
datetime(2020, 7, 16, 14, 49),
|
||||
datetime(2020, 7, 16, 14, 50),
|
||||
],
|
||||
"a": [3, 6],
|
||||
"b": [6, 3],
|
||||
"a__2 weeks ago": [1, 1],
|
||||
"b__2 weeks ago": [1, 1],
|
||||
"a__22 weeks ago": [2, 4],
|
||||
"b__22 weeks ago": [4, 2],
|
||||
}
|
||||
)
|
||||
processed_df = contribution(
|
||||
df,
|
||||
orientation=PostProcessingContributionOrientation.ROW,
|
||||
time_shifts=["2 weeks ago", "22 weeks ago"],
|
||||
)
|
||||
# Non-time-shift columns: a=3,b=6 -> a=1/3, b=2/3; a=6,b=3 -> a=2/3, b=1/3
|
||||
assert_array_equal(processed_df["a"].tolist(), [1 / 3, 2 / 3])
|
||||
assert_array_equal(processed_df["b"].tolist(), [2 / 3, 1 / 3])
|
||||
# "2 weeks ago" group: a=1,b=1 -> 0.5,0.5 each row
|
||||
assert_array_equal(processed_df["a__2 weeks ago"].tolist(), [0.5, 0.5])
|
||||
assert_array_equal(processed_df["b__2 weeks ago"].tolist(), [0.5, 0.5])
|
||||
# "22 weeks ago" group: a=2,b=4 -> 1/3,2/3; a=4,b=2 -> 2/3,1/3
|
||||
assert_array_equal(processed_df["a__22 weeks ago"].tolist(), [1 / 3, 2 / 3])
|
||||
assert_array_equal(processed_df["b__22 weeks ago"].tolist(), [2 / 3, 1 / 3])
|
||||
|
||||
@@ -31,7 +31,6 @@ from superset.utils.core import (
|
||||
cast_to_boolean,
|
||||
check_is_safe_zip,
|
||||
DateColumn,
|
||||
extract_dataframe_dtypes,
|
||||
FilterOperator,
|
||||
generic_find_constraint_name,
|
||||
generic_find_fk_constraint_name,
|
||||
@@ -647,9 +646,8 @@ def test_get_user_agent(mocker: MockerFixture, app_context: None) -> None:
|
||||
|
||||
@with_config(
|
||||
{
|
||||
"USER_AGENT_FUNC": lambda database, source: (
|
||||
f"{database.database_name} {source.name}"
|
||||
)
|
||||
"USER_AGENT_FUNC": lambda database,
|
||||
source: f"{database.database_name} {source.name}"
|
||||
}
|
||||
)
|
||||
def test_get_user_agent_custom(mocker: MockerFixture, app_context: None) -> None:
|
||||
@@ -1690,10 +1688,3 @@ def test_sanitize_url_blocks_dangerous():
|
||||
"""Test that dangerous URL schemes are blocked."""
|
||||
assert sanitize_url("javascript:alert('xss')") == ""
|
||||
assert sanitize_url("data:text/html,<script>alert(1)</script>") == ""
|
||||
|
||||
|
||||
def test_extract_dataframe_dtypes_with_duplicate_columns():
|
||||
"""extract_dataframe_dtypes should not crash on duplicate column names."""
|
||||
df = pd.DataFrame([[1, 2, 3]], columns=["a", "b", "a"])
|
||||
result = extract_dataframe_dtypes(df)
|
||||
assert len(result) == 3
|
||||
|
||||
Reference in New Issue
Block a user