mirror of
https://github.com/apache/superset.git
synced 2026-04-30 05:24:31 +00:00
Compare commits
19 Commits
embedded-e
...
new-dar
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
40787496e2 | ||
|
|
b180cc4fca | ||
|
|
78e4f55ede | ||
|
|
853d57337d | ||
|
|
3ad694e5a9 | ||
|
|
b66729ad08 | ||
|
|
c5d2329297 | ||
|
|
ca7635dfc2 | ||
|
|
0aaa13ab79 | ||
|
|
902509b1f0 | ||
|
|
2e7df4614c | ||
|
|
3554325104 | ||
|
|
b469b01e0f | ||
|
|
4d9378a818 | ||
|
|
0141bdd2b0 | ||
|
|
808ba668ff | ||
|
|
e9fc7c6f6c | ||
|
|
5c61c40704 | ||
|
|
57a210f7d6 |
@@ -105,7 +105,7 @@ class CeleryConfig:
|
||||
|
||||
CELERY_CONFIG = CeleryConfig
|
||||
|
||||
FEATURE_FLAGS = {"ALERT_REPORTS": True}
|
||||
FEATURE_FLAGS = {"ALERT_REPORTS": True, "DATA_ACCESS_RULES": True}
|
||||
ALERT_REPORTS_NOTIFICATION_DRY_RUN = True
|
||||
WEBDRIVER_BASEURL = f"http://superset_app{os.environ.get('SUPERSET_APP_ROOT', '/')}/" # When using docker compose baseurl should be http://superset_nginx{ENV{BASEPATH}}/ # noqa: E501
|
||||
# The base URL for the email report hyperlinks.
|
||||
|
||||
@@ -134,6 +134,8 @@ export function mapRows<T extends object>(
|
||||
) {
|
||||
return rows.map(row => {
|
||||
prepareRow(row);
|
||||
return { rowId: row.id, ...row.original, ...row.getRowProps() };
|
||||
// Spread getRowProps first so data properties from row.original take precedence
|
||||
// This prevents HTML attributes like `role: "row"` from overwriting data properties
|
||||
return { ...row.getRowProps(), rowId: row.id, ...row.original };
|
||||
});
|
||||
}
|
||||
|
||||
@@ -34,6 +34,7 @@ export enum FeatureFlag {
|
||||
CssTemplates = 'CSS_TEMPLATES',
|
||||
DashboardVirtualization = 'DASHBOARD_VIRTUALIZATION',
|
||||
DashboardRbac = 'DASHBOARD_RBAC',
|
||||
DataAccessRules = 'DATA_ACCESS_RULES',
|
||||
DatapanelClosedByDefault = 'DATAPANEL_CLOSED_BY_DEFAULT',
|
||||
DateRangeTimeshiftsEnabled = 'DATE_RANGE_TIMESHIFTS_ENABLED',
|
||||
/** @deprecated */
|
||||
|
||||
@@ -0,0 +1,496 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
import { SupersetClient, t } from '@superset-ui/core';
|
||||
import { css, styled } from '@apache-superset/core/ui';
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
import { ModalTitleWithIcon } from 'src/components/ModalTitleWithIcon';
|
||||
import {
|
||||
Modal,
|
||||
AsyncSelect,
|
||||
InfoTooltip,
|
||||
Input,
|
||||
Collapse,
|
||||
} from '@superset-ui/core/components';
|
||||
import rison from 'rison';
|
||||
import { useSingleViewResource } from 'src/views/CRUD/hooks';
|
||||
import { DataAccessRuleObject } from './types';
|
||||
import PermissionsTree from './PermissionsTree';
|
||||
import type { PermissionsPayload } from './PermissionsTree/types';
|
||||
|
||||
const StyledModal = styled(Modal)`
|
||||
max-width: 1200px;
|
||||
min-width: min-content;
|
||||
width: 100%;
|
||||
.ant-modal-footer {
|
||||
white-space: nowrap;
|
||||
}
|
||||
`;
|
||||
|
||||
const StyledSectionContainer = styled.div`
|
||||
${({ theme }) => css`
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
padding: ${theme.sizeUnit * 3}px ${theme.sizeUnit * 4}px
|
||||
${theme.sizeUnit * 2}px;
|
||||
|
||||
label,
|
||||
.control-label {
|
||||
display: flex;
|
||||
font-size: ${theme.fontSizeSM}px;
|
||||
color: ${theme.colorTextLabel};
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.info-solid-small {
|
||||
vertical-align: middle;
|
||||
padding-bottom: ${theme.sizeUnit / 2}px;
|
||||
}
|
||||
`}
|
||||
`;
|
||||
|
||||
const StyledInputContainer = styled.div`
|
||||
${({ theme }) => css`
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
margin: ${theme.sizeUnit}px;
|
||||
margin-bottom: ${theme.sizeUnit * 4}px;
|
||||
|
||||
.input-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
|
||||
> div {
|
||||
width: 100%;
|
||||
}
|
||||
}
|
||||
|
||||
input,
|
||||
textarea {
|
||||
flex: 1 1 auto;
|
||||
}
|
||||
|
||||
.required {
|
||||
margin-left: ${theme.sizeUnit / 2}px;
|
||||
color: ${theme.colorErrorText};
|
||||
}
|
||||
`}
|
||||
`;
|
||||
|
||||
const StyledTextArea = styled(Input.TextArea)`
|
||||
resize: vertical;
|
||||
margin-top: ${({ theme }) => theme.sizeUnit}px;
|
||||
font-family: monospace;
|
||||
`;
|
||||
|
||||
const StyledCollapse = styled(Collapse)`
|
||||
margin-top: ${({ theme }) => theme.sizeUnit * 2}px;
|
||||
background: transparent;
|
||||
|
||||
.ant-collapse-header {
|
||||
padding: ${({ theme }) => theme.sizeUnit}px 0 !important;
|
||||
}
|
||||
|
||||
.ant-collapse-content-box {
|
||||
padding: ${({ theme }) => theme.sizeUnit}px 0 !important;
|
||||
}
|
||||
`;
|
||||
|
||||
export interface DataAccessRuleModalProps {
|
||||
rule: DataAccessRuleObject | null;
|
||||
addSuccessToast: (msg: string) => void;
|
||||
addDangerToast: (msg: string) => void;
|
||||
onAdd?: (rule?: DataAccessRuleObject) => void;
|
||||
onHide: () => void;
|
||||
show: boolean;
|
||||
}
|
||||
|
||||
const DEFAULT_RULE: DataAccessRuleObject = {
|
||||
name: '',
|
||||
description: '',
|
||||
role_id: 0,
|
||||
rule: JSON.stringify(
|
||||
{
|
||||
allowed: [],
|
||||
denied: [],
|
||||
},
|
||||
null,
|
||||
2,
|
||||
),
|
||||
};
|
||||
|
||||
type SelectValue = {
|
||||
value: number;
|
||||
label: string;
|
||||
};
|
||||
|
||||
function DataAccessRuleModal(props: DataAccessRuleModalProps) {
|
||||
const { rule, addDangerToast, addSuccessToast, onHide, show } = props;
|
||||
|
||||
const [currentRule, setCurrentRule] = useState<DataAccessRuleObject>({
|
||||
...DEFAULT_RULE,
|
||||
});
|
||||
const [selectedRole, setSelectedRole] = useState<SelectValue | null>(null);
|
||||
const [disableSave, setDisableSave] = useState<boolean>(true);
|
||||
const [jsonError, setJsonError] = useState<string | null>(null);
|
||||
const [permissionsPayload, setPermissionsPayload] =
|
||||
useState<PermissionsPayload>({ allowed: [], denied: [] });
|
||||
const [showAdvanced, setShowAdvanced] = useState<string[]>([]);
|
||||
|
||||
const isEditMode = rule !== null;
|
||||
|
||||
const {
|
||||
state: { loading, resource, error: fetchError },
|
||||
fetchResource,
|
||||
createResource,
|
||||
updateResource,
|
||||
clearError,
|
||||
} = useSingleViewResource<DataAccessRuleObject>(
|
||||
'dar',
|
||||
t('data access rule'),
|
||||
addDangerToast,
|
||||
);
|
||||
|
||||
const updateRuleState = (name: string, value: any) => {
|
||||
setCurrentRule(currentRuleData => ({
|
||||
...currentRuleData,
|
||||
[name]: value,
|
||||
}));
|
||||
};
|
||||
|
||||
// Validate form
|
||||
const validate = useCallback(() => {
|
||||
// Check role is selected
|
||||
if (!selectedRole?.value) {
|
||||
setDisableSave(true);
|
||||
return;
|
||||
}
|
||||
|
||||
// If advanced mode is open, validate JSON
|
||||
if (showAdvanced.includes('advanced')) {
|
||||
try {
|
||||
const parsed = JSON.parse(currentRule.rule);
|
||||
if (typeof parsed !== 'object' || parsed === null) {
|
||||
setJsonError(t('Rule must be a JSON object'));
|
||||
setDisableSave(true);
|
||||
return;
|
||||
}
|
||||
setJsonError(null);
|
||||
} catch {
|
||||
setJsonError(t('Invalid JSON'));
|
||||
setDisableSave(true);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
setDisableSave(false);
|
||||
}, [currentRule.rule, selectedRole, showAdvanced]);
|
||||
|
||||
// Initialize
|
||||
useEffect(() => {
|
||||
if (!isEditMode) {
|
||||
setCurrentRule({ ...DEFAULT_RULE });
|
||||
setSelectedRole(null);
|
||||
setPermissionsPayload({ allowed: [], denied: [] });
|
||||
setShowAdvanced([]);
|
||||
} else if (rule?.id !== null && !loading && !fetchError) {
|
||||
fetchResource(rule.id as number);
|
||||
}
|
||||
}, [rule]);
|
||||
|
||||
useEffect(() => {
|
||||
if (resource) {
|
||||
const ruleStr =
|
||||
typeof resource.rule === 'string'
|
||||
? resource.rule
|
||||
: JSON.stringify(resource.rule, null, 2);
|
||||
|
||||
setCurrentRule({
|
||||
...resource,
|
||||
id: rule?.id,
|
||||
rule: ruleStr,
|
||||
});
|
||||
|
||||
// Parse rule into permissions payload
|
||||
try {
|
||||
const parsed =
|
||||
typeof resource.rule === 'string'
|
||||
? JSON.parse(resource.rule)
|
||||
: resource.rule;
|
||||
setPermissionsPayload({
|
||||
allowed: parsed.allowed || [],
|
||||
denied: parsed.denied || [],
|
||||
});
|
||||
} catch {
|
||||
setPermissionsPayload({ allowed: [], denied: [] });
|
||||
}
|
||||
|
||||
if (resource.role) {
|
||||
setSelectedRole({
|
||||
value: resource.role.id,
|
||||
label: resource.role.name,
|
||||
});
|
||||
}
|
||||
}
|
||||
}, [resource]);
|
||||
|
||||
useEffect(() => {
|
||||
validate();
|
||||
}, [validate]);
|
||||
|
||||
const onRuleChange = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||
updateRuleState('rule', event.target.value);
|
||||
};
|
||||
|
||||
const onRoleChange = (value: SelectValue | null) => {
|
||||
setSelectedRole(value);
|
||||
if (value) {
|
||||
updateRuleState('role_id', value.value);
|
||||
}
|
||||
};
|
||||
|
||||
const onPermissionsChange = useCallback((payload: PermissionsPayload) => {
|
||||
setPermissionsPayload(payload);
|
||||
// Update JSON representation
|
||||
const ruleJson = JSON.stringify(payload, null, 2);
|
||||
updateRuleState('rule', ruleJson);
|
||||
}, []);
|
||||
|
||||
const onAdvancedChange = (keys: string | string[]) => {
|
||||
const keyArray = Array.isArray(keys) ? keys : [keys];
|
||||
setShowAdvanced(keyArray);
|
||||
|
||||
// When opening advanced, sync JSON from permissions
|
||||
if (keyArray.includes('advanced')) {
|
||||
const ruleJson = JSON.stringify(permissionsPayload, null, 2);
|
||||
updateRuleState('rule', ruleJson);
|
||||
}
|
||||
// When closing advanced, sync permissions from JSON
|
||||
if (!keyArray.includes('advanced') && showAdvanced.includes('advanced')) {
|
||||
try {
|
||||
const parsed = JSON.parse(currentRule.rule);
|
||||
setPermissionsPayload({
|
||||
allowed: parsed.allowed || [],
|
||||
denied: parsed.denied || [],
|
||||
});
|
||||
} catch {
|
||||
// Keep existing payload if JSON is invalid
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const hide = () => {
|
||||
clearError();
|
||||
setCurrentRule({ ...DEFAULT_RULE });
|
||||
setSelectedRole(null);
|
||||
setJsonError(null);
|
||||
setPermissionsPayload({ allowed: [], denied: [] });
|
||||
setShowAdvanced([]);
|
||||
onHide();
|
||||
};
|
||||
|
||||
const onSave = () => {
|
||||
const data = {
|
||||
name: currentRule.name || null,
|
||||
description: currentRule.description || null,
|
||||
role_id: selectedRole?.value,
|
||||
rule: currentRule.rule,
|
||||
};
|
||||
|
||||
if (isEditMode && currentRule.id) {
|
||||
updateResource(currentRule.id, data).then(response => {
|
||||
if (!response) {
|
||||
return;
|
||||
}
|
||||
addSuccessToast(t('Rule updated'));
|
||||
hide();
|
||||
});
|
||||
} else {
|
||||
createResource(data).then(response => {
|
||||
if (!response) return;
|
||||
addSuccessToast(t('Rule added'));
|
||||
hide();
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const loadRoleOptions = useMemo(
|
||||
() =>
|
||||
(input = '', page: number, pageSize: number) => {
|
||||
const query = rison.encode({
|
||||
filter: input,
|
||||
page,
|
||||
page_size: pageSize,
|
||||
});
|
||||
return SupersetClient.get({
|
||||
endpoint: `/api/v1/dar/related/role?q=${query}`,
|
||||
}).then(response => {
|
||||
const list = response.json.result.map(
|
||||
(item: { value: number; text: string }) => ({
|
||||
label: item.text,
|
||||
value: item.value,
|
||||
}),
|
||||
);
|
||||
return { data: list, totalCount: response.json.count };
|
||||
});
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
return (
|
||||
<StyledModal
|
||||
className="no-content-padding"
|
||||
responsive
|
||||
show={show}
|
||||
onHide={hide}
|
||||
primaryButtonName={isEditMode ? t('Save') : t('Add')}
|
||||
disablePrimaryButton={disableSave}
|
||||
onHandledPrimaryAction={onSave}
|
||||
width="50%"
|
||||
maxWidth="800px"
|
||||
title={
|
||||
<ModalTitleWithIcon
|
||||
isEditMode={isEditMode}
|
||||
title={isEditMode ? t('Edit Data Access Rule') : t('Add Data Access Rule')}
|
||||
data-test="dar-modal-title"
|
||||
/>
|
||||
}
|
||||
>
|
||||
<StyledSectionContainer>
|
||||
<div className="main-section">
|
||||
<StyledInputContainer>
|
||||
<div className="control-label">
|
||||
{t('Role')} <span className="required">*</span>
|
||||
<InfoTooltip
|
||||
tooltip={t(
|
||||
'Select the role this rule applies to. Each role can have multiple data access rules.',
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
<div className="input-container">
|
||||
<AsyncSelect
|
||||
ariaLabel={t('Role')}
|
||||
onChange={onRoleChange}
|
||||
value={selectedRole}
|
||||
options={loadRoleOptions}
|
||||
data-test="role-select"
|
||||
/>
|
||||
</div>
|
||||
</StyledInputContainer>
|
||||
|
||||
<StyledInputContainer>
|
||||
<div className="control-label">
|
||||
{t('Name')}
|
||||
<InfoTooltip
|
||||
tooltip={t('Optional name to help identify this rule.')}
|
||||
/>
|
||||
</div>
|
||||
<div className="input-container">
|
||||
<Input
|
||||
name="name"
|
||||
value={currentRule.name || ''}
|
||||
onChange={e => updateRuleState('name', e.target.value)}
|
||||
placeholder={t('e.g., Sales team access')}
|
||||
data-test="rule-name"
|
||||
/>
|
||||
</div>
|
||||
</StyledInputContainer>
|
||||
|
||||
<StyledInputContainer>
|
||||
<div className="control-label">
|
||||
{t('Description')}
|
||||
<InfoTooltip
|
||||
tooltip={t('Optional description of what this rule grants or restricts.')}
|
||||
/>
|
||||
</div>
|
||||
<div className="input-container">
|
||||
<Input.TextArea
|
||||
name="description"
|
||||
value={currentRule.description || ''}
|
||||
onChange={e => updateRuleState('description', e.target.value)}
|
||||
placeholder={t('Describe the purpose of this rule...')}
|
||||
rows={2}
|
||||
data-test="rule-description"
|
||||
/>
|
||||
</div>
|
||||
</StyledInputContainer>
|
||||
|
||||
<StyledInputContainer>
|
||||
<div className="control-label">
|
||||
{t('Table Permissions')}
|
||||
<InfoTooltip
|
||||
tooltip={t(
|
||||
'Select databases, schemas, and tables to allow or deny access. Click the icons to toggle between Allow (green), Deny (red), and Inherit (gray).',
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
<PermissionsTree
|
||||
value={permissionsPayload}
|
||||
onChange={onPermissionsChange}
|
||||
/>
|
||||
</StyledInputContainer>
|
||||
|
||||
<StyledCollapse
|
||||
activeKey={showAdvanced}
|
||||
onChange={onAdvancedChange}
|
||||
ghost
|
||||
items={[
|
||||
{
|
||||
key: 'advanced',
|
||||
label: t('Advanced: Edit JSON directly'),
|
||||
children: (
|
||||
<>
|
||||
<StyledInputContainer>
|
||||
<div className="control-label">
|
||||
{t('Rule (JSON)')}
|
||||
<InfoTooltip
|
||||
tooltip={t(
|
||||
`Define the access rule as a JSON document with "allowed" and "denied" arrays. Each entry specifies database, catalog, schema, table, and optional RLS/CLS configurations.`,
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
<div className="input-container">
|
||||
<StyledTextArea
|
||||
rows={15}
|
||||
name="rule"
|
||||
value={currentRule.rule}
|
||||
onChange={onRuleChange}
|
||||
status={jsonError ? 'error' : undefined}
|
||||
data-test="rule-json"
|
||||
/>
|
||||
</div>
|
||||
{jsonError && (
|
||||
<div style={{ color: 'red', marginTop: '4px' }}>
|
||||
{jsonError}
|
||||
</div>
|
||||
)}
|
||||
</StyledInputContainer>
|
||||
|
||||
</>
|
||||
),
|
||||
},
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
</StyledSectionContainer>
|
||||
</StyledModal>
|
||||
);
|
||||
}
|
||||
|
||||
export default DataAccessRuleModal;
|
||||
@@ -0,0 +1,144 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
import { SupersetClient } from '@superset-ui/core';
|
||||
import rison from 'rison';
|
||||
|
||||
export interface DatabaseInfo {
|
||||
id: number;
|
||||
database_name: string;
|
||||
}
|
||||
|
||||
export interface FetchResult<T> {
|
||||
result: T[];
|
||||
count: number;
|
||||
}
|
||||
|
||||
export async function fetchDatabases(): Promise<DatabaseInfo[]> {
|
||||
const query = rison.encode({
|
||||
columns: ['id', 'database_name'],
|
||||
page_size: 1000,
|
||||
});
|
||||
const response = await SupersetClient.get({
|
||||
endpoint: `/api/v1/database/?q=${query}`,
|
||||
});
|
||||
return response.json.result;
|
||||
}
|
||||
|
||||
export async function fetchCatalogs(
|
||||
databaseId: number,
|
||||
filter?: string,
|
||||
page?: number,
|
||||
pageSize?: number,
|
||||
): Promise<FetchResult<string>> {
|
||||
const params: Record<string, unknown> = {};
|
||||
if (filter) params.filter = filter;
|
||||
if (page !== undefined) params.page = page;
|
||||
if (pageSize !== undefined) params.page_size = pageSize;
|
||||
|
||||
const query = rison.encode(params);
|
||||
const response = await SupersetClient.get({
|
||||
endpoint: `/api/v1/database/${databaseId}/catalogs/?q=${query}`,
|
||||
});
|
||||
return {
|
||||
result: response.json.result,
|
||||
count: response.json.count ?? response.json.result.length,
|
||||
};
|
||||
}
|
||||
|
||||
export async function fetchSchemas(
|
||||
databaseId: number,
|
||||
catalog?: string,
|
||||
filter?: string,
|
||||
page?: number,
|
||||
pageSize?: number,
|
||||
): Promise<FetchResult<string>> {
|
||||
const params: Record<string, unknown> = {};
|
||||
if (catalog) params.catalog = catalog;
|
||||
if (filter) params.filter = filter;
|
||||
if (page !== undefined) params.page = page;
|
||||
if (pageSize !== undefined) params.page_size = pageSize;
|
||||
|
||||
const query = rison.encode(params);
|
||||
const response = await SupersetClient.get({
|
||||
endpoint: `/api/v1/database/${databaseId}/schemas/?q=${query}`,
|
||||
});
|
||||
return {
|
||||
result: response.json.result,
|
||||
count: response.json.count ?? response.json.result.length,
|
||||
};
|
||||
}
|
||||
|
||||
export interface TableInfo {
|
||||
value: string;
|
||||
type: 'table' | 'view' | 'materialized_view';
|
||||
extra?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export async function fetchTables(
|
||||
databaseId: number,
|
||||
schemaName: string,
|
||||
catalogName?: string,
|
||||
filter?: string,
|
||||
page?: number,
|
||||
pageSize?: number,
|
||||
): Promise<FetchResult<TableInfo>> {
|
||||
const params: Record<string, unknown> = {
|
||||
schema_name: schemaName,
|
||||
};
|
||||
if (catalogName) params.catalog_name = catalogName;
|
||||
if (filter) params.filter = filter;
|
||||
if (page !== undefined) params.page = page;
|
||||
if (pageSize !== undefined) params.page_size = pageSize;
|
||||
|
||||
const query = rison.encode(params);
|
||||
const response = await SupersetClient.get({
|
||||
endpoint: `/api/v1/database/${databaseId}/tables/?q=${query}`,
|
||||
});
|
||||
return {
|
||||
result: response.json.result,
|
||||
count: response.json.count,
|
||||
};
|
||||
}
|
||||
|
||||
export interface ColumnInfo {
|
||||
name: string;
|
||||
type: string;
|
||||
nullable?: boolean;
|
||||
default?: string;
|
||||
}
|
||||
|
||||
export async function fetchColumns(
|
||||
databaseId: number,
|
||||
tableName: string,
|
||||
schemaName?: string,
|
||||
catalogName?: string,
|
||||
): Promise<ColumnInfo[]> {
|
||||
const params: Record<string, string> = {
|
||||
name: tableName,
|
||||
};
|
||||
if (schemaName) params.schema = schemaName;
|
||||
if (catalogName) params.catalog = catalogName;
|
||||
|
||||
const queryString = new URLSearchParams(params).toString();
|
||||
const response = await SupersetClient.get({
|
||||
endpoint: `/api/v1/database/${databaseId}/table_metadata/?${queryString}`,
|
||||
});
|
||||
return response.json.columns || [];
|
||||
}
|
||||
@@ -0,0 +1,832 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { t } from '@superset-ui/core';
|
||||
import { css, styled } from '@apache-superset/core/ui';
|
||||
import { Tree, Tooltip, Spin, Button, Select, Input } from 'antd';
|
||||
import {
|
||||
CheckCircleOutlined,
|
||||
CloseCircleOutlined,
|
||||
MinusCircleOutlined,
|
||||
DatabaseOutlined,
|
||||
FolderOutlined,
|
||||
TableOutlined,
|
||||
ColumnHeightOutlined,
|
||||
NumberOutlined,
|
||||
EyeInvisibleOutlined,
|
||||
StopOutlined,
|
||||
StarOutlined,
|
||||
} from '@ant-design/icons';
|
||||
import type { TreeProps } from 'antd';
|
||||
import type {
|
||||
PermissionNode,
|
||||
TreeState,
|
||||
PermissionsPayload,
|
||||
CLSAction,
|
||||
RLSRule,
|
||||
} from './types';
|
||||
import {
|
||||
fetchDatabases,
|
||||
fetchCatalogs,
|
||||
fetchSchemas,
|
||||
fetchTables,
|
||||
fetchColumns,
|
||||
type DatabaseInfo,
|
||||
} from './api';
|
||||
import {
|
||||
makeKey,
|
||||
makeClsKey,
|
||||
getEffectiveState,
|
||||
getExplicitState,
|
||||
cyclePermissionState,
|
||||
updateTreeData,
|
||||
generatePermissionsPayload,
|
||||
loadPermissionsFromPayload,
|
||||
countDescendantPermissions,
|
||||
} from './utils';
|
||||
|
||||
const DEFAULT_PAGE_SIZE = 50;
|
||||
|
||||
const StyledContainer = styled.div`
|
||||
${({ theme }) => css`
|
||||
.search-input {
|
||||
margin-bottom: ${theme.sizeUnit * 2}px;
|
||||
}
|
||||
|
||||
.tree-container {
|
||||
border: 1px solid ${theme.colorBorder};
|
||||
border-radius: ${theme.borderRadius}px;
|
||||
padding: ${theme.sizeUnit * 2}px;
|
||||
max-height: 500px;
|
||||
overflow: auto;
|
||||
}
|
||||
|
||||
.ant-tree-title {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: ${theme.sizeUnit}px;
|
||||
}
|
||||
|
||||
.permission-icon {
|
||||
cursor: pointer;
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
.permission-icon-allow {
|
||||
color: ${theme.colorSuccess};
|
||||
}
|
||||
|
||||
.permission-icon-deny {
|
||||
color: ${theme.colorError};
|
||||
}
|
||||
|
||||
.permission-icon-inherit {
|
||||
color: ${theme.colorTextSecondary};
|
||||
}
|
||||
|
||||
.node-title {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: ${theme.sizeUnit}px;
|
||||
}
|
||||
|
||||
.load-more-btn {
|
||||
margin-left: ${theme.sizeUnit * 3}px;
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
.node-count {
|
||||
font-size: 11px;
|
||||
color: ${theme.colorTextSecondary};
|
||||
margin-left: ${theme.sizeUnit}px;
|
||||
}
|
||||
|
||||
.cls-select {
|
||||
width: 90px;
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
.column-node {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
width: 100%;
|
||||
max-width: 300px;
|
||||
}
|
||||
|
||||
.column-name {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: ${theme.sizeUnit}px;
|
||||
}
|
||||
|
||||
.column-type {
|
||||
font-size: 11px;
|
||||
color: ${theme.colorTextSecondary};
|
||||
}
|
||||
|
||||
.rls-container {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: ${theme.sizeUnit}px;
|
||||
padding: ${theme.sizeUnit}px 0;
|
||||
margin-bottom: ${theme.sizeUnit}px;
|
||||
border-bottom: 1px dashed ${theme.colorBorder};
|
||||
}
|
||||
|
||||
.rls-row {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: ${theme.sizeUnit}px;
|
||||
}
|
||||
|
||||
.rls-label {
|
||||
font-size: 11px;
|
||||
color: ${theme.colorTextSecondary};
|
||||
min-width: 70px;
|
||||
}
|
||||
|
||||
.rls-input {
|
||||
flex: 1;
|
||||
max-width: 250px;
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
.count-allowed {
|
||||
color: ${theme.colorSuccess};
|
||||
font-weight: ${theme.fontWeightSemiBold};
|
||||
}
|
||||
|
||||
.count-denied {
|
||||
color: ${theme.colorError};
|
||||
font-weight: ${theme.fontWeightSemiBold};
|
||||
}
|
||||
`}
|
||||
`;
|
||||
|
||||
export interface PermissionsTreeProps {
|
||||
value?: PermissionsPayload;
|
||||
onChange?: (payload: PermissionsPayload) => void;
|
||||
pageSize?: number;
|
||||
}
|
||||
|
||||
function PermissionsTree({
|
||||
value,
|
||||
onChange,
|
||||
pageSize = DEFAULT_PAGE_SIZE,
|
||||
}: PermissionsTreeProps) {
|
||||
const [treeState, setTreeState] = useState<TreeState>({
|
||||
expandedKeys: [],
|
||||
loadedKeys: [],
|
||||
treeData: [],
|
||||
permissionStates: {},
|
||||
clsRules: {},
|
||||
rlsRules: {},
|
||||
});
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [databases, setDatabases] = useState<DatabaseInfo[]>([]);
|
||||
|
||||
// Map of database ID -> name and name -> ID
|
||||
const databaseMaps = useMemo(() => {
|
||||
const idToName = new Map<number, string>();
|
||||
const nameToId = new Map<string, number>();
|
||||
databases.forEach(db => {
|
||||
idToName.set(db.id, db.database_name);
|
||||
nameToId.set(db.database_name, db.id);
|
||||
});
|
||||
return { idToName, nameToId };
|
||||
}, [databases]);
|
||||
|
||||
// Track if we're in the middle of an internal update to avoid circular updates
|
||||
const isInternalUpdateRef = useRef(false);
|
||||
|
||||
// Load initial databases
|
||||
useEffect(() => {
|
||||
loadDatabases();
|
||||
}, []);
|
||||
|
||||
// Load initial value (only on mount or when value changes externally)
|
||||
useEffect(() => {
|
||||
if (isInternalUpdateRef.current) {
|
||||
isInternalUpdateRef.current = false;
|
||||
return;
|
||||
}
|
||||
if (value && databaseMaps.nameToId.size > 0) {
|
||||
const { states, clsRules, rlsRules } = loadPermissionsFromPayload(
|
||||
value,
|
||||
databaseMaps.nameToId,
|
||||
);
|
||||
setTreeState(prev => ({
|
||||
...prev,
|
||||
permissionStates: states,
|
||||
clsRules,
|
||||
rlsRules,
|
||||
}));
|
||||
}
|
||||
}, [value, databaseMaps.nameToId]);
|
||||
|
||||
const loadDatabases = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const dbs = await fetchDatabases();
|
||||
setDatabases(dbs);
|
||||
|
||||
const treeData: PermissionNode[] = dbs.map(db => ({
|
||||
key: makeKey({ databaseId: db.id }),
|
||||
title: db.database_name,
|
||||
nodeType: 'database',
|
||||
databaseId: db.id,
|
||||
databaseName: db.database_name,
|
||||
children: [],
|
||||
isLeaf: false,
|
||||
}));
|
||||
|
||||
setTreeState(prev => ({
|
||||
...prev,
|
||||
treeData,
|
||||
}));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const loadChildren = useCallback(
|
||||
async (node: PermissionNode): Promise<PermissionNode[]> => {
|
||||
const { databaseId, catalogName, schemaName, nodeType } = node;
|
||||
|
||||
if (!databaseId) return [];
|
||||
|
||||
// Database level -> load catalogs or schemas
|
||||
if (nodeType === 'database') {
|
||||
// First try to load catalogs
|
||||
try {
|
||||
const catalogs = await fetchCatalogs(databaseId, undefined, 0, 1);
|
||||
if (catalogs.result.length > 0) {
|
||||
// Database has catalogs
|
||||
const catalogsResult = await fetchCatalogs(
|
||||
databaseId,
|
||||
undefined,
|
||||
0,
|
||||
pageSize,
|
||||
);
|
||||
return catalogsResult.result.map(cat => ({
|
||||
key: makeKey({ databaseId, catalogName: cat }),
|
||||
title: cat,
|
||||
nodeType: 'catalog' as const,
|
||||
databaseId,
|
||||
databaseName: node.databaseName,
|
||||
catalogName: cat,
|
||||
children: [],
|
||||
isLeaf: false,
|
||||
hasMore: catalogsResult.count > pageSize,
|
||||
totalCount: catalogsResult.count,
|
||||
}));
|
||||
}
|
||||
} catch {
|
||||
// Database doesn't support catalogs, load schemas directly
|
||||
}
|
||||
|
||||
// Load schemas directly
|
||||
const schemasResult = await fetchSchemas(
|
||||
databaseId,
|
||||
undefined,
|
||||
undefined,
|
||||
0,
|
||||
pageSize,
|
||||
);
|
||||
return schemasResult.result.map(schema => ({
|
||||
key: makeKey({ databaseId, schemaName: schema }),
|
||||
title: schema,
|
||||
nodeType: 'schema' as const,
|
||||
databaseId,
|
||||
databaseName: node.databaseName,
|
||||
schemaName: schema,
|
||||
children: [],
|
||||
isLeaf: false,
|
||||
hasMore: schemasResult.count > pageSize,
|
||||
totalCount: schemasResult.count,
|
||||
}));
|
||||
}
|
||||
|
||||
// Catalog level -> load schemas
|
||||
if (nodeType === 'catalog' && catalogName) {
|
||||
const schemasResult = await fetchSchemas(
|
||||
databaseId,
|
||||
catalogName,
|
||||
undefined,
|
||||
0,
|
||||
pageSize,
|
||||
);
|
||||
return schemasResult.result.map(schema => ({
|
||||
key: makeKey({ databaseId, catalogName, schemaName: schema }),
|
||||
title: schema,
|
||||
nodeType: 'schema' as const,
|
||||
databaseId,
|
||||
databaseName: node.databaseName,
|
||||
catalogName,
|
||||
schemaName: schema,
|
||||
children: [],
|
||||
isLeaf: false,
|
||||
hasMore: schemasResult.count > pageSize,
|
||||
totalCount: schemasResult.count,
|
||||
}));
|
||||
}
|
||||
|
||||
// Schema level -> load tables
|
||||
if (nodeType === 'schema' && schemaName) {
|
||||
const tablesResult = await fetchTables(
|
||||
databaseId,
|
||||
schemaName,
|
||||
catalogName,
|
||||
undefined,
|
||||
0,
|
||||
pageSize,
|
||||
);
|
||||
return tablesResult.result.map(table => ({
|
||||
key: makeKey({
|
||||
databaseId,
|
||||
catalogName,
|
||||
schemaName,
|
||||
tableName: table.value,
|
||||
}),
|
||||
title: table.value,
|
||||
nodeType: 'table' as const,
|
||||
databaseId,
|
||||
databaseName: node.databaseName,
|
||||
catalogName,
|
||||
schemaName,
|
||||
tableName: table.value,
|
||||
isLeaf: false, // Tables are expandable to show columns
|
||||
children: [],
|
||||
}));
|
||||
}
|
||||
|
||||
// Table level -> load columns
|
||||
if (nodeType === 'table' && node.tableName) {
|
||||
const columns = await fetchColumns(
|
||||
databaseId,
|
||||
node.tableName,
|
||||
schemaName,
|
||||
catalogName,
|
||||
);
|
||||
const tableKey = node.key as string;
|
||||
return columns.map(col => ({
|
||||
key: `${tableKey}|col:${col.name}`,
|
||||
title: col.name,
|
||||
nodeType: 'column' as const,
|
||||
databaseId,
|
||||
databaseName: node.databaseName,
|
||||
catalogName,
|
||||
schemaName,
|
||||
tableName: node.tableName,
|
||||
columnName: col.name,
|
||||
isLeaf: true,
|
||||
}));
|
||||
}
|
||||
|
||||
return [];
|
||||
},
|
||||
[pageSize],
|
||||
);
|
||||
|
||||
const onLoadData = async (node: PermissionNode): Promise<void> => {
|
||||
if (node.children && node.children.length > 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const children = await loadChildren(node);
|
||||
const updatedTreeData = updateTreeData(
|
||||
treeState.treeData,
|
||||
node.key as string,
|
||||
children,
|
||||
);
|
||||
|
||||
setTreeState(prev => ({
|
||||
...prev,
|
||||
treeData: updatedTreeData,
|
||||
loadedKeys: [...prev.loadedKeys, node.key as string],
|
||||
}));
|
||||
};
|
||||
|
||||
const onExpand: TreeProps['onExpand'] = expandedKeysValue => {
|
||||
setTreeState(prev => ({
|
||||
...prev,
|
||||
expandedKeys: expandedKeysValue as string[],
|
||||
}));
|
||||
};
|
||||
|
||||
const handlePermissionClick = (
|
||||
nodeKey: string,
|
||||
event: React.MouseEvent,
|
||||
) => {
|
||||
event.stopPropagation();
|
||||
const newStates = cyclePermissionState(nodeKey, treeState.permissionStates);
|
||||
|
||||
// Mark as internal update to prevent circular updates
|
||||
isInternalUpdateRef.current = true;
|
||||
|
||||
setTreeState(prev => ({
|
||||
...prev,
|
||||
permissionStates: newStates,
|
||||
}));
|
||||
|
||||
// Notify parent of changes
|
||||
if (onChange && databaseMaps.idToName.size > 0) {
|
||||
const payload = generatePermissionsPayload(
|
||||
newStates,
|
||||
databaseMaps.idToName,
|
||||
treeState.clsRules,
|
||||
treeState.rlsRules,
|
||||
);
|
||||
onChange(payload);
|
||||
}
|
||||
};
|
||||
|
||||
const handleRlsChange = (
|
||||
tableKey: string,
|
||||
field: 'predicate' | 'groupKey',
|
||||
value: string,
|
||||
) => {
|
||||
// Mark as internal update to prevent circular updates
|
||||
isInternalUpdateRef.current = true;
|
||||
|
||||
setTreeState(prev => {
|
||||
const currentRls = prev.rlsRules[tableKey] || { predicate: '' };
|
||||
const newRls: RLSRule = { ...currentRls, [field]: value };
|
||||
|
||||
// Remove the rule if both fields are empty
|
||||
const newRlsRules = { ...prev.rlsRules };
|
||||
if (!newRls.predicate && !newRls.groupKey) {
|
||||
delete newRlsRules[tableKey];
|
||||
} else {
|
||||
newRlsRules[tableKey] = newRls;
|
||||
}
|
||||
|
||||
// Notify parent of changes
|
||||
if (onChange && databaseMaps.idToName.size > 0) {
|
||||
const payload = generatePermissionsPayload(
|
||||
prev.permissionStates,
|
||||
databaseMaps.idToName,
|
||||
prev.clsRules,
|
||||
newRlsRules,
|
||||
);
|
||||
onChange(payload);
|
||||
}
|
||||
|
||||
return {
|
||||
...prev,
|
||||
rlsRules: newRlsRules,
|
||||
};
|
||||
});
|
||||
};
|
||||
|
||||
const handleClsChange = (
|
||||
tableKey: string,
|
||||
columnName: string,
|
||||
action: CLSAction | undefined,
|
||||
) => {
|
||||
const clsKey = makeClsKey(tableKey, columnName);
|
||||
|
||||
// Mark as internal update to prevent circular updates
|
||||
isInternalUpdateRef.current = true;
|
||||
|
||||
setTreeState(prev => {
|
||||
const newClsRules = { ...prev.clsRules };
|
||||
if (action) {
|
||||
newClsRules[clsKey] = action;
|
||||
} else {
|
||||
delete newClsRules[clsKey];
|
||||
}
|
||||
|
||||
// Notify parent of changes
|
||||
if (onChange && databaseMaps.idToName.size > 0) {
|
||||
const payload = generatePermissionsPayload(
|
||||
prev.permissionStates,
|
||||
databaseMaps.idToName,
|
||||
newClsRules,
|
||||
prev.rlsRules,
|
||||
);
|
||||
onChange(payload);
|
||||
}
|
||||
|
||||
return {
|
||||
...prev,
|
||||
clsRules: newClsRules,
|
||||
};
|
||||
});
|
||||
};
|
||||
|
||||
const getStateIcon = (nodeKey: string) => {
|
||||
const explicitState = getExplicitState(nodeKey, treeState.permissionStates);
|
||||
const effectiveState = getEffectiveState(
|
||||
nodeKey,
|
||||
treeState.permissionStates,
|
||||
);
|
||||
|
||||
const getTooltipText = () => {
|
||||
if (explicitState === 'allow') return t('Allowed (click to deny)');
|
||||
if (explicitState === 'deny') return t('Denied (click to allow)');
|
||||
if (effectiveState === 'allow')
|
||||
return t('Inherits Allow (click to deny)');
|
||||
return t('Inherits Deny (click to allow)');
|
||||
};
|
||||
|
||||
if (explicitState === 'allow') {
|
||||
return (
|
||||
<Tooltip title={getTooltipText()}>
|
||||
<CheckCircleOutlined
|
||||
className="permission-icon permission-icon-allow"
|
||||
onClick={e => handlePermissionClick(nodeKey, e)}
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
if (explicitState === 'deny') {
|
||||
return (
|
||||
<Tooltip title={getTooltipText()}>
|
||||
<CloseCircleOutlined
|
||||
className="permission-icon permission-icon-deny"
|
||||
onClick={e => handlePermissionClick(nodeKey, e)}
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<Tooltip title={getTooltipText()}>
|
||||
<MinusCircleOutlined
|
||||
className="permission-icon permission-icon-inherit"
|
||||
onClick={e => handlePermissionClick(nodeKey, e)}
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
||||
const getNodeIcon = (nodeType: string) => {
|
||||
switch (nodeType) {
|
||||
case 'database':
|
||||
return <DatabaseOutlined />;
|
||||
case 'catalog':
|
||||
case 'schema':
|
||||
return <FolderOutlined />;
|
||||
case 'table':
|
||||
return <TableOutlined />;
|
||||
case 'column':
|
||||
return <ColumnHeightOutlined />;
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
const clsOptions = [
|
||||
{ value: '', label: ' ' },
|
||||
{
|
||||
value: 'hash',
|
||||
label: (
|
||||
<span>
|
||||
<NumberOutlined /> {t('Hash')}
|
||||
</span>
|
||||
),
|
||||
},
|
||||
{
|
||||
value: 'mask',
|
||||
label: (
|
||||
<span>
|
||||
<StarOutlined /> {t('Mask')}
|
||||
</span>
|
||||
),
|
||||
},
|
||||
{
|
||||
value: 'nullify',
|
||||
label: (
|
||||
<span>
|
||||
<StopOutlined /> {t('Null')}
|
||||
</span>
|
||||
),
|
||||
},
|
||||
{
|
||||
value: 'hide',
|
||||
label: (
|
||||
<span>
|
||||
<EyeInvisibleOutlined /> {t('Hide')}
|
||||
</span>
|
||||
),
|
||||
},
|
||||
];
|
||||
|
||||
// Count CLS rules for a given table
|
||||
const countClsRules = (tableKey: string): number => {
|
||||
const prefix = `${tableKey}::`;
|
||||
return Object.keys(treeState.clsRules).filter(key => key.startsWith(prefix)).length;
|
||||
};
|
||||
|
||||
// Check if a table has RLS rules
|
||||
const hasRlsRules = (tableKey: string): boolean => {
|
||||
const rls = treeState.rlsRules[tableKey];
|
||||
return !!(rls && (rls.predicate || rls.groupKey));
|
||||
};
|
||||
|
||||
const titleRender = (node: PermissionNode) => {
|
||||
const nodeKey = node.key as string;
|
||||
const isExpanded = treeState.expandedKeys.includes(nodeKey);
|
||||
const title = node.title as string;
|
||||
|
||||
// Column nodes get special rendering with CLS dropdown
|
||||
if (node.nodeType === 'column' && node.tableName && node.columnName) {
|
||||
// Get the table key by removing the column part
|
||||
const tableKey = nodeKey.substring(0, nodeKey.lastIndexOf('|col:'));
|
||||
const clsKey = makeClsKey(tableKey, node.columnName);
|
||||
const currentAction = treeState.clsRules[clsKey] || '';
|
||||
|
||||
return (
|
||||
<div className="column-node">
|
||||
<span className="column-name">
|
||||
{getNodeIcon(node.nodeType)}
|
||||
<span>{title}</span>
|
||||
</span>
|
||||
<Select
|
||||
className="cls-select"
|
||||
size="small"
|
||||
value={currentAction}
|
||||
onChange={(val: string) =>
|
||||
handleClsChange(
|
||||
tableKey,
|
||||
node.columnName as string,
|
||||
(val || undefined) as CLSAction | undefined,
|
||||
)
|
||||
}
|
||||
onClick={e => e.stopPropagation()}
|
||||
options={clsOptions}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Table nodes get RLS inputs when expanded
|
||||
if (node.nodeType === 'table' && isExpanded) {
|
||||
const currentRls = treeState.rlsRules[nodeKey] || {
|
||||
predicate: '',
|
||||
groupKey: '',
|
||||
};
|
||||
|
||||
// Check if table has permissions, CLS, or RLS
|
||||
const tableCounts = countDescendantPermissions(nodeKey, treeState.permissionStates);
|
||||
const tableExplicitState = getExplicitState(nodeKey, treeState.permissionStates);
|
||||
const clsCount = countClsRules(nodeKey);
|
||||
const hasRls = hasRlsRules(nodeKey);
|
||||
const tableHasConfig =
|
||||
tableExplicitState === 'allow' ||
|
||||
tableExplicitState === 'deny' ||
|
||||
(tableCounts && (tableCounts.allowed > 0 || tableCounts.denied > 0)) ||
|
||||
clsCount > 0 ||
|
||||
hasRls;
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="node-title">
|
||||
{getStateIcon(nodeKey)}
|
||||
{getNodeIcon(node.nodeType)}
|
||||
<span style={tableHasConfig ? { fontWeight: 600 } : undefined}>{title}</span>
|
||||
{(hasRls || clsCount > 0) && (
|
||||
<span className="node-count">
|
||||
({[
|
||||
hasRls ? 'RLS' : null,
|
||||
clsCount > 0 ? `CLS: ${clsCount}` : null,
|
||||
].filter(Boolean).join(', ')})
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<div className="rls-container">
|
||||
<div className="rls-row">
|
||||
<span className="rls-label">{t('RLS Predicate')}</span>
|
||||
<Input
|
||||
className="rls-input"
|
||||
size="small"
|
||||
placeholder={t('e.g., org_id = {{current_user_id}}')}
|
||||
value={currentRls.predicate}
|
||||
onChange={e => handleRlsChange(nodeKey, 'predicate', e.target.value)}
|
||||
onClick={e => e.stopPropagation()}
|
||||
/>
|
||||
</div>
|
||||
<div className="rls-row">
|
||||
<span className="rls-label">{t('Group Key')}</span>
|
||||
<Input
|
||||
className="rls-input"
|
||||
size="small"
|
||||
placeholder={t('Optional grouping key')}
|
||||
value={currentRls.groupKey || ''}
|
||||
onChange={e => handleRlsChange(nodeKey, 'groupKey', e.target.value)}
|
||||
onClick={e => e.stopPropagation()}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Non-column nodes: check if leaf (columns are leafs, tables are not anymore)
|
||||
const isLeaf = node.nodeType === 'column';
|
||||
|
||||
// Count descendants with custom permissions
|
||||
const counts = !isLeaf
|
||||
? countDescendantPermissions(nodeKey, treeState.permissionStates)
|
||||
: null;
|
||||
const hasCustomRules = counts && (counts.allowed > 0 || counts.denied > 0);
|
||||
|
||||
// Check if this node itself has explicit permissions
|
||||
const explicitState = getExplicitState(nodeKey, treeState.permissionStates);
|
||||
const hasExplicitPermission = explicitState === 'allow' || explicitState === 'deny';
|
||||
|
||||
// For tables, also check CLS/RLS
|
||||
const isTable = node.nodeType === 'table';
|
||||
const clsCount = isTable ? countClsRules(nodeKey) : 0;
|
||||
const hasRls = isTable ? hasRlsRules(nodeKey) : false;
|
||||
|
||||
// Bold if node or any children have permissions, or if table has CLS/RLS
|
||||
const shouldBeBold = hasExplicitPermission || hasCustomRules || clsCount > 0 || hasRls;
|
||||
|
||||
return (
|
||||
<div className="node-title">
|
||||
{getStateIcon(nodeKey)}
|
||||
{getNodeIcon(node.nodeType)}
|
||||
<span style={shouldBeBold ? { fontWeight: 600 } : undefined}>{title}</span>
|
||||
{hasCustomRules && !isExpanded && (
|
||||
<span className="node-count">
|
||||
(
|
||||
<span className="count-allowed">{counts.allowed}</span>
|
||||
{' / '}
|
||||
<span className="count-denied">{counts.denied}</span>
|
||||
)
|
||||
</span>
|
||||
)}
|
||||
{isTable && (hasRls || clsCount > 0) && !isExpanded && (
|
||||
<span className="node-count">
|
||||
({[
|
||||
hasRls ? 'RLS' : null,
|
||||
clsCount > 0 ? `CLS: ${clsCount}` : null,
|
||||
].filter(Boolean).join(', ')})
|
||||
</span>
|
||||
)}
|
||||
{node.hasMore && node.totalCount && (
|
||||
<span className="node-count" style={{ marginLeft: 4 }}>
|
||||
[{node.children?.length || 0}/{node.totalCount}]
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const clearAll = () => {
|
||||
isInternalUpdateRef.current = true;
|
||||
setTreeState(prev => ({
|
||||
...prev,
|
||||
permissionStates: {},
|
||||
clsRules: {},
|
||||
rlsRules: {},
|
||||
}));
|
||||
if (onChange) {
|
||||
onChange({ allowed: [], denied: [] });
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<StyledContainer>
|
||||
<div style={{ marginBottom: 8 }}>
|
||||
<Button size="small" onClick={clearAll}>
|
||||
{t('Reset All')}
|
||||
</Button>
|
||||
</div>
|
||||
<Spin spinning={loading}>
|
||||
<div className="tree-container">
|
||||
<Tree
|
||||
loadData={onLoadData as TreeProps['loadData']}
|
||||
treeData={treeState.treeData}
|
||||
expandedKeys={treeState.expandedKeys}
|
||||
onExpand={onExpand}
|
||||
titleRender={titleRender as TreeProps['titleRender']}
|
||||
showLine={{ showLeafIcon: false }}
|
||||
blockNode
|
||||
/>
|
||||
</div>
|
||||
</Spin>
|
||||
</StyledContainer>
|
||||
);
|
||||
}
|
||||
|
||||
export default PermissionsTree;
|
||||
@@ -0,0 +1,73 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
import type { DataNode } from 'antd/es/tree';
|
||||
|
||||
export type PermissionState = 'allow' | 'deny' | 'inherit';
|
||||
|
||||
export type NodeType = 'database' | 'catalog' | 'schema' | 'table' | 'column';
|
||||
|
||||
export type CLSAction = 'hash' | 'mask' | 'nullify' | 'hide';
|
||||
|
||||
export interface PermissionNode extends DataNode {
|
||||
key: string;
|
||||
title: string;
|
||||
nodeType: NodeType;
|
||||
databaseId?: number;
|
||||
databaseName?: string;
|
||||
catalogName?: string;
|
||||
schemaName?: string;
|
||||
tableName?: string;
|
||||
columnName?: string;
|
||||
children?: PermissionNode[];
|
||||
isLeaf?: boolean;
|
||||
hasMore?: boolean;
|
||||
totalCount?: number;
|
||||
}
|
||||
|
||||
export interface RLSRule {
|
||||
predicate: string;
|
||||
groupKey?: string;
|
||||
}
|
||||
|
||||
export interface TreeState {
|
||||
expandedKeys: string[];
|
||||
loadedKeys: string[];
|
||||
treeData: PermissionNode[];
|
||||
permissionStates: Record<string, PermissionState>;
|
||||
clsRules: Record<string, CLSAction>; // key format: "tableKey::columnName" -> action
|
||||
rlsRules: Record<string, RLSRule>; // key format: tableKey -> RLSRule
|
||||
}
|
||||
|
||||
export interface PermissionsPayload {
|
||||
allowed: PermissionEntry[];
|
||||
denied: PermissionEntry[];
|
||||
}
|
||||
|
||||
export interface PermissionEntry {
|
||||
database: string;
|
||||
catalog?: string;
|
||||
schema?: string;
|
||||
table?: string;
|
||||
rls?: {
|
||||
predicate: string;
|
||||
group_key?: string;
|
||||
};
|
||||
cls?: Record<string, string>; // column name -> action
|
||||
}
|
||||
@@ -0,0 +1,548 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
import type {
|
||||
PermissionNode,
|
||||
PermissionState,
|
||||
PermissionsPayload,
|
||||
PermissionEntry,
|
||||
NodeType,
|
||||
CLSAction,
|
||||
RLSRule,
|
||||
} from './types';
|
||||
|
||||
// Key format: db:{id}|cat:{name}|schema:{name}|table:{name}
|
||||
// Each part is optional based on the level
|
||||
|
||||
export function makeKey(parts: {
|
||||
databaseId: number;
|
||||
catalogName?: string;
|
||||
schemaName?: string;
|
||||
tableName?: string;
|
||||
}): string {
|
||||
let key = `db:${parts.databaseId}`;
|
||||
if (parts.catalogName) {
|
||||
key += `|cat:${parts.catalogName}`;
|
||||
}
|
||||
if (parts.schemaName) {
|
||||
key += `|schema:${parts.schemaName}`;
|
||||
}
|
||||
if (parts.tableName) {
|
||||
key += `|table:${parts.tableName}`;
|
||||
}
|
||||
return key;
|
||||
}
|
||||
|
||||
export function parseKey(key: string): {
|
||||
databaseId: number;
|
||||
catalogName?: string;
|
||||
schemaName?: string;
|
||||
tableName?: string;
|
||||
nodeType: NodeType;
|
||||
} {
|
||||
const parts = key.split('|');
|
||||
const result: ReturnType<typeof parseKey> = {
|
||||
databaseId: 0,
|
||||
nodeType: 'database',
|
||||
};
|
||||
|
||||
for (const part of parts) {
|
||||
if (part.startsWith('db:')) {
|
||||
result.databaseId = parseInt(part.slice(3), 10);
|
||||
} else if (part.startsWith('cat:')) {
|
||||
result.catalogName = part.slice(4);
|
||||
result.nodeType = 'catalog';
|
||||
} else if (part.startsWith('schema:')) {
|
||||
result.schemaName = part.slice(7);
|
||||
result.nodeType = 'schema';
|
||||
} else if (part.startsWith('table:')) {
|
||||
result.tableName = part.slice(6);
|
||||
result.nodeType = 'table';
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
export function getParentKey(key: string): string | null {
|
||||
const parts = key.split('|');
|
||||
if (parts.length <= 1) return null;
|
||||
return parts.slice(0, -1).join('|');
|
||||
}
|
||||
|
||||
export function getEffectiveState(
|
||||
nodeKey: string,
|
||||
permissionStates: Record<string, PermissionState>,
|
||||
): PermissionState {
|
||||
// Check if node has explicit state
|
||||
const explicitState = permissionStates[nodeKey];
|
||||
if (explicitState && explicitState !== 'inherit') {
|
||||
return explicitState;
|
||||
}
|
||||
|
||||
// Walk up ancestors to find effective state
|
||||
let currentKey: string | null = nodeKey;
|
||||
while (currentKey) {
|
||||
const parentKey = getParentKey(currentKey);
|
||||
if (!parentKey) break;
|
||||
|
||||
const parentState = permissionStates[parentKey];
|
||||
if (parentState && parentState !== 'inherit') {
|
||||
return parentState;
|
||||
}
|
||||
currentKey = parentKey;
|
||||
}
|
||||
|
||||
// Default is deny (implicit)
|
||||
return 'deny';
|
||||
}
|
||||
|
||||
export function getExplicitState(
|
||||
nodeKey: string,
|
||||
permissionStates: Record<string, PermissionState>,
|
||||
): PermissionState {
|
||||
return permissionStates[nodeKey] || 'inherit';
|
||||
}
|
||||
|
||||
export function cyclePermissionState(
|
||||
nodeKey: string,
|
||||
permissionStates: Record<string, PermissionState>,
|
||||
): Record<string, PermissionState> {
|
||||
const newStates = { ...permissionStates };
|
||||
const currentState = newStates[nodeKey] || 'inherit';
|
||||
const parentKey = getParentKey(nodeKey);
|
||||
|
||||
let nextState: PermissionState;
|
||||
|
||||
if (!parentKey) {
|
||||
// Root element (database) - only toggle between inherit and allow
|
||||
// (deny is the default state, so there's no point in explicitly denying)
|
||||
if (currentState === 'inherit') {
|
||||
nextState = 'allow';
|
||||
} else {
|
||||
nextState = 'inherit';
|
||||
}
|
||||
} else {
|
||||
// Child element - toggle between inherit and opposite of parent's effective state
|
||||
const parentEffective = getEffectiveState(parentKey, permissionStates);
|
||||
const oppositeState = parentEffective === 'allow' ? 'deny' : 'allow';
|
||||
|
||||
if (currentState === 'inherit') {
|
||||
nextState = oppositeState;
|
||||
} else {
|
||||
nextState = 'inherit';
|
||||
}
|
||||
}
|
||||
|
||||
// Apply the new state
|
||||
if (nextState === 'inherit') {
|
||||
delete newStates[nodeKey];
|
||||
} else {
|
||||
newStates[nodeKey] = nextState;
|
||||
}
|
||||
|
||||
// Clear all descendant states when setting a non-inherit state on a parent
|
||||
// This ensures children return to "inherit" when parent is set
|
||||
if (nextState !== 'inherit') {
|
||||
clearDescendantStates(nodeKey, newStates);
|
||||
}
|
||||
|
||||
return newStates;
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear all permission states for descendants of a given node key.
|
||||
*/
|
||||
function clearDescendantStates(
|
||||
parentKey: string,
|
||||
states: Record<string, PermissionState>,
|
||||
): void {
|
||||
const keysToDelete: string[] = [];
|
||||
|
||||
Object.keys(states).forEach(key => {
|
||||
if (isDescendantOf(key, parentKey)) {
|
||||
keysToDelete.push(key);
|
||||
}
|
||||
});
|
||||
|
||||
keysToDelete.forEach(key => {
|
||||
delete states[key];
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a key is a descendant of another key.
|
||||
*/
|
||||
function isDescendantOf(childKey: string, ancestorKey: string): boolean {
|
||||
// A key is a descendant if it starts with the ancestor key + '|'
|
||||
return childKey.startsWith(`${ancestorKey}|`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Count children with explicit allow/deny states.
|
||||
* Returns counts for direct children only (not all descendants).
|
||||
*/
|
||||
export function countChildPermissions(
|
||||
parentKey: string,
|
||||
permissionStates: Record<string, PermissionState>,
|
||||
): { allowed: number; denied: number } {
|
||||
let allowed = 0;
|
||||
let denied = 0;
|
||||
|
||||
Object.entries(permissionStates).forEach(([key, state]) => {
|
||||
// Check if this is a direct child (parent key + one more segment)
|
||||
if (isDirectChildOf(key, parentKey)) {
|
||||
if (state === 'allow') {
|
||||
allowed += 1;
|
||||
} else if (state === 'deny') {
|
||||
denied += 1;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return { allowed, denied };
|
||||
}
|
||||
|
||||
/**
|
||||
* Count all descendants with explicit allow/deny states.
|
||||
*/
|
||||
export function countDescendantPermissions(
|
||||
parentKey: string,
|
||||
permissionStates: Record<string, PermissionState>,
|
||||
): { allowed: number; denied: number } {
|
||||
let allowed = 0;
|
||||
let denied = 0;
|
||||
|
||||
Object.entries(permissionStates).forEach(([key, state]) => {
|
||||
if (isDescendantOf(key, parentKey)) {
|
||||
if (state === 'allow') {
|
||||
allowed += 1;
|
||||
} else if (state === 'deny') {
|
||||
denied += 1;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return { allowed, denied };
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a key is a direct child of another key.
|
||||
*/
|
||||
function isDirectChildOf(childKey: string, parentKey: string): boolean {
|
||||
if (!childKey.startsWith(`${parentKey}|`)) {
|
||||
return false;
|
||||
}
|
||||
// Count the number of '|' segments after the parent
|
||||
const suffix = childKey.slice(parentKey.length + 1);
|
||||
return !suffix.includes('|');
|
||||
}
|
||||
|
||||
export function updateTreeData(
|
||||
list: PermissionNode[],
|
||||
key: string,
|
||||
children: PermissionNode[],
|
||||
): PermissionNode[] {
|
||||
return list.map(node => {
|
||||
if (node.key === key) {
|
||||
return {
|
||||
...node,
|
||||
children,
|
||||
};
|
||||
}
|
||||
if (node.children) {
|
||||
return {
|
||||
...node,
|
||||
children: updateTreeData(node.children, key, children),
|
||||
};
|
||||
}
|
||||
return node;
|
||||
});
|
||||
}
|
||||
|
||||
export function findNodeByKey(
|
||||
data: PermissionNode[],
|
||||
key: string,
|
||||
): PermissionNode | null {
|
||||
for (const node of data) {
|
||||
if (node.key === key) return node;
|
||||
if (node.children) {
|
||||
const found = findNodeByKey(node.children, key);
|
||||
if (found) return found;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Make a CLS key from a table key and column name.
|
||||
* Format: "{tableKey}::{columnName}"
|
||||
*/
|
||||
export function makeClsKey(tableKey: string, columnName: string): string {
|
||||
return `${tableKey}::${columnName}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse a CLS key into table key and column name.
|
||||
*/
|
||||
export function parseClsKey(clsKey: string): {
|
||||
tableKey: string;
|
||||
columnName: string;
|
||||
} | null {
|
||||
const separatorIndex = clsKey.indexOf('::');
|
||||
if (separatorIndex === -1) return null;
|
||||
return {
|
||||
tableKey: clsKey.slice(0, separatorIndex),
|
||||
columnName: clsKey.slice(separatorIndex + 2),
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get CLS rules for a specific table from the clsRules state.
|
||||
*/
|
||||
export function getTableClsRules(
|
||||
tableKey: string,
|
||||
clsRules: Record<string, CLSAction>,
|
||||
): Record<string, CLSAction> {
|
||||
const tableRules: Record<string, CLSAction> = {};
|
||||
const prefix = `${tableKey}::`;
|
||||
|
||||
Object.entries(clsRules).forEach(([key, action]) => {
|
||||
if (key.startsWith(prefix)) {
|
||||
const columnName = key.slice(prefix.length);
|
||||
tableRules[columnName] = action;
|
||||
}
|
||||
});
|
||||
|
||||
return tableRules;
|
||||
}
|
||||
|
||||
export function generatePermissionsPayload(
|
||||
permissionStates: Record<string, PermissionState>,
|
||||
databases: Map<number, string>,
|
||||
clsRules: Record<string, CLSAction> = {},
|
||||
rlsRules: Record<string, RLSRule> = {},
|
||||
): PermissionsPayload {
|
||||
const allowed: PermissionEntry[] = [];
|
||||
const denied: PermissionEntry[] = [];
|
||||
|
||||
// Clean up states - remove redundant entries
|
||||
const cleanedStates = cleanupStates(permissionStates);
|
||||
|
||||
// Group CLS rules by table key
|
||||
const clsByTable: Record<string, Record<string, string>> = {};
|
||||
Object.entries(clsRules).forEach(([clsKey, action]) => {
|
||||
const parsed = parseClsKey(clsKey);
|
||||
if (!parsed) return;
|
||||
if (!clsByTable[parsed.tableKey]) {
|
||||
clsByTable[parsed.tableKey] = {};
|
||||
}
|
||||
clsByTable[parsed.tableKey][parsed.columnName] = action;
|
||||
});
|
||||
|
||||
Object.entries(cleanedStates).forEach(([key, state]) => {
|
||||
const parsed = parseKey(key);
|
||||
const databaseName = databases.get(parsed.databaseId);
|
||||
|
||||
if (!databaseName) return;
|
||||
|
||||
const entry: PermissionEntry = {
|
||||
database: databaseName,
|
||||
};
|
||||
|
||||
if (parsed.catalogName) entry.catalog = parsed.catalogName;
|
||||
if (parsed.schemaName) entry.schema = parsed.schemaName;
|
||||
if (parsed.tableName) entry.table = parsed.tableName;
|
||||
|
||||
// Add RLS rules if this is a table entry and has RLS rules
|
||||
if (parsed.tableName && rlsRules[key]) {
|
||||
const rls = rlsRules[key];
|
||||
if (rls.predicate) {
|
||||
entry.rls = {
|
||||
predicate: rls.predicate,
|
||||
};
|
||||
if (rls.groupKey) {
|
||||
entry.rls.group_key = rls.groupKey;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add CLS rules if this is a table entry and has CLS rules
|
||||
if (parsed.tableName && clsByTable[key]) {
|
||||
entry.cls = clsByTable[key];
|
||||
}
|
||||
|
||||
if (state === 'allow') {
|
||||
allowed.push(entry);
|
||||
} else if (state === 'deny') {
|
||||
denied.push(entry);
|
||||
}
|
||||
});
|
||||
|
||||
// Also add RLS/CLS rules for tables that don't have explicit permission states
|
||||
// but do have RLS or CLS rules (they inherit permission from parent)
|
||||
const tablesWithRules = new Set([
|
||||
...Object.keys(clsByTable),
|
||||
...Object.keys(rlsRules),
|
||||
]);
|
||||
|
||||
tablesWithRules.forEach(tableKey => {
|
||||
// Check if we already added this table
|
||||
const parsed = parseKey(tableKey);
|
||||
const databaseName = databases.get(parsed.databaseId);
|
||||
if (!databaseName) return;
|
||||
|
||||
// Only add if not already in allowed list
|
||||
const alreadyInAllowed = allowed.some(
|
||||
entry =>
|
||||
entry.database === databaseName &&
|
||||
entry.catalog === parsed.catalogName &&
|
||||
entry.schema === parsed.schemaName &&
|
||||
entry.table === parsed.tableName,
|
||||
);
|
||||
|
||||
const hasClsRules =
|
||||
clsByTable[tableKey] && Object.keys(clsByTable[tableKey]).length > 0;
|
||||
const hasRlsRules =
|
||||
rlsRules[tableKey] && rlsRules[tableKey].predicate;
|
||||
|
||||
if (!alreadyInAllowed && (hasClsRules || hasRlsRules)) {
|
||||
// Check if the table has effective allow state (inherited)
|
||||
const effectiveState = getEffectiveState(tableKey, permissionStates);
|
||||
if (effectiveState === 'allow') {
|
||||
const entry: PermissionEntry = {
|
||||
database: databaseName,
|
||||
table: parsed.tableName,
|
||||
};
|
||||
if (parsed.catalogName) entry.catalog = parsed.catalogName;
|
||||
if (parsed.schemaName) entry.schema = parsed.schemaName;
|
||||
|
||||
if (hasRlsRules) {
|
||||
const rls = rlsRules[tableKey];
|
||||
entry.rls = { predicate: rls.predicate };
|
||||
if (rls.groupKey) {
|
||||
entry.rls.group_key = rls.groupKey;
|
||||
}
|
||||
}
|
||||
|
||||
if (hasClsRules) {
|
||||
entry.cls = clsByTable[tableKey];
|
||||
}
|
||||
|
||||
allowed.push(entry);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return { allowed, denied };
|
||||
}
|
||||
|
||||
function cleanupStates(
|
||||
states: Record<string, PermissionState>,
|
||||
): Record<string, PermissionState> {
|
||||
const cleaned: Record<string, PermissionState> = {};
|
||||
|
||||
Object.entries(states).forEach(([key, state]) => {
|
||||
if (state === 'inherit') return;
|
||||
|
||||
// Check ancestor state
|
||||
let ancestorState: PermissionState | null = null;
|
||||
let currentKey: string | null = key;
|
||||
while (currentKey) {
|
||||
const parentKey = getParentKey(currentKey);
|
||||
if (!parentKey) break;
|
||||
if (states[parentKey] && states[parentKey] !== 'inherit') {
|
||||
ancestorState = states[parentKey];
|
||||
break;
|
||||
}
|
||||
currentKey = parentKey;
|
||||
}
|
||||
|
||||
// Only keep state if it differs from ancestor
|
||||
if (ancestorState) {
|
||||
if (ancestorState !== state) {
|
||||
cleaned[key] = state;
|
||||
}
|
||||
} else {
|
||||
// No ancestor - only keep 'allow' (deny is default)
|
||||
if (state === 'allow') {
|
||||
cleaned[key] = state;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return cleaned;
|
||||
}
|
||||
|
||||
export function loadPermissionsFromPayload(
|
||||
payload: PermissionsPayload,
|
||||
databases: Map<string, number>,
|
||||
): {
|
||||
states: Record<string, PermissionState>;
|
||||
clsRules: Record<string, CLSAction>;
|
||||
rlsRules: Record<string, RLSRule>;
|
||||
} {
|
||||
const states: Record<string, PermissionState> = {};
|
||||
const clsRules: Record<string, CLSAction> = {};
|
||||
const rlsRules: Record<string, RLSRule> = {};
|
||||
|
||||
payload.allowed.forEach(entry => {
|
||||
const databaseId = databases.get(entry.database);
|
||||
if (databaseId === undefined) return;
|
||||
|
||||
const key = makeKey({
|
||||
databaseId,
|
||||
catalogName: entry.catalog,
|
||||
schemaName: entry.schema,
|
||||
tableName: entry.table,
|
||||
});
|
||||
states[key] = 'allow';
|
||||
|
||||
// Load RLS rules if present
|
||||
if (entry.rls && entry.table) {
|
||||
rlsRules[key] = {
|
||||
predicate: entry.rls.predicate,
|
||||
groupKey: entry.rls.group_key,
|
||||
};
|
||||
}
|
||||
|
||||
// Load CLS rules if present
|
||||
if (entry.cls && entry.table) {
|
||||
Object.entries(entry.cls).forEach(([columnName, action]) => {
|
||||
const clsKey = makeClsKey(key, columnName);
|
||||
clsRules[clsKey] = action as CLSAction;
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
payload.denied.forEach(entry => {
|
||||
const databaseId = databases.get(entry.database);
|
||||
if (databaseId === undefined) return;
|
||||
|
||||
const key = makeKey({
|
||||
databaseId,
|
||||
catalogName: entry.catalog,
|
||||
schemaName: entry.schema,
|
||||
tableName: entry.table,
|
||||
});
|
||||
states[key] = 'deny';
|
||||
});
|
||||
|
||||
return { states, clsRules, rlsRules };
|
||||
}
|
||||
38
superset-frontend/src/features/dataAccessRules/types.ts
Normal file
38
superset-frontend/src/features/dataAccessRules/types.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
export type RoleObject = {
|
||||
id: number;
|
||||
name: string;
|
||||
};
|
||||
|
||||
export type DataAccessRuleObject = {
|
||||
id?: number;
|
||||
name?: string;
|
||||
description?: string;
|
||||
role_id: number;
|
||||
role?: RoleObject;
|
||||
rule: string;
|
||||
changed_on_delta_humanized?: string;
|
||||
changed_by?: {
|
||||
id: number;
|
||||
first_name: string;
|
||||
last_name: string;
|
||||
};
|
||||
};
|
||||
407
superset-frontend/src/pages/DataAccessRulesList/index.tsx
Normal file
407
superset-frontend/src/pages/DataAccessRulesList/index.tsx
Normal file
@@ -0,0 +1,407 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import { t, SupersetClient } from '@superset-ui/core';
|
||||
import { useMemo, useState } from 'react';
|
||||
import { ConfirmStatusChange, Tooltip } from '@superset-ui/core/components';
|
||||
import {
|
||||
ModifiedInfo,
|
||||
ListView,
|
||||
ListViewFilterOperator as FilterOperator,
|
||||
type ListViewProps,
|
||||
type ListViewFilters,
|
||||
type ListViewFetchDataConfig as FetchDataConfig,
|
||||
} from 'src/components';
|
||||
import { Icons } from '@superset-ui/core/components/Icons';
|
||||
import withToasts from 'src/components/MessageToasts/withToasts';
|
||||
import SubMenu, { SubMenuProps } from 'src/features/home/SubMenu';
|
||||
import rison from 'rison';
|
||||
import { useListViewResource } from 'src/views/CRUD/hooks';
|
||||
import DataAccessRuleModal from 'src/features/dataAccessRules/DataAccessRuleModal';
|
||||
import { DataAccessRuleObject } from 'src/features/dataAccessRules/types';
|
||||
import { createErrorHandler, createFetchRelated } from 'src/views/CRUD/utils';
|
||||
import { QueryObjectColumns } from 'src/views/CRUD/types';
|
||||
|
||||
interface DataAccessRulesListProps {
|
||||
addDangerToast: (msg: string) => void;
|
||||
addSuccessToast: (msg: string) => void;
|
||||
user: {
|
||||
userId: string | number;
|
||||
firstName: string;
|
||||
lastName: string;
|
||||
};
|
||||
}
|
||||
|
||||
function DataAccessRulesList(props: DataAccessRulesListProps) {
|
||||
const { addDangerToast, addSuccessToast, user } = props;
|
||||
const [ruleModalOpen, setRuleModalOpen] = useState<boolean>(false);
|
||||
const [currentRule, setCurrentRule] =
|
||||
useState<DataAccessRuleObject | null>(null);
|
||||
|
||||
const {
|
||||
state: {
|
||||
loading,
|
||||
resourceCount: rulesCount,
|
||||
resourceCollection: rules,
|
||||
bulkSelectEnabled,
|
||||
},
|
||||
hasPerm,
|
||||
fetchData,
|
||||
refreshData,
|
||||
toggleBulkSelect,
|
||||
} = useListViewResource<DataAccessRuleObject>(
|
||||
'dar',
|
||||
t('Data Access Rule'),
|
||||
addDangerToast,
|
||||
true,
|
||||
undefined,
|
||||
undefined,
|
||||
true,
|
||||
);
|
||||
|
||||
function handleRuleEdit(rule: DataAccessRuleObject | null) {
|
||||
setCurrentRule(rule);
|
||||
setRuleModalOpen(true);
|
||||
}
|
||||
|
||||
function handleRuleDelete(
|
||||
{ id, name: ruleName }: DataAccessRuleObject,
|
||||
refreshData: (arg0?: FetchDataConfig | null) => void,
|
||||
addSuccessToast: (arg0: string) => void,
|
||||
addDangerToast: (arg0: string) => void,
|
||||
) {
|
||||
const name = ruleName || `Rule ${id}`;
|
||||
return SupersetClient.delete({
|
||||
endpoint: `/api/v1/dar/${id}`,
|
||||
}).then(
|
||||
() => {
|
||||
refreshData();
|
||||
addSuccessToast(t('Deleted %s', name));
|
||||
},
|
||||
createErrorHandler(errMsg =>
|
||||
addDangerToast(t('There was an issue deleting %s: %s', name, errMsg)),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
function handleBulkRulesDelete(rulesToDelete: DataAccessRuleObject[]) {
|
||||
const ids = rulesToDelete.map(({ id }) => id);
|
||||
return SupersetClient.delete({
|
||||
endpoint: `/api/v1/dar/?q=${rison.encode(ids)}`,
|
||||
}).then(
|
||||
() => {
|
||||
refreshData();
|
||||
addSuccessToast(t(`Deleted`));
|
||||
},
|
||||
createErrorHandler(errMsg =>
|
||||
addDangerToast(t('There was an issue deleting rules: %s', errMsg)),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
function handleRuleModalHide() {
|
||||
setCurrentRule(null);
|
||||
setRuleModalOpen(false);
|
||||
refreshData();
|
||||
}
|
||||
|
||||
const canWrite = hasPerm('can_write');
|
||||
const canEdit = hasPerm('can_write');
|
||||
const canExport = hasPerm('can_export');
|
||||
|
||||
const columns = useMemo(
|
||||
() => [
|
||||
{
|
||||
Cell: ({
|
||||
row: {
|
||||
original: { name },
|
||||
},
|
||||
}: {
|
||||
row: { original: DataAccessRuleObject };
|
||||
}) => name || '-',
|
||||
accessor: 'name',
|
||||
Header: t('Name'),
|
||||
size: 'lg',
|
||||
id: 'name',
|
||||
},
|
||||
{
|
||||
Cell: ({
|
||||
row: {
|
||||
original: { description },
|
||||
},
|
||||
}: {
|
||||
row: { original: DataAccessRuleObject };
|
||||
}) => {
|
||||
if (!description) return '-';
|
||||
const truncated =
|
||||
description.length > 100
|
||||
? `${description.substring(0, 100)}...`
|
||||
: description;
|
||||
return (
|
||||
<Tooltip id="desc-tooltip" title={description} placement="top">
|
||||
<span>{truncated}</span>
|
||||
</Tooltip>
|
||||
);
|
||||
},
|
||||
accessor: 'description',
|
||||
Header: t('Description'),
|
||||
size: 'xl',
|
||||
id: 'description',
|
||||
disableSortBy: true,
|
||||
},
|
||||
{
|
||||
Cell: ({
|
||||
row: {
|
||||
original: { role },
|
||||
},
|
||||
}: {
|
||||
row: { original: DataAccessRuleObject };
|
||||
}) => role?.name || '-',
|
||||
accessor: 'role',
|
||||
Header: t('Role'),
|
||||
size: 'lg',
|
||||
id: 'role',
|
||||
disableSortBy: true,
|
||||
},
|
||||
{
|
||||
Cell: ({
|
||||
row: {
|
||||
original: {
|
||||
changed_on_delta_humanized: changedOn,
|
||||
changed_by: changedBy,
|
||||
},
|
||||
},
|
||||
}: {
|
||||
row: { original: DataAccessRuleObject };
|
||||
}) => <ModifiedInfo date={changedOn} user={changedBy} />,
|
||||
Header: t('Last modified'),
|
||||
accessor: 'changed_on_delta_humanized',
|
||||
size: 'xl',
|
||||
id: 'changed_on_delta_humanized',
|
||||
},
|
||||
{
|
||||
Cell: ({ row: { original } }: { row: { original: DataAccessRuleObject } }) => {
|
||||
const handleDelete = () =>
|
||||
handleRuleDelete(
|
||||
original,
|
||||
refreshData,
|
||||
addSuccessToast,
|
||||
addDangerToast,
|
||||
);
|
||||
const handleEdit = () => handleRuleEdit(original);
|
||||
return (
|
||||
<div className="actions">
|
||||
{canWrite && (
|
||||
<ConfirmStatusChange
|
||||
title={t('Please confirm')}
|
||||
description={
|
||||
<>
|
||||
{t('Are you sure you want to delete')}{' '}
|
||||
<b>{original.name || `Rule ${original.id}`}</b>
|
||||
</>
|
||||
}
|
||||
onConfirm={handleDelete}
|
||||
>
|
||||
{confirmDelete => (
|
||||
<Tooltip
|
||||
id="delete-action-tooltip"
|
||||
title={t('Delete')}
|
||||
placement="bottom"
|
||||
>
|
||||
<span
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
className="action-button"
|
||||
onClick={confirmDelete}
|
||||
>
|
||||
<Icons.DeleteOutlined
|
||||
data-test="dar-list-trash-icon"
|
||||
iconSize="l"
|
||||
/>
|
||||
</span>
|
||||
</Tooltip>
|
||||
)}
|
||||
</ConfirmStatusChange>
|
||||
)}
|
||||
{canEdit && (
|
||||
<Tooltip
|
||||
id="edit-action-tooltip"
|
||||
title={t('Edit')}
|
||||
placement="bottom"
|
||||
>
|
||||
<span
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
className="action-button"
|
||||
onClick={handleEdit}
|
||||
>
|
||||
<Icons.EditOutlined data-test="edit-alt" iconSize="l" />
|
||||
</span>
|
||||
</Tooltip>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
},
|
||||
Header: t('Actions'),
|
||||
id: 'actions',
|
||||
hidden: !canEdit && !canWrite && !canExport,
|
||||
disableSortBy: true,
|
||||
size: 'lg',
|
||||
},
|
||||
{
|
||||
accessor: QueryObjectColumns.ChangedBy,
|
||||
hidden: true,
|
||||
id: QueryObjectColumns.ChangedBy,
|
||||
},
|
||||
],
|
||||
[
|
||||
user.userId,
|
||||
canEdit,
|
||||
canWrite,
|
||||
canExport,
|
||||
hasPerm,
|
||||
refreshData,
|
||||
addDangerToast,
|
||||
addSuccessToast,
|
||||
],
|
||||
);
|
||||
|
||||
const emptyState = {
|
||||
title: t('No Data Access Rules yet'),
|
||||
image: 'filter-results.svg',
|
||||
buttonAction: () => handleRuleEdit(null),
|
||||
buttonIcon: canEdit ? (
|
||||
<Icons.PlusOutlined iconSize="m" data-test="add-rule-empty" />
|
||||
) : undefined,
|
||||
buttonText: canEdit ? t('Rule') : null,
|
||||
};
|
||||
|
||||
const filters: ListViewFilters = useMemo(
|
||||
() => [
|
||||
{
|
||||
Header: t('Role'),
|
||||
key: 'role',
|
||||
id: 'role',
|
||||
input: 'select',
|
||||
operator: FilterOperator.RelationOneMany,
|
||||
unfilteredLabel: t('All'),
|
||||
fetchSelects: createFetchRelated(
|
||||
'dar',
|
||||
'role',
|
||||
createErrorHandler(errMsg =>
|
||||
t('An error occurred while fetching roles: %s', errMsg),
|
||||
),
|
||||
user,
|
||||
),
|
||||
paginate: true,
|
||||
},
|
||||
{
|
||||
Header: t('Modified by'),
|
||||
key: 'changed_by',
|
||||
id: 'changed_by',
|
||||
input: 'select',
|
||||
operator: FilterOperator.RelationOneMany,
|
||||
unfilteredLabel: t('All'),
|
||||
fetchSelects: createFetchRelated(
|
||||
'dar',
|
||||
'changed_by',
|
||||
createErrorHandler(errMsg =>
|
||||
t('An error occurred while fetching users: %s', errMsg),
|
||||
),
|
||||
user,
|
||||
),
|
||||
paginate: true,
|
||||
},
|
||||
],
|
||||
[user],
|
||||
);
|
||||
|
||||
const initialSort = [{ id: 'changed_on_delta_humanized', desc: true }];
|
||||
const PAGE_SIZE = 25;
|
||||
|
||||
const subMenuButtons: SubMenuProps['buttons'] = [];
|
||||
|
||||
if (canWrite) {
|
||||
subMenuButtons.push({
|
||||
name: t('Bulk select'),
|
||||
buttonStyle: 'secondary',
|
||||
'data-test': 'bulk-select',
|
||||
onClick: toggleBulkSelect,
|
||||
});
|
||||
subMenuButtons.push({
|
||||
name: t('Rule'),
|
||||
icon: <Icons.PlusOutlined iconSize="m" data-test="add-rule" />,
|
||||
buttonStyle: 'primary',
|
||||
onClick: () => handleRuleEdit(null),
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<SubMenu name={t('Data Access Rules')} buttons={subMenuButtons} />
|
||||
<ConfirmStatusChange
|
||||
title={t('Please confirm')}
|
||||
description={t('Are you sure you want to delete the selected rules?')}
|
||||
onConfirm={handleBulkRulesDelete}
|
||||
>
|
||||
{confirmDelete => {
|
||||
const bulkActions: ListViewProps['bulkActions'] = [];
|
||||
if (canWrite) {
|
||||
bulkActions.push({
|
||||
key: 'delete',
|
||||
name: t('Delete'),
|
||||
type: 'danger',
|
||||
onSelect: confirmDelete,
|
||||
});
|
||||
}
|
||||
return (
|
||||
<>
|
||||
<DataAccessRuleModal
|
||||
rule={currentRule}
|
||||
addDangerToast={addDangerToast}
|
||||
onHide={handleRuleModalHide}
|
||||
addSuccessToast={addSuccessToast}
|
||||
show={ruleModalOpen}
|
||||
/>
|
||||
<ListView<DataAccessRuleObject>
|
||||
className="dar-list-view"
|
||||
bulkActions={bulkActions}
|
||||
bulkSelectEnabled={bulkSelectEnabled}
|
||||
disableBulkSelect={toggleBulkSelect}
|
||||
columns={columns}
|
||||
count={rulesCount}
|
||||
data={rules}
|
||||
emptyState={emptyState}
|
||||
fetchData={fetchData}
|
||||
filters={filters}
|
||||
initialSort={initialSort}
|
||||
loading={loading}
|
||||
addDangerToast={addDangerToast}
|
||||
addSuccessToast={addSuccessToast}
|
||||
refreshData={() => {}}
|
||||
pageSize={PAGE_SIZE}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}}
|
||||
</ConfirmStatusChange>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default withToasts(DataAccessRulesList);
|
||||
@@ -138,6 +138,13 @@ const RowLevelSecurityList = lazy(
|
||||
),
|
||||
);
|
||||
|
||||
const DataAccessRulesList = lazy(
|
||||
() =>
|
||||
import(
|
||||
/* webpackChunkName: "DataAccessRulesList" */ 'src/pages/DataAccessRulesList'
|
||||
),
|
||||
);
|
||||
|
||||
const RolesList = lazy(
|
||||
() => import(/* webpackChunkName: "RolesList" */ 'src/pages/RolesList'),
|
||||
);
|
||||
@@ -315,6 +322,13 @@ if (isFeatureEnabled(FeatureFlag.TaggingSystem)) {
|
||||
});
|
||||
}
|
||||
|
||||
if (isFeatureEnabled(FeatureFlag.DataAccessRules)) {
|
||||
routes.push({
|
||||
path: '/dataaccessrules/list/',
|
||||
Component: DataAccessRulesList,
|
||||
});
|
||||
}
|
||||
|
||||
const user = getBootstrapData()?.user;
|
||||
const authRegistrationEnabled =
|
||||
getBootstrapData()?.common.conf.AUTH_USER_REGISTRATION;
|
||||
|
||||
@@ -46,11 +46,17 @@ class TablesDatabaseCommand(BaseCommand):
|
||||
catalog_name: str | None,
|
||||
schema_name: str,
|
||||
force: bool,
|
||||
filter_str: str | None = None,
|
||||
page: int | None = None,
|
||||
page_size: int | None = None,
|
||||
):
|
||||
self._db_id = db_id
|
||||
self._catalog_name = catalog_name
|
||||
self._schema_name = schema_name
|
||||
self._force = force
|
||||
self._filter_str = filter_str
|
||||
self._page = page
|
||||
self._page_size = page_size
|
||||
|
||||
def run(self) -> dict[str, Any]:
|
||||
self.validate()
|
||||
@@ -161,8 +167,24 @@ class TablesDatabaseCommand(BaseCommand):
|
||||
key=lambda item: item["value"],
|
||||
)
|
||||
|
||||
# Apply filter if provided
|
||||
if self._filter_str:
|
||||
filter_lower = self._filter_str.lower()
|
||||
options = [
|
||||
opt for opt in options if filter_lower in opt["value"].lower()
|
||||
]
|
||||
|
||||
# Get total count before pagination
|
||||
total_count = len(options)
|
||||
|
||||
# Apply pagination if provided
|
||||
if self._page is not None and self._page_size is not None:
|
||||
start = self._page * self._page_size
|
||||
end = start + self._page_size
|
||||
options = options[start:end]
|
||||
|
||||
payload = {
|
||||
"count": len(tables) + len(views) + len(materialized_views),
|
||||
"count": total_count,
|
||||
"result": options,
|
||||
}
|
||||
return payload
|
||||
|
||||
@@ -581,6 +581,10 @@ DEFAULT_FEATURE_FLAGS: dict[str, bool] = {
|
||||
# Apply RLS rules to SQL Lab queries. This requires parsing and manipulating the
|
||||
# query, and might break queries and/or allow users to bypass RLS. Use with care!
|
||||
"RLS_IN_SQLLAB": False,
|
||||
# Enable the new Data Access Rules system for table-level access control,
|
||||
# row-level security (RLS), and column-level security (CLS). This replaces
|
||||
# the FAB-based permission system with a more flexible JSON-based rule system.
|
||||
"DATA_ACCESS_RULES": False,
|
||||
# Try to optimize SQL queries — for now only predicate pushdown is supported.
|
||||
"OPTIMIZE_SQL": False,
|
||||
# When impersonating a user, use the email prefix instead of the username
|
||||
|
||||
@@ -113,6 +113,7 @@ from superset.superset_typing import (
|
||||
)
|
||||
from superset.utils import core as utils, json
|
||||
from superset.utils.backports import StrEnum
|
||||
from superset.data_access_rules.utils import get_hidden_columns_for_table
|
||||
|
||||
config = current_app.config # Backward compatibility for tests
|
||||
metadata = Model.metadata # pylint: disable=no-member
|
||||
@@ -438,6 +439,21 @@ class BaseDatasource(
|
||||
@property
|
||||
def data(self) -> ExplorableData:
|
||||
"""Data representation of the datasource sent to the frontend"""
|
||||
# Filter hidden columns based on CLS rules
|
||||
columns_data = [o.data for o in self.columns]
|
||||
if is_feature_enabled("DATA_ACCESS_RULES") and hasattr(self, "database"):
|
||||
try:
|
||||
table = Table(self.datasource_name, self.schema, self.catalog)
|
||||
hidden_columns = get_hidden_columns_for_table(table, self.database)
|
||||
if hidden_columns:
|
||||
columns_data = [
|
||||
c for c in columns_data
|
||||
if c.get("column_name") not in hidden_columns
|
||||
]
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# Don't fail if CLS check fails, just return all columns
|
||||
pass
|
||||
|
||||
return {
|
||||
# simple fields
|
||||
"id": self.id,
|
||||
@@ -462,7 +478,7 @@ class BaseDatasource(
|
||||
# sqla-specific
|
||||
"sql": self.sql,
|
||||
# one to many
|
||||
"columns": [o.data for o in self.columns],
|
||||
"columns": columns_data,
|
||||
"metrics": [o.data for o in self.metrics],
|
||||
"folders": self.folders,
|
||||
# TODO deprecate, move logic to JS
|
||||
|
||||
28
superset/data_access_rules/__init__.py
Normal file
28
superset/data_access_rules/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Data Access Rules module.
|
||||
|
||||
This module provides a new approach to data access control in Superset,
|
||||
supporting:
|
||||
- Table-level access control (allow/deny patterns)
|
||||
- Row-level security (RLS) with predicates
|
||||
- Column-level security (CLS) with masking/hiding options
|
||||
|
||||
Unlike the FAB-based permission system, rules are stored as JSON documents
|
||||
and can reference tables directly without requiring a priori permission creation.
|
||||
"""
|
||||
235
superset/data_access_rules/api.py
Normal file
235
superset/data_access_rules/api.py
Normal file
@@ -0,0 +1,235 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Data Access Rules REST API.
|
||||
|
||||
This module provides the REST API for managing Data Access Rules,
|
||||
including CRUD operations and a group_keys discovery endpoint.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from flask import request, Response
|
||||
from flask_appbuilder.api import expose, protect, safe
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
from marshmallow import ValidationError
|
||||
|
||||
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
|
||||
from superset.data_access_rules.models import DataAccessRule
|
||||
from superset.data_access_rules.schemas import (
|
||||
DataAccessRulePostSchema,
|
||||
DataAccessRulePutSchema,
|
||||
DataAccessRuleShowSchema,
|
||||
)
|
||||
from superset.data_access_rules.utils import get_all_group_keys
|
||||
from superset.extensions import event_logger
|
||||
from superset.views.base_api import (
|
||||
BaseSupersetModelRestApi,
|
||||
requires_json,
|
||||
statsd_metrics,
|
||||
)
|
||||
from superset.views.filters import BaseFilterRelatedRoles, BaseFilterRelatedUsers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataAccessRulesRestApi(BaseSupersetModelRestApi):
|
||||
"""REST API for Data Access Rules."""
|
||||
|
||||
datamodel = SQLAInterface(DataAccessRule)
|
||||
include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | {
|
||||
RouteMethod.RELATED,
|
||||
"group_keys",
|
||||
}
|
||||
resource_name = "dar"
|
||||
class_permission_name = "DataAccessRule"
|
||||
openapi_spec_tag = "Data Access Rules"
|
||||
method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP
|
||||
allow_browser_login = True
|
||||
|
||||
list_columns = [
|
||||
"id",
|
||||
"name",
|
||||
"description",
|
||||
"role_id",
|
||||
"role.id",
|
||||
"role.name",
|
||||
"rule",
|
||||
"changed_on_delta_humanized",
|
||||
"changed_by.first_name",
|
||||
"changed_by.last_name",
|
||||
"changed_by.id",
|
||||
]
|
||||
order_columns = [
|
||||
"id",
|
||||
"name",
|
||||
"role_id",
|
||||
"changed_on_delta_humanized",
|
||||
]
|
||||
add_columns = [
|
||||
"name",
|
||||
"description",
|
||||
"role_id",
|
||||
"rule",
|
||||
]
|
||||
edit_columns = [
|
||||
"name",
|
||||
"description",
|
||||
"role_id",
|
||||
"rule",
|
||||
]
|
||||
show_columns = [
|
||||
"id",
|
||||
"name",
|
||||
"description",
|
||||
"role_id",
|
||||
"role.name",
|
||||
"role.id",
|
||||
"rule",
|
||||
"created_on",
|
||||
"changed_on",
|
||||
"created_by.first_name",
|
||||
"created_by.last_name",
|
||||
"changed_by.first_name",
|
||||
"changed_by.last_name",
|
||||
]
|
||||
search_columns = ["role", "changed_by"]
|
||||
|
||||
allowed_rel_fields = {"role", "changed_by"}
|
||||
base_related_field_filters = {
|
||||
"role": [["id", BaseFilterRelatedRoles, lambda: []]],
|
||||
"changed_by": [["id", BaseFilterRelatedUsers, lambda: []]],
|
||||
}
|
||||
|
||||
add_model_schema = DataAccessRulePostSchema()
|
||||
edit_model_schema = DataAccessRulePutSchema()
|
||||
show_model_schema = DataAccessRuleShowSchema()
|
||||
|
||||
openapi_spec_methods = {
|
||||
"get": {"get": {"summary": "Get a data access rule"}},
|
||||
"get_list": {"get": {"summary": "Get a list of data access rules"}},
|
||||
"post": {"post": {"summary": "Create a data access rule"}},
|
||||
"put": {"put": {"summary": "Update a data access rule"}},
|
||||
"delete": {"delete": {"summary": "Delete a data access rule"}},
|
||||
}
|
||||
|
||||
@expose("/<int:pk>", methods=("PUT",))
|
||||
@protect()
|
||||
@safe
|
||||
@statsd_metrics
|
||||
@requires_json
|
||||
@event_logger.log_this_with_context(
|
||||
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.put",
|
||||
log_to_statsd=False,
|
||||
)
|
||||
def put(self, pk: int) -> Response:
|
||||
"""Update a data access rule.
|
||||
---
|
||||
put:
|
||||
summary: Update a data access rule
|
||||
parameters:
|
||||
- in: path
|
||||
schema:
|
||||
type: integer
|
||||
name: pk
|
||||
description: The rule pk
|
||||
requestBody:
|
||||
description: Data access rule schema
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/DataAccessRulePutSchema'
|
||||
responses:
|
||||
200:
|
||||
description: Rule updated
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: number
|
||||
result:
|
||||
$ref: '#/components/schemas/DataAccessRulePutSchema'
|
||||
400:
|
||||
$ref: '#/components/responses/400'
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
404:
|
||||
$ref: '#/components/responses/404'
|
||||
422:
|
||||
$ref: '#/components/responses/422'
|
||||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
try:
|
||||
item = self.edit_model_schema.load(request.json)
|
||||
except ValidationError as error:
|
||||
return self.response_400(message=error.messages)
|
||||
|
||||
# Get existing rule
|
||||
existing = self.datamodel.get(pk)
|
||||
if not existing:
|
||||
return self.response_404()
|
||||
|
||||
# Update fields
|
||||
for key, value in item.items():
|
||||
setattr(existing, key, value)
|
||||
|
||||
try:
|
||||
self.datamodel.edit(existing)
|
||||
return self.response(200, id=existing.id, result=item)
|
||||
except Exception as ex:
|
||||
logger.error("Error updating data access rule: %s", str(ex), exc_info=True)
|
||||
return self.response_422(message=str(ex))
|
||||
|
||||
@expose("/group_keys/", methods=("GET",))
|
||||
@protect()
|
||||
@safe
|
||||
@statsd_metrics
|
||||
@event_logger.log_this_with_context(
|
||||
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.group_keys",
|
||||
log_to_statsd=False,
|
||||
)
|
||||
def group_keys(self) -> Response:
|
||||
"""
|
||||
Get all distinct group_keys used in RLS rules.
|
||||
|
||||
This endpoint is useful for UI discoverability - showing users
|
||||
what group_keys already exist so they can reuse them.
|
||||
---
|
||||
get:
|
||||
summary: Get all distinct RLS group keys
|
||||
description: >-
|
||||
Returns a list of all unique group_key values used in RLS rules
|
||||
across all Data Access Rules. This helps users discover existing
|
||||
keys for consistent rule grouping.
|
||||
responses:
|
||||
200:
|
||||
description: List of group keys
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/GroupKeysResponseSchema'
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
group_keys = get_all_group_keys()
|
||||
return self.response(200, result=sorted(group_keys))
|
||||
119
superset/data_access_rules/models.py
Normal file
119
superset/data_access_rules/models.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Data Access Rules models.
|
||||
|
||||
This module defines the DataAccessRule model for storing access rules
|
||||
as JSON documents associated with roles.
|
||||
|
||||
Example rule document structure:
|
||||
{
|
||||
"allowed": [
|
||||
{
|
||||
"database": "sales",
|
||||
"schema": "orders",
|
||||
"table": "ord_main"
|
||||
},
|
||||
{
|
||||
"database": "logs",
|
||||
"catalog": "public"
|
||||
},
|
||||
{
|
||||
"database": "sales",
|
||||
"schema": "orders",
|
||||
"table": "prices",
|
||||
"rls": {
|
||||
"predicate": "org = 495",
|
||||
"group_key": "org_filter"
|
||||
}
|
||||
},
|
||||
{
|
||||
"database": "sales",
|
||||
"schema": "orders",
|
||||
"table": "user_info",
|
||||
"cls": {
|
||||
"name": "mask",
|
||||
"age": "nullify",
|
||||
"email": "hash",
|
||||
"lastname": "hide"
|
||||
}
|
||||
}
|
||||
],
|
||||
"denied": [
|
||||
{
|
||||
"database": "logs",
|
||||
"catalog": "public",
|
||||
"schema": "pii"
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from flask_appbuilder import Model
|
||||
from sqlalchemy import Column, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from superset import security_manager
|
||||
from superset.models.helpers import AuditMixinNullable
|
||||
|
||||
|
||||
class DataAccessRule(Model, AuditMixinNullable):
|
||||
"""
|
||||
Data access rule associated with a role.
|
||||
|
||||
Each rule is a JSON document that describes what databases, catalogs,
|
||||
schemas, and tables a role can access, along with optional RLS predicates
|
||||
and CLS column restrictions.
|
||||
"""
|
||||
|
||||
__tablename__ = "data_access_rules"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(250), nullable=True)
|
||||
description = Column(Text, nullable=True)
|
||||
role_id = Column(Integer, ForeignKey("ab_role.id"), nullable=False)
|
||||
rule = Column(Text, nullable=False)
|
||||
|
||||
role = relationship(
|
||||
security_manager.role_model,
|
||||
backref="data_access_rules",
|
||||
foreign_keys=[role_id],
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<DataAccessRule(id={self.id}, name={self.name!r}, role_id={self.role_id})>"
|
||||
|
||||
@property
|
||||
def rule_dict(self) -> dict[str, Any]:
|
||||
"""Parse the rule JSON string into a dictionary."""
|
||||
import json
|
||||
|
||||
try:
|
||||
return json.loads(self.rule) if self.rule else {}
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
|
||||
@rule_dict.setter
|
||||
def rule_dict(self, value: dict[str, Any]) -> None:
|
||||
"""Serialize a dictionary to JSON for storage."""
|
||||
import json
|
||||
|
||||
self.rule = json.dumps(value)
|
||||
249
superset/data_access_rules/schemas.py
Normal file
249
superset/data_access_rules/schemas.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Data Access Rules schemas for API serialization/deserialization.
|
||||
"""
|
||||
|
||||
from marshmallow import fields, post_load, Schema, validates_schema, ValidationError
|
||||
|
||||
from superset.dashboards.schemas import UserSchema
|
||||
from superset.data_access_rules.models import DataAccessRule
|
||||
|
||||
# Field descriptions for OpenAPI documentation
|
||||
rule_description = """
|
||||
A JSON document describing the access rule. The document should have two optional keys:
|
||||
- `allowed`: List of entries describing what is allowed
|
||||
- `denied`: List of entries describing what is denied
|
||||
|
||||
Each entry can specify:
|
||||
- `database` (required): The database name
|
||||
- `catalog` (optional): The catalog name
|
||||
- `schema` (optional): The schema name
|
||||
- `table` (optional): The table name
|
||||
- `rls` (optional): Row-level security config with `predicate` and optional `group_key`
|
||||
- `cls` (optional): Column-level security config mapping column names to actions
|
||||
|
||||
Example:
|
||||
{
|
||||
"allowed": [
|
||||
{"database": "sales", "schema": "orders"},
|
||||
{"database": "sales", "schema": "orders", "table": "prices",
|
||||
"rls": {"predicate": "org_id = 123", "group_key": "org"}},
|
||||
{"database": "sales", "schema": "users", "table": "info",
|
||||
"cls": {"email": "mask", "ssn": "hide"}}
|
||||
],
|
||||
"denied": [
|
||||
{"database": "sales", "schema": "internal"}
|
||||
]
|
||||
}
|
||||
|
||||
CLS actions: "hash", "nullify", "mask", "hide"
|
||||
"""
|
||||
|
||||
|
||||
class RoleSchema(Schema):
|
||||
"""Schema for role information."""
|
||||
|
||||
name = fields.String()
|
||||
id = fields.Integer()
|
||||
|
||||
|
||||
class DataAccessRuleListSchema(Schema):
|
||||
"""Schema for listing data access rules."""
|
||||
|
||||
id = fields.Integer(metadata={"description": "Unique ID of the rule"})
|
||||
name = fields.String(metadata={"description": "Name of the rule"})
|
||||
description = fields.String(metadata={"description": "Description of the rule"})
|
||||
role_id = fields.Integer(metadata={"description": "ID of the associated role"})
|
||||
role = fields.Nested(RoleSchema)
|
||||
rule = fields.String(metadata={"description": rule_description})
|
||||
changed_on_delta_humanized = fields.Method("get_changed_on_delta_humanized")
|
||||
changed_by = fields.Nested(UserSchema(exclude=["username"]))
|
||||
|
||||
def get_changed_on_delta_humanized(self, obj: DataAccessRule) -> str:
|
||||
return obj.changed_on_delta_humanized()
|
||||
|
||||
|
||||
class DataAccessRuleShowSchema(Schema):
|
||||
"""Schema for showing a single data access rule."""
|
||||
|
||||
id = fields.Integer(metadata={"description": "Unique ID of the rule"})
|
||||
name = fields.String(metadata={"description": "Name of the rule"})
|
||||
description = fields.String(metadata={"description": "Description of the rule"})
|
||||
role_id = fields.Integer(metadata={"description": "ID of the associated role"})
|
||||
role = fields.Nested(RoleSchema)
|
||||
rule = fields.String(metadata={"description": rule_description})
|
||||
created_on = fields.DateTime()
|
||||
changed_on = fields.DateTime()
|
||||
created_by = fields.Nested(UserSchema(exclude=["username"]))
|
||||
changed_by = fields.Nested(UserSchema(exclude=["username"]))
|
||||
|
||||
|
||||
class DataAccessRulePostSchema(Schema):
|
||||
"""Schema for creating a data access rule."""
|
||||
|
||||
name = fields.String(
|
||||
metadata={"description": "Name for this rule (optional)"},
|
||||
required=False,
|
||||
allow_none=True,
|
||||
)
|
||||
description = fields.String(
|
||||
metadata={"description": "Description of the rule (optional)"},
|
||||
required=False,
|
||||
allow_none=True,
|
||||
)
|
||||
role_id = fields.Integer(
|
||||
metadata={"description": "ID of the role this rule applies to"},
|
||||
required=True,
|
||||
allow_none=False,
|
||||
)
|
||||
rule = fields.String(
|
||||
metadata={"description": rule_description},
|
||||
required=True,
|
||||
allow_none=False,
|
||||
)
|
||||
|
||||
@validates_schema
|
||||
def validate_rule_json(self, data: dict, **kwargs: dict) -> None:
|
||||
"""Validate that the rule field contains valid JSON."""
|
||||
import json
|
||||
|
||||
if rule := data.get("rule"):
|
||||
try:
|
||||
parsed = json.loads(rule)
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValidationError(
|
||||
"Rule must be a JSON object", field_name="rule"
|
||||
)
|
||||
|
||||
# Validate structure
|
||||
allowed = parsed.get("allowed", [])
|
||||
denied = parsed.get("denied", [])
|
||||
|
||||
if not isinstance(allowed, list):
|
||||
raise ValidationError("'allowed' must be a list", field_name="rule")
|
||||
if not isinstance(denied, list):
|
||||
raise ValidationError("'denied' must be a list", field_name="rule")
|
||||
|
||||
# Validate entries
|
||||
for entry in allowed + denied:
|
||||
if not isinstance(entry, dict):
|
||||
raise ValidationError(
|
||||
"Each entry must be an object", field_name="rule"
|
||||
)
|
||||
if "database" not in entry:
|
||||
raise ValidationError(
|
||||
"Each entry must have a 'database' field",
|
||||
field_name="rule",
|
||||
)
|
||||
|
||||
# Validate CLS actions if present
|
||||
if cls_config := entry.get("cls"):
|
||||
valid_actions = {"hash", "nullify", "mask", "hide"}
|
||||
for col, action in cls_config.items():
|
||||
if action.lower() not in valid_actions:
|
||||
raise ValidationError(
|
||||
f"Invalid CLS action '{action}' for column '{col}'. "
|
||||
f"Valid actions: {valid_actions}",
|
||||
field_name="rule",
|
||||
)
|
||||
except json.JSONDecodeError as ex:
|
||||
raise ValidationError(f"Invalid JSON: {ex}", field_name="rule") from ex
|
||||
|
||||
@post_load
|
||||
def make_object(self, data: dict, **kwargs: dict) -> DataAccessRule:
|
||||
"""Convert validated data to a DataAccessRule instance."""
|
||||
return DataAccessRule(**data)
|
||||
|
||||
|
||||
class DataAccessRulePutSchema(Schema):
|
||||
"""Schema for updating a data access rule."""
|
||||
|
||||
name = fields.String(
|
||||
metadata={"description": "Name for this rule (optional)"},
|
||||
required=False,
|
||||
allow_none=True,
|
||||
)
|
||||
description = fields.String(
|
||||
metadata={"description": "Description of the rule (optional)"},
|
||||
required=False,
|
||||
allow_none=True,
|
||||
)
|
||||
role_id = fields.Integer(
|
||||
metadata={"description": "ID of the role this rule applies to"},
|
||||
required=False,
|
||||
allow_none=False,
|
||||
)
|
||||
rule = fields.String(
|
||||
metadata={"description": rule_description},
|
||||
required=False,
|
||||
allow_none=False,
|
||||
)
|
||||
|
||||
@validates_schema
|
||||
def validate_rule_json(self, data: dict, **kwargs: dict) -> None:
|
||||
"""Validate that the rule field contains valid JSON if provided."""
|
||||
import json
|
||||
|
||||
if rule := data.get("rule"):
|
||||
try:
|
||||
parsed = json.loads(rule)
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValidationError(
|
||||
"Rule must be a JSON object", field_name="rule"
|
||||
)
|
||||
|
||||
# Same validation as POST schema
|
||||
allowed = parsed.get("allowed", [])
|
||||
denied = parsed.get("denied", [])
|
||||
|
||||
if not isinstance(allowed, list):
|
||||
raise ValidationError("'allowed' must be a list", field_name="rule")
|
||||
if not isinstance(denied, list):
|
||||
raise ValidationError("'denied' must be a list", field_name="rule")
|
||||
|
||||
for entry in allowed + denied:
|
||||
if not isinstance(entry, dict):
|
||||
raise ValidationError(
|
||||
"Each entry must be an object", field_name="rule"
|
||||
)
|
||||
if "database" not in entry:
|
||||
raise ValidationError(
|
||||
"Each entry must have a 'database' field",
|
||||
field_name="rule",
|
||||
)
|
||||
|
||||
if cls_config := entry.get("cls"):
|
||||
valid_actions = {"hash", "nullify", "mask", "hide"}
|
||||
for col, action in cls_config.items():
|
||||
if action.lower() not in valid_actions:
|
||||
raise ValidationError(
|
||||
f"Invalid CLS action '{action}' for column '{col}'. "
|
||||
f"Valid actions: {valid_actions}",
|
||||
field_name="rule",
|
||||
)
|
||||
except json.JSONDecodeError as ex:
|
||||
raise ValidationError(f"Invalid JSON: {ex}", field_name="rule") from ex
|
||||
|
||||
|
||||
class GroupKeysResponseSchema(Schema):
|
||||
"""Schema for the group_keys endpoint response."""
|
||||
|
||||
result = fields.List(
|
||||
fields.String(),
|
||||
metadata={"description": "List of unique group_key values used in RLS rules"},
|
||||
)
|
||||
896
superset/data_access_rules/utils.py
Normal file
896
superset/data_access_rules/utils.py
Normal file
@@ -0,0 +1,896 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Data Access Rules utility functions.
|
||||
|
||||
This module provides functions for:
|
||||
- Checking if a user has access to a table
|
||||
- Collecting RLS predicates for a table
|
||||
- Collecting CLS rules for a table
|
||||
- Applying RLS and CLS to SQL queries
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from flask import g
|
||||
|
||||
from superset import db, is_feature_enabled, security_manager
|
||||
from superset.sql.parse import CLSAction, Table
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.data_access_rules.models import DataAccessRule
|
||||
from superset.models.core import Database
|
||||
from superset.sql.parse import BaseSQLStatement
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AccessCheckResult(Enum):
|
||||
"""Result of an access check."""
|
||||
|
||||
ALLOWED = "allowed"
|
||||
DENIED = "denied"
|
||||
NO_RULE = "no_rule"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLSPredicate:
|
||||
"""An RLS predicate with optional group_key."""
|
||||
|
||||
predicate: str
|
||||
group_key: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TableAccessInfo:
|
||||
"""Information about access to a specific table."""
|
||||
|
||||
access: AccessCheckResult
|
||||
rls_predicates: list[RLSPredicate]
|
||||
cls_rules: dict[str, CLSAction]
|
||||
|
||||
|
||||
def get_user_rules() -> list[DataAccessRule]:
|
||||
"""
|
||||
Get all data access rules for the current user's roles.
|
||||
|
||||
Returns:
|
||||
List of DataAccessRule objects for the current user's roles.
|
||||
"""
|
||||
from superset.data_access_rules.models import DataAccessRule
|
||||
|
||||
if not hasattr(g, "user") or not g.user:
|
||||
return []
|
||||
|
||||
user_roles = security_manager.get_user_roles()
|
||||
role_ids = [role.id for role in user_roles]
|
||||
|
||||
if not role_ids:
|
||||
return []
|
||||
|
||||
return (
|
||||
db.session.query(DataAccessRule)
|
||||
.filter(DataAccessRule.role_id.in_(role_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def _matches_rule_entry(
|
||||
entry: dict[str, Any],
|
||||
database_name: str,
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
table_name: str | None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a rule entry matches the given database/catalog/schema/table.
|
||||
|
||||
The rule entry can specify any level of the hierarchy:
|
||||
- database only: matches all catalogs/schemas/tables in that database
|
||||
- database + catalog: matches all schemas/tables in that catalog
|
||||
- database + schema: matches all tables in that schema (for DBs without catalogs)
|
||||
- database + catalog + schema: matches all tables in that schema
|
||||
- database + schema + table: matches the specific table (for DBs without catalogs)
|
||||
- database + catalog + schema + table: matches the specific table
|
||||
|
||||
Args:
|
||||
entry: The rule entry dict
|
||||
database_name: The database name to check
|
||||
catalog: The catalog to check (None if DB doesn't support catalogs)
|
||||
schema: The schema to check
|
||||
table_name: The table name to check
|
||||
|
||||
Returns:
|
||||
True if the entry matches, False otherwise.
|
||||
"""
|
||||
# Database must always match
|
||||
if entry.get("database") != database_name:
|
||||
return False
|
||||
|
||||
entry_catalog = entry.get("catalog")
|
||||
entry_schema = entry.get("schema")
|
||||
entry_table = entry.get("table")
|
||||
|
||||
# If the entry specifies a catalog, it must match (or catalog must be None/default)
|
||||
if entry_catalog is not None:
|
||||
if catalog is not None and entry_catalog != catalog:
|
||||
return False
|
||||
|
||||
# If the entry specifies a schema, it must match
|
||||
if entry_schema is not None:
|
||||
if schema is not None and entry_schema != schema:
|
||||
return False
|
||||
|
||||
# If the entry specifies a table, it must match
|
||||
if entry_table is not None:
|
||||
if table_name is not None and entry_table != table_name:
|
||||
return False
|
||||
|
||||
# Check specificity: entry must be at least as specific as the query
|
||||
# If querying a specific table, entry must specify that table or be broader
|
||||
if table_name is not None and entry_table is not None and entry_table != table_name:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _is_more_specific(entry: dict[str, Any], other: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if 'entry' is more specific than 'other'.
|
||||
|
||||
More specific means it specifies more levels of the hierarchy.
|
||||
"""
|
||||
entry_specificity = sum(
|
||||
[
|
||||
entry.get("catalog") is not None,
|
||||
entry.get("schema") is not None,
|
||||
entry.get("table") is not None,
|
||||
]
|
||||
)
|
||||
other_specificity = sum(
|
||||
[
|
||||
other.get("catalog") is not None,
|
||||
other.get("schema") is not None,
|
||||
other.get("table") is not None,
|
||||
]
|
||||
)
|
||||
return entry_specificity > other_specificity
|
||||
|
||||
|
||||
def check_table_access(
|
||||
database_name: str,
|
||||
table: Table,
|
||||
rules: list[DataAccessRule] | None = None,
|
||||
) -> TableAccessInfo:
|
||||
"""
|
||||
Check if the current user has access to a specific table.
|
||||
|
||||
The function evaluates all rules for the user's roles and determines:
|
||||
1. Whether access is allowed, denied, or no rule applies
|
||||
2. Any RLS predicates that should be applied
|
||||
3. Any CLS rules for column masking/hiding
|
||||
|
||||
Denied rules take precedence over allowed rules when at the same specificity level.
|
||||
More specific rules take precedence over less specific rules.
|
||||
|
||||
Args:
|
||||
database_name: The database name
|
||||
table: The Table object with catalog, schema, and table name
|
||||
rules: Optional list of rules to check (defaults to current user's rules)
|
||||
|
||||
Returns:
|
||||
TableAccessInfo with access result, RLS predicates, and CLS rules.
|
||||
"""
|
||||
if rules is None:
|
||||
rules = get_user_rules()
|
||||
|
||||
if not rules:
|
||||
return TableAccessInfo(
|
||||
access=AccessCheckResult.NO_RULE,
|
||||
rls_predicates=[],
|
||||
cls_rules={},
|
||||
)
|
||||
|
||||
# Collect all matching rules
|
||||
allowed_entries: list[dict[str, Any]] = []
|
||||
denied_entries: list[dict[str, Any]] = []
|
||||
|
||||
for rule in rules:
|
||||
rule_dict = rule.rule_dict
|
||||
|
||||
# Check allowed entries
|
||||
for entry in rule_dict.get("allowed", []):
|
||||
if _matches_rule_entry(
|
||||
entry, database_name, table.catalog, table.schema, table.table
|
||||
):
|
||||
allowed_entries.append(entry)
|
||||
|
||||
# Check denied entries
|
||||
for entry in rule_dict.get("denied", []):
|
||||
if _matches_rule_entry(
|
||||
entry, database_name, table.catalog, table.schema, table.table
|
||||
):
|
||||
denied_entries.append(entry)
|
||||
|
||||
# If no rules match, return NO_RULE
|
||||
if not allowed_entries and not denied_entries:
|
||||
return TableAccessInfo(
|
||||
access=AccessCheckResult.NO_RULE,
|
||||
rls_predicates=[],
|
||||
cls_rules={},
|
||||
)
|
||||
|
||||
# Find the most specific denied entry
|
||||
most_specific_denied = None
|
||||
for entry in denied_entries:
|
||||
if most_specific_denied is None or _is_more_specific(
|
||||
entry, most_specific_denied
|
||||
):
|
||||
most_specific_denied = entry
|
||||
|
||||
# Find the most specific allowed entry
|
||||
most_specific_allowed = None
|
||||
for entry in allowed_entries:
|
||||
if most_specific_allowed is None or _is_more_specific(
|
||||
entry, most_specific_allowed
|
||||
):
|
||||
most_specific_allowed = entry
|
||||
|
||||
# Determine access: deny wins at same specificity, more specific wins otherwise
|
||||
if most_specific_denied is not None and most_specific_allowed is not None:
|
||||
if _is_more_specific(most_specific_denied, most_specific_allowed):
|
||||
return TableAccessInfo(
|
||||
access=AccessCheckResult.DENIED,
|
||||
rls_predicates=[],
|
||||
cls_rules={},
|
||||
)
|
||||
elif _is_more_specific(most_specific_allowed, most_specific_denied):
|
||||
# Access allowed, collect RLS and CLS from matching entries
|
||||
pass
|
||||
else:
|
||||
# Same specificity: denied wins
|
||||
return TableAccessInfo(
|
||||
access=AccessCheckResult.DENIED,
|
||||
rls_predicates=[],
|
||||
cls_rules={},
|
||||
)
|
||||
elif most_specific_denied is not None:
|
||||
return TableAccessInfo(
|
||||
access=AccessCheckResult.DENIED,
|
||||
rls_predicates=[],
|
||||
cls_rules={},
|
||||
)
|
||||
elif most_specific_allowed is None:
|
||||
return TableAccessInfo(
|
||||
access=AccessCheckResult.NO_RULE,
|
||||
rls_predicates=[],
|
||||
cls_rules={},
|
||||
)
|
||||
|
||||
# Collect RLS predicates from all matching allowed entries
|
||||
# (RLS is cumulative - all predicates are applied)
|
||||
rls_predicates: list[RLSPredicate] = []
|
||||
for entry in allowed_entries:
|
||||
rls_config = entry.get("rls")
|
||||
if rls_config and "predicate" in rls_config:
|
||||
rls_predicates.append(
|
||||
RLSPredicate(
|
||||
predicate=rls_config["predicate"],
|
||||
group_key=rls_config.get("group_key"),
|
||||
)
|
||||
)
|
||||
|
||||
# Collect CLS rules from all matching allowed entries
|
||||
# (CLS is cumulative - strictest action wins per column)
|
||||
cls_rules: dict[str, CLSAction] = {}
|
||||
cls_precedence = {
|
||||
CLSAction.HIDE: 4,
|
||||
CLSAction.NULLIFY: 3,
|
||||
CLSAction.MASK: 2,
|
||||
CLSAction.HASH: 1,
|
||||
}
|
||||
action_map = {
|
||||
"hide": CLSAction.HIDE,
|
||||
"nullify": CLSAction.NULLIFY,
|
||||
"mask": CLSAction.MASK,
|
||||
"hash": CLSAction.HASH,
|
||||
}
|
||||
|
||||
for entry in allowed_entries:
|
||||
cls_config = entry.get("cls", {})
|
||||
for column, action_str in cls_config.items():
|
||||
action = action_map.get(action_str.lower())
|
||||
if action is None:
|
||||
logger.warning("Unknown CLS action: %s", action_str)
|
||||
continue
|
||||
|
||||
existing = cls_rules.get(column)
|
||||
if existing is None or cls_precedence[action] > cls_precedence[existing]:
|
||||
cls_rules[column] = action
|
||||
|
||||
return TableAccessInfo(
|
||||
access=AccessCheckResult.ALLOWED,
|
||||
rls_predicates=rls_predicates,
|
||||
cls_rules=cls_rules,
|
||||
)
|
||||
|
||||
|
||||
def get_rls_predicates_for_table(
|
||||
table: Table,
|
||||
database: Database,
|
||||
rules: list[DataAccessRule] | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Get the RLS predicates for a table using the new Data Access Rules system.
|
||||
|
||||
This function collects all RLS predicates from matching rules and combines them
|
||||
using the group_key logic:
|
||||
- Predicates without group_key are ANDed together
|
||||
- Predicates with the same group_key are ORed together
|
||||
- Groups are ANDed together
|
||||
|
||||
Args:
|
||||
table: The fully qualified Table object
|
||||
database: The Database object
|
||||
rules: Optional list of rules to check (defaults to current user's rules)
|
||||
|
||||
Returns:
|
||||
List of SQL predicate strings to be ANDed together.
|
||||
"""
|
||||
access_info = check_table_access(
|
||||
database_name=database.database_name,
|
||||
table=table,
|
||||
rules=rules,
|
||||
)
|
||||
|
||||
if access_info.access != AccessCheckResult.ALLOWED:
|
||||
return []
|
||||
|
||||
if not access_info.rls_predicates:
|
||||
return []
|
||||
|
||||
# Group predicates by group_key
|
||||
ungrouped: list[str] = []
|
||||
groups: dict[str, list[str]] = defaultdict(list)
|
||||
|
||||
for pred in access_info.rls_predicates:
|
||||
if pred.group_key:
|
||||
groups[pred.group_key].append(f"({pred.predicate})")
|
||||
else:
|
||||
ungrouped.append(f"({pred.predicate})")
|
||||
|
||||
# Build result: ungrouped predicates + OR'd groups
|
||||
result = ungrouped.copy()
|
||||
for group_predicates in groups.values():
|
||||
if len(group_predicates) == 1:
|
||||
result.append(group_predicates[0])
|
||||
else:
|
||||
result.append(f"({' OR '.join(group_predicates)})")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_cls_rules_for_table(
|
||||
table: Table,
|
||||
database: Database,
|
||||
rules: list[DataAccessRule] | None = None,
|
||||
) -> dict[str, CLSAction]:
|
||||
"""
|
||||
Get the CLS rules for a table using the new Data Access Rules system.
|
||||
|
||||
Args:
|
||||
table: The fully qualified Table object
|
||||
database: The Database object
|
||||
rules: Optional list of rules to check (defaults to current user's rules)
|
||||
|
||||
Returns:
|
||||
Dict mapping column names to CLSAction values.
|
||||
"""
|
||||
access_info = check_table_access(
|
||||
database_name=database.database_name,
|
||||
table=table,
|
||||
rules=rules,
|
||||
)
|
||||
|
||||
if access_info.access != AccessCheckResult.ALLOWED:
|
||||
return {}
|
||||
|
||||
return access_info.cls_rules
|
||||
|
||||
|
||||
def get_hidden_columns_for_table(
|
||||
table: Table,
|
||||
database: Database,
|
||||
rules: list[DataAccessRule] | None = None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
Get the set of column names that should be hidden for a table.
|
||||
|
||||
This function checks the CLS rules for the current user and returns
|
||||
the names of columns that have the "hide" action applied.
|
||||
|
||||
Args:
|
||||
table: The fully qualified Table object
|
||||
database: The Database object
|
||||
rules: Optional list of rules to check (defaults to current user's rules)
|
||||
|
||||
Returns:
|
||||
Set of column names that should be hidden.
|
||||
"""
|
||||
cls_rules = get_cls_rules_for_table(table, database, rules)
|
||||
|
||||
hidden_columns: set[str] = set()
|
||||
for column_name, action in cls_rules.items():
|
||||
if action == CLSAction.HIDE:
|
||||
hidden_columns.add(column_name)
|
||||
|
||||
return hidden_columns
|
||||
|
||||
|
||||
def filter_columns_by_cls(
|
||||
columns: list[dict[str, Any]],
|
||||
table: Table,
|
||||
database: Database,
|
||||
column_name_key: str = "column_name",
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Filter a list of column dictionaries to exclude hidden columns.
|
||||
|
||||
This function is useful for filtering column metadata returned by
|
||||
database reflection or dataset APIs.
|
||||
|
||||
Args:
|
||||
columns: List of column dictionaries
|
||||
table: The fully qualified Table object
|
||||
database: The Database object
|
||||
column_name_key: The key in the column dict that contains the column name
|
||||
|
||||
Returns:
|
||||
Filtered list of columns with hidden columns removed.
|
||||
"""
|
||||
if not is_feature_enabled("DATA_ACCESS_RULES"):
|
||||
return columns
|
||||
|
||||
hidden_columns = get_hidden_columns_for_table(table, database)
|
||||
|
||||
if not hidden_columns:
|
||||
return columns
|
||||
|
||||
return [
|
||||
col for col in columns
|
||||
if col.get(column_name_key) not in hidden_columns
|
||||
]
|
||||
|
||||
|
||||
def apply_data_access_rules(
|
||||
database: Database,
|
||||
catalog: str | None,
|
||||
schema: str,
|
||||
parsed_statement: BaseSQLStatement[Any],
|
||||
) -> None:
|
||||
"""
|
||||
Apply Data Access Rules (RLS and CLS) to a parsed SQL statement.
|
||||
|
||||
This function:
|
||||
1. Checks if the DATA_ACCESS_RULES feature is enabled
|
||||
2. For each table in the query, checks access and collects RLS/CLS rules
|
||||
3. Applies RLS predicates using the existing infrastructure
|
||||
4. Applies CLS rules using the existing infrastructure
|
||||
|
||||
Args:
|
||||
database: The Database object
|
||||
catalog: The default catalog for the query
|
||||
schema: The default schema for the query
|
||||
parsed_statement: The parsed SQL statement to modify in place
|
||||
"""
|
||||
if not is_feature_enabled("DATA_ACCESS_RULES"):
|
||||
return
|
||||
|
||||
from superset.sql.parse import CLSRules
|
||||
|
||||
rules = get_user_rules()
|
||||
if not rules:
|
||||
return
|
||||
|
||||
# Get the RLS method for this database
|
||||
method = database.db_engine_spec.get_rls_method()
|
||||
|
||||
# Collect RLS predicates and CLS rules for all tables
|
||||
rls_predicates: dict[Table, list[Any]] = {}
|
||||
cls_rules: CLSRules = {}
|
||||
|
||||
for table in parsed_statement.tables:
|
||||
qualified_table = table.qualify(catalog=catalog, schema=schema)
|
||||
|
||||
# Check access first
|
||||
access_info = check_table_access(
|
||||
database_name=database.database_name,
|
||||
table=qualified_table,
|
||||
rules=rules,
|
||||
)
|
||||
|
||||
if access_info.access == AccessCheckResult.DENIED:
|
||||
# TODO: How should we handle denied access mid-query?
|
||||
# For now, log a warning. In the future, we might raise an exception.
|
||||
logger.warning(
|
||||
"Access denied to table %s for user %s",
|
||||
qualified_table,
|
||||
getattr(g, "user", "unknown"),
|
||||
)
|
||||
continue
|
||||
|
||||
# Collect RLS predicates
|
||||
predicates = get_rls_predicates_for_table(qualified_table, database, rules)
|
||||
if predicates:
|
||||
rls_predicates[qualified_table] = [
|
||||
parsed_statement.parse_predicate(pred) for pred in predicates if pred
|
||||
]
|
||||
|
||||
# Collect CLS rules
|
||||
table_cls = get_cls_rules_for_table(qualified_table, database, rules)
|
||||
if table_cls:
|
||||
cls_rules[qualified_table] = table_cls
|
||||
|
||||
# Apply CLS first (before RLS) so that hidden columns are removed
|
||||
# before RLS wraps the query in a subquery
|
||||
if cls_rules:
|
||||
# Build schema dict for sqlglot's qualify() to expand SELECT *
|
||||
# sqlglot expects nested format: {catalog: {schema: {table: {col: type}}}}
|
||||
# or {schema: {table: {col: type}}} without catalog
|
||||
table_schemas: dict[str, Any] = {}
|
||||
for table in cls_rules.keys():
|
||||
try:
|
||||
columns = database.get_columns(table)
|
||||
col_types = {
|
||||
col["column_name"]: str(col.get("type", "VARCHAR"))
|
||||
for col in columns
|
||||
}
|
||||
|
||||
# Build nested structure for sqlglot
|
||||
if table.catalog:
|
||||
if table.catalog not in table_schemas:
|
||||
table_schemas[table.catalog] = {}
|
||||
if table.schema:
|
||||
if table.schema not in table_schemas[table.catalog]:
|
||||
table_schemas[table.catalog][table.schema] = {}
|
||||
table_schemas[table.catalog][table.schema][table.table] = col_types
|
||||
else:
|
||||
table_schemas[table.catalog][table.table] = col_types
|
||||
elif table.schema:
|
||||
if table.schema not in table_schemas:
|
||||
table_schemas[table.schema] = {}
|
||||
table_schemas[table.schema][table.table] = col_types
|
||||
else:
|
||||
table_schemas[table.table] = col_types
|
||||
except Exception as ex:
|
||||
logger.warning(
|
||||
"Could not fetch schema for table %s: %s",
|
||||
table,
|
||||
ex,
|
||||
)
|
||||
|
||||
parsed_statement.apply_cls(cls_rules, schema=table_schemas if table_schemas else None)
|
||||
|
||||
# Apply RLS after CLS - RLS wraps the query in a subquery with SELECT *
|
||||
# which will pick up the already-transformed columns from CLS
|
||||
if rls_predicates:
|
||||
parsed_statement.apply_rls(catalog, schema, rls_predicates, method)
|
||||
|
||||
|
||||
def get_allowed_tables(
|
||||
database_name: str,
|
||||
schema: str | None = None,
|
||||
catalog: str | None = None,
|
||||
) -> tuple[set[str], bool]:
|
||||
"""
|
||||
Get all table names that the current user has access to via Data Access Rules
|
||||
for a specific database and schema.
|
||||
|
||||
Args:
|
||||
database_name: The database name to check
|
||||
schema: Optional schema name to filter by
|
||||
catalog: Optional catalog name to filter by
|
||||
|
||||
Returns:
|
||||
Tuple of (set of table names, bool indicating if schema-level access is granted).
|
||||
If schema-level access is granted, the set may be empty but all tables are allowed.
|
||||
"""
|
||||
if not is_feature_enabled("DATA_ACCESS_RULES"):
|
||||
return set(), False
|
||||
|
||||
rules = get_user_rules()
|
||||
if not rules:
|
||||
return set(), False
|
||||
|
||||
table_names: set[str] = set()
|
||||
schema_level_access = False
|
||||
|
||||
for rule in rules:
|
||||
rule_dict = rule.rule_dict
|
||||
|
||||
# Collect tables from allowed entries
|
||||
for entry in rule_dict.get("allowed", []):
|
||||
if entry.get("database") != database_name:
|
||||
continue
|
||||
|
||||
# If catalog is specified in the entry, it must match
|
||||
entry_catalog = entry.get("catalog")
|
||||
if catalog is not None and entry_catalog is not None:
|
||||
if entry_catalog != catalog:
|
||||
continue
|
||||
|
||||
# If schema is specified, check if it matches
|
||||
entry_schema = entry.get("schema")
|
||||
if schema is not None and entry_schema is not None:
|
||||
if entry_schema != schema:
|
||||
continue
|
||||
|
||||
# If entry has a table, add it to the set
|
||||
if table := entry.get("table"):
|
||||
table_names.add(table)
|
||||
elif entry_schema == schema or (entry_schema is None and schema is None):
|
||||
# Schema-level or database-level access without table means all tables
|
||||
schema_level_access = True
|
||||
|
||||
return table_names, schema_level_access
|
||||
|
||||
|
||||
def get_allowed_schemas(database_name: str, catalog: str | None = None) -> set[str]:
|
||||
"""
|
||||
Get all schema names that the current user has access to via Data Access Rules
|
||||
for a specific database.
|
||||
|
||||
Args:
|
||||
database_name: The database name to check
|
||||
catalog: Optional catalog name to filter by
|
||||
|
||||
Returns:
|
||||
Set of schema names the user has access to.
|
||||
"""
|
||||
if not is_feature_enabled("DATA_ACCESS_RULES"):
|
||||
return set()
|
||||
|
||||
rules = get_user_rules()
|
||||
if not rules:
|
||||
return set()
|
||||
|
||||
schema_names: set[str] = set()
|
||||
|
||||
for rule in rules:
|
||||
rule_dict = rule.rule_dict
|
||||
|
||||
# Collect schemas from allowed entries
|
||||
for entry in rule_dict.get("allowed", []):
|
||||
if entry.get("database") != database_name:
|
||||
continue
|
||||
|
||||
# If catalog is specified in the entry, it must match
|
||||
entry_catalog = entry.get("catalog")
|
||||
if catalog is not None and entry_catalog is not None:
|
||||
if entry_catalog != catalog:
|
||||
continue
|
||||
|
||||
# If the entry grants database-level access (no schema specified),
|
||||
# we return an empty set to indicate "all schemas" should be allowed
|
||||
# This will be handled by the caller
|
||||
if schema := entry.get("schema"):
|
||||
schema_names.add(schema)
|
||||
elif entry.get("database") == database_name:
|
||||
# Database-level access without schema means all schemas
|
||||
# Return a special marker that caller can check
|
||||
schema_names.add("*")
|
||||
|
||||
return schema_names
|
||||
|
||||
|
||||
def get_allowed_databases() -> set[str]:
|
||||
"""
|
||||
Get all database names that the current user has access to via Data Access Rules.
|
||||
|
||||
This function is used to populate database selectors in SQL Lab and elsewhere.
|
||||
|
||||
Returns:
|
||||
Set of database names the user has access to.
|
||||
"""
|
||||
if not is_feature_enabled("DATA_ACCESS_RULES"):
|
||||
return set()
|
||||
|
||||
rules = get_user_rules()
|
||||
if not rules:
|
||||
return set()
|
||||
|
||||
database_names: set[str] = set()
|
||||
|
||||
for rule in rules:
|
||||
rule_dict = rule.rule_dict
|
||||
|
||||
# Collect databases from allowed entries
|
||||
for entry in rule_dict.get("allowed", []):
|
||||
if database := entry.get("database"):
|
||||
database_names.add(database)
|
||||
|
||||
return database_names
|
||||
|
||||
|
||||
@dataclass
|
||||
class AllowedTable:
|
||||
"""A table allowed by DAR with database context."""
|
||||
|
||||
database: str
|
||||
table: str
|
||||
schema: str | None = None
|
||||
catalog: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AllowedEntry:
|
||||
"""
|
||||
An allowed entry from DAR at any level of the hierarchy.
|
||||
|
||||
Fields may be None to indicate "all" at that level:
|
||||
- database only: all catalogs/schemas/tables in that database
|
||||
- database + catalog: all schemas/tables in that catalog
|
||||
- database + schema: all tables in that schema (for DBs without catalogs)
|
||||
- database + catalog + schema: all tables in that schema
|
||||
- database + schema + table: specific table (for DBs without catalogs)
|
||||
- database + catalog + schema + table: specific table
|
||||
"""
|
||||
|
||||
database: str
|
||||
catalog: str | None = None
|
||||
schema: str | None = None
|
||||
table: str | None = None
|
||||
|
||||
|
||||
def get_all_allowed_entries() -> list[AllowedEntry]:
|
||||
"""
|
||||
Get all access entries that the current user has via Data Access Rules
|
||||
across all databases.
|
||||
|
||||
This function returns entries at all hierarchy levels (database, schema, table),
|
||||
allowing callers to build appropriate filters for their use case.
|
||||
|
||||
Returns:
|
||||
List of AllowedEntry objects representing allowed access.
|
||||
"""
|
||||
if not is_feature_enabled("DATA_ACCESS_RULES"):
|
||||
return []
|
||||
|
||||
rules = get_user_rules()
|
||||
if not rules:
|
||||
return []
|
||||
|
||||
allowed_entries: list[AllowedEntry] = []
|
||||
|
||||
for rule in rules:
|
||||
rule_dict = rule.rule_dict
|
||||
|
||||
# Collect all allowed entries
|
||||
for entry in rule_dict.get("allowed", []):
|
||||
database = entry.get("database")
|
||||
if not database:
|
||||
continue
|
||||
|
||||
allowed_entries.append(
|
||||
AllowedEntry(
|
||||
database=database,
|
||||
catalog=entry.get("catalog"),
|
||||
schema=entry.get("schema"),
|
||||
table=entry.get("table"),
|
||||
)
|
||||
)
|
||||
|
||||
return allowed_entries
|
||||
|
||||
|
||||
def get_all_allowed_tables() -> list[AllowedTable]:
|
||||
"""
|
||||
Get all tables that the current user has access to via Data Access Rules
|
||||
across all databases.
|
||||
|
||||
This function is used for dataset filtering where we need to know all
|
||||
specific tables the user can access.
|
||||
|
||||
Returns:
|
||||
List of AllowedTable objects representing allowed tables.
|
||||
"""
|
||||
if not is_feature_enabled("DATA_ACCESS_RULES"):
|
||||
return []
|
||||
|
||||
rules = get_user_rules()
|
||||
if not rules:
|
||||
return []
|
||||
|
||||
allowed_tables: list[AllowedTable] = []
|
||||
|
||||
for rule in rules:
|
||||
rule_dict = rule.rule_dict
|
||||
|
||||
# Collect tables from allowed entries
|
||||
for entry in rule_dict.get("allowed", []):
|
||||
database = entry.get("database")
|
||||
if not database:
|
||||
continue
|
||||
|
||||
table_name = entry.get("table")
|
||||
if not table_name:
|
||||
# Skip database-level or schema-level access for now
|
||||
# as we can't enumerate all tables without querying the DB
|
||||
continue
|
||||
|
||||
allowed_tables.append(
|
||||
AllowedTable(
|
||||
database=database,
|
||||
table=table_name,
|
||||
schema=entry.get("schema"),
|
||||
catalog=entry.get("catalog"),
|
||||
)
|
||||
)
|
||||
|
||||
return allowed_tables
|
||||
|
||||
|
||||
def get_all_group_keys(
|
||||
database_name: str | None = None,
|
||||
table: Table | None = None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
Get all distinct group_keys used in RLS rules.
|
||||
|
||||
This is useful for UI discoverability - showing users what group_keys
|
||||
already exist so they can reuse them for consistent rule grouping.
|
||||
|
||||
Args:
|
||||
database_name: Optional filter by database
|
||||
table: Optional Table object to filter by catalog/schema/table
|
||||
|
||||
Returns:
|
||||
Set of unique group_key values.
|
||||
"""
|
||||
from superset.data_access_rules.models import DataAccessRule
|
||||
|
||||
query = db.session.query(DataAccessRule)
|
||||
rules = query.all()
|
||||
|
||||
group_keys: set[str] = set()
|
||||
|
||||
for rule in rules:
|
||||
rule_dict = rule.rule_dict
|
||||
|
||||
for entry in rule_dict.get("allowed", []):
|
||||
# Apply filters if specified
|
||||
if database_name and entry.get("database") != database_name:
|
||||
continue
|
||||
if table is not None:
|
||||
if table.catalog and entry.get("catalog") != table.catalog:
|
||||
continue
|
||||
if table.schema and entry.get("schema") != table.schema:
|
||||
continue
|
||||
if table.table and entry.get("table") != table.table:
|
||||
continue
|
||||
|
||||
rls_config = entry.get("rls", {})
|
||||
if group_key := rls_config.get("group_key"):
|
||||
group_keys.add(group_key)
|
||||
|
||||
return group_keys
|
||||
@@ -717,16 +717,37 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||
if not database:
|
||||
return self.response_404()
|
||||
try:
|
||||
params = kwargs["rison"]
|
||||
catalogs = database.get_all_catalog_names(
|
||||
cache=database.catalog_cache_enabled,
|
||||
cache_timeout=database.catalog_cache_timeout or None,
|
||||
force=kwargs["rison"].get("force", False),
|
||||
force=params.get("force", False),
|
||||
)
|
||||
catalogs = security_manager.get_catalogs_accessible_by_user(
|
||||
database,
|
||||
catalogs,
|
||||
)
|
||||
return self.response(200, result=list(catalogs))
|
||||
|
||||
# Convert to list and sort
|
||||
catalogs = sorted(catalogs)
|
||||
|
||||
# Apply filter if provided
|
||||
filter_str = params.get("filter", "").lower()
|
||||
if filter_str:
|
||||
catalogs = [c for c in catalogs if filter_str in c.lower()]
|
||||
|
||||
# Get total count before pagination
|
||||
total_count = len(catalogs)
|
||||
|
||||
# Apply pagination if provided
|
||||
page = params.get("page")
|
||||
page_size = params.get("page_size")
|
||||
if page is not None and page_size is not None:
|
||||
start = page * page_size
|
||||
end = start + page_size
|
||||
catalogs = catalogs[start:end]
|
||||
|
||||
return self.response(200, result=catalogs, count=total_count)
|
||||
except OperationalError:
|
||||
return self.response(
|
||||
500,
|
||||
@@ -797,21 +818,38 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||
)
|
||||
if params.get("upload_allowed"):
|
||||
if not database.allow_file_upload:
|
||||
return self.response(200, result=[])
|
||||
return self.response(200, result=[], count=0)
|
||||
if allowed_schemas := database.get_schema_access_for_file_upload():
|
||||
# some databases might return the list of schemas in uppercase,
|
||||
# while the list of allowed schemas is manually inputted so
|
||||
# could be lowercase
|
||||
allowed_schemas = {schema.lower() for schema in allowed_schemas}
|
||||
return self.response(
|
||||
200,
|
||||
result=[
|
||||
schema
|
||||
for schema in schemas
|
||||
if schema.lower() in allowed_schemas
|
||||
],
|
||||
)
|
||||
return self.response(200, result=list(schemas))
|
||||
schemas = [
|
||||
schema
|
||||
for schema in schemas
|
||||
if schema.lower() in allowed_schemas
|
||||
]
|
||||
|
||||
# Convert to list and sort
|
||||
schemas = sorted(schemas)
|
||||
|
||||
# Apply filter if provided
|
||||
filter_str = params.get("filter", "").lower()
|
||||
if filter_str:
|
||||
schemas = [s for s in schemas if filter_str in s.lower()]
|
||||
|
||||
# Get total count before pagination
|
||||
total_count = len(schemas)
|
||||
|
||||
# Apply pagination if provided
|
||||
page = params.get("page")
|
||||
page_size = params.get("page_size")
|
||||
if page is not None and page_size is not None:
|
||||
start = page * page_size
|
||||
end = start + page_size
|
||||
schemas = schemas[start:end]
|
||||
|
||||
return self.response(200, result=schemas, count=total_count)
|
||||
except OperationalError:
|
||||
return self.response(
|
||||
500, message="There was an error connecting to the database"
|
||||
@@ -874,11 +912,23 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
force = kwargs["rison"].get("force", False)
|
||||
catalog_name = kwargs["rison"].get("catalog_name")
|
||||
schema_name = kwargs["rison"].get("schema_name", "")
|
||||
params = kwargs["rison"]
|
||||
force = params.get("force", False)
|
||||
catalog_name = params.get("catalog_name")
|
||||
schema_name = params.get("schema_name", "")
|
||||
filter_str = params.get("filter")
|
||||
page = params.get("page")
|
||||
page_size = params.get("page_size")
|
||||
|
||||
command = TablesDatabaseCommand(pk, catalog_name, schema_name, force)
|
||||
command = TablesDatabaseCommand(
|
||||
pk,
|
||||
catalog_name,
|
||||
schema_name,
|
||||
force,
|
||||
filter_str=filter_str,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
payload = command.run()
|
||||
return self.response(200, **payload)
|
||||
|
||||
|
||||
@@ -23,11 +23,22 @@ from sqlalchemy.orm import Query
|
||||
from sqlalchemy.sql.expression import cast
|
||||
from sqlalchemy.sql.sqltypes import JSON
|
||||
|
||||
from superset import security_manager
|
||||
from superset import is_feature_enabled, security_manager
|
||||
from superset.models.core import Database
|
||||
from superset.views.base import BaseFilter
|
||||
|
||||
|
||||
def get_dar_allowed_databases() -> set[str]:
|
||||
"""Get databases allowed by Data Access Rules for the current user."""
|
||||
if not is_feature_enabled("DATA_ACCESS_RULES"):
|
||||
return set()
|
||||
|
||||
# Lazy import to avoid circular dependency
|
||||
from superset.data_access_rules.utils import get_allowed_databases
|
||||
|
||||
return get_allowed_databases()
|
||||
|
||||
|
||||
def can_access_databases(view_menu_name: str) -> set[str]:
|
||||
"""
|
||||
Return names of databases available in `view_menu_name`.
|
||||
@@ -62,10 +73,15 @@ class DatabaseFilter(BaseFilter): # pylint: disable=too-few-public-methods
|
||||
catalog_access_databases = can_access_databases("catalog_access")
|
||||
schema_access_databases = can_access_databases("schema_access")
|
||||
datasource_access_databases = can_access_databases("datasource_access")
|
||||
|
||||
# Include databases from Data Access Rules
|
||||
dar_databases = get_dar_allowed_databases()
|
||||
|
||||
database_names = sorted(
|
||||
catalog_access_databases
|
||||
| schema_access_databases
|
||||
| datasource_access_databases
|
||||
| dar_databases
|
||||
)
|
||||
|
||||
return query.filter(
|
||||
|
||||
@@ -66,12 +66,20 @@ database_schemas_query_schema = {
|
||||
"force": {"type": "boolean"},
|
||||
"upload_allowed": {"type": "boolean"},
|
||||
"catalog": {"type": "string"},
|
||||
"filter": {"type": "string"},
|
||||
"page": {"type": "integer", "minimum": 0},
|
||||
"page_size": {"type": "integer", "minimum": 1, "maximum": 1000},
|
||||
},
|
||||
}
|
||||
|
||||
database_catalogs_query_schema = {
|
||||
"type": "object",
|
||||
"properties": {"force": {"type": "boolean"}},
|
||||
"properties": {
|
||||
"force": {"type": "boolean"},
|
||||
"filter": {"type": "string"},
|
||||
"page": {"type": "integer", "minimum": 0},
|
||||
"page_size": {"type": "integer", "minimum": 1, "maximum": 1000},
|
||||
},
|
||||
}
|
||||
|
||||
database_tables_query_schema = {
|
||||
@@ -80,6 +88,9 @@ database_tables_query_schema = {
|
||||
"force": {"type": "boolean"},
|
||||
"schema_name": {"type": "string"},
|
||||
"catalog_name": {"type": "string"},
|
||||
"filter": {"type": "string"},
|
||||
"page": {"type": "integer", "minimum": 0},
|
||||
"page_size": {"type": "integer", "minimum": 1, "maximum": 1000},
|
||||
},
|
||||
"required": ["schema_name"],
|
||||
}
|
||||
|
||||
@@ -73,6 +73,12 @@ def get_table_metadata(database: Any, table: Table) -> TableMetadataResponse:
|
||||
"""
|
||||
keys = []
|
||||
columns = database.get_columns(table)
|
||||
|
||||
# Filter out columns hidden by CLS rules (lazy import to avoid circular dependency)
|
||||
from superset.data_access_rules.utils import filter_columns_by_cls
|
||||
|
||||
columns = filter_columns_by_cls(columns, table, database)
|
||||
|
||||
primary_key = database.get_pk_constraint(table)
|
||||
if primary_key and primary_key.get("constrained_columns"):
|
||||
primary_key["column_names"] = primary_key.pop("constrained_columns")
|
||||
|
||||
@@ -161,6 +161,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
|
||||
from superset.dashboards.api import DashboardRestApi
|
||||
from superset.dashboards.filter_state.api import DashboardFilterStateRestApi
|
||||
from superset.dashboards.permalink.api import DashboardPermalinkRestApi
|
||||
from superset.data_access_rules.api import DataAccessRulesRestApi
|
||||
from superset.databases.api import DatabaseRestApi
|
||||
from superset.datasets.api import DatasetRestApi
|
||||
from superset.datasets.columns.api import DatasetColumnsRestApi
|
||||
@@ -213,6 +214,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
|
||||
TabStateView,
|
||||
)
|
||||
from superset.views.sqla import (
|
||||
DataAccessRulesView,
|
||||
RowLevelSecurityView,
|
||||
TableModelView,
|
||||
)
|
||||
@@ -264,6 +266,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
|
||||
appbuilder.add_api(ReportScheduleRestApi)
|
||||
appbuilder.add_api(ReportExecutionLogRestApi)
|
||||
appbuilder.add_api(RLSRestApi)
|
||||
appbuilder.add_api(DataAccessRulesRestApi)
|
||||
appbuilder.add_api(SavedQueryRestApi)
|
||||
appbuilder.add_api(TagRestApi)
|
||||
appbuilder.add_api(SqlLabRestApi)
|
||||
@@ -518,6 +521,17 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
|
||||
icon="fa-lock",
|
||||
)
|
||||
|
||||
if feature_flag_manager.is_feature_enabled("DATA_ACCESS_RULES"):
|
||||
appbuilder.add_view(
|
||||
DataAccessRulesView,
|
||||
"Data Access Rules",
|
||||
href="DataAccessRulesView.list",
|
||||
label=_("Data Access Rules"),
|
||||
category="Security",
|
||||
category_label=_("Security"),
|
||||
icon="fa-shield",
|
||||
)
|
||||
|
||||
def init_core_dependencies(self) -> None:
|
||||
"""Initialize core dependency injection for direct import patterns."""
|
||||
from superset.core.api.core_api_injection import (
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""add_data_access_rules_table
|
||||
|
||||
Revision ID: a352d7609189
|
||||
Revises: a9c01ec10479
|
||||
Create Date: 2025-12-17 10:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import mysql
|
||||
|
||||
from superset.migrations.shared.utils import (
|
||||
create_fks_for_table,
|
||||
create_table,
|
||||
drop_table,
|
||||
)
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a352d7609189"
|
||||
down_revision = "a9c01ec10479"
|
||||
|
||||
|
||||
def upgrade():
|
||||
create_table(
|
||||
"data_access_rules",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("role_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"rule",
|
||||
sa.Text().with_variant(mysql.MEDIUMTEXT(), "mysql"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("created_on", sa.DateTime(), nullable=True),
|
||||
sa.Column("changed_on", sa.DateTime(), nullable=True),
|
||||
sa.Column("created_by_fk", sa.Integer(), nullable=True),
|
||||
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Create foreign key constraints
|
||||
create_fks_for_table(
|
||||
"fk_data_access_rules_role_id_ab_role",
|
||||
"data_access_rules",
|
||||
"ab_role",
|
||||
["role_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
create_fks_for_table(
|
||||
"fk_data_access_rules_created_by_fk_ab_user",
|
||||
"data_access_rules",
|
||||
"ab_user",
|
||||
["created_by_fk"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
create_fks_for_table(
|
||||
"fk_data_access_rules_changed_by_fk_ab_user",
|
||||
"data_access_rules",
|
||||
"ab_user",
|
||||
["changed_by_fk"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
drop_table("data_access_rules")
|
||||
@@ -0,0 +1,43 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""add name and description to data_access_rules
|
||||
|
||||
Revision ID: b463d8709290
|
||||
Revises: a352d7609189
|
||||
Create Date: 2025-12-17 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from superset.migrations.shared.utils import add_columns, drop_columns
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b463d8709290"
|
||||
down_revision = "a352d7609189"
|
||||
|
||||
|
||||
def upgrade():
|
||||
add_columns(
|
||||
"data_access_rules",
|
||||
sa.Column("name", sa.String(250), nullable=True),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
drop_columns("data_access_rules", "name", "description")
|
||||
@@ -120,6 +120,7 @@ from superset.utils.core import (
|
||||
from superset.utils.date_parser import get_past_or_future, normalize_time_delta
|
||||
from superset.utils.dates import datetime_to_epoch
|
||||
from superset.utils.rls import apply_rls
|
||||
from superset.data_access_rules.utils import apply_data_access_rules
|
||||
|
||||
|
||||
class ValidationResultDict(TypedDict):
|
||||
@@ -1049,6 +1050,22 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||
)
|
||||
sql = self._apply_cte(sql, sqlaq.cte)
|
||||
|
||||
# Apply Data Access Rules (RLS and CLS) if enabled
|
||||
if is_feature_enabled("DATA_ACCESS_RULES"):
|
||||
try:
|
||||
default_schema = self.database.get_default_schema(self.catalog)
|
||||
parsed_script = SQLScript(sql, engine=self.db_engine_spec.engine)
|
||||
for statement in parsed_script.statements:
|
||||
apply_data_access_rules(
|
||||
self.database,
|
||||
self.catalog,
|
||||
self.schema or default_schema or "",
|
||||
statement,
|
||||
)
|
||||
sql = parsed_script.format()
|
||||
except Exception as ex:
|
||||
logger.warning("Failed to apply Data Access Rules: %s", ex)
|
||||
|
||||
if mutate:
|
||||
sql = self.database.mutate_sql_based_on_config(sql)
|
||||
return QueryStringExtended(
|
||||
@@ -2051,6 +2068,23 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||
# Log the error but don't fail - RLS application is best-effort
|
||||
logger.warning("Failed to apply RLS to virtual dataset SQL: %s", ex)
|
||||
|
||||
# Apply Data Access Rules to virtual dataset SQL
|
||||
if is_feature_enabled("DATA_ACCESS_RULES") and parsed_script.statements:
|
||||
default_schema = self.database.get_default_schema(self.catalog)
|
||||
try:
|
||||
for statement in parsed_script.statements:
|
||||
apply_data_access_rules(
|
||||
self.database,
|
||||
self.catalog,
|
||||
self.schema or default_schema or "",
|
||||
statement,
|
||||
)
|
||||
from_sql = parsed_script.format()
|
||||
except Exception as ex:
|
||||
logger.warning(
|
||||
"Failed to apply Data Access Rules to virtual dataset SQL: %s", ex
|
||||
)
|
||||
|
||||
cte = self.db_engine_spec.get_cte_query(from_sql)
|
||||
from_clause = (
|
||||
sa.table(self.db_engine_spec.cte_alias)
|
||||
|
||||
@@ -976,6 +976,19 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
}
|
||||
)
|
||||
|
||||
# Data Access Rules
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset import is_feature_enabled
|
||||
|
||||
if is_feature_enabled("DATA_ACCESS_RULES"):
|
||||
from superset.data_access_rules.utils import get_allowed_schemas
|
||||
|
||||
dar_schemas = get_allowed_schemas(database.database_name, catalog)
|
||||
if "*" in dar_schemas:
|
||||
# Database-level access means all schemas
|
||||
return schemas
|
||||
accessible_schemas.update(dar_schemas)
|
||||
|
||||
return schemas & accessible_schemas
|
||||
|
||||
def get_catalogs_accessible_by_user(
|
||||
@@ -1091,6 +1104,24 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
)
|
||||
}
|
||||
|
||||
# Check Data Access Rules
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset import is_feature_enabled
|
||||
|
||||
if is_feature_enabled("DATA_ACCESS_RULES"):
|
||||
from superset.data_access_rules.utils import get_allowed_tables
|
||||
|
||||
dar_tables, schema_level_access = get_allowed_tables(
|
||||
database.database_name, schema, catalog
|
||||
)
|
||||
if schema_level_access:
|
||||
# Schema-level access means all tables in the schema
|
||||
return datasource_names
|
||||
|
||||
# Add DAR tables to accessible datasources
|
||||
for table_name in dar_tables:
|
||||
user_datasources.add(DatasourceName(table_name, schema, catalog))
|
||||
|
||||
return [
|
||||
datasource
|
||||
for datasource in datasource_names
|
||||
@@ -2410,6 +2441,20 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
# access to any datasource is sufficient
|
||||
break
|
||||
else:
|
||||
# Check Data Access Rules before denying
|
||||
if is_feature_enabled("DATA_ACCESS_RULES"):
|
||||
from superset.data_access_rules.utils import (
|
||||
AccessCheckResult,
|
||||
check_table_access,
|
||||
)
|
||||
|
||||
access_info = check_table_access(
|
||||
database_name=database.database_name,
|
||||
table=table_,
|
||||
)
|
||||
if access_info.access == AccessCheckResult.ALLOWED:
|
||||
continue
|
||||
|
||||
denied.add(table_)
|
||||
|
||||
if denied:
|
||||
@@ -2444,10 +2489,33 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
|
||||
assert datasource
|
||||
|
||||
# Check DAR access for datasource
|
||||
dar_allowed = False
|
||||
if is_feature_enabled("DATA_ACCESS_RULES") and hasattr(
|
||||
datasource, "database"
|
||||
):
|
||||
from superset.data_access_rules.utils import (
|
||||
AccessCheckResult,
|
||||
check_table_access,
|
||||
)
|
||||
from superset.sql.parse import Table
|
||||
|
||||
dar_table = Table(
|
||||
table=datasource.table_name,
|
||||
schema=datasource.schema,
|
||||
catalog=getattr(datasource, "catalog", None),
|
||||
)
|
||||
dar_access_info = check_table_access(
|
||||
database_name=datasource.database.database_name,
|
||||
table=dar_table,
|
||||
)
|
||||
dar_allowed = dar_access_info.access == AccessCheckResult.ALLOWED
|
||||
|
||||
if not (
|
||||
self.can_access_schema(datasource)
|
||||
or self.can_access("datasource_access", datasource.perm or "")
|
||||
or self.is_owner(datasource)
|
||||
or dar_allowed
|
||||
or (
|
||||
# Grant access to the datasource only if dashboard RBAC is enabled
|
||||
# or the user is an embedded guest user with access to the dashboard
|
||||
|
||||
@@ -129,6 +129,145 @@ class LimitMethod(enum.Enum):
|
||||
FETCH_MANY = enum.auto()
|
||||
|
||||
|
||||
class CLSAction(enum.Enum):
|
||||
"""
|
||||
Column-Level Security actions.
|
||||
|
||||
These actions determine how sensitive columns are transformed in queries.
|
||||
"""
|
||||
|
||||
HASH = enum.auto() # Pseudonymization via hashing
|
||||
NULLIFY = enum.auto() # Replace with NULL
|
||||
HIDE = enum.auto() # Remove from results entirely
|
||||
MASK = enum.auto() # Replace with '****'
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class Table:
|
||||
"""
|
||||
A fully qualified SQL table conforming to [[catalog.]schema.]table.
|
||||
"""
|
||||
|
||||
table: str
|
||||
schema: str | None = None
|
||||
catalog: str | None = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Return the fully qualified SQL table name.
|
||||
|
||||
Should not be used for SQL generation, only for logging and debugging, since the
|
||||
quoting is not engine-specific.
|
||||
"""
|
||||
return ".".join(
|
||||
urllib.parse.quote(part, safe="").replace(".", "%2E")
|
||||
for part in [self.catalog, self.schema, self.table]
|
||||
if part
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return str(self) == str(other)
|
||||
|
||||
def qualify(
|
||||
self,
|
||||
*,
|
||||
catalog: str | None = None,
|
||||
schema: str | None = None,
|
||||
) -> Table:
|
||||
"""
|
||||
Return a new Table with the given schema and/or catalog, if not already set.
|
||||
"""
|
||||
return Table(
|
||||
table=self.table,
|
||||
schema=self.schema or schema,
|
||||
catalog=self.catalog or catalog,
|
||||
)
|
||||
|
||||
|
||||
# Type alias for CLS rules: {Table: {column_name: action}}
|
||||
CLSRules = dict[Table, dict[str, CLSAction]]
|
||||
|
||||
# CLS action precedence: higher value = stricter (less information revealed)
|
||||
# HIDE > NULLIFY > MASK > HASH
|
||||
CLS_ACTION_PRECEDENCE: dict[CLSAction, int] = {
|
||||
CLSAction.HASH: 1,
|
||||
CLSAction.MASK: 2,
|
||||
CLSAction.NULLIFY: 3,
|
||||
CLSAction.HIDE: 4,
|
||||
}
|
||||
|
||||
|
||||
def merge_cls_rules(*rules_list: CLSRules) -> CLSRules:
|
||||
"""
|
||||
Merge multiple CLS rule sets into one, using the stricter action when conflicts occur.
|
||||
|
||||
When multiple rules specify actions for the same table/column, the stricter action
|
||||
is kept. Precedence (strictest to least strict): HIDE > NULLIFY > MASK > HASH
|
||||
|
||||
Args:
|
||||
*rules_list: Variable number of CLSRules dicts to merge
|
||||
|
||||
Returns:
|
||||
A merged CLSRules dict with the strictest action for each table/column
|
||||
|
||||
Example:
|
||||
>>> rules1 = {Table("foo"): {"col1": CLSAction.HASH}}
|
||||
>>> rules2 = {Table("foo"): {"col1": CLSAction.NULLIFY, "col2": CLSAction.HIDE}}
|
||||
>>> merge_cls_rules(rules1, rules2)
|
||||
{Table("foo"): {"col1": CLSAction.NULLIFY, "col2": CLSAction.HIDE}}
|
||||
"""
|
||||
merged: CLSRules = {}
|
||||
|
||||
for rules in rules_list:
|
||||
for table, columns in rules.items():
|
||||
if table not in merged:
|
||||
merged[table] = {}
|
||||
|
||||
for column, action in columns.items():
|
||||
existing_action = merged[table].get(column)
|
||||
if existing_action is None:
|
||||
merged[table][column] = action
|
||||
else:
|
||||
# Keep the stricter action (higher precedence value)
|
||||
if CLS_ACTION_PRECEDENCE[action] > CLS_ACTION_PRECEDENCE[existing_action]:
|
||||
merged[table][column] = action
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
# Hash function patterns by dialect. The placeholder {} will be replaced with the
|
||||
# column. Some databases need casting for non-string types, so we cast to string/text.
|
||||
# The fallback uses a literal since there's no universal hash function across all
|
||||
# databases.
|
||||
CLS_HASH_FUNCTIONS: dict[Dialects | type[Dialect] | None, str] = {
|
||||
None: "'[HASHED]'", # Universal fallback - no hash function available
|
||||
Dialects.DIALECT: "MD5(CAST({} AS VARCHAR))", # Generic SQL with MD5
|
||||
Dialects.POSTGRES: "MD5(CAST({} AS TEXT))",
|
||||
Dialects.MYSQL: "MD5(CAST({} AS CHAR))",
|
||||
Dialects.BIGQUERY: "TO_HEX(MD5(CAST({} AS STRING)))",
|
||||
Dialects.SNOWFLAKE: "MD5(TO_VARCHAR({}))",
|
||||
Dialects.REDSHIFT: "MD5(CAST({} AS VARCHAR))",
|
||||
Dialects.PRESTO: "TO_HEX(MD5(CAST({} AS VARBINARY)))",
|
||||
Dialects.TRINO: "TO_HEX(MD5(CAST({} AS VARBINARY)))",
|
||||
Dialects.SQLITE: "HEX({})", # SQLite doesn't have MD5, use HEX as placeholder
|
||||
Dialects.DUCKDB: "MD5(CAST({} AS VARCHAR))",
|
||||
Dialects.ORACLE: "STANDARD_HASH(TO_CHAR({}), 'MD5')",
|
||||
Dialects.TSQL: (
|
||||
"CONVERT(VARCHAR(32), HASHBYTES('MD5', CAST({} AS VARCHAR(MAX))), 2)"
|
||||
),
|
||||
Dialects.HIVE: "MD5(CAST({} AS STRING))",
|
||||
Dialects.SPARK: "MD5(CAST({} AS STRING))",
|
||||
Dialects.CLICKHOUSE: "MD5(toString({}))",
|
||||
Dialects.DATABRICKS: "MD5(CAST({} AS STRING))",
|
||||
Dialects.DORIS: "MD5(CAST({} AS VARCHAR))",
|
||||
Dialects.STARROCKS: "MD5(CAST({} AS VARCHAR))",
|
||||
Dialects.DRILL: "MD5(CAST({} AS VARCHAR))",
|
||||
Dialects.DRUID: "MD5(CAST({} AS VARCHAR))",
|
||||
Dialects.TERADATA: "HASH_MD5(CAST({} AS VARCHAR(10000)))",
|
||||
Dialects.RISINGWAVE: "MD5(CAST({} AS VARCHAR))",
|
||||
}
|
||||
|
||||
|
||||
class CTASMethod(enum.Enum):
|
||||
TABLE = enum.auto()
|
||||
VIEW = enum.auto()
|
||||
@@ -279,47 +418,530 @@ class RLSAsSubqueryTransformer(RLSTransformer):
|
||||
return node
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class Table:
|
||||
class CLSTransformer:
|
||||
"""
|
||||
A fully qualified SQL table conforming to [[catalog.]schema.]table.
|
||||
AST transformer to apply Column-Level Security rules.
|
||||
|
||||
This transformer modifies SELECT expressions and predicates to apply CLS actions:
|
||||
- HASH: Replace column with hash function (database-specific)
|
||||
- NULLIFY: Replace with NULL AS column_name (SELECT) or FALSE (predicates)
|
||||
- HIDE: Remove column from SELECT entirely, FALSE in predicates
|
||||
- MASK: Replace column with '****' AS column_name (SELECT) or FALSE (predicates)
|
||||
|
||||
Example:
|
||||
Given rules: {Table("my_table"): {"id": CLSAction.HASH, "salary": CLSAction.NULLIFY}}
|
||||
Query: SELECT id, salary, name FROM my_table WHERE id = 1
|
||||
Result: SELECT MD5(CAST(id AS TEXT)), NULL AS salary, name FROM my_table
|
||||
WHERE MD5(CAST(id AS TEXT)) = 1
|
||||
|
||||
For predicates, HASH transforms the column to ensure filtered results also respect
|
||||
the security policy. NULLIFY/MASK/HIDE transform to FALSE to prevent information
|
||||
leakage through filtering.
|
||||
"""
|
||||
|
||||
table: str
|
||||
schema: str | None = None
|
||||
catalog: str | None = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Return the fully qualified SQL table name.
|
||||
|
||||
Should not be used for SQL generation, only for logging and debugging, since the
|
||||
quoting is not engine-specific.
|
||||
"""
|
||||
return ".".join(
|
||||
urllib.parse.quote(part, safe="").replace(".", "%2E")
|
||||
for part in [self.catalog, self.schema, self.table]
|
||||
if part
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return str(self) == str(other)
|
||||
|
||||
def qualify(
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
catalog: str | None = None,
|
||||
rules: CLSRules,
|
||||
dialect: Dialects | type[Dialect] | None,
|
||||
) -> None:
|
||||
self.rules = self._normalize_rules(rules)
|
||||
self.dialect = dialect
|
||||
self.hash_pattern = CLS_HASH_FUNCTIONS.get(dialect, CLS_HASH_FUNCTIONS[None])
|
||||
|
||||
def _normalize_rules(self, rules: CLSRules) -> dict[Table, dict[str, CLSAction]]:
|
||||
"""
|
||||
Normalize table and column names to lowercase for case-insensitive matching.
|
||||
"""
|
||||
return {
|
||||
Table(
|
||||
table=table.table.lower(),
|
||||
schema=table.schema.lower() if table.schema else None,
|
||||
catalog=table.catalog.lower() if table.catalog else None,
|
||||
): {col.lower(): action for col, action in cols.items()}
|
||||
for table, cols in rules.items()
|
||||
}
|
||||
|
||||
def _get_action(
|
||||
self,
|
||||
table_name: str | None,
|
||||
column_name: str,
|
||||
schema: str | None = None,
|
||||
) -> Table:
|
||||
catalog: str | None = None,
|
||||
) -> CLSAction | None:
|
||||
"""
|
||||
Return a new Table with the given schema and/or catalog, if not already set.
|
||||
Get the CLS action for a column, if any.
|
||||
|
||||
Matching logic:
|
||||
1. First try exact match with schema/catalog if provided
|
||||
2. Fallback to table name match - if table names match, apply the rule
|
||||
regardless of schema/catalog (since query may not have schema info)
|
||||
"""
|
||||
return Table(
|
||||
table=self.table,
|
||||
schema=self.schema or schema,
|
||||
catalog=self.catalog or catalog,
|
||||
if not table_name:
|
||||
return None
|
||||
|
||||
# Create a normalized Table for lookup
|
||||
lookup_table = Table(
|
||||
table=table_name.lower(),
|
||||
schema=schema.lower() if schema else None,
|
||||
catalog=catalog.lower() if catalog else None,
|
||||
)
|
||||
|
||||
# First try exact match with schema/catalog
|
||||
table_rules = self.rules.get(lookup_table)
|
||||
if table_rules:
|
||||
return table_rules.get(column_name.lower())
|
||||
|
||||
# Fallback: match by table name only
|
||||
# This handles cases where the rule has schema/catalog but the query doesn't
|
||||
for rule_table, cols in self.rules.items():
|
||||
if rule_table.table == lookup_table.table:
|
||||
action = cols.get(column_name.lower())
|
||||
if action:
|
||||
return action
|
||||
|
||||
return None
|
||||
|
||||
def _create_hash_expression(
|
||||
self,
|
||||
column: exp.Column,
|
||||
alias: str,
|
||||
) -> exp.Expression:
|
||||
"""
|
||||
Create a hash expression for a column.
|
||||
"""
|
||||
# Generate the column SQL without any alias
|
||||
col_sql = column.sql(dialect=self.dialect)
|
||||
hash_sql = self.hash_pattern.format(col_sql)
|
||||
hash_expr = sqlglot.parse_one(hash_sql, dialect=self.dialect)
|
||||
return exp.Alias(
|
||||
this=hash_expr,
|
||||
alias=exp.Identifier(this=alias),
|
||||
)
|
||||
|
||||
def _create_null_expression(self, alias: str) -> exp.Expression:
|
||||
"""
|
||||
Create a NULL AS alias expression.
|
||||
"""
|
||||
return exp.Alias(
|
||||
this=exp.Null(),
|
||||
alias=exp.Identifier(this=alias),
|
||||
)
|
||||
|
||||
def _create_mask_expression(
|
||||
self,
|
||||
column: exp.Column,
|
||||
alias: str,
|
||||
) -> exp.Expression:
|
||||
"""
|
||||
Create a CASE expression that masks non-NULL values while preserving NULLs.
|
||||
|
||||
Generates: CASE WHEN column IS NULL THEN NULL ELSE '****' END AS alias
|
||||
|
||||
This preserves the semantic meaning of NULL (no value) vs masked (hidden value).
|
||||
"""
|
||||
return exp.Alias(
|
||||
this=exp.Case(
|
||||
ifs=[
|
||||
exp.If(
|
||||
this=exp.Is(this=column.copy(), expression=exp.Null()),
|
||||
true=exp.Null(),
|
||||
)
|
||||
],
|
||||
default=exp.Literal(this="****", is_string=True),
|
||||
),
|
||||
alias=exp.Identifier(this=alias),
|
||||
)
|
||||
|
||||
def _create_hash_expression_no_alias(
|
||||
self,
|
||||
column: exp.Column,
|
||||
) -> exp.Expression:
|
||||
"""
|
||||
Create a hash expression for a column without an alias.
|
||||
|
||||
Used for transforming columns in predicates (WHERE, ON, etc.).
|
||||
"""
|
||||
col_sql = column.sql(dialect=self.dialect)
|
||||
hash_sql = self.hash_pattern.format(col_sql)
|
||||
return sqlglot.parse_one(hash_sql, dialect=self.dialect)
|
||||
|
||||
def _get_column_alias(self, expr: exp.Expression) -> str:
|
||||
"""
|
||||
Get the alias for a column expression.
|
||||
"""
|
||||
if isinstance(expr, exp.Alias):
|
||||
return expr.alias
|
||||
if isinstance(expr, exp.Column):
|
||||
return expr.name
|
||||
return expr.sql(dialect=self.dialect)
|
||||
|
||||
def _get_table_for_column(
|
||||
self,
|
||||
column: exp.Column,
|
||||
scope_tables: dict[str, str],
|
||||
) -> str | None:
|
||||
"""
|
||||
Resolve which table a column belongs to.
|
||||
|
||||
Args:
|
||||
column: The column expression
|
||||
scope_tables: Map of alias/name to actual table name
|
||||
|
||||
Returns:
|
||||
The table name or None if cannot be resolved
|
||||
"""
|
||||
if column.table:
|
||||
# Column is qualified with table name/alias
|
||||
return scope_tables.get(column.table.lower(), column.table)
|
||||
|
||||
# For unqualified columns, if there's only one table in scope,
|
||||
# we can infer the column belongs to that table
|
||||
if len(scope_tables) == 1:
|
||||
return next(iter(scope_tables.values()))
|
||||
|
||||
# With multiple tables, check if any table in rules has this column
|
||||
# This is a best-effort match for unqualified columns
|
||||
col_lower = column.name.lower()
|
||||
for table_name in scope_tables.values():
|
||||
# Look for a rule matching this table
|
||||
for rule_table, cols in self.rules.items():
|
||||
if rule_table.table == table_name.lower() and col_lower in cols:
|
||||
return table_name
|
||||
|
||||
return None
|
||||
|
||||
def _extract_scope_tables(self, select: exp.Select) -> dict[str, str]:
|
||||
"""
|
||||
Extract table names and aliases from a SELECT statement's FROM clause.
|
||||
|
||||
Returns a dict mapping alias (or table name if no alias) to actual table name.
|
||||
"""
|
||||
tables: dict[str, str] = {}
|
||||
|
||||
if from_clause := select.args.get("from"):
|
||||
for table in from_clause.find_all(exp.Table):
|
||||
table_name = table.name
|
||||
alias = table.alias if table.alias else table_name
|
||||
tables[alias.lower()] = table_name
|
||||
|
||||
for join in select.args.get("joins") or []:
|
||||
for table in join.find_all(exp.Table):
|
||||
table_name = table.name
|
||||
alias = table.alias if table.alias else table_name
|
||||
tables[alias.lower()] = table_name
|
||||
|
||||
return tables
|
||||
|
||||
def _transform_nested_column(
|
||||
self,
|
||||
column: exp.Column,
|
||||
scope_tables: dict[str, str],
|
||||
) -> exp.Expression:
|
||||
"""
|
||||
Transform a nested column reference within a SELECT expression.
|
||||
|
||||
This handles columns inside CASE expressions, function arguments, etc.
|
||||
Unlike top-level columns, nested columns use NULL for blocking instead
|
||||
of FALSE (which works better in non-predicate contexts).
|
||||
|
||||
- HASH: Replace with hash function
|
||||
- NULLIFY/MASK/HIDE: Replace with NULL (blocks computation safely)
|
||||
"""
|
||||
table_name = self._get_table_for_column(column, scope_tables)
|
||||
action = self._get_action(table_name, column.name)
|
||||
|
||||
if action is None:
|
||||
return column
|
||||
|
||||
if action == CLSAction.HASH:
|
||||
return self._create_hash_expression_no_alias(column)
|
||||
|
||||
# NULLIFY/MASK/HIDE: Return NULL to safely block any computation
|
||||
# NULL propagates through expressions: UPPER(NULL)→NULL, 1+NULL→NULL, etc.
|
||||
return exp.Null()
|
||||
|
||||
def _transform_expression(
|
||||
self,
|
||||
expr: exp.Expression,
|
||||
scope_tables: dict[str, str],
|
||||
) -> exp.Expression | None:
|
||||
"""
|
||||
Transform a single SELECT expression based on CLS rules.
|
||||
|
||||
For simple column references: apply full transformation with alias.
|
||||
For complex expressions: transform all nested column references.
|
||||
|
||||
Returns:
|
||||
- Transformed expression
|
||||
- None if a top-level column should be hidden
|
||||
"""
|
||||
# Get the underlying column (handle aliases)
|
||||
column = expr.this if isinstance(expr, exp.Alias) else expr
|
||||
alias = self._get_column_alias(expr)
|
||||
|
||||
if isinstance(column, exp.Column):
|
||||
# Simple column reference - apply full transformation with alias
|
||||
table_name = self._get_table_for_column(column, scope_tables)
|
||||
action = self._get_action(table_name, column.name)
|
||||
|
||||
if action is None:
|
||||
return expr
|
||||
|
||||
if action == CLSAction.HIDE:
|
||||
return None
|
||||
if action == CLSAction.HASH:
|
||||
return self._create_hash_expression(column, alias)
|
||||
if action == CLSAction.NULLIFY:
|
||||
return self._create_null_expression(alias)
|
||||
# action == CLSAction.MASK
|
||||
return self._create_mask_expression(column, alias)
|
||||
|
||||
# Complex expression (CASE, function, arithmetic, etc.)
|
||||
# Transform ALL nested column references within it
|
||||
def transform_nested(node: exp.Expression) -> exp.Expression:
|
||||
if isinstance(node, exp.Column):
|
||||
return self._transform_nested_column(node, scope_tables)
|
||||
return node
|
||||
|
||||
return expr.transform(transform_nested)
|
||||
|
||||
def _transform_star(
|
||||
self,
|
||||
star: exp.Star,
|
||||
scope_tables: dict[str, str],
|
||||
) -> list[exp.Expression]:
|
||||
"""
|
||||
Transform SELECT * by expanding hidden columns conceptually.
|
||||
|
||||
Since we don't have schema information, we cannot truly expand *.
|
||||
We return the star as-is but log a warning.
|
||||
"""
|
||||
# Without schema information, we cannot expand SELECT *
|
||||
# In a real implementation, you would need to query the database schema
|
||||
logger.warning(
|
||||
"CLS cannot fully process SELECT * without schema information. "
|
||||
"Consider using explicit column lists for queries with CLS rules."
|
||||
)
|
||||
return [star]
|
||||
|
||||
def _transform_non_select_column(
|
||||
self,
|
||||
column: exp.Column,
|
||||
scope_tables: dict[str, str],
|
||||
) -> exp.Expression:
|
||||
"""
|
||||
Transform a column reference outside of SELECT list.
|
||||
|
||||
This is the SINGLE transformation function for ALL column references
|
||||
outside the SELECT list (WHERE, HAVING, ON, GROUP BY, ORDER BY,
|
||||
window functions, CASE expressions, function arguments, etc.)
|
||||
|
||||
- HASH: Replace with hash function
|
||||
- NULLIFY/MASK/HIDE: Replace with FALSE (blocks predicates, marked for
|
||||
removal in GROUP BY/ORDER BY)
|
||||
"""
|
||||
table_name = self._get_table_for_column(column, scope_tables)
|
||||
action = self._get_action(table_name, column.name)
|
||||
|
||||
if action is None:
|
||||
return column
|
||||
|
||||
if action == CLSAction.HASH:
|
||||
return self._create_hash_expression_no_alias(column)
|
||||
|
||||
# NULLIFY/MASK/HIDE: Return FALSE to block usage
|
||||
# For predicates: FALSE blocks the filter
|
||||
# For GROUP BY/ORDER BY: Will be cleaned up in post-processing
|
||||
return exp.false()
|
||||
|
||||
@staticmethod
|
||||
def _is_blocked(node: exp.Expression) -> bool:
|
||||
"""Check if an expression is a blocked column (FALSE or NULL sentinel)."""
|
||||
# FALSE is used for blocked columns in predicates (Phase 2)
|
||||
# NULL is used for blocked columns in nested expressions (Phase 1)
|
||||
if isinstance(node, exp.Boolean) and not node.this:
|
||||
return True
|
||||
if isinstance(node, exp.Null):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _transform_all_non_select_columns(
|
||||
self,
|
||||
select: exp.Select,
|
||||
scope_tables: dict[str, str],
|
||||
) -> None:
|
||||
"""
|
||||
Transform ALL column references outside the SELECT list.
|
||||
|
||||
This uses sqlglot's transform() to recursively walk through the entire
|
||||
expression tree, ensuring we catch columns in:
|
||||
- WHERE clauses
|
||||
- HAVING clauses
|
||||
- JOIN ON conditions
|
||||
- GROUP BY clauses
|
||||
- ORDER BY clauses
|
||||
- Window function PARTITION BY / ORDER BY
|
||||
- CASE expressions
|
||||
- Function arguments
|
||||
- Any other nested expression
|
||||
|
||||
This is the security-critical function that ensures NO column reference
|
||||
is missed.
|
||||
"""
|
||||
|
||||
def transform_column(node: exp.Expression) -> exp.Expression:
|
||||
if isinstance(node, exp.Column):
|
||||
return self._transform_non_select_column(node, scope_tables)
|
||||
return node
|
||||
|
||||
# Transform WHERE
|
||||
if where := select.args.get("where"):
|
||||
transformed = where.this.transform(transform_column)
|
||||
select.set("where", exp.Where(this=transformed))
|
||||
|
||||
# Transform HAVING
|
||||
if having := select.args.get("having"):
|
||||
transformed = having.this.transform(transform_column)
|
||||
select.set("having", exp.Having(this=transformed))
|
||||
|
||||
# Transform all JOINs (ON conditions)
|
||||
for join in select.args.get("joins") or []:
|
||||
if on := join.args.get("on"):
|
||||
transformed = on.transform(transform_column)
|
||||
join.set("on", transformed)
|
||||
|
||||
# Transform GROUP BY and remove blocked (FALSE) expressions
|
||||
if group := select.args.get("group"):
|
||||
new_exprs = []
|
||||
for expr in group.expressions:
|
||||
transformed = expr.transform(transform_column)
|
||||
if not self._is_blocked(transformed):
|
||||
new_exprs.append(transformed)
|
||||
if new_exprs:
|
||||
group.set("expressions", new_exprs)
|
||||
else:
|
||||
select.set("group", None)
|
||||
|
||||
# Transform ORDER BY and remove blocked (FALSE) expressions
|
||||
if order := select.args.get("order"):
|
||||
new_exprs = []
|
||||
for ordered in order.expressions:
|
||||
transformed = ordered.transform(transform_column)
|
||||
# Check the inner expression (Ordered wraps the actual expr)
|
||||
inner = transformed.this if isinstance(transformed, exp.Ordered) else transformed
|
||||
if not self._is_blocked(inner):
|
||||
new_exprs.append(transformed)
|
||||
if new_exprs:
|
||||
order.set("expressions", new_exprs)
|
||||
else:
|
||||
select.set("order", None)
|
||||
|
||||
# Transform Window functions within SELECT expressions
|
||||
# Window functions have their own PARTITION BY and ORDER BY clauses
|
||||
for expr in select.args.get("expressions", []):
|
||||
for window in expr.find_all(exp.Window):
|
||||
# Transform PARTITION BY
|
||||
if partition_by := window.args.get("partition_by"):
|
||||
new_partition = []
|
||||
for part_expr in partition_by:
|
||||
transformed = part_expr.transform(transform_column)
|
||||
if not self._is_blocked(transformed):
|
||||
new_partition.append(transformed)
|
||||
window.set("partition_by", new_partition if new_partition else None)
|
||||
|
||||
# Transform ORDER BY within window
|
||||
if window_order := window.args.get("order"):
|
||||
new_order_exprs = []
|
||||
for ordered in window_order.expressions:
|
||||
transformed = ordered.transform(transform_column)
|
||||
inner = (
|
||||
transformed.this
|
||||
if isinstance(transformed, exp.Ordered)
|
||||
else transformed
|
||||
)
|
||||
if not self._is_blocked(inner):
|
||||
new_order_exprs.append(transformed)
|
||||
if new_order_exprs:
|
||||
window_order.set("expressions", new_order_exprs)
|
||||
else:
|
||||
window.set("order", None)
|
||||
|
||||
def transform_select(self, select: exp.Select) -> exp.Select:
|
||||
"""
|
||||
Transform a SELECT statement by applying CLS rules.
|
||||
|
||||
This is the main entry point for CLS transformation. It:
|
||||
1. Extracts table scope for column resolution
|
||||
2. Transforms SELECT list expressions (with HIDE removal and aliases)
|
||||
3. Transforms ALL other column references in the query
|
||||
"""
|
||||
scope_tables = self._extract_scope_tables(select)
|
||||
|
||||
# Phase 1: Transform SELECT list expressions
|
||||
# This handles HASH/NULLIFY/MASK with aliases, and removes HIDE columns
|
||||
expressions = select.args.get("expressions", [])
|
||||
new_expressions: list[exp.Expression] = []
|
||||
for expr in expressions:
|
||||
if isinstance(expr, exp.Star):
|
||||
new_expressions.extend(self._transform_star(expr, scope_tables))
|
||||
else:
|
||||
transformed = self._transform_expression(expr, scope_tables)
|
||||
if transformed is not None:
|
||||
new_expressions.append(transformed)
|
||||
select.set("expressions", new_expressions)
|
||||
|
||||
# Phase 2: Transform ALL other column references
|
||||
# This is the security-critical phase that catches every column reference
|
||||
self._transform_all_non_select_columns(select, scope_tables)
|
||||
|
||||
return select
|
||||
|
||||
def __call__(self, node: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Transform callback for sqlglot's transform method.
|
||||
"""
|
||||
if isinstance(node, exp.Select):
|
||||
return self.transform_select(node)
|
||||
return node
|
||||
|
||||
|
||||
def apply_cls(
|
||||
sql: str,
|
||||
rules: CLSRules,
|
||||
engine: str = "base",
|
||||
schema: dict[str, dict[str, str]] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Apply Column-Level Security rules to a SQL query.
|
||||
|
||||
This function transforms a SQL query by applying CLS actions to sensitive columns
|
||||
in both SELECT expressions and predicates (WHERE, ON, HAVING):
|
||||
|
||||
- HASH: Pseudonymize using database-specific hash function (both SELECT and predicates)
|
||||
- NULLIFY: Replace with NULL (SELECT), FALSE in predicates to block filtering
|
||||
- HIDE: Remove from SELECT results, FALSE in predicates to block filtering
|
||||
- MASK: Replace with '****' (SELECT), FALSE in predicates to block filtering
|
||||
|
||||
Args:
|
||||
sql: The SQL query to transform
|
||||
rules: CLS rules mapping Table objects to column actions
|
||||
Example: {Table("my_table"): {"id": CLSAction.HASH, "salary": CLSAction.NULLIFY}}
|
||||
Tables can include schema/catalog for fully qualified matching.
|
||||
engine: The database engine (used for dialect-specific hash functions)
|
||||
schema: Optional schema for column qualification. Required for JOINs with
|
||||
ambiguous column names. Format: {"table": {"column": "TYPE", ...}, ...}
|
||||
|
||||
Returns:
|
||||
The transformed SQL query
|
||||
"""
|
||||
if not rules:
|
||||
return sql
|
||||
|
||||
statement = SQLStatement(sql, engine)
|
||||
statement.apply_cls(rules, schema=schema)
|
||||
|
||||
return statement.format(comments=True)
|
||||
|
||||
|
||||
# To avoid unnecessary parsing/formatting of queries, the statement has the concept of
|
||||
# an "internal representation", which is the AST of the SQL statement. For most of the
|
||||
@@ -887,6 +1509,46 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
transformer = transformers[method](catalog, schema, predicates)
|
||||
self._parsed = self._parsed.transform(transformer)
|
||||
|
||||
def apply_cls(
|
||||
self,
|
||||
rules: CLSRules,
|
||||
schema: dict[str, dict[str, str]] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Apply Column-Level Security rules to the statement inplace.
|
||||
|
||||
CLS rules transform sensitive columns in SELECT statements and predicates:
|
||||
- HASH: Pseudonymize using database-specific hash function (both SELECT and predicates)
|
||||
- NULLIFY: Replace with NULL (SELECT), FALSE in predicates to block filtering
|
||||
- HIDE: Remove from SELECT results, FALSE in predicates to block filtering
|
||||
- MASK: Replace with '****' (SELECT), FALSE in predicates to block filtering
|
||||
|
||||
:param rules: CLS rules mapping Table objects to column actions
|
||||
Example: {Table("my_table"): {"ssn": CLSAction.HASH, "salary": CLSAction.NULLIFY}}
|
||||
:param schema: Optional schema for column qualification. Required for JOINs
|
||||
with ambiguous column names. Format: {"table": {"column": "TYPE", ...}}
|
||||
"""
|
||||
if not rules:
|
||||
return
|
||||
|
||||
# Always attempt to qualify columns for better CLS resolution.
|
||||
# With schema: full qualification of all columns.
|
||||
# Without schema: qualifies single-table queries, partial for JOINs.
|
||||
from sqlglot.optimizer.qualify import qualify
|
||||
|
||||
# Only expand stars if schema is provided (from DAR with feature flag enabled)
|
||||
# to avoid potential errors in other contexts
|
||||
self._parsed = qualify(
|
||||
self._parsed,
|
||||
schema=schema,
|
||||
dialect=self._dialect,
|
||||
validate_qualify_columns=False,
|
||||
expand_stars=bool(schema),
|
||||
)
|
||||
|
||||
transformer = CLSTransformer(rules, self._dialect)
|
||||
self._parsed = self._parsed.transform(transformer)
|
||||
|
||||
|
||||
class KQLSplitState(enum.Enum):
|
||||
"""
|
||||
|
||||
@@ -67,6 +67,7 @@ from superset.utils.core import (
|
||||
from superset.utils.dates import now_as_float
|
||||
from superset.utils.decorators import stats_timing
|
||||
from superset.utils.rls import apply_rls
|
||||
from superset.data_access_rules.utils import apply_data_access_rules
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.core import Database
|
||||
@@ -419,6 +420,13 @@ def execute_sql_statements( # noqa: C901
|
||||
for statement in parsed_script.statements:
|
||||
apply_rls(query.database, query.catalog, default_schema, statement)
|
||||
|
||||
if is_feature_enabled("DATA_ACCESS_RULES"):
|
||||
default_schema = query.database.get_default_schema_for_query(query)
|
||||
for statement in parsed_script.statements:
|
||||
apply_data_access_rules(
|
||||
query.database, query.catalog, default_schema, statement
|
||||
)
|
||||
|
||||
if query.select_as_cta:
|
||||
# CTAS is valid when the last statement is a SELECT, while CVAS is valid when
|
||||
# there is only a single statement which must be a SELECT.
|
||||
|
||||
@@ -17,10 +17,73 @@
|
||||
from typing import Any
|
||||
|
||||
from flask_appbuilder import Model
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import and_, or_
|
||||
from sqlalchemy.sql.elements import BooleanClauseList
|
||||
|
||||
|
||||
def get_dar_dataset_filters(base_model: type[Model]) -> list[Any]:
|
||||
"""
|
||||
Get SQLAlchemy filters for DAR-allowed datasets.
|
||||
|
||||
Handles hierarchical permissions:
|
||||
- Database-level: allows all tables in the database
|
||||
- Catalog-level: allows all tables in the catalog
|
||||
- Schema-level: allows all tables in the schema
|
||||
- Table-level: allows only the specific table
|
||||
"""
|
||||
# pylint: disable=import-outside-toplevel
|
||||
import logging
|
||||
|
||||
from superset import is_feature_enabled
|
||||
from superset.connectors.sqla.models import Database
|
||||
|
||||
if not is_feature_enabled("DATA_ACCESS_RULES"):
|
||||
return []
|
||||
|
||||
try:
|
||||
from superset.data_access_rules.utils import get_all_allowed_entries
|
||||
|
||||
allowed_entries = get_all_allowed_entries()
|
||||
if not allowed_entries:
|
||||
return []
|
||||
|
||||
# Build OR filters for each allowed entry at its hierarchy level
|
||||
filters = []
|
||||
for entry in allowed_entries:
|
||||
# Start with database filter (always required)
|
||||
entry_filter = Database.database_name == entry.database
|
||||
|
||||
# Add catalog filter if specified
|
||||
if entry.catalog is not None:
|
||||
entry_filter = and_(
|
||||
entry_filter,
|
||||
base_model.catalog == entry.catalog,
|
||||
)
|
||||
|
||||
# Add schema filter if specified
|
||||
if entry.schema is not None:
|
||||
entry_filter = and_(
|
||||
entry_filter,
|
||||
base_model.schema == entry.schema,
|
||||
)
|
||||
|
||||
# Add table filter if specified
|
||||
if entry.table is not None:
|
||||
entry_filter = and_(
|
||||
entry_filter,
|
||||
base_model.table_name == entry.table,
|
||||
)
|
||||
|
||||
filters.append(entry_filter)
|
||||
|
||||
return filters
|
||||
except Exception as ex:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Error getting DAR dataset filters: %s", ex
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def get_dataset_access_filters(
|
||||
base_model: type[Model],
|
||||
*args: Any,
|
||||
@@ -34,10 +97,14 @@ def get_dataset_access_filters(
|
||||
schema_perms = security_manager.user_view_menu_names("schema_access")
|
||||
catalog_perms = security_manager.user_view_menu_names("catalog_access")
|
||||
|
||||
# Get DAR-based table filters
|
||||
dar_filters = get_dar_dataset_filters(base_model)
|
||||
|
||||
return or_(
|
||||
Database.id.in_(database_ids),
|
||||
base_model.perm.in_(perms),
|
||||
base_model.catalog_perm.in_(catalog_perms),
|
||||
base_model.schema_perm.in_(schema_perms),
|
||||
*dar_filters,
|
||||
*args,
|
||||
)
|
||||
|
||||
@@ -38,6 +38,17 @@ class RowLevelSecurityView(BaseSupersetView):
|
||||
return super().render_app_template()
|
||||
|
||||
|
||||
class DataAccessRulesView(BaseSupersetView):
|
||||
route_base = "/dataaccessrules"
|
||||
class_permission_name = "DataAccessRule"
|
||||
|
||||
@expose("/list/")
|
||||
@has_access
|
||||
@permission_name("read")
|
||||
def list(self) -> FlaskResponse:
|
||||
return super().render_app_template()
|
||||
|
||||
|
||||
class TableModelView(BaseSupersetView):
|
||||
class_permission_name = "Dataset"
|
||||
method_permission_name = MODEL_VIEW_RW_METHOD_PERMISSION_MAP
|
||||
|
||||
16
tests/unit_tests/data_access_rules/__init__.py
Normal file
16
tests/unit_tests/data_access_rules/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
261
tests/unit_tests/data_access_rules/schemas_test.py
Normal file
261
tests/unit_tests/data_access_rules/schemas_test.py
Normal file
@@ -0,0 +1,261 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Unit tests for Data Access Rules schemas.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from marshmallow import ValidationError
|
||||
|
||||
from superset.data_access_rules.schemas import (
|
||||
DataAccessRulePostSchema,
|
||||
DataAccessRulePutSchema,
|
||||
)
|
||||
|
||||
|
||||
def test_post_schema_valid_rule():
|
||||
"""Test that valid rule JSON is accepted."""
|
||||
schema = DataAccessRulePostSchema()
|
||||
data = {
|
||||
"role_id": 1,
|
||||
"rule": json.dumps(
|
||||
{
|
||||
"allowed": [{"database": "mydb", "schema": "public"}],
|
||||
"denied": [],
|
||||
}
|
||||
),
|
||||
}
|
||||
result = schema.load(data)
|
||||
assert result["role_id"] == 1
|
||||
assert "allowed" in json.loads(result["rule"])
|
||||
|
||||
|
||||
def test_post_schema_complex_rule():
|
||||
"""Test that complex rule with RLS and CLS is accepted."""
|
||||
schema = DataAccessRulePostSchema()
|
||||
data = {
|
||||
"role_id": 1,
|
||||
"rule": json.dumps(
|
||||
{
|
||||
"allowed": [
|
||||
{"database": "mydb", "schema": "public"},
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "orders",
|
||||
"table": "items",
|
||||
"rls": {"predicate": "org_id = 123", "group_key": "org"},
|
||||
},
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "users",
|
||||
"table": "info",
|
||||
"cls": {"email": "mask", "ssn": "hide", "name": "hash"},
|
||||
},
|
||||
],
|
||||
"denied": [{"database": "mydb", "schema": "internal"}],
|
||||
}
|
||||
),
|
||||
}
|
||||
result = schema.load(data)
|
||||
assert result["role_id"] == 1
|
||||
|
||||
|
||||
def test_post_schema_invalid_json():
|
||||
"""Test that invalid JSON is rejected."""
|
||||
schema = DataAccessRulePostSchema()
|
||||
data = {
|
||||
"role_id": 1,
|
||||
"rule": "not valid json",
|
||||
}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
schema.load(data)
|
||||
assert "Invalid JSON" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_post_schema_rule_not_object():
|
||||
"""Test that non-object rule is rejected."""
|
||||
schema = DataAccessRulePostSchema()
|
||||
data = {
|
||||
"role_id": 1,
|
||||
"rule": json.dumps(["not", "an", "object"]),
|
||||
}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
schema.load(data)
|
||||
assert "must be a JSON object" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_post_schema_allowed_not_list():
|
||||
"""Test that non-list 'allowed' is rejected."""
|
||||
schema = DataAccessRulePostSchema()
|
||||
data = {
|
||||
"role_id": 1,
|
||||
"rule": json.dumps({"allowed": "not a list", "denied": []}),
|
||||
}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
schema.load(data)
|
||||
assert "'allowed' must be a list" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_post_schema_denied_not_list():
|
||||
"""Test that non-list 'denied' is rejected."""
|
||||
schema = DataAccessRulePostSchema()
|
||||
data = {
|
||||
"role_id": 1,
|
||||
"rule": json.dumps({"allowed": [], "denied": "not a list"}),
|
||||
}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
schema.load(data)
|
||||
assert "'denied' must be a list" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_post_schema_entry_not_object():
|
||||
"""Test that non-object entry is rejected."""
|
||||
schema = DataAccessRulePostSchema()
|
||||
data = {
|
||||
"role_id": 1,
|
||||
"rule": json.dumps({"allowed": ["not an object"], "denied": []}),
|
||||
}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
schema.load(data)
|
||||
assert "must be an object" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_post_schema_entry_missing_database():
|
||||
"""Test that entry without 'database' is rejected."""
|
||||
schema = DataAccessRulePostSchema()
|
||||
data = {
|
||||
"role_id": 1,
|
||||
"rule": json.dumps({"allowed": [{"schema": "public"}], "denied": []}),
|
||||
}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
schema.load(data)
|
||||
assert "'database' field" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_post_schema_invalid_cls_action():
|
||||
"""Test that invalid CLS action is rejected."""
|
||||
schema = DataAccessRulePostSchema()
|
||||
data = {
|
||||
"role_id": 1,
|
||||
"rule": json.dumps(
|
||||
{
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"cls": {"email": "invalid_action"},
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
),
|
||||
}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
schema.load(data)
|
||||
assert "Invalid CLS action" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_post_schema_missing_role_id():
|
||||
"""Test that missing role_id is rejected."""
|
||||
schema = DataAccessRulePostSchema()
|
||||
data = {
|
||||
"rule": json.dumps({"allowed": [], "denied": []}),
|
||||
}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
schema.load(data)
|
||||
assert "role_id" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_post_schema_missing_rule():
|
||||
"""Test that missing rule is rejected."""
|
||||
schema = DataAccessRulePostSchema()
|
||||
data = {
|
||||
"role_id": 1,
|
||||
}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
schema.load(data)
|
||||
assert "rule" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_put_schema_partial_update():
|
||||
"""Test that PUT schema allows partial updates."""
|
||||
schema = DataAccessRulePutSchema()
|
||||
|
||||
# Only updating role_id
|
||||
data = {"role_id": 2}
|
||||
result = schema.load(data)
|
||||
assert result == {"role_id": 2}
|
||||
|
||||
# Only updating rule
|
||||
data = {"rule": json.dumps({"allowed": [{"database": "newdb"}], "denied": []})}
|
||||
result = schema.load(data)
|
||||
assert "rule" in result
|
||||
|
||||
|
||||
def test_put_schema_validates_rule_if_provided():
|
||||
"""Test that PUT schema validates rule if provided."""
|
||||
schema = DataAccessRulePutSchema()
|
||||
data = {
|
||||
"rule": "invalid json",
|
||||
}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
schema.load(data)
|
||||
assert "Invalid JSON" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_post_schema_empty_allowed_denied():
|
||||
"""Test that empty allowed and denied lists are valid."""
|
||||
schema = DataAccessRulePostSchema()
|
||||
data = {
|
||||
"role_id": 1,
|
||||
"rule": json.dumps({"allowed": [], "denied": []}),
|
||||
}
|
||||
result = schema.load(data)
|
||||
assert result["role_id"] == 1
|
||||
|
||||
|
||||
def test_post_schema_cls_all_valid_actions():
|
||||
"""Test all valid CLS actions are accepted."""
|
||||
schema = DataAccessRulePostSchema()
|
||||
data = {
|
||||
"role_id": 1,
|
||||
"rule": json.dumps(
|
||||
{
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"cls": {
|
||||
"col1": "hash",
|
||||
"col2": "HASH", # Case insensitive
|
||||
"col3": "nullify",
|
||||
"col4": "NULLIFY",
|
||||
"col5": "mask",
|
||||
"col6": "MASK",
|
||||
"col7": "hide",
|
||||
"col8": "HIDE",
|
||||
},
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
),
|
||||
}
|
||||
result = schema.load(data)
|
||||
assert result["role_id"] == 1
|
||||
768
tests/unit_tests/data_access_rules/utils_test.py
Normal file
768
tests/unit_tests/data_access_rules/utils_test.py
Normal file
@@ -0,0 +1,768 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Unit tests for Data Access Rules utility functions.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from superset.data_access_rules.models import DataAccessRule
|
||||
from superset.data_access_rules.utils import (
|
||||
_is_more_specific,
|
||||
_matches_rule_entry,
|
||||
AccessCheckResult,
|
||||
check_table_access,
|
||||
get_all_group_keys,
|
||||
get_cls_rules_for_table,
|
||||
get_rls_predicates_for_table,
|
||||
)
|
||||
from superset.sql.parse import CLSAction, Table
|
||||
|
||||
|
||||
# Tests for _matches_rule_entry
|
||||
def test_matches_rule_entry_database_only():
|
||||
"""Test matching when rule specifies only database."""
|
||||
entry = {"database": "mydb"}
|
||||
assert _matches_rule_entry(entry, "mydb", None, None, None) is True
|
||||
assert _matches_rule_entry(entry, "mydb", "catalog1", "schema1", "table1") is True
|
||||
assert _matches_rule_entry(entry, "otherdb", None, None, None) is False
|
||||
|
||||
|
||||
def test_matches_rule_entry_database_and_catalog():
|
||||
"""Test matching when rule specifies database and catalog."""
|
||||
entry = {"database": "mydb", "catalog": "cat1"}
|
||||
assert _matches_rule_entry(entry, "mydb", "cat1", None, None) is True
|
||||
assert _matches_rule_entry(entry, "mydb", "cat1", "schema1", "table1") is True
|
||||
assert _matches_rule_entry(entry, "mydb", "cat2", None, None) is False
|
||||
assert _matches_rule_entry(entry, "otherdb", "cat1", None, None) is False
|
||||
|
||||
|
||||
def test_matches_rule_entry_database_and_schema():
|
||||
"""Test matching when rule specifies database and schema (no catalog)."""
|
||||
entry = {"database": "mydb", "schema": "public"}
|
||||
assert _matches_rule_entry(entry, "mydb", None, "public", None) is True
|
||||
assert _matches_rule_entry(entry, "mydb", None, "public", "table1") is True
|
||||
assert _matches_rule_entry(entry, "mydb", None, "other", None) is False
|
||||
|
||||
|
||||
def test_matches_rule_entry_full_table():
|
||||
"""Test matching when rule specifies full table path."""
|
||||
entry = {"database": "mydb", "schema": "public", "table": "users"}
|
||||
assert _matches_rule_entry(entry, "mydb", None, "public", "users") is True
|
||||
assert _matches_rule_entry(entry, "mydb", None, "public", "orders") is False
|
||||
assert _matches_rule_entry(entry, "mydb", None, "other", "users") is False
|
||||
|
||||
|
||||
def test_matches_rule_entry_with_catalog():
|
||||
"""Test matching with catalog in the path."""
|
||||
entry = {
|
||||
"database": "mydb",
|
||||
"catalog": "main",
|
||||
"schema": "public",
|
||||
"table": "users",
|
||||
}
|
||||
assert _matches_rule_entry(entry, "mydb", "main", "public", "users") is True
|
||||
assert _matches_rule_entry(entry, "mydb", "other", "public", "users") is False
|
||||
|
||||
|
||||
# Tests for _is_more_specific
|
||||
def test_is_more_specific():
|
||||
"""Test specificity comparison between entries."""
|
||||
db_only = {"database": "mydb"}
|
||||
db_schema = {"database": "mydb", "schema": "public"}
|
||||
db_table = {"database": "mydb", "schema": "public", "table": "users"}
|
||||
db_catalog = {"database": "mydb", "catalog": "main"}
|
||||
db_catalog_schema = {"database": "mydb", "catalog": "main", "schema": "public"}
|
||||
|
||||
# More specific should win
|
||||
assert _is_more_specific(db_schema, db_only) is True
|
||||
assert _is_more_specific(db_table, db_schema) is True
|
||||
assert _is_more_specific(db_table, db_only) is True
|
||||
|
||||
# Less specific should lose
|
||||
assert _is_more_specific(db_only, db_schema) is False
|
||||
assert _is_more_specific(db_schema, db_table) is False
|
||||
|
||||
# Same specificity
|
||||
assert _is_more_specific(db_schema, db_catalog) is False
|
||||
assert _is_more_specific(db_catalog, db_schema) is False
|
||||
|
||||
|
||||
# Tests for check_table_access
|
||||
def test_check_table_access_no_rules():
|
||||
"""Test access check when no rules are provided."""
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
result = check_table_access("mydb", table, rules=[])
|
||||
assert result.access == AccessCheckResult.NO_RULE
|
||||
assert result.rls_predicates == []
|
||||
assert result.cls_rules == {}
|
||||
|
||||
|
||||
def test_check_table_access_allowed():
|
||||
"""Test access check when table is allowed."""
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [{"database": "mydb", "schema": "public"}],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
result = check_table_access("mydb", table, rules=[rule])
|
||||
assert result.access == AccessCheckResult.ALLOWED
|
||||
assert result.rls_predicates == []
|
||||
assert result.cls_rules == {}
|
||||
|
||||
|
||||
def test_check_table_access_denied():
|
||||
"""Test access check when table is denied."""
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [],
|
||||
"denied": [{"database": "mydb", "schema": "public"}],
|
||||
}
|
||||
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
result = check_table_access("mydb", table, rules=[rule])
|
||||
assert result.access == AccessCheckResult.DENIED
|
||||
|
||||
|
||||
def test_check_table_access_denied_more_specific():
|
||||
"""Test that more specific deny wins over less specific allow."""
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [{"database": "mydb"}], # Less specific
|
||||
"denied": [{"database": "mydb", "schema": "secret"}], # More specific
|
||||
}
|
||||
|
||||
# Table in non-denied schema should be allowed
|
||||
table_public = Table(table="users", schema="public", catalog=None)
|
||||
result = check_table_access("mydb", table_public, rules=[rule])
|
||||
assert result.access == AccessCheckResult.ALLOWED
|
||||
|
||||
# Table in denied schema should be denied
|
||||
table_secret = Table(table="data", schema="secret", catalog=None)
|
||||
result = check_table_access("mydb", table_secret, rules=[rule])
|
||||
assert result.access == AccessCheckResult.DENIED
|
||||
|
||||
|
||||
def test_check_table_access_allowed_more_specific():
|
||||
"""Test that more specific allow wins over less specific deny."""
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [{"database": "mydb", "schema": "public", "table": "users"}],
|
||||
"denied": [{"database": "mydb", "schema": "public"}],
|
||||
}
|
||||
|
||||
# The specific table is allowed despite schema being denied
|
||||
table_users = Table(table="users", schema="public", catalog=None)
|
||||
result = check_table_access("mydb", table_users, rules=[rule])
|
||||
assert result.access == AccessCheckResult.ALLOWED
|
||||
|
||||
# Other tables in the schema are still denied
|
||||
table_orders = Table(table="orders", schema="public", catalog=None)
|
||||
result = check_table_access("mydb", table_orders, rules=[rule])
|
||||
assert result.access == AccessCheckResult.DENIED
|
||||
|
||||
|
||||
def test_check_table_access_same_specificity_deny_wins():
|
||||
"""Test that deny wins when rules have same specificity."""
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [{"database": "mydb", "schema": "public"}],
|
||||
"denied": [{"database": "mydb", "schema": "public"}],
|
||||
}
|
||||
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
result = check_table_access("mydb", table, rules=[rule])
|
||||
assert result.access == AccessCheckResult.DENIED
|
||||
|
||||
|
||||
def test_check_table_access_with_rls():
|
||||
"""Test access check collects RLS predicates."""
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"table": "users",
|
||||
"rls": {"predicate": "org_id = 123", "group_key": "org"},
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
result = check_table_access("mydb", table, rules=[rule])
|
||||
assert result.access == AccessCheckResult.ALLOWED
|
||||
assert len(result.rls_predicates) == 1
|
||||
assert result.rls_predicates[0].predicate == "org_id = 123"
|
||||
assert result.rls_predicates[0].group_key == "org"
|
||||
|
||||
|
||||
def test_check_table_access_multiple_rls():
|
||||
"""Test access check collects multiple RLS predicates from different rules."""
|
||||
rule1 = MagicMock(spec=DataAccessRule)
|
||||
rule1.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"rls": {"predicate": "org_id = 123", "group_key": "org"},
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
rule2 = MagicMock(spec=DataAccessRule)
|
||||
rule2.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"rls": {"predicate": "region = 'US'"},
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
result = check_table_access("mydb", table, rules=[rule1, rule2])
|
||||
assert result.access == AccessCheckResult.ALLOWED
|
||||
assert len(result.rls_predicates) == 2
|
||||
|
||||
predicates = [p.predicate for p in result.rls_predicates]
|
||||
assert "org_id = 123" in predicates
|
||||
assert "region = 'US'" in predicates
|
||||
|
||||
|
||||
def test_check_table_access_with_cls():
|
||||
"""Test access check collects CLS rules."""
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"table": "users",
|
||||
"cls": {"email": "mask", "ssn": "hide", "name": "hash"},
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
result = check_table_access("mydb", table, rules=[rule])
|
||||
assert result.access == AccessCheckResult.ALLOWED
|
||||
assert result.cls_rules == {
|
||||
"email": CLSAction.MASK,
|
||||
"ssn": CLSAction.HIDE,
|
||||
"name": CLSAction.HASH,
|
||||
}
|
||||
|
||||
|
||||
def test_check_table_access_cls_strictest_wins():
|
||||
"""Test that strictest CLS action wins when multiple rules apply."""
|
||||
rule1 = MagicMock(spec=DataAccessRule)
|
||||
rule1.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"cls": {"email": "mask"}, # Less strict
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
rule2 = MagicMock(spec=DataAccessRule)
|
||||
rule2.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"cls": {"email": "hide"}, # More strict - should win
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
result = check_table_access("mydb", table, rules=[rule1, rule2])
|
||||
assert result.access == AccessCheckResult.ALLOWED
|
||||
assert result.cls_rules["email"] == CLSAction.HIDE
|
||||
|
||||
|
||||
# Tests for get_rls_predicates_for_table
|
||||
def test_get_rls_predicates_for_table_no_predicates():
|
||||
"""Test getting RLS predicates when there are none."""
|
||||
database = MagicMock()
|
||||
database.database_name = "mydb"
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [{"database": "mydb", "schema": "public"}],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
predicates = get_rls_predicates_for_table(table, database, rules=[rule])
|
||||
assert predicates == []
|
||||
|
||||
|
||||
def test_get_rls_predicates_for_table_with_predicates():
|
||||
"""Test getting RLS predicates."""
|
||||
database = MagicMock()
|
||||
database.database_name = "mydb"
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"rls": {"predicate": "org_id = 123"},
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
predicates = get_rls_predicates_for_table(table, database, rules=[rule])
|
||||
assert predicates == ["(org_id = 123)"]
|
||||
|
||||
|
||||
def test_get_rls_predicates_for_table_with_group_key():
|
||||
"""Test getting RLS predicates with group_key combines with OR."""
|
||||
database = MagicMock()
|
||||
database.database_name = "mydb"
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
|
||||
rule1 = MagicMock(spec=DataAccessRule)
|
||||
rule1.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"rls": {"predicate": "org_id = 1", "group_key": "org"},
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
rule2 = MagicMock(spec=DataAccessRule)
|
||||
rule2.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"rls": {"predicate": "org_id = 2", "group_key": "org"},
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
predicates = get_rls_predicates_for_table(table, database, rules=[rule1, rule2])
|
||||
# Same group_key predicates should be ORed
|
||||
assert len(predicates) == 1
|
||||
assert "(org_id = 1)" in predicates[0]
|
||||
assert "(org_id = 2)" in predicates[0]
|
||||
assert " OR " in predicates[0]
|
||||
|
||||
|
||||
def test_get_rls_predicates_for_table_mixed_group_keys():
|
||||
"""Test getting RLS predicates with mixed group_keys."""
|
||||
database = MagicMock()
|
||||
database.database_name = "mydb"
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"rls": {"predicate": "org_id = 1", "group_key": "org"},
|
||||
},
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"rls": {"predicate": "org_id = 2", "group_key": "org"},
|
||||
},
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"rls": {"predicate": "region = 'US'"}, # No group_key
|
||||
},
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
predicates = get_rls_predicates_for_table(table, database, rules=[rule])
|
||||
# Should have: ungrouped predicate + ORed group predicate = 2 items
|
||||
assert len(predicates) == 2
|
||||
|
||||
has_region = any("region = 'US'" in p for p in predicates)
|
||||
has_org_group = any("org_id = 1" in p and "org_id = 2" in p for p in predicates)
|
||||
assert has_region
|
||||
assert has_org_group
|
||||
|
||||
|
||||
# Tests for get_cls_rules_for_table
|
||||
def test_get_cls_rules_for_table_no_rules():
|
||||
"""Test getting CLS rules when there are none."""
|
||||
database = MagicMock()
|
||||
database.database_name = "mydb"
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [{"database": "mydb", "schema": "public"}],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
cls_rules = get_cls_rules_for_table(table, database, rules=[rule])
|
||||
assert cls_rules == {}
|
||||
|
||||
|
||||
def test_get_cls_rules_for_table_with_rules():
|
||||
"""Test getting CLS rules."""
|
||||
database = MagicMock()
|
||||
database.database_name = "mydb"
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"table": "users",
|
||||
"cls": {"email": "mask", "ssn": "hide"},
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
cls_rules = get_cls_rules_for_table(table, database, rules=[rule])
|
||||
assert cls_rules == {"email": CLSAction.MASK, "ssn": CLSAction.HIDE}
|
||||
|
||||
|
||||
# Tests for get_all_group_keys
|
||||
def test_get_all_group_keys_empty(app_context: None):
|
||||
"""Test getting group keys when none exist."""
|
||||
with patch("superset.data_access_rules.utils.db") as mock_db:
|
||||
mock_db.session.query.return_value.all.return_value = []
|
||||
keys = get_all_group_keys()
|
||||
assert keys == set()
|
||||
|
||||
|
||||
def test_get_all_group_keys_with_keys(app_context: None):
|
||||
"""Test getting group keys from rules."""
|
||||
rule1 = MagicMock(spec=DataAccessRule)
|
||||
rule1.rule_dict = {
|
||||
"allowed": [
|
||||
{"database": "mydb", "rls": {"predicate": "x=1", "group_key": "key1"}},
|
||||
{"database": "mydb", "rls": {"predicate": "x=2", "group_key": "key2"}},
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
rule2 = MagicMock(spec=DataAccessRule)
|
||||
rule2.rule_dict = {
|
||||
"allowed": [
|
||||
{"database": "mydb", "rls": {"predicate": "x=3", "group_key": "key1"}},
|
||||
{"database": "mydb", "rls": {"predicate": "x=4"}}, # No group_key
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
with patch("superset.data_access_rules.utils.db") as mock_db:
|
||||
mock_db.session.query.return_value.all.return_value = [rule1, rule2]
|
||||
keys = get_all_group_keys()
|
||||
assert keys == {"key1", "key2"}
|
||||
|
||||
|
||||
def test_get_all_group_keys_filtered_by_database(app_context: None):
|
||||
"""Test getting group keys filtered by database."""
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [
|
||||
{"database": "db1", "rls": {"predicate": "x=1", "group_key": "key1"}},
|
||||
{"database": "db2", "rls": {"predicate": "x=2", "group_key": "key2"}},
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
with patch("superset.data_access_rules.utils.db") as mock_db:
|
||||
mock_db.session.query.return_value.all.return_value = [rule]
|
||||
keys = get_all_group_keys(database_name="db1")
|
||||
assert keys == {"key1"}
|
||||
|
||||
|
||||
def test_get_all_group_keys_filtered_by_table(app_context: None):
|
||||
"""Test getting group keys filtered by table."""
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "db1",
|
||||
"schema": "public",
|
||||
"table": "users",
|
||||
"rls": {"predicate": "x=1", "group_key": "key1"},
|
||||
},
|
||||
{
|
||||
"database": "db1",
|
||||
"schema": "public",
|
||||
"table": "orders",
|
||||
"rls": {"predicate": "x=2", "group_key": "key2"},
|
||||
},
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
with patch("superset.data_access_rules.utils.db") as mock_db:
|
||||
mock_db.session.query.return_value.all.return_value = [rule]
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
keys = get_all_group_keys(database_name="db1", table=table)
|
||||
assert keys == {"key1"}
|
||||
|
||||
|
||||
# Tests for get_hidden_columns_for_table
|
||||
def test_get_hidden_columns_for_table_no_hidden(app_context: None):
|
||||
"""Test getting hidden columns when no columns are hidden."""
|
||||
from superset.data_access_rules.utils import get_hidden_columns_for_table
|
||||
|
||||
database = MagicMock()
|
||||
database.database_name = "mydb"
|
||||
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"table": "users",
|
||||
"cls": {"email": "mask", "phone": "hash"}, # No "hide" actions
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
hidden = get_hidden_columns_for_table(table, database, rules=[rule])
|
||||
assert hidden == set()
|
||||
|
||||
|
||||
def test_get_hidden_columns_for_table_with_hidden(app_context: None):
|
||||
"""Test getting hidden columns when some columns are hidden."""
|
||||
from superset.data_access_rules.utils import get_hidden_columns_for_table
|
||||
|
||||
database = MagicMock()
|
||||
database.database_name = "mydb"
|
||||
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"table": "users",
|
||||
"cls": {"email": "mask", "ssn": "hide", "password": "hide"},
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
hidden = get_hidden_columns_for_table(table, database, rules=[rule])
|
||||
assert hidden == {"ssn", "password"}
|
||||
|
||||
|
||||
def test_get_hidden_columns_for_table_denied_access(app_context: None):
|
||||
"""Test that denied access returns no hidden columns."""
|
||||
from superset.data_access_rules.utils import get_hidden_columns_for_table
|
||||
|
||||
database = MagicMock()
|
||||
database.database_name = "mydb"
|
||||
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [],
|
||||
"denied": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"table": "users",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
hidden = get_hidden_columns_for_table(table, database, rules=[rule])
|
||||
# Denied access means no CLS rules are returned
|
||||
assert hidden == set()
|
||||
|
||||
|
||||
# Tests for filter_columns_by_cls
|
||||
def test_filter_columns_by_cls_no_hidden(app_context: None):
|
||||
"""Test filtering columns when no columns are hidden."""
|
||||
from superset.data_access_rules.utils import filter_columns_by_cls
|
||||
|
||||
database = MagicMock()
|
||||
database.database_name = "mydb"
|
||||
|
||||
columns = [
|
||||
{"column_name": "id", "type": "INTEGER"},
|
||||
{"column_name": "name", "type": "VARCHAR"},
|
||||
{"column_name": "email", "type": "VARCHAR"},
|
||||
]
|
||||
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [
|
||||
{"database": "mydb", "schema": "public", "table": "users"}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
|
||||
with patch(
|
||||
"superset.data_access_rules.utils.is_feature_enabled",
|
||||
return_value=True,
|
||||
):
|
||||
with patch(
|
||||
"superset.data_access_rules.utils.get_user_rules",
|
||||
return_value=[rule],
|
||||
):
|
||||
filtered = filter_columns_by_cls(columns, table, database)
|
||||
assert len(filtered) == 3
|
||||
assert filtered == columns
|
||||
|
||||
|
||||
def test_filter_columns_by_cls_with_hidden(app_context: None):
|
||||
"""Test filtering columns when some columns are hidden."""
|
||||
from superset.data_access_rules.utils import filter_columns_by_cls
|
||||
|
||||
database = MagicMock()
|
||||
database.database_name = "mydb"
|
||||
|
||||
columns = [
|
||||
{"column_name": "id", "type": "INTEGER"},
|
||||
{"column_name": "name", "type": "VARCHAR"},
|
||||
{"column_name": "email", "type": "VARCHAR"},
|
||||
{"column_name": "ssn", "type": "VARCHAR"},
|
||||
]
|
||||
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"table": "users",
|
||||
"cls": {"ssn": "hide"},
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
|
||||
with patch(
|
||||
"superset.data_access_rules.utils.is_feature_enabled",
|
||||
return_value=True,
|
||||
):
|
||||
with patch(
|
||||
"superset.data_access_rules.utils.get_user_rules",
|
||||
return_value=[rule],
|
||||
):
|
||||
filtered = filter_columns_by_cls(columns, table, database)
|
||||
assert len(filtered) == 3
|
||||
column_names = [c["column_name"] for c in filtered]
|
||||
assert "ssn" not in column_names
|
||||
assert "id" in column_names
|
||||
assert "name" in column_names
|
||||
assert "email" in column_names
|
||||
|
||||
|
||||
def test_filter_columns_by_cls_feature_disabled(app_context: None):
|
||||
"""Test that filtering is skipped when feature flag is disabled."""
|
||||
from superset.data_access_rules.utils import filter_columns_by_cls
|
||||
|
||||
database = MagicMock()
|
||||
database.database_name = "mydb"
|
||||
|
||||
columns = [
|
||||
{"column_name": "id", "type": "INTEGER"},
|
||||
{"column_name": "ssn", "type": "VARCHAR"},
|
||||
]
|
||||
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
|
||||
with patch(
|
||||
"superset.data_access_rules.utils.is_feature_enabled",
|
||||
return_value=False,
|
||||
):
|
||||
# Even if there would be hidden columns, they are not filtered
|
||||
filtered = filter_columns_by_cls(columns, table, database)
|
||||
assert len(filtered) == 2
|
||||
assert filtered == columns
|
||||
|
||||
|
||||
def test_filter_columns_by_cls_custom_key(app_context: None):
|
||||
"""Test filtering columns with custom column name key."""
|
||||
from superset.data_access_rules.utils import filter_columns_by_cls
|
||||
|
||||
database = MagicMock()
|
||||
database.database_name = "mydb"
|
||||
|
||||
# Columns with different key structure (like from SQL Lab table metadata)
|
||||
columns = [
|
||||
{"name": "id", "type": "INTEGER"},
|
||||
{"name": "ssn", "type": "VARCHAR"},
|
||||
]
|
||||
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"table": "users",
|
||||
"cls": {"ssn": "hide"},
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
|
||||
with patch(
|
||||
"superset.data_access_rules.utils.is_feature_enabled",
|
||||
return_value=True,
|
||||
):
|
||||
with patch(
|
||||
"superset.data_access_rules.utils.get_user_rules",
|
||||
return_value=[rule],
|
||||
):
|
||||
filtered = filter_columns_by_cls(
|
||||
columns, table, database, column_name_key="name"
|
||||
)
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]["name"] == "id"
|
||||
@@ -1703,3 +1703,122 @@ def test_adhoc_column_with_spaces_in_full_query(database: Database) -> None:
|
||||
# Verify SELECT and FROM clauses are present
|
||||
assert "SELECT" in sql
|
||||
assert "FROM" in sql
|
||||
|
||||
|
||||
def test_get_query_str_extended_with_data_access_rules(
|
||||
mocker: MockerFixture,
|
||||
database: Database,
|
||||
) -> None:
|
||||
"""
|
||||
Test that get_query_str_extended calls apply_data_access_rules when enabled.
|
||||
|
||||
This test mocks the get_sqla_query to return a simple SQL query, then verifies
|
||||
that apply_data_access_rules is called when the feature flag is enabled.
|
||||
"""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||
from superset.models.helpers import SqlaQuery
|
||||
|
||||
table = SqlaTable(
|
||||
database=database,
|
||||
schema=None,
|
||||
table_name="t",
|
||||
columns=[TableColumn(column_name="a")],
|
||||
)
|
||||
|
||||
# Create a mock SqlaQuery result (fields must be in order per NamedTuple)
|
||||
mock_sqla_query = SqlaQuery(
|
||||
applied_template_filters=[],
|
||||
applied_filter_columns=[],
|
||||
rejected_filter_columns=[],
|
||||
cte=None,
|
||||
extra_cache_keys=[],
|
||||
labels_expected=["a"],
|
||||
prequeries=[],
|
||||
sqla_query=sa.select(sa.column("a")).select_from(sa.table("t")),
|
||||
)
|
||||
|
||||
# Mock get_sqla_query to return our simple query
|
||||
mocker.patch.object(table, "get_sqla_query", return_value=mock_sqla_query)
|
||||
|
||||
# Mock apply_data_access_rules
|
||||
mock_apply_dar = mocker.patch(
|
||||
"superset.models.helpers.apply_data_access_rules"
|
||||
)
|
||||
|
||||
# Mock is_feature_enabled to return True for DATA_ACCESS_RULES
|
||||
mocker.patch(
|
||||
"superset.models.helpers.is_feature_enabled",
|
||||
return_value=True,
|
||||
)
|
||||
|
||||
# Mock database.get_default_schema
|
||||
mocker.patch.object(
|
||||
database, "get_default_schema", return_value="public"
|
||||
)
|
||||
|
||||
query_obj = {
|
||||
"granularity": None,
|
||||
"from_dttm": None,
|
||||
"to_dttm": None,
|
||||
"columns": ["a"],
|
||||
"metrics": [],
|
||||
"orderby": [],
|
||||
"row_limit": 100,
|
||||
"filter": [],
|
||||
}
|
||||
|
||||
result = table.get_query_str_extended(query_obj)
|
||||
|
||||
# Verify we got a result
|
||||
assert result is not None
|
||||
assert result.sql is not None
|
||||
|
||||
# Verify apply_data_access_rules was called
|
||||
assert mock_apply_dar.called
|
||||
|
||||
|
||||
def test_get_from_clause_with_data_access_rules(
|
||||
mocker: MockerFixture,
|
||||
database: Database,
|
||||
) -> None:
|
||||
"""
|
||||
Test that get_from_clause calls apply_data_access_rules for virtual datasets.
|
||||
"""
|
||||
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||
|
||||
# Create a virtual dataset (has SQL defined)
|
||||
table = SqlaTable(
|
||||
database=database,
|
||||
schema=None,
|
||||
table_name="virtual_table",
|
||||
sql="SELECT * FROM base_table",
|
||||
columns=[TableColumn(column_name="a")],
|
||||
)
|
||||
|
||||
# Mock apply_data_access_rules
|
||||
mock_apply_dar = mocker.patch(
|
||||
"superset.models.helpers.apply_data_access_rules"
|
||||
)
|
||||
|
||||
# Mock is_feature_enabled to return True for DATA_ACCESS_RULES
|
||||
mocker.patch(
|
||||
"superset.models.helpers.is_feature_enabled",
|
||||
return_value=True,
|
||||
)
|
||||
|
||||
# Mock database.get_default_schema
|
||||
mocker.patch.object(
|
||||
database, "get_default_schema", return_value="public"
|
||||
)
|
||||
|
||||
try:
|
||||
table.get_from_clause()
|
||||
except Exception:
|
||||
pass # We're just testing that apply_data_access_rules is called
|
||||
|
||||
# Verify apply_data_access_rules was called
|
||||
assert mock_apply_dar.called
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -36,7 +36,9 @@ from superset.sql_lab import (
|
||||
get_sql_results,
|
||||
)
|
||||
from superset.utils.rls import apply_rls, get_predicates_for_table
|
||||
from superset.data_access_rules.utils import apply_data_access_rules
|
||||
from tests.conftest import with_config
|
||||
from tests.unit_tests.conftest import with_feature_flags
|
||||
from tests.unit_tests.models.core_test import oauth2_client_info
|
||||
|
||||
|
||||
@@ -301,3 +303,119 @@ def test_get_predicates_for_table(mocker: MockerFixture) -> None:
|
||||
|
||||
table = Table("t1", "public", "examples")
|
||||
assert get_predicates_for_table(table, database, "examples") == ["c1 = 1"]
|
||||
|
||||
|
||||
def test_apply_data_access_rules(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test the ``apply_data_access_rules`` helper function.
|
||||
"""
|
||||
from superset.data_access_rules.utils import (
|
||||
AccessCheckResult,
|
||||
RLSPredicate,
|
||||
TableAccessInfo,
|
||||
)
|
||||
|
||||
database = mocker.MagicMock()
|
||||
database.database_name = "test_db"
|
||||
database.get_default_schema_for_query.return_value = "public"
|
||||
database.get_default_catalog.return_value = "examples"
|
||||
database.db_engine_spec = PostgresEngineSpec
|
||||
|
||||
# Mock get_user_rules to return a rule with RLS predicates
|
||||
get_user_rules = mocker.patch(
|
||||
"superset.data_access_rules.utils.get_user_rules",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
# Mock check_table_access to return allowed access with RLS predicates
|
||||
mocker.patch(
|
||||
"superset.data_access_rules.utils.check_table_access",
|
||||
return_value=TableAccessInfo(
|
||||
access=AccessCheckResult.ALLOWED,
|
||||
rls_predicates=[
|
||||
RLSPredicate(predicate="org_id = 1", group_key=None),
|
||||
],
|
||||
cls_rules={},
|
||||
),
|
||||
)
|
||||
|
||||
# Mock is_feature_enabled
|
||||
mocker.patch(
|
||||
"superset.data_access_rules.utils.is_feature_enabled",
|
||||
return_value=True,
|
||||
)
|
||||
|
||||
parsed_statement = SQLStatement("SELECT * FROM t1", "postgresql")
|
||||
|
||||
apply_data_access_rules(database, "examples", "public", parsed_statement)
|
||||
|
||||
# Since we mocked the feature flag check, the function should have processed
|
||||
get_user_rules.assert_called_once()
|
||||
|
||||
|
||||
@with_feature_flags(DATA_ACCESS_RULES=True)
|
||||
@with_config(
|
||||
{
|
||||
"SQLLAB_PAYLOAD_MAX_MB": 50,
|
||||
"DISALLOWED_SQL_FUNCTIONS": {},
|
||||
"SQLLAB_CTAS_NO_LIMIT": False,
|
||||
"SQL_MAX_ROW": 100000,
|
||||
"QUERY_LOGGER": None,
|
||||
"TROUBLESHOOTING_LINK": None,
|
||||
"STATS_LOGGER": MagicMock(),
|
||||
}
|
||||
)
|
||||
def test_execute_sql_statements_with_data_access_rules(
|
||||
mocker: MockerFixture, app
|
||||
) -> None:
|
||||
"""
|
||||
Test that `execute_sql_statements` calls `apply_data_access_rules`
|
||||
when the DATA_ACCESS_RULES feature flag is enabled.
|
||||
"""
|
||||
# Mock apply_data_access_rules to track calls
|
||||
mock_apply_dar = mocker.patch("superset.sql_lab.apply_data_access_rules")
|
||||
|
||||
# Mock the query object and database
|
||||
query = mocker.MagicMock()
|
||||
query.limit = 1
|
||||
query.database = mocker.MagicMock()
|
||||
query.database.cache_timeout = 100
|
||||
query.status = "RUNNING"
|
||||
query.select_as_cta = False
|
||||
query.database.allow_run_async = True
|
||||
query.database.allow_dml = False
|
||||
query.catalog = "examples"
|
||||
query.database.get_default_schema_for_query.return_value = "public"
|
||||
|
||||
# Mock get_query to return our mocked query object
|
||||
mocker.patch("superset.sql_lab.get_query", return_value=query)
|
||||
|
||||
# Mock db.session.refresh
|
||||
mocker.patch("superset.sql_lab.db.session.refresh", return_value=None)
|
||||
|
||||
# Mock the results backend
|
||||
mocker.patch("superset.sql_lab.results_backend", return_value=True)
|
||||
|
||||
# Mock sys.getsizeof to simulate a small payload
|
||||
mocker.patch("sys.getsizeof", return_value=1000)
|
||||
|
||||
# Mock _serialize_payload
|
||||
mocker.patch(
|
||||
"superset.sql_lab._serialize_payload", return_value="serialized_payload"
|
||||
)
|
||||
|
||||
try:
|
||||
execute_sql_statements(
|
||||
query_id=1,
|
||||
rendered_query="SELECT 42 AS answer",
|
||||
return_results=True,
|
||||
store_results=True,
|
||||
start_time=None,
|
||||
expand_data=False,
|
||||
log_params={},
|
||||
)
|
||||
except Exception:
|
||||
pass # We're just testing that apply_data_access_rules is called
|
||||
|
||||
# Verify apply_data_access_rules was called
|
||||
assert mock_apply_dar.called
|
||||
|
||||
Reference in New Issue
Block a user