mirror of
https://github.com/apache/superset.git
synced 2026-06-13 03:29:17 +00:00
Compare commits
7 Commits
fix/chart-
...
default-db
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0a2850a33c | ||
|
|
ff17daa424 | ||
|
|
89c5c55dcb | ||
|
|
db201285e5 | ||
|
|
1910a5c607 | ||
|
|
b377ce564b | ||
|
|
4c6df01353 |
@@ -73,11 +73,14 @@ beforeEach(() => {
|
|||||||
dbId: expectDbId,
|
dbId: expectDbId,
|
||||||
forceRefresh: false,
|
forceRefresh: false,
|
||||||
},
|
},
|
||||||
fakeSchemaApiResult.map(value => ({
|
{
|
||||||
value,
|
schemas: fakeSchemaApiResult.map(value => ({
|
||||||
label: value,
|
value,
|
||||||
title: value,
|
label: value,
|
||||||
})),
|
title: value,
|
||||||
|
})),
|
||||||
|
defaultSchema: null,
|
||||||
|
},
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
store.dispatch(
|
store.dispatch(
|
||||||
@@ -307,11 +310,14 @@ test('returns long keywords with docText', async () => {
|
|||||||
dbId: expectLongKeywordDbId,
|
dbId: expectLongKeywordDbId,
|
||||||
forceRefresh: false,
|
forceRefresh: false,
|
||||||
},
|
},
|
||||||
['short', longKeyword].map(value => ({
|
{
|
||||||
value,
|
schemas: ['short', longKeyword].map(value => ({
|
||||||
label: value,
|
value,
|
||||||
title: value,
|
label: value,
|
||||||
})),
|
title: value,
|
||||||
|
})),
|
||||||
|
defaultSchema: null,
|
||||||
|
},
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ export function useKeywords(
|
|||||||
// skipFetch is used to prevent re-evaluating memoized keywords
|
// skipFetch is used to prevent re-evaluating memoized keywords
|
||||||
// due to updated api results by skip flag
|
// due to updated api results by skip flag
|
||||||
const skipFetch = hasFetchedKeywords && skip;
|
const skipFetch = hasFetchedKeywords && skip;
|
||||||
const { currentData: schemaOptions } = useSchemasQueryState(
|
const { currentData: schemaData } = useSchemasQueryState(
|
||||||
{
|
{
|
||||||
dbId,
|
dbId,
|
||||||
catalog: catalog || undefined,
|
catalog: catalog || undefined,
|
||||||
@@ -86,6 +86,7 @@ export function useKeywords(
|
|||||||
},
|
},
|
||||||
{ skip: skipFetch || !dbId },
|
{ skip: skipFetch || !dbId },
|
||||||
);
|
);
|
||||||
|
const schemaOptions = schemaData?.schemas;
|
||||||
const { currentData: tableData } = useTablesQueryState(
|
const { currentData: tableData } = useTablesQueryState(
|
||||||
{
|
{
|
||||||
dbId,
|
dbId,
|
||||||
|
|||||||
@@ -163,11 +163,13 @@ const fakeDatabaseApiResultInReverseOrder = {
|
|||||||
const fakeSchemaApiResult = {
|
const fakeSchemaApiResult = {
|
||||||
count: 2,
|
count: 2,
|
||||||
result: ['information_schema', 'public'],
|
result: ['information_schema', 'public'],
|
||||||
|
default: 'public',
|
||||||
};
|
};
|
||||||
|
|
||||||
const fakeCatalogApiResult = {
|
const fakeCatalogApiResult = {
|
||||||
count: 0,
|
count: 0,
|
||||||
result: [],
|
result: [],
|
||||||
|
default: null,
|
||||||
};
|
};
|
||||||
|
|
||||||
const fakeFunctionNamesApiResult = {
|
const fakeFunctionNamesApiResult = {
|
||||||
@@ -369,10 +371,11 @@ test('Sends the correct schema when changing the schema', async () => {
|
|||||||
});
|
});
|
||||||
await waitFor(() => expect(fetchMock.calls(databaseApiRoute).length).toBe(1));
|
await waitFor(() => expect(fetchMock.calls(databaseApiRoute).length).toBe(1));
|
||||||
rerender(<DatabaseSelector {...props} />);
|
rerender(<DatabaseSelector {...props} />);
|
||||||
expect(props.onSchemaChange).toHaveBeenCalledTimes(0);
|
// Wait for schema data to load
|
||||||
const select = screen.getByRole('combobox', {
|
const select = await screen.findByRole('combobox', {
|
||||||
name: 'Select schema or type to search schemas: public',
|
name: 'Select schema or type to search schemas: public',
|
||||||
});
|
});
|
||||||
|
expect(props.onSchemaChange).toHaveBeenCalledTimes(0);
|
||||||
expect(select).toBeInTheDocument();
|
expect(select).toBeInTheDocument();
|
||||||
await userEvent.click(select);
|
await userEvent.click(select);
|
||||||
const schemaOption = await screen.findByText('information_schema');
|
const schemaOption = await screen.findByText('information_schema');
|
||||||
@@ -382,3 +385,82 @@ test('Sends the correct schema when changing the schema', async () => {
|
|||||||
);
|
);
|
||||||
expect(props.onSchemaChange).toHaveBeenCalledTimes(1);
|
expect(props.onSchemaChange).toHaveBeenCalledTimes(1);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test('Auto-selects default schema on first load when no schema is provided', async () => {
|
||||||
|
fetchMock.get(
|
||||||
|
schemaApiRoute,
|
||||||
|
{
|
||||||
|
result: ['information_schema', 'public', 'other_schema'],
|
||||||
|
default: 'public',
|
||||||
|
},
|
||||||
|
{ overwriteRoutes: true },
|
||||||
|
);
|
||||||
|
|
||||||
|
const props = {
|
||||||
|
...createProps(),
|
||||||
|
schema: undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<DatabaseSelector {...props} />, { useRedux: true, store });
|
||||||
|
|
||||||
|
// Wait for schemas to load and default to be applied
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(props.onSchemaChange).toHaveBeenCalledWith('public');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test('Does not auto-select default schema when schema is already provided', async () => {
|
||||||
|
fetchMock.get(
|
||||||
|
schemaApiRoute,
|
||||||
|
{
|
||||||
|
result: ['information_schema', 'public', 'other_schema'],
|
||||||
|
default: 'public',
|
||||||
|
},
|
||||||
|
{ overwriteRoutes: true },
|
||||||
|
);
|
||||||
|
|
||||||
|
const props = {
|
||||||
|
...createProps(),
|
||||||
|
schema: 'information_schema',
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<DatabaseSelector {...props} />, { useRedux: true, store });
|
||||||
|
|
||||||
|
// Wait for schemas to load
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(fetchMock.calls(schemaApiRoute).length).toBe(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Should not call onSchemaChange since schema is already set
|
||||||
|
expect(props.onSchemaChange).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('Auto-selects default catalog on first load for multi-catalog database', async () => {
|
||||||
|
fetchMock.get(
|
||||||
|
catalogApiRoute,
|
||||||
|
{
|
||||||
|
result: ['catalog_a', 'catalog_b', 'catalog_c'],
|
||||||
|
default: 'catalog_b',
|
||||||
|
},
|
||||||
|
{ overwriteRoutes: true },
|
||||||
|
);
|
||||||
|
|
||||||
|
const props = {
|
||||||
|
...createProps(),
|
||||||
|
db: {
|
||||||
|
id: 1,
|
||||||
|
database_name: 'test-multicatalog',
|
||||||
|
backend: 'test-postgresql',
|
||||||
|
allow_multi_catalog: true,
|
||||||
|
},
|
||||||
|
catalog: undefined,
|
||||||
|
onCatalogChange: jest.fn(),
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<DatabaseSelector {...props} />, { useRedux: true, store });
|
||||||
|
|
||||||
|
// Wait for catalogs to load and default to be applied
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(props.onCatalogChange).toHaveBeenCalledWith('catalog_b');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|||||||
@@ -144,6 +144,9 @@ export function DatabaseSelector({
|
|||||||
);
|
);
|
||||||
const schemaRef = useRef(schema);
|
const schemaRef = useRef(schema);
|
||||||
schemaRef.current = schema;
|
schemaRef.current = schema;
|
||||||
|
// Track if we've applied defaults to avoid re-applying after user clears selection
|
||||||
|
const appliedCatalogDefaultRef = useRef<string | null>(null);
|
||||||
|
const appliedSchemaDefaultRef = useRef<string | null>(null);
|
||||||
const { addSuccessToast } = useToasts();
|
const { addSuccessToast } = useToasts();
|
||||||
const sortComparator = useCallback(
|
const sortComparator = useCallback(
|
||||||
(itemA: AntdLabeledValueWithOrder, itemB: AntdLabeledValueWithOrder) =>
|
(itemA: AntdLabeledValueWithOrder, itemB: AntdLabeledValueWithOrder) =>
|
||||||
@@ -240,22 +243,14 @@ export function DatabaseSelector({
|
|||||||
}
|
}
|
||||||
|
|
||||||
const {
|
const {
|
||||||
currentData: schemaData,
|
data: schemaData,
|
||||||
isFetching: loadingSchemas,
|
isFetching: loadingSchemas,
|
||||||
refetch: refetchSchemas,
|
refetch: refetchSchemas,
|
||||||
|
defaultSchema,
|
||||||
} = useSchemas({
|
} = useSchemas({
|
||||||
dbId: currentDb?.value,
|
dbId: currentDb?.value,
|
||||||
catalog: currentCatalog?.value,
|
catalog: currentCatalog?.value,
|
||||||
onSuccess: (schemas, isFetched) => {
|
onSuccess: (schemas, isFetched) => {
|
||||||
setErrorPayload(null);
|
|
||||||
if (schemas.length === 1) {
|
|
||||||
changeSchema(schemas[0]);
|
|
||||||
} else if (
|
|
||||||
!schemas.find(schemaOption => schemaRef.current === schemaOption.value)
|
|
||||||
) {
|
|
||||||
changeSchema(undefined);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isFetched) {
|
if (isFetched) {
|
||||||
addSuccessToast('List refreshed');
|
addSuccessToast('List refreshed');
|
||||||
}
|
}
|
||||||
@@ -271,9 +266,41 @@ export function DatabaseSelector({
|
|||||||
|
|
||||||
const schemaOptions = schemaData || EMPTY_SCHEMA_OPTIONS;
|
const schemaOptions = schemaData || EMPTY_SCHEMA_OPTIONS;
|
||||||
|
|
||||||
|
// Handle schema auto-selection when data changes
|
||||||
|
useEffect(() => {
|
||||||
|
if (!schemaData || loadingSchemas) return;
|
||||||
|
|
||||||
|
setErrorPayload(null);
|
||||||
|
|
||||||
|
if (schemaData.length === 1) {
|
||||||
|
changeSchema(schemaData[0]);
|
||||||
|
} else if (
|
||||||
|
!schemaData.find(schemaOption => schemaRef.current === schemaOption.value)
|
||||||
|
) {
|
||||||
|
// Current selection not in list - try to apply default on first load
|
||||||
|
if (
|
||||||
|
defaultSchema &&
|
||||||
|
appliedSchemaDefaultRef.current !== defaultSchema
|
||||||
|
) {
|
||||||
|
const defaultOption = schemaData.find(s => s.value === defaultSchema);
|
||||||
|
if (defaultOption) {
|
||||||
|
appliedSchemaDefaultRef.current = defaultSchema;
|
||||||
|
changeSchema(defaultOption);
|
||||||
|
} else {
|
||||||
|
changeSchema(undefined);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
changeSchema(undefined);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
|
}, [schemaData, defaultSchema, loadingSchemas]);
|
||||||
|
|
||||||
function changeCatalog(catalog: CatalogOption | null | undefined) {
|
function changeCatalog(catalog: CatalogOption | null | undefined) {
|
||||||
setCurrentCatalog(catalog);
|
setCurrentCatalog(catalog);
|
||||||
setCurrentSchema(undefined);
|
setCurrentSchema(undefined);
|
||||||
|
// Reset schema default ref so default can be applied for the new catalog
|
||||||
|
appliedSchemaDefaultRef.current = null;
|
||||||
if (onCatalogChange && catalog?.value !== catalogRef.current) {
|
if (onCatalogChange && catalog?.value !== catalogRef.current) {
|
||||||
onCatalogChange(catalog?.value);
|
onCatalogChange(catalog?.value);
|
||||||
}
|
}
|
||||||
@@ -283,22 +310,10 @@ export function DatabaseSelector({
|
|||||||
data: catalogData,
|
data: catalogData,
|
||||||
isFetching: loadingCatalogs,
|
isFetching: loadingCatalogs,
|
||||||
refetch: refetchCatalogs,
|
refetch: refetchCatalogs,
|
||||||
|
defaultCatalog,
|
||||||
} = useCatalogs({
|
} = useCatalogs({
|
||||||
dbId: showCatalogSelector ? currentDb?.value : undefined,
|
dbId: showCatalogSelector ? currentDb?.value : undefined,
|
||||||
onSuccess: (catalogs, isFetched) => {
|
onSuccess: (catalogs, isFetched) => {
|
||||||
setErrorPayload(null);
|
|
||||||
if (!showCatalogSelector) {
|
|
||||||
changeCatalog(null);
|
|
||||||
} else if (catalogs.length === 1) {
|
|
||||||
changeCatalog(catalogs[0]);
|
|
||||||
} else if (
|
|
||||||
!catalogs.find(
|
|
||||||
catalogOption => catalogRef.current === catalogOption.value,
|
|
||||||
)
|
|
||||||
) {
|
|
||||||
changeCatalog(undefined);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (showCatalogSelector && isFetched) {
|
if (showCatalogSelector && isFetched) {
|
||||||
addSuccessToast('List refreshed');
|
addSuccessToast('List refreshed');
|
||||||
}
|
}
|
||||||
@@ -316,6 +331,49 @@ export function DatabaseSelector({
|
|||||||
|
|
||||||
const catalogOptions = catalogData || EMPTY_CATALOG_OPTIONS;
|
const catalogOptions = catalogData || EMPTY_CATALOG_OPTIONS;
|
||||||
|
|
||||||
|
// Handle catalog auto-selection when data changes
|
||||||
|
useEffect(() => {
|
||||||
|
if (loadingCatalogs) return;
|
||||||
|
|
||||||
|
setErrorPayload(null);
|
||||||
|
|
||||||
|
if (!showCatalogSelector) {
|
||||||
|
// Only clear catalog if it's not already null
|
||||||
|
if (currentCatalog !== null) {
|
||||||
|
setCurrentCatalog(null);
|
||||||
|
if (onCatalogChange && catalogRef.current != null) {
|
||||||
|
onCatalogChange(undefined);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (catalogData && catalogData.length === 1) {
|
||||||
|
changeCatalog(catalogData[0]);
|
||||||
|
} else if (
|
||||||
|
catalogData &&
|
||||||
|
!catalogData.find(
|
||||||
|
catalogOption => catalogRef.current === catalogOption.value,
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
// Current selection not in list - try to apply default on first load
|
||||||
|
if (
|
||||||
|
defaultCatalog &&
|
||||||
|
appliedCatalogDefaultRef.current !== defaultCatalog
|
||||||
|
) {
|
||||||
|
const defaultOption = catalogData.find(
|
||||||
|
c => c.value === defaultCatalog,
|
||||||
|
);
|
||||||
|
if (defaultOption) {
|
||||||
|
appliedCatalogDefaultRef.current = defaultCatalog;
|
||||||
|
changeCatalog(defaultOption);
|
||||||
|
} else {
|
||||||
|
changeCatalog(undefined);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
changeCatalog(undefined);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
|
}, [catalogData, defaultCatalog, loadingCatalogs, showCatalogSelector]);
|
||||||
|
|
||||||
function changeDatabase(
|
function changeDatabase(
|
||||||
value: { label: string; value: number },
|
value: { label: string; value: number },
|
||||||
database: DatabaseValue,
|
database: DatabaseValue,
|
||||||
@@ -326,6 +384,9 @@ export function DatabaseSelector({
|
|||||||
setCurrentDb(databaseWithId);
|
setCurrentDb(databaseWithId);
|
||||||
setCurrentCatalog(undefined);
|
setCurrentCatalog(undefined);
|
||||||
setCurrentSchema(undefined);
|
setCurrentSchema(undefined);
|
||||||
|
// Reset default refs so defaults can be applied for the new database
|
||||||
|
appliedCatalogDefaultRef.current = null;
|
||||||
|
appliedSchemaDefaultRef.current = null;
|
||||||
if (onDbChange) {
|
if (onDbChange) {
|
||||||
onDbChange(databaseWithId);
|
onDbChange(databaseWithId);
|
||||||
}
|
}
|
||||||
|
|||||||
204
superset-frontend/src/hooks/apiResources/catalogs.test.ts
Normal file
204
superset-frontend/src/hooks/apiResources/catalogs.test.ts
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
/**
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
* or more contributor license agreements. See the NOTICE file
|
||||||
|
* distributed with this work for additional information
|
||||||
|
* regarding copyright ownership. The ASF licenses this file
|
||||||
|
* to you under the Apache License, Version 2.0 (the
|
||||||
|
* "License"); you may not use this file except in compliance
|
||||||
|
* with the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
import rison from 'rison';
|
||||||
|
import fetchMock from 'fetch-mock';
|
||||||
|
import { act, renderHook } from '@testing-library/react-hooks';
|
||||||
|
import {
|
||||||
|
createWrapper,
|
||||||
|
defaultStore as store,
|
||||||
|
} from 'spec/helpers/testing-library';
|
||||||
|
import { api } from 'src/hooks/apiResources/queryApi';
|
||||||
|
import { useCatalogs } from './catalogs';
|
||||||
|
|
||||||
|
const fakeApiResult = {
|
||||||
|
result: ['catalog_a', 'catalog_b'],
|
||||||
|
default: 'catalog_a',
|
||||||
|
};
|
||||||
|
const fakeApiResult2 = {
|
||||||
|
result: ['catalog_c', 'catalog_d'],
|
||||||
|
default: null,
|
||||||
|
};
|
||||||
|
|
||||||
|
const expectedResult = fakeApiResult.result.map((value: string) => ({
|
||||||
|
value,
|
||||||
|
label: value,
|
||||||
|
title: value,
|
||||||
|
}));
|
||||||
|
const expectedResult2 = fakeApiResult2.result.map((value: string) => ({
|
||||||
|
value,
|
||||||
|
label: value,
|
||||||
|
title: value,
|
||||||
|
}));
|
||||||
|
|
||||||
|
// eslint-disable-next-line no-restricted-globals -- TODO: Migrate from describe blocks
|
||||||
|
describe('useCatalogs hook', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
fetchMock.reset();
|
||||||
|
store.dispatch(api.util.resetApiState());
|
||||||
|
});
|
||||||
|
|
||||||
|
test('returns api response mapping json result with default catalog', async () => {
|
||||||
|
const expectDbId = 'db1';
|
||||||
|
const forceRefresh = false;
|
||||||
|
const catalogApiRoute = `glob:*/api/v1/database/${expectDbId}/catalogs/*`;
|
||||||
|
fetchMock.get(catalogApiRoute, fakeApiResult);
|
||||||
|
const onSuccess = jest.fn();
|
||||||
|
const { result, waitFor } = renderHook(
|
||||||
|
() =>
|
||||||
|
useCatalogs({
|
||||||
|
dbId: expectDbId,
|
||||||
|
onSuccess,
|
||||||
|
}),
|
||||||
|
{
|
||||||
|
wrapper: createWrapper({
|
||||||
|
useRedux: true,
|
||||||
|
store,
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
await waitFor(() =>
|
||||||
|
expect(fetchMock.calls(catalogApiRoute).length).toBe(1),
|
||||||
|
);
|
||||||
|
expect(result.current.data).toEqual(expectedResult);
|
||||||
|
expect(result.current.defaultCatalog).toBe('catalog_a');
|
||||||
|
expect(
|
||||||
|
fetchMock.calls(
|
||||||
|
`end:/api/v1/database/${expectDbId}/catalogs/?q=${rison.encode({
|
||||||
|
force: forceRefresh,
|
||||||
|
})}`,
|
||||||
|
).length,
|
||||||
|
).toBe(1);
|
||||||
|
expect(onSuccess).toHaveBeenCalledTimes(1);
|
||||||
|
act(() => {
|
||||||
|
result.current.refetch();
|
||||||
|
});
|
||||||
|
await waitFor(() =>
|
||||||
|
expect(fetchMock.calls(catalogApiRoute).length).toBe(2),
|
||||||
|
);
|
||||||
|
expect(
|
||||||
|
fetchMock.calls(
|
||||||
|
`end:/api/v1/database/${expectDbId}/catalogs/?q=${rison.encode({
|
||||||
|
force: true,
|
||||||
|
})}`,
|
||||||
|
).length,
|
||||||
|
).toBe(1);
|
||||||
|
expect(onSuccess).toHaveBeenCalledTimes(2);
|
||||||
|
expect(result.current.data).toEqual(expectedResult);
|
||||||
|
expect(result.current.defaultCatalog).toBe('catalog_a');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('returns cached data without api request', async () => {
|
||||||
|
const expectDbId = 'db1';
|
||||||
|
const catalogApiRoute = `glob:*/api/v1/database/${expectDbId}/catalogs/*`;
|
||||||
|
fetchMock.get(catalogApiRoute, fakeApiResult);
|
||||||
|
const { result, rerender, waitFor } = renderHook(
|
||||||
|
() =>
|
||||||
|
useCatalogs({
|
||||||
|
dbId: expectDbId,
|
||||||
|
}),
|
||||||
|
{
|
||||||
|
wrapper: createWrapper({
|
||||||
|
useRedux: true,
|
||||||
|
store,
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
await waitFor(() => expect(result.current.data).toEqual(expectedResult));
|
||||||
|
expect(result.current.defaultCatalog).toBe('catalog_a');
|
||||||
|
expect(fetchMock.calls(catalogApiRoute).length).toBe(1);
|
||||||
|
rerender();
|
||||||
|
await waitFor(() => expect(result.current.data).toEqual(expectedResult));
|
||||||
|
expect(result.current.defaultCatalog).toBe('catalog_a');
|
||||||
|
expect(fetchMock.calls(catalogApiRoute).length).toBe(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('returns refreshed data after switching databases', async () => {
|
||||||
|
const expectDbId = 'db1';
|
||||||
|
const catalogApiRoute = `glob:*/api/v1/database/*/catalogs/*`;
|
||||||
|
fetchMock.get(catalogApiRoute, url =>
|
||||||
|
url.includes(expectDbId) ? fakeApiResult : fakeApiResult2,
|
||||||
|
);
|
||||||
|
const onSuccess = jest.fn();
|
||||||
|
const { result, rerender, waitFor } = renderHook(
|
||||||
|
({ dbId }) =>
|
||||||
|
useCatalogs({
|
||||||
|
dbId,
|
||||||
|
onSuccess,
|
||||||
|
}),
|
||||||
|
{
|
||||||
|
initialProps: { dbId: expectDbId },
|
||||||
|
wrapper: createWrapper({
|
||||||
|
useRedux: true,
|
||||||
|
store,
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
await waitFor(() => expect(result.current.data).toEqual(expectedResult));
|
||||||
|
expect(result.current.defaultCatalog).toBe('catalog_a');
|
||||||
|
expect(fetchMock.calls(catalogApiRoute).length).toBe(1);
|
||||||
|
expect(onSuccess).toHaveBeenCalledTimes(1);
|
||||||
|
|
||||||
|
rerender({ dbId: 'db2' });
|
||||||
|
await waitFor(() => expect(result.current.data).toEqual(expectedResult2));
|
||||||
|
expect(result.current.defaultCatalog).toBeNull();
|
||||||
|
expect(fetchMock.calls(catalogApiRoute).length).toBe(2);
|
||||||
|
expect(onSuccess).toHaveBeenCalledTimes(2);
|
||||||
|
|
||||||
|
rerender({ dbId: expectDbId });
|
||||||
|
await waitFor(() => expect(result.current.data).toEqual(expectedResult));
|
||||||
|
expect(result.current.defaultCatalog).toBe('catalog_a');
|
||||||
|
expect(fetchMock.calls(catalogApiRoute).length).toBe(2);
|
||||||
|
expect(onSuccess).toHaveBeenCalledTimes(2);
|
||||||
|
|
||||||
|
// clean up cache
|
||||||
|
act(() => {
|
||||||
|
store.dispatch(api.util.invalidateTags(['Catalogs']));
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() =>
|
||||||
|
expect(fetchMock.calls(catalogApiRoute).length).toBe(4),
|
||||||
|
);
|
||||||
|
expect(fetchMock.calls(catalogApiRoute)[2][0]).toContain(expectDbId);
|
||||||
|
await waitFor(() => expect(result.current.data).toEqual(expectedResult));
|
||||||
|
expect(result.current.defaultCatalog).toBe('catalog_a');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('returns null defaultCatalog when API response has no default', async () => {
|
||||||
|
const expectDbId = 'db-no-default';
|
||||||
|
const catalogApiRoute = `glob:*/api/v1/database/${expectDbId}/catalogs/*`;
|
||||||
|
fetchMock.get(catalogApiRoute, { result: ['catalog1', 'catalog2'] });
|
||||||
|
const { result, waitFor } = renderHook(
|
||||||
|
() =>
|
||||||
|
useCatalogs({
|
||||||
|
dbId: expectDbId,
|
||||||
|
}),
|
||||||
|
{
|
||||||
|
wrapper: createWrapper({
|
||||||
|
useRedux: true,
|
||||||
|
store,
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
await waitFor(() =>
|
||||||
|
expect(fetchMock.calls(catalogApiRoute).length).toBe(1),
|
||||||
|
);
|
||||||
|
expect(result.current.defaultCatalog).toBeNull();
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -30,27 +30,39 @@ export type CatalogOption = {
|
|||||||
export type FetchCatalogsQueryParams = {
|
export type FetchCatalogsQueryParams = {
|
||||||
dbId?: string | number;
|
dbId?: string | number;
|
||||||
forceRefresh: boolean;
|
forceRefresh: boolean;
|
||||||
onSuccess?: (data: CatalogOption[], isRefetched: boolean) => void;
|
onSuccess?: (
|
||||||
|
data: CatalogOption[],
|
||||||
|
isRefetched: boolean,
|
||||||
|
defaultCatalog: string | null,
|
||||||
|
) => void;
|
||||||
onError?: (error: ClientErrorObject) => void;
|
onError?: (error: ClientErrorObject) => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
type Params = Omit<FetchCatalogsQueryParams, 'forceRefresh'>;
|
type Params = Omit<FetchCatalogsQueryParams, 'forceRefresh'>;
|
||||||
|
|
||||||
|
// Internal type for transformed API response
|
||||||
|
type CatalogsApiResponse = {
|
||||||
|
catalogs: CatalogOption[];
|
||||||
|
defaultCatalog: string | null;
|
||||||
|
};
|
||||||
|
|
||||||
const catalogApi = api.injectEndpoints({
|
const catalogApi = api.injectEndpoints({
|
||||||
endpoints: builder => ({
|
endpoints: builder => ({
|
||||||
catalogs: builder.query<CatalogOption[], FetchCatalogsQueryParams>({
|
catalogs: builder.query<CatalogsApiResponse, FetchCatalogsQueryParams>({
|
||||||
providesTags: [{ type: 'Catalogs', id: 'LIST' }],
|
providesTags: [{ type: 'Catalogs', id: 'LIST' }],
|
||||||
query: ({ dbId, forceRefresh }) => ({
|
query: ({ dbId, forceRefresh }) => ({
|
||||||
endpoint: `/api/v1/database/${dbId}/catalogs/`,
|
endpoint: `/api/v1/database/${dbId}/catalogs/`,
|
||||||
urlParams: {
|
urlParams: {
|
||||||
force: forceRefresh,
|
force: forceRefresh,
|
||||||
},
|
},
|
||||||
transformResponse: ({ json }: JsonResponse) =>
|
transformResponse: ({ json }: JsonResponse) => ({
|
||||||
json.result.sort().map((value: string) => ({
|
catalogs: json.result.sort().map((value: string) => ({
|
||||||
value,
|
value,
|
||||||
label: value,
|
label: value,
|
||||||
title: value,
|
title: value,
|
||||||
})),
|
})),
|
||||||
|
defaultCatalog: json.default ?? null,
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
serializeQueryArgs: ({ queryArgs: { dbId } }) => ({
|
serializeQueryArgs: ({ queryArgs: { dbId } }) => ({
|
||||||
dbId,
|
dbId,
|
||||||
@@ -89,7 +101,11 @@ export function useCatalogs(options: Params) {
|
|||||||
if (dbId && (!result.currentData || forceRefresh)) {
|
if (dbId && (!result.currentData || forceRefresh)) {
|
||||||
trigger({ dbId, forceRefresh }).then(({ isSuccess, isError, data }) => {
|
trigger({ dbId, forceRefresh }).then(({ isSuccess, isError, data }) => {
|
||||||
if (isSuccess) {
|
if (isSuccess) {
|
||||||
onSuccess?.(data || EMPTY_CATALOGS, forceRefresh);
|
onSuccess?.(
|
||||||
|
data?.catalogs || EMPTY_CATALOGS,
|
||||||
|
forceRefresh,
|
||||||
|
data?.defaultCatalog ?? null,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
if (isError) {
|
if (isError) {
|
||||||
onError?.(result.error as ClientErrorObject);
|
onError?.(result.error as ClientErrorObject);
|
||||||
@@ -110,5 +126,7 @@ export function useCatalogs(options: Params) {
|
|||||||
return {
|
return {
|
||||||
...result,
|
...result,
|
||||||
refetch,
|
refetch,
|
||||||
|
data: result.data?.catalogs,
|
||||||
|
defaultCatalog: result.data?.defaultCatalog ?? null,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,12 +28,15 @@ import { useSchemas } from './schemas';
|
|||||||
|
|
||||||
const fakeApiResult = {
|
const fakeApiResult = {
|
||||||
result: ['test schema 1', 'test schema b'],
|
result: ['test schema 1', 'test schema b'],
|
||||||
|
default: 'test schema 1',
|
||||||
};
|
};
|
||||||
const fakeApiResult2 = {
|
const fakeApiResult2 = {
|
||||||
result: ['test schema 2', 'test schema a'],
|
result: ['test schema 2', 'test schema a'],
|
||||||
|
default: null,
|
||||||
};
|
};
|
||||||
const fakeApiResult3 = {
|
const fakeApiResult3 = {
|
||||||
result: ['test schema 3', 'test schema c'],
|
result: ['test schema 3', 'test schema c'],
|
||||||
|
default: 'test schema c',
|
||||||
};
|
};
|
||||||
|
|
||||||
const expectedResult = fakeApiResult.result.map((value: string) => ({
|
const expectedResult = fakeApiResult.result.map((value: string) => ({
|
||||||
@@ -80,6 +83,7 @@ describe('useSchemas hook', () => {
|
|||||||
);
|
);
|
||||||
await waitFor(() => expect(fetchMock.calls(schemaApiRoute).length).toBe(1));
|
await waitFor(() => expect(fetchMock.calls(schemaApiRoute).length).toBe(1));
|
||||||
expect(result.current.data).toEqual(expectedResult);
|
expect(result.current.data).toEqual(expectedResult);
|
||||||
|
expect(result.current.defaultSchema).toBe('test schema 1');
|
||||||
expect(
|
expect(
|
||||||
fetchMock.calls(
|
fetchMock.calls(
|
||||||
`end:/api/v1/database/${expectDbId}/schemas/?q=${rison.encode({
|
`end:/api/v1/database/${expectDbId}/schemas/?q=${rison.encode({
|
||||||
@@ -120,9 +124,11 @@ describe('useSchemas hook', () => {
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
await waitFor(() => expect(result.current.data).toEqual(expectedResult));
|
await waitFor(() => expect(result.current.data).toEqual(expectedResult));
|
||||||
|
expect(result.current.defaultSchema).toBe('test schema 1');
|
||||||
expect(fetchMock.calls(schemaApiRoute).length).toBe(1);
|
expect(fetchMock.calls(schemaApiRoute).length).toBe(1);
|
||||||
rerender();
|
rerender();
|
||||||
await waitFor(() => expect(result.current.data).toEqual(expectedResult));
|
await waitFor(() => expect(result.current.data).toEqual(expectedResult));
|
||||||
|
expect(result.current.defaultSchema).toBe('test schema 1');
|
||||||
expect(fetchMock.calls(schemaApiRoute).length).toBe(1);
|
expect(fetchMock.calls(schemaApiRoute).length).toBe(1);
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -148,23 +154,20 @@ describe('useSchemas hook', () => {
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
await waitFor(() =>
|
await waitFor(() => expect(result.current.data).toEqual(expectedResult));
|
||||||
expect(result.current.currentData).toEqual(expectedResult),
|
expect(result.current.defaultSchema).toBe('test schema 1');
|
||||||
);
|
|
||||||
expect(fetchMock.calls(schemaApiRoute).length).toBe(1);
|
expect(fetchMock.calls(schemaApiRoute).length).toBe(1);
|
||||||
expect(onSuccess).toHaveBeenCalledTimes(1);
|
expect(onSuccess).toHaveBeenCalledTimes(1);
|
||||||
|
|
||||||
rerender({ dbId: 'db2' });
|
rerender({ dbId: 'db2' });
|
||||||
await waitFor(() =>
|
await waitFor(() => expect(result.current.data).toEqual(expectedResult2));
|
||||||
expect(result.current.currentData).toEqual(expectedResult2),
|
expect(result.current.defaultSchema).toBeNull();
|
||||||
);
|
|
||||||
expect(fetchMock.calls(schemaApiRoute).length).toBe(2);
|
expect(fetchMock.calls(schemaApiRoute).length).toBe(2);
|
||||||
expect(onSuccess).toHaveBeenCalledTimes(2);
|
expect(onSuccess).toHaveBeenCalledTimes(2);
|
||||||
|
|
||||||
rerender({ dbId: expectDbId });
|
rerender({ dbId: expectDbId });
|
||||||
await waitFor(() =>
|
await waitFor(() => expect(result.current.data).toEqual(expectedResult));
|
||||||
expect(result.current.currentData).toEqual(expectedResult),
|
expect(result.current.defaultSchema).toBe('test schema 1');
|
||||||
);
|
|
||||||
expect(fetchMock.calls(schemaApiRoute).length).toBe(2);
|
expect(fetchMock.calls(schemaApiRoute).length).toBe(2);
|
||||||
expect(onSuccess).toHaveBeenCalledTimes(2);
|
expect(onSuccess).toHaveBeenCalledTimes(2);
|
||||||
|
|
||||||
@@ -175,9 +178,7 @@ describe('useSchemas hook', () => {
|
|||||||
|
|
||||||
await waitFor(() => expect(fetchMock.calls(schemaApiRoute).length).toBe(4));
|
await waitFor(() => expect(fetchMock.calls(schemaApiRoute).length).toBe(4));
|
||||||
expect(fetchMock.calls(schemaApiRoute)[2][0]).toContain(expectDbId);
|
expect(fetchMock.calls(schemaApiRoute)[2][0]).toContain(expectDbId);
|
||||||
await waitFor(() =>
|
await waitFor(() => expect(result.current.data).toEqual(expectedResult));
|
||||||
expect(result.current.currentData).toEqual(expectedResult),
|
|
||||||
);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
test('returns correct schema list by a catalog', async () => {
|
test('returns correct schema list by a catalog', async () => {
|
||||||
@@ -208,14 +209,37 @@ describe('useSchemas hook', () => {
|
|||||||
|
|
||||||
await waitFor(() => expect(fetchMock.calls(schemaApiRoute).length).toBe(1));
|
await waitFor(() => expect(fetchMock.calls(schemaApiRoute).length).toBe(1));
|
||||||
expect(result.current.data).toEqual(expectedResult3);
|
expect(result.current.data).toEqual(expectedResult3);
|
||||||
|
expect(result.current.defaultSchema).toBe('test schema c');
|
||||||
expect(onSuccess).toHaveBeenCalledTimes(1);
|
expect(onSuccess).toHaveBeenCalledTimes(1);
|
||||||
|
|
||||||
rerender({ dbId, catalog: 'catalog2' });
|
rerender({ dbId, catalog: 'catalog2' });
|
||||||
await waitFor(() => expect(fetchMock.calls(schemaApiRoute).length).toBe(2));
|
await waitFor(() => expect(fetchMock.calls(schemaApiRoute).length).toBe(2));
|
||||||
expect(result.current.data).toEqual(expectedResult2);
|
expect(result.current.data).toEqual(expectedResult2);
|
||||||
|
expect(result.current.defaultSchema).toBeNull();
|
||||||
|
|
||||||
rerender({ dbId, catalog: expectCatalog });
|
rerender({ dbId, catalog: expectCatalog });
|
||||||
expect(result.current.data).toEqual(expectedResult3);
|
expect(result.current.data).toEqual(expectedResult3);
|
||||||
|
expect(result.current.defaultSchema).toBe('test schema c');
|
||||||
expect(fetchMock.calls(schemaApiRoute).length).toBe(2);
|
expect(fetchMock.calls(schemaApiRoute).length).toBe(2);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test('returns null defaultSchema when API response has no default', async () => {
|
||||||
|
const expectDbId = 'db-no-default';
|
||||||
|
const schemaApiRoute = `glob:*/api/v1/database/${expectDbId}/schemas/*`;
|
||||||
|
fetchMock.get(schemaApiRoute, { result: ['schema1', 'schema2'] });
|
||||||
|
const { result, waitFor } = renderHook(
|
||||||
|
() =>
|
||||||
|
useSchemas({
|
||||||
|
dbId: expectDbId,
|
||||||
|
}),
|
||||||
|
{
|
||||||
|
wrapper: createWrapper({
|
||||||
|
useRedux: true,
|
||||||
|
store,
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
await waitFor(() => expect(fetchMock.calls(schemaApiRoute).length).toBe(1));
|
||||||
|
expect(result.current.defaultSchema).toBeNull();
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -31,15 +31,25 @@ export type FetchSchemasQueryParams = {
|
|||||||
dbId?: string | number;
|
dbId?: string | number;
|
||||||
catalog?: string;
|
catalog?: string;
|
||||||
forceRefresh: boolean;
|
forceRefresh: boolean;
|
||||||
onSuccess?: (data: SchemaOption[], isRefetched: boolean) => void;
|
onSuccess?: (
|
||||||
|
data: SchemaOption[],
|
||||||
|
isRefetched: boolean,
|
||||||
|
defaultSchema: string | null,
|
||||||
|
) => void;
|
||||||
onError?: (error: ClientErrorObject) => void;
|
onError?: (error: ClientErrorObject) => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
type Params = Omit<FetchSchemasQueryParams, 'forceRefresh'>;
|
type Params = Omit<FetchSchemasQueryParams, 'forceRefresh'>;
|
||||||
|
|
||||||
|
// Internal type for transformed API response
|
||||||
|
type SchemasApiResponse = {
|
||||||
|
schemas: SchemaOption[];
|
||||||
|
defaultSchema: string | null;
|
||||||
|
};
|
||||||
|
|
||||||
const schemaApi = api.injectEndpoints({
|
const schemaApi = api.injectEndpoints({
|
||||||
endpoints: builder => ({
|
endpoints: builder => ({
|
||||||
schemas: builder.query<SchemaOption[], FetchSchemasQueryParams>({
|
schemas: builder.query<SchemasApiResponse, FetchSchemasQueryParams>({
|
||||||
providesTags: [{ type: 'Schemas', id: 'LIST' }],
|
providesTags: [{ type: 'Schemas', id: 'LIST' }],
|
||||||
query: ({ dbId, catalog, forceRefresh }) => ({
|
query: ({ dbId, catalog, forceRefresh }) => ({
|
||||||
endpoint: `/api/v1/database/${dbId}/schemas/`,
|
endpoint: `/api/v1/database/${dbId}/schemas/`,
|
||||||
@@ -48,12 +58,14 @@ const schemaApi = api.injectEndpoints({
|
|||||||
force: forceRefresh,
|
force: forceRefresh,
|
||||||
...(catalog !== undefined && { catalog }),
|
...(catalog !== undefined && { catalog }),
|
||||||
},
|
},
|
||||||
transformResponse: ({ json }: JsonResponse) =>
|
transformResponse: ({ json }: JsonResponse) => ({
|
||||||
json.result.sort().map((value: string) => ({
|
schemas: json.result.sort().map((value: string) => ({
|
||||||
value,
|
value,
|
||||||
label: value,
|
label: value,
|
||||||
title: value,
|
title: value,
|
||||||
})),
|
})),
|
||||||
|
defaultSchema: json.default ?? null,
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
serializeQueryArgs: ({ queryArgs: { dbId, catalog } }) => ({
|
serializeQueryArgs: ({ queryArgs: { dbId, catalog } }) => ({
|
||||||
dbId,
|
dbId,
|
||||||
@@ -98,7 +110,11 @@ export function useSchemas(options: Params) {
|
|||||||
trigger({ dbId, catalog, forceRefresh }).then(
|
trigger({ dbId, catalog, forceRefresh }).then(
|
||||||
({ isSuccess, isError, data }) => {
|
({ isSuccess, isError, data }) => {
|
||||||
if (isSuccess) {
|
if (isSuccess) {
|
||||||
onSuccess?.(data || EMPTY_SCHEMAS, forceRefresh);
|
onSuccess?.(
|
||||||
|
data?.schemas || EMPTY_SCHEMAS,
|
||||||
|
forceRefresh,
|
||||||
|
data?.defaultSchema ?? null,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
if (isError) {
|
if (isError) {
|
||||||
onError?.(result.error as ClientErrorObject);
|
onError?.(result.error as ClientErrorObject);
|
||||||
@@ -120,5 +136,7 @@ export function useSchemas(options: Params) {
|
|||||||
return {
|
return {
|
||||||
...result,
|
...result,
|
||||||
refetch,
|
refetch,
|
||||||
|
data: result.currentData?.schemas,
|
||||||
|
defaultSchema: result.currentData?.defaultSchema ?? null,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -167,7 +167,7 @@ export const {
|
|||||||
export function useTables(options: Params) {
|
export function useTables(options: Params) {
|
||||||
const { dbId, catalog, schema, onSuccess, onError } = options || {};
|
const { dbId, catalog, schema, onSuccess, onError } = options || {};
|
||||||
const isMountedRef = useRef(false);
|
const isMountedRef = useRef(false);
|
||||||
const { currentData: schemaOptions, isFetching } = useSchemas({
|
const { data: schemaOptions, isFetching } = useSchemas({
|
||||||
dbId,
|
dbId,
|
||||||
catalog: catalog || undefined,
|
catalog: catalog || undefined,
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -317,6 +317,32 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||||||
"changed_by": [["id", BaseFilterRelatedUsers, lambda: []]],
|
"changed_by": [["id", BaseFilterRelatedUsers, lambda: []]],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_default_schema(
|
||||||
|
database: Database,
|
||||||
|
catalog: str | None,
|
||||||
|
accessible_schemas: set[str],
|
||||||
|
pk: int,
|
||||||
|
) -> str | None:
|
||||||
|
"""
|
||||||
|
Get the default schema for a database/catalog, with error handling.
|
||||||
|
|
||||||
|
Returns None if the default cannot be determined or is not accessible.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
default_schema = database.get_default_schema(catalog)
|
||||||
|
# Only include if user has access to it
|
||||||
|
if default_schema and default_schema not in accessible_schemas:
|
||||||
|
return None
|
||||||
|
return default_schema
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
logger.debug(
|
||||||
|
"Could not get default schema for database %s, catalog %s",
|
||||||
|
pk,
|
||||||
|
catalog,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
@expose("/<int:pk>/connection", methods=("GET",))
|
@expose("/<int:pk>/connection", methods=("GET",))
|
||||||
@protect()
|
@protect()
|
||||||
@safe
|
@safe
|
||||||
@@ -726,7 +752,18 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||||||
database,
|
database,
|
||||||
catalogs,
|
catalogs,
|
||||||
)
|
)
|
||||||
return self.response(200, result=list(catalogs))
|
|
||||||
|
# Get default catalog with error handling
|
||||||
|
default_catalog = None
|
||||||
|
try:
|
||||||
|
default_catalog = database.get_default_catalog()
|
||||||
|
# Only include if user has access to it
|
||||||
|
if default_catalog and default_catalog not in catalogs:
|
||||||
|
default_catalog = None
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
logger.debug("Could not get default catalog for database %s", pk)
|
||||||
|
|
||||||
|
return self.response(200, result=list(catalogs), default=default_catalog)
|
||||||
except OperationalError:
|
except OperationalError:
|
||||||
return self.response(
|
return self.response(
|
||||||
500,
|
500,
|
||||||
@@ -795,23 +832,30 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||||||
catalog,
|
catalog,
|
||||||
schemas,
|
schemas,
|
||||||
)
|
)
|
||||||
|
default_schema = self._get_default_schema(database, catalog, schemas, pk)
|
||||||
|
|
||||||
if params.get("upload_allowed"):
|
if params.get("upload_allowed"):
|
||||||
if not database.allow_file_upload:
|
if not database.allow_file_upload:
|
||||||
return self.response(200, result=[])
|
return self.response(200, result=[], default=None)
|
||||||
if allowed_schemas := database.get_schema_access_for_file_upload():
|
if allowed_schemas := database.get_schema_access_for_file_upload():
|
||||||
# some databases might return the list of schemas in uppercase,
|
# some databases might return the list of schemas in uppercase,
|
||||||
# while the list of allowed schemas is manually inputted so
|
# while the list of allowed schemas is manually inputted so
|
||||||
# could be lowercase
|
# could be lowercase
|
||||||
allowed_schemas = {schema.lower() for schema in allowed_schemas}
|
allowed_schemas = {schema.lower() for schema in allowed_schemas}
|
||||||
|
filtered_schemas = [
|
||||||
|
schema
|
||||||
|
for schema in schemas
|
||||||
|
if schema.lower() in allowed_schemas
|
||||||
|
]
|
||||||
|
# Check if default is in filtered list
|
||||||
|
if default_schema and default_schema.lower() not in allowed_schemas:
|
||||||
|
default_schema = None
|
||||||
return self.response(
|
return self.response(
|
||||||
200,
|
200,
|
||||||
result=[
|
result=filtered_schemas,
|
||||||
schema
|
default=default_schema,
|
||||||
for schema in schemas
|
|
||||||
if schema.lower() in allowed_schemas
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
return self.response(200, result=list(schemas))
|
return self.response(200, result=list(schemas), default=default_schema)
|
||||||
except OperationalError:
|
except OperationalError:
|
||||||
return self.response(
|
return self.response(
|
||||||
500, message="There was an error connecting to the database"
|
500, message="There was an error connecting to the database"
|
||||||
|
|||||||
@@ -742,12 +742,22 @@ class SchemasResponseSchema(Schema):
|
|||||||
result = fields.List(
|
result = fields.List(
|
||||||
fields.String(metadata={"description": "A database schema name"})
|
fields.String(metadata={"description": "A database schema name"})
|
||||||
)
|
)
|
||||||
|
default = fields.String(
|
||||||
|
allow_none=True,
|
||||||
|
load_default=None,
|
||||||
|
metadata={"description": "The default schema for this database/catalog"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CatalogsResponseSchema(Schema):
|
class CatalogsResponseSchema(Schema):
|
||||||
result = fields.List(
|
result = fields.List(
|
||||||
fields.String(metadata={"description": "A database catalog name"})
|
fields.String(metadata={"description": "A database catalog name"})
|
||||||
)
|
)
|
||||||
|
default = fields.String(
|
||||||
|
allow_none=True,
|
||||||
|
load_default=None,
|
||||||
|
metadata={"description": "The default catalog for this database"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DatabaseTablesResponse(Schema):
|
class DatabaseTablesResponse(Schema):
|
||||||
|
|||||||
@@ -255,7 +255,7 @@ def test_database_connection(
|
|||||||
"service_account_info": {
|
"service_account_info": {
|
||||||
"type": "service_account",
|
"type": "service_account",
|
||||||
"project_id": "black-sanctum-314419",
|
"project_id": "black-sanctum-314419",
|
||||||
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173",
|
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173", # noqa: E501
|
||||||
"private_key": "XXXXXXXXXX",
|
"private_key": "XXXXXXXXXX",
|
||||||
"client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com", # noqa: E501
|
"client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com", # noqa: E501
|
||||||
"client_id": "114567578578109757129",
|
"client_id": "114567578578109757129",
|
||||||
@@ -2104,6 +2104,7 @@ def test_catalogs(
|
|||||||
"""
|
"""
|
||||||
database = mocker.MagicMock()
|
database = mocker.MagicMock()
|
||||||
database.get_all_catalog_names.return_value = {"db1", "db2"}
|
database.get_all_catalog_names.return_value = {"db1", "db2"}
|
||||||
|
database.get_default_catalog.return_value = "db2"
|
||||||
DatabaseDAO = mocker.patch("superset.databases.api.DatabaseDAO") # noqa: N806
|
DatabaseDAO = mocker.patch("superset.databases.api.DatabaseDAO") # noqa: N806
|
||||||
DatabaseDAO.find_by_id.return_value = database
|
DatabaseDAO.find_by_id.return_value = database
|
||||||
|
|
||||||
@@ -2115,7 +2116,7 @@ def test_catalogs(
|
|||||||
|
|
||||||
response = client.get("/api/v1/database/1/catalogs/")
|
response = client.get("/api/v1/database/1/catalogs/")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json == {"result": ["db2"]}
|
assert response.json == {"result": ["db2"], "default": "db2"}
|
||||||
database.get_all_catalog_names.assert_called_with(
|
database.get_all_catalog_names.assert_called_with(
|
||||||
cache=database.catalog_cache_enabled,
|
cache=database.catalog_cache_enabled,
|
||||||
cache_timeout=database.catalog_cache_timeout,
|
cache_timeout=database.catalog_cache_timeout,
|
||||||
@@ -2187,6 +2188,7 @@ def test_schemas(
|
|||||||
|
|
||||||
database = mocker.MagicMock()
|
database = mocker.MagicMock()
|
||||||
database.get_all_schema_names.return_value = {"schema1", "schema2"}
|
database.get_all_schema_names.return_value = {"schema1", "schema2"}
|
||||||
|
database.get_default_schema.return_value = "schema2"
|
||||||
datamodel = mocker.patch.object(DatabaseRestApi, "datamodel")
|
datamodel = mocker.patch.object(DatabaseRestApi, "datamodel")
|
||||||
datamodel.get.return_value = database
|
datamodel.get.return_value = database
|
||||||
|
|
||||||
@@ -2198,7 +2200,7 @@ def test_schemas(
|
|||||||
|
|
||||||
response = client.get("/api/v1/database/1/schemas/")
|
response = client.get("/api/v1/database/1/schemas/")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json == {"result": ["schema2"]}
|
assert response.json == {"result": ["schema2"], "default": "schema2"}
|
||||||
database.get_all_schema_names.assert_called_with(
|
database.get_all_schema_names.assert_called_with(
|
||||||
catalog=None,
|
catalog=None,
|
||||||
cache=database.schema_cache_enabled,
|
cache=database.schema_cache_enabled,
|
||||||
@@ -2274,3 +2276,184 @@ def test_schemas_with_oauth2(
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_catalogs_default_not_accessible(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
client: Any,
|
||||||
|
full_api_access: None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Test that `default` is null when the default catalog is not accessible to the user.
|
||||||
|
"""
|
||||||
|
database = mocker.MagicMock()
|
||||||
|
database.get_all_catalog_names.return_value = {"db1", "db2"}
|
||||||
|
database.get_default_catalog.return_value = "db1" # default is db1
|
||||||
|
DatabaseDAO = mocker.patch("superset.databases.api.DatabaseDAO") # noqa: N806
|
||||||
|
DatabaseDAO.find_by_id.return_value = database
|
||||||
|
|
||||||
|
security_manager = mocker.patch(
|
||||||
|
"superset.databases.api.security_manager",
|
||||||
|
new=mocker.MagicMock(),
|
||||||
|
)
|
||||||
|
# User only has access to db2, not the default db1
|
||||||
|
security_manager.get_catalogs_accessible_by_user.return_value = {"db2"}
|
||||||
|
|
||||||
|
response = client.get("/api/v1/database/1/catalogs/")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json == {"result": ["db2"], "default": None}
|
||||||
|
|
||||||
|
|
||||||
|
def test_catalogs_default_retrieval_fails(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
client: Any,
|
||||||
|
full_api_access: None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Test that the endpoint still works when get_default_catalog fails.
|
||||||
|
"""
|
||||||
|
database = mocker.MagicMock()
|
||||||
|
database.get_all_catalog_names.return_value = {"db1", "db2"}
|
||||||
|
database.get_default_catalog.side_effect = Exception("Connection failed")
|
||||||
|
DatabaseDAO = mocker.patch("superset.databases.api.DatabaseDAO") # noqa: N806
|
||||||
|
DatabaseDAO.find_by_id.return_value = database
|
||||||
|
|
||||||
|
security_manager = mocker.patch(
|
||||||
|
"superset.databases.api.security_manager",
|
||||||
|
new=mocker.MagicMock(),
|
||||||
|
)
|
||||||
|
security_manager.get_catalogs_accessible_by_user.return_value = {"db1", "db2"}
|
||||||
|
|
||||||
|
response = client.get("/api/v1/database/1/catalogs/")
|
||||||
|
assert response.status_code == 200
|
||||||
|
# Result should still be returned, default is null due to error
|
||||||
|
assert set(response.json["result"]) == {"db1", "db2"}
|
||||||
|
assert response.json["default"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_schemas_default_not_accessible(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
client: Any,
|
||||||
|
full_api_access: None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Test that `default` is null when the default schema is not accessible to the user.
|
||||||
|
"""
|
||||||
|
from superset.databases.api import DatabaseRestApi
|
||||||
|
|
||||||
|
database = mocker.MagicMock()
|
||||||
|
database.get_all_schema_names.return_value = {"public", "private"}
|
||||||
|
database.get_default_schema.return_value = "public" # default is public
|
||||||
|
datamodel = mocker.patch.object(DatabaseRestApi, "datamodel")
|
||||||
|
datamodel.get.return_value = database
|
||||||
|
|
||||||
|
security_manager = mocker.patch(
|
||||||
|
"superset.databases.api.security_manager",
|
||||||
|
new=mocker.MagicMock(),
|
||||||
|
)
|
||||||
|
# User only has access to private, not the default public
|
||||||
|
security_manager.get_schemas_accessible_by_user.return_value = {"private"}
|
||||||
|
|
||||||
|
response = client.get("/api/v1/database/1/schemas/")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json == {"result": ["private"], "default": None}
|
||||||
|
|
||||||
|
|
||||||
|
def test_schemas_default_retrieval_fails(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
client: Any,
|
||||||
|
full_api_access: None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Test that the endpoint still works when get_default_schema fails.
|
||||||
|
"""
|
||||||
|
from superset.databases.api import DatabaseRestApi
|
||||||
|
|
||||||
|
database = mocker.MagicMock()
|
||||||
|
database.get_all_schema_names.return_value = {"public", "private"}
|
||||||
|
database.get_default_schema.side_effect = Exception("Connection failed")
|
||||||
|
datamodel = mocker.patch.object(DatabaseRestApi, "datamodel")
|
||||||
|
datamodel.get.return_value = database
|
||||||
|
|
||||||
|
security_manager = mocker.patch(
|
||||||
|
"superset.databases.api.security_manager",
|
||||||
|
new=mocker.MagicMock(),
|
||||||
|
)
|
||||||
|
security_manager.get_schemas_accessible_by_user.return_value = {"public", "private"}
|
||||||
|
|
||||||
|
response = client.get("/api/v1/database/1/schemas/")
|
||||||
|
assert response.status_code == 200
|
||||||
|
# Result should still be returned, default is null due to error
|
||||||
|
assert set(response.json["result"]) == {"public", "private"}
|
||||||
|
assert response.json["default"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_schemas_default_with_upload_allowed(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
client: Any,
|
||||||
|
full_api_access: None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Test that default schema is returned correctly with upload_allowed filter.
|
||||||
|
"""
|
||||||
|
from superset.databases.api import DatabaseRestApi
|
||||||
|
|
||||||
|
database = mocker.MagicMock()
|
||||||
|
database.get_all_schema_names.return_value = {"public", "uploads", "private"}
|
||||||
|
database.get_default_schema.return_value = "public"
|
||||||
|
database.allow_file_upload = True
|
||||||
|
database.get_schema_access_for_file_upload.return_value = ["uploads", "public"]
|
||||||
|
datamodel = mocker.patch.object(DatabaseRestApi, "datamodel")
|
||||||
|
datamodel.get.return_value = database
|
||||||
|
|
||||||
|
security_manager = mocker.patch(
|
||||||
|
"superset.databases.api.security_manager",
|
||||||
|
new=mocker.MagicMock(),
|
||||||
|
)
|
||||||
|
security_manager.get_schemas_accessible_by_user.return_value = {
|
||||||
|
"public",
|
||||||
|
"uploads",
|
||||||
|
"private",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.get("/api/v1/database/1/schemas/?q=(upload_allowed:!t)")
|
||||||
|
assert response.status_code == 200
|
||||||
|
# Only upload-allowed schemas should be returned
|
||||||
|
assert set(response.json["result"]) == {"public", "uploads"}
|
||||||
|
# Default should be public since it's in the allowed list
|
||||||
|
assert response.json["default"] == "public"
|
||||||
|
|
||||||
|
|
||||||
|
def test_schemas_default_not_in_upload_allowed(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
client: Any,
|
||||||
|
full_api_access: None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Test that default schema is null when not in upload_allowed schemas.
|
||||||
|
"""
|
||||||
|
from superset.databases.api import DatabaseRestApi
|
||||||
|
|
||||||
|
database = mocker.MagicMock()
|
||||||
|
database.get_all_schema_names.return_value = {"public", "uploads", "private"}
|
||||||
|
database.get_default_schema.return_value = "private" # default not in allowed list
|
||||||
|
database.allow_file_upload = True
|
||||||
|
database.get_schema_access_for_file_upload.return_value = ["uploads", "public"]
|
||||||
|
datamodel = mocker.patch.object(DatabaseRestApi, "datamodel")
|
||||||
|
datamodel.get.return_value = database
|
||||||
|
|
||||||
|
security_manager = mocker.patch(
|
||||||
|
"superset.databases.api.security_manager",
|
||||||
|
new=mocker.MagicMock(),
|
||||||
|
)
|
||||||
|
security_manager.get_schemas_accessible_by_user.return_value = {
|
||||||
|
"public",
|
||||||
|
"uploads",
|
||||||
|
"private",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.get("/api/v1/database/1/schemas/?q=(upload_allowed:!t)")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert set(response.json["result"]) == {"public", "uploads"}
|
||||||
|
# Default should be null since "private" is not in allowed list
|
||||||
|
assert response.json["default"] is None
|
||||||
|
|||||||
Reference in New Issue
Block a user