Compare commits

...

12 Commits

Author SHA1 Message Date
Sam Firke
120a92728e add two new options for boxplot percentiles 2025-01-31 15:34:59 -05:00
Beto Dealmeida
1064ad5d58 fix: enforce ALERT_REPORTS_MAX_CUSTOM_SCREENSHOT_WIDTH (#32053) 2025-01-31 14:56:56 -05:00
Jack
7db0589340 fix(thumbnail cache): Enabling force parameter on screenshot/thumbnail cache (#31757)
Co-authored-by: Kamil Gabryjelski <kamil.gabryjelski@gmail.com>
2025-01-31 19:22:31 +01:00
JUST.in DO IT
c590e90c87 feat(sqllab): improve table metadata UI (#32051) 2025-01-31 15:19:37 -03:00
Vitor Avila
101d3fa78d chore: Re-enable asnyc event API tests (#32062) 2025-01-31 13:54:05 -03:00
Enzo Martellucci
468bb5f47a refactor(Radio): Upgrade Radio Component to Ant Design 5 (#32004) 2025-01-31 17:45:06 +01:00
Geido
1c1494d3e0 fix(DatePicker): Increase z-index over Modal (#32061) 2025-01-31 15:59:35 +01:00
Vitor Avila
5fc11fb706 chore: Add more database-related tests (follow up to #31948) (#32054) 2025-01-31 08:36:09 -03:00
JUST.in DO IT
f73d61a597 feat(sqllab): Replace FilterableTable by AgGrid Table (#29900) 2025-01-30 16:43:22 -08:00
Vitor Avila
3f46bcf142 chore: Skip the creation of secondary perms during catalog migrations (#32043) 2025-01-30 18:29:07 -03:00
Đỗ Trọng Hải
aa67525b70 fix(fe/explore): prevent runtime error when editing Dataset-origin Chart with empty title (#32031)
Signed-off-by: hainenber <dotronghai96@gmail.com>
2025-01-30 11:40:39 -07:00
Maxime Beauchemin
568f6d958b fix: Revert "fix: re-enable cypress checks" (#32045) 2025-01-30 14:20:55 -03:00
81 changed files with 3570 additions and 875 deletions

1
.gitignore vendored
View File

@@ -21,7 +21,6 @@
*.swp
__pycache__
.aider.*
.local
.cache
.bento*

View File

@@ -60,7 +60,6 @@ def run_cypress_for_test_file(
f"--browser {browser} "
f"--record --group {group_id} --tag {REPO},{GITHUB_EVENT_NAME} "
f"--ci-build-id {build_id} "
f"--wait-for-missing-groups "
f"-- {chrome_flags}"
)
else:

View File

@@ -599,7 +599,7 @@ describe('Drill by modal', () => {
]);
});
it('Radar Chart', () => {
it.skip('Radar Chart', () => {
testEchart('radar', 'Radar Chart', [
[182, 49],
[423, 91],

View File

@@ -335,7 +335,7 @@ describe('Drill to detail modal', () => {
});
});
describe('Bar Chart', () => {
describe.skip('Bar Chart', () => {
it('opens the modal with the correct filters', () => {
interceptSamples();
@@ -373,7 +373,7 @@ describe('Drill to detail modal', () => {
});
});
describe('Area Chart', () => {
describe.skip('Area Chart', () => {
it('opens the modal with the correct filters', () => {
testTimeChart('echarts_area');
});
@@ -407,7 +407,7 @@ describe('Drill to detail modal', () => {
});
});
describe('World Map', () => {
describe.skip('World Map', () => {
it('opens the modal with the correct filters', () => {
interceptSamples();
@@ -567,7 +567,7 @@ describe('Drill to detail modal', () => {
});
});
describe('Radar Chart', () => {
describe.skip('Radar Chart', () => {
it('opens the modal with the correct filters', () => {
interceptSamples();

View File

@@ -176,7 +176,7 @@ describe('Horizontal FilterBar', () => {
validateFilterNameOnDashboard(testItems.topTenChart.filterColumn);
});
it('should spot changes in "more filters" and apply their values', () => {
it.skip('should spot changes in "more filters" and apply their values', () => {
cy.intercept(`/api/v1/chart/data?form_data=**`).as('chart');
prepareDashboardFilters([
{ name: 'test_1', column: 'country_name', datasetId: 2 },
@@ -204,7 +204,7 @@ describe('Horizontal FilterBar', () => {
);
});
it('should focus filter and open "more filters" programmatically', () => {
it.skip('should focus filter and open "more filters" programmatically', () => {
prepareDashboardFilters([
{ name: 'test_1', column: 'country_name', datasetId: 2 },
{ name: 'test_2', column: 'country_code', datasetId: 2 },
@@ -231,7 +231,7 @@ describe('Horizontal FilterBar', () => {
cy.get('.ant-select-focused').should('be.visible');
});
it('should show tag count and one plain tag on focus and only count on blur in select ', () => {
it.skip('should show tag count and one plain tag on focus and only count on blur in select ', () => {
prepareDashboardFilters([
{ name: 'test_1', column: 'country_name', datasetId: 2 },
]);

View File

@@ -56,6 +56,8 @@
"@visx/xychart": "^3.5.1",
"abortcontroller-polyfill": "^1.7.8",
"ace-builds": "^1.36.3",
"ag-grid-community": "32.2.1",
"ag-grid-react": "32.2.1",
"antd": "4.10.3",
"antd-v5": "npm:antd@^5.18.0",
"bootstrap": "^3.4.1",
@@ -14478,6 +14480,32 @@
"node": ">= 10.0.0"
}
},
"node_modules/ag-charts-types": {
"version": "10.2.0",
"resolved": "https://registry.npmjs.org/ag-charts-types/-/ag-charts-types-10.2.0.tgz",
"integrity": "sha512-PUqH1QtugpYLnlbMdeSZVf5PpT1XZVsP69qN1JXhetLtQpVC28zaj7ikwu9CMA9N9b+dBboA9QcjUQUJZVUokQ=="
},
"node_modules/ag-grid-community": {
"version": "32.2.1",
"resolved": "https://registry.npmjs.org/ag-grid-community/-/ag-grid-community-32.2.1.tgz",
"integrity": "sha512-mrnm1DnLI9Wd408mMwP+6p7lbTC3FYgzNIUPygBvNh3SzZnbzTEUJF/BTKXi+MARWtG5S0IMUYy4hqBiLbobaQ==",
"dependencies": {
"ag-charts-types": "10.2.0"
}
},
"node_modules/ag-grid-react": {
"version": "32.2.1",
"resolved": "https://registry.npmjs.org/ag-grid-react/-/ag-grid-react-32.2.1.tgz",
"integrity": "sha512-lojTKsT/ncRZ81mrDa7qkIhZePfYlLCHIiAL1WbzL1mNPrglaa7QQKkE6hhhuAXvAm2uUhK1OfkMPnrqsEFldA==",
"dependencies": {
"ag-grid-community": "32.2.1",
"prop-types": "^15.8.1"
},
"peerDependencies": {
"react": "^16.3.0 || ^17.0.0 || ^18.0.0",
"react-dom": "^16.3.0 || ^17.0.0 || ^18.0.0"
}
},
"node_modules/agent-base": {
"version": "6.0.2",
"resolved": "https://registry.npmjs.org/agent-base/-/agent-base-6.0.2.tgz",

View File

@@ -123,6 +123,8 @@
"@visx/xychart": "^3.5.1",
"abortcontroller-polyfill": "^1.7.8",
"ace-builds": "^1.36.3",
"ag-grid-community": "32.2.1",
"ag-grid-react": "32.2.1",
"antd": "4.10.3",
"antd-v5": "npm:antd@^5.18.0",
"bootstrap": "^3.4.1",

View File

@@ -144,6 +144,13 @@ export interface SQLResultTableExtensionProps {
allowHTML?: boolean;
}
export interface SQLTablePreviewExtensionProps {
dbId: number;
catalog?: string;
schema: string;
tableName: string;
}
/**
* Interface for extensions to Slice Header
*/
@@ -229,4 +236,8 @@ export type Extensions = Partial<{
'sqleditor.extension.customAutocomplete': (
args: CustomAutoCompleteArgs,
) => CustomAutocomplete[] | undefined;
'sqleditor.extension.tablePreview': [
string,
ComponentType<SQLTablePreviewExtensionProps>,
][];
}>;

View File

@@ -94,7 +94,9 @@ const config: ControlPanelConfig = {
['Tukey', t('Tukey')],
['Min/max (no outliers)', t('Min/max (no outliers)')],
['2/98 percentiles', t('2/98 percentiles')],
['5/95 percentiles', t('5/95 percentiles')],
['9/91 percentiles', t('9/91 percentiles')],
['10/90 percentiles', t('10/90 percentiles')],
],
},
},

View File

@@ -36,7 +36,9 @@ export type BoxPlotFormDataWhiskerOptions =
| 'Tukey'
| 'Min/max (no outliers)'
| '2/98 percentiles'
| '9/91 percentiles';
| '5/95 percentiles'
| '9/91 percentiles'
| '10/90 percentiles';
export type BoxPlotFormXTickLayout =
| '45°'

View File

@@ -47,7 +47,8 @@ export const GlobalStyles = () => (
.ant-dropdown,
.ant-select-dropdown,
.antd5-modal-wrap,
.antd5-modal-mask {
.antd5-modal-mask,
.antd5-picker-dropdown {
z-index: ${theme.zIndex.max} !important;
}

View File

@@ -237,7 +237,7 @@ export function clearInactiveQueries(interval) {
return { type: CLEAR_INACTIVE_QUERIES, interval };
}
export function startQuery(query) {
export function startQuery(query, runPreviewOnly) {
Object.assign(query, {
id: query.id ? query.id : nanoid(11),
progress: 0,
@@ -245,7 +245,7 @@ export function startQuery(query) {
state: query.runAsync ? 'pending' : 'running',
cached: false,
});
return { type: START_QUERY, query };
return { type: START_QUERY, query, runPreviewOnly };
}
export function querySuccess(query, results) {
@@ -327,9 +327,9 @@ export function fetchQueryResults(query, displayLimit, timeoutInMs) {
};
}
export function runQuery(query) {
export function runQuery(query, runPreviewOnly) {
return function (dispatch) {
dispatch(startQuery(query));
dispatch(startQuery(query, runPreviewOnly));
const postPayload = {
client_id: query.id,
database_id: query.dbId,
@@ -947,29 +947,25 @@ export function mergeTable(table, query, prepend) {
export function addTable(queryEditor, tableName, catalogName, schemaName) {
return function (dispatch, getState) {
const query = getUpToDateQuery(getState(), queryEditor, queryEditor.id);
const { dbId } = getUpToDateQuery(getState(), queryEditor, queryEditor.id);
const table = {
dbId: query.dbId,
queryEditorId: query.id,
dbId,
queryEditorId: queryEditor.id,
catalog: catalogName,
schema: schemaName,
name: tableName,
};
dispatch(
mergeTable(
{
...table,
id: nanoid(11),
expanded: true,
},
null,
true,
),
mergeTable({
...table,
id: nanoid(11),
expanded: true,
}),
);
};
}
export function runTablePreviewQuery(newTable) {
export function runTablePreviewQuery(newTable, runPreviewOnly) {
return function (dispatch, getState) {
const {
sqlLab: { databases },
@@ -979,7 +975,7 @@ export function runTablePreviewQuery(newTable) {
if (database && !database.disable_data_preview) {
const dataPreviewQuery = {
id: nanoid(11),
id: newTable.previewQueryId ?? nanoid(11),
dbId,
catalog,
schema,
@@ -991,6 +987,9 @@ export function runTablePreviewQuery(newTable) {
ctas: false,
isDataPreview: true,
};
if (runPreviewOnly) {
return dispatch(runQuery(dataPreviewQuery, runPreviewOnly));
}
return Promise.all([
dispatch(
mergeTable(
@@ -1024,7 +1023,7 @@ export function syncTable(table, tableMetadata) {
return sync
.then(({ json: resultJson }) => {
const newTable = { ...table, id: resultJson.id };
const newTable = { ...table, id: `${resultJson.id}` };
dispatch(
mergeTable({
...newTable,
@@ -1032,9 +1031,6 @@ export function syncTable(table, tableMetadata) {
initialized: true,
}),
);
if (!table.dataPreviewQueryId) {
dispatch(runTablePreviewQuery({ ...tableMetadata, ...newTable }));
}
})
.catch(() =>
dispatch(

View File

@@ -1031,30 +1031,38 @@ describe('async actions', () => {
});
describe('runTablePreviewQuery', () => {
it('updates and runs data preview query when configured', () => {
expect.assertions(3);
const results = {
data: mockBigNumber,
query: { sqlEditorId: 'null', dbId: 1 },
query_id: 'efgh',
};
const tableName = 'table';
const catalogName = null;
const schemaName = 'schema';
const store = mockStore({
...initialState,
sqlLab: {
...initialState.sqlLab,
databases: {
1: { disable_data_preview: false },
},
},
});
const results = {
data: mockBigNumber,
query: { sqlEditorId: 'null', dbId: 1 },
query_id: 'efgh',
};
beforeEach(() => {
fetchMock.post(runQueryEndpoint, JSON.stringify(results), {
overwriteRoutes: true,
});
});
afterEach(() => {
store.clearActions();
fetchMock.resetHistory();
});
it('updates and runs data preview query when configured', () => {
expect.assertions(3);
const tableName = 'table';
const catalogName = null;
const schemaName = 'schema';
const store = mockStore({
...initialState,
sqlLab: {
...initialState.sqlLab,
databases: {
1: { disable_data_preview: false },
},
},
});
const expectedActionTypes = [
actions.MERGE_TABLE, // addTable (data preview)
actions.START_QUERY, // runQuery (data preview)
@@ -1075,6 +1083,30 @@ describe('async actions', () => {
expect(fetchMock.calls(updateTabStateEndpoint)).toHaveLength(0);
});
});
it('runs data preview query only', () => {
const expectedActionTypes = [
actions.START_QUERY, // runQuery (data preview)
actions.QUERY_SUCCESS, // querySuccess
];
const request = actions.runTablePreviewQuery(
{
dbId: 1,
name: tableName,
catalog: catalogName,
schema: schemaName,
},
true,
);
return request(store.dispatch, store.getState).then(() => {
expect(store.getActions().map(a => a.type)).toEqual(
expectedActionTypes,
);
expect(fetchMock.calls(runQueryEndpoint)).toHaveLength(1);
// tab state is not updated, since the query is a data preview
expect(fetchMock.calls(updateTabStateEndpoint)).toHaveLength(0);
});
});
});
describe('expandTable', () => {

View File

@@ -16,8 +16,6 @@
* specific language governing permissions and limitations
* under the License.
*/
import configureStore from 'redux-mock-store';
import thunk from 'redux-thunk';
import reducerIndex from 'spec/helpers/reducerIndex';
import { render, waitFor, createStore } from 'spec/helpers/testing-library';
import { QueryEditor } from 'src/SqlLab/types';
@@ -34,9 +32,6 @@ import {
} from 'src/SqlLab/actions/sqlLab';
import fetchMock from 'fetch-mock';
const middlewares = [thunk];
const mockStore = configureStore(middlewares);
fetchMock.get('glob:*/api/v1/database/*/function_names/', {
function_names: [],
});
@@ -79,7 +74,8 @@ describe('AceEditorWrapper', () => {
});
it('renders ace editor including sql value', async () => {
const { getByTestId } = setup(defaultQueryEditor, mockStore(initialState));
const store = createStore(initialState, reducerIndex);
const { getByTestId } = setup(defaultQueryEditor, store);
await waitFor(() => expect(getByTestId('react-ace')).toBeInTheDocument());
expect(getByTestId('react-ace')).toHaveTextContent(
@@ -89,9 +85,8 @@ describe('AceEditorWrapper', () => {
it('renders current sql for unrelated unsaved changes', () => {
const expectedSql = 'SELECT updated_column\nFROM updated_table\nWHERE';
const { getByTestId } = setup(
defaultQueryEditor,
mockStore({
const store = createStore(
{
...initialState,
sqlLab: {
...initialState.sqlLab,
@@ -100,8 +95,10 @@ describe('AceEditorWrapper', () => {
sql: expectedSql,
},
},
}),
},
reducerIndex,
);
const { getByTestId } = setup(defaultQueryEditor, store);
expect(getByTestId('react-ace')).not.toHaveTextContent(
JSON.stringify({ value: expectedSql }).slice(1, -1),
@@ -122,7 +119,7 @@ describe('AceEditorWrapper', () => {
queryEditorSetCursorPosition(defaultQueryEditor, updatedCursorPosition),
);
expect(FullSQLEditor).toHaveBeenCalledTimes(renderCount);
store.dispatch(queryEditorSetDb(defaultQueryEditor, 1));
store.dispatch(queryEditorSetDb(defaultQueryEditor, 2));
expect(FullSQLEditor).toHaveBeenCalledTimes(renderCount + 1);
});
});

View File

@@ -202,6 +202,7 @@ test('returns column keywords among selected tables', async () => {
{
name: expectColumn,
type: 'VARCHAR',
longType: 'VARCHAR',
},
],
},
@@ -223,6 +224,7 @@ test('returns column keywords among selected tables', async () => {
{
name: unexpectedColumn,
type: 'VARCHAR',
longType: 'VARCHAR',
},
],
},

View File

@@ -47,7 +47,7 @@ const SqlLabStyles = styled.div`
left: 0;
padding: 0 ${theme.gridUnit * 2}px;
pre {
pre:not(.code) {
padding: 0 !important;
margin: 0;
border: none;

View File

@@ -354,7 +354,7 @@ describe('ResultSet', () => {
);
});
const { getByRole } = setup(mockedProps, mockStore(initialState));
expect(getByRole('table')).toBeInTheDocument();
expect(getByRole('treegrid')).toBeInTheDocument();
});
test('renders if there is a limit in query.results but not queryLimit', async () => {
@@ -372,7 +372,7 @@ describe('ResultSet', () => {
},
}),
);
expect(getByRole('table')).toBeInTheDocument();
expect(getByRole('treegrid')).toBeInTheDocument();
});
test('Async queries - renders "Fetch data preview" button when data preview has no results', () => {
@@ -400,7 +400,7 @@ describe('ResultSet', () => {
name: /fetch data preview/i,
}),
).toBeVisible();
expect(screen.queryByRole('table')).not.toBeInTheDocument();
expect(screen.queryByRole('treegrid')).not.toBeInTheDocument();
});
test('Async queries - renders "Refetch results" button when a query has no results', () => {
@@ -429,7 +429,7 @@ describe('ResultSet', () => {
name: /refetch results/i,
}),
).toBeVisible();
expect(screen.queryByRole('table')).not.toBeInTheDocument();
expect(screen.queryByRole('treegrid')).not.toBeInTheDocument();
});
test('Async queries - renders on the first call', () => {
@@ -449,7 +449,7 @@ describe('ResultSet', () => {
},
}),
);
expect(screen.getByRole('table')).toBeVisible();
expect(screen.getByRole('treegrid')).toBeVisible();
expect(
screen.queryByRole('button', {
name: /fetch data preview/i,

View File

@@ -19,8 +19,8 @@
import { useCallback, useState, FormEvent } from 'react';
import { Radio } from 'src/components/Radio';
import { RadioChangeEvent, AsyncSelect } from 'src/components';
import { Radio, RadioChangeEvent } from 'src/components/Radio';
import { AsyncSelect } from 'src/components';
import { Input } from 'src/components/Input';
import StyledModal from 'src/components/Modal';
import Button from 'src/components/Button';

View File

@@ -28,21 +28,25 @@ interface ShowSQLProps {
sql: string;
title: string;
tooltipText: string;
triggerNode?: React.ReactNode;
}
export default function ShowSQL({
tooltipText,
title,
sql: sqlString,
triggerNode,
}: ShowSQLProps) {
return (
<ModalTrigger
modalTitle={title}
triggerNode={
<IconTooltip
className="fa fa-eye pull-left m-l-2"
tooltip={tooltipText}
/>
triggerNode || (
<IconTooltip
className="fa fa-eye pull-left m-l-2"
tooltip={tooltipText}
/>
)
}
modalBody={
<div>

View File

@@ -135,7 +135,7 @@ test('should render empty result state when latestQuery is empty', () => {
expect(resultPanel).toHaveTextContent('Run a query to display results');
});
test('should render tabs for table preview queries', () => {
test('should render tabs for table metadata view', () => {
const { getAllByRole } = render(<SouthPane {...mockedProps} />, {
useRedux: true,
initialState: mockState,
@@ -145,7 +145,7 @@ test('should render tabs for table preview queries', () => {
expect(tabs).toHaveLength(mockState.sqlLab.tables.length + 2);
expect(tabs[0]).toHaveTextContent('Results');
expect(tabs[1]).toHaveTextContent('Query history');
mockState.sqlLab.tables.forEach(({ name }, index) => {
expect(tabs[index + 2]).toHaveTextContent(`Preview: \`${name}\``);
mockState.sqlLab.tables.forEach(({ name, schema }, index) => {
expect(tabs[index + 2]).toHaveTextContent(`${schema}.${name}`);
});
});

View File

@@ -16,24 +16,25 @@
* specific language governing permissions and limitations
* under the License.
*/
import { createRef, useMemo } from 'react';
import { createRef, useCallback, useMemo } from 'react';
import { shallowEqual, useDispatch, useSelector } from 'react-redux';
import { nanoid } from 'nanoid';
import Tabs from 'src/components/Tabs';
import { styled, t } from '@superset-ui/core';
import { css, styled, t } from '@superset-ui/core';
import { setActiveSouthPaneTab } from 'src/SqlLab/actions/sqlLab';
import { removeTables, setActiveSouthPaneTab } from 'src/SqlLab/actions/sqlLab';
import Label from 'src/components/Label';
import Icons from 'src/components/Icons';
import { SqlLabRootState } from 'src/SqlLab/types';
import QueryHistory from '../QueryHistory';
import ResultSet from '../ResultSet';
import {
STATUS_OPTIONS,
STATE_TYPE_MAP,
STATUS_OPTIONS_LOCALIZED,
} from '../../constants';
import Results from './Results';
import TablePreview from '../TablePreview';
const TAB_HEIGHT = 130;
@@ -98,31 +99,45 @@ const SouthPane = ({
}),
shallowEqual,
);
const queries = useSelector(
({ sqlLab: { queries } }: SqlLabRootState) => Object.keys(queries),
shallowEqual,
);
const activeSouthPaneTab =
useSelector<SqlLabRootState, string>(
state => state.sqlLab.activeSouthPaneTab as string,
) ?? 'Results';
const querySet = useMemo(() => new Set(queries), [queries]);
const dataPreviewQueries = useMemo(
const pinnedTables = useMemo(
() =>
tables.filter(
({ dataPreviewQueryId, queryEditorId: qeId }) =>
dataPreviewQueryId &&
queryEditorId === qeId &&
querySet.has(dataPreviewQueryId),
({ queryEditorId: qeId }) => String(queryEditorId) === qeId,
),
[queryEditorId, tables, querySet],
[queryEditorId, tables],
);
const pinnedTableKeys = useMemo(
() =>
Object.fromEntries(
pinnedTables.map(({ id, dbId, catalog, schema, name }) => [
id,
[dbId, catalog, schema, name].join(':'),
]),
),
[pinnedTables],
);
const innerTabContentHeight = height - TAB_HEIGHT;
const southPaneRef = createRef<HTMLDivElement>();
const switchTab = (id: string) => {
dispatch(setActiveSouthPaneTab(id));
};
const removeTable = useCallback(
(key, action) => {
if (action === 'remove') {
const table = pinnedTables.find(
({ dbId, catalog, schema, name }) =>
[dbId, catalog, schema, name].join(':') === key,
);
dispatch(removeTables([table]));
}
},
[dispatch, queryEditorId],
);
return offline ? (
<Label className="m-r-3" type={STATE_TYPE_MAP[STATUS_OPTIONS.offline]}>
@@ -136,14 +151,17 @@ const SouthPane = ({
ref={southPaneRef}
>
<Tabs
activeKey={activeSouthPaneTab}
type="editable-card"
activeKey={pinnedTableKeys[activeSouthPaneTab] || activeSouthPaneTab}
className="SouthPaneTabs"
onChange={switchTab}
id={nanoid(11)}
fullWidth={false}
animated={false}
onEdit={removeTable}
hideAdd
>
<Tabs.TabPane tab={t('Results')} key="Results">
<Tabs.TabPane tab={t('Results')} key="Results" closable={false}>
<Results
height={innerTabContentHeight}
latestQueryId={latestQueryId}
@@ -151,32 +169,37 @@ const SouthPane = ({
defaultQueryLimit={defaultQueryLimit}
/>
</Tabs.TabPane>
<Tabs.TabPane tab={t('Query history')} key="History">
<Tabs.TabPane tab={t('Query history')} key="History" closable={false}>
<QueryHistory
queryEditorId={queryEditorId}
displayLimit={displayLimit}
latestQueryId={latestQueryId}
/>
</Tabs.TabPane>
{dataPreviewQueries.map(
({ name, dataPreviewQueryId }) =>
dataPreviewQueryId && (
<Tabs.TabPane
tab={t('Preview: `%s`', decodeURIComponent(name))}
key={dataPreviewQueryId}
>
<ResultSet
queryId={dataPreviewQueryId}
visualize={false}
csv={false}
cache
height={innerTabContentHeight}
displayLimit={displayLimit}
defaultQueryLimit={defaultQueryLimit}
{pinnedTables.map(({ id, dbId, catalog, schema, name }) => (
<Tabs.TabPane
tab={
<>
<Icons.Table
iconSize="s"
css={css`
margin-bottom: 2px;
margin-right: 4px;
`}
/>
</Tabs.TabPane>
),
)}
{`${schema}.${decodeURIComponent(name)}`}
</>
}
key={pinnedTableKeys[id]}
>
<TablePreview
dbId={dbId}
catalog={catalog}
schema={schema}
tableName={name}
/>
</Tabs.TabPane>
))}
</Tabs>
</StyledPane>
);

View File

@@ -27,6 +27,7 @@ import {
initialState,
defaultQueryEditor,
extraQueryEditor1,
extraQueryEditor2,
} from 'src/SqlLab/fixtures';
import type { RootState } from 'src/views/store';
import type { Store } from 'redux';
@@ -206,13 +207,13 @@ test('should toggle the table when the header is clicked', async () => {
});
test('When changing database the schema and table list must be updated', async () => {
const { rerender } = await renderAndWait(mockedProps, undefined, {
const reduxState = {
...initialState,
sqlLab: {
...initialState.sqlLab,
unsavedQueryEditor: {
id: defaultQueryEditor.id,
schema: 'new_schema',
schema: 'db1_schema',
},
queryEditors: [
defaultQueryEditor,
@@ -223,16 +224,22 @@ test('When changing database the schema and table list must be updated', async (
},
],
tables: [
table,
{
...table,
dbId: defaultQueryEditor.dbId,
schema: 'db1_schema',
},
{
...table,
dbId: 2,
schema: 'new_schema',
name: 'new_table',
queryEditorId: extraQueryEditor1.id,
},
],
},
});
};
const { rerender } = await renderAndWait(mockedProps, undefined, reduxState);
expect(screen.getAllByText(/main/i)[0]).toBeInTheDocument();
expect(screen.getAllByText(/ab_user/i)[0]).toBeInTheDocument();
@@ -250,30 +257,60 @@ test('When changing database the schema and table list must be updated', async (
);
const updatedDbSelector = await screen.findAllByText(/new_db/i);
expect(updatedDbSelector[0]).toBeInTheDocument();
const updatedTableSelector = await screen.findAllByText(/new_table/i);
expect(updatedTableSelector[0]).toBeInTheDocument();
const select = screen.getByRole('combobox', {
name: 'Select schema or type to search schemas',
});
userEvent.click(select);
expect(
await screen.findByRole('option', { name: 'main' }),
).toBeInTheDocument();
expect(
await screen.findByRole('option', { name: 'new_schema' }),
).toBeInTheDocument();
rerender(
<SqlEditorLeftBar
{...mockedProps}
database={{
userEvent.click(screen.getAllByText('new_schema')[1]);
const updatedTableSelector = await screen.findAllByText(/new_table/i);
expect(updatedTableSelector[0]).toBeInTheDocument();
});
test('display no compatible schema found when schema api throws errors', async () => {
const reduxState = {
...initialState,
sqlLab: {
...initialState.sqlLab,
queryEditors: [
{
...extraQueryEditor2,
dbId: 3,
schema: undefined,
},
],
},
};
await renderAndWait(
{
...mockedProps,
queryEditorId: extraQueryEditor2.id,
database: {
id: 3,
database_name: 'unauth_db',
backend: 'minervasql',
}}
queryEditorId={extraQueryEditor1.id}
/>,
},
},
undefined,
reduxState,
);
await waitFor(() =>
expect(fetchMock.calls('glob:*/api/v1/database/3/schemas/?*')).toHaveLength(
1,
),
);
const select = screen.getByRole('combobox', {
name: 'Select schema or type to search schemas',
});
userEvent.click(select);
expect(
await screen.findByText('No compatible schema found'),

View File

@@ -101,7 +101,7 @@ const SqlEditorLeftBar = ({
queryEditorId,
height = 500,
}: SqlEditorLeftBarProps) => {
const tables = useSelector<SqlLabRootState, Table[]>(
const allSelectedTables = useSelector<SqlLabRootState, Table[]>(
({ sqlLab }) =>
sqlLab.tables.filter(table => table.queryEditorId === queryEditorId),
shallowEqual,
@@ -117,7 +117,14 @@ const SqlEditorLeftBar = ({
const [userSelectedDb, setUserSelected] = useState<DatabaseObject | null>(
null,
);
const { catalog, schema } = queryEditor;
const { dbId, catalog, schema } = queryEditor;
const tables = useMemo(
() =>
allSelectedTables.filter(
table => table.dbId === dbId && table.schema === schema,
),
[allSelectedTables, dbId, schema],
);
useEffect(() => {
const bool = querystring.parse(window.location.search).db;

View File

@@ -92,7 +92,7 @@ test('has 4 IconTooltip elements', async () => {
initialState,
});
await waitFor(() =>
expect(getAllByTestId('mock-icon-tooltip')).toHaveLength(5),
expect(getAllByTestId('mock-icon-tooltip')).toHaveLength(6),
);
});
@@ -112,7 +112,7 @@ test('fades table', async () => {
initialState,
});
await waitFor(() =>
expect(getAllByTestId('mock-icon-tooltip')).toHaveLength(5),
expect(getAllByTestId('mock-icon-tooltip')).toHaveLength(6),
);
const style = window.getComputedStyle(getAllByTestId('fade')[0]);
expect(style.opacity).toBe('0');
@@ -133,7 +133,7 @@ test('sorts columns', async () => {
},
);
await waitFor(() =>
expect(getAllByTestId('mock-icon-tooltip')).toHaveLength(5),
expect(getAllByTestId('mock-icon-tooltip')).toHaveLength(6),
);
expect(
getAllByTestId('mock-column-element').map(el => el.textContent),
@@ -160,7 +160,7 @@ test('removes the table', async () => {
},
);
await waitFor(() =>
expect(getAllByTestId('mock-icon-tooltip')).toHaveLength(5),
expect(getAllByTestId('mock-icon-tooltip')).toHaveLength(6),
);
expect(fetchMock.calls(updateTableSchemaEndpoint)).toHaveLength(0);
fireEvent.click(getByText('Remove table preview'));
@@ -193,7 +193,7 @@ test('refreshes table metadata when triggered', async () => {
},
);
await waitFor(() =>
expect(getAllByTestId('mock-icon-tooltip')).toHaveLength(5),
expect(getAllByTestId('mock-icon-tooltip')).toHaveLength(6),
);
expect(fetchMock.calls(updateTableSchemaEndpoint)).toHaveLength(0);
expect(fetchMock.calls(getTableMetadataEndpoint)).toHaveLength(1);

View File

@@ -0,0 +1,173 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { type ReactChild } from 'react';
import fetchMock from 'fetch-mock';
import { table, initialState } from 'src/SqlLab/fixtures';
import {
render,
waitFor,
fireEvent,
screen,
} from 'spec/helpers/testing-library';
import TablePreview from '.';
jest.mock(
'src/components/FilterableTable',
() =>
({ data }: { data: Record<string, any>[] }) => (
<div>
{data.map((record, i) => (
<div key={i} data-test="mock-record-row">
{JSON.stringify(record)}
</div>
))}
</div>
),
);
jest.mock(
'react-virtualized-auto-sizer',
() =>
({ children }: { children: (params: { height: number }) => ReactChild }) =>
children({ height: 500 }),
);
jest.mock('src/components/IconTooltip', () => ({
IconTooltip: ({
onClick,
tooltip,
}: {
onClick: () => void;
tooltip: string;
}) => (
<button type="button" data-test="mock-icon-tooltip" onClick={onClick}>
{tooltip}
</button>
),
}));
const getTableMetadataEndpoint =
/\/api\/v1\/database\/\d+\/table_metadata\/(?:\?.*)?$/;
const getExtraTableMetadataEndpoint =
/\/api\/v1\/database\/\d+\/table_metadata\/extra\/(?:\?.*)?$/;
const fetchPreviewEndpoint = 'glob:*/api/v1/sqllab/execute/';
beforeEach(() => {
fetchMock.get(getTableMetadataEndpoint, table);
fetchMock.get(getExtraTableMetadataEndpoint, {});
fetchMock.post(fetchPreviewEndpoint, `{ "data": 123 }`);
});
afterEach(() => {
fetchMock.reset();
});
const mockedProps = {
dbId: table.dbId,
catalog: table.catalog,
schema: table.schema,
tableName: table.name,
};
test('renders columns', async () => {
const { getAllByTestId, queryByText } = render(
<TablePreview {...mockedProps} />,
{
useRedux: true,
initialState,
},
);
await waitFor(() =>
expect(getAllByTestId('mock-record-row')).toHaveLength(
table.columns.length,
),
);
expect(queryByText(`Columns (${table.columns.length})`)).toBeInTheDocument();
});
test('renders indexes', async () => {
const { queryByText } = render(<TablePreview {...mockedProps} />, {
useRedux: true,
initialState,
});
await waitFor(() =>
expect(fetchMock.calls(getTableMetadataEndpoint)).toHaveLength(1),
);
expect(queryByText(`Indexes (${table.indexes.length})`)).toBeInTheDocument();
});
test('renders preview', async () => {
const { getByText } = render(<TablePreview {...mockedProps} />, {
useRedux: true,
initialState: {
...initialState,
sqlLab: {
...initialState.sqlLab,
databases: {
[table.dbId]: {
id: table.dbId,
database_name: 'mysql',
disable_data_preview: false,
},
},
},
},
});
await waitFor(() =>
expect(fetchMock.calls(getTableMetadataEndpoint)).toHaveLength(1),
);
expect(fetchMock.calls(fetchPreviewEndpoint)).toHaveLength(0);
fireEvent.click(getByText('Data preview'));
await waitFor(() =>
expect(fetchMock.calls(fetchPreviewEndpoint)).toHaveLength(1),
);
});
describe('table actions', () => {
test('refreshes table metadata when triggered', async () => {
const { getByRole, getByText } = render(<TablePreview {...mockedProps} />, {
useRedux: true,
initialState,
});
await waitFor(() =>
expect(fetchMock.calls(getTableMetadataEndpoint)).toHaveLength(1),
);
const menuButton = getByRole('button', { name: /Table actions/i });
fireEvent.click(menuButton);
fireEvent.click(getByText('Refresh table schema'));
await waitFor(() =>
expect(fetchMock.calls(getTableMetadataEndpoint)).toHaveLength(2),
);
});
test('shows CREATE VIEW statement', async () => {
const { getByRole, getByText } = render(<TablePreview {...mockedProps} />, {
useRedux: true,
initialState,
});
await waitFor(() =>
expect(fetchMock.calls(getTableMetadataEndpoint)).toHaveLength(1),
);
const menuButton = getByRole('button', { name: /Table actions/i });
fireEvent.click(menuButton);
fireEvent.click(getByText('Show CREATE VIEW statement'));
await waitFor(() =>
expect(
screen.queryByRole('dialog', { name: 'CREATE VIEW statement' }),
).toBeInTheDocument(),
);
});
});

View File

@@ -0,0 +1,430 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { type FC, useCallback, useMemo, useRef, useState } from 'react';
import { shallowEqual, useDispatch, useSelector } from 'react-redux';
import { nanoid } from 'nanoid';
import {
ClientErrorObject,
css,
getExtensionsRegistry,
SafeMarkdown,
styled,
t,
} from '@superset-ui/core';
import AutoSizer from 'react-virtualized-auto-sizer';
import Icons from 'src/components/Icons';
import type { SqlLabRootState } from 'src/SqlLab/types';
import {
Skeleton,
AntdBreadcrumb as Breadcrumb,
AntdDropdown,
} from 'src/components';
import FilterableTable from 'src/components/FilterableTable';
import Tabs from 'src/components/Tabs';
import {
tableApiUtil,
TableMetaData,
useTableExtendedMetadataQuery,
useTableMetadataQuery,
} from 'src/hooks/apiResources';
import { runTablePreviewQuery } from 'src/SqlLab/actions/sqlLab';
import Alert from 'src/components/Alert';
import { Menu } from 'src/components/Menu';
import Card from 'src/components/Card';
import CopyToClipboard from 'src/components/CopyToClipboard';
import ResultSet from '../ResultSet';
import ShowSQL from '../ShowSQL';
type Props = {
dbId: number | string;
schema?: string;
catalog?: string | null;
tableName: string;
};
const extensionsRegistry = getExtensionsRegistry();
const COLUMN_KEYS = ['column_name', 'column_type', 'keys', 'comment'];
const MENUS = [
{
key: 'refresh-table',
label: t('Refresh table schema'),
icon: <i aria-hidden className="fa fa-refresh" />,
},
{
key: 'copy-select-statement',
label: t('Copy SELECT statement'),
icon: <i aria-hidden className="fa fa-clipboard m-l-2" />,
},
{
key: 'show-create-view-statement',
label: t('Show CREATE VIEW statement'),
icon: <i aria-hidden className="fa fa-eye" />,
},
];
const TAB_HEADER_HEIGHT = 80;
const PREVIEW_TOP_ACTION_HEIGHT = 30;
const PREVIEW_QUERY_LIMIT = 100;
const Title = styled.div`
display: flex;
flex-direction: row;
align-items: center;
column-gap: ${({ theme }) => theme.gridUnit}px;
font-size: ${({ theme }) => theme.typography.sizes.l}px;
font-weight: ${({ theme }) => theme.typography.weights.bold};
`;
const renderWell = (partitions: TableMetaData['partitions']) => {
if (!partitions) {
return null;
}
const { partitionQuery } = partitions;
let partitionClipBoard;
if (partitionQuery) {
const tt = t('Copy partition query to clipboard');
partitionClipBoard = (
<CopyToClipboard
text={partitionQuery}
shouldShowText={false}
tooltipText={tt}
copyNode={<i className="fa fa-clipboard" />}
/>
);
}
const latest = Object.entries(partitions.latest || [])
.map(([key, value]) => `${key}=${value}`)
.join('/');
return (
<Card size="small">
<div>
<small>
{t('latest partition:')} {latest}
</small>{' '}
{partitionClipBoard}
</div>
</Card>
);
};
const TablePreview: FC<Props> = ({ dbId, catalog, schema, tableName }) => {
const dispatch = useDispatch();
const [databaseName, backend, disableDataPreview] = useSelector<
SqlLabRootState,
string[]
>(
({ sqlLab: { databases } }) => [
databases[dbId]?.database_name,
databases[dbId]?.backend,
databases[dbId]?.disable_data_preview,
],
shallowEqual,
);
const copyStatementActionRef = useRef<HTMLButtonElement | null>(null);
const showViewStatementActionRef = useRef<HTMLButtonElement | null>(null);
const [previewQueryId, setPreviewQueryId] = useState<string>();
const {
currentData: tableMetadata,
isLoading: isMetadataLoading,
isFetching: isMetadataRefreshing,
isError: hasMetadataError,
error: metadataError,
} = useTableMetadataQuery(
{
dbId,
catalog,
schema: schema ?? '',
table: tableName ?? '',
},
{ skip: !dbId || !schema || !tableName },
);
const { currentData: tableExtendedMetadata, error: metadataExtrError } =
useTableExtendedMetadataQuery(
{
dbId,
catalog,
schema: schema ?? '',
table: tableName ?? '',
},
{ skip: !dbId || !schema || !tableName },
);
const data = useMemo(
() =>
(tableMetadata?.columns.length ?? 0) > 0
? tableMetadata?.columns.map(
({ name, type, longType, keys, comment }) => ({
column_name: name,
column_type: longType || type,
keys,
comment,
}),
)
: undefined,
[tableMetadata],
);
const hasKeys = useMemo(
() => data?.some(({ keys }) => Boolean(keys?.length)),
[data],
);
const columns = useMemo(
() => (hasKeys ? COLUMN_KEYS : COLUMN_KEYS.filter(name => name !== 'keys')),
[hasKeys],
);
const tableData = {
dataPreviewQueryId: previewQueryId,
...tableMetadata,
...tableExtendedMetadata,
};
const refreshTableMetadata = () => {
dispatch(
tableApiUtil.invalidateTags([{ type: 'TableMetadatas', id: tableName }]),
);
};
const ResultTable =
extensionsRegistry.get('sqleditor.extension.resultTable') ??
FilterableTable;
const customTabs =
extensionsRegistry.get('sqleditor.extension.tablePreview') ?? [];
const onTabSwitch = useCallback(
(activeKey: string) => {
if (activeKey === 'sample' && !previewQueryId) {
const queryId = nanoid(11);
dispatch(
runTablePreviewQuery(
{
previewQueryId: queryId,
dbId,
catalog,
schema,
name: tableName,
selectStar: tableData.selectStar,
},
true,
),
);
setPreviewQueryId(queryId);
}
},
[
previewQueryId,
dbId,
catalog,
schema,
tableName,
tableData.selectStar,
dispatch,
],
);
const dropdownMenu = useMemo(() => {
let menus = [...MENUS];
if (!tableData.selectStar) {
menus = menus.filter(({ key }) => key !== 'copy-select-statement');
}
if (!tableData.view) {
menus = menus.filter(({ key }) => key !== 'show-create-view-statement');
}
return menus;
}, [tableData.view, tableData.selectStar]);
if (isMetadataLoading) {
return <Skeleton active />;
}
if (hasMetadataError || metadataExtrError) {
return (
<Alert
type="warning"
message={
((metadataError || metadataExtrError) as ClientErrorObject)?.error
}
/>
);
}
if (!data) {
return (
<Alert
type="warning"
message={t('Cannot find the table (%s) metadata.', tableName)}
closable={false}
/>
);
}
return (
<div
css={css`
height: 100%;
display: flex;
flex-direction: column;
`}
>
<Breadcrumb separator=">">
<Breadcrumb.Item>{backend}</Breadcrumb.Item>
<Breadcrumb.Item>{databaseName}</Breadcrumb.Item>
{catalog && <Breadcrumb.Item>{catalog}</Breadcrumb.Item>}
{schema && <Breadcrumb.Item>{schema}</Breadcrumb.Item>}
<Breadcrumb.Item> </Breadcrumb.Item>
</Breadcrumb>
<div style={{ display: 'none' }}>
<CopyToClipboard
copyNode={
<button type="button" ref={copyStatementActionRef}>
invisible button
</button>
}
text={tableData.selectStar}
shouldShowText={false}
/>
{tableData.view && (
<ShowSQL
sql={tableData.view}
tooltipText={t('Show CREATE VIEW statement')}
title={t('CREATE VIEW statement')}
triggerNode={
<button type="button" ref={showViewStatementActionRef}>
invisible button
</button>
}
/>
)}
</div>
<Title>
<Icons.Table iconSize="l" />
{tableName}
<AntdDropdown
overlay={
<Menu
onClick={({ key }) => {
if (key === 'refresh-table') {
refreshTableMetadata();
}
if (key === 'copy-select-statement') {
copyStatementActionRef.current?.click();
}
if (key === 'show-create-view-statement') {
showViewStatementActionRef.current?.click();
}
}}
items={dropdownMenu}
/>
}
trigger={['click']}
>
<Icons.DownSquareOutlined
iconSize="m"
style={{ marginTop: 2, marginLeft: 4 }}
aria-label={t('Table actions')}
/>
</AntdDropdown>
</Title>
{isMetadataRefreshing ? (
<Skeleton active />
) : (
<>
{tableData.comment && <SafeMarkdown source={tableData.comment} />}
{renderWell(tableData.partitions)}
<div
css={css`
flex: 1 1 auto;
`}
>
<AutoSizer disableWidth>
{({ height }) => (
<Tabs
fullWidth={false}
onTabClick={onTabSwitch}
css={css`
height: ${height}px;
`}
>
<Tabs.TabPane
tab={t('Columns (%s)', data.length)}
key="columns"
>
<ResultTable
queryId="table-columns"
height={height - TAB_HEADER_HEIGHT}
data={data}
orderedColumnKeys={columns}
/>
</Tabs.TabPane>
{tableData?.selectStar && !disableDataPreview && (
<Tabs.TabPane tab={t('Data preview')} key="sample">
{previewQueryId && (
<ResultSet
queryId={previewQueryId}
visualize={false}
csv={false}
cache
height={
height -
TAB_HEADER_HEIGHT -
PREVIEW_TOP_ACTION_HEIGHT
}
displayLimit={PREVIEW_QUERY_LIMIT}
defaultQueryLimit={PREVIEW_QUERY_LIMIT}
/>
)}
</Tabs.TabPane>
)}
{tableData?.indexes && tableData.indexes.length > 0 && (
<Tabs.TabPane
tab={t('Indexes (%s)', tableData.indexes.length)}
key="indexes"
>
{tableData.indexes.map((ix, i) => (
<pre className="code" key={i}>
{JSON.stringify(ix, null, ' ')}
</pre>
))}
</Tabs.TabPane>
)}
{tableData?.metadata && (
<Tabs.TabPane tab={t('Metadata')} key="metadata">
<ResultTable
queryId="table-metadata"
height={height - TAB_HEADER_HEIGHT}
data={Object.entries(tableData.metadata).map(
([name, value]) => ({ name, value }),
)}
orderedColumnKeys={['name', 'value']}
/>
</Tabs.TabPane>
)}
{customTabs.map(([title, ExtComponent]) => (
<Tabs.TabPane tab={title} key={title}>
<ExtComponent
dbId={Number(dbId)}
schema={schema ?? ''}
tableName={tableName}
/>
</Tabs.TabPane>
))}
</Tabs>
)}
</AutoSizer>
</div>
</>
)}
</div>
);
};
export default TablePreview;

View File

@@ -38,9 +38,10 @@ export const table = {
selectStar: 'SELECT * FROM ab_user',
queryEditorId: 'dfsadfs',
catalog: null,
schema: 'superset',
schema: 'main',
name: 'ab_user',
id: 'r11Vgt60',
view: 'SELECT * FROM ab_user',
dataPreviewQueryId: null,
partitions: {
cols: ['username'],
@@ -188,7 +189,7 @@ export const defaultQueryEditor = {
version: LatestQueryEditorVersion,
id: 'dfsadfs',
autorun: false,
dbId: undefined,
dbId: 1,
latestQueryId: null,
selectedText: undefined,
sql: 'SELECT *\nFROM\nWHERE',

View File

@@ -167,9 +167,10 @@ describe('getInitialState', () => {
table: 'table1',
tab_state_id: 1,
description: {
name: 'table1',
columns: [
{ name: 'id', type: 'INT' },
{ name: 'column2', type: 'STRING' },
{ name: 'id', type: 'INT', longType: 'INT()' },
{ name: 'column2', type: 'STRING', longType: 'STRING()' },
],
},
},
@@ -178,9 +179,10 @@ describe('getInitialState', () => {
table: 'table2',
tab_state_id: 1,
description: {
name: 'table2',
columns: [
{ name: 'id', type: 'INT' },
{ name: 'column2', type: 'STRING' },
{ name: 'id', type: 'INT', longType: 'INT()' },
{ name: 'column2', type: 'STRING', longType: 'STRING()' },
],
},
},

View File

@@ -122,12 +122,12 @@ export default function getInitialState({
.forEach(tableSchema => {
const { dataPreviewQueryId, ...persistData } = tableSchema.description;
const table = {
dbId: tableSchema.database_id,
dbId: tableSchema.database_id ?? 0,
queryEditorId: tableSchema.tab_state_id.toString(),
catalog: tableSchema.catalog,
schema: tableSchema.schema,
name: tableSchema.table,
expanded: tableSchema.expanded,
expanded: Boolean(tableSchema.expanded),
id: tableSchema.id,
dataPreviewQueryId,
persistData,
@@ -147,7 +147,8 @@ export default function getInitialState({
}),
};
const destroyedQueryEditors: Record<string, number> = {};
const destroyedQueryEditors: SqlLabRootState['sqlLab']['destroyedQueryEditors'] =
{};
/**
* If the `SQLLAB_BACKEND_PERSISTENCE` feature flag is off, or if the user

View File

@@ -187,30 +187,40 @@ export default function sqlLabReducer(state = {}, action) {
},
[actions.MERGE_TABLE]() {
const at = { ...action.table };
let existingTable;
state.tables.forEach(xt => {
if (
const existingTableIndex = state.tables.findIndex(
xt =>
xt.dbId === at.dbId &&
xt.queryEditorId === at.queryEditorId &&
xt.catalog === at.catalog &&
xt.schema === at.schema &&
xt.name === at.name
) {
existingTable = xt;
}
});
if (existingTable) {
xt.name === at.name,
);
if (existingTableIndex >= 0) {
if (action.query) {
at.dataPreviewQueryId = action.query.id;
}
if (existingTable.initialized) {
at.id = existingTable.id;
}
return alterInArr(state, 'tables', existingTable, at);
return {
...state,
tables: [
...state.tables.slice(0, existingTableIndex),
{
...state.tables[existingTableIndex],
...at,
...(state.tables[existingTableIndex].initialized && {
id: state.tables[existingTableIndex].id,
}),
},
...state.tables.slice(existingTableIndex + 1),
],
...(at.expanded && {
activeSouthPaneTab: at.id,
}),
};
}
// for new table, associate Id of query for data preview
at.dataPreviewQueryId = null;
let newState = addToArr(state, 'tables', at, Boolean(action.prepend));
newState.activeSouthPaneTab = at.id;
if (action.query) {
newState = alterInArr(newState, 'tables', at, {
dataPreviewQueryId: action.query.id,
@@ -245,7 +255,6 @@ export default function sqlLabReducer(state = {}, action) {
...state,
queries,
tables: newTables,
activeSouthPaneTab: action.newQuery.id,
};
},
[actions.COLLAPSE_TABLE]() {
@@ -253,9 +262,18 @@ export default function sqlLabReducer(state = {}, action) {
},
[actions.REMOVE_TABLES]() {
const tableIds = action.tables.map(table => table.id);
const tables = state.tables.filter(table => !tableIds.includes(table.id));
return {
...state,
tables: state.tables.filter(table => !tableIds.includes(table.id)),
tables,
...(tableIds.includes(state.activeSouthPaneTab) && {
activeSouthPaneTab:
tables.find(
({ queryEditorId }) =>
queryEditorId === action.tables[0].queryEditorId,
)?.id ?? 'Results',
}),
};
},
[actions.COST_ESTIMATE_STARTED]() {
@@ -315,8 +333,6 @@ export default function sqlLabReducer(state = {}, action) {
const queries = { ...state.queries, [q.id]: q };
newState = { ...state, queries };
}
} else {
newState.activeSouthPaneTab = action.query.id;
}
newState = addToObject(newState, 'queries', action.query);

View File

@@ -86,7 +86,7 @@ export interface Table {
schema: string;
name: string;
queryEditorId: QueryEditor['id'];
dataPreviewQueryId: string | null;
dataPreviewQueryId?: string | null;
expanded: boolean;
initialized?: boolean;
inLocalStorage?: boolean;

View File

@@ -32,33 +32,21 @@ export const useDisplayModeToggle = () => {
<div
css={(theme: SupersetTheme) => css`
margin-bottom: ${theme.gridUnit * 6}px;
.ant-radio-button-wrapper-checked:not(
.ant-radio-button-wrapper-disabled
):focus-within {
box-shadow: none;
}
`}
data-test="drill-by-display-toggle"
>
<Radio.Group
<Radio.GroupWrapper
onChange={({ target: { value } }) => {
setDrillByDisplayMode(value);
}}
defaultValue={DrillByType.Chart}
>
<Radio.Button
value={DrillByType.Chart}
data-test="drill-by-chart-radio"
>
{t('Chart')}
</Radio.Button>
<Radio.Button
value={DrillByType.Table}
data-test="drill-by-table-radio"
>
{t('Table')}
</Radio.Button>
</Radio.Group>
options={[
{ label: t('Chart'), value: DrillByType.Chart },
{ label: t('Table'), value: DrillByType.Table },
]}
optionType="button"
buttonStyle="outline"
/>
</div>
),
[],

View File

@@ -38,7 +38,7 @@ describe('FilterableTable', () => {
const { getByRole, getByText } = render(
<FilterableTable {...mockedProps} />,
);
expect(getByRole('table')).toBeInTheDocument();
expect(getByRole('treegrid')).toBeInTheDocument();
mockedProps.data.forEach(({ b: columnBContent }) => {
expect(getByText(columnBContent)).toBeInTheDocument();
});
@@ -78,11 +78,10 @@ describe('FilterableTable sorting - RTL', () => {
};
render(<FilterableTable {...stringProps} />);
const stringColumn = within(screen.getByRole('table'))
const stringColumn = within(screen.getByRole('treegrid'))
.getByText('columnA')
.closest('th');
// Antd 4.x Table does not follow the table role structure. Need a hacky selector to point the cell item
const gridCells = screen.getByTitle('Bravo').closest('.virtual-grid');
.closest('[role=button]');
const gridCells = screen.getByText('Bravo').closest('[role=rowgroup]');
// Original order
expect(gridCells?.textContent).toEqual(
@@ -124,10 +123,10 @@ describe('FilterableTable sorting - RTL', () => {
};
render(<FilterableTable {...integerProps} />);
const integerColumn = within(screen.getByRole('table'))
const integerColumn = within(screen.getByRole('treegrid'))
.getByText('columnB')
.closest('th');
const gridCells = screen.getByTitle('21').closest('.virtual-grid');
.closest('[role=button]');
const gridCells = screen.getByText('21').closest('[role=rowgroup]');
// Original order
expect(gridCells?.textContent).toEqual(['21', '0', '623'].join(''));
@@ -159,10 +158,10 @@ describe('FilterableTable sorting - RTL', () => {
};
render(<FilterableTable {...floatProps} />);
const floatColumn = within(screen.getByRole('table'))
const floatColumn = within(screen.getByRole('treegrid'))
.getByText('columnC')
.closest('th');
const gridCells = screen.getByTitle('45.67').closest('.virtual-grid');
.closest('[role=button]');
const gridCells = screen.getByText('45.67').closest('[role=rowgroup]');
// Original order
expect(gridCells?.textContent).toEqual(
@@ -214,10 +213,10 @@ describe('FilterableTable sorting - RTL', () => {
};
render(<FilterableTable {...mixedFloatProps} />);
const mixedFloatColumn = within(screen.getByRole('table'))
const mixedFloatColumn = within(screen.getByRole('treegrid'))
.getByText('columnD')
.closest('th');
const gridCells = screen.getByTitle('48710.92').closest('.virtual-grid');
.closest('[role=button]');
const gridCells = screen.getByText('48710.92').closest('[role=rowgroup]');
// Original order
expect(gridCells?.textContent).toEqual(
@@ -312,10 +311,10 @@ describe('FilterableTable sorting - RTL', () => {
};
render(<FilterableTable {...dsProps} />);
const dsColumn = within(screen.getByRole('table'))
const dsColumn = within(screen.getByRole('treegrid'))
.getByText('columnDS')
.closest('th');
const gridCells = screen.getByTitle('2021-01-01').closest('.virtual-grid');
.closest('[role=button]');
const gridCells = screen.getByText('2021-01-01').closest('[role=rowgroup]');
// Original order
expect(gridCells?.textContent).toEqual(

View File

@@ -16,55 +16,20 @@
* specific language governing permissions and limitations
* under the License.
*/
import _JSONbig from 'json-bigint';
import { useEffect, useRef, useState, useMemo } from 'react';
import { getMultipleTextDimensions, styled } from '@superset-ui/core';
import { useDebounceValue } from 'src/hooks/useDebounceValue';
import { useMemo, useRef, useCallback } from 'react';
import { styled } from '@superset-ui/core';
import { useCellContentParser } from './useCellContentParser';
import { renderResultCell } from './utils';
import { Table, TableSize } from '../Table';
import GridTable, { GridSize, ColDef } from '../GridTable';
const JSONbig = _JSONbig({
storeAsString: true,
constructorAction: 'preserve',
});
const SCROLL_BAR_HEIGHT = 15;
// This regex handles all possible number formats in javascript, including ints, floats,
// exponential notation, NaN, and Infinity.
// See https://stackoverflow.com/a/30987109 for more details
const ONLY_NUMBER_REGEX = /^(NaN|-?((\d*\.\d+|\d+)([Ee][+-]?\d+)?|Infinity))$/;
const StyledFilterableTable = styled.div`
${({ theme }) => `
height: 100%;
overflow: hidden;
.ant-table-cell {
font-weight: ${theme.typography.weights.bold};
background-color: ${theme.colors.grayscale.light5};
}
.ant-table-cell,
.virtual-table-cell {
min-width: 0px;
align-self: center;
font-size: ${theme.typography.sizes.s}px;
}
.even-row {
background: ${theme.colors.grayscale.light4};
}
.odd-row {
background: ${theme.colors.grayscale.light5};
}
.cell-text-for-measuring {
font-family: ${theme.typography.families.sansSerif};
font-size: ${theme.typography.sizes.s}px;
}
`}
height: 100%;
overflow: hidden;
`;
type CellDataType = string | number | null;
@@ -79,12 +44,38 @@ export interface FilterableTableProps {
overscanColumnCount?: number;
overscanRowCount?: number;
rowHeight?: number;
// need antd 5.0 to support striped color pattern
striped?: boolean;
expandedColumns?: string[];
allowHTML?: boolean;
}
const parseNumberFromString = (value: string | number | null) => {
if (typeof value === 'string' && ONLY_NUMBER_REGEX.test(value)) {
return parseFloat(value);
}
return value;
};
const sortResults = (valueA: string | number, valueB: string | number) => {
const aValue = parseNumberFromString(valueA);
const bValue = parseNumberFromString(valueB);
// equal items sort equally
if (aValue === bValue) {
return 0;
}
// nulls sort after anything else
if (aValue === null) {
return 1;
}
if (bValue === null) {
return -1;
}
return aValue < bValue ? -1 : 1;
};
const FilterableTable = ({
orderedColumnKeys,
data,
@@ -92,83 +83,13 @@ const FilterableTable = ({
filterText = '',
expandedColumns = [],
allowHTML = true,
striped,
}: FilterableTableProps) => {
const formatTableData = (data: Record<string, unknown>[]): Datum[] =>
data.map(row => {
const newRow: Record<string, any> = {};
Object.entries(row).forEach(([key, val]) => {
if (['string', 'number'].indexOf(typeof val) >= 0) {
newRow[key] = val;
} else {
newRow[key] = val === null ? null : JSONbig.stringify(val);
}
});
return newRow;
});
const [fitted, setFitted] = useState(false);
const [list] = useState<Datum[]>(() => formatTableData(data));
const getCellContent = useCellContentParser({
columnKeys: orderedColumnKeys,
expandedColumns,
});
const getWidthsForColumns = () => {
const PADDING = 50; // accounts for cell padding and width of sorting icon
const widthsByColumnKey: Record<string, number> = {};
const cellContent = ([] as string[]).concat(
...orderedColumnKeys.map(key => {
const cellContentList = list.map((data: Datum) =>
getCellContent({ cellData: data[key], columnKey: key }),
);
cellContentList.push(key);
return cellContentList;
}),
);
const colWidths = getMultipleTextDimensions({
className: 'cell-text-for-measuring',
texts: cellContent,
}).map(dimension => dimension.width);
orderedColumnKeys.forEach((key, index) => {
// we can't use Math.max(...colWidths.slice(...)) here since the number
// of elements might be bigger than the number of allowed arguments in a
// JavaScript function
widthsByColumnKey[key] =
colWidths
.slice(index * (list.length + 1), (index + 1) * (list.length + 1))
.reduce((a, b) => Math.max(a, b)) + PADDING;
});
return widthsByColumnKey;
};
const [widthsForColumnsByKey] = useState<Record<string, number>>(() =>
getWidthsForColumns(),
);
const totalTableWidth = useRef(
orderedColumnKeys
.map(key => widthsForColumnsByKey[key])
.reduce((curr, next) => curr + next),
);
const container = useRef<HTMLDivElement>(null);
const fitTableToWidthIfNeeded = () => {
const containerWidth = container.current?.clientWidth ?? 0;
if (totalTableWidth.current < containerWidth) {
// fit table width if content doesn't fill the width of the container
totalTableWidth.current = containerWidth;
}
setFitted(true);
};
useEffect(() => {
fitTableToWidthIfNeeded();
}, []);
const hasMatch = (text: string, row: Datum) => {
const values: string[] = [];
Object.keys(row).forEach(key => {
@@ -188,86 +109,52 @@ const FilterableTable = ({
return values.some(v => v.includes(lowerCaseText));
};
// Parse any numbers from strings so they'll sort correctly
const parseNumberFromString = (value: string | number | null) => {
if (typeof value === 'string') {
if (ONLY_NUMBER_REGEX.test(value)) {
return parseFloat(value);
}
}
return value;
};
const sortResults = (key: string, a: Datum, b: Datum) => {
const aValue = parseNumberFromString(a[key]);
const bValue = parseNumberFromString(b[key]);
// equal items sort equally
if (aValue === bValue) {
return 0;
}
// nulls sort after anything else
if (aValue === null) {
return 1;
}
if (bValue === null) {
return -1;
}
return aValue < bValue ? -1 : 1;
};
const keyword = useDebounceValue(filterText);
const filteredList = useMemo(
const columns = useMemo(
() =>
keyword ? list.filter((row: Datum) => hasMatch(keyword, row)) : list,
[list, keyword],
orderedColumnKeys.map(key => ({
key,
label: key,
fieldName: key,
headerName: key,
comparator: sortResults,
render: ({ value, colDef }: { value: CellDataType; colDef: ColDef }) =>
renderResultCell({
cellData: value,
columnKey: colDef.field,
allowHTML,
getCellContent,
}),
})),
[orderedColumnKeys, allowHTML, getCellContent],
);
// exclude the height of the horizontal scroll bar from the height of the table
// and the height of the table container if the content overflows
const totalTableHeight =
container.current && totalTableWidth.current > container.current.clientWidth
? height - SCROLL_BAR_HEIGHT
: height;
const keyword = useRef<string | undefined>(filterText);
keyword.current = filterText;
const columns = orderedColumnKeys.map(key => ({
key,
title: key,
dataIndex: key,
width: widthsForColumnsByKey[key],
sorter: (a: Datum, b: Datum) => sortResults(key, a, b),
render: (text: CellDataType) =>
renderResultCell({
cellData: text,
columnKey: key,
allowHTML,
getCellContent,
}),
}));
const keywordFilter = useCallback(node => {
if (keyword.current && node.data) {
return hasMatch(keyword.current, node.data);
}
return true;
}, []);
return (
<StyledFilterableTable
className="filterable-table-container"
data-test="table-container"
ref={container}
>
{fitted && (
<Table
loading={filterText !== keyword}
size={TableSize.Small}
height={totalTableHeight + 42}
usePagination={false}
columns={columns}
data={filteredList}
childrenColumnName=""
virtualize
bordered
/>
)}
<GridTable
size={GridSize.Small}
height={height}
usePagination={false}
columns={columns}
data={data}
externalFilter={keywordFilter}
showRowNumber
striped={striped}
enableActions
columnReorderable
/>
</StyledFilterableTable>
);
};

View File

@@ -0,0 +1,66 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { render } from 'spec/helpers/testing-library';
import GridTable from '.';
jest.mock('src/components/ErrorBoundary', () => ({
__esModule: true,
default: ({ children }: { children: React.ReactNode }) => <>{children}</>,
}));
const mockedProps = {
queryId: 'abc',
columns: ['a', 'b', 'c'].map(key => ({
key,
label: key,
headerName: key,
render: ({ value }: { value: any }) => value,
})),
data: [
{ a: 'a1', b: 'b1', c: 'c1', d: 0 },
{ a: 'a2', b: 'b2', c: 'c2', d: 100 },
{ a: null, b: 'b3', c: 'c3', d: 50 },
],
height: 500,
};
test('renders a grid with 3 Table rows', () => {
const { queryByText } = render(<GridTable {...mockedProps} />);
mockedProps.data.forEach(({ b: columnBContent }) => {
expect(queryByText(columnBContent)).toBeInTheDocument();
});
});
test('sorts strings correctly', () => {
const stringProps = {
...mockedProps,
columns: ['columnA'].map(key => ({
key,
label: key,
headerName: key,
render: ({ value }: { value: any }) => value,
})),
data: [{ columnA: 'Bravo' }, { columnA: 'Alpha' }, { columnA: 'Charlie' }],
height: 500,
};
const { container } = render(<GridTable {...stringProps} />);
// Original order
expect(container).toHaveTextContent(['Bravo', 'Alpha', 'Charlie'].join(''));
});

View File

@@ -0,0 +1,109 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import type { Column, GridApi } from 'ag-grid-community';
import { act, fireEvent, render } from 'spec/helpers/testing-library';
import Header from './Header';
import { PIVOT_COL_ID } from './constants';
jest.mock('src/components/Dropdown', () => ({
Dropdown: () => <div data-test="mock-dropdown" />,
}));
jest.mock('src/components/Icons', () => ({
Sort: () => <div data-test="mock-sort" />,
SortAsc: () => <div data-test="mock-sort-asc" />,
SortDesc: () => <div data-test="mock-sort-desc" />,
}));
class MockApi extends EventTarget {
getAllDisplayedColumns() {
return [];
}
isDestroyed() {
return false;
}
}
const mockedProps = {
displayName: 'test column',
setSort: jest.fn(),
enableSorting: true,
column: {
getColId: () => '123',
isPinnedLeft: () => true,
isPinnedRight: () => false,
getSort: () => 'asc',
getSortIndex: () => null,
} as any as Column,
api: new MockApi() as any as GridApi,
};
test('renders display name for the column', () => {
const { queryByText } = render(<Header {...mockedProps} />);
expect(queryByText(mockedProps.displayName)).toBeInTheDocument();
});
test('sorts by clicking a column header', () => {
const { getByText, queryByTestId } = render(<Header {...mockedProps} />);
fireEvent.click(getByText(mockedProps.displayName));
expect(mockedProps.setSort).toHaveBeenCalledWith('asc', false);
expect(queryByTestId('mock-sort-asc')).toBeInTheDocument();
fireEvent.click(getByText(mockedProps.displayName));
expect(mockedProps.setSort).toHaveBeenCalledWith('desc', false);
expect(queryByTestId('mock-sort-desc')).toBeInTheDocument();
fireEvent.click(getByText(mockedProps.displayName));
expect(mockedProps.setSort).toHaveBeenCalledWith(null, false);
expect(queryByTestId('mock-sort-asc')).not.toBeInTheDocument();
expect(queryByTestId('mock-sort-desc')).not.toBeInTheDocument();
});
test('synchronizes the current sort when sortChanged event occured', async () => {
const { findByTestId } = render(<Header {...mockedProps} />);
act(() => {
mockedProps.api.dispatchEvent(new Event('sortChanged'));
});
const sortAsc = await findByTestId('mock-sort-asc');
expect(sortAsc).toBeInTheDocument();
});
test('disable menu when enableFilterButton is false', () => {
const { queryByText, queryByTestId } = render(
<Header {...mockedProps} enableFilterButton={false} />,
);
expect(queryByText(mockedProps.displayName)).toBeInTheDocument();
expect(queryByTestId('mock-dropdown')).not.toBeInTheDocument();
});
test('hide display name for PIVOT_COL_ID', () => {
const { queryByText } = render(
<Header
{...mockedProps}
column={
{
getColId: () => PIVOT_COL_ID,
isPinnedLeft: () => true,
isPinnedRight: () => false,
getSortIndex: () => null,
} as any as Column
}
/>,
);
expect(queryByText(mockedProps.displayName)).not.toBeInTheDocument();
});

View File

@@ -0,0 +1,200 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { useCallback, useEffect, useRef, useState } from 'react';
import { styled, useTheme, t } from '@superset-ui/core';
import type { Column, GridApi } from 'ag-grid-community';
import Icons from 'src/components/Icons';
import { PIVOT_COL_ID } from './constants';
import HeaderMenu from './HeaderMenu';
interface Params {
enableFilterButton?: boolean;
enableSorting?: boolean;
displayName: string;
column: Column;
api: GridApi;
setSort: (sort: string | null, multiSort: boolean) => void;
}
const SORT_DIRECTION = [null, 'asc', 'desc'];
const HeaderCell = styled.div`
display: flex;
flex: 1;
&[role='button'] {
cursor: pointer;
}
`;
const HeaderCellSort = styled.div`
position: relative;
display: inline-flex;
align-items: center;
`;
const SortSeqLabel = styled.span`
position: absolute;
right: 0;
`;
const HeaderAction = styled.div`
display: none;
position: absolute;
right: ${({ theme }) => theme.gridUnit * 3}px;
&.main {
margin: 0 auto;
left: 0;
right: 0;
width: 20px;
}
& .ant-dropdown-trigger {
cursor: context-menu;
padding: ${({ theme }) => theme.gridUnit * 2}px;
background-color: var(--ag-background-color);
box-shadow: 0 0 2px var(--ag-chip-border-color);
border-radius: 50%;
&:hover {
box-shadow: 0 0 4px ${({ theme }) => theme.colors.grayscale.light1};
}
}
`;
const IconPlaceholder = styled.div`
position: absolute;
top: 0;
`;
const Header: React.FC<Params> = ({
enableFilterButton,
enableSorting,
displayName,
setSort,
column,
api,
}: Params) => {
const theme = useTheme();
const colId = column.getColId();
const pinnedLeft = column.isPinnedLeft();
const pinnedRight = column.isPinnedRight();
const sortOption = useRef<number>(0);
const [invisibleColumns, setInvisibleColumns] = useState<Column[]>([]);
const [currentSort, setCurrentSort] = useState<string | null>(null);
const [sortIndex, setSortIndex] = useState<number | null>();
const onSort = useCallback(
event => {
sortOption.current = (sortOption.current + 1) % SORT_DIRECTION.length;
const sort = SORT_DIRECTION[sortOption.current];
setSort(sort, event.shiftKey);
setCurrentSort(sort);
},
[setSort],
);
const onVisibleChange = useCallback(
(isVisible: boolean) => {
if (isVisible) {
setInvisibleColumns(
api.getColumns()?.filter(c => !c.isVisible()) || [],
);
}
},
[api],
);
const onSortChanged = useCallback(() => {
const hasMultiSort =
api.getAllDisplayedColumns().findIndex(c => c.getSortIndex()) !== -1;
const updatedSortIndex = column.getSortIndex();
sortOption.current = SORT_DIRECTION.indexOf(column.getSort() ?? null);
setCurrentSort(column.getSort() ?? null);
setSortIndex(hasMultiSort ? updatedSortIndex : null);
}, [api, column]);
useEffect(() => {
api.addEventListener('sortChanged', onSortChanged);
return () => {
if (api.isDestroyed()) return;
api.removeEventListener('sortChanged', onSortChanged);
};
}, [api, onSortChanged]);
return (
<>
{colId !== PIVOT_COL_ID && (
<HeaderCell
tabIndex={0}
className="ag-header-cell-label"
{...(enableSorting && {
role: 'button',
onClick: onSort,
title: t(
'To enable multiple column sorting, hold down the ⇧ Shift key while clicking the column header.',
),
})}
>
<div className="ag-header-cell-text">{displayName}</div>
{enableSorting && (
<HeaderCellSort>
<Icons.Sort iconSize="xxl" />
<IconPlaceholder>
{currentSort === 'asc' && (
<Icons.SortAsc
iconSize="xxl"
iconColor={theme.colors.primary.base}
/>
)}
{currentSort === 'desc' && (
<Icons.SortDesc
iconSize="xxl"
iconColor={theme.colors.primary.base}
/>
)}
</IconPlaceholder>
{typeof sortIndex === 'number' && (
<SortSeqLabel>{sortIndex + 1}</SortSeqLabel>
)}
</HeaderCellSort>
)}
</HeaderCell>
)}
{enableFilterButton && colId && api && (
<HeaderAction
className={`customHeaderAction${
colId === PIVOT_COL_ID ? ' main' : ''
}`}
>
{colId && (
<HeaderMenu
colId={colId}
api={api}
pinnedLeft={pinnedLeft}
pinnedRight={pinnedRight}
invisibleColumns={invisibleColumns}
isMain={colId === PIVOT_COL_ID}
onVisibleChange={onVisibleChange}
/>
)}
</HeaderAction>
)}
</>
);
};
export default Header;

View File

@@ -0,0 +1,266 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import type { Column, GridApi } from 'ag-grid-community';
import {
fireEvent,
render,
waitFor,
screen,
} from 'spec/helpers/testing-library';
import HeaderMenu from './HeaderMenu';
jest.mock('src/components/Menu', () => {
const Menu = ({ children }: { children: React.ReactChild }) => (
<div data-test="mock-Menu">{children}</div>
);
Menu.Item = ({
children,
onClick,
}: {
children: React.ReactChild;
onClick: () => void;
}) => (
<button type="button" data-test="mock-Item" onClick={onClick}>
{children}
</button>
);
Menu.SubMenu = ({
title,
children,
}: {
title: React.ReactNode;
children: React.ReactNode;
}) => (
<div>
{title}
<button type="button" data-test="mock-SubMenu">
{children}
</button>
</div>
);
Menu.Divider = () => <div data-test="mock-Divider" />;
return { Menu };
});
jest.mock('src/components/Icons', () => ({
DownloadOutlined: () => <div data-test="mock-DownloadOutlined" />,
CopyOutlined: () => <div data-test="mock-CopyOutlined" />,
UnlockOutlined: () => <div data-test="mock-UnlockOutlined" />,
VerticalRightOutlined: () => <div data-test="mock-VerticalRightOutlined" />,
VerticalLeftOutlined: () => <div data-test="mock-VerticalLeftOutlined" />,
EyeInvisibleOutlined: () => <div data-test="mock-EyeInvisibleOutlined" />,
EyeOutlined: () => <div data-test="mock-EyeOutlined" />,
ColumnWidthOutlined: () => <div data-test="mock-column-width" />,
}));
jest.mock('src/components/Dropdown', () => ({
Dropdown: ({ overlay }: { overlay: React.ReactChild }) => (
<div data-test="mock-Dropdown">{overlay}</div>
),
}));
jest.mock('src/utils/copy', () => jest.fn().mockImplementation(f => f()));
const mockInvisibleColumn = {
getColId: jest.fn().mockReturnValue('column2'),
getColDef: jest.fn().mockReturnValue({ headerName: 'column2' }),
getDataAsCsv: jest.fn().mockReturnValue('csv'),
} as any as Column;
const mockInvisibleColumn3 = {
getColId: jest.fn().mockReturnValue('column3'),
getColDef: jest.fn().mockReturnValue({ headerName: 'column3' }),
getDataAsCsv: jest.fn().mockReturnValue('csv'),
} as any as Column;
const mockGridApi = {
autoSizeColumns: jest.fn(),
autoSizeAllColumns: jest.fn(),
getColumn: jest.fn().mockReturnValue({
getColDef: jest.fn().mockReturnValue({}),
}),
getColumns: jest.fn().mockReturnValue([]),
getDataAsCsv: jest.fn().mockReturnValue('csv'),
exportDataAsCsv: jest.fn().mockReturnValue('csv'),
getAllDisplayedColumns: jest.fn().mockReturnValue([]),
setColumnsPinned: jest.fn(),
setColumnsVisible: jest.fn(),
setColumnVisible: jest.fn(),
moveColumns: jest.fn(),
} as any as GridApi;
const mockedProps = {
colId: 'column1',
invisibleColumns: [],
api: mockGridApi,
onVisibleChange: jest.fn(),
};
afterEach(() => {
(mockGridApi.getDataAsCsv as jest.Mock).mockClear();
(mockGridApi.setColumnsPinned as jest.Mock).mockClear();
(mockGridApi.setColumnsVisible as jest.Mock).mockClear();
(mockGridApi.setColumnsVisible as jest.Mock).mockClear();
(mockGridApi.setColumnsPinned as jest.Mock).mockClear();
(mockGridApi.autoSizeColumns as jest.Mock).mockClear();
(mockGridApi.autoSizeAllColumns as jest.Mock).mockClear();
(mockGridApi.moveColumns as jest.Mock).mockClear();
});
test('renders copy data', async () => {
const { getByText } = render(<HeaderMenu {...mockedProps} />);
fireEvent.click(getByText('Copy'));
await waitFor(() =>
expect(mockGridApi.getDataAsCsv).toHaveBeenCalledTimes(1),
);
expect(mockGridApi.getDataAsCsv).toHaveBeenCalledWith({
columnKeys: [mockedProps.colId],
suppressQuotes: true,
});
});
test('renders buttons pinning both sides', () => {
const { queryByText, getByText } = render(<HeaderMenu {...mockedProps} />);
expect(queryByText('Pin Left')).toBeInTheDocument();
expect(queryByText('Pin Right')).toBeInTheDocument();
fireEvent.click(getByText('Pin Left'));
expect(mockGridApi.setColumnsPinned).toHaveBeenCalledTimes(1);
expect(mockGridApi.setColumnsPinned).toHaveBeenCalledWith(
[mockedProps.colId],
'left',
);
fireEvent.click(getByText('Pin Right'));
expect(mockGridApi.setColumnsPinned).toHaveBeenLastCalledWith(
[mockedProps.colId],
'right',
);
});
test('renders unpin on pinned left', () => {
const { queryByText, getByText } = render(
<HeaderMenu {...mockedProps} pinnedLeft />,
);
expect(queryByText('Pin Left')).not.toBeInTheDocument();
expect(queryByText('Unpin')).toBeInTheDocument();
fireEvent.click(getByText('Unpin'));
expect(mockGridApi.setColumnsPinned).toHaveBeenCalledTimes(1);
expect(mockGridApi.setColumnsPinned).toHaveBeenCalledWith(
[mockedProps.colId],
null,
);
});
test('renders unpin on pinned right', () => {
const { queryByText } = render(<HeaderMenu {...mockedProps} pinnedRight />);
expect(queryByText('Pin Right')).not.toBeInTheDocument();
expect(queryByText('Unpin')).toBeInTheDocument();
});
test('renders autosize column', async () => {
const { getByText } = render(<HeaderMenu {...mockedProps} />);
fireEvent.click(getByText('Autosize Column'));
await waitFor(() =>
expect(mockGridApi.autoSizeColumns).toHaveBeenCalledTimes(1),
);
});
test('renders unhide when invisible column exists', async () => {
const { queryByText } = render(
<HeaderMenu {...mockedProps} invisibleColumns={[mockInvisibleColumn]} />,
);
expect(queryByText('Unhide')).toBeInTheDocument();
const unhideColumnsButton = await screen.findByText('column2');
fireEvent.click(unhideColumnsButton);
expect(mockGridApi.setColumnsVisible).toHaveBeenCalledTimes(1);
expect(mockGridApi.setColumnsVisible).toHaveBeenCalledWith(['column2'], true);
});
describe('for main menu', () => {
test('renders Copy to Clipboard', async () => {
const { getByText } = render(<HeaderMenu {...mockedProps} isMain />);
fireEvent.click(getByText('Copy the current data'));
await waitFor(() =>
expect(mockGridApi.getDataAsCsv).toHaveBeenCalledTimes(1),
);
expect(mockGridApi.getDataAsCsv).toHaveBeenCalledWith({
columnKeys: [],
columnSeparator: '\t',
suppressQuotes: true,
});
});
test('renders Download to CSV', async () => {
const { getByText } = render(<HeaderMenu {...mockedProps} isMain />);
fireEvent.click(getByText('Download to CSV'));
await waitFor(() =>
expect(mockGridApi.exportDataAsCsv).toHaveBeenCalledTimes(1),
);
expect(mockGridApi.exportDataAsCsv).toHaveBeenCalledWith({
columnKeys: [],
});
});
test('renders autosize column', async () => {
const { getByText } = render(<HeaderMenu {...mockedProps} isMain />);
fireEvent.click(getByText('Autosize all columns'));
await waitFor(() =>
expect(mockGridApi.autoSizeAllColumns).toHaveBeenCalledTimes(1),
);
});
test('renders all unhide all hidden columns when multiple invisible columns exist', async () => {
render(
<HeaderMenu
{...mockedProps}
isMain
invisibleColumns={[mockInvisibleColumn, mockInvisibleColumn3]}
/>,
);
const unhideColumnsButton = await screen.findByText(
`All ${2} hidden columns`,
);
fireEvent.click(unhideColumnsButton);
expect(mockGridApi.setColumnsVisible).toHaveBeenCalledTimes(1);
expect(mockGridApi.setColumnsVisible).toHaveBeenCalledWith(
[mockInvisibleColumn, mockInvisibleColumn3],
true,
);
});
test('reset columns configuration', async () => {
const { getByText } = render(
<HeaderMenu
{...mockedProps}
isMain
invisibleColumns={[mockInvisibleColumn]}
/>,
);
fireEvent.click(getByText('Reset columns'));
await waitFor(() =>
expect(mockGridApi.setColumnsVisible).toHaveBeenCalledTimes(1),
);
expect(mockGridApi.setColumnsVisible).toHaveBeenCalledWith(
[mockInvisibleColumn],
true,
);
expect(mockGridApi.setColumnsPinned).toHaveBeenCalledTimes(1);
expect(mockGridApi.setColumnsPinned).toHaveBeenCalledWith([], null);
expect(mockGridApi.moveColumns).toHaveBeenCalledTimes(1);
});
});

View File

@@ -0,0 +1,247 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { useCallback } from 'react';
import { styled, t } from '@superset-ui/core';
import type { Column, ColumnPinnedType, GridApi } from 'ag-grid-community';
import Icons from 'src/components/Icons';
import { Dropdown, DropdownProps } from 'src/components/Dropdown';
import { Menu } from 'src/components/Menu';
import copyTextToClipboard from 'src/utils/copy';
import { PIVOT_COL_ID } from './constants';
const IconMenuItem = styled(Menu.Item)`
display: flex;
align-items: center;
`;
const IconEmpty = styled.span`
width: 20px;
`;
type Params = {
colId: string;
column?: Column;
api: GridApi;
pinnedLeft?: boolean;
pinnedRight?: boolean;
invisibleColumns: Column[];
isMain?: boolean;
onVisibleChange: DropdownProps['onVisibleChange'];
};
const HeaderMenu: React.FC<Params> = ({
colId,
api,
pinnedLeft,
pinnedRight,
invisibleColumns,
isMain,
onVisibleChange,
}: Params) => {
const pinColumn = useCallback(
(pinLoc: ColumnPinnedType) => {
api.setColumnsPinned([colId], pinLoc);
},
[api, colId],
);
const unHideAction = invisibleColumns.length > 0 && (
<Menu.SubMenu
title={
<>
<Icons.EyeOutlined iconSize="m" />
{t('Unhide')}
</>
}
>
{invisibleColumns.length > 1 && (
<Menu.Item
onClick={() => {
api.setColumnsVisible(invisibleColumns, true);
}}
>
<b>{t('All %s hidden columns', invisibleColumns.length)}</b>
</Menu.Item>
)}
{invisibleColumns.map(c => (
<Menu.Item
key={c.getColId()}
onClick={() => {
api.setColumnsVisible([c.getColId()], true);
}}
>
{c.getColDef().headerName}
</Menu.Item>
))}
</Menu.SubMenu>
);
if (isMain) {
return (
<Dropdown
placement="bottomLeft"
trigger={['click']}
onVisibleChange={onVisibleChange}
overlay={
<Menu style={{ width: 250 }} mode="vertical">
<IconMenuItem
onClick={() => {
copyTextToClipboard(
() =>
new Promise((resolve, reject) => {
const data = api.getDataAsCsv({
columnKeys: api
.getAllDisplayedColumns()
.map(c => c.getColId())
.filter(id => id !== colId),
suppressQuotes: true,
columnSeparator: '\t',
});
if (data) {
resolve(data);
} else {
reject();
}
}),
);
}}
>
<Icons.CopyOutlined iconSize="m" /> {t('Copy the current data')}
</IconMenuItem>
<IconMenuItem
onClick={() => {
api.exportDataAsCsv({
columnKeys: api
.getAllDisplayedColumns()
.map(c => c.getColId())
.filter(id => id !== colId),
});
}}
>
<Icons.DownloadOutlined iconSize="m" /> {t('Download to CSV')}
</IconMenuItem>
<Menu.Divider />
<IconMenuItem
onClick={() => {
api.autoSizeAllColumns();
}}
>
<Icons.ColumnWidthOutlined iconSize="m" />
{t('Autosize all columns')}
</IconMenuItem>
{unHideAction}
<Menu.Divider />
<IconMenuItem
onClick={() => {
api.setColumnsVisible(invisibleColumns, true);
const columns = api.getColumns();
if (columns) {
const pinnedColumns = columns.filter(
c => c.getColId() !== PIVOT_COL_ID && c.isPinned(),
);
api.setColumnsPinned(pinnedColumns, null);
api.moveColumns(columns, 0);
const firstColumn = columns.find(
c => c.getColId() !== PIVOT_COL_ID,
);
if (firstColumn) {
api.ensureColumnVisible(firstColumn, 'start');
}
}
}}
>
<IconEmpty className="anticon" />
{t('Reset columns')}
</IconMenuItem>
</Menu>
}
/>
);
}
return (
<Dropdown
placement="bottomRight"
trigger={['click']}
onVisibleChange={onVisibleChange}
overlay={
<Menu style={{ width: 180 }} mode="vertical">
<IconMenuItem
onClick={() => {
copyTextToClipboard(
() =>
new Promise((resolve, reject) => {
const data = api.getDataAsCsv({
columnKeys: [colId],
suppressQuotes: true,
});
if (data) {
resolve(data);
} else {
reject();
}
}),
);
}}
>
<Icons.CopyOutlined iconSize="m" /> {t('Copy')}
</IconMenuItem>
{(pinnedLeft || pinnedRight) && (
<IconMenuItem onClick={() => pinColumn(null)}>
<Icons.UnlockOutlined iconSize="m" /> {t('Unpin')}
</IconMenuItem>
)}
{!pinnedLeft && (
<IconMenuItem onClick={() => pinColumn('left')}>
<Icons.VerticalRightOutlined iconSize="m" />
{t('Pin Left')}
</IconMenuItem>
)}
{!pinnedRight && (
<IconMenuItem onClick={() => pinColumn('right')}>
<Icons.VerticalLeftOutlined iconSize="m" />
{t('Pin Right')}
</IconMenuItem>
)}
<Menu.Divider />
<IconMenuItem
onClick={() => {
api.autoSizeColumns([colId]);
}}
>
<Icons.ColumnWidthOutlined iconSize="m" />
{t('Autosize Column')}
</IconMenuItem>
<IconMenuItem
onClick={() => {
api.setColumnsVisible([colId], false);
}}
disabled={api.getColumns()?.length === invisibleColumns.length + 1}
>
<Icons.EyeInvisibleOutlined iconSize="m" />
{t('Hide Column')}
</IconMenuItem>
{unHideAction}
</Menu>
}
/>
);
};
export default HeaderMenu;

View File

@@ -0,0 +1,24 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
export const PIVOT_COL_ID = '-1';
export enum GridSize {
Small = 'small',
Middle = 'middle',
}

View File

@@ -0,0 +1,241 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { useCallback, useMemo } from 'react';
import { Global } from '@emotion/react';
import { css, useTheme } from '@superset-ui/core';
import type { Column } from 'ag-grid-community';
import { AgGridReact, type AgGridReactProps } from 'ag-grid-react';
import 'ag-grid-community/styles/ag-grid.css';
import 'ag-grid-community/styles/ag-theme-quartz.css';
import copyTextToClipboard from 'src/utils/copy';
import ErrorBoundary from 'src/components/ErrorBoundary';
import { PIVOT_COL_ID, GridSize } from './constants';
import Header from './Header';
const gridComponents = {
agColumnHeader: Header,
};
export { GridSize };
export type ColDef = {
type: string;
field: string;
};
export interface TableProps<RecordType> {
/**
* Data that will populate the each row and map to the column key.
*/
data: RecordType[];
/**
* Table column definitions.
*/
columns: {
label: string;
headerName?: string;
width?: number;
comparator?: (valueA: string | number, valueB: string | number) => number;
render?: (value: any) => React.ReactNode;
}[];
size?: GridSize;
externalFilter?: AgGridReactProps['doesExternalFilterPass'];
height: number;
columnReorderable?: boolean;
sortable?: boolean;
enableActions?: boolean;
showRowNumber?: boolean;
usePagination?: boolean;
striped?: boolean;
}
const onSortChanged: AgGridReactProps['onSortChanged'] = ({ api }) =>
api.refreshCells();
function GridTable<RecordType extends object>({
data,
columns,
sortable = true,
columnReorderable,
height,
externalFilter,
showRowNumber,
enableActions,
size = GridSize.Middle,
striped,
}: TableProps<RecordType>) {
const theme = useTheme();
const isExternalFilterPresent = useCallback(
() => Boolean(externalFilter),
[externalFilter],
);
const rowIndexLength = `${data.length}}`.length;
const onKeyDown: AgGridReactProps<Record<string, any>>['onCellKeyDown'] =
useCallback(({ event, column, data, value, api }) => {
if (
!document.getSelection?.()?.toString?.() &&
event &&
event.key === 'c' &&
(event.ctrlKey || event.metaKey)
) {
const columns =
column.getColId() === PIVOT_COL_ID
? api
.getAllDisplayedColumns()
.filter((column: Column) => column.getColId() !== PIVOT_COL_ID)
: [column];
const record =
column.getColId() === PIVOT_COL_ID
? [
columns.map((column: Column) => column.getColId()).join('\t'),
columns
.map((column: Column) => data?.[column.getColId()])
.join('\t'),
].join('\n')
: String(value);
copyTextToClipboard(() => Promise.resolve(record));
}
}, []);
const columnDefs = useMemo(
() =>
[
{
field: PIVOT_COL_ID,
valueGetter: 'node.rowIndex+1',
cellClass: 'locked-col',
width: 20 + rowIndexLength * 6,
suppressNavigable: true,
resizable: false,
pinned: 'left' as const,
sortable: false,
...(columnReorderable && { suppressMovable: true }),
},
...columns.map(
(
{ label, headerName, width, render: cellRenderer, comparator },
index,
) => ({
field: label,
headerName,
cellRenderer,
sortable,
comparator,
...(index === columns.length - 1 && {
flex: 1,
width,
minWidth: 150,
}),
}),
),
].slice(showRowNumber ? 0 : 1),
[rowIndexLength, columnReorderable, columns, showRowNumber, sortable],
);
const defaultColDef: AgGridReactProps['defaultColDef'] = {
...(!columnReorderable && { suppressMovable: true }),
resizable: true,
sortable,
filter: Boolean(enableActions),
};
const rowHeight = theme.gridUnit * (size === GridSize.Middle ? 9 : 7);
return (
<ErrorBoundary>
<Global
styles={() => css`
#grid-table.ag-theme-quartz {
--ag-icon-font-family: agGridMaterial;
--ag-grid-size: ${theme.gridUnit}px;
--ag-font-size: ${theme.typography.sizes[
size === GridSize.Middle ? 'm' : 's'
]}px;
--ag-font-family: ${theme.typography.families.sansSerif};
--ag-row-height: ${rowHeight}px;
${!striped &&
`--ag-odd-row-background-color: ${theme.colors.grayscale.light5};`}
--ag-border-color: ${theme.colors.grayscale.light2};
--ag-row-border-color: ${theme.colors.grayscale.light2};
--ag-header-background-color: ${theme.colors.grayscale.light4};
}
#grid-table .ag-cell {
-webkit-font-smoothing: antialiased;
}
.locked-col {
background: var(--ag-row-border-color);
padding: 0;
text-align: center;
font-size: calc(var(--ag-font-size) * 0.9);
color: var(--ag-disabled-foreground-color);
}
.ag-row-hover .locked-col {
background: var(--ag-row-hover-color);
}
.ag-header-cell {
overflow: hidden;
}
& [role='columnheader']:hover .customHeaderAction {
display: block;
}
`}
/>
<div
id="grid-table"
className="ag-theme-quartz"
css={css`
width: 100%;
height: ${height}px;
`}
>
<AgGridReact
rowData={data}
columnDefs={columnDefs}
defaultColDef={defaultColDef}
onSortChanged={onSortChanged}
isExternalFilterPresent={isExternalFilterPresent}
doesExternalFilterPass={externalFilter}
components={gridComponents}
gridOptions={{
enableCellTextSelection: true,
ensureDomOrder: true,
suppressFieldDotNotation: true,
headerHeight: rowHeight,
rowSelection: 'multiple',
rowHeight,
}}
onCellKeyDown={onKeyDown}
/>
</div>
</ErrorBoundary>
);
}
export default GridTable;

View File

@@ -27,17 +27,22 @@ import {
BarChartOutlined,
BellOutlined,
BookOutlined,
CaretDownOutlined,
CalendarOutlined,
CaretUpOutlined,
CheckOutlined,
CheckSquareOutlined,
CloseOutlined,
ColumnWidthOutlined,
CommentOutlined,
ConsoleSqlOutlined,
CopyOutlined,
DashboardOutlined,
DatabaseOutlined,
DeleteFilled,
DownSquareOutlined,
DownOutlined,
DownloadOutlined,
EditOutlined,
ExclamationCircleOutlined,
EyeOutlined,
@@ -65,8 +70,11 @@ import {
StopOutlined,
SyncOutlined,
TagsOutlined,
UnlockOutlined,
UpOutlined,
UserOutlined,
VerticalLeftOutlined,
VerticalRightOutlined,
} from '@ant-design/icons';
import { StyledIcon } from './Icon';
import IconType from './IconType';
@@ -80,17 +88,22 @@ const AntdIcons = {
BarChartOutlined,
BellOutlined,
BookOutlined,
CaretDownOutlined,
CalendarOutlined,
CaretUpOutlined,
CheckOutlined,
CheckSquareOutlined,
CloseOutlined,
ColumnWidthOutlined,
CommentOutlined,
ConsoleSqlOutlined,
CopyOutlined,
DashboardOutlined,
DatabaseOutlined,
DeleteFilled,
DownSquareOutlined,
DownOutlined,
DownloadOutlined,
EditOutlined,
ExclamationCircleOutlined,
EyeOutlined,
@@ -118,8 +131,11 @@ const AntdIcons = {
StopOutlined,
SyncOutlined,
TagsOutlined,
UnlockOutlined,
UpOutlined,
UserOutlined,
VerticalLeftOutlined,
VerticalRightOutlined,
};
const AntdEnhancedIcons = Object.keys(AntdIcons)

View File

@@ -16,40 +16,139 @@
* specific language governing permissions and limitations
* under the License.
*/
import { useArgs } from '@storybook/preview-api';
import { Radio } from './index';
import { Space } from 'src/components/Space';
import {
BarChartOutlined,
DotChartOutlined,
LineChartOutlined,
PieChartOutlined,
} from '@ant-design/icons';
import { Radio, RadioProps, RadioGroupWrapperProps } from './index';
export default {
title: 'Radio',
component: Radio,
parameters: {
controls: { hideNoControlsWarning: true },
tags: ['autodocs'],
};
const RadioArgsType = {
value: {
control: 'text',
description: 'The value of the radio button.',
},
argTypes: {
theme: {
table: {
disable: true,
},
},
checked: { control: 'boolean' },
disabled: { control: 'boolean' },
disabled: {
control: 'boolean',
description: 'Whether the radio button is disabled or not.',
},
checked: {
control: 'boolean',
description: 'The checked state of the radio button.',
},
};
export const SupersetRadio = () => {
const [{ checked, ...rest }, updateArgs] = useArgs();
return (
<Radio
checked={checked}
onChange={() => updateArgs({ checked: !checked })}
{...rest}
>
Example
</Radio>
);
const radioGroupWrapperArgsType = {
onChange: { action: 'changed' },
disabled: { control: 'boolean' },
size: {
control: 'select',
options: ['small', 'middle', 'large'],
},
options: { control: 'object' },
'spaceConfig.direction': {
control: 'select',
options: ['horizontal', 'vertical'],
description: 'Direction of the Space layout',
if: { arg: 'enableSpaceConfig', truthy: true },
},
'spaceConfig.size': {
control: 'select',
options: ['small', 'middle', 'large'],
description: 'Layout size Space',
if: { arg: 'enableSpaceConfig', truthy: true },
},
'spaceConfig.align': {
control: 'select',
options: ['start', 'center', 'end'],
description: 'Alignment of the Space layout',
if: { arg: 'enableSpaceConfig', truthy: true },
},
'spaceConfig.wrap': {
control: 'boolean',
description:
'Controls whether the items inside the Space component should wrap to the next line when space is insufficient',
if: { arg: 'enableSpaceConfig', truthy: true },
},
};
SupersetRadio.args = {
export const RadioStory = {
args: {
value: 'radio1',
disabled: false,
checked: false,
children: 'Radio',
},
argTypes: RadioArgsType,
};
export const RadioButtonStory = (args: RadioProps) => (
<Radio.Button {...args}>Radio Button</Radio.Button>
);
RadioButtonStory.args = {
value: 'button1',
disabled: false,
checked: false,
};
RadioButtonStory.argTypes = RadioArgsType;
export const RadioGroupWithOptionsStory = (args: RadioGroupWrapperProps) => (
<Radio.GroupWrapper {...args} />
);
RadioGroupWithOptionsStory.args = {
spaceConfig: {
direction: 'vertical',
size: 'middle',
align: 'center',
wrap: false,
},
size: 'middle',
options: [
{
value: 1,
label: (
<Space align="center" direction="vertical">
<LineChartOutlined style={{ fontSize: 18 }} />
LineChart
</Space>
),
},
{
value: 2,
label: (
<Space align="center" direction="vertical">
<DotChartOutlined style={{ fontSize: 18 }} />
DotChart
</Space>
),
},
{
value: 3,
label: (
<Space align="center" direction="vertical">
<BarChartOutlined style={{ fontSize: 18 }} />
BarChart
</Space>
),
},
{
value: 4,
label: (
<Space align="center" direction="vertical">
<PieChartOutlined style={{ fontSize: 18 }} />
PieChart
</Space>
),
},
],
disabled: false,
};
RadioGroupWithOptionsStory.argTypes = radioGroupWrapperArgsType;

View File

@@ -16,46 +16,48 @@
* specific language governing permissions and limitations
* under the License.
*/
import { styled } from '@superset-ui/core';
import { Radio as AntdRadio } from 'antd';
import { Radio as Antd5Radio, CheckboxOptionType } from 'antd-v5';
import type {
RadioChangeEvent,
RadioProps,
RadioGroupProps,
} from 'antd-v5/lib/radio';
const StyledRadio = styled(AntdRadio)`
.ant-radio-inner {
top: -1px;
left: 2px;
width: ${({ theme }) => theme.gridUnit * 4}px;
height: ${({ theme }) => theme.gridUnit * 4}px;
border-width: 2px;
border-color: ${({ theme }) => theme.colors.grayscale.light2};
}
import { Space, SpaceProps } from 'src/components/Space';
.ant-radio.ant-radio-checked {
.ant-radio-inner {
border-width: ${({ theme }) => theme.gridUnit + 1}px;
border-color: ${({ theme }) => theme.colors.primary.base};
}
export type RadioGroupWrapperProps = RadioGroupProps & {
spaceConfig?: {
direction?: SpaceProps['direction'];
size?: SpaceProps['size'];
align?: SpaceProps['align'];
wrap?: SpaceProps['wrap'];
};
options: CheckboxOptionType[];
};
.ant-radio-inner::after {
background-color: ${({ theme }) => theme.colors.grayscale.light5};
top: 0;
left: 0;
width: ${({ theme }) => theme.gridUnit + 2}px;
height: ${({ theme }) => theme.gridUnit + 2}px;
}
}
.ant-radio:hover,
.ant-radio:focus {
.ant-radio-inner {
border-color: ${({ theme }) => theme.colors.primary.dark1};
}
}
`;
const StyledGroup = styled(AntdRadio.Group)`
font-size: inherit;
`;
export const Radio = Object.assign(StyledRadio, {
Group: StyledGroup,
Button: AntdRadio.Button,
const RadioGroup = ({
spaceConfig,
options,
...props
}: RadioGroupWrapperProps) => {
const content = options.map((option: CheckboxOptionType) => (
<Radio key={option.value} value={option.value}>
{option.label}
</Radio>
));
return (
<Radio.Group {...props}>
{spaceConfig ? <Space {...spaceConfig}>{content}</Space> : content}
</Radio.Group>
);
};
export type {
RadioChangeEvent,
RadioGroupProps,
RadioProps,
CheckboxOptionType,
};
export const Radio = Object.assign(Antd5Radio, {
GroupWrapper: RadioGroup,
Button: Antd5Radio.Button,
});

View File

@@ -19,7 +19,6 @@
import { useState } from 'react';
import { css, useTheme } from '@superset-ui/core';
import { Radio } from 'src/components/Radio';
import { Space } from 'src/components/Space';
import Icons from 'src/components/Icons';
import Popover from 'src/components/Popover';
@@ -56,21 +55,20 @@ function HeaderWithRadioGroup(props: HeaderWithRadioGroupProps) {
>
{groupTitle}
</div>
<Radio.Group
<Radio.GroupWrapper
spaceConfig={{
direction: 'vertical',
size: 4,
wrap: false,
align: 'start',
}}
value={value}
onChange={e => {
onChange(e.target.value);
setPopoverVisible(false);
}}
>
<Space direction="vertical">
{groupOptions.map(option => (
<Radio key={option.value} value={option.value}>
{option.label}
</Radio>
))}
</Space>
</Radio.Group>
options={groupOptions}
/>
</div>
}
placement="bottomLeft"

View File

@@ -1106,15 +1106,16 @@ const FiltersConfigForm = (
initialValue={sort}
label={<StyledLabel>{t('Sort type')}</StyledLabel>}
>
<Radio.Group
<Radio.GroupWrapper
options={[
{ value: true, label: t('Sort ascending') },
{ value: false, label: t('Sort descending') },
]}
onChange={value => {
onSortChanged(value.target.value);
formChanged();
}}
>
<Radio value>{t('Sort ascending')}</Radio>
<Radio value={false}>{t('Sort descending')}</Radio>
</Radio.Group>
/>
</StyledRowFormItem>
{hasMetrics && (
<StyledRowSubFormItem
@@ -1181,22 +1182,23 @@ const FiltersConfigForm = (
<StyledLabel>{t('Single value type')}</StyledLabel>
}
>
<Radio.Group
<Radio.GroupWrapper
onChange={value => {
onEnableSingleValueChanged(value.target.value);
formChanged();
}}
>
<Radio value={SingleValueType.Minimum}>
{t('Minimum')}
</Radio>
<Radio value={SingleValueType.Exact}>
{t('Exact')}
</Radio>
<Radio value={SingleValueType.Maximum}>
{t('Maximum')}
</Radio>
</Radio.Group>
options={[
{
label: t('Minimum'),
value: SingleValueType.Minimum,
},
{ label: t('Exact'), value: SingleValueType.Exact },
{
label: t('Maximum'),
value: SingleValueType.Maximum,
},
]}
/>
</StyledRowFormItem>
</CollapsibleControl>
</CleanFormItem>

View File

@@ -30,7 +30,6 @@ import {
import { Global } from '@emotion/react';
import { Column } from 'react-table';
import { debounce } from 'lodash';
import { Space } from 'src/components/Space';
import { Input } from 'src/components/Input';
import {
BOOL_FALSE_DISPLAY,
@@ -141,12 +140,21 @@ const FormatPicker = ({
onChange: any;
value: FormatPickerValue;
}) => (
<Radio.Group value={value} onChange={onChange}>
<Space direction="vertical">
<Radio value={FormatPickerValue.Formatted}>{t('Formatted date')}</Radio>
<Radio value={FormatPickerValue.Original}>{t('Original value')}</Radio>
</Space>
</Radio.Group>
<Radio.GroupWrapper
spaceConfig={{
direction: 'vertical',
align: 'start',
size: 15,
wrap: false,
}}
size="large"
value={value}
onChange={onChange}
options={[
{ label: t('Formatted date'), value: FormatPickerValue.Formatted },
{ label: t('Original value'), value: FormatPickerValue.Original },
]}
/>
);
const FormatPickerContainer = styled.div`

View File

@@ -167,7 +167,7 @@ export const ExploreChartHeader = ({
<>
<PageHeaderWithActions
editableTitleProps={{
title: sliceName,
title: sliceName ?? '',
canEdit:
!slice ||
canOverwrite ||

View File

@@ -87,12 +87,6 @@ const ContentStyleWrapper = styled.div`
margin: 8px 0;
}
.vertical-radio {
display: block;
height: 40px;
line-height: 40px;
}
.section-title {
font-style: normal;
font-weight: ${theme.typography.weights.bold};

View File

@@ -45,16 +45,18 @@ export function CalendarFrame({ onChange, value }: FrameComponentProps) {
<div className="section-title">
{t('Configure Time Range: Previous...')}
</div>
<Radio.Group
<Radio.GroupWrapper
spaceConfig={{
direction: 'vertical',
size: 15,
align: 'start',
wrap: false,
}}
size="large"
value={value}
onChange={(e: any) => onChange(e.target.value)}
>
{CALENDAR_RANGE_OPTIONS.map(({ value, label }) => (
<Radio key={value} value={value} className="vertical-radio">
{label}
</Radio>
))}
</Radio.Group>
options={CALENDAR_RANGE_OPTIONS}
/>
</>
);
}

View File

@@ -41,16 +41,18 @@ export function CommonFrame(props: FrameComponentProps) {
<div className="section-title" data-test={DateFilterTestKey.CommonFrame}>
{t('Configure Time Range: Last...')}
</div>
<Radio.Group
<Radio.GroupWrapper
spaceConfig={{
direction: 'vertical',
size: 15,
align: 'start',
wrap: false,
}}
size="large"
value={commonRange}
onChange={(e: any) => props.onChange(e.target.value)}
>
{COMMON_RANGE_OPTIONS.map(({ value, label }) => (
<Radio key={value} value={value} className="vertical-radio">
{label}
</Radio>
))}
</Radio.Group>
options={COMMON_RANGE_OPTIONS}
/>
</>
);
}

View File

@@ -41,25 +41,22 @@ export function CurrentCalendarFrame({ onChange, value }: FrameComponentProps) {
<div className="section-title">
{t('Configure Time Range: Current...')}
</div>
<Radio.Group
value={value}
<Radio.GroupWrapper
spaceConfig={{
direction: 'vertical',
size: 15,
align: 'start',
wrap: true,
}}
size="large"
onChange={(e: any) => {
let newValue = e.target.value;
// Sanitization: Trim whitespace
newValue = newValue.trim();
// Validation: Check if the value is non-empty
if (newValue === '') {
return;
}
if (newValue === '') return;
onChange(newValue);
}}
>
{CURRENT_RANGE_OPTIONS.map(({ value, label }) => (
<Radio key={value} value={value} className="vertical-radio">
{label}
</Radio>
))}
</Radio.Group>
options={CURRENT_RANGE_OPTIONS}
/>
</>
);
}

View File

@@ -238,18 +238,15 @@ export function CustomFrame(props: FrameComponentProps) {
<div className="control-label">{t('Anchor to')}</div>
<Row align="middle">
<Col>
<Radio.Group
<Radio.GroupWrapper
options={[
{ value: 'now', label: t('Now') },
{ value: 'specific', label: t('Date/Time') },
]}
onChange={onAnchorMode}
defaultValue="now"
value={anchorMode}
>
<Radio key="now" value="now">
{t('NOW')}
</Radio>
<Radio key="specific" value="specific">
{t('Date/Time')}
</Radio>
</Radio.Group>
/>
</Col>
{anchorMode !== 'now' && (
<Col>

View File

@@ -42,7 +42,7 @@ describe('CalendarFrame', () => {
const radios = screen.getAllByRole('radio');
expect(radios).toHaveLength(CALENDAR_RANGE_OPTIONS.length);
CALENDAR_RANGE_OPTIONS.forEach(option => {
expect(screen.getByText(option.label)).toBeInTheDocument();
expect(screen.getByText(option.label as string)).toBeInTheDocument();
});
});
@@ -56,7 +56,7 @@ describe('CalendarFrame', () => {
);
const secondOption = CALENDAR_RANGE_OPTIONS[1];
const radio = screen.getByLabelText(secondOption.label);
const radio = screen.getByLabelText(secondOption.label as string);
fireEvent.click(radio);
expect(mockOnChange).toHaveBeenCalledWith(secondOption.value);
@@ -85,6 +85,8 @@ describe('CalendarFrame', () => {
const thirdOption = CALENDAR_RANGE_OPTIONS[2];
expect(thirdOption.value).toBe(PreviousCalendarQuarter);
expect(screen.getByLabelText(thirdOption.label)).toBeInTheDocument();
expect(
screen.getByLabelText(thirdOption.label as string),
).toBeInTheDocument();
});
});

View File

@@ -167,8 +167,8 @@ test('renders anchor with now option', async () => {
);
await waitForElementToBeRemoved(() => screen.queryByLabelText('Loading'));
expect(screen.getByText('Anchor to')).toBeInTheDocument();
expect(screen.getByRole('radio', { name: 'NOW' })).toBeInTheDocument();
expect(screen.getByRole('radio', { name: 'Date/Time' })).toBeInTheDocument();
expect(screen.getByLabelText('Now')).toBeInTheDocument();
expect(screen.getByLabelText('Date/Time')).toBeInTheDocument();
expect(screen.queryByPlaceholderText('Select date')).not.toBeInTheDocument();
});
@@ -180,8 +180,8 @@ test('renders anchor with date/time option', async () => {
);
await waitForElementToBeRemoved(() => screen.queryByLabelText('Loading'));
expect(screen.getByText('Anchor to')).toBeInTheDocument();
expect(screen.getByRole('radio', { name: 'NOW' })).toBeInTheDocument();
expect(screen.getByRole('radio', { name: 'Date/Time' })).toBeInTheDocument();
expect(screen.getByLabelText('Now')).toBeInTheDocument();
expect(screen.getByLabelText('Date/Time')).toBeInTheDocument();
expect(screen.getByPlaceholderText('Select date')).toBeInTheDocument();
});

View File

@@ -32,6 +32,7 @@ import {
CurrentQuarter,
CurrentDay,
} from 'src/explore/components/controls/DateFilterControl/types';
import { CheckboxOptionType } from 'src/components/Radio';
import { extendedDayjs } from 'src/utils/dates';
export const FRAME_OPTIONS: SelectOptionType[] = [
@@ -43,7 +44,7 @@ export const FRAME_OPTIONS: SelectOptionType[] = [
{ value: 'No filter', label: t('No filter') },
];
export const COMMON_RANGE_OPTIONS: SelectOptionType[] = [
export const COMMON_RANGE_OPTIONS: CheckboxOptionType[] = [
{ value: 'Last day', label: t('Last day') },
{ value: 'Last week', label: t('Last week') },
{ value: 'Last month', label: t('Last month') },
@@ -51,20 +52,20 @@ export const COMMON_RANGE_OPTIONS: SelectOptionType[] = [
{ value: 'Last year', label: t('Last year') },
];
export const COMMON_RANGE_VALUES_SET = new Set(
COMMON_RANGE_OPTIONS.map(({ value }) => value),
COMMON_RANGE_OPTIONS.map(value => value.value),
);
export const CALENDAR_RANGE_OPTIONS: SelectOptionType[] = [
export const CALENDAR_RANGE_OPTIONS: CheckboxOptionType[] = [
{ value: PreviousCalendarWeek, label: t('previous calendar week') },
{ value: PreviousCalendarMonth, label: t('previous calendar month') },
{ value: PreviousCalendarQuarter, label: t('previous calendar quarter') },
{ value: PreviousCalendarYear, label: t('previous calendar year') },
];
export const CALENDAR_RANGE_VALUES_SET = new Set(
CALENDAR_RANGE_OPTIONS.map(({ value }) => value),
CALENDAR_RANGE_OPTIONS.map(value => value.value),
);
export const CURRENT_RANGE_OPTIONS: SelectOptionType[] = [
export const CURRENT_RANGE_OPTIONS: CheckboxOptionType[] = [
{ value: CurrentDay, label: t('Current day') },
{ value: CurrentWeek, label: t('Current week') },
{ value: CurrentMonth, label: t('Current month') },
@@ -72,7 +73,7 @@ export const CURRENT_RANGE_OPTIONS: SelectOptionType[] = [
{ value: CurrentYear, label: t('Current year') },
];
export const CURRENT_RANGE_VALUES_SET = new Set(
CURRENT_RANGE_OPTIONS.map(({ value }) => value),
CURRENT_RANGE_OPTIONS.map(value => value.value),
);
const GRAIN_OPTIONS = [

View File

@@ -41,7 +41,7 @@ import TimezoneSelector from 'src/components/TimezoneSelector';
import LabeledErrorBoundInput from 'src/components/Form/LabeledErrorBoundInput';
import Icons from 'src/components/Icons';
import { CronError } from 'src/components/CronPicker';
import { RadioChangeEvent } from 'src/components';
import { Radio, RadioChangeEvent } from 'src/components/Radio';
import { Input } from 'src/components/Input';
import withToasts from 'src/components/MessageToasts/withToasts';
import { ChartState } from 'src/explore/types';
@@ -68,8 +68,6 @@ import {
TimezoneHeaderStyle,
SectionHeaderStyle,
StyledMessageContentTitle,
StyledRadio,
StyledRadioGroup,
} from './styles';
interface ReportProps {
@@ -257,24 +255,32 @@ function ReportModal({
<h4>{t('Message content')}</h4>
</StyledMessageContentTitle>
<div className="inline-container">
<StyledRadioGroup
<Radio.GroupWrapper
spaceConfig={{
direction: 'vertical',
size: 'middle',
align: 'start',
wrap: false,
}}
onChange={(event: RadioChangeEvent) => {
setCurrentReport({ report_format: event.target.value });
}}
value={currentReport.report_format || defaultNotificationFormat}
>
{isTextBasedChart && (
<StyledRadio value={NotificationFormats.Text}>
{t('Text embedded in email')}
</StyledRadio>
)}
<StyledRadio value={NotificationFormats.PNG}>
{t('Image (PNG) embedded in email')}
</StyledRadio>
<StyledRadio value={NotificationFormats.CSV}>
{t('Formatted CSV attached in email')}
</StyledRadio>
</StyledRadioGroup>
options={[
{
label: t('Text embedded in email'),
value: NotificationFormats.Text,
},
{
label: t('Image (PNG) embedded in email'),
value: NotificationFormats.PNG,
},
{
label: t('Formatted CSV attached in email'),
value: NotificationFormats.CSV,
},
]}
/>
</div>
</>
);

View File

@@ -108,10 +108,6 @@ export const StyledRadio = styled(Radio)`
line-height: ${({ theme }) => theme.gridUnit * 8}px;
`;
export const StyledRadioGroup = styled(Radio.Group)`
margin-left: ${({ theme }) => theme.gridUnit * 0.5}px;
`;
export const antDErrorAlertStyles = (theme: SupersetTheme) => css`
margin: ${theme.gridUnit * 4}px;
margin-top: 0;

View File

@@ -33,9 +33,11 @@ export type InitialState = {
id: number;
table: string;
description: {
columns?: {
name: string;
columns: {
name: string;
type: string;
longType: string;
}[];
dataPreviewQueryId?: string;
} & Record<string, any>;

View File

@@ -66,10 +66,12 @@ export type FetchTableMetadataQueryParams = {
};
type ColumnKeyTypeType = 'pk' | 'fk' | 'index';
interface Column {
export interface Column {
name: string;
keys?: { type: ColumnKeyTypeType }[];
type: string;
comment?: string;
longType: string;
}
export type TableMetaData = {
@@ -83,6 +85,7 @@ export type TableMetaData = {
selectStar?: string;
view?: string;
columns: Column[];
comment?: string;
};
type TableMetadataResponse = {
@@ -143,6 +146,9 @@ const tableApi = api.injectEndpoints({
)}`,
transformResponse: ({ json }: JsonResponse) => json,
}),
providesTags: (result, error, { table }) => [
{ type: 'TableMetadatas', id: table },
],
}),
}),
});
@@ -150,6 +156,8 @@ const tableApi = api.injectEndpoints({
export const {
useLazyTablesQuery,
useTablesQuery,
useLazyTableMetadataQuery,
useLazyTableExtendedMetadataQuery,
useTableMetadataQuery,
useTableExtendedMetadataQuery,
endpoints: tableEndpoints,

View File

@@ -30,7 +30,7 @@ from marshmallow import ValidationError
from werkzeug.wrappers import Response as WerkzeugResponse
from werkzeug.wsgi import FileWrapper
from superset import app, is_feature_enabled, thumbnail_cache
from superset import app, is_feature_enabled
from superset.charts.filters import (
ChartAllTextFilter,
ChartCertifiedFilter,
@@ -84,7 +84,12 @@ from superset.models.slice import Slice
from superset.tasks.thumbnails import cache_chart_thumbnail
from superset.tasks.utils import get_current_user
from superset.utils import json
from superset.utils.screenshots import ChartScreenshot, DEFAULT_CHART_WINDOW_SIZE
from superset.utils.screenshots import (
ChartScreenshot,
DEFAULT_CHART_WINDOW_SIZE,
ScreenshotCachePayload,
StatusValues,
)
from superset.utils.urls import get_url_path
from superset.views.base_api import (
BaseSupersetModelRestApi,
@@ -564,8 +569,14 @@ class ChartRestApi(BaseSupersetModelRestApi):
schema:
$ref: '#/components/schemas/screenshot_query_schema'
responses:
200:
description: Chart async result
content:
application/json:
schema:
$ref: "#/components/schemas/ChartCacheScreenshotResponseSchema"
202:
description: Chart async result
description: Chart screenshot task created
content:
application/json:
schema:
@@ -580,6 +591,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
$ref: '#/components/responses/500'
"""
rison_dict = kwargs["rison"]
force = rison_dict.get("force")
window_size = rison_dict.get("window_size") or DEFAULT_CHART_WINDOW_SIZE
# Don't shrink the image if thumb_size is not specified
@@ -591,25 +603,36 @@ class ChartRestApi(BaseSupersetModelRestApi):
chart_url = get_url_path("Superset.slice", slice_id=chart.id)
screenshot_obj = ChartScreenshot(chart_url, chart.digest)
cache_key = screenshot_obj.cache_key(window_size, thumb_size)
cache_key = screenshot_obj.get_cache_key(window_size, thumb_size)
cache_payload = (
screenshot_obj.get_from_cache_key(cache_key) or ScreenshotCachePayload()
)
image_url = get_url_path(
"ChartRestApi.screenshot", pk=chart.id, digest=cache_key
)
def trigger_celery() -> WerkzeugResponse:
def build_response(status_code: int) -> WerkzeugResponse:
return self.response(
status_code,
cache_key=cache_key,
chart_url=chart_url,
image_url=image_url,
task_updated_at=cache_payload.get_timestamp(),
task_status=cache_payload.get_status(),
)
if cache_payload.should_trigger_task(force):
logger.info("Triggering screenshot ASYNC")
screenshot_obj.cache.set(cache_key, ScreenshotCachePayload())
cache_chart_thumbnail.delay(
current_user=get_current_user(),
chart_id=chart.id,
force=True,
window_size=window_size,
thumb_size=thumb_size,
force=force,
)
return self.response(
202, cache_key=cache_key, chart_url=chart_url, image_url=image_url
)
return trigger_celery()
return build_response(202)
return build_response(200)
@expose("/<pk>/screenshot/<digest>/", methods=("GET",))
@protect()
@@ -635,7 +658,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
name: digest
responses:
200:
description: Chart thumbnail image
description: Chart screenshot image
content:
image/*:
schema:
@@ -652,16 +675,16 @@ class ChartRestApi(BaseSupersetModelRestApi):
"""
chart = self.datamodel.get(pk, self._base_filters)
# Making sure the chart still exists
if not chart:
return self.response_404()
# fetch the chart screenshot using the current user and cache if set
if img := ChartScreenshot.get_from_cache_key(thumbnail_cache, digest):
return Response(
FileWrapper(img), mimetype="image/png", direct_passthrough=True
)
# TODO: return an empty image
if cache_payload := ChartScreenshot.get_from_cache_key(digest):
if cache_payload.status == StatusValues.UPDATED:
return Response(
FileWrapper(cache_payload.get_image()),
mimetype="image/png",
direct_passthrough=True,
)
return self.response_404()
@expose("/<pk>/thumbnail/<digest>/", methods=("GET",))
@@ -685,9 +708,10 @@ class ChartRestApi(BaseSupersetModelRestApi):
type: integer
name: pk
- in: path
name: digest
description: A hex digest that makes this chart unique
schema:
type: string
name: digest
responses:
200:
description: Chart thumbnail image
@@ -712,34 +736,6 @@ class ChartRestApi(BaseSupersetModelRestApi):
return self.response_404()
current_user = get_current_user()
url = get_url_path("Superset.slice", slice_id=chart.id)
if kwargs["rison"].get("force", False):
logger.info(
"Triggering thumbnail compute (chart id: %s) ASYNC", str(chart.id)
)
cache_chart_thumbnail.delay(
current_user=current_user,
chart_id=chart.id,
force=True,
)
return self.response(202, message="OK Async")
# fetch the chart screenshot using the current user and cache if set
screenshot = ChartScreenshot(url, chart.digest).get_from_cache(
cache=thumbnail_cache
)
# If not screenshot then send request to compute thumb to celery
if not screenshot:
self.incr_stats("async", self.thumbnail.__name__)
logger.info(
"Triggering thumbnail compute (chart id: %s) ASYNC", str(chart.id)
)
cache_chart_thumbnail.delay(
current_user=current_user,
chart_id=chart.id,
force=True,
)
return self.response(202, message="OK Async")
# If digests
if chart.digest != digest:
self.incr_stats("redirect", self.thumbnail.__name__)
return redirect(
@@ -747,9 +743,34 @@ class ChartRestApi(BaseSupersetModelRestApi):
f"{self.__class__.__name__}.thumbnail", pk=pk, digest=chart.digest
)
)
url = get_url_path("Superset.slice", slice_id=chart.id)
screenshot_obj = ChartScreenshot(url, chart.digest)
cache_key = screenshot_obj.get_cache_key()
cache_payload = (
screenshot_obj.get_from_cache_key(cache_key) or ScreenshotCachePayload()
)
if cache_payload.should_trigger_task():
self.incr_stats("async", self.thumbnail.__name__)
logger.info(
"Triggering thumbnail compute (chart id: %s) ASYNC", str(chart.id)
)
screenshot_obj.cache.set(cache_key, ScreenshotCachePayload())
cache_chart_thumbnail.delay(
current_user=current_user,
chart_id=chart.id,
force=False,
)
return self.response(
202,
task_updated_at=cache_payload.get_timestamp(),
task_status=cache_payload.get_status(),
)
self.incr_stats("from_cache", self.thumbnail.__name__)
return Response(
FileWrapper(screenshot), mimetype="image/png", direct_passthrough=True
FileWrapper(cache_payload.get_image()),
mimetype="image/png",
direct_passthrough=True,
)
@expose("/export/", methods=("GET",))

View File

@@ -304,6 +304,21 @@ class ChartCacheScreenshotResponseSchema(Schema):
image_url = fields.String(
metadata={"description": "The url to fetch the screenshot"}
)
task_status = fields.String(
metadata={"description": "The status of the async screenshot"}
)
task_updated_at = fields.String(
metadata={"description": "The timestamp of the last change in status"}
)
class ChartGetCachedScreenshotResponseSchema(Schema):
task_status = fields.String(
metadata={"description": "The status of the async screenshot"}
)
task_updated_at = fields.String(
metadata={"description": "The timestamp of the last change in status"}
)
class ChartDataColumnSchema(Schema):

View File

@@ -60,6 +60,9 @@ class TablesDatabaseCommand(BaseCommand):
catalog=self._catalog_name,
schema=self._schema_name,
datasource_names=sorted(
# get_all_table_names_in_schema may return raw (unserialized) cached
# results, so we wrap them as DatasourceName objects here instead of
# directly in the method to ensure consistency.
DatasourceName(*datasource_name)
for datasource_name in self._model.get_all_table_names_in_schema(
catalog=self._catalog_name,
@@ -76,6 +79,9 @@ class TablesDatabaseCommand(BaseCommand):
catalog=self._catalog_name,
schema=self._schema_name,
datasource_names=sorted(
# get_all_view_names_in_schema may return raw (unserialized) cached
# results, so we wrap them as DatasourceName objects here instead of
# directly in the method to ensure consistency.
DatasourceName(*datasource_name)
for datasource_name in self._model.get_all_view_names_in_schema(
catalog=self._catalog_name,

View File

@@ -300,13 +300,16 @@ class BaseReportState:
)
user = security_manager.find_user(username)
max_width = app.config["ALERT_REPORTS_MAX_CUSTOM_SCREENSHOT_WIDTH"]
if self._report_schedule.chart:
url = self._get_url()
window_width, window_height = app.config["WEBDRIVER_WINDOW"]["slice"]
window_size = (
self._report_schedule.custom_width or window_width,
self._report_schedule.custom_height or window_height,
)
width = min(max_width, self._report_schedule.custom_width or window_width)
height = self._report_schedule.custom_height or window_height
window_size = (width, height)
screenshots: list[Union[ChartScreenshot, DashboardScreenshot]] = [
ChartScreenshot(
url,
@@ -317,11 +320,12 @@ class BaseReportState:
]
else:
urls = self.get_dashboard_urls()
window_width, window_height = app.config["WEBDRIVER_WINDOW"]["dashboard"]
window_size = (
self._report_schedule.custom_width or window_width,
self._report_schedule.custom_height or window_height,
)
width = min(max_width, self._report_schedule.custom_width or window_width)
height = self._report_schedule.custom_height or window_height
window_size = (width, height)
screenshots = [
DashboardScreenshot(
url,
@@ -578,9 +582,9 @@ class BaseReportState:
SupersetError(
message=ex.message,
error_type=SupersetErrorType.REPORT_NOTIFICATION_ERROR,
level=ErrorLevel.ERROR
if ex.status >= 500
else ErrorLevel.WARNING,
level=(
ErrorLevel.ERROR if ex.status >= 500 else ErrorLevel.WARNING
),
)
)
if notification_errors:

View File

@@ -729,8 +729,10 @@ THUMBNAIL_CHART_DIGEST_FUNC: Callable[[Slice, ExecutorType, str], str | None] |
THUMBNAIL_CACHE_CONFIG: CacheConfig = {
"CACHE_TYPE": "NullCache",
"CACHE_DEFAULT_TIMEOUT": int(timedelta(days=7).total_seconds()),
"CACHE_NO_NULL_WARNING": True,
}
THUMBNAIL_ERROR_CACHE_TTL = int(timedelta(days=1).total_seconds())
# Time before selenium times out after trying to locate an element on the page and wait
# for that element to load for a screenshot.
@@ -1905,6 +1907,15 @@ class ExtraDynamicQueryFilters(TypedDict, total=False):
EXTRA_DYNAMIC_QUERY_FILTERS: ExtraDynamicQueryFilters = {}
# The migrations that add catalog permissions might take a considerably long time
# to execute as it has to create permissions to all schemas and catalogs from all
# other catalogs accessible by the credentials. This flag allows to skip the
# creation of these secondary perms, and focus only on permissions for the default
# catalog. These secondary permissions can be created later by editing the DB
# connection via the UI (without downtime).
CATALOGS_SIMPLIFIED_MIGRATION: bool = False
# -------------------------------------------------------------------
# * WARNING: STOP EDITING HERE *
# -------------------------------------------------------------------

View File

@@ -31,7 +31,7 @@ from marshmallow import ValidationError
from werkzeug.wrappers import Response as WerkzeugResponse
from werkzeug.wsgi import FileWrapper
from superset import db, thumbnail_cache
from superset import db
from superset.charts.schemas import ChartEntityResponseSchema
from superset.commands.dashboard.copy import CopyDashboardCommand
from superset.commands.dashboard.create import CreateDashboardCommand
@@ -115,6 +115,7 @@ from superset.utils.pdf import build_pdf_from_screenshots
from superset.utils.screenshots import (
DashboardScreenshot,
DEFAULT_DASHBOARD_WINDOW_SIZE,
ScreenshotCachePayload,
)
from superset.utils.urls import get_url_path
from superset.views.base_api import (
@@ -1022,110 +1023,6 @@ class DashboardRestApi(BaseSupersetModelRestApi):
response.set_cookie(token, "done", max_age=600)
return response
@expose("/<pk>/thumbnail/<digest>/", methods=("GET",))
@validate_feature_flags(["THUMBNAILS"])
@protect()
@safe
@rison(thumbnail_query_schema)
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.thumbnail",
log_to_statsd=False,
)
def thumbnail(self, pk: int, digest: str, **kwargs: Any) -> WerkzeugResponse:
"""Compute async or get already computed dashboard thumbnail from cache.
---
get:
summary: Get dashboard's thumbnail
description: >-
Computes async or get already computed dashboard thumbnail from cache.
parameters:
- in: path
schema:
type: integer
name: pk
- in: path
name: digest
description: A hex digest that makes this dashboard unique
schema:
type: string
- in: query
name: q
content:
application/json:
schema:
$ref: '#/components/schemas/thumbnail_query_schema'
responses:
200:
description: Dashboard thumbnail image
content:
image/*:
schema:
type: string
format: binary
202:
description: Thumbnail does not exist on cache, fired async to compute
content:
application/json:
schema:
type: object
properties:
message:
type: string
302:
description: Redirects to the current digest
401:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
dashboard = cast(Dashboard, self.datamodel.get(pk, self._base_filters))
if not dashboard:
return self.response_404()
dashboard_url = get_url_path(
"Superset.dashboard", dashboard_id_or_slug=dashboard.id
)
# If force, request a screenshot from the workers
current_user = get_current_user()
if kwargs["rison"].get("force", False):
cache_dashboard_thumbnail.delay(
current_user=current_user,
dashboard_id=dashboard.id,
force=True,
)
return self.response(202, message="OK Async")
# fetch the dashboard screenshot using the current user and cache if set
screenshot = DashboardScreenshot(
dashboard_url, dashboard.digest
).get_from_cache(cache=thumbnail_cache)
# If the screenshot does not exist, request one from the workers
if not screenshot:
self.incr_stats("async", self.thumbnail.__name__)
cache_dashboard_thumbnail.delay(
current_user=current_user,
dashboard_id=dashboard.id,
force=True,
)
return self.response(202, message="OK Async")
# If digests
if dashboard.digest != digest:
self.incr_stats("redirect", self.thumbnail.__name__)
return redirect(
url_for(
f"{self.__class__.__name__}.thumbnail",
pk=pk,
digest=dashboard.digest,
)
)
self.incr_stats("from_cache", self.thumbnail.__name__)
return Response(
FileWrapper(screenshot), mimetype="image/png", direct_passthrough=True
)
@expose("/<pk>/cache_dashboard_screenshot/", methods=("POST",))
@validate_feature_flags(["THUMBNAILS", "ENABLE_DASHBOARD_SCREENSHOT_ENDPOINTS"])
@protect()
@@ -1172,7 +1069,6 @@ class DashboardRestApi(BaseSupersetModelRestApi):
payload = CacheScreenshotSchema().load(request.json)
except ValidationError as error:
return self.response_400(message=error.messages)
dashboard = cast(Dashboard, self.datamodel.get(pk, self._base_filters))
if not dashboard:
return self.response_404()
@@ -1182,7 +1078,7 @@ class DashboardRestApi(BaseSupersetModelRestApi):
)
# Don't shrink the image if thumb_size is not specified
thumb_size = kwargs["rison"].get("thumb_size") or window_size
force = kwargs["rison"].get("force", False)
dashboard_state: DashboardPermalinkState = {
"dataMask": payload.get("dataMask", {}),
"activeTabs": payload.get("activeTabs", []),
@@ -1197,13 +1093,29 @@ class DashboardRestApi(BaseSupersetModelRestApi):
dashboard_url = get_url_path("Superset.dashboard_permalink", key=permalink_key)
screenshot_obj = DashboardScreenshot(dashboard_url, dashboard.digest)
cache_key = screenshot_obj.cache_key(window_size, thumb_size, dashboard_state)
cache_key = screenshot_obj.get_cache_key(
window_size, thumb_size, dashboard_state
)
image_url = get_url_path(
"DashboardRestApi.screenshot", pk=dashboard.id, digest=cache_key
)
cache_payload = (
screenshot_obj.get_from_cache_key(cache_key) or ScreenshotCachePayload()
)
def trigger_celery() -> WerkzeugResponse:
def build_response(status_code: int) -> WerkzeugResponse:
return self.response(
status_code,
cache_key=cache_key,
dashboard_url=dashboard_url,
image_url=image_url,
task_updated_at=cache_payload.get_timestamp(),
task_status=cache_payload.get_status(),
)
if cache_payload.should_trigger_task(force):
logger.info("Triggering screenshot ASYNC")
screenshot_obj.cache.set(cache_key, ScreenshotCachePayload())
cache_dashboard_screenshot.delay(
username=get_current_user(),
guest_token=(
@@ -1213,19 +1125,12 @@ class DashboardRestApi(BaseSupersetModelRestApi):
),
dashboard_id=dashboard.id,
dashboard_url=dashboard_url,
cache_key=cache_key,
force=False,
thumb_size=thumb_size,
window_size=window_size,
force=force,
)
return self.response(
202,
cache_key=cache_key,
dashboard_url=dashboard_url,
image_url=image_url,
)
return trigger_celery()
return build_response(202)
return build_response(200)
@expose("/<pk>/screenshot/<digest>/", methods=("GET",))
@validate_feature_flags(["THUMBNAILS", "ENABLE_DASHBOARD_SCREENSHOT_ENDPOINTS"])
@@ -1282,9 +1187,12 @@ class DashboardRestApi(BaseSupersetModelRestApi):
# fetch the dashboard screenshot using the current user and cache if set
if img := DashboardScreenshot.get_from_cache_key(thumbnail_cache, digest):
if cache_payload := DashboardScreenshot.get_from_cache_key(digest):
image = cache_payload.get_image()
if not image:
return self.response_404()
if download_format == "pdf":
pdf_img = img.getvalue()
pdf_img = image.getvalue()
# Convert the screenshot to PDF
pdf_data = build_pdf_from_screenshots([pdf_img])
@@ -1296,13 +1204,120 @@ class DashboardRestApi(BaseSupersetModelRestApi):
)
if download_format == "png":
return Response(
FileWrapper(img),
FileWrapper(image),
mimetype="image/png",
direct_passthrough=True,
)
return self.response_404()
@expose("/<pk>/thumbnail/<digest>/", methods=("GET",))
@validate_feature_flags(["THUMBNAILS"])
@protect()
@safe
@rison(thumbnail_query_schema)
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.thumbnail",
log_to_statsd=False,
)
def thumbnail(self, pk: int, digest: str, **kwargs: Any) -> WerkzeugResponse:
"""Compute async or get already computed dashboard thumbnail from cache.
---
get:
summary: Get dashboard's thumbnail
description: >-
Computes async or get already computed dashboard thumbnail from cache.
parameters:
- in: path
schema:
type: integer
name: pk
- in: path
name: digest
description: A hex digest that makes this dashboard unique
schema:
type: string
responses:
200:
description: Dashboard thumbnail image
content:
image/*:
schema:
type: string
format: binary
202:
description: Thumbnail does not exist on cache, fired async to compute
content:
application/json:
schema:
type: object
properties:
message:
type: string
302:
description: Redirects to the current digest
401:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
dashboard = cast(Dashboard, self.datamodel.get(pk, self._base_filters))
if not dashboard:
return self.response_404()
current_user = get_current_user()
dashboard_url = get_url_path(
"Superset.dashboard", dashboard_id_or_slug=dashboard.id
)
if dashboard.digest != digest:
self.incr_stats("redirect", self.thumbnail.__name__)
return redirect(
url_for(
f"{self.__class__.__name__}.thumbnail",
pk=pk,
digest=dashboard.digest,
)
)
screenshot_obj = DashboardScreenshot(dashboard_url, dashboard.digest)
cache_key = screenshot_obj.get_cache_key()
cache_payload = (
screenshot_obj.get_from_cache_key(cache_key) or ScreenshotCachePayload()
)
image_url = get_url_path(
"DashboardRestApi.thumbnail", pk=dashboard.id, digest=cache_key
)
if cache_payload.should_trigger_task():
self.incr_stats("async", self.thumbnail.__name__)
logger.info(
"Triggering thumbnail compute (dashboard id: %s) ASYNC",
str(dashboard.id),
)
screenshot_obj.cache.set(cache_key, ScreenshotCachePayload())
cache_dashboard_thumbnail.delay(
current_user=current_user,
dashboard_id=dashboard.id,
force=False,
)
return self.response(
202,
cache_key=cache_key,
dashboard_url=dashboard_url,
image_url=image_url,
task_updated_at=cache_payload.get_timestamp(),
task_status=cache_payload.get_status(),
)
self.incr_stats("from_cache", self.thumbnail.__name__)
return Response(
FileWrapper(cache_payload.get_image()),
mimetype="image/png",
direct_passthrough=True,
)
@expose("/favorite_status/", methods=("GET",))
@protect()
@safe

View File

@@ -507,6 +507,12 @@ class DashboardCacheScreenshotResponseSchema(Schema):
image_url = fields.String(
metadata={"description": "The url to fetch the screenshot"}
)
task_status = fields.String(
metadata={"description": "The status of the async screenshot"}
)
task_updated_at = fields.String(
metadata={"description": "The timestamp of the last change in status"}
)
class CacheScreenshotSchema(Schema):

View File

@@ -23,6 +23,7 @@ from typing import Any, Type, Union
import sqlalchemy as sa
from alembic import op
from flask import current_app
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
@@ -425,9 +426,13 @@ def upgrade_database_catalogs(
# update `schema_perm` and `catalog_perm` for tables and charts
update_schema_catalog_perms(session, database, catalog_perm, default_catalog, False)
# add any new catalogs discovered and their schemas
new_catalog_pvms = add_non_default_catalogs(database, default_catalog, session)
pvms.update(new_catalog_pvms)
if (
not current_app.config["CATALOGS_SIMPLIFIED_MIGRATION"]
and not database.is_oauth2_enabled()
):
# add any new catalogs discovered and their schemas
new_catalog_pvms = add_non_default_catalogs(database, default_catalog, session)
pvms.update(new_catalog_pvms)
# add default catalog permission and permissions for any new found schemas, and also
# permissions for new catalogs and their schemas

View File

@@ -380,9 +380,7 @@ def event_after_chart_changed(
_mapper: Mapper, _connection: Connection, target: Slice
) -> None:
cache_chart_thumbnail.delay(
current_user=get_current_user(),
chart_id=target.id,
force=True,
current_user=get_current_user(), chart_id=target.id, force=True
)

View File

@@ -38,7 +38,7 @@ logger = logging.getLogger(__name__)
def cache_chart_thumbnail(
current_user: Optional[str],
chart_id: int,
force: bool = False,
force: bool,
window_size: Optional[WindowSize] = None,
thumb_size: Optional[WindowSize] = None,
) -> None:
@@ -64,10 +64,9 @@ def cache_chart_thumbnail(
screenshot = ChartScreenshot(url, chart.digest)
screenshot.compute_and_cache(
user=user,
cache=thumbnail_cache,
force=force,
window_size=window_size,
thumb_size=thumb_size,
force=force,
)
return None
@@ -76,7 +75,7 @@ def cache_chart_thumbnail(
def cache_dashboard_thumbnail(
current_user: Optional[str],
dashboard_id: int,
force: bool = False,
force: bool,
thumb_size: Optional[WindowSize] = None,
window_size: Optional[WindowSize] = None,
) -> None:
@@ -101,10 +100,9 @@ def cache_dashboard_thumbnail(
screenshot = DashboardScreenshot(url, dashboard.digest)
screenshot.compute_and_cache(
user=user,
cache=thumbnail_cache,
force=force,
window_size=window_size,
thumb_size=thumb_size,
force=force,
)
@@ -113,7 +111,7 @@ def cache_dashboard_screenshot( # pylint: disable=too-many-arguments
username: str,
dashboard_id: int,
dashboard_url: str,
force: bool = True,
force: bool,
cache_key: Optional[str] = None,
guest_token: Optional[GuestToken] = None,
thumb_size: Optional[WindowSize] = None,
@@ -145,9 +143,8 @@ def cache_dashboard_screenshot( # pylint: disable=too-many-arguments
screenshot = DashboardScreenshot(dashboard_url, dashboard.digest)
screenshot.compute_and_cache(
user=current_user,
cache=thumbnail_cache,
force=force,
window_size=window_size,
thumb_size=thumb_size,
cache_key=cache_key,
force=force,
)

View File

@@ -17,12 +17,14 @@
from __future__ import annotations
import logging
from datetime import datetime
from enum import Enum
from io import BytesIO
from typing import TYPE_CHECKING
from flask import current_app
from superset import feature_flag_manager
from superset import app, feature_flag_manager, thumbnail_cache
from superset.dashboards.permalink.types import DashboardPermalinkState
from superset.extensions import event_logger
from superset.utils.hashing import md5_sha_from_dict
@@ -54,6 +56,70 @@ if TYPE_CHECKING:
from flask_caching import Cache
class StatusValues(Enum):
PENDING = "Pending"
COMPUTING = "Computing"
UPDATED = "Updated"
ERROR = "Error"
class ScreenshotCachePayload:
def __init__(self, image: bytes | None = None):
self._image = image
self._timestamp = datetime.now().isoformat()
self.status = StatusValues.PENDING
if image:
self.status = StatusValues.UPDATED
def update_timestamp(self) -> None:
self._timestamp = datetime.now().isoformat()
def pending(self) -> None:
self.update_timestamp()
self._image = None
self.status = StatusValues.PENDING
def computing(self) -> None:
self.update_timestamp()
self._image = None
self.status = StatusValues.COMPUTING
def update(self, image: bytes) -> None:
self.update_timestamp()
self.status = StatusValues.UPDATED
self._image = image
def error(
self,
) -> None:
self.update_timestamp()
self.status = StatusValues.ERROR
def get_image(self) -> BytesIO | None:
if not self._image:
return None
return BytesIO(self._image)
def get_timestamp(self) -> str:
return self._timestamp
def get_status(self) -> str:
return self.status.value
def is_error_cache_ttl_expired(self) -> bool:
error_cache_ttl = app.config["THUMBNAIL_ERROR_CACHE_TTL"]
return (
datetime.now() - datetime.fromisoformat(self.get_timestamp())
).total_seconds() > error_cache_ttl
def should_trigger_task(self, force: bool = False) -> bool:
return (
force
or self.status == StatusValues.PENDING
or (self.status == StatusValues.ERROR and self.is_error_cache_ttl_expired())
)
class BaseScreenshot:
driver_type = current_app.config["WEBDRIVER_TYPE"]
url: str
@@ -63,6 +129,7 @@ class BaseScreenshot:
element: str = ""
window_size: WindowSize = DEFAULT_SCREENSHOT_WINDOW_SIZE
thumb_size: WindowSize = DEFAULT_SCREENSHOT_THUMBNAIL_SIZE
cache: Cache = thumbnail_cache
def __init__(self, url: str, digest: str | None):
self.digest = digest
@@ -75,7 +142,14 @@ class BaseScreenshot:
return WebDriverPlaywright(self.driver_type, window_size)
return WebDriverSelenium(self.driver_type, window_size)
def cache_key(
def get_screenshot(
self, user: User, window_size: WindowSize | None = None
) -> bytes | None:
driver = self.driver(window_size)
self.screenshot = driver.get_screenshot(self.url, self.element, user)
return self.screenshot
def get_cache_key(
self,
window_size: bool | WindowSize | None = None,
thumb_size: bool | WindowSize | None = None,
@@ -91,69 +165,35 @@ class BaseScreenshot:
}
return md5_sha_from_dict(args)
def get_screenshot(
self, user: User, window_size: WindowSize | None = None
) -> bytes | None:
driver = self.driver(window_size)
with event_logger.log_context("screenshot", screenshot_url=self.url):
self.screenshot = driver.get_screenshot(self.url, self.element, user)
return self.screenshot
def get(
self,
user: User = None,
cache: Cache = None,
thumb_size: WindowSize | None = None,
) -> BytesIO | None:
"""
Get thumbnail screenshot has BytesIO from cache or fetch
:param user: None to use current user or User Model to login and fetch
:param cache: The cache to use
:param thumb_size: Override thumbnail site
"""
payload: bytes | None = None
cache_key = self.cache_key(self.window_size, thumb_size)
if cache:
payload = cache.get(cache_key)
if not payload:
payload = self.compute_and_cache(
user=user, thumb_size=thumb_size, cache=cache
)
else:
logger.info("Loaded thumbnail from cache: %s", cache_key)
if payload:
return BytesIO(payload)
return None
def get_from_cache(
self,
cache: Cache,
window_size: WindowSize | None = None,
thumb_size: WindowSize | None = None,
) -> BytesIO | None:
cache_key = self.cache_key(window_size, thumb_size)
return self.get_from_cache_key(cache, cache_key)
) -> ScreenshotCachePayload | None:
cache_key = self.get_cache_key(window_size, thumb_size)
return self.get_from_cache_key(cache_key)
@staticmethod
def get_from_cache_key(cache: Cache, cache_key: str) -> BytesIO | None:
@classmethod
def get_from_cache_key(cls, cache_key: str) -> ScreenshotCachePayload | None:
logger.info("Attempting to get from cache: %s", cache_key)
if payload := cache.get(cache_key):
return BytesIO(payload)
if payload := cls.cache.get(cache_key):
# for backwards compatability, byte objects should be converted
if not isinstance(payload, ScreenshotCachePayload):
payload = ScreenshotCachePayload(payload)
return payload
logger.info("Failed at getting from cache: %s", cache_key)
return None
def compute_and_cache( # pylint: disable=too-many-arguments
self,
force: bool,
user: User = None,
window_size: WindowSize | None = None,
thumb_size: WindowSize | None = None,
cache: Cache = None,
force: bool = True,
cache_key: str | None = None,
) -> bytes | None:
) -> None:
"""
Fetches the screenshot, computes the thumbnail and caches the result
Computes the thumbnail and caches the result
:param user: If no user is given will use the current context
:param cache: The cache to keep the thumbnail payload
@@ -162,40 +202,46 @@ class BaseScreenshot:
:param force: Will force the computation even if it's already cached
:return: Image payload
"""
cache_key = cache_key or self.cache_key(window_size, thumb_size)
cache_key = cache_key or self.get_cache_key(window_size, thumb_size)
cache_payload = self.get_from_cache_key(cache_key) or ScreenshotCachePayload()
if (
cache_payload.status in [StatusValues.COMPUTING, StatusValues.UPDATED]
and not force
):
logger.info(
"Skipping compute - already processed for thumbnail: %s", cache_key
)
return
window_size = window_size or self.window_size
thumb_size = thumb_size or self.thumb_size
if not force and cache and cache.get(cache_key):
logger.info("Thumb already cached, skipping...")
return None
logger.info("Processing url for thumbnail: %s", cache_key)
payload = None
cache_payload.computing()
self.cache.set(cache_key, cache_payload)
image = None
# Assuming all sorts of things can go wrong with Selenium
try:
with event_logger.log_context(
f"screenshot.compute.{self.thumbnail_type}", force=force
):
payload = self.get_screenshot(user=user, window_size=window_size)
logger.info("trying to generate screenshot")
with event_logger.log_context(f"screenshot.compute.{self.thumbnail_type}"):
image = self.get_screenshot(user=user, window_size=window_size)
except Exception as ex: # pylint: disable=broad-except
logger.warning("Failed at generating thumbnail %s", ex, exc_info=True)
if payload and window_size != thumb_size:
cache_payload.error()
if image and window_size != thumb_size:
try:
payload = self.resize_image(payload, thumb_size=thumb_size)
image = self.resize_image(image, thumb_size=thumb_size)
except Exception as ex: # pylint: disable=broad-except
logger.warning("Failed at resizing thumbnail %s", ex, exc_info=True)
payload = None
cache_payload.error()
image = None
if payload:
if image:
logger.info("Caching thumbnail: %s", cache_key)
with event_logger.log_context(
f"screenshot.cache.{self.thumbnail_type}", force=force
):
cache.set(cache_key, payload)
logger.info("Done caching thumbnail")
return payload
with event_logger.log_context(f"screenshot.cache.{self.thumbnail_type}"):
cache_payload.update(image)
self.cache.set(cache_key, cache_payload)
logger.info("Updated thumbnail cache; Status: %s", cache_payload.get_status())
return
@classmethod
def resize_image(
@@ -265,7 +311,7 @@ class DashboardScreenshot(BaseScreenshot):
self.window_size = window_size or DEFAULT_DASHBOARD_WINDOW_SIZE
self.thumb_size = thumb_size or DEFAULT_DASHBOARD_THUMBNAIL_SIZE
def cache_key(
def get_cache_key(
self,
window_size: bool | WindowSize | None = None,
thumb_size: bool | WindowSize | None = None,

View File

@@ -380,7 +380,7 @@ class WebDriverSelenium(WebDriverProxy):
return error_messages
def get_screenshot(self, url: str, element_name: str, user: User) -> bytes | None:
def get_screenshot(self, url: str, element_name: str, user: User) -> bytes | None: # noqa: C901
driver = self.auth(user)
driver.set_window_size(*self._window)
driver.get(url)
@@ -411,6 +411,7 @@ class WebDriverSelenium(WebDriverProxy):
)
)
except TimeoutException:
logger.info("Timeout Exception caught")
# Fallback to allow a screenshot of an empty dashboard
try:
WebDriverWait(driver, 0).until(
@@ -461,18 +462,23 @@ class WebDriverSelenium(WebDriverProxy):
)
img = element.screenshot_as_png
except Exception as ex:
logger.warning("exception in webdriver", exc_info=ex)
raise
except TimeoutException:
# raise again for the finally block, but handled above
pass
raise
except StaleElementReferenceException:
logger.exception(
"Selenium got a stale element while requesting url %s",
url,
)
raise
except WebDriverException:
logger.exception(
"Encountered an unexpected error when requesting url %s", url
)
raise
finally:
self.destroy(driver, current_app.config["SCREENSHOT_SELENIUM_RETRIES"])
return img

View File

@@ -17,20 +17,17 @@
from typing import Any, Optional, Type
from unittest import mock
import pytest
from superset.async_events.cache_backend import (
RedisCacheBackend,
RedisSentinelCacheBackend,
)
from superset.extensions import async_query_manager
from superset.extensions import async_query_manager, async_query_manager_factory
from superset.utils import json
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.constants import ADMIN_USERNAME
from tests.integration_tests.test_app import app
@pytest.skip(reason="Needs to investigate this test", allow_module_level=True)
class TestAsyncEventApi(SupersetTestCase):
UUID = "943c920-32a5-412a-977d-b8e47d36f5a4"
@@ -41,7 +38,7 @@ class TestAsyncEventApi(SupersetTestCase):
def run_test_with_cache_backend(self, cache_backend_cls: Type[Any], test_func):
app._got_first_request = False
async_query_manager.init_app(app)
async_query_manager_factory.init_app(app)
# Create a mock cache backend instance
mock_cache = mock.Mock(spec=cache_backend_cls)
@@ -130,7 +127,7 @@ class TestAsyncEventApi(SupersetTestCase):
def test_events_no_login(self):
app._got_first_request = False
async_query_manager.init_app(app)
async_query_manager_factory.init_app(app)
rv = self.fetch_events()
assert rv.status_code == 401

View File

@@ -319,9 +319,5 @@ def test_compute_thumbnails(thumbnail_mock, app_context, fs):
["-d", "-i", dashboard.id],
)
thumbnail_mock.assert_called_with(
None,
dashboard.id,
force=False,
)
thumbnail_mock.assert_called_with(None, dashboard.id, force=False)
assert response.exit_code == 0

View File

@@ -37,6 +37,7 @@ from superset.reports.models import ReportSchedule, ReportScheduleType
from superset.models.slice import Slice
from superset.tags.models import Tag, TaggedObject, TagType, ObjectType
from superset.utils.core import backend, override_user
from superset.utils.screenshots import ScreenshotCachePayload
from superset.utils import json
from tests.integration_tests.base_api_tests import ApiOwnersTestCaseMixin
@@ -3069,13 +3070,15 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
@pytest.mark.usefixtures("create_dashboard_with_tag")
@patch("superset.dashboards.api.cache_dashboard_screenshot")
@patch("superset.dashboards.api.DashboardScreenshot.get_from_cache_key")
def test_screenshot_success_png(self, mock_get_cache, mock_cache_task):
def test_screenshot_success_png(self, mock_get_from_cache_key, mock_cache_task):
"""
Validate screenshot returns png
"""
self.login(ADMIN_USERNAME)
mock_cache_task.return_value = None
mock_get_cache.return_value = BytesIO(b"fake image data")
mock_get_from_cache_key.return_value = ScreenshotCachePayload(
b"fake image data"
)
dashboard = (
db.session.query(Dashboard)
@@ -3083,7 +3086,7 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
.first()
)
cache_resp = self._cache_screenshot(dashboard.id)
assert cache_resp.status_code == 202
assert cache_resp.status_code == 200
cache_key = json.loads(cache_resp.data.decode("utf-8"))["cache_key"]
response = self._get_screenshot(dashboard.id, cache_key, "png")
@@ -3091,20 +3094,29 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
assert response.mimetype == "image/png"
assert response.data == b"fake image data"
mock_get_from_cache_key.return_value = ScreenshotCachePayload()
cache_resp = self._cache_screenshot(dashboard.id)
assert cache_resp.status_code == 202
@with_feature_flags(THUMBNAILS=True, ENABLE_DASHBOARD_SCREENSHOT_ENDPOINTS=True)
@pytest.mark.usefixtures("create_dashboard_with_tag")
@patch("superset.dashboards.api.cache_dashboard_screenshot")
@patch("superset.dashboards.api.build_pdf_from_screenshots")
@patch("superset.dashboards.api.DashboardScreenshot.get_from_cache_key")
def test_screenshot_success_pdf(
self, mock_get_from_cache, mock_build_pdf, mock_cache_task
self,
mock_get_from_cache_key,
mock_build_pdf,
mock_cache_task,
):
"""
Validate screenshot can return pdf.
"""
self.login(ADMIN_USERNAME)
mock_cache_task.return_value = None
mock_get_from_cache.return_value = BytesIO(b"fake image data")
mock_get_from_cache_key.return_value = ScreenshotCachePayload(
b"fake image data"
)
mock_build_pdf.return_value = b"fake pdf data"
dashboard = (
@@ -3113,7 +3125,7 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
.first()
)
cache_resp = self._cache_screenshot(dashboard.id)
assert cache_resp.status_code == 202
assert cache_resp.status_code == 200
cache_key = json.loads(cache_resp.data.decode("utf-8"))["cache_key"]
response = self._get_screenshot(dashboard.id, cache_key, "pdf")
@@ -3121,6 +3133,10 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
assert response.mimetype == "application/pdf"
assert response.data == b"fake pdf data"
mock_get_from_cache_key.return_value = ScreenshotCachePayload()
cache_resp = self._cache_screenshot(dashboard.id)
assert cache_resp.status_code == 202
@with_feature_flags(THUMBNAILS=True, ENABLE_DASHBOARD_SCREENSHOT_ENDPOINTS=True)
@pytest.mark.usefixtures("create_dashboard_with_tag")
@patch("superset.dashboards.api.cache_dashboard_screenshot")
@@ -3153,10 +3169,12 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
@pytest.mark.usefixtures("create_dashboard_with_tag")
@patch("superset.dashboards.api.cache_dashboard_screenshot")
@patch("superset.dashboards.api.DashboardScreenshot.get_from_cache_key")
def test_screenshot_invalid_download_format(self, mock_get_cache, mock_cache_task):
def test_screenshot_invalid_download_format(
self, mock_get_from_cache_key, mock_cache_task
):
self.login(ADMIN_USERNAME)
mock_cache_task.return_value = None
mock_get_cache.return_value = BytesIO(b"fake png data")
mock_get_from_cache_key.return_value = ScreenshotCachePayload(b"fake png data")
dashboard = (
db.session.query(Dashboard)
@@ -3165,9 +3183,13 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
)
cache_resp = self._cache_screenshot(dashboard.id)
assert cache_resp.status_code == 202
assert cache_resp.status_code == 200
cache_key = json.loads(cache_resp.data.decode("utf-8"))["cache_key"]
mock_get_from_cache_key.return_value = ScreenshotCachePayload()
cache_resp = self._cache_screenshot(dashboard.id)
assert cache_resp.status_code == 202
response = self._get_screenshot(dashboard.id, cache_key, "invalid")
assert response.status_code == 404

View File

@@ -18,7 +18,6 @@
# from superset.models.dashboard import Dashboard
import urllib.request
from io import BytesIO
from unittest import skipUnless
from unittest.mock import ANY, call, MagicMock, patch
@@ -32,7 +31,11 @@ from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.tasks.types import ExecutorType, FixedExecutor
from superset.utils import json
from superset.utils.screenshots import ChartScreenshot, DashboardScreenshot
from superset.utils.screenshots import (
ChartScreenshot,
DashboardScreenshot,
ScreenshotCachePayload,
)
from superset.utils.urls import get_url_path
from superset.utils.webdriver import WebDriverSelenium
from tests.integration_tests.base_tests import SupersetTestCase
@@ -287,14 +290,14 @@ class TestThumbnails(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_async_dashboard_not_allowed(self):
def test_get_async_dashboard_created(self):
"""
Thumbnails: Simple get async dashboard not allowed
"""
self.login(ADMIN_USERNAME)
_, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL)
rv = self.client.get(thumbnail_url)
assert rv.status_code == 404
assert rv.status_code == 202
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@with_feature_flags(THUMBNAILS=True)
@@ -370,7 +373,9 @@ class TestThumbnails(SupersetTestCase):
Thumbnails: Simple get chart with wrong digest
"""
with patch.object(
ChartScreenshot, "get_from_cache", return_value=BytesIO(self.mock_image)
ChartScreenshot,
"get_from_cache",
return_value=ScreenshotCachePayload(self.mock_image),
):
self.login(ADMIN_USERNAME)
id_, thumbnail_url = self._get_id_and_thumbnail_url(CHART_URL)
@@ -385,7 +390,9 @@ class TestThumbnails(SupersetTestCase):
Thumbnails: Simple get cached dashboard screenshot
"""
with patch.object(
DashboardScreenshot, "get_from_cache", return_value=BytesIO(self.mock_image)
DashboardScreenshot,
"get_from_cache_key",
return_value=ScreenshotCachePayload(self.mock_image),
):
self.login(ADMIN_USERNAME)
_, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL)
@@ -400,7 +407,9 @@ class TestThumbnails(SupersetTestCase):
Thumbnails: Simple get cached chart screenshot
"""
with patch.object(
ChartScreenshot, "get_from_cache", return_value=BytesIO(self.mock_image)
ChartScreenshot,
"get_from_cache_key",
return_value=ScreenshotCachePayload(self.mock_image),
):
self.login(ADMIN_USERNAME)
id_, thumbnail_url = self._get_id_and_thumbnail_url(CHART_URL)
@@ -415,7 +424,9 @@ class TestThumbnails(SupersetTestCase):
Thumbnails: Simple get dashboard with wrong digest
"""
with patch.object(
DashboardScreenshot, "get_from_cache", return_value=BytesIO(self.mock_image)
DashboardScreenshot,
"get_from_cache",
return_value=ScreenshotCachePayload(self.mock_image),
):
self.login(ADMIN_USERNAME)
id_, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL)

View File

@@ -34,13 +34,13 @@ def database_with_catalog(mocker: MockerFixture) -> MagicMock:
database = mocker.MagicMock()
database.database_name = "test_database"
database.get_all_table_names_in_schema.return_value = [
DatasourceName("table1", "schema1", "catalog1"),
DatasourceName("table2", "schema1", "catalog1"),
]
database.get_all_view_names_in_schema.return_value = [
DatasourceName("view1", "schema1", "catalog1"),
]
database.get_all_table_names_in_schema.return_value = {
("table1", "schema1", "catalog1"),
("table2", "schema1", "catalog1"),
}
database.get_all_view_names_in_schema.return_value = {
("view1", "schema1", "catalog1"),
}
DatabaseDAO = mocker.patch("superset.commands.database.tables.DatabaseDAO") # noqa: N806
DatabaseDAO.find_by_id.return_value = database
@@ -57,13 +57,13 @@ def database_without_catalog(mocker: MockerFixture) -> MagicMock:
database = mocker.MagicMock()
database.database_name = "test_database"
database.get_all_table_names_in_schema.return_value = [
DatasourceName("table1", "schema1"),
DatasourceName("table2", "schema1"),
]
database.get_all_view_names_in_schema.return_value = [
DatasourceName("view1", "schema1"),
]
database.get_all_table_names_in_schema.return_value = {
("table1", "schema1", None),
("table2", "schema1", None),
}
database.get_all_view_names_in_schema.return_value = {
("view1", "schema1", None),
}
DatabaseDAO = mocker.patch("superset.commands.database.tables.DatabaseDAO") # noqa: N806
DatabaseDAO.find_by_id.return_value = database

View File

@@ -16,19 +16,24 @@
# under the License.
import json
from datetime import datetime
from unittest.mock import patch
from uuid import UUID
import pytest
from pytest_mock import MockerFixture
from superset.app import SupersetApp
from superset.commands.report.execute import BaseReportState
from superset.dashboards.permalink.types import DashboardPermalinkState
from superset.reports.models import (
ReportRecipientType,
ReportSchedule,
ReportScheduleType,
ReportSourceFormat,
)
from superset.utils.core import HeaderDataType
from superset.utils.screenshots import ChartScreenshot
from tests.integration_tests.conftest import with_feature_flags
@@ -365,3 +370,113 @@ def test_get_tab_url(
)
result: str = class_instance._get_tab_url(dashboard_state)
assert result == "http://0.0.0.0:8080/superset/dashboard/p/uri/"
def create_report_schedule(
mocker: MockerFixture,
custom_width: int | None = None,
custom_height: int | None = None,
) -> ReportSchedule:
"""Helper function to create a ReportSchedule instance with specified dimensions."""
schedule = ReportSchedule()
schedule.type = ReportScheduleType.REPORT
schedule.name = "Test Report"
schedule.description = "Test Description"
schedule.chart = mocker.MagicMock()
schedule.chart.id = 1
schedule.dashboard = None
schedule.database = None
schedule.custom_width = custom_width
schedule.custom_height = custom_height
return schedule
@pytest.mark.parametrize(
"test_id,custom_width,max_width,window_width,expected_width",
[
# Test when custom width exceeds max width
("exceeds_max", 2000, 1600, 800, 1600),
# Test when custom width is less than max width
("under_max", 1200, 1600, 800, 1200),
# Test when custom width is None (should use window width)
("no_custom", None, 1600, 800, 800),
# Test when custom width equals max width
("equals_max", 1600, 1600, 800, 1600),
],
)
def test_screenshot_width_calculation(
app: SupersetApp,
mocker: MockerFixture,
test_id: str,
custom_width: int | None,
max_width: int,
window_width: int,
expected_width: int,
) -> None:
"""
Test that screenshot width is correctly calculated.
The width should be:
- Limited by max_width when custom_width exceeds it
- Equal to custom_width when it's less than max_width
- Equal to window_width when custom_width is None
"""
from superset.commands.report.execute import BaseReportState
# Mock configuration
app.config.update(
{
"ALERT_REPORTS_MAX_CUSTOM_SCREENSHOT_WIDTH": max_width,
"WEBDRIVER_WINDOW": {
"slice": (window_width, 600),
"dashboard": (window_width, 600),
},
"ALERT_REPORTS_EXECUTORS": {},
}
)
# Create report schedule with specified custom width
report_schedule = create_report_schedule(mocker, custom_width=custom_width)
# Initialize BaseReportState
report_state = BaseReportState(
report_schedule=report_schedule,
scheduled_dttm=datetime.now(),
execution_id=UUID("084e7ee6-5557-4ecd-9632-b7f39c9ec524"),
)
# Mock security manager and screenshot
with (
patch(
"superset.commands.report.execute.security_manager"
) as mock_security_manager,
patch(
"superset.utils.screenshots.ChartScreenshot.get_screenshot"
) as mock_get_screenshot,
):
# Mock user
mock_user = mocker.MagicMock()
mock_security_manager.find_user.return_value = mock_user
mock_get_screenshot.return_value = b"screenshot bytes"
# Mock get_executor to avoid database lookups
with patch(
"superset.commands.report.execute.get_executor"
) as mock_get_executor:
mock_get_executor.return_value = ("executor", "username")
# Capture the ChartScreenshot instantiation
with patch(
"superset.commands.report.execute.ChartScreenshot",
wraps=ChartScreenshot,
) as mock_chart_screenshot:
# Call the method that triggers screenshot creation
report_state._get_screenshots()
# Verify ChartScreenshot was created with correct window_size
mock_chart_screenshot.assert_called_once()
_, kwargs = mock_chart_screenshot.call_args
assert kwargs["window_size"][0] == expected_width, (
f"Test {test_id}: Expected width {expected_width}, "
f"but got {kwargs['window_size'][0]}"
)

View File

@@ -18,6 +18,7 @@
from pytest_mock import MockerFixture
from sqlalchemy.orm.session import Session
from superset import app
from superset.migrations.shared.catalogs import (
downgrade_catalog_perms,
upgrade_catalog_perms,
@@ -329,3 +330,252 @@ def test_upgrade_catalog_perms_graceful(
("[my_db].[my_table](id:1)",),
("[my_db].[public]",),
]
def test_upgrade_catalog_perms_oauth_connection(
mocker: MockerFixture,
session: Session,
) -> None:
"""
Test the `upgrade_catalog_perms` function when the DB is set up using OAuth.
During the migration we try to connect to the analytical database to get the list of
schemas. This step should be skipped if the database is set up using OAuth and not
raise an exception.
"""
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
from superset.models.slice import Slice
from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState
engine = session.get_bind()
Database.metadata.create_all(engine)
mocker.patch("superset.migrations.shared.catalogs.op")
db = mocker.patch("superset.migrations.shared.catalogs.db")
db.Session.return_value = session
add_non_default_catalogs = mocker.patch(
"superset.migrations.shared.catalogs.add_non_default_catalogs"
)
mocker.patch("superset.migrations.shared.catalogs.op", session)
database = Database(
database_name="my_db",
sqlalchemy_uri="bigquery://my-test-project",
encrypted_extra='{"oauth2_client_info": "fake_mock_oauth_conn"}',
)
dataset = SqlaTable(
table_name="my_table",
database=database,
catalog=None,
schema="public",
schema_perm="[my_db].[public]",
)
session.add(dataset)
session.commit()
chart = Slice(
slice_name="my_chart",
datasource_type="table",
datasource_id=dataset.id,
)
query = Query(
client_id="foo",
database=database,
catalog=None,
schema="public",
)
saved_query = SavedQuery(
database=database,
sql="SELECT * FROM public.t",
catalog=None,
schema="public",
)
tab_state = TabState(
database=database,
catalog=None,
schema="public",
)
table_schema = TableSchema(
database=database,
catalog=None,
schema="public",
)
session.add_all([chart, query, saved_query, tab_state, table_schema])
session.commit()
# before migration
assert dataset.catalog is None
assert query.catalog is None
assert saved_query.catalog is None
assert tab_state.catalog is None
assert table_schema.catalog is None
assert dataset.schema_perm == "[my_db].[public]"
assert chart.schema_perm == "[my_db].[public]"
assert session.query(ViewMenu.name).all() == [
("[my_db].(id:1)",),
("[my_db].[my_table](id:1)",),
("[my_db].[public]",),
]
upgrade_catalog_perms()
session.commit()
# after migration
assert dataset.catalog == "my-test-project"
assert query.catalog == "my-test-project"
assert saved_query.catalog == "my-test-project"
assert tab_state.catalog == "my-test-project"
assert table_schema.catalog == "my-test-project"
assert dataset.schema_perm == "[my_db].[my-test-project].[public]"
assert chart.schema_perm == "[my_db].[my-test-project].[public]"
assert session.query(ViewMenu.name).all() == [
("[my_db].(id:1)",),
("[my_db].[my_table](id:1)",),
("[my_db].[my-test-project].[public]",),
("[my_db].[my-test-project]",),
]
add_non_default_catalogs.assert_not_called()
downgrade_catalog_perms()
session.commit()
# revert
assert dataset.catalog is None
assert query.catalog is None
assert saved_query.catalog is None
assert tab_state.catalog is None
assert table_schema.catalog is None
assert dataset.schema_perm == "[my_db].[public]"
assert chart.schema_perm == "[my_db].[public]"
assert session.query(ViewMenu.name).all() == [
("[my_db].(id:1)",),
("[my_db].[my_table](id:1)",),
("[my_db].[public]",),
]
def test_upgrade_catalog_perms_simplified_migration(
mocker: MockerFixture,
session: Session,
) -> None:
"""
Test the `upgrade_catalog_perms` function when the ``CATALOGS_SIMPLIFIED_MIGRATION``
config is set to ``True``.
This should only update existing permissions + create a new permission
for the default catalog.
"""
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
from superset.models.slice import Slice
from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState
engine = session.get_bind()
Database.metadata.create_all(engine)
mocker.patch("superset.migrations.shared.catalogs.op")
db = mocker.patch("superset.migrations.shared.catalogs.db")
db.Session.return_value = session
add_non_default_catalogs = mocker.patch(
"superset.migrations.shared.catalogs.add_non_default_catalogs"
)
mocker.patch("superset.migrations.shared.catalogs.op", session)
database = Database(
database_name="my_db",
sqlalchemy_uri="bigquery://my-test-project",
)
dataset = SqlaTable(
table_name="my_table",
database=database,
catalog=None,
schema="public",
schema_perm="[my_db].[public]",
)
session.add(dataset)
session.commit()
chart = Slice(
slice_name="my_chart",
datasource_type="table",
datasource_id=dataset.id,
)
query = Query(
client_id="foo",
database=database,
catalog=None,
schema="public",
)
saved_query = SavedQuery(
database=database,
sql="SELECT * FROM public.t",
catalog=None,
schema="public",
)
tab_state = TabState(
database=database,
catalog=None,
schema="public",
)
table_schema = TableSchema(
database=database,
catalog=None,
schema="public",
)
session.add_all([chart, query, saved_query, tab_state, table_schema])
session.commit()
# before migration
assert dataset.catalog is None
assert query.catalog is None
assert saved_query.catalog is None
assert tab_state.catalog is None
assert table_schema.catalog is None
assert dataset.schema_perm == "[my_db].[public]"
assert chart.schema_perm == "[my_db].[public]"
assert session.query(ViewMenu.name).all() == [
("[my_db].(id:1)",),
("[my_db].[my_table](id:1)",),
("[my_db].[public]",),
]
with app.test_request_context():
app.config["CATALOGS_SIMPLIFIED_MIGRATION"] = True
upgrade_catalog_perms()
session.commit()
# after migration
assert dataset.catalog == "my-test-project"
assert query.catalog == "my-test-project"
assert saved_query.catalog == "my-test-project"
assert tab_state.catalog == "my-test-project"
assert table_schema.catalog == "my-test-project"
assert dataset.schema_perm == "[my_db].[my-test-project].[public]"
assert chart.schema_perm == "[my_db].[my-test-project].[public]"
assert session.query(ViewMenu.name).all() == [
("[my_db].(id:1)",),
("[my_db].[my_table](id:1)",),
("[my_db].[my-test-project].[public]",),
("[my_db].[my-test-project]",),
]
add_non_default_catalogs.assert_not_called()
downgrade_catalog_perms()
session.commit()
# revert
assert dataset.catalog is None
assert query.catalog is None
assert saved_query.catalog is None
assert tab_state.catalog is None
assert table_schema.catalog is None
assert dataset.schema_perm == "[my_db].[public]"
assert chart.schema_perm == "[my_db].[public]"
assert session.query(ViewMenu.name).all() == [
("[my_db].(id:1)",),
("[my_db].[my_table](id:1)",),
("[my_db].[public]",),
]

View File

@@ -768,3 +768,59 @@ FROM (
WHERE
TRUE AND TRUE"""
)
def test_get_all_table_names_in_schema(mocker: MockerFixture) -> None:
"""
Test the `get_all_table_names_in_schema` method.
"""
database = Database(
database_name="db",
sqlalchemy_uri="postgresql://user:password@host:5432/examples",
)
mocker.patch.object(database, "get_inspector")
get_table_names = mocker.patch(
"superset.db_engine_specs.postgres.PostgresEngineSpec.get_table_names"
)
get_table_names.return_value = {"first_table", "second_table", "third_table"}
tables_list = database.get_all_table_names_in_schema(
catalog="examples",
schema="public",
)
assert sorted(tables_list) == sorted(
{
("first_table", "public", "examples"),
("second_table", "public", "examples"),
("third_table", "public", "examples"),
}
)
def test_get_all_view_names_in_schema(mocker: MockerFixture) -> None:
"""
Test the `get_all_view_names_in_schema` method.
"""
database = Database(
database_name="db",
sqlalchemy_uri="postgresql://user:password@host:5432/examples",
)
mocker.patch.object(database, "get_inspector")
get_view_names = mocker.patch(
"superset.db_engine_specs.base.BaseEngineSpec.get_view_names"
)
get_view_names.return_value = {"first_view", "second_view", "third_view"}
views_list = database.get_all_view_names_in_schema(
catalog="examples",
schema="public",
)
assert sorted(views_list) == sorted(
{
("first_view", "public", "examples"),
("second_view", "public", "examples"),
("third_view", "public", "examples"),
}
)

View File

@@ -0,0 +1,194 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-outside-toplevel, unused-argument
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from superset.utils.hashing import md5_sha_from_dict
from superset.utils.screenshots import (
BaseScreenshot,
ScreenshotCachePayload,
StatusValues,
)
BASE_SCREENSHOT_PATH = "superset.utils.screenshots.BaseScreenshot"
class MockCache:
"""A class to manage screenshot cache."""
def __init__(self):
self._cache = None # Store the cached value
def set(self, _key, value):
"""Set the cache with a new value."""
self._cache = value
def get(self, _key):
"""Get the cached value."""
return self._cache
@pytest.fixture
def mock_user():
"""Fixture to create a mock user."""
mock_user = MagicMock()
mock_user.id = 1
return mock_user
@pytest.fixture
def screenshot_obj():
"""Fixture to create a BaseScreenshot object."""
url = "http://example.com"
digest = "sample_digest"
return BaseScreenshot(url, digest)
def test_get_screenshot(mocker: MockerFixture, screenshot_obj):
"""Get screenshot should return a Bytes object"""
fake_bytes = b"fake_screenshot_data"
driver = mocker.patch(BASE_SCREENSHOT_PATH + ".driver")
driver.return_value.get_screenshot.return_value = fake_bytes
screenshot_data = screenshot_obj.get_screenshot(mock_user)
assert screenshot_data == fake_bytes
def test_get_cache_key(screenshot_obj):
"""Test get_cache_key method"""
expected_cache_key = md5_sha_from_dict(
{
"thumbnail_type": "",
"digest": screenshot_obj.digest,
"type": "thumb",
"window_size": screenshot_obj.window_size,
"thumb_size": screenshot_obj.thumb_size,
}
)
cache_key = screenshot_obj.get_cache_key()
assert cache_key == expected_cache_key
def test_get_from_cache_key(mocker: MockerFixture, screenshot_obj):
"""get_from_cache_key should always return a ScreenshotCachePayload Object"""
# backwards compatability test for retrieving plain bytes
fake_bytes = b"fake_screenshot_data"
BaseScreenshot.cache = MockCache()
BaseScreenshot.cache.set("key", fake_bytes)
cache_payload = screenshot_obj.get_from_cache_key("key")
assert isinstance(cache_payload, ScreenshotCachePayload)
assert cache_payload._image == fake_bytes # pylint: disable=protected-access
class TestComputeAndCache:
def _setup_compute_and_cache(self, mocker: MockerFixture, screenshot_obj):
"""Helper method to handle the common setup for the tests."""
# Patch the methods
get_from_cache_key = mocker.patch(
BASE_SCREENSHOT_PATH + ".get_from_cache_key", return_value=None
)
get_screenshot = mocker.patch(
BASE_SCREENSHOT_PATH + ".get_screenshot", return_value=b"new_image_data"
)
resize_image = mocker.patch(
BASE_SCREENSHOT_PATH + ".resize_image", return_value=b"resized_image_data"
)
BaseScreenshot.cache = MockCache()
return {
"get_from_cache_key": get_from_cache_key,
"get_screenshot": get_screenshot,
"resize_image": resize_image,
}
def test_happy_path(self, mocker: MockerFixture, screenshot_obj):
self._setup_compute_and_cache(mocker, screenshot_obj)
screenshot_obj.compute_and_cache(force=False)
cache_payload: ScreenshotCachePayload = screenshot_obj.cache.get("key")
assert cache_payload.status == StatusValues.UPDATED
def test_screenshot_error(self, mocker: MockerFixture, screenshot_obj):
mocks = self._setup_compute_and_cache(mocker, screenshot_obj)
get_screenshot: MagicMock = mocks.get("get_screenshot")
get_screenshot.side_effect = Exception
screenshot_obj.compute_and_cache(force=False)
cache_payload: ScreenshotCachePayload = screenshot_obj.cache.get("key")
assert cache_payload.status == StatusValues.ERROR
def test_resize_error(self, mocker: MockerFixture, screenshot_obj):
mocks = self._setup_compute_and_cache(mocker, screenshot_obj)
resize_image: MagicMock = mocks.get("resize_image")
resize_image.side_effect = Exception
screenshot_obj.compute_and_cache(force=False)
cache_payload: ScreenshotCachePayload = screenshot_obj.cache.get("key")
assert cache_payload.status == StatusValues.ERROR
def test_skips_if_computing(self, mocker: MockerFixture, screenshot_obj):
mocks = self._setup_compute_and_cache(mocker, screenshot_obj)
cached_value = ScreenshotCachePayload()
cached_value.computing()
get_from_cache_key = mocks.get("get_from_cache_key")
get_from_cache_key.return_value = cached_value
# Ensure that it skips when thumbnail status is computing
screenshot_obj.compute_and_cache(force=False)
get_screenshot = mocks.get("get_screenshot")
get_screenshot.assert_not_called()
# Ensure that it processes when force = True
screenshot_obj.compute_and_cache(force=True)
get_screenshot.assert_called_once()
cache_payload: ScreenshotCachePayload = screenshot_obj.cache.get("key")
assert cache_payload.status == StatusValues.UPDATED
def test_skips_if_updated(self, mocker: MockerFixture, screenshot_obj):
mocks = self._setup_compute_and_cache(mocker, screenshot_obj)
cached_value = ScreenshotCachePayload(image=b"initial_value")
get_from_cache_key = mocks.get("get_from_cache_key")
get_from_cache_key.return_value = cached_value
# Ensure that it skips when thumbnail status is updated
window_size = thumb_size = (10, 10)
screenshot_obj.compute_and_cache(
force=False, window_size=window_size, thumb_size=thumb_size
)
get_screenshot = mocks.get("get_screenshot")
get_screenshot.assert_not_called()
# Ensure that it processes when force = True
screenshot_obj.compute_and_cache(
force=True, window_size=window_size, thumb_size=thumb_size
)
get_screenshot.assert_called_once()
cache_payload: ScreenshotCachePayload = screenshot_obj.cache.get("key")
assert cache_payload._image != b"initial_value"
def test_resize(self, mocker: MockerFixture, screenshot_obj):
mocks = self._setup_compute_and_cache(mocker, screenshot_obj)
window_size = thumb_size = (10, 10)
resize_image: MagicMock = mocks.get("resize_image")
screenshot_obj.compute_and_cache(
force=False, window_size=window_size, thumb_size=thumb_size
)
resize_image.assert_not_called()
screenshot_obj.compute_and_cache(
force=False, window_size=(1, 1), thumb_size=thumb_size
)
resize_image.assert_called_once()