Compare commits

...

19 Commits

Author SHA1 Message Date
Beto Dealmeida
40787496e2 Fix DB toggle 2025-12-18 10:40:42 -05:00
Beto Dealmeida
b180cc4fca Fix icons 2025-12-18 10:24:27 -05:00
Beto Dealmeida
78e4f55ede Add icons 2025-12-18 09:19:59 -05:00
Beto Dealmeida
853d57337d Fix role name 2025-12-17 22:30:01 -05:00
Beto Dealmeida
3ad694e5a9 Add name and description 2025-12-17 21:52:32 -05:00
Beto Dealmeida
b66729ad08 RLS 2025-12-17 21:36:04 -05:00
Beto Dealmeida
c5d2329297 CLS implementation 2025-12-17 21:03:21 -05:00
Beto Dealmeida
ca7635dfc2 UI progress 2025-12-17 19:34:52 -05:00
Beto Dealmeida
0aaa13ab79 Tree 2025-12-17 18:36:33 -05:00
Beto Dealmeida
902509b1f0 Explore working 2025-12-17 17:29:18 -05:00
Beto Dealmeida
2e7df4614c RLS and CLS working 2025-12-17 17:08:58 -05:00
Beto Dealmeida
3554325104 Testing 2025-12-17 16:19:44 -05:00
Beto Dealmeida
b469b01e0f Initial UI 2025-12-17 12:29:44 -05:00
Beto Dealmeida
4d9378a818 Filter hidden columns 2025-12-17 11:41:09 -05:00
Beto Dealmeida
0141bdd2b0 Integrate into chart/explore query execution 2025-12-17 11:17:41 -05:00
Beto Dealmeida
808ba668ff Integrate apply_data_access_rules() into SQL Lab query execution 2025-12-17 11:02:25 -05:00
Beto Dealmeida
e9fc7c6f6c Initial rules 2025-12-17 10:50:40 -05:00
Beto Dealmeida
5c61c40704 Support filters 2025-12-16 11:31:35 -05:00
Beto Dealmeida
57a210f7d6 feat: column-level security 2025-12-15 17:01:11 -05:00
38 changed files with 8248 additions and 55 deletions

View File

@@ -105,7 +105,7 @@ class CeleryConfig:
CELERY_CONFIG = CeleryConfig
FEATURE_FLAGS = {"ALERT_REPORTS": True}
FEATURE_FLAGS = {"ALERT_REPORTS": True, "DATA_ACCESS_RULES": True}
ALERT_REPORTS_NOTIFICATION_DRY_RUN = True
WEBDRIVER_BASEURL = f"http://superset_app{os.environ.get('SUPERSET_APP_ROOT', '/')}/" # When using docker compose baseurl should be http://superset_nginx{ENV{BASEPATH}}/ # noqa: E501
# The base URL for the email report hyperlinks.

View File

@@ -134,6 +134,8 @@ export function mapRows<T extends object>(
) {
return rows.map(row => {
prepareRow(row);
return { rowId: row.id, ...row.original, ...row.getRowProps() };
// Spread getRowProps first so data properties from row.original take precedence
// This prevents HTML attributes like `role: "row"` from overwriting data properties
return { ...row.getRowProps(), rowId: row.id, ...row.original };
});
}

View File

@@ -34,6 +34,7 @@ export enum FeatureFlag {
CssTemplates = 'CSS_TEMPLATES',
DashboardVirtualization = 'DASHBOARD_VIRTUALIZATION',
DashboardRbac = 'DASHBOARD_RBAC',
DataAccessRules = 'DATA_ACCESS_RULES',
DatapanelClosedByDefault = 'DATAPANEL_CLOSED_BY_DEFAULT',
DateRangeTimeshiftsEnabled = 'DATE_RANGE_TIMESHIFTS_ENABLED',
/** @deprecated */

View File

@@ -0,0 +1,496 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { SupersetClient, t } from '@superset-ui/core';
import { css, styled } from '@apache-superset/core/ui';
import { useCallback, useEffect, useMemo, useState } from 'react';
import { ModalTitleWithIcon } from 'src/components/ModalTitleWithIcon';
import {
Modal,
AsyncSelect,
InfoTooltip,
Input,
Collapse,
} from '@superset-ui/core/components';
import rison from 'rison';
import { useSingleViewResource } from 'src/views/CRUD/hooks';
import { DataAccessRuleObject } from './types';
import PermissionsTree from './PermissionsTree';
import type { PermissionsPayload } from './PermissionsTree/types';
const StyledModal = styled(Modal)`
max-width: 1200px;
min-width: min-content;
width: 100%;
.ant-modal-footer {
white-space: nowrap;
}
`;
const StyledSectionContainer = styled.div`
${({ theme }) => css`
display: flex;
flex-direction: column;
padding: ${theme.sizeUnit * 3}px ${theme.sizeUnit * 4}px
${theme.sizeUnit * 2}px;
label,
.control-label {
display: flex;
font-size: ${theme.fontSizeSM}px;
color: ${theme.colorTextLabel};
align-items: center;
}
.info-solid-small {
vertical-align: middle;
padding-bottom: ${theme.sizeUnit / 2}px;
}
`}
`;
const StyledInputContainer = styled.div`
${({ theme }) => css`
display: flex;
flex-direction: column;
margin: ${theme.sizeUnit}px;
margin-bottom: ${theme.sizeUnit * 4}px;
.input-container {
display: flex;
align-items: center;
> div {
width: 100%;
}
}
input,
textarea {
flex: 1 1 auto;
}
.required {
margin-left: ${theme.sizeUnit / 2}px;
color: ${theme.colorErrorText};
}
`}
`;
const StyledTextArea = styled(Input.TextArea)`
resize: vertical;
margin-top: ${({ theme }) => theme.sizeUnit}px;
font-family: monospace;
`;
const StyledCollapse = styled(Collapse)`
margin-top: ${({ theme }) => theme.sizeUnit * 2}px;
background: transparent;
.ant-collapse-header {
padding: ${({ theme }) => theme.sizeUnit}px 0 !important;
}
.ant-collapse-content-box {
padding: ${({ theme }) => theme.sizeUnit}px 0 !important;
}
`;
export interface DataAccessRuleModalProps {
rule: DataAccessRuleObject | null;
addSuccessToast: (msg: string) => void;
addDangerToast: (msg: string) => void;
onAdd?: (rule?: DataAccessRuleObject) => void;
onHide: () => void;
show: boolean;
}
const DEFAULT_RULE: DataAccessRuleObject = {
name: '',
description: '',
role_id: 0,
rule: JSON.stringify(
{
allowed: [],
denied: [],
},
null,
2,
),
};
type SelectValue = {
value: number;
label: string;
};
function DataAccessRuleModal(props: DataAccessRuleModalProps) {
const { rule, addDangerToast, addSuccessToast, onHide, show } = props;
const [currentRule, setCurrentRule] = useState<DataAccessRuleObject>({
...DEFAULT_RULE,
});
const [selectedRole, setSelectedRole] = useState<SelectValue | null>(null);
const [disableSave, setDisableSave] = useState<boolean>(true);
const [jsonError, setJsonError] = useState<string | null>(null);
const [permissionsPayload, setPermissionsPayload] =
useState<PermissionsPayload>({ allowed: [], denied: [] });
const [showAdvanced, setShowAdvanced] = useState<string[]>([]);
const isEditMode = rule !== null;
const {
state: { loading, resource, error: fetchError },
fetchResource,
createResource,
updateResource,
clearError,
} = useSingleViewResource<DataAccessRuleObject>(
'dar',
t('data access rule'),
addDangerToast,
);
const updateRuleState = (name: string, value: any) => {
setCurrentRule(currentRuleData => ({
...currentRuleData,
[name]: value,
}));
};
// Validate form
const validate = useCallback(() => {
// Check role is selected
if (!selectedRole?.value) {
setDisableSave(true);
return;
}
// If advanced mode is open, validate JSON
if (showAdvanced.includes('advanced')) {
try {
const parsed = JSON.parse(currentRule.rule);
if (typeof parsed !== 'object' || parsed === null) {
setJsonError(t('Rule must be a JSON object'));
setDisableSave(true);
return;
}
setJsonError(null);
} catch {
setJsonError(t('Invalid JSON'));
setDisableSave(true);
return;
}
}
setDisableSave(false);
}, [currentRule.rule, selectedRole, showAdvanced]);
// Initialize
useEffect(() => {
if (!isEditMode) {
setCurrentRule({ ...DEFAULT_RULE });
setSelectedRole(null);
setPermissionsPayload({ allowed: [], denied: [] });
setShowAdvanced([]);
} else if (rule?.id !== null && !loading && !fetchError) {
fetchResource(rule.id as number);
}
}, [rule]);
useEffect(() => {
if (resource) {
const ruleStr =
typeof resource.rule === 'string'
? resource.rule
: JSON.stringify(resource.rule, null, 2);
setCurrentRule({
...resource,
id: rule?.id,
rule: ruleStr,
});
// Parse rule into permissions payload
try {
const parsed =
typeof resource.rule === 'string'
? JSON.parse(resource.rule)
: resource.rule;
setPermissionsPayload({
allowed: parsed.allowed || [],
denied: parsed.denied || [],
});
} catch {
setPermissionsPayload({ allowed: [], denied: [] });
}
if (resource.role) {
setSelectedRole({
value: resource.role.id,
label: resource.role.name,
});
}
}
}, [resource]);
useEffect(() => {
validate();
}, [validate]);
const onRuleChange = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
updateRuleState('rule', event.target.value);
};
const onRoleChange = (value: SelectValue | null) => {
setSelectedRole(value);
if (value) {
updateRuleState('role_id', value.value);
}
};
const onPermissionsChange = useCallback((payload: PermissionsPayload) => {
setPermissionsPayload(payload);
// Update JSON representation
const ruleJson = JSON.stringify(payload, null, 2);
updateRuleState('rule', ruleJson);
}, []);
const onAdvancedChange = (keys: string | string[]) => {
const keyArray = Array.isArray(keys) ? keys : [keys];
setShowAdvanced(keyArray);
// When opening advanced, sync JSON from permissions
if (keyArray.includes('advanced')) {
const ruleJson = JSON.stringify(permissionsPayload, null, 2);
updateRuleState('rule', ruleJson);
}
// When closing advanced, sync permissions from JSON
if (!keyArray.includes('advanced') && showAdvanced.includes('advanced')) {
try {
const parsed = JSON.parse(currentRule.rule);
setPermissionsPayload({
allowed: parsed.allowed || [],
denied: parsed.denied || [],
});
} catch {
// Keep existing payload if JSON is invalid
}
}
};
const hide = () => {
clearError();
setCurrentRule({ ...DEFAULT_RULE });
setSelectedRole(null);
setJsonError(null);
setPermissionsPayload({ allowed: [], denied: [] });
setShowAdvanced([]);
onHide();
};
const onSave = () => {
const data = {
name: currentRule.name || null,
description: currentRule.description || null,
role_id: selectedRole?.value,
rule: currentRule.rule,
};
if (isEditMode && currentRule.id) {
updateResource(currentRule.id, data).then(response => {
if (!response) {
return;
}
addSuccessToast(t('Rule updated'));
hide();
});
} else {
createResource(data).then(response => {
if (!response) return;
addSuccessToast(t('Rule added'));
hide();
});
}
};
const loadRoleOptions = useMemo(
() =>
(input = '', page: number, pageSize: number) => {
const query = rison.encode({
filter: input,
page,
page_size: pageSize,
});
return SupersetClient.get({
endpoint: `/api/v1/dar/related/role?q=${query}`,
}).then(response => {
const list = response.json.result.map(
(item: { value: number; text: string }) => ({
label: item.text,
value: item.value,
}),
);
return { data: list, totalCount: response.json.count };
});
},
[],
);
return (
<StyledModal
className="no-content-padding"
responsive
show={show}
onHide={hide}
primaryButtonName={isEditMode ? t('Save') : t('Add')}
disablePrimaryButton={disableSave}
onHandledPrimaryAction={onSave}
width="50%"
maxWidth="800px"
title={
<ModalTitleWithIcon
isEditMode={isEditMode}
title={isEditMode ? t('Edit Data Access Rule') : t('Add Data Access Rule')}
data-test="dar-modal-title"
/>
}
>
<StyledSectionContainer>
<div className="main-section">
<StyledInputContainer>
<div className="control-label">
{t('Role')} <span className="required">*</span>
<InfoTooltip
tooltip={t(
'Select the role this rule applies to. Each role can have multiple data access rules.',
)}
/>
</div>
<div className="input-container">
<AsyncSelect
ariaLabel={t('Role')}
onChange={onRoleChange}
value={selectedRole}
options={loadRoleOptions}
data-test="role-select"
/>
</div>
</StyledInputContainer>
<StyledInputContainer>
<div className="control-label">
{t('Name')}
<InfoTooltip
tooltip={t('Optional name to help identify this rule.')}
/>
</div>
<div className="input-container">
<Input
name="name"
value={currentRule.name || ''}
onChange={e => updateRuleState('name', e.target.value)}
placeholder={t('e.g., Sales team access')}
data-test="rule-name"
/>
</div>
</StyledInputContainer>
<StyledInputContainer>
<div className="control-label">
{t('Description')}
<InfoTooltip
tooltip={t('Optional description of what this rule grants or restricts.')}
/>
</div>
<div className="input-container">
<Input.TextArea
name="description"
value={currentRule.description || ''}
onChange={e => updateRuleState('description', e.target.value)}
placeholder={t('Describe the purpose of this rule...')}
rows={2}
data-test="rule-description"
/>
</div>
</StyledInputContainer>
<StyledInputContainer>
<div className="control-label">
{t('Table Permissions')}
<InfoTooltip
tooltip={t(
'Select databases, schemas, and tables to allow or deny access. Click the icons to toggle between Allow (green), Deny (red), and Inherit (gray).',
)}
/>
</div>
<PermissionsTree
value={permissionsPayload}
onChange={onPermissionsChange}
/>
</StyledInputContainer>
<StyledCollapse
activeKey={showAdvanced}
onChange={onAdvancedChange}
ghost
items={[
{
key: 'advanced',
label: t('Advanced: Edit JSON directly'),
children: (
<>
<StyledInputContainer>
<div className="control-label">
{t('Rule (JSON)')}
<InfoTooltip
tooltip={t(
`Define the access rule as a JSON document with "allowed" and "denied" arrays. Each entry specifies database, catalog, schema, table, and optional RLS/CLS configurations.`,
)}
/>
</div>
<div className="input-container">
<StyledTextArea
rows={15}
name="rule"
value={currentRule.rule}
onChange={onRuleChange}
status={jsonError ? 'error' : undefined}
data-test="rule-json"
/>
</div>
{jsonError && (
<div style={{ color: 'red', marginTop: '4px' }}>
{jsonError}
</div>
)}
</StyledInputContainer>
</>
),
},
]}
/>
</div>
</StyledSectionContainer>
</StyledModal>
);
}
export default DataAccessRuleModal;

View File

@@ -0,0 +1,144 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { SupersetClient } from '@superset-ui/core';
import rison from 'rison';
export interface DatabaseInfo {
id: number;
database_name: string;
}
export interface FetchResult<T> {
result: T[];
count: number;
}
export async function fetchDatabases(): Promise<DatabaseInfo[]> {
const query = rison.encode({
columns: ['id', 'database_name'],
page_size: 1000,
});
const response = await SupersetClient.get({
endpoint: `/api/v1/database/?q=${query}`,
});
return response.json.result;
}
export async function fetchCatalogs(
databaseId: number,
filter?: string,
page?: number,
pageSize?: number,
): Promise<FetchResult<string>> {
const params: Record<string, unknown> = {};
if (filter) params.filter = filter;
if (page !== undefined) params.page = page;
if (pageSize !== undefined) params.page_size = pageSize;
const query = rison.encode(params);
const response = await SupersetClient.get({
endpoint: `/api/v1/database/${databaseId}/catalogs/?q=${query}`,
});
return {
result: response.json.result,
count: response.json.count ?? response.json.result.length,
};
}
export async function fetchSchemas(
databaseId: number,
catalog?: string,
filter?: string,
page?: number,
pageSize?: number,
): Promise<FetchResult<string>> {
const params: Record<string, unknown> = {};
if (catalog) params.catalog = catalog;
if (filter) params.filter = filter;
if (page !== undefined) params.page = page;
if (pageSize !== undefined) params.page_size = pageSize;
const query = rison.encode(params);
const response = await SupersetClient.get({
endpoint: `/api/v1/database/${databaseId}/schemas/?q=${query}`,
});
return {
result: response.json.result,
count: response.json.count ?? response.json.result.length,
};
}
export interface TableInfo {
value: string;
type: 'table' | 'view' | 'materialized_view';
extra?: Record<string, unknown>;
}
export async function fetchTables(
databaseId: number,
schemaName: string,
catalogName?: string,
filter?: string,
page?: number,
pageSize?: number,
): Promise<FetchResult<TableInfo>> {
const params: Record<string, unknown> = {
schema_name: schemaName,
};
if (catalogName) params.catalog_name = catalogName;
if (filter) params.filter = filter;
if (page !== undefined) params.page = page;
if (pageSize !== undefined) params.page_size = pageSize;
const query = rison.encode(params);
const response = await SupersetClient.get({
endpoint: `/api/v1/database/${databaseId}/tables/?q=${query}`,
});
return {
result: response.json.result,
count: response.json.count,
};
}
export interface ColumnInfo {
name: string;
type: string;
nullable?: boolean;
default?: string;
}
export async function fetchColumns(
databaseId: number,
tableName: string,
schemaName?: string,
catalogName?: string,
): Promise<ColumnInfo[]> {
const params: Record<string, string> = {
name: tableName,
};
if (schemaName) params.schema = schemaName;
if (catalogName) params.catalog = catalogName;
const queryString = new URLSearchParams(params).toString();
const response = await SupersetClient.get({
endpoint: `/api/v1/database/${databaseId}/table_metadata/?${queryString}`,
});
return response.json.columns || [];
}

View File

@@ -0,0 +1,832 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { t } from '@superset-ui/core';
import { css, styled } from '@apache-superset/core/ui';
import { Tree, Tooltip, Spin, Button, Select, Input } from 'antd';
import {
CheckCircleOutlined,
CloseCircleOutlined,
MinusCircleOutlined,
DatabaseOutlined,
FolderOutlined,
TableOutlined,
ColumnHeightOutlined,
NumberOutlined,
EyeInvisibleOutlined,
StopOutlined,
StarOutlined,
} from '@ant-design/icons';
import type { TreeProps } from 'antd';
import type {
PermissionNode,
TreeState,
PermissionsPayload,
CLSAction,
RLSRule,
} from './types';
import {
fetchDatabases,
fetchCatalogs,
fetchSchemas,
fetchTables,
fetchColumns,
type DatabaseInfo,
} from './api';
import {
makeKey,
makeClsKey,
getEffectiveState,
getExplicitState,
cyclePermissionState,
updateTreeData,
generatePermissionsPayload,
loadPermissionsFromPayload,
countDescendantPermissions,
} from './utils';
const DEFAULT_PAGE_SIZE = 50;
const StyledContainer = styled.div`
${({ theme }) => css`
.search-input {
margin-bottom: ${theme.sizeUnit * 2}px;
}
.tree-container {
border: 1px solid ${theme.colorBorder};
border-radius: ${theme.borderRadius}px;
padding: ${theme.sizeUnit * 2}px;
max-height: 500px;
overflow: auto;
}
.ant-tree-title {
display: flex;
align-items: center;
gap: ${theme.sizeUnit}px;
}
.permission-icon {
cursor: pointer;
font-size: 16px;
}
.permission-icon-allow {
color: ${theme.colorSuccess};
}
.permission-icon-deny {
color: ${theme.colorError};
}
.permission-icon-inherit {
color: ${theme.colorTextSecondary};
}
.node-title {
display: flex;
align-items: center;
gap: ${theme.sizeUnit}px;
}
.load-more-btn {
margin-left: ${theme.sizeUnit * 3}px;
font-size: 12px;
}
.node-count {
font-size: 11px;
color: ${theme.colorTextSecondary};
margin-left: ${theme.sizeUnit}px;
}
.cls-select {
width: 90px;
font-size: 12px;
}
.column-node {
display: flex;
align-items: center;
justify-content: space-between;
width: 100%;
max-width: 300px;
}
.column-name {
display: flex;
align-items: center;
gap: ${theme.sizeUnit}px;
}
.column-type {
font-size: 11px;
color: ${theme.colorTextSecondary};
}
.rls-container {
display: flex;
flex-direction: column;
gap: ${theme.sizeUnit}px;
padding: ${theme.sizeUnit}px 0;
margin-bottom: ${theme.sizeUnit}px;
border-bottom: 1px dashed ${theme.colorBorder};
}
.rls-row {
display: flex;
align-items: center;
gap: ${theme.sizeUnit}px;
}
.rls-label {
font-size: 11px;
color: ${theme.colorTextSecondary};
min-width: 70px;
}
.rls-input {
flex: 1;
max-width: 250px;
font-size: 12px;
}
.count-allowed {
color: ${theme.colorSuccess};
font-weight: ${theme.fontWeightSemiBold};
}
.count-denied {
color: ${theme.colorError};
font-weight: ${theme.fontWeightSemiBold};
}
`}
`;
export interface PermissionsTreeProps {
value?: PermissionsPayload;
onChange?: (payload: PermissionsPayload) => void;
pageSize?: number;
}
function PermissionsTree({
value,
onChange,
pageSize = DEFAULT_PAGE_SIZE,
}: PermissionsTreeProps) {
const [treeState, setTreeState] = useState<TreeState>({
expandedKeys: [],
loadedKeys: [],
treeData: [],
permissionStates: {},
clsRules: {},
rlsRules: {},
});
const [loading, setLoading] = useState(false);
const [databases, setDatabases] = useState<DatabaseInfo[]>([]);
// Map of database ID -> name and name -> ID
const databaseMaps = useMemo(() => {
const idToName = new Map<number, string>();
const nameToId = new Map<string, number>();
databases.forEach(db => {
idToName.set(db.id, db.database_name);
nameToId.set(db.database_name, db.id);
});
return { idToName, nameToId };
}, [databases]);
// Track if we're in the middle of an internal update to avoid circular updates
const isInternalUpdateRef = useRef(false);
// Load initial databases
useEffect(() => {
loadDatabases();
}, []);
// Load initial value (only on mount or when value changes externally)
useEffect(() => {
if (isInternalUpdateRef.current) {
isInternalUpdateRef.current = false;
return;
}
if (value && databaseMaps.nameToId.size > 0) {
const { states, clsRules, rlsRules } = loadPermissionsFromPayload(
value,
databaseMaps.nameToId,
);
setTreeState(prev => ({
...prev,
permissionStates: states,
clsRules,
rlsRules,
}));
}
}, [value, databaseMaps.nameToId]);
const loadDatabases = async () => {
setLoading(true);
try {
const dbs = await fetchDatabases();
setDatabases(dbs);
const treeData: PermissionNode[] = dbs.map(db => ({
key: makeKey({ databaseId: db.id }),
title: db.database_name,
nodeType: 'database',
databaseId: db.id,
databaseName: db.database_name,
children: [],
isLeaf: false,
}));
setTreeState(prev => ({
...prev,
treeData,
}));
} finally {
setLoading(false);
}
};
const loadChildren = useCallback(
async (node: PermissionNode): Promise<PermissionNode[]> => {
const { databaseId, catalogName, schemaName, nodeType } = node;
if (!databaseId) return [];
// Database level -> load catalogs or schemas
if (nodeType === 'database') {
// First try to load catalogs
try {
const catalogs = await fetchCatalogs(databaseId, undefined, 0, 1);
if (catalogs.result.length > 0) {
// Database has catalogs
const catalogsResult = await fetchCatalogs(
databaseId,
undefined,
0,
pageSize,
);
return catalogsResult.result.map(cat => ({
key: makeKey({ databaseId, catalogName: cat }),
title: cat,
nodeType: 'catalog' as const,
databaseId,
databaseName: node.databaseName,
catalogName: cat,
children: [],
isLeaf: false,
hasMore: catalogsResult.count > pageSize,
totalCount: catalogsResult.count,
}));
}
} catch {
// Database doesn't support catalogs, load schemas directly
}
// Load schemas directly
const schemasResult = await fetchSchemas(
databaseId,
undefined,
undefined,
0,
pageSize,
);
return schemasResult.result.map(schema => ({
key: makeKey({ databaseId, schemaName: schema }),
title: schema,
nodeType: 'schema' as const,
databaseId,
databaseName: node.databaseName,
schemaName: schema,
children: [],
isLeaf: false,
hasMore: schemasResult.count > pageSize,
totalCount: schemasResult.count,
}));
}
// Catalog level -> load schemas
if (nodeType === 'catalog' && catalogName) {
const schemasResult = await fetchSchemas(
databaseId,
catalogName,
undefined,
0,
pageSize,
);
return schemasResult.result.map(schema => ({
key: makeKey({ databaseId, catalogName, schemaName: schema }),
title: schema,
nodeType: 'schema' as const,
databaseId,
databaseName: node.databaseName,
catalogName,
schemaName: schema,
children: [],
isLeaf: false,
hasMore: schemasResult.count > pageSize,
totalCount: schemasResult.count,
}));
}
// Schema level -> load tables
if (nodeType === 'schema' && schemaName) {
const tablesResult = await fetchTables(
databaseId,
schemaName,
catalogName,
undefined,
0,
pageSize,
);
return tablesResult.result.map(table => ({
key: makeKey({
databaseId,
catalogName,
schemaName,
tableName: table.value,
}),
title: table.value,
nodeType: 'table' as const,
databaseId,
databaseName: node.databaseName,
catalogName,
schemaName,
tableName: table.value,
isLeaf: false, // Tables are expandable to show columns
children: [],
}));
}
// Table level -> load columns
if (nodeType === 'table' && node.tableName) {
const columns = await fetchColumns(
databaseId,
node.tableName,
schemaName,
catalogName,
);
const tableKey = node.key as string;
return columns.map(col => ({
key: `${tableKey}|col:${col.name}`,
title: col.name,
nodeType: 'column' as const,
databaseId,
databaseName: node.databaseName,
catalogName,
schemaName,
tableName: node.tableName,
columnName: col.name,
isLeaf: true,
}));
}
return [];
},
[pageSize],
);
const onLoadData = async (node: PermissionNode): Promise<void> => {
if (node.children && node.children.length > 0) {
return;
}
const children = await loadChildren(node);
const updatedTreeData = updateTreeData(
treeState.treeData,
node.key as string,
children,
);
setTreeState(prev => ({
...prev,
treeData: updatedTreeData,
loadedKeys: [...prev.loadedKeys, node.key as string],
}));
};
const onExpand: TreeProps['onExpand'] = expandedKeysValue => {
setTreeState(prev => ({
...prev,
expandedKeys: expandedKeysValue as string[],
}));
};
const handlePermissionClick = (
nodeKey: string,
event: React.MouseEvent,
) => {
event.stopPropagation();
const newStates = cyclePermissionState(nodeKey, treeState.permissionStates);
// Mark as internal update to prevent circular updates
isInternalUpdateRef.current = true;
setTreeState(prev => ({
...prev,
permissionStates: newStates,
}));
// Notify parent of changes
if (onChange && databaseMaps.idToName.size > 0) {
const payload = generatePermissionsPayload(
newStates,
databaseMaps.idToName,
treeState.clsRules,
treeState.rlsRules,
);
onChange(payload);
}
};
const handleRlsChange = (
tableKey: string,
field: 'predicate' | 'groupKey',
value: string,
) => {
// Mark as internal update to prevent circular updates
isInternalUpdateRef.current = true;
setTreeState(prev => {
const currentRls = prev.rlsRules[tableKey] || { predicate: '' };
const newRls: RLSRule = { ...currentRls, [field]: value };
// Remove the rule if both fields are empty
const newRlsRules = { ...prev.rlsRules };
if (!newRls.predicate && !newRls.groupKey) {
delete newRlsRules[tableKey];
} else {
newRlsRules[tableKey] = newRls;
}
// Notify parent of changes
if (onChange && databaseMaps.idToName.size > 0) {
const payload = generatePermissionsPayload(
prev.permissionStates,
databaseMaps.idToName,
prev.clsRules,
newRlsRules,
);
onChange(payload);
}
return {
...prev,
rlsRules: newRlsRules,
};
});
};
const handleClsChange = (
tableKey: string,
columnName: string,
action: CLSAction | undefined,
) => {
const clsKey = makeClsKey(tableKey, columnName);
// Mark as internal update to prevent circular updates
isInternalUpdateRef.current = true;
setTreeState(prev => {
const newClsRules = { ...prev.clsRules };
if (action) {
newClsRules[clsKey] = action;
} else {
delete newClsRules[clsKey];
}
// Notify parent of changes
if (onChange && databaseMaps.idToName.size > 0) {
const payload = generatePermissionsPayload(
prev.permissionStates,
databaseMaps.idToName,
newClsRules,
prev.rlsRules,
);
onChange(payload);
}
return {
...prev,
clsRules: newClsRules,
};
});
};
const getStateIcon = (nodeKey: string) => {
const explicitState = getExplicitState(nodeKey, treeState.permissionStates);
const effectiveState = getEffectiveState(
nodeKey,
treeState.permissionStates,
);
const getTooltipText = () => {
if (explicitState === 'allow') return t('Allowed (click to deny)');
if (explicitState === 'deny') return t('Denied (click to allow)');
if (effectiveState === 'allow')
return t('Inherits Allow (click to deny)');
return t('Inherits Deny (click to allow)');
};
if (explicitState === 'allow') {
return (
<Tooltip title={getTooltipText()}>
<CheckCircleOutlined
className="permission-icon permission-icon-allow"
onClick={e => handlePermissionClick(nodeKey, e)}
/>
</Tooltip>
);
}
if (explicitState === 'deny') {
return (
<Tooltip title={getTooltipText()}>
<CloseCircleOutlined
className="permission-icon permission-icon-deny"
onClick={e => handlePermissionClick(nodeKey, e)}
/>
</Tooltip>
);
}
return (
<Tooltip title={getTooltipText()}>
<MinusCircleOutlined
className="permission-icon permission-icon-inherit"
onClick={e => handlePermissionClick(nodeKey, e)}
/>
</Tooltip>
);
};
const getNodeIcon = (nodeType: string) => {
switch (nodeType) {
case 'database':
return <DatabaseOutlined />;
case 'catalog':
case 'schema':
return <FolderOutlined />;
case 'table':
return <TableOutlined />;
case 'column':
return <ColumnHeightOutlined />;
default:
return null;
}
};
const clsOptions = [
{ value: '', label: ' ' },
{
value: 'hash',
label: (
<span>
<NumberOutlined /> {t('Hash')}
</span>
),
},
{
value: 'mask',
label: (
<span>
<StarOutlined /> {t('Mask')}
</span>
),
},
{
value: 'nullify',
label: (
<span>
<StopOutlined /> {t('Null')}
</span>
),
},
{
value: 'hide',
label: (
<span>
<EyeInvisibleOutlined /> {t('Hide')}
</span>
),
},
];
// Count CLS rules for a given table
const countClsRules = (tableKey: string): number => {
const prefix = `${tableKey}::`;
return Object.keys(treeState.clsRules).filter(key => key.startsWith(prefix)).length;
};
// Check if a table has RLS rules
const hasRlsRules = (tableKey: string): boolean => {
const rls = treeState.rlsRules[tableKey];
return !!(rls && (rls.predicate || rls.groupKey));
};
const titleRender = (node: PermissionNode) => {
const nodeKey = node.key as string;
const isExpanded = treeState.expandedKeys.includes(nodeKey);
const title = node.title as string;
// Column nodes get special rendering with CLS dropdown
if (node.nodeType === 'column' && node.tableName && node.columnName) {
// Get the table key by removing the column part
const tableKey = nodeKey.substring(0, nodeKey.lastIndexOf('|col:'));
const clsKey = makeClsKey(tableKey, node.columnName);
const currentAction = treeState.clsRules[clsKey] || '';
return (
<div className="column-node">
<span className="column-name">
{getNodeIcon(node.nodeType)}
<span>{title}</span>
</span>
<Select
className="cls-select"
size="small"
value={currentAction}
onChange={(val: string) =>
handleClsChange(
tableKey,
node.columnName as string,
(val || undefined) as CLSAction | undefined,
)
}
onClick={e => e.stopPropagation()}
options={clsOptions}
/>
</div>
);
}
// Table nodes get RLS inputs when expanded
if (node.nodeType === 'table' && isExpanded) {
const currentRls = treeState.rlsRules[nodeKey] || {
predicate: '',
groupKey: '',
};
// Check if table has permissions, CLS, or RLS
const tableCounts = countDescendantPermissions(nodeKey, treeState.permissionStates);
const tableExplicitState = getExplicitState(nodeKey, treeState.permissionStates);
const clsCount = countClsRules(nodeKey);
const hasRls = hasRlsRules(nodeKey);
const tableHasConfig =
tableExplicitState === 'allow' ||
tableExplicitState === 'deny' ||
(tableCounts && (tableCounts.allowed > 0 || tableCounts.denied > 0)) ||
clsCount > 0 ||
hasRls;
return (
<div>
<div className="node-title">
{getStateIcon(nodeKey)}
{getNodeIcon(node.nodeType)}
<span style={tableHasConfig ? { fontWeight: 600 } : undefined}>{title}</span>
{(hasRls || clsCount > 0) && (
<span className="node-count">
({[
hasRls ? 'RLS' : null,
clsCount > 0 ? `CLS: ${clsCount}` : null,
].filter(Boolean).join(', ')})
</span>
)}
</div>
<div className="rls-container">
<div className="rls-row">
<span className="rls-label">{t('RLS Predicate')}</span>
<Input
className="rls-input"
size="small"
placeholder={t('e.g., org_id = {{current_user_id}}')}
value={currentRls.predicate}
onChange={e => handleRlsChange(nodeKey, 'predicate', e.target.value)}
onClick={e => e.stopPropagation()}
/>
</div>
<div className="rls-row">
<span className="rls-label">{t('Group Key')}</span>
<Input
className="rls-input"
size="small"
placeholder={t('Optional grouping key')}
value={currentRls.groupKey || ''}
onChange={e => handleRlsChange(nodeKey, 'groupKey', e.target.value)}
onClick={e => e.stopPropagation()}
/>
</div>
</div>
</div>
);
}
// Non-column nodes: check if leaf (columns are leafs, tables are not anymore)
const isLeaf = node.nodeType === 'column';
// Count descendants with custom permissions
const counts = !isLeaf
? countDescendantPermissions(nodeKey, treeState.permissionStates)
: null;
const hasCustomRules = counts && (counts.allowed > 0 || counts.denied > 0);
// Check if this node itself has explicit permissions
const explicitState = getExplicitState(nodeKey, treeState.permissionStates);
const hasExplicitPermission = explicitState === 'allow' || explicitState === 'deny';
// For tables, also check CLS/RLS
const isTable = node.nodeType === 'table';
const clsCount = isTable ? countClsRules(nodeKey) : 0;
const hasRls = isTable ? hasRlsRules(nodeKey) : false;
// Bold if node or any children have permissions, or if table has CLS/RLS
const shouldBeBold = hasExplicitPermission || hasCustomRules || clsCount > 0 || hasRls;
return (
<div className="node-title">
{getStateIcon(nodeKey)}
{getNodeIcon(node.nodeType)}
<span style={shouldBeBold ? { fontWeight: 600 } : undefined}>{title}</span>
{hasCustomRules && !isExpanded && (
<span className="node-count">
(
<span className="count-allowed">{counts.allowed}</span>
{' / '}
<span className="count-denied">{counts.denied}</span>
)
</span>
)}
{isTable && (hasRls || clsCount > 0) && !isExpanded && (
<span className="node-count">
({[
hasRls ? 'RLS' : null,
clsCount > 0 ? `CLS: ${clsCount}` : null,
].filter(Boolean).join(', ')})
</span>
)}
{node.hasMore && node.totalCount && (
<span className="node-count" style={{ marginLeft: 4 }}>
[{node.children?.length || 0}/{node.totalCount}]
</span>
)}
</div>
);
};
const clearAll = () => {
isInternalUpdateRef.current = true;
setTreeState(prev => ({
...prev,
permissionStates: {},
clsRules: {},
rlsRules: {},
}));
if (onChange) {
onChange({ allowed: [], denied: [] });
}
};
return (
<StyledContainer>
<div style={{ marginBottom: 8 }}>
<Button size="small" onClick={clearAll}>
{t('Reset All')}
</Button>
</div>
<Spin spinning={loading}>
<div className="tree-container">
<Tree
loadData={onLoadData as TreeProps['loadData']}
treeData={treeState.treeData}
expandedKeys={treeState.expandedKeys}
onExpand={onExpand}
titleRender={titleRender as TreeProps['titleRender']}
showLine={{ showLeafIcon: false }}
blockNode
/>
</div>
</Spin>
</StyledContainer>
);
}
export default PermissionsTree;

View File

@@ -0,0 +1,73 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import type { DataNode } from 'antd/es/tree';
export type PermissionState = 'allow' | 'deny' | 'inherit';
export type NodeType = 'database' | 'catalog' | 'schema' | 'table' | 'column';
export type CLSAction = 'hash' | 'mask' | 'nullify' | 'hide';
export interface PermissionNode extends DataNode {
key: string;
title: string;
nodeType: NodeType;
databaseId?: number;
databaseName?: string;
catalogName?: string;
schemaName?: string;
tableName?: string;
columnName?: string;
children?: PermissionNode[];
isLeaf?: boolean;
hasMore?: boolean;
totalCount?: number;
}
export interface RLSRule {
predicate: string;
groupKey?: string;
}
export interface TreeState {
expandedKeys: string[];
loadedKeys: string[];
treeData: PermissionNode[];
permissionStates: Record<string, PermissionState>;
clsRules: Record<string, CLSAction>; // key format: "tableKey::columnName" -> action
rlsRules: Record<string, RLSRule>; // key format: tableKey -> RLSRule
}
export interface PermissionsPayload {
allowed: PermissionEntry[];
denied: PermissionEntry[];
}
export interface PermissionEntry {
database: string;
catalog?: string;
schema?: string;
table?: string;
rls?: {
predicate: string;
group_key?: string;
};
cls?: Record<string, string>; // column name -> action
}

View File

@@ -0,0 +1,548 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import type {
PermissionNode,
PermissionState,
PermissionsPayload,
PermissionEntry,
NodeType,
CLSAction,
RLSRule,
} from './types';
// Key format: db:{id}|cat:{name}|schema:{name}|table:{name}
// Each part is optional based on the level
export function makeKey(parts: {
databaseId: number;
catalogName?: string;
schemaName?: string;
tableName?: string;
}): string {
let key = `db:${parts.databaseId}`;
if (parts.catalogName) {
key += `|cat:${parts.catalogName}`;
}
if (parts.schemaName) {
key += `|schema:${parts.schemaName}`;
}
if (parts.tableName) {
key += `|table:${parts.tableName}`;
}
return key;
}
export function parseKey(key: string): {
databaseId: number;
catalogName?: string;
schemaName?: string;
tableName?: string;
nodeType: NodeType;
} {
const parts = key.split('|');
const result: ReturnType<typeof parseKey> = {
databaseId: 0,
nodeType: 'database',
};
for (const part of parts) {
if (part.startsWith('db:')) {
result.databaseId = parseInt(part.slice(3), 10);
} else if (part.startsWith('cat:')) {
result.catalogName = part.slice(4);
result.nodeType = 'catalog';
} else if (part.startsWith('schema:')) {
result.schemaName = part.slice(7);
result.nodeType = 'schema';
} else if (part.startsWith('table:')) {
result.tableName = part.slice(6);
result.nodeType = 'table';
}
}
return result;
}
export function getParentKey(key: string): string | null {
const parts = key.split('|');
if (parts.length <= 1) return null;
return parts.slice(0, -1).join('|');
}
export function getEffectiveState(
nodeKey: string,
permissionStates: Record<string, PermissionState>,
): PermissionState {
// Check if node has explicit state
const explicitState = permissionStates[nodeKey];
if (explicitState && explicitState !== 'inherit') {
return explicitState;
}
// Walk up ancestors to find effective state
let currentKey: string | null = nodeKey;
while (currentKey) {
const parentKey = getParentKey(currentKey);
if (!parentKey) break;
const parentState = permissionStates[parentKey];
if (parentState && parentState !== 'inherit') {
return parentState;
}
currentKey = parentKey;
}
// Default is deny (implicit)
return 'deny';
}
export function getExplicitState(
nodeKey: string,
permissionStates: Record<string, PermissionState>,
): PermissionState {
return permissionStates[nodeKey] || 'inherit';
}
export function cyclePermissionState(
nodeKey: string,
permissionStates: Record<string, PermissionState>,
): Record<string, PermissionState> {
const newStates = { ...permissionStates };
const currentState = newStates[nodeKey] || 'inherit';
const parentKey = getParentKey(nodeKey);
let nextState: PermissionState;
if (!parentKey) {
// Root element (database) - only toggle between inherit and allow
// (deny is the default state, so there's no point in explicitly denying)
if (currentState === 'inherit') {
nextState = 'allow';
} else {
nextState = 'inherit';
}
} else {
// Child element - toggle between inherit and opposite of parent's effective state
const parentEffective = getEffectiveState(parentKey, permissionStates);
const oppositeState = parentEffective === 'allow' ? 'deny' : 'allow';
if (currentState === 'inherit') {
nextState = oppositeState;
} else {
nextState = 'inherit';
}
}
// Apply the new state
if (nextState === 'inherit') {
delete newStates[nodeKey];
} else {
newStates[nodeKey] = nextState;
}
// Clear all descendant states when setting a non-inherit state on a parent
// This ensures children return to "inherit" when parent is set
if (nextState !== 'inherit') {
clearDescendantStates(nodeKey, newStates);
}
return newStates;
}
/**
* Clear all permission states for descendants of a given node key.
*/
function clearDescendantStates(
parentKey: string,
states: Record<string, PermissionState>,
): void {
const keysToDelete: string[] = [];
Object.keys(states).forEach(key => {
if (isDescendantOf(key, parentKey)) {
keysToDelete.push(key);
}
});
keysToDelete.forEach(key => {
delete states[key];
});
}
/**
* Check if a key is a descendant of another key.
*/
function isDescendantOf(childKey: string, ancestorKey: string): boolean {
// A key is a descendant if it starts with the ancestor key + '|'
return childKey.startsWith(`${ancestorKey}|`);
}
/**
* Count children with explicit allow/deny states.
* Returns counts for direct children only (not all descendants).
*/
export function countChildPermissions(
parentKey: string,
permissionStates: Record<string, PermissionState>,
): { allowed: number; denied: number } {
let allowed = 0;
let denied = 0;
Object.entries(permissionStates).forEach(([key, state]) => {
// Check if this is a direct child (parent key + one more segment)
if (isDirectChildOf(key, parentKey)) {
if (state === 'allow') {
allowed += 1;
} else if (state === 'deny') {
denied += 1;
}
}
});
return { allowed, denied };
}
/**
* Count all descendants with explicit allow/deny states.
*/
export function countDescendantPermissions(
parentKey: string,
permissionStates: Record<string, PermissionState>,
): { allowed: number; denied: number } {
let allowed = 0;
let denied = 0;
Object.entries(permissionStates).forEach(([key, state]) => {
if (isDescendantOf(key, parentKey)) {
if (state === 'allow') {
allowed += 1;
} else if (state === 'deny') {
denied += 1;
}
}
});
return { allowed, denied };
}
/**
* Check if a key is a direct child of another key.
*/
function isDirectChildOf(childKey: string, parentKey: string): boolean {
if (!childKey.startsWith(`${parentKey}|`)) {
return false;
}
// Count the number of '|' segments after the parent
const suffix = childKey.slice(parentKey.length + 1);
return !suffix.includes('|');
}
export function updateTreeData(
list: PermissionNode[],
key: string,
children: PermissionNode[],
): PermissionNode[] {
return list.map(node => {
if (node.key === key) {
return {
...node,
children,
};
}
if (node.children) {
return {
...node,
children: updateTreeData(node.children, key, children),
};
}
return node;
});
}
export function findNodeByKey(
data: PermissionNode[],
key: string,
): PermissionNode | null {
for (const node of data) {
if (node.key === key) return node;
if (node.children) {
const found = findNodeByKey(node.children, key);
if (found) return found;
}
}
return null;
}
/**
* Make a CLS key from a table key and column name.
* Format: "{tableKey}::{columnName}"
*/
export function makeClsKey(tableKey: string, columnName: string): string {
return `${tableKey}::${columnName}`;
}
/**
* Parse a CLS key into table key and column name.
*/
export function parseClsKey(clsKey: string): {
tableKey: string;
columnName: string;
} | null {
const separatorIndex = clsKey.indexOf('::');
if (separatorIndex === -1) return null;
return {
tableKey: clsKey.slice(0, separatorIndex),
columnName: clsKey.slice(separatorIndex + 2),
};
}
/**
* Get CLS rules for a specific table from the clsRules state.
*/
export function getTableClsRules(
tableKey: string,
clsRules: Record<string, CLSAction>,
): Record<string, CLSAction> {
const tableRules: Record<string, CLSAction> = {};
const prefix = `${tableKey}::`;
Object.entries(clsRules).forEach(([key, action]) => {
if (key.startsWith(prefix)) {
const columnName = key.slice(prefix.length);
tableRules[columnName] = action;
}
});
return tableRules;
}
export function generatePermissionsPayload(
permissionStates: Record<string, PermissionState>,
databases: Map<number, string>,
clsRules: Record<string, CLSAction> = {},
rlsRules: Record<string, RLSRule> = {},
): PermissionsPayload {
const allowed: PermissionEntry[] = [];
const denied: PermissionEntry[] = [];
// Clean up states - remove redundant entries
const cleanedStates = cleanupStates(permissionStates);
// Group CLS rules by table key
const clsByTable: Record<string, Record<string, string>> = {};
Object.entries(clsRules).forEach(([clsKey, action]) => {
const parsed = parseClsKey(clsKey);
if (!parsed) return;
if (!clsByTable[parsed.tableKey]) {
clsByTable[parsed.tableKey] = {};
}
clsByTable[parsed.tableKey][parsed.columnName] = action;
});
Object.entries(cleanedStates).forEach(([key, state]) => {
const parsed = parseKey(key);
const databaseName = databases.get(parsed.databaseId);
if (!databaseName) return;
const entry: PermissionEntry = {
database: databaseName,
};
if (parsed.catalogName) entry.catalog = parsed.catalogName;
if (parsed.schemaName) entry.schema = parsed.schemaName;
if (parsed.tableName) entry.table = parsed.tableName;
// Add RLS rules if this is a table entry and has RLS rules
if (parsed.tableName && rlsRules[key]) {
const rls = rlsRules[key];
if (rls.predicate) {
entry.rls = {
predicate: rls.predicate,
};
if (rls.groupKey) {
entry.rls.group_key = rls.groupKey;
}
}
}
// Add CLS rules if this is a table entry and has CLS rules
if (parsed.tableName && clsByTable[key]) {
entry.cls = clsByTable[key];
}
if (state === 'allow') {
allowed.push(entry);
} else if (state === 'deny') {
denied.push(entry);
}
});
// Also add RLS/CLS rules for tables that don't have explicit permission states
// but do have RLS or CLS rules (they inherit permission from parent)
const tablesWithRules = new Set([
...Object.keys(clsByTable),
...Object.keys(rlsRules),
]);
tablesWithRules.forEach(tableKey => {
// Check if we already added this table
const parsed = parseKey(tableKey);
const databaseName = databases.get(parsed.databaseId);
if (!databaseName) return;
// Only add if not already in allowed list
const alreadyInAllowed = allowed.some(
entry =>
entry.database === databaseName &&
entry.catalog === parsed.catalogName &&
entry.schema === parsed.schemaName &&
entry.table === parsed.tableName,
);
const hasClsRules =
clsByTable[tableKey] && Object.keys(clsByTable[tableKey]).length > 0;
const hasRlsRules =
rlsRules[tableKey] && rlsRules[tableKey].predicate;
if (!alreadyInAllowed && (hasClsRules || hasRlsRules)) {
// Check if the table has effective allow state (inherited)
const effectiveState = getEffectiveState(tableKey, permissionStates);
if (effectiveState === 'allow') {
const entry: PermissionEntry = {
database: databaseName,
table: parsed.tableName,
};
if (parsed.catalogName) entry.catalog = parsed.catalogName;
if (parsed.schemaName) entry.schema = parsed.schemaName;
if (hasRlsRules) {
const rls = rlsRules[tableKey];
entry.rls = { predicate: rls.predicate };
if (rls.groupKey) {
entry.rls.group_key = rls.groupKey;
}
}
if (hasClsRules) {
entry.cls = clsByTable[tableKey];
}
allowed.push(entry);
}
}
});
return { allowed, denied };
}
function cleanupStates(
states: Record<string, PermissionState>,
): Record<string, PermissionState> {
const cleaned: Record<string, PermissionState> = {};
Object.entries(states).forEach(([key, state]) => {
if (state === 'inherit') return;
// Check ancestor state
let ancestorState: PermissionState | null = null;
let currentKey: string | null = key;
while (currentKey) {
const parentKey = getParentKey(currentKey);
if (!parentKey) break;
if (states[parentKey] && states[parentKey] !== 'inherit') {
ancestorState = states[parentKey];
break;
}
currentKey = parentKey;
}
// Only keep state if it differs from ancestor
if (ancestorState) {
if (ancestorState !== state) {
cleaned[key] = state;
}
} else {
// No ancestor - only keep 'allow' (deny is default)
if (state === 'allow') {
cleaned[key] = state;
}
}
});
return cleaned;
}
export function loadPermissionsFromPayload(
payload: PermissionsPayload,
databases: Map<string, number>,
): {
states: Record<string, PermissionState>;
clsRules: Record<string, CLSAction>;
rlsRules: Record<string, RLSRule>;
} {
const states: Record<string, PermissionState> = {};
const clsRules: Record<string, CLSAction> = {};
const rlsRules: Record<string, RLSRule> = {};
payload.allowed.forEach(entry => {
const databaseId = databases.get(entry.database);
if (databaseId === undefined) return;
const key = makeKey({
databaseId,
catalogName: entry.catalog,
schemaName: entry.schema,
tableName: entry.table,
});
states[key] = 'allow';
// Load RLS rules if present
if (entry.rls && entry.table) {
rlsRules[key] = {
predicate: entry.rls.predicate,
groupKey: entry.rls.group_key,
};
}
// Load CLS rules if present
if (entry.cls && entry.table) {
Object.entries(entry.cls).forEach(([columnName, action]) => {
const clsKey = makeClsKey(key, columnName);
clsRules[clsKey] = action as CLSAction;
});
}
});
payload.denied.forEach(entry => {
const databaseId = databases.get(entry.database);
if (databaseId === undefined) return;
const key = makeKey({
databaseId,
catalogName: entry.catalog,
schemaName: entry.schema,
tableName: entry.table,
});
states[key] = 'deny';
});
return { states, clsRules, rlsRules };
}

View File

@@ -0,0 +1,38 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
export type RoleObject = {
id: number;
name: string;
};
export type DataAccessRuleObject = {
id?: number;
name?: string;
description?: string;
role_id: number;
role?: RoleObject;
rule: string;
changed_on_delta_humanized?: string;
changed_by?: {
id: number;
first_name: string;
last_name: string;
};
};

View File

@@ -0,0 +1,407 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { t, SupersetClient } from '@superset-ui/core';
import { useMemo, useState } from 'react';
import { ConfirmStatusChange, Tooltip } from '@superset-ui/core/components';
import {
ModifiedInfo,
ListView,
ListViewFilterOperator as FilterOperator,
type ListViewProps,
type ListViewFilters,
type ListViewFetchDataConfig as FetchDataConfig,
} from 'src/components';
import { Icons } from '@superset-ui/core/components/Icons';
import withToasts from 'src/components/MessageToasts/withToasts';
import SubMenu, { SubMenuProps } from 'src/features/home/SubMenu';
import rison from 'rison';
import { useListViewResource } from 'src/views/CRUD/hooks';
import DataAccessRuleModal from 'src/features/dataAccessRules/DataAccessRuleModal';
import { DataAccessRuleObject } from 'src/features/dataAccessRules/types';
import { createErrorHandler, createFetchRelated } from 'src/views/CRUD/utils';
import { QueryObjectColumns } from 'src/views/CRUD/types';
interface DataAccessRulesListProps {
addDangerToast: (msg: string) => void;
addSuccessToast: (msg: string) => void;
user: {
userId: string | number;
firstName: string;
lastName: string;
};
}
function DataAccessRulesList(props: DataAccessRulesListProps) {
const { addDangerToast, addSuccessToast, user } = props;
const [ruleModalOpen, setRuleModalOpen] = useState<boolean>(false);
const [currentRule, setCurrentRule] =
useState<DataAccessRuleObject | null>(null);
const {
state: {
loading,
resourceCount: rulesCount,
resourceCollection: rules,
bulkSelectEnabled,
},
hasPerm,
fetchData,
refreshData,
toggleBulkSelect,
} = useListViewResource<DataAccessRuleObject>(
'dar',
t('Data Access Rule'),
addDangerToast,
true,
undefined,
undefined,
true,
);
function handleRuleEdit(rule: DataAccessRuleObject | null) {
setCurrentRule(rule);
setRuleModalOpen(true);
}
function handleRuleDelete(
{ id, name: ruleName }: DataAccessRuleObject,
refreshData: (arg0?: FetchDataConfig | null) => void,
addSuccessToast: (arg0: string) => void,
addDangerToast: (arg0: string) => void,
) {
const name = ruleName || `Rule ${id}`;
return SupersetClient.delete({
endpoint: `/api/v1/dar/${id}`,
}).then(
() => {
refreshData();
addSuccessToast(t('Deleted %s', name));
},
createErrorHandler(errMsg =>
addDangerToast(t('There was an issue deleting %s: %s', name, errMsg)),
),
);
}
function handleBulkRulesDelete(rulesToDelete: DataAccessRuleObject[]) {
const ids = rulesToDelete.map(({ id }) => id);
return SupersetClient.delete({
endpoint: `/api/v1/dar/?q=${rison.encode(ids)}`,
}).then(
() => {
refreshData();
addSuccessToast(t(`Deleted`));
},
createErrorHandler(errMsg =>
addDangerToast(t('There was an issue deleting rules: %s', errMsg)),
),
);
}
function handleRuleModalHide() {
setCurrentRule(null);
setRuleModalOpen(false);
refreshData();
}
const canWrite = hasPerm('can_write');
const canEdit = hasPerm('can_write');
const canExport = hasPerm('can_export');
const columns = useMemo(
() => [
{
Cell: ({
row: {
original: { name },
},
}: {
row: { original: DataAccessRuleObject };
}) => name || '-',
accessor: 'name',
Header: t('Name'),
size: 'lg',
id: 'name',
},
{
Cell: ({
row: {
original: { description },
},
}: {
row: { original: DataAccessRuleObject };
}) => {
if (!description) return '-';
const truncated =
description.length > 100
? `${description.substring(0, 100)}...`
: description;
return (
<Tooltip id="desc-tooltip" title={description} placement="top">
<span>{truncated}</span>
</Tooltip>
);
},
accessor: 'description',
Header: t('Description'),
size: 'xl',
id: 'description',
disableSortBy: true,
},
{
Cell: ({
row: {
original: { role },
},
}: {
row: { original: DataAccessRuleObject };
}) => role?.name || '-',
accessor: 'role',
Header: t('Role'),
size: 'lg',
id: 'role',
disableSortBy: true,
},
{
Cell: ({
row: {
original: {
changed_on_delta_humanized: changedOn,
changed_by: changedBy,
},
},
}: {
row: { original: DataAccessRuleObject };
}) => <ModifiedInfo date={changedOn} user={changedBy} />,
Header: t('Last modified'),
accessor: 'changed_on_delta_humanized',
size: 'xl',
id: 'changed_on_delta_humanized',
},
{
Cell: ({ row: { original } }: { row: { original: DataAccessRuleObject } }) => {
const handleDelete = () =>
handleRuleDelete(
original,
refreshData,
addSuccessToast,
addDangerToast,
);
const handleEdit = () => handleRuleEdit(original);
return (
<div className="actions">
{canWrite && (
<ConfirmStatusChange
title={t('Please confirm')}
description={
<>
{t('Are you sure you want to delete')}{' '}
<b>{original.name || `Rule ${original.id}`}</b>
</>
}
onConfirm={handleDelete}
>
{confirmDelete => (
<Tooltip
id="delete-action-tooltip"
title={t('Delete')}
placement="bottom"
>
<span
role="button"
tabIndex={0}
className="action-button"
onClick={confirmDelete}
>
<Icons.DeleteOutlined
data-test="dar-list-trash-icon"
iconSize="l"
/>
</span>
</Tooltip>
)}
</ConfirmStatusChange>
)}
{canEdit && (
<Tooltip
id="edit-action-tooltip"
title={t('Edit')}
placement="bottom"
>
<span
role="button"
tabIndex={0}
className="action-button"
onClick={handleEdit}
>
<Icons.EditOutlined data-test="edit-alt" iconSize="l" />
</span>
</Tooltip>
)}
</div>
);
},
Header: t('Actions'),
id: 'actions',
hidden: !canEdit && !canWrite && !canExport,
disableSortBy: true,
size: 'lg',
},
{
accessor: QueryObjectColumns.ChangedBy,
hidden: true,
id: QueryObjectColumns.ChangedBy,
},
],
[
user.userId,
canEdit,
canWrite,
canExport,
hasPerm,
refreshData,
addDangerToast,
addSuccessToast,
],
);
const emptyState = {
title: t('No Data Access Rules yet'),
image: 'filter-results.svg',
buttonAction: () => handleRuleEdit(null),
buttonIcon: canEdit ? (
<Icons.PlusOutlined iconSize="m" data-test="add-rule-empty" />
) : undefined,
buttonText: canEdit ? t('Rule') : null,
};
const filters: ListViewFilters = useMemo(
() => [
{
Header: t('Role'),
key: 'role',
id: 'role',
input: 'select',
operator: FilterOperator.RelationOneMany,
unfilteredLabel: t('All'),
fetchSelects: createFetchRelated(
'dar',
'role',
createErrorHandler(errMsg =>
t('An error occurred while fetching roles: %s', errMsg),
),
user,
),
paginate: true,
},
{
Header: t('Modified by'),
key: 'changed_by',
id: 'changed_by',
input: 'select',
operator: FilterOperator.RelationOneMany,
unfilteredLabel: t('All'),
fetchSelects: createFetchRelated(
'dar',
'changed_by',
createErrorHandler(errMsg =>
t('An error occurred while fetching users: %s', errMsg),
),
user,
),
paginate: true,
},
],
[user],
);
const initialSort = [{ id: 'changed_on_delta_humanized', desc: true }];
const PAGE_SIZE = 25;
const subMenuButtons: SubMenuProps['buttons'] = [];
if (canWrite) {
subMenuButtons.push({
name: t('Bulk select'),
buttonStyle: 'secondary',
'data-test': 'bulk-select',
onClick: toggleBulkSelect,
});
subMenuButtons.push({
name: t('Rule'),
icon: <Icons.PlusOutlined iconSize="m" data-test="add-rule" />,
buttonStyle: 'primary',
onClick: () => handleRuleEdit(null),
});
}
return (
<>
<SubMenu name={t('Data Access Rules')} buttons={subMenuButtons} />
<ConfirmStatusChange
title={t('Please confirm')}
description={t('Are you sure you want to delete the selected rules?')}
onConfirm={handleBulkRulesDelete}
>
{confirmDelete => {
const bulkActions: ListViewProps['bulkActions'] = [];
if (canWrite) {
bulkActions.push({
key: 'delete',
name: t('Delete'),
type: 'danger',
onSelect: confirmDelete,
});
}
return (
<>
<DataAccessRuleModal
rule={currentRule}
addDangerToast={addDangerToast}
onHide={handleRuleModalHide}
addSuccessToast={addSuccessToast}
show={ruleModalOpen}
/>
<ListView<DataAccessRuleObject>
className="dar-list-view"
bulkActions={bulkActions}
bulkSelectEnabled={bulkSelectEnabled}
disableBulkSelect={toggleBulkSelect}
columns={columns}
count={rulesCount}
data={rules}
emptyState={emptyState}
fetchData={fetchData}
filters={filters}
initialSort={initialSort}
loading={loading}
addDangerToast={addDangerToast}
addSuccessToast={addSuccessToast}
refreshData={() => {}}
pageSize={PAGE_SIZE}
/>
</>
);
}}
</ConfirmStatusChange>
</>
);
}
export default withToasts(DataAccessRulesList);

View File

@@ -138,6 +138,13 @@ const RowLevelSecurityList = lazy(
),
);
const DataAccessRulesList = lazy(
() =>
import(
/* webpackChunkName: "DataAccessRulesList" */ 'src/pages/DataAccessRulesList'
),
);
const RolesList = lazy(
() => import(/* webpackChunkName: "RolesList" */ 'src/pages/RolesList'),
);
@@ -315,6 +322,13 @@ if (isFeatureEnabled(FeatureFlag.TaggingSystem)) {
});
}
if (isFeatureEnabled(FeatureFlag.DataAccessRules)) {
routes.push({
path: '/dataaccessrules/list/',
Component: DataAccessRulesList,
});
}
const user = getBootstrapData()?.user;
const authRegistrationEnabled =
getBootstrapData()?.common.conf.AUTH_USER_REGISTRATION;

View File

@@ -46,11 +46,17 @@ class TablesDatabaseCommand(BaseCommand):
catalog_name: str | None,
schema_name: str,
force: bool,
filter_str: str | None = None,
page: int | None = None,
page_size: int | None = None,
):
self._db_id = db_id
self._catalog_name = catalog_name
self._schema_name = schema_name
self._force = force
self._filter_str = filter_str
self._page = page
self._page_size = page_size
def run(self) -> dict[str, Any]:
self.validate()
@@ -161,8 +167,24 @@ class TablesDatabaseCommand(BaseCommand):
key=lambda item: item["value"],
)
# Apply filter if provided
if self._filter_str:
filter_lower = self._filter_str.lower()
options = [
opt for opt in options if filter_lower in opt["value"].lower()
]
# Get total count before pagination
total_count = len(options)
# Apply pagination if provided
if self._page is not None and self._page_size is not None:
start = self._page * self._page_size
end = start + self._page_size
options = options[start:end]
payload = {
"count": len(tables) + len(views) + len(materialized_views),
"count": total_count,
"result": options,
}
return payload

View File

@@ -581,6 +581,10 @@ DEFAULT_FEATURE_FLAGS: dict[str, bool] = {
# Apply RLS rules to SQL Lab queries. This requires parsing and manipulating the
# query, and might break queries and/or allow users to bypass RLS. Use with care!
"RLS_IN_SQLLAB": False,
# Enable the new Data Access Rules system for table-level access control,
# row-level security (RLS), and column-level security (CLS). This replaces
# the FAB-based permission system with a more flexible JSON-based rule system.
"DATA_ACCESS_RULES": False,
# Try to optimize SQL queries — for now only predicate pushdown is supported.
"OPTIMIZE_SQL": False,
# When impersonating a user, use the email prefix instead of the username

View File

@@ -113,6 +113,7 @@ from superset.superset_typing import (
)
from superset.utils import core as utils, json
from superset.utils.backports import StrEnum
from superset.data_access_rules.utils import get_hidden_columns_for_table
config = current_app.config # Backward compatibility for tests
metadata = Model.metadata # pylint: disable=no-member
@@ -438,6 +439,21 @@ class BaseDatasource(
@property
def data(self) -> ExplorableData:
"""Data representation of the datasource sent to the frontend"""
# Filter hidden columns based on CLS rules
columns_data = [o.data for o in self.columns]
if is_feature_enabled("DATA_ACCESS_RULES") and hasattr(self, "database"):
try:
table = Table(self.datasource_name, self.schema, self.catalog)
hidden_columns = get_hidden_columns_for_table(table, self.database)
if hidden_columns:
columns_data = [
c for c in columns_data
if c.get("column_name") not in hidden_columns
]
except Exception: # pylint: disable=broad-except
# Don't fail if CLS check fails, just return all columns
pass
return {
# simple fields
"id": self.id,
@@ -462,7 +478,7 @@ class BaseDatasource(
# sqla-specific
"sql": self.sql,
# one to many
"columns": [o.data for o in self.columns],
"columns": columns_data,
"metrics": [o.data for o in self.metrics],
"folders": self.folders,
# TODO deprecate, move logic to JS

View File

@@ -0,0 +1,28 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Data Access Rules module.
This module provides a new approach to data access control in Superset,
supporting:
- Table-level access control (allow/deny patterns)
- Row-level security (RLS) with predicates
- Column-level security (CLS) with masking/hiding options
Unlike the FAB-based permission system, rules are stored as JSON documents
and can reference tables directly without requiring a priori permission creation.
"""

View File

@@ -0,0 +1,235 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Data Access Rules REST API.
This module provides the REST API for managing Data Access Rules,
including CRUD operations and a group_keys discovery endpoint.
"""
import logging
from flask import request, Response
from flask_appbuilder.api import expose, protect, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
from marshmallow import ValidationError
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
from superset.data_access_rules.models import DataAccessRule
from superset.data_access_rules.schemas import (
DataAccessRulePostSchema,
DataAccessRulePutSchema,
DataAccessRuleShowSchema,
)
from superset.data_access_rules.utils import get_all_group_keys
from superset.extensions import event_logger
from superset.views.base_api import (
BaseSupersetModelRestApi,
requires_json,
statsd_metrics,
)
from superset.views.filters import BaseFilterRelatedRoles, BaseFilterRelatedUsers
logger = logging.getLogger(__name__)
class DataAccessRulesRestApi(BaseSupersetModelRestApi):
"""REST API for Data Access Rules."""
datamodel = SQLAInterface(DataAccessRule)
include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | {
RouteMethod.RELATED,
"group_keys",
}
resource_name = "dar"
class_permission_name = "DataAccessRule"
openapi_spec_tag = "Data Access Rules"
method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP
allow_browser_login = True
list_columns = [
"id",
"name",
"description",
"role_id",
"role.id",
"role.name",
"rule",
"changed_on_delta_humanized",
"changed_by.first_name",
"changed_by.last_name",
"changed_by.id",
]
order_columns = [
"id",
"name",
"role_id",
"changed_on_delta_humanized",
]
add_columns = [
"name",
"description",
"role_id",
"rule",
]
edit_columns = [
"name",
"description",
"role_id",
"rule",
]
show_columns = [
"id",
"name",
"description",
"role_id",
"role.name",
"role.id",
"rule",
"created_on",
"changed_on",
"created_by.first_name",
"created_by.last_name",
"changed_by.first_name",
"changed_by.last_name",
]
search_columns = ["role", "changed_by"]
allowed_rel_fields = {"role", "changed_by"}
base_related_field_filters = {
"role": [["id", BaseFilterRelatedRoles, lambda: []]],
"changed_by": [["id", BaseFilterRelatedUsers, lambda: []]],
}
add_model_schema = DataAccessRulePostSchema()
edit_model_schema = DataAccessRulePutSchema()
show_model_schema = DataAccessRuleShowSchema()
openapi_spec_methods = {
"get": {"get": {"summary": "Get a data access rule"}},
"get_list": {"get": {"summary": "Get a list of data access rules"}},
"post": {"post": {"summary": "Create a data access rule"}},
"put": {"put": {"summary": "Update a data access rule"}},
"delete": {"delete": {"summary": "Delete a data access rule"}},
}
@expose("/<int:pk>", methods=("PUT",))
@protect()
@safe
@statsd_metrics
@requires_json
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.put",
log_to_statsd=False,
)
def put(self, pk: int) -> Response:
"""Update a data access rule.
---
put:
summary: Update a data access rule
parameters:
- in: path
schema:
type: integer
name: pk
description: The rule pk
requestBody:
description: Data access rule schema
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/DataAccessRulePutSchema'
responses:
200:
description: Rule updated
content:
application/json:
schema:
type: object
properties:
id:
type: number
result:
$ref: '#/components/schemas/DataAccessRulePutSchema'
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
try:
item = self.edit_model_schema.load(request.json)
except ValidationError as error:
return self.response_400(message=error.messages)
# Get existing rule
existing = self.datamodel.get(pk)
if not existing:
return self.response_404()
# Update fields
for key, value in item.items():
setattr(existing, key, value)
try:
self.datamodel.edit(existing)
return self.response(200, id=existing.id, result=item)
except Exception as ex:
logger.error("Error updating data access rule: %s", str(ex), exc_info=True)
return self.response_422(message=str(ex))
@expose("/group_keys/", methods=("GET",))
@protect()
@safe
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.group_keys",
log_to_statsd=False,
)
def group_keys(self) -> Response:
"""
Get all distinct group_keys used in RLS rules.
This endpoint is useful for UI discoverability - showing users
what group_keys already exist so they can reuse them.
---
get:
summary: Get all distinct RLS group keys
description: >-
Returns a list of all unique group_key values used in RLS rules
across all Data Access Rules. This helps users discover existing
keys for consistent rule grouping.
responses:
200:
description: List of group keys
content:
application/json:
schema:
$ref: '#/components/schemas/GroupKeysResponseSchema'
401:
$ref: '#/components/responses/401'
500:
$ref: '#/components/responses/500'
"""
group_keys = get_all_group_keys()
return self.response(200, result=sorted(group_keys))

View File

@@ -0,0 +1,119 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Data Access Rules models.
This module defines the DataAccessRule model for storing access rules
as JSON documents associated with roles.
Example rule document structure:
{
"allowed": [
{
"database": "sales",
"schema": "orders",
"table": "ord_main"
},
{
"database": "logs",
"catalog": "public"
},
{
"database": "sales",
"schema": "orders",
"table": "prices",
"rls": {
"predicate": "org = 495",
"group_key": "org_filter"
}
},
{
"database": "sales",
"schema": "orders",
"table": "user_info",
"cls": {
"name": "mask",
"age": "nullify",
"email": "hash",
"lastname": "hide"
}
}
],
"denied": [
{
"database": "logs",
"catalog": "public",
"schema": "pii"
}
]
}
"""
from __future__ import annotations
from typing import Any
from flask_appbuilder import Model
from sqlalchemy import Column, ForeignKey, Integer, String, Text
from sqlalchemy.orm import relationship
from superset import security_manager
from superset.models.helpers import AuditMixinNullable
class DataAccessRule(Model, AuditMixinNullable):
"""
Data access rule associated with a role.
Each rule is a JSON document that describes what databases, catalogs,
schemas, and tables a role can access, along with optional RLS predicates
and CLS column restrictions.
"""
__tablename__ = "data_access_rules"
id = Column(Integer, primary_key=True)
name = Column(String(250), nullable=True)
description = Column(Text, nullable=True)
role_id = Column(Integer, ForeignKey("ab_role.id"), nullable=False)
rule = Column(Text, nullable=False)
role = relationship(
security_manager.role_model,
backref="data_access_rules",
foreign_keys=[role_id],
)
def __repr__(self) -> str:
return f"<DataAccessRule(id={self.id}, name={self.name!r}, role_id={self.role_id})>"
@property
def rule_dict(self) -> dict[str, Any]:
"""Parse the rule JSON string into a dictionary."""
import json
try:
return json.loads(self.rule) if self.rule else {}
except json.JSONDecodeError:
return {}
@rule_dict.setter
def rule_dict(self, value: dict[str, Any]) -> None:
"""Serialize a dictionary to JSON for storage."""
import json
self.rule = json.dumps(value)

View File

@@ -0,0 +1,249 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Data Access Rules schemas for API serialization/deserialization.
"""
from marshmallow import fields, post_load, Schema, validates_schema, ValidationError
from superset.dashboards.schemas import UserSchema
from superset.data_access_rules.models import DataAccessRule
# Field descriptions for OpenAPI documentation
rule_description = """
A JSON document describing the access rule. The document should have two optional keys:
- `allowed`: List of entries describing what is allowed
- `denied`: List of entries describing what is denied
Each entry can specify:
- `database` (required): The database name
- `catalog` (optional): The catalog name
- `schema` (optional): The schema name
- `table` (optional): The table name
- `rls` (optional): Row-level security config with `predicate` and optional `group_key`
- `cls` (optional): Column-level security config mapping column names to actions
Example:
{
"allowed": [
{"database": "sales", "schema": "orders"},
{"database": "sales", "schema": "orders", "table": "prices",
"rls": {"predicate": "org_id = 123", "group_key": "org"}},
{"database": "sales", "schema": "users", "table": "info",
"cls": {"email": "mask", "ssn": "hide"}}
],
"denied": [
{"database": "sales", "schema": "internal"}
]
}
CLS actions: "hash", "nullify", "mask", "hide"
"""
class RoleSchema(Schema):
"""Schema for role information."""
name = fields.String()
id = fields.Integer()
class DataAccessRuleListSchema(Schema):
"""Schema for listing data access rules."""
id = fields.Integer(metadata={"description": "Unique ID of the rule"})
name = fields.String(metadata={"description": "Name of the rule"})
description = fields.String(metadata={"description": "Description of the rule"})
role_id = fields.Integer(metadata={"description": "ID of the associated role"})
role = fields.Nested(RoleSchema)
rule = fields.String(metadata={"description": rule_description})
changed_on_delta_humanized = fields.Method("get_changed_on_delta_humanized")
changed_by = fields.Nested(UserSchema(exclude=["username"]))
def get_changed_on_delta_humanized(self, obj: DataAccessRule) -> str:
return obj.changed_on_delta_humanized()
class DataAccessRuleShowSchema(Schema):
"""Schema for showing a single data access rule."""
id = fields.Integer(metadata={"description": "Unique ID of the rule"})
name = fields.String(metadata={"description": "Name of the rule"})
description = fields.String(metadata={"description": "Description of the rule"})
role_id = fields.Integer(metadata={"description": "ID of the associated role"})
role = fields.Nested(RoleSchema)
rule = fields.String(metadata={"description": rule_description})
created_on = fields.DateTime()
changed_on = fields.DateTime()
created_by = fields.Nested(UserSchema(exclude=["username"]))
changed_by = fields.Nested(UserSchema(exclude=["username"]))
class DataAccessRulePostSchema(Schema):
"""Schema for creating a data access rule."""
name = fields.String(
metadata={"description": "Name for this rule (optional)"},
required=False,
allow_none=True,
)
description = fields.String(
metadata={"description": "Description of the rule (optional)"},
required=False,
allow_none=True,
)
role_id = fields.Integer(
metadata={"description": "ID of the role this rule applies to"},
required=True,
allow_none=False,
)
rule = fields.String(
metadata={"description": rule_description},
required=True,
allow_none=False,
)
@validates_schema
def validate_rule_json(self, data: dict, **kwargs: dict) -> None:
"""Validate that the rule field contains valid JSON."""
import json
if rule := data.get("rule"):
try:
parsed = json.loads(rule)
if not isinstance(parsed, dict):
raise ValidationError(
"Rule must be a JSON object", field_name="rule"
)
# Validate structure
allowed = parsed.get("allowed", [])
denied = parsed.get("denied", [])
if not isinstance(allowed, list):
raise ValidationError("'allowed' must be a list", field_name="rule")
if not isinstance(denied, list):
raise ValidationError("'denied' must be a list", field_name="rule")
# Validate entries
for entry in allowed + denied:
if not isinstance(entry, dict):
raise ValidationError(
"Each entry must be an object", field_name="rule"
)
if "database" not in entry:
raise ValidationError(
"Each entry must have a 'database' field",
field_name="rule",
)
# Validate CLS actions if present
if cls_config := entry.get("cls"):
valid_actions = {"hash", "nullify", "mask", "hide"}
for col, action in cls_config.items():
if action.lower() not in valid_actions:
raise ValidationError(
f"Invalid CLS action '{action}' for column '{col}'. "
f"Valid actions: {valid_actions}",
field_name="rule",
)
except json.JSONDecodeError as ex:
raise ValidationError(f"Invalid JSON: {ex}", field_name="rule") from ex
@post_load
def make_object(self, data: dict, **kwargs: dict) -> DataAccessRule:
"""Convert validated data to a DataAccessRule instance."""
return DataAccessRule(**data)
class DataAccessRulePutSchema(Schema):
"""Schema for updating a data access rule."""
name = fields.String(
metadata={"description": "Name for this rule (optional)"},
required=False,
allow_none=True,
)
description = fields.String(
metadata={"description": "Description of the rule (optional)"},
required=False,
allow_none=True,
)
role_id = fields.Integer(
metadata={"description": "ID of the role this rule applies to"},
required=False,
allow_none=False,
)
rule = fields.String(
metadata={"description": rule_description},
required=False,
allow_none=False,
)
@validates_schema
def validate_rule_json(self, data: dict, **kwargs: dict) -> None:
"""Validate that the rule field contains valid JSON if provided."""
import json
if rule := data.get("rule"):
try:
parsed = json.loads(rule)
if not isinstance(parsed, dict):
raise ValidationError(
"Rule must be a JSON object", field_name="rule"
)
# Same validation as POST schema
allowed = parsed.get("allowed", [])
denied = parsed.get("denied", [])
if not isinstance(allowed, list):
raise ValidationError("'allowed' must be a list", field_name="rule")
if not isinstance(denied, list):
raise ValidationError("'denied' must be a list", field_name="rule")
for entry in allowed + denied:
if not isinstance(entry, dict):
raise ValidationError(
"Each entry must be an object", field_name="rule"
)
if "database" not in entry:
raise ValidationError(
"Each entry must have a 'database' field",
field_name="rule",
)
if cls_config := entry.get("cls"):
valid_actions = {"hash", "nullify", "mask", "hide"}
for col, action in cls_config.items():
if action.lower() not in valid_actions:
raise ValidationError(
f"Invalid CLS action '{action}' for column '{col}'. "
f"Valid actions: {valid_actions}",
field_name="rule",
)
except json.JSONDecodeError as ex:
raise ValidationError(f"Invalid JSON: {ex}", field_name="rule") from ex
class GroupKeysResponseSchema(Schema):
"""Schema for the group_keys endpoint response."""
result = fields.List(
fields.String(),
metadata={"description": "List of unique group_key values used in RLS rules"},
)

View File

@@ -0,0 +1,896 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Data Access Rules utility functions.
This module provides functions for:
- Checking if a user has access to a table
- Collecting RLS predicates for a table
- Collecting CLS rules for a table
- Applying RLS and CLS to SQL queries
"""
from __future__ import annotations
import logging
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum
from typing import Any, TYPE_CHECKING
from flask import g
from superset import db, is_feature_enabled, security_manager
from superset.sql.parse import CLSAction, Table
if TYPE_CHECKING:
from superset.data_access_rules.models import DataAccessRule
from superset.models.core import Database
from superset.sql.parse import BaseSQLStatement
logger = logging.getLogger(__name__)
class AccessCheckResult(Enum):
"""Result of an access check."""
ALLOWED = "allowed"
DENIED = "denied"
NO_RULE = "no_rule"
@dataclass
class RLSPredicate:
"""An RLS predicate with optional group_key."""
predicate: str
group_key: str | None = None
@dataclass
class TableAccessInfo:
"""Information about access to a specific table."""
access: AccessCheckResult
rls_predicates: list[RLSPredicate]
cls_rules: dict[str, CLSAction]
def get_user_rules() -> list[DataAccessRule]:
"""
Get all data access rules for the current user's roles.
Returns:
List of DataAccessRule objects for the current user's roles.
"""
from superset.data_access_rules.models import DataAccessRule
if not hasattr(g, "user") or not g.user:
return []
user_roles = security_manager.get_user_roles()
role_ids = [role.id for role in user_roles]
if not role_ids:
return []
return (
db.session.query(DataAccessRule)
.filter(DataAccessRule.role_id.in_(role_ids))
.all()
)
def _matches_rule_entry(
entry: dict[str, Any],
database_name: str,
catalog: str | None,
schema: str | None,
table_name: str | None,
) -> bool:
"""
Check if a rule entry matches the given database/catalog/schema/table.
The rule entry can specify any level of the hierarchy:
- database only: matches all catalogs/schemas/tables in that database
- database + catalog: matches all schemas/tables in that catalog
- database + schema: matches all tables in that schema (for DBs without catalogs)
- database + catalog + schema: matches all tables in that schema
- database + schema + table: matches the specific table (for DBs without catalogs)
- database + catalog + schema + table: matches the specific table
Args:
entry: The rule entry dict
database_name: The database name to check
catalog: The catalog to check (None if DB doesn't support catalogs)
schema: The schema to check
table_name: The table name to check
Returns:
True if the entry matches, False otherwise.
"""
# Database must always match
if entry.get("database") != database_name:
return False
entry_catalog = entry.get("catalog")
entry_schema = entry.get("schema")
entry_table = entry.get("table")
# If the entry specifies a catalog, it must match (or catalog must be None/default)
if entry_catalog is not None:
if catalog is not None and entry_catalog != catalog:
return False
# If the entry specifies a schema, it must match
if entry_schema is not None:
if schema is not None and entry_schema != schema:
return False
# If the entry specifies a table, it must match
if entry_table is not None:
if table_name is not None and entry_table != table_name:
return False
# Check specificity: entry must be at least as specific as the query
# If querying a specific table, entry must specify that table or be broader
if table_name is not None and entry_table is not None and entry_table != table_name:
return False
return True
def _is_more_specific(entry: dict[str, Any], other: dict[str, Any]) -> bool:
"""
Check if 'entry' is more specific than 'other'.
More specific means it specifies more levels of the hierarchy.
"""
entry_specificity = sum(
[
entry.get("catalog") is not None,
entry.get("schema") is not None,
entry.get("table") is not None,
]
)
other_specificity = sum(
[
other.get("catalog") is not None,
other.get("schema") is not None,
other.get("table") is not None,
]
)
return entry_specificity > other_specificity
def check_table_access(
database_name: str,
table: Table,
rules: list[DataAccessRule] | None = None,
) -> TableAccessInfo:
"""
Check if the current user has access to a specific table.
The function evaluates all rules for the user's roles and determines:
1. Whether access is allowed, denied, or no rule applies
2. Any RLS predicates that should be applied
3. Any CLS rules for column masking/hiding
Denied rules take precedence over allowed rules when at the same specificity level.
More specific rules take precedence over less specific rules.
Args:
database_name: The database name
table: The Table object with catalog, schema, and table name
rules: Optional list of rules to check (defaults to current user's rules)
Returns:
TableAccessInfo with access result, RLS predicates, and CLS rules.
"""
if rules is None:
rules = get_user_rules()
if not rules:
return TableAccessInfo(
access=AccessCheckResult.NO_RULE,
rls_predicates=[],
cls_rules={},
)
# Collect all matching rules
allowed_entries: list[dict[str, Any]] = []
denied_entries: list[dict[str, Any]] = []
for rule in rules:
rule_dict = rule.rule_dict
# Check allowed entries
for entry in rule_dict.get("allowed", []):
if _matches_rule_entry(
entry, database_name, table.catalog, table.schema, table.table
):
allowed_entries.append(entry)
# Check denied entries
for entry in rule_dict.get("denied", []):
if _matches_rule_entry(
entry, database_name, table.catalog, table.schema, table.table
):
denied_entries.append(entry)
# If no rules match, return NO_RULE
if not allowed_entries and not denied_entries:
return TableAccessInfo(
access=AccessCheckResult.NO_RULE,
rls_predicates=[],
cls_rules={},
)
# Find the most specific denied entry
most_specific_denied = None
for entry in denied_entries:
if most_specific_denied is None or _is_more_specific(
entry, most_specific_denied
):
most_specific_denied = entry
# Find the most specific allowed entry
most_specific_allowed = None
for entry in allowed_entries:
if most_specific_allowed is None or _is_more_specific(
entry, most_specific_allowed
):
most_specific_allowed = entry
# Determine access: deny wins at same specificity, more specific wins otherwise
if most_specific_denied is not None and most_specific_allowed is not None:
if _is_more_specific(most_specific_denied, most_specific_allowed):
return TableAccessInfo(
access=AccessCheckResult.DENIED,
rls_predicates=[],
cls_rules={},
)
elif _is_more_specific(most_specific_allowed, most_specific_denied):
# Access allowed, collect RLS and CLS from matching entries
pass
else:
# Same specificity: denied wins
return TableAccessInfo(
access=AccessCheckResult.DENIED,
rls_predicates=[],
cls_rules={},
)
elif most_specific_denied is not None:
return TableAccessInfo(
access=AccessCheckResult.DENIED,
rls_predicates=[],
cls_rules={},
)
elif most_specific_allowed is None:
return TableAccessInfo(
access=AccessCheckResult.NO_RULE,
rls_predicates=[],
cls_rules={},
)
# Collect RLS predicates from all matching allowed entries
# (RLS is cumulative - all predicates are applied)
rls_predicates: list[RLSPredicate] = []
for entry in allowed_entries:
rls_config = entry.get("rls")
if rls_config and "predicate" in rls_config:
rls_predicates.append(
RLSPredicate(
predicate=rls_config["predicate"],
group_key=rls_config.get("group_key"),
)
)
# Collect CLS rules from all matching allowed entries
# (CLS is cumulative - strictest action wins per column)
cls_rules: dict[str, CLSAction] = {}
cls_precedence = {
CLSAction.HIDE: 4,
CLSAction.NULLIFY: 3,
CLSAction.MASK: 2,
CLSAction.HASH: 1,
}
action_map = {
"hide": CLSAction.HIDE,
"nullify": CLSAction.NULLIFY,
"mask": CLSAction.MASK,
"hash": CLSAction.HASH,
}
for entry in allowed_entries:
cls_config = entry.get("cls", {})
for column, action_str in cls_config.items():
action = action_map.get(action_str.lower())
if action is None:
logger.warning("Unknown CLS action: %s", action_str)
continue
existing = cls_rules.get(column)
if existing is None or cls_precedence[action] > cls_precedence[existing]:
cls_rules[column] = action
return TableAccessInfo(
access=AccessCheckResult.ALLOWED,
rls_predicates=rls_predicates,
cls_rules=cls_rules,
)
def get_rls_predicates_for_table(
table: Table,
database: Database,
rules: list[DataAccessRule] | None = None,
) -> list[str]:
"""
Get the RLS predicates for a table using the new Data Access Rules system.
This function collects all RLS predicates from matching rules and combines them
using the group_key logic:
- Predicates without group_key are ANDed together
- Predicates with the same group_key are ORed together
- Groups are ANDed together
Args:
table: The fully qualified Table object
database: The Database object
rules: Optional list of rules to check (defaults to current user's rules)
Returns:
List of SQL predicate strings to be ANDed together.
"""
access_info = check_table_access(
database_name=database.database_name,
table=table,
rules=rules,
)
if access_info.access != AccessCheckResult.ALLOWED:
return []
if not access_info.rls_predicates:
return []
# Group predicates by group_key
ungrouped: list[str] = []
groups: dict[str, list[str]] = defaultdict(list)
for pred in access_info.rls_predicates:
if pred.group_key:
groups[pred.group_key].append(f"({pred.predicate})")
else:
ungrouped.append(f"({pred.predicate})")
# Build result: ungrouped predicates + OR'd groups
result = ungrouped.copy()
for group_predicates in groups.values():
if len(group_predicates) == 1:
result.append(group_predicates[0])
else:
result.append(f"({' OR '.join(group_predicates)})")
return result
def get_cls_rules_for_table(
table: Table,
database: Database,
rules: list[DataAccessRule] | None = None,
) -> dict[str, CLSAction]:
"""
Get the CLS rules for a table using the new Data Access Rules system.
Args:
table: The fully qualified Table object
database: The Database object
rules: Optional list of rules to check (defaults to current user's rules)
Returns:
Dict mapping column names to CLSAction values.
"""
access_info = check_table_access(
database_name=database.database_name,
table=table,
rules=rules,
)
if access_info.access != AccessCheckResult.ALLOWED:
return {}
return access_info.cls_rules
def get_hidden_columns_for_table(
table: Table,
database: Database,
rules: list[DataAccessRule] | None = None,
) -> set[str]:
"""
Get the set of column names that should be hidden for a table.
This function checks the CLS rules for the current user and returns
the names of columns that have the "hide" action applied.
Args:
table: The fully qualified Table object
database: The Database object
rules: Optional list of rules to check (defaults to current user's rules)
Returns:
Set of column names that should be hidden.
"""
cls_rules = get_cls_rules_for_table(table, database, rules)
hidden_columns: set[str] = set()
for column_name, action in cls_rules.items():
if action == CLSAction.HIDE:
hidden_columns.add(column_name)
return hidden_columns
def filter_columns_by_cls(
columns: list[dict[str, Any]],
table: Table,
database: Database,
column_name_key: str = "column_name",
) -> list[dict[str, Any]]:
"""
Filter a list of column dictionaries to exclude hidden columns.
This function is useful for filtering column metadata returned by
database reflection or dataset APIs.
Args:
columns: List of column dictionaries
table: The fully qualified Table object
database: The Database object
column_name_key: The key in the column dict that contains the column name
Returns:
Filtered list of columns with hidden columns removed.
"""
if not is_feature_enabled("DATA_ACCESS_RULES"):
return columns
hidden_columns = get_hidden_columns_for_table(table, database)
if not hidden_columns:
return columns
return [
col for col in columns
if col.get(column_name_key) not in hidden_columns
]
def apply_data_access_rules(
database: Database,
catalog: str | None,
schema: str,
parsed_statement: BaseSQLStatement[Any],
) -> None:
"""
Apply Data Access Rules (RLS and CLS) to a parsed SQL statement.
This function:
1. Checks if the DATA_ACCESS_RULES feature is enabled
2. For each table in the query, checks access and collects RLS/CLS rules
3. Applies RLS predicates using the existing infrastructure
4. Applies CLS rules using the existing infrastructure
Args:
database: The Database object
catalog: The default catalog for the query
schema: The default schema for the query
parsed_statement: The parsed SQL statement to modify in place
"""
if not is_feature_enabled("DATA_ACCESS_RULES"):
return
from superset.sql.parse import CLSRules
rules = get_user_rules()
if not rules:
return
# Get the RLS method for this database
method = database.db_engine_spec.get_rls_method()
# Collect RLS predicates and CLS rules for all tables
rls_predicates: dict[Table, list[Any]] = {}
cls_rules: CLSRules = {}
for table in parsed_statement.tables:
qualified_table = table.qualify(catalog=catalog, schema=schema)
# Check access first
access_info = check_table_access(
database_name=database.database_name,
table=qualified_table,
rules=rules,
)
if access_info.access == AccessCheckResult.DENIED:
# TODO: How should we handle denied access mid-query?
# For now, log a warning. In the future, we might raise an exception.
logger.warning(
"Access denied to table %s for user %s",
qualified_table,
getattr(g, "user", "unknown"),
)
continue
# Collect RLS predicates
predicates = get_rls_predicates_for_table(qualified_table, database, rules)
if predicates:
rls_predicates[qualified_table] = [
parsed_statement.parse_predicate(pred) for pred in predicates if pred
]
# Collect CLS rules
table_cls = get_cls_rules_for_table(qualified_table, database, rules)
if table_cls:
cls_rules[qualified_table] = table_cls
# Apply CLS first (before RLS) so that hidden columns are removed
# before RLS wraps the query in a subquery
if cls_rules:
# Build schema dict for sqlglot's qualify() to expand SELECT *
# sqlglot expects nested format: {catalog: {schema: {table: {col: type}}}}
# or {schema: {table: {col: type}}} without catalog
table_schemas: dict[str, Any] = {}
for table in cls_rules.keys():
try:
columns = database.get_columns(table)
col_types = {
col["column_name"]: str(col.get("type", "VARCHAR"))
for col in columns
}
# Build nested structure for sqlglot
if table.catalog:
if table.catalog not in table_schemas:
table_schemas[table.catalog] = {}
if table.schema:
if table.schema not in table_schemas[table.catalog]:
table_schemas[table.catalog][table.schema] = {}
table_schemas[table.catalog][table.schema][table.table] = col_types
else:
table_schemas[table.catalog][table.table] = col_types
elif table.schema:
if table.schema not in table_schemas:
table_schemas[table.schema] = {}
table_schemas[table.schema][table.table] = col_types
else:
table_schemas[table.table] = col_types
except Exception as ex:
logger.warning(
"Could not fetch schema for table %s: %s",
table,
ex,
)
parsed_statement.apply_cls(cls_rules, schema=table_schemas if table_schemas else None)
# Apply RLS after CLS - RLS wraps the query in a subquery with SELECT *
# which will pick up the already-transformed columns from CLS
if rls_predicates:
parsed_statement.apply_rls(catalog, schema, rls_predicates, method)
def get_allowed_tables(
database_name: str,
schema: str | None = None,
catalog: str | None = None,
) -> tuple[set[str], bool]:
"""
Get all table names that the current user has access to via Data Access Rules
for a specific database and schema.
Args:
database_name: The database name to check
schema: Optional schema name to filter by
catalog: Optional catalog name to filter by
Returns:
Tuple of (set of table names, bool indicating if schema-level access is granted).
If schema-level access is granted, the set may be empty but all tables are allowed.
"""
if not is_feature_enabled("DATA_ACCESS_RULES"):
return set(), False
rules = get_user_rules()
if not rules:
return set(), False
table_names: set[str] = set()
schema_level_access = False
for rule in rules:
rule_dict = rule.rule_dict
# Collect tables from allowed entries
for entry in rule_dict.get("allowed", []):
if entry.get("database") != database_name:
continue
# If catalog is specified in the entry, it must match
entry_catalog = entry.get("catalog")
if catalog is not None and entry_catalog is not None:
if entry_catalog != catalog:
continue
# If schema is specified, check if it matches
entry_schema = entry.get("schema")
if schema is not None and entry_schema is not None:
if entry_schema != schema:
continue
# If entry has a table, add it to the set
if table := entry.get("table"):
table_names.add(table)
elif entry_schema == schema or (entry_schema is None and schema is None):
# Schema-level or database-level access without table means all tables
schema_level_access = True
return table_names, schema_level_access
def get_allowed_schemas(database_name: str, catalog: str | None = None) -> set[str]:
"""
Get all schema names that the current user has access to via Data Access Rules
for a specific database.
Args:
database_name: The database name to check
catalog: Optional catalog name to filter by
Returns:
Set of schema names the user has access to.
"""
if not is_feature_enabled("DATA_ACCESS_RULES"):
return set()
rules = get_user_rules()
if not rules:
return set()
schema_names: set[str] = set()
for rule in rules:
rule_dict = rule.rule_dict
# Collect schemas from allowed entries
for entry in rule_dict.get("allowed", []):
if entry.get("database") != database_name:
continue
# If catalog is specified in the entry, it must match
entry_catalog = entry.get("catalog")
if catalog is not None and entry_catalog is not None:
if entry_catalog != catalog:
continue
# If the entry grants database-level access (no schema specified),
# we return an empty set to indicate "all schemas" should be allowed
# This will be handled by the caller
if schema := entry.get("schema"):
schema_names.add(schema)
elif entry.get("database") == database_name:
# Database-level access without schema means all schemas
# Return a special marker that caller can check
schema_names.add("*")
return schema_names
def get_allowed_databases() -> set[str]:
"""
Get all database names that the current user has access to via Data Access Rules.
This function is used to populate database selectors in SQL Lab and elsewhere.
Returns:
Set of database names the user has access to.
"""
if not is_feature_enabled("DATA_ACCESS_RULES"):
return set()
rules = get_user_rules()
if not rules:
return set()
database_names: set[str] = set()
for rule in rules:
rule_dict = rule.rule_dict
# Collect databases from allowed entries
for entry in rule_dict.get("allowed", []):
if database := entry.get("database"):
database_names.add(database)
return database_names
@dataclass
class AllowedTable:
"""A table allowed by DAR with database context."""
database: str
table: str
schema: str | None = None
catalog: str | None = None
@dataclass
class AllowedEntry:
"""
An allowed entry from DAR at any level of the hierarchy.
Fields may be None to indicate "all" at that level:
- database only: all catalogs/schemas/tables in that database
- database + catalog: all schemas/tables in that catalog
- database + schema: all tables in that schema (for DBs without catalogs)
- database + catalog + schema: all tables in that schema
- database + schema + table: specific table (for DBs without catalogs)
- database + catalog + schema + table: specific table
"""
database: str
catalog: str | None = None
schema: str | None = None
table: str | None = None
def get_all_allowed_entries() -> list[AllowedEntry]:
"""
Get all access entries that the current user has via Data Access Rules
across all databases.
This function returns entries at all hierarchy levels (database, schema, table),
allowing callers to build appropriate filters for their use case.
Returns:
List of AllowedEntry objects representing allowed access.
"""
if not is_feature_enabled("DATA_ACCESS_RULES"):
return []
rules = get_user_rules()
if not rules:
return []
allowed_entries: list[AllowedEntry] = []
for rule in rules:
rule_dict = rule.rule_dict
# Collect all allowed entries
for entry in rule_dict.get("allowed", []):
database = entry.get("database")
if not database:
continue
allowed_entries.append(
AllowedEntry(
database=database,
catalog=entry.get("catalog"),
schema=entry.get("schema"),
table=entry.get("table"),
)
)
return allowed_entries
def get_all_allowed_tables() -> list[AllowedTable]:
"""
Get all tables that the current user has access to via Data Access Rules
across all databases.
This function is used for dataset filtering where we need to know all
specific tables the user can access.
Returns:
List of AllowedTable objects representing allowed tables.
"""
if not is_feature_enabled("DATA_ACCESS_RULES"):
return []
rules = get_user_rules()
if not rules:
return []
allowed_tables: list[AllowedTable] = []
for rule in rules:
rule_dict = rule.rule_dict
# Collect tables from allowed entries
for entry in rule_dict.get("allowed", []):
database = entry.get("database")
if not database:
continue
table_name = entry.get("table")
if not table_name:
# Skip database-level or schema-level access for now
# as we can't enumerate all tables without querying the DB
continue
allowed_tables.append(
AllowedTable(
database=database,
table=table_name,
schema=entry.get("schema"),
catalog=entry.get("catalog"),
)
)
return allowed_tables
def get_all_group_keys(
database_name: str | None = None,
table: Table | None = None,
) -> set[str]:
"""
Get all distinct group_keys used in RLS rules.
This is useful for UI discoverability - showing users what group_keys
already exist so they can reuse them for consistent rule grouping.
Args:
database_name: Optional filter by database
table: Optional Table object to filter by catalog/schema/table
Returns:
Set of unique group_key values.
"""
from superset.data_access_rules.models import DataAccessRule
query = db.session.query(DataAccessRule)
rules = query.all()
group_keys: set[str] = set()
for rule in rules:
rule_dict = rule.rule_dict
for entry in rule_dict.get("allowed", []):
# Apply filters if specified
if database_name and entry.get("database") != database_name:
continue
if table is not None:
if table.catalog and entry.get("catalog") != table.catalog:
continue
if table.schema and entry.get("schema") != table.schema:
continue
if table.table and entry.get("table") != table.table:
continue
rls_config = entry.get("rls", {})
if group_key := rls_config.get("group_key"):
group_keys.add(group_key)
return group_keys

View File

@@ -717,16 +717,37 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
if not database:
return self.response_404()
try:
params = kwargs["rison"]
catalogs = database.get_all_catalog_names(
cache=database.catalog_cache_enabled,
cache_timeout=database.catalog_cache_timeout or None,
force=kwargs["rison"].get("force", False),
force=params.get("force", False),
)
catalogs = security_manager.get_catalogs_accessible_by_user(
database,
catalogs,
)
return self.response(200, result=list(catalogs))
# Convert to list and sort
catalogs = sorted(catalogs)
# Apply filter if provided
filter_str = params.get("filter", "").lower()
if filter_str:
catalogs = [c for c in catalogs if filter_str in c.lower()]
# Get total count before pagination
total_count = len(catalogs)
# Apply pagination if provided
page = params.get("page")
page_size = params.get("page_size")
if page is not None and page_size is not None:
start = page * page_size
end = start + page_size
catalogs = catalogs[start:end]
return self.response(200, result=catalogs, count=total_count)
except OperationalError:
return self.response(
500,
@@ -797,21 +818,38 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
)
if params.get("upload_allowed"):
if not database.allow_file_upload:
return self.response(200, result=[])
return self.response(200, result=[], count=0)
if allowed_schemas := database.get_schema_access_for_file_upload():
# some databases might return the list of schemas in uppercase,
# while the list of allowed schemas is manually inputted so
# could be lowercase
allowed_schemas = {schema.lower() for schema in allowed_schemas}
return self.response(
200,
result=[
schema
for schema in schemas
if schema.lower() in allowed_schemas
],
)
return self.response(200, result=list(schemas))
schemas = [
schema
for schema in schemas
if schema.lower() in allowed_schemas
]
# Convert to list and sort
schemas = sorted(schemas)
# Apply filter if provided
filter_str = params.get("filter", "").lower()
if filter_str:
schemas = [s for s in schemas if filter_str in s.lower()]
# Get total count before pagination
total_count = len(schemas)
# Apply pagination if provided
page = params.get("page")
page_size = params.get("page_size")
if page is not None and page_size is not None:
start = page * page_size
end = start + page_size
schemas = schemas[start:end]
return self.response(200, result=schemas, count=total_count)
except OperationalError:
return self.response(
500, message="There was an error connecting to the database"
@@ -874,11 +912,23 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
500:
$ref: '#/components/responses/500'
"""
force = kwargs["rison"].get("force", False)
catalog_name = kwargs["rison"].get("catalog_name")
schema_name = kwargs["rison"].get("schema_name", "")
params = kwargs["rison"]
force = params.get("force", False)
catalog_name = params.get("catalog_name")
schema_name = params.get("schema_name", "")
filter_str = params.get("filter")
page = params.get("page")
page_size = params.get("page_size")
command = TablesDatabaseCommand(pk, catalog_name, schema_name, force)
command = TablesDatabaseCommand(
pk,
catalog_name,
schema_name,
force,
filter_str=filter_str,
page=page,
page_size=page_size,
)
payload = command.run()
return self.response(200, **payload)

View File

@@ -23,11 +23,22 @@ from sqlalchemy.orm import Query
from sqlalchemy.sql.expression import cast
from sqlalchemy.sql.sqltypes import JSON
from superset import security_manager
from superset import is_feature_enabled, security_manager
from superset.models.core import Database
from superset.views.base import BaseFilter
def get_dar_allowed_databases() -> set[str]:
"""Get databases allowed by Data Access Rules for the current user."""
if not is_feature_enabled("DATA_ACCESS_RULES"):
return set()
# Lazy import to avoid circular dependency
from superset.data_access_rules.utils import get_allowed_databases
return get_allowed_databases()
def can_access_databases(view_menu_name: str) -> set[str]:
"""
Return names of databases available in `view_menu_name`.
@@ -62,10 +73,15 @@ class DatabaseFilter(BaseFilter): # pylint: disable=too-few-public-methods
catalog_access_databases = can_access_databases("catalog_access")
schema_access_databases = can_access_databases("schema_access")
datasource_access_databases = can_access_databases("datasource_access")
# Include databases from Data Access Rules
dar_databases = get_dar_allowed_databases()
database_names = sorted(
catalog_access_databases
| schema_access_databases
| datasource_access_databases
| dar_databases
)
return query.filter(

View File

@@ -66,12 +66,20 @@ database_schemas_query_schema = {
"force": {"type": "boolean"},
"upload_allowed": {"type": "boolean"},
"catalog": {"type": "string"},
"filter": {"type": "string"},
"page": {"type": "integer", "minimum": 0},
"page_size": {"type": "integer", "minimum": 1, "maximum": 1000},
},
}
database_catalogs_query_schema = {
"type": "object",
"properties": {"force": {"type": "boolean"}},
"properties": {
"force": {"type": "boolean"},
"filter": {"type": "string"},
"page": {"type": "integer", "minimum": 0},
"page_size": {"type": "integer", "minimum": 1, "maximum": 1000},
},
}
database_tables_query_schema = {
@@ -80,6 +88,9 @@ database_tables_query_schema = {
"force": {"type": "boolean"},
"schema_name": {"type": "string"},
"catalog_name": {"type": "string"},
"filter": {"type": "string"},
"page": {"type": "integer", "minimum": 0},
"page_size": {"type": "integer", "minimum": 1, "maximum": 1000},
},
"required": ["schema_name"],
}

View File

@@ -73,6 +73,12 @@ def get_table_metadata(database: Any, table: Table) -> TableMetadataResponse:
"""
keys = []
columns = database.get_columns(table)
# Filter out columns hidden by CLS rules (lazy import to avoid circular dependency)
from superset.data_access_rules.utils import filter_columns_by_cls
columns = filter_columns_by_cls(columns, table, database)
primary_key = database.get_pk_constraint(table)
if primary_key and primary_key.get("constrained_columns"):
primary_key["column_names"] = primary_key.pop("constrained_columns")

View File

@@ -161,6 +161,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
from superset.dashboards.api import DashboardRestApi
from superset.dashboards.filter_state.api import DashboardFilterStateRestApi
from superset.dashboards.permalink.api import DashboardPermalinkRestApi
from superset.data_access_rules.api import DataAccessRulesRestApi
from superset.databases.api import DatabaseRestApi
from superset.datasets.api import DatasetRestApi
from superset.datasets.columns.api import DatasetColumnsRestApi
@@ -213,6 +214,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
TabStateView,
)
from superset.views.sqla import (
DataAccessRulesView,
RowLevelSecurityView,
TableModelView,
)
@@ -264,6 +266,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
appbuilder.add_api(ReportScheduleRestApi)
appbuilder.add_api(ReportExecutionLogRestApi)
appbuilder.add_api(RLSRestApi)
appbuilder.add_api(DataAccessRulesRestApi)
appbuilder.add_api(SavedQueryRestApi)
appbuilder.add_api(TagRestApi)
appbuilder.add_api(SqlLabRestApi)
@@ -518,6 +521,17 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
icon="fa-lock",
)
if feature_flag_manager.is_feature_enabled("DATA_ACCESS_RULES"):
appbuilder.add_view(
DataAccessRulesView,
"Data Access Rules",
href="DataAccessRulesView.list",
label=_("Data Access Rules"),
category="Security",
category_label=_("Security"),
icon="fa-shield",
)
def init_core_dependencies(self) -> None:
"""Initialize core dependency injection for direct import patterns."""
from superset.core.api.core_api_injection import (

View File

@@ -0,0 +1,84 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""add_data_access_rules_table
Revision ID: a352d7609189
Revises: a9c01ec10479
Create Date: 2025-12-17 10:00:00.000000
"""
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
from superset.migrations.shared.utils import (
create_fks_for_table,
create_table,
drop_table,
)
# revision identifiers, used by Alembic.
revision = "a352d7609189"
down_revision = "a9c01ec10479"
def upgrade():
create_table(
"data_access_rules",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("role_id", sa.Integer(), nullable=False),
sa.Column(
"rule",
sa.Text().with_variant(mysql.MEDIUMTEXT(), "mysql"),
nullable=False,
),
sa.Column("created_on", sa.DateTime(), nullable=True),
sa.Column("changed_on", sa.DateTime(), nullable=True),
sa.Column("created_by_fk", sa.Integer(), nullable=True),
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
# Create foreign key constraints
create_fks_for_table(
"fk_data_access_rules_role_id_ab_role",
"data_access_rules",
"ab_role",
["role_id"],
["id"],
ondelete="CASCADE",
)
create_fks_for_table(
"fk_data_access_rules_created_by_fk_ab_user",
"data_access_rules",
"ab_user",
["created_by_fk"],
["id"],
)
create_fks_for_table(
"fk_data_access_rules_changed_by_fk_ab_user",
"data_access_rules",
"ab_user",
["changed_by_fk"],
["id"],
)
def downgrade():
drop_table("data_access_rules")

View File

@@ -0,0 +1,43 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""add name and description to data_access_rules
Revision ID: b463d8709290
Revises: a352d7609189
Create Date: 2025-12-17 12:00:00.000000
"""
import sqlalchemy as sa
from superset.migrations.shared.utils import add_columns, drop_columns
# revision identifiers, used by Alembic.
revision = "b463d8709290"
down_revision = "a352d7609189"
def upgrade():
add_columns(
"data_access_rules",
sa.Column("name", sa.String(250), nullable=True),
sa.Column("description", sa.Text(), nullable=True),
)
def downgrade():
drop_columns("data_access_rules", "name", "description")

View File

@@ -120,6 +120,7 @@ from superset.utils.core import (
from superset.utils.date_parser import get_past_or_future, normalize_time_delta
from superset.utils.dates import datetime_to_epoch
from superset.utils.rls import apply_rls
from superset.data_access_rules.utils import apply_data_access_rules
class ValidationResultDict(TypedDict):
@@ -1049,6 +1050,22 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
)
sql = self._apply_cte(sql, sqlaq.cte)
# Apply Data Access Rules (RLS and CLS) if enabled
if is_feature_enabled("DATA_ACCESS_RULES"):
try:
default_schema = self.database.get_default_schema(self.catalog)
parsed_script = SQLScript(sql, engine=self.db_engine_spec.engine)
for statement in parsed_script.statements:
apply_data_access_rules(
self.database,
self.catalog,
self.schema or default_schema or "",
statement,
)
sql = parsed_script.format()
except Exception as ex:
logger.warning("Failed to apply Data Access Rules: %s", ex)
if mutate:
sql = self.database.mutate_sql_based_on_config(sql)
return QueryStringExtended(
@@ -2051,6 +2068,23 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
# Log the error but don't fail - RLS application is best-effort
logger.warning("Failed to apply RLS to virtual dataset SQL: %s", ex)
# Apply Data Access Rules to virtual dataset SQL
if is_feature_enabled("DATA_ACCESS_RULES") and parsed_script.statements:
default_schema = self.database.get_default_schema(self.catalog)
try:
for statement in parsed_script.statements:
apply_data_access_rules(
self.database,
self.catalog,
self.schema or default_schema or "",
statement,
)
from_sql = parsed_script.format()
except Exception as ex:
logger.warning(
"Failed to apply Data Access Rules to virtual dataset SQL: %s", ex
)
cte = self.db_engine_spec.get_cte_query(from_sql)
from_clause = (
sa.table(self.db_engine_spec.cte_alias)

View File

@@ -976,6 +976,19 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
}
)
# Data Access Rules
# pylint: disable=import-outside-toplevel
from superset import is_feature_enabled
if is_feature_enabled("DATA_ACCESS_RULES"):
from superset.data_access_rules.utils import get_allowed_schemas
dar_schemas = get_allowed_schemas(database.database_name, catalog)
if "*" in dar_schemas:
# Database-level access means all schemas
return schemas
accessible_schemas.update(dar_schemas)
return schemas & accessible_schemas
def get_catalogs_accessible_by_user(
@@ -1091,6 +1104,24 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
)
}
# Check Data Access Rules
# pylint: disable=import-outside-toplevel
from superset import is_feature_enabled
if is_feature_enabled("DATA_ACCESS_RULES"):
from superset.data_access_rules.utils import get_allowed_tables
dar_tables, schema_level_access = get_allowed_tables(
database.database_name, schema, catalog
)
if schema_level_access:
# Schema-level access means all tables in the schema
return datasource_names
# Add DAR tables to accessible datasources
for table_name in dar_tables:
user_datasources.add(DatasourceName(table_name, schema, catalog))
return [
datasource
for datasource in datasource_names
@@ -2410,6 +2441,20 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
# access to any datasource is sufficient
break
else:
# Check Data Access Rules before denying
if is_feature_enabled("DATA_ACCESS_RULES"):
from superset.data_access_rules.utils import (
AccessCheckResult,
check_table_access,
)
access_info = check_table_access(
database_name=database.database_name,
table=table_,
)
if access_info.access == AccessCheckResult.ALLOWED:
continue
denied.add(table_)
if denied:
@@ -2444,10 +2489,33 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
assert datasource
# Check DAR access for datasource
dar_allowed = False
if is_feature_enabled("DATA_ACCESS_RULES") and hasattr(
datasource, "database"
):
from superset.data_access_rules.utils import (
AccessCheckResult,
check_table_access,
)
from superset.sql.parse import Table
dar_table = Table(
table=datasource.table_name,
schema=datasource.schema,
catalog=getattr(datasource, "catalog", None),
)
dar_access_info = check_table_access(
database_name=datasource.database.database_name,
table=dar_table,
)
dar_allowed = dar_access_info.access == AccessCheckResult.ALLOWED
if not (
self.can_access_schema(datasource)
or self.can_access("datasource_access", datasource.perm or "")
or self.is_owner(datasource)
or dar_allowed
or (
# Grant access to the datasource only if dashboard RBAC is enabled
# or the user is an embedded guest user with access to the dashboard

View File

@@ -129,6 +129,145 @@ class LimitMethod(enum.Enum):
FETCH_MANY = enum.auto()
class CLSAction(enum.Enum):
"""
Column-Level Security actions.
These actions determine how sensitive columns are transformed in queries.
"""
HASH = enum.auto() # Pseudonymization via hashing
NULLIFY = enum.auto() # Replace with NULL
HIDE = enum.auto() # Remove from results entirely
MASK = enum.auto() # Replace with '****'
@dataclass(eq=True, frozen=True)
class Table:
"""
A fully qualified SQL table conforming to [[catalog.]schema.]table.
"""
table: str
schema: str | None = None
catalog: str | None = None
def __str__(self) -> str:
"""
Return the fully qualified SQL table name.
Should not be used for SQL generation, only for logging and debugging, since the
quoting is not engine-specific.
"""
return ".".join(
urllib.parse.quote(part, safe="").replace(".", "%2E")
for part in [self.catalog, self.schema, self.table]
if part
)
def __eq__(self, other: Any) -> bool:
return str(self) == str(other)
def qualify(
self,
*,
catalog: str | None = None,
schema: str | None = None,
) -> Table:
"""
Return a new Table with the given schema and/or catalog, if not already set.
"""
return Table(
table=self.table,
schema=self.schema or schema,
catalog=self.catalog or catalog,
)
# Type alias for CLS rules: {Table: {column_name: action}}
CLSRules = dict[Table, dict[str, CLSAction]]
# CLS action precedence: higher value = stricter (less information revealed)
# HIDE > NULLIFY > MASK > HASH
CLS_ACTION_PRECEDENCE: dict[CLSAction, int] = {
CLSAction.HASH: 1,
CLSAction.MASK: 2,
CLSAction.NULLIFY: 3,
CLSAction.HIDE: 4,
}
def merge_cls_rules(*rules_list: CLSRules) -> CLSRules:
"""
Merge multiple CLS rule sets into one, using the stricter action when conflicts occur.
When multiple rules specify actions for the same table/column, the stricter action
is kept. Precedence (strictest to least strict): HIDE > NULLIFY > MASK > HASH
Args:
*rules_list: Variable number of CLSRules dicts to merge
Returns:
A merged CLSRules dict with the strictest action for each table/column
Example:
>>> rules1 = {Table("foo"): {"col1": CLSAction.HASH}}
>>> rules2 = {Table("foo"): {"col1": CLSAction.NULLIFY, "col2": CLSAction.HIDE}}
>>> merge_cls_rules(rules1, rules2)
{Table("foo"): {"col1": CLSAction.NULLIFY, "col2": CLSAction.HIDE}}
"""
merged: CLSRules = {}
for rules in rules_list:
for table, columns in rules.items():
if table not in merged:
merged[table] = {}
for column, action in columns.items():
existing_action = merged[table].get(column)
if existing_action is None:
merged[table][column] = action
else:
# Keep the stricter action (higher precedence value)
if CLS_ACTION_PRECEDENCE[action] > CLS_ACTION_PRECEDENCE[existing_action]:
merged[table][column] = action
return merged
# Hash function patterns by dialect. The placeholder {} will be replaced with the
# column. Some databases need casting for non-string types, so we cast to string/text.
# The fallback uses a literal since there's no universal hash function across all
# databases.
CLS_HASH_FUNCTIONS: dict[Dialects | type[Dialect] | None, str] = {
None: "'[HASHED]'", # Universal fallback - no hash function available
Dialects.DIALECT: "MD5(CAST({} AS VARCHAR))", # Generic SQL with MD5
Dialects.POSTGRES: "MD5(CAST({} AS TEXT))",
Dialects.MYSQL: "MD5(CAST({} AS CHAR))",
Dialects.BIGQUERY: "TO_HEX(MD5(CAST({} AS STRING)))",
Dialects.SNOWFLAKE: "MD5(TO_VARCHAR({}))",
Dialects.REDSHIFT: "MD5(CAST({} AS VARCHAR))",
Dialects.PRESTO: "TO_HEX(MD5(CAST({} AS VARBINARY)))",
Dialects.TRINO: "TO_HEX(MD5(CAST({} AS VARBINARY)))",
Dialects.SQLITE: "HEX({})", # SQLite doesn't have MD5, use HEX as placeholder
Dialects.DUCKDB: "MD5(CAST({} AS VARCHAR))",
Dialects.ORACLE: "STANDARD_HASH(TO_CHAR({}), 'MD5')",
Dialects.TSQL: (
"CONVERT(VARCHAR(32), HASHBYTES('MD5', CAST({} AS VARCHAR(MAX))), 2)"
),
Dialects.HIVE: "MD5(CAST({} AS STRING))",
Dialects.SPARK: "MD5(CAST({} AS STRING))",
Dialects.CLICKHOUSE: "MD5(toString({}))",
Dialects.DATABRICKS: "MD5(CAST({} AS STRING))",
Dialects.DORIS: "MD5(CAST({} AS VARCHAR))",
Dialects.STARROCKS: "MD5(CAST({} AS VARCHAR))",
Dialects.DRILL: "MD5(CAST({} AS VARCHAR))",
Dialects.DRUID: "MD5(CAST({} AS VARCHAR))",
Dialects.TERADATA: "HASH_MD5(CAST({} AS VARCHAR(10000)))",
Dialects.RISINGWAVE: "MD5(CAST({} AS VARCHAR))",
}
class CTASMethod(enum.Enum):
TABLE = enum.auto()
VIEW = enum.auto()
@@ -279,47 +418,530 @@ class RLSAsSubqueryTransformer(RLSTransformer):
return node
@dataclass(eq=True, frozen=True)
class Table:
class CLSTransformer:
"""
A fully qualified SQL table conforming to [[catalog.]schema.]table.
AST transformer to apply Column-Level Security rules.
This transformer modifies SELECT expressions and predicates to apply CLS actions:
- HASH: Replace column with hash function (database-specific)
- NULLIFY: Replace with NULL AS column_name (SELECT) or FALSE (predicates)
- HIDE: Remove column from SELECT entirely, FALSE in predicates
- MASK: Replace column with '****' AS column_name (SELECT) or FALSE (predicates)
Example:
Given rules: {Table("my_table"): {"id": CLSAction.HASH, "salary": CLSAction.NULLIFY}}
Query: SELECT id, salary, name FROM my_table WHERE id = 1
Result: SELECT MD5(CAST(id AS TEXT)), NULL AS salary, name FROM my_table
WHERE MD5(CAST(id AS TEXT)) = 1
For predicates, HASH transforms the column to ensure filtered results also respect
the security policy. NULLIFY/MASK/HIDE transform to FALSE to prevent information
leakage through filtering.
"""
table: str
schema: str | None = None
catalog: str | None = None
def __str__(self) -> str:
"""
Return the fully qualified SQL table name.
Should not be used for SQL generation, only for logging and debugging, since the
quoting is not engine-specific.
"""
return ".".join(
urllib.parse.quote(part, safe="").replace(".", "%2E")
for part in [self.catalog, self.schema, self.table]
if part
)
def __eq__(self, other: Any) -> bool:
return str(self) == str(other)
def qualify(
def __init__(
self,
*,
catalog: str | None = None,
rules: CLSRules,
dialect: Dialects | type[Dialect] | None,
) -> None:
self.rules = self._normalize_rules(rules)
self.dialect = dialect
self.hash_pattern = CLS_HASH_FUNCTIONS.get(dialect, CLS_HASH_FUNCTIONS[None])
def _normalize_rules(self, rules: CLSRules) -> dict[Table, dict[str, CLSAction]]:
"""
Normalize table and column names to lowercase for case-insensitive matching.
"""
return {
Table(
table=table.table.lower(),
schema=table.schema.lower() if table.schema else None,
catalog=table.catalog.lower() if table.catalog else None,
): {col.lower(): action for col, action in cols.items()}
for table, cols in rules.items()
}
def _get_action(
self,
table_name: str | None,
column_name: str,
schema: str | None = None,
) -> Table:
catalog: str | None = None,
) -> CLSAction | None:
"""
Return a new Table with the given schema and/or catalog, if not already set.
Get the CLS action for a column, if any.
Matching logic:
1. First try exact match with schema/catalog if provided
2. Fallback to table name match - if table names match, apply the rule
regardless of schema/catalog (since query may not have schema info)
"""
return Table(
table=self.table,
schema=self.schema or schema,
catalog=self.catalog or catalog,
if not table_name:
return None
# Create a normalized Table for lookup
lookup_table = Table(
table=table_name.lower(),
schema=schema.lower() if schema else None,
catalog=catalog.lower() if catalog else None,
)
# First try exact match with schema/catalog
table_rules = self.rules.get(lookup_table)
if table_rules:
return table_rules.get(column_name.lower())
# Fallback: match by table name only
# This handles cases where the rule has schema/catalog but the query doesn't
for rule_table, cols in self.rules.items():
if rule_table.table == lookup_table.table:
action = cols.get(column_name.lower())
if action:
return action
return None
def _create_hash_expression(
self,
column: exp.Column,
alias: str,
) -> exp.Expression:
"""
Create a hash expression for a column.
"""
# Generate the column SQL without any alias
col_sql = column.sql(dialect=self.dialect)
hash_sql = self.hash_pattern.format(col_sql)
hash_expr = sqlglot.parse_one(hash_sql, dialect=self.dialect)
return exp.Alias(
this=hash_expr,
alias=exp.Identifier(this=alias),
)
def _create_null_expression(self, alias: str) -> exp.Expression:
"""
Create a NULL AS alias expression.
"""
return exp.Alias(
this=exp.Null(),
alias=exp.Identifier(this=alias),
)
def _create_mask_expression(
self,
column: exp.Column,
alias: str,
) -> exp.Expression:
"""
Create a CASE expression that masks non-NULL values while preserving NULLs.
Generates: CASE WHEN column IS NULL THEN NULL ELSE '****' END AS alias
This preserves the semantic meaning of NULL (no value) vs masked (hidden value).
"""
return exp.Alias(
this=exp.Case(
ifs=[
exp.If(
this=exp.Is(this=column.copy(), expression=exp.Null()),
true=exp.Null(),
)
],
default=exp.Literal(this="****", is_string=True),
),
alias=exp.Identifier(this=alias),
)
def _create_hash_expression_no_alias(
self,
column: exp.Column,
) -> exp.Expression:
"""
Create a hash expression for a column without an alias.
Used for transforming columns in predicates (WHERE, ON, etc.).
"""
col_sql = column.sql(dialect=self.dialect)
hash_sql = self.hash_pattern.format(col_sql)
return sqlglot.parse_one(hash_sql, dialect=self.dialect)
def _get_column_alias(self, expr: exp.Expression) -> str:
"""
Get the alias for a column expression.
"""
if isinstance(expr, exp.Alias):
return expr.alias
if isinstance(expr, exp.Column):
return expr.name
return expr.sql(dialect=self.dialect)
def _get_table_for_column(
self,
column: exp.Column,
scope_tables: dict[str, str],
) -> str | None:
"""
Resolve which table a column belongs to.
Args:
column: The column expression
scope_tables: Map of alias/name to actual table name
Returns:
The table name or None if cannot be resolved
"""
if column.table:
# Column is qualified with table name/alias
return scope_tables.get(column.table.lower(), column.table)
# For unqualified columns, if there's only one table in scope,
# we can infer the column belongs to that table
if len(scope_tables) == 1:
return next(iter(scope_tables.values()))
# With multiple tables, check if any table in rules has this column
# This is a best-effort match for unqualified columns
col_lower = column.name.lower()
for table_name in scope_tables.values():
# Look for a rule matching this table
for rule_table, cols in self.rules.items():
if rule_table.table == table_name.lower() and col_lower in cols:
return table_name
return None
def _extract_scope_tables(self, select: exp.Select) -> dict[str, str]:
"""
Extract table names and aliases from a SELECT statement's FROM clause.
Returns a dict mapping alias (or table name if no alias) to actual table name.
"""
tables: dict[str, str] = {}
if from_clause := select.args.get("from"):
for table in from_clause.find_all(exp.Table):
table_name = table.name
alias = table.alias if table.alias else table_name
tables[alias.lower()] = table_name
for join in select.args.get("joins") or []:
for table in join.find_all(exp.Table):
table_name = table.name
alias = table.alias if table.alias else table_name
tables[alias.lower()] = table_name
return tables
def _transform_nested_column(
self,
column: exp.Column,
scope_tables: dict[str, str],
) -> exp.Expression:
"""
Transform a nested column reference within a SELECT expression.
This handles columns inside CASE expressions, function arguments, etc.
Unlike top-level columns, nested columns use NULL for blocking instead
of FALSE (which works better in non-predicate contexts).
- HASH: Replace with hash function
- NULLIFY/MASK/HIDE: Replace with NULL (blocks computation safely)
"""
table_name = self._get_table_for_column(column, scope_tables)
action = self._get_action(table_name, column.name)
if action is None:
return column
if action == CLSAction.HASH:
return self._create_hash_expression_no_alias(column)
# NULLIFY/MASK/HIDE: Return NULL to safely block any computation
# NULL propagates through expressions: UPPER(NULL)→NULL, 1+NULL→NULL, etc.
return exp.Null()
def _transform_expression(
self,
expr: exp.Expression,
scope_tables: dict[str, str],
) -> exp.Expression | None:
"""
Transform a single SELECT expression based on CLS rules.
For simple column references: apply full transformation with alias.
For complex expressions: transform all nested column references.
Returns:
- Transformed expression
- None if a top-level column should be hidden
"""
# Get the underlying column (handle aliases)
column = expr.this if isinstance(expr, exp.Alias) else expr
alias = self._get_column_alias(expr)
if isinstance(column, exp.Column):
# Simple column reference - apply full transformation with alias
table_name = self._get_table_for_column(column, scope_tables)
action = self._get_action(table_name, column.name)
if action is None:
return expr
if action == CLSAction.HIDE:
return None
if action == CLSAction.HASH:
return self._create_hash_expression(column, alias)
if action == CLSAction.NULLIFY:
return self._create_null_expression(alias)
# action == CLSAction.MASK
return self._create_mask_expression(column, alias)
# Complex expression (CASE, function, arithmetic, etc.)
# Transform ALL nested column references within it
def transform_nested(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Column):
return self._transform_nested_column(node, scope_tables)
return node
return expr.transform(transform_nested)
def _transform_star(
self,
star: exp.Star,
scope_tables: dict[str, str],
) -> list[exp.Expression]:
"""
Transform SELECT * by expanding hidden columns conceptually.
Since we don't have schema information, we cannot truly expand *.
We return the star as-is but log a warning.
"""
# Without schema information, we cannot expand SELECT *
# In a real implementation, you would need to query the database schema
logger.warning(
"CLS cannot fully process SELECT * without schema information. "
"Consider using explicit column lists for queries with CLS rules."
)
return [star]
def _transform_non_select_column(
self,
column: exp.Column,
scope_tables: dict[str, str],
) -> exp.Expression:
"""
Transform a column reference outside of SELECT list.
This is the SINGLE transformation function for ALL column references
outside the SELECT list (WHERE, HAVING, ON, GROUP BY, ORDER BY,
window functions, CASE expressions, function arguments, etc.)
- HASH: Replace with hash function
- NULLIFY/MASK/HIDE: Replace with FALSE (blocks predicates, marked for
removal in GROUP BY/ORDER BY)
"""
table_name = self._get_table_for_column(column, scope_tables)
action = self._get_action(table_name, column.name)
if action is None:
return column
if action == CLSAction.HASH:
return self._create_hash_expression_no_alias(column)
# NULLIFY/MASK/HIDE: Return FALSE to block usage
# For predicates: FALSE blocks the filter
# For GROUP BY/ORDER BY: Will be cleaned up in post-processing
return exp.false()
@staticmethod
def _is_blocked(node: exp.Expression) -> bool:
"""Check if an expression is a blocked column (FALSE or NULL sentinel)."""
# FALSE is used for blocked columns in predicates (Phase 2)
# NULL is used for blocked columns in nested expressions (Phase 1)
if isinstance(node, exp.Boolean) and not node.this:
return True
if isinstance(node, exp.Null):
return True
return False
def _transform_all_non_select_columns(
self,
select: exp.Select,
scope_tables: dict[str, str],
) -> None:
"""
Transform ALL column references outside the SELECT list.
This uses sqlglot's transform() to recursively walk through the entire
expression tree, ensuring we catch columns in:
- WHERE clauses
- HAVING clauses
- JOIN ON conditions
- GROUP BY clauses
- ORDER BY clauses
- Window function PARTITION BY / ORDER BY
- CASE expressions
- Function arguments
- Any other nested expression
This is the security-critical function that ensures NO column reference
is missed.
"""
def transform_column(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Column):
return self._transform_non_select_column(node, scope_tables)
return node
# Transform WHERE
if where := select.args.get("where"):
transformed = where.this.transform(transform_column)
select.set("where", exp.Where(this=transformed))
# Transform HAVING
if having := select.args.get("having"):
transformed = having.this.transform(transform_column)
select.set("having", exp.Having(this=transformed))
# Transform all JOINs (ON conditions)
for join in select.args.get("joins") or []:
if on := join.args.get("on"):
transformed = on.transform(transform_column)
join.set("on", transformed)
# Transform GROUP BY and remove blocked (FALSE) expressions
if group := select.args.get("group"):
new_exprs = []
for expr in group.expressions:
transformed = expr.transform(transform_column)
if not self._is_blocked(transformed):
new_exprs.append(transformed)
if new_exprs:
group.set("expressions", new_exprs)
else:
select.set("group", None)
# Transform ORDER BY and remove blocked (FALSE) expressions
if order := select.args.get("order"):
new_exprs = []
for ordered in order.expressions:
transformed = ordered.transform(transform_column)
# Check the inner expression (Ordered wraps the actual expr)
inner = transformed.this if isinstance(transformed, exp.Ordered) else transformed
if not self._is_blocked(inner):
new_exprs.append(transformed)
if new_exprs:
order.set("expressions", new_exprs)
else:
select.set("order", None)
# Transform Window functions within SELECT expressions
# Window functions have their own PARTITION BY and ORDER BY clauses
for expr in select.args.get("expressions", []):
for window in expr.find_all(exp.Window):
# Transform PARTITION BY
if partition_by := window.args.get("partition_by"):
new_partition = []
for part_expr in partition_by:
transformed = part_expr.transform(transform_column)
if not self._is_blocked(transformed):
new_partition.append(transformed)
window.set("partition_by", new_partition if new_partition else None)
# Transform ORDER BY within window
if window_order := window.args.get("order"):
new_order_exprs = []
for ordered in window_order.expressions:
transformed = ordered.transform(transform_column)
inner = (
transformed.this
if isinstance(transformed, exp.Ordered)
else transformed
)
if not self._is_blocked(inner):
new_order_exprs.append(transformed)
if new_order_exprs:
window_order.set("expressions", new_order_exprs)
else:
window.set("order", None)
def transform_select(self, select: exp.Select) -> exp.Select:
"""
Transform a SELECT statement by applying CLS rules.
This is the main entry point for CLS transformation. It:
1. Extracts table scope for column resolution
2. Transforms SELECT list expressions (with HIDE removal and aliases)
3. Transforms ALL other column references in the query
"""
scope_tables = self._extract_scope_tables(select)
# Phase 1: Transform SELECT list expressions
# This handles HASH/NULLIFY/MASK with aliases, and removes HIDE columns
expressions = select.args.get("expressions", [])
new_expressions: list[exp.Expression] = []
for expr in expressions:
if isinstance(expr, exp.Star):
new_expressions.extend(self._transform_star(expr, scope_tables))
else:
transformed = self._transform_expression(expr, scope_tables)
if transformed is not None:
new_expressions.append(transformed)
select.set("expressions", new_expressions)
# Phase 2: Transform ALL other column references
# This is the security-critical phase that catches every column reference
self._transform_all_non_select_columns(select, scope_tables)
return select
def __call__(self, node: exp.Expression) -> exp.Expression:
"""
Transform callback for sqlglot's transform method.
"""
if isinstance(node, exp.Select):
return self.transform_select(node)
return node
def apply_cls(
sql: str,
rules: CLSRules,
engine: str = "base",
schema: dict[str, dict[str, str]] | None = None,
) -> str:
"""
Apply Column-Level Security rules to a SQL query.
This function transforms a SQL query by applying CLS actions to sensitive columns
in both SELECT expressions and predicates (WHERE, ON, HAVING):
- HASH: Pseudonymize using database-specific hash function (both SELECT and predicates)
- NULLIFY: Replace with NULL (SELECT), FALSE in predicates to block filtering
- HIDE: Remove from SELECT results, FALSE in predicates to block filtering
- MASK: Replace with '****' (SELECT), FALSE in predicates to block filtering
Args:
sql: The SQL query to transform
rules: CLS rules mapping Table objects to column actions
Example: {Table("my_table"): {"id": CLSAction.HASH, "salary": CLSAction.NULLIFY}}
Tables can include schema/catalog for fully qualified matching.
engine: The database engine (used for dialect-specific hash functions)
schema: Optional schema for column qualification. Required for JOINs with
ambiguous column names. Format: {"table": {"column": "TYPE", ...}, ...}
Returns:
The transformed SQL query
"""
if not rules:
return sql
statement = SQLStatement(sql, engine)
statement.apply_cls(rules, schema=schema)
return statement.format(comments=True)
# To avoid unnecessary parsing/formatting of queries, the statement has the concept of
# an "internal representation", which is the AST of the SQL statement. For most of the
@@ -887,6 +1509,46 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
transformer = transformers[method](catalog, schema, predicates)
self._parsed = self._parsed.transform(transformer)
def apply_cls(
self,
rules: CLSRules,
schema: dict[str, dict[str, str]] | None = None,
) -> None:
"""
Apply Column-Level Security rules to the statement inplace.
CLS rules transform sensitive columns in SELECT statements and predicates:
- HASH: Pseudonymize using database-specific hash function (both SELECT and predicates)
- NULLIFY: Replace with NULL (SELECT), FALSE in predicates to block filtering
- HIDE: Remove from SELECT results, FALSE in predicates to block filtering
- MASK: Replace with '****' (SELECT), FALSE in predicates to block filtering
:param rules: CLS rules mapping Table objects to column actions
Example: {Table("my_table"): {"ssn": CLSAction.HASH, "salary": CLSAction.NULLIFY}}
:param schema: Optional schema for column qualification. Required for JOINs
with ambiguous column names. Format: {"table": {"column": "TYPE", ...}}
"""
if not rules:
return
# Always attempt to qualify columns for better CLS resolution.
# With schema: full qualification of all columns.
# Without schema: qualifies single-table queries, partial for JOINs.
from sqlglot.optimizer.qualify import qualify
# Only expand stars if schema is provided (from DAR with feature flag enabled)
# to avoid potential errors in other contexts
self._parsed = qualify(
self._parsed,
schema=schema,
dialect=self._dialect,
validate_qualify_columns=False,
expand_stars=bool(schema),
)
transformer = CLSTransformer(rules, self._dialect)
self._parsed = self._parsed.transform(transformer)
class KQLSplitState(enum.Enum):
"""

View File

@@ -67,6 +67,7 @@ from superset.utils.core import (
from superset.utils.dates import now_as_float
from superset.utils.decorators import stats_timing
from superset.utils.rls import apply_rls
from superset.data_access_rules.utils import apply_data_access_rules
if TYPE_CHECKING:
from superset.models.core import Database
@@ -419,6 +420,13 @@ def execute_sql_statements( # noqa: C901
for statement in parsed_script.statements:
apply_rls(query.database, query.catalog, default_schema, statement)
if is_feature_enabled("DATA_ACCESS_RULES"):
default_schema = query.database.get_default_schema_for_query(query)
for statement in parsed_script.statements:
apply_data_access_rules(
query.database, query.catalog, default_schema, statement
)
if query.select_as_cta:
# CTAS is valid when the last statement is a SELECT, while CVAS is valid when
# there is only a single statement which must be a SELECT.

View File

@@ -17,10 +17,73 @@
from typing import Any
from flask_appbuilder import Model
from sqlalchemy import or_
from sqlalchemy import and_, or_
from sqlalchemy.sql.elements import BooleanClauseList
def get_dar_dataset_filters(base_model: type[Model]) -> list[Any]:
"""
Get SQLAlchemy filters for DAR-allowed datasets.
Handles hierarchical permissions:
- Database-level: allows all tables in the database
- Catalog-level: allows all tables in the catalog
- Schema-level: allows all tables in the schema
- Table-level: allows only the specific table
"""
# pylint: disable=import-outside-toplevel
import logging
from superset import is_feature_enabled
from superset.connectors.sqla.models import Database
if not is_feature_enabled("DATA_ACCESS_RULES"):
return []
try:
from superset.data_access_rules.utils import get_all_allowed_entries
allowed_entries = get_all_allowed_entries()
if not allowed_entries:
return []
# Build OR filters for each allowed entry at its hierarchy level
filters = []
for entry in allowed_entries:
# Start with database filter (always required)
entry_filter = Database.database_name == entry.database
# Add catalog filter if specified
if entry.catalog is not None:
entry_filter = and_(
entry_filter,
base_model.catalog == entry.catalog,
)
# Add schema filter if specified
if entry.schema is not None:
entry_filter = and_(
entry_filter,
base_model.schema == entry.schema,
)
# Add table filter if specified
if entry.table is not None:
entry_filter = and_(
entry_filter,
base_model.table_name == entry.table,
)
filters.append(entry_filter)
return filters
except Exception as ex:
logging.getLogger(__name__).warning(
"Error getting DAR dataset filters: %s", ex
)
return []
def get_dataset_access_filters(
base_model: type[Model],
*args: Any,
@@ -34,10 +97,14 @@ def get_dataset_access_filters(
schema_perms = security_manager.user_view_menu_names("schema_access")
catalog_perms = security_manager.user_view_menu_names("catalog_access")
# Get DAR-based table filters
dar_filters = get_dar_dataset_filters(base_model)
return or_(
Database.id.in_(database_ids),
base_model.perm.in_(perms),
base_model.catalog_perm.in_(catalog_perms),
base_model.schema_perm.in_(schema_perms),
*dar_filters,
*args,
)

View File

@@ -38,6 +38,17 @@ class RowLevelSecurityView(BaseSupersetView):
return super().render_app_template()
class DataAccessRulesView(BaseSupersetView):
route_base = "/dataaccessrules"
class_permission_name = "DataAccessRule"
@expose("/list/")
@has_access
@permission_name("read")
def list(self) -> FlaskResponse:
return super().render_app_template()
class TableModelView(BaseSupersetView):
class_permission_name = "Dataset"
method_permission_name = MODEL_VIEW_RW_METHOD_PERMISSION_MAP

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,261 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for Data Access Rules schemas.
"""
import json
import pytest
from marshmallow import ValidationError
from superset.data_access_rules.schemas import (
DataAccessRulePostSchema,
DataAccessRulePutSchema,
)
def test_post_schema_valid_rule():
"""Test that valid rule JSON is accepted."""
schema = DataAccessRulePostSchema()
data = {
"role_id": 1,
"rule": json.dumps(
{
"allowed": [{"database": "mydb", "schema": "public"}],
"denied": [],
}
),
}
result = schema.load(data)
assert result["role_id"] == 1
assert "allowed" in json.loads(result["rule"])
def test_post_schema_complex_rule():
"""Test that complex rule with RLS and CLS is accepted."""
schema = DataAccessRulePostSchema()
data = {
"role_id": 1,
"rule": json.dumps(
{
"allowed": [
{"database": "mydb", "schema": "public"},
{
"database": "mydb",
"schema": "orders",
"table": "items",
"rls": {"predicate": "org_id = 123", "group_key": "org"},
},
{
"database": "mydb",
"schema": "users",
"table": "info",
"cls": {"email": "mask", "ssn": "hide", "name": "hash"},
},
],
"denied": [{"database": "mydb", "schema": "internal"}],
}
),
}
result = schema.load(data)
assert result["role_id"] == 1
def test_post_schema_invalid_json():
"""Test that invalid JSON is rejected."""
schema = DataAccessRulePostSchema()
data = {
"role_id": 1,
"rule": "not valid json",
}
with pytest.raises(ValidationError) as exc_info:
schema.load(data)
assert "Invalid JSON" in str(exc_info.value)
def test_post_schema_rule_not_object():
"""Test that non-object rule is rejected."""
schema = DataAccessRulePostSchema()
data = {
"role_id": 1,
"rule": json.dumps(["not", "an", "object"]),
}
with pytest.raises(ValidationError) as exc_info:
schema.load(data)
assert "must be a JSON object" in str(exc_info.value)
def test_post_schema_allowed_not_list():
"""Test that non-list 'allowed' is rejected."""
schema = DataAccessRulePostSchema()
data = {
"role_id": 1,
"rule": json.dumps({"allowed": "not a list", "denied": []}),
}
with pytest.raises(ValidationError) as exc_info:
schema.load(data)
assert "'allowed' must be a list" in str(exc_info.value)
def test_post_schema_denied_not_list():
"""Test that non-list 'denied' is rejected."""
schema = DataAccessRulePostSchema()
data = {
"role_id": 1,
"rule": json.dumps({"allowed": [], "denied": "not a list"}),
}
with pytest.raises(ValidationError) as exc_info:
schema.load(data)
assert "'denied' must be a list" in str(exc_info.value)
def test_post_schema_entry_not_object():
"""Test that non-object entry is rejected."""
schema = DataAccessRulePostSchema()
data = {
"role_id": 1,
"rule": json.dumps({"allowed": ["not an object"], "denied": []}),
}
with pytest.raises(ValidationError) as exc_info:
schema.load(data)
assert "must be an object" in str(exc_info.value)
def test_post_schema_entry_missing_database():
"""Test that entry without 'database' is rejected."""
schema = DataAccessRulePostSchema()
data = {
"role_id": 1,
"rule": json.dumps({"allowed": [{"schema": "public"}], "denied": []}),
}
with pytest.raises(ValidationError) as exc_info:
schema.load(data)
assert "'database' field" in str(exc_info.value)
def test_post_schema_invalid_cls_action():
"""Test that invalid CLS action is rejected."""
schema = DataAccessRulePostSchema()
data = {
"role_id": 1,
"rule": json.dumps(
{
"allowed": [
{
"database": "mydb",
"schema": "public",
"cls": {"email": "invalid_action"},
}
],
"denied": [],
}
),
}
with pytest.raises(ValidationError) as exc_info:
schema.load(data)
assert "Invalid CLS action" in str(exc_info.value)
def test_post_schema_missing_role_id():
"""Test that missing role_id is rejected."""
schema = DataAccessRulePostSchema()
data = {
"rule": json.dumps({"allowed": [], "denied": []}),
}
with pytest.raises(ValidationError) as exc_info:
schema.load(data)
assert "role_id" in str(exc_info.value)
def test_post_schema_missing_rule():
"""Test that missing rule is rejected."""
schema = DataAccessRulePostSchema()
data = {
"role_id": 1,
}
with pytest.raises(ValidationError) as exc_info:
schema.load(data)
assert "rule" in str(exc_info.value)
def test_put_schema_partial_update():
"""Test that PUT schema allows partial updates."""
schema = DataAccessRulePutSchema()
# Only updating role_id
data = {"role_id": 2}
result = schema.load(data)
assert result == {"role_id": 2}
# Only updating rule
data = {"rule": json.dumps({"allowed": [{"database": "newdb"}], "denied": []})}
result = schema.load(data)
assert "rule" in result
def test_put_schema_validates_rule_if_provided():
"""Test that PUT schema validates rule if provided."""
schema = DataAccessRulePutSchema()
data = {
"rule": "invalid json",
}
with pytest.raises(ValidationError) as exc_info:
schema.load(data)
assert "Invalid JSON" in str(exc_info.value)
def test_post_schema_empty_allowed_denied():
"""Test that empty allowed and denied lists are valid."""
schema = DataAccessRulePostSchema()
data = {
"role_id": 1,
"rule": json.dumps({"allowed": [], "denied": []}),
}
result = schema.load(data)
assert result["role_id"] == 1
def test_post_schema_cls_all_valid_actions():
"""Test all valid CLS actions are accepted."""
schema = DataAccessRulePostSchema()
data = {
"role_id": 1,
"rule": json.dumps(
{
"allowed": [
{
"database": "mydb",
"schema": "public",
"cls": {
"col1": "hash",
"col2": "HASH", # Case insensitive
"col3": "nullify",
"col4": "NULLIFY",
"col5": "mask",
"col6": "MASK",
"col7": "hide",
"col8": "HIDE",
},
}
],
"denied": [],
}
),
}
result = schema.load(data)
assert result["role_id"] == 1

View File

@@ -0,0 +1,768 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for Data Access Rules utility functions.
"""
from unittest.mock import MagicMock, patch
from superset.data_access_rules.models import DataAccessRule
from superset.data_access_rules.utils import (
_is_more_specific,
_matches_rule_entry,
AccessCheckResult,
check_table_access,
get_all_group_keys,
get_cls_rules_for_table,
get_rls_predicates_for_table,
)
from superset.sql.parse import CLSAction, Table
# Tests for _matches_rule_entry
def test_matches_rule_entry_database_only():
"""Test matching when rule specifies only database."""
entry = {"database": "mydb"}
assert _matches_rule_entry(entry, "mydb", None, None, None) is True
assert _matches_rule_entry(entry, "mydb", "catalog1", "schema1", "table1") is True
assert _matches_rule_entry(entry, "otherdb", None, None, None) is False
def test_matches_rule_entry_database_and_catalog():
"""Test matching when rule specifies database and catalog."""
entry = {"database": "mydb", "catalog": "cat1"}
assert _matches_rule_entry(entry, "mydb", "cat1", None, None) is True
assert _matches_rule_entry(entry, "mydb", "cat1", "schema1", "table1") is True
assert _matches_rule_entry(entry, "mydb", "cat2", None, None) is False
assert _matches_rule_entry(entry, "otherdb", "cat1", None, None) is False
def test_matches_rule_entry_database_and_schema():
"""Test matching when rule specifies database and schema (no catalog)."""
entry = {"database": "mydb", "schema": "public"}
assert _matches_rule_entry(entry, "mydb", None, "public", None) is True
assert _matches_rule_entry(entry, "mydb", None, "public", "table1") is True
assert _matches_rule_entry(entry, "mydb", None, "other", None) is False
def test_matches_rule_entry_full_table():
"""Test matching when rule specifies full table path."""
entry = {"database": "mydb", "schema": "public", "table": "users"}
assert _matches_rule_entry(entry, "mydb", None, "public", "users") is True
assert _matches_rule_entry(entry, "mydb", None, "public", "orders") is False
assert _matches_rule_entry(entry, "mydb", None, "other", "users") is False
def test_matches_rule_entry_with_catalog():
"""Test matching with catalog in the path."""
entry = {
"database": "mydb",
"catalog": "main",
"schema": "public",
"table": "users",
}
assert _matches_rule_entry(entry, "mydb", "main", "public", "users") is True
assert _matches_rule_entry(entry, "mydb", "other", "public", "users") is False
# Tests for _is_more_specific
def test_is_more_specific():
"""Test specificity comparison between entries."""
db_only = {"database": "mydb"}
db_schema = {"database": "mydb", "schema": "public"}
db_table = {"database": "mydb", "schema": "public", "table": "users"}
db_catalog = {"database": "mydb", "catalog": "main"}
db_catalog_schema = {"database": "mydb", "catalog": "main", "schema": "public"}
# More specific should win
assert _is_more_specific(db_schema, db_only) is True
assert _is_more_specific(db_table, db_schema) is True
assert _is_more_specific(db_table, db_only) is True
# Less specific should lose
assert _is_more_specific(db_only, db_schema) is False
assert _is_more_specific(db_schema, db_table) is False
# Same specificity
assert _is_more_specific(db_schema, db_catalog) is False
assert _is_more_specific(db_catalog, db_schema) is False
# Tests for check_table_access
def test_check_table_access_no_rules():
"""Test access check when no rules are provided."""
table = Table(table="users", schema="public", catalog=None)
result = check_table_access("mydb", table, rules=[])
assert result.access == AccessCheckResult.NO_RULE
assert result.rls_predicates == []
assert result.cls_rules == {}
def test_check_table_access_allowed():
"""Test access check when table is allowed."""
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [{"database": "mydb", "schema": "public"}],
"denied": [],
}
table = Table(table="users", schema="public", catalog=None)
result = check_table_access("mydb", table, rules=[rule])
assert result.access == AccessCheckResult.ALLOWED
assert result.rls_predicates == []
assert result.cls_rules == {}
def test_check_table_access_denied():
"""Test access check when table is denied."""
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [],
"denied": [{"database": "mydb", "schema": "public"}],
}
table = Table(table="users", schema="public", catalog=None)
result = check_table_access("mydb", table, rules=[rule])
assert result.access == AccessCheckResult.DENIED
def test_check_table_access_denied_more_specific():
"""Test that more specific deny wins over less specific allow."""
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [{"database": "mydb"}], # Less specific
"denied": [{"database": "mydb", "schema": "secret"}], # More specific
}
# Table in non-denied schema should be allowed
table_public = Table(table="users", schema="public", catalog=None)
result = check_table_access("mydb", table_public, rules=[rule])
assert result.access == AccessCheckResult.ALLOWED
# Table in denied schema should be denied
table_secret = Table(table="data", schema="secret", catalog=None)
result = check_table_access("mydb", table_secret, rules=[rule])
assert result.access == AccessCheckResult.DENIED
def test_check_table_access_allowed_more_specific():
"""Test that more specific allow wins over less specific deny."""
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [{"database": "mydb", "schema": "public", "table": "users"}],
"denied": [{"database": "mydb", "schema": "public"}],
}
# The specific table is allowed despite schema being denied
table_users = Table(table="users", schema="public", catalog=None)
result = check_table_access("mydb", table_users, rules=[rule])
assert result.access == AccessCheckResult.ALLOWED
# Other tables in the schema are still denied
table_orders = Table(table="orders", schema="public", catalog=None)
result = check_table_access("mydb", table_orders, rules=[rule])
assert result.access == AccessCheckResult.DENIED
def test_check_table_access_same_specificity_deny_wins():
"""Test that deny wins when rules have same specificity."""
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [{"database": "mydb", "schema": "public"}],
"denied": [{"database": "mydb", "schema": "public"}],
}
table = Table(table="users", schema="public", catalog=None)
result = check_table_access("mydb", table, rules=[rule])
assert result.access == AccessCheckResult.DENIED
def test_check_table_access_with_rls():
"""Test access check collects RLS predicates."""
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"table": "users",
"rls": {"predicate": "org_id = 123", "group_key": "org"},
}
],
"denied": [],
}
table = Table(table="users", schema="public", catalog=None)
result = check_table_access("mydb", table, rules=[rule])
assert result.access == AccessCheckResult.ALLOWED
assert len(result.rls_predicates) == 1
assert result.rls_predicates[0].predicate == "org_id = 123"
assert result.rls_predicates[0].group_key == "org"
def test_check_table_access_multiple_rls():
"""Test access check collects multiple RLS predicates from different rules."""
rule1 = MagicMock(spec=DataAccessRule)
rule1.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"rls": {"predicate": "org_id = 123", "group_key": "org"},
}
],
"denied": [],
}
rule2 = MagicMock(spec=DataAccessRule)
rule2.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"rls": {"predicate": "region = 'US'"},
}
],
"denied": [],
}
table = Table(table="users", schema="public", catalog=None)
result = check_table_access("mydb", table, rules=[rule1, rule2])
assert result.access == AccessCheckResult.ALLOWED
assert len(result.rls_predicates) == 2
predicates = [p.predicate for p in result.rls_predicates]
assert "org_id = 123" in predicates
assert "region = 'US'" in predicates
def test_check_table_access_with_cls():
"""Test access check collects CLS rules."""
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"table": "users",
"cls": {"email": "mask", "ssn": "hide", "name": "hash"},
}
],
"denied": [],
}
table = Table(table="users", schema="public", catalog=None)
result = check_table_access("mydb", table, rules=[rule])
assert result.access == AccessCheckResult.ALLOWED
assert result.cls_rules == {
"email": CLSAction.MASK,
"ssn": CLSAction.HIDE,
"name": CLSAction.HASH,
}
def test_check_table_access_cls_strictest_wins():
"""Test that strictest CLS action wins when multiple rules apply."""
rule1 = MagicMock(spec=DataAccessRule)
rule1.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"cls": {"email": "mask"}, # Less strict
}
],
"denied": [],
}
rule2 = MagicMock(spec=DataAccessRule)
rule2.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"cls": {"email": "hide"}, # More strict - should win
}
],
"denied": [],
}
table = Table(table="users", schema="public", catalog=None)
result = check_table_access("mydb", table, rules=[rule1, rule2])
assert result.access == AccessCheckResult.ALLOWED
assert result.cls_rules["email"] == CLSAction.HIDE
# Tests for get_rls_predicates_for_table
def test_get_rls_predicates_for_table_no_predicates():
"""Test getting RLS predicates when there are none."""
database = MagicMock()
database.database_name = "mydb"
table = Table(table="users", schema="public", catalog=None)
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [{"database": "mydb", "schema": "public"}],
"denied": [],
}
predicates = get_rls_predicates_for_table(table, database, rules=[rule])
assert predicates == []
def test_get_rls_predicates_for_table_with_predicates():
"""Test getting RLS predicates."""
database = MagicMock()
database.database_name = "mydb"
table = Table(table="users", schema="public", catalog=None)
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"rls": {"predicate": "org_id = 123"},
}
],
"denied": [],
}
predicates = get_rls_predicates_for_table(table, database, rules=[rule])
assert predicates == ["(org_id = 123)"]
def test_get_rls_predicates_for_table_with_group_key():
"""Test getting RLS predicates with group_key combines with OR."""
database = MagicMock()
database.database_name = "mydb"
table = Table(table="users", schema="public", catalog=None)
rule1 = MagicMock(spec=DataAccessRule)
rule1.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"rls": {"predicate": "org_id = 1", "group_key": "org"},
}
],
"denied": [],
}
rule2 = MagicMock(spec=DataAccessRule)
rule2.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"rls": {"predicate": "org_id = 2", "group_key": "org"},
}
],
"denied": [],
}
predicates = get_rls_predicates_for_table(table, database, rules=[rule1, rule2])
# Same group_key predicates should be ORed
assert len(predicates) == 1
assert "(org_id = 1)" in predicates[0]
assert "(org_id = 2)" in predicates[0]
assert " OR " in predicates[0]
def test_get_rls_predicates_for_table_mixed_group_keys():
"""Test getting RLS predicates with mixed group_keys."""
database = MagicMock()
database.database_name = "mydb"
table = Table(table="users", schema="public", catalog=None)
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"rls": {"predicate": "org_id = 1", "group_key": "org"},
},
{
"database": "mydb",
"schema": "public",
"rls": {"predicate": "org_id = 2", "group_key": "org"},
},
{
"database": "mydb",
"schema": "public",
"rls": {"predicate": "region = 'US'"}, # No group_key
},
],
"denied": [],
}
predicates = get_rls_predicates_for_table(table, database, rules=[rule])
# Should have: ungrouped predicate + ORed group predicate = 2 items
assert len(predicates) == 2
has_region = any("region = 'US'" in p for p in predicates)
has_org_group = any("org_id = 1" in p and "org_id = 2" in p for p in predicates)
assert has_region
assert has_org_group
# Tests for get_cls_rules_for_table
def test_get_cls_rules_for_table_no_rules():
"""Test getting CLS rules when there are none."""
database = MagicMock()
database.database_name = "mydb"
table = Table(table="users", schema="public", catalog=None)
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [{"database": "mydb", "schema": "public"}],
"denied": [],
}
cls_rules = get_cls_rules_for_table(table, database, rules=[rule])
assert cls_rules == {}
def test_get_cls_rules_for_table_with_rules():
"""Test getting CLS rules."""
database = MagicMock()
database.database_name = "mydb"
table = Table(table="users", schema="public", catalog=None)
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"table": "users",
"cls": {"email": "mask", "ssn": "hide"},
}
],
"denied": [],
}
cls_rules = get_cls_rules_for_table(table, database, rules=[rule])
assert cls_rules == {"email": CLSAction.MASK, "ssn": CLSAction.HIDE}
# Tests for get_all_group_keys
def test_get_all_group_keys_empty(app_context: None):
"""Test getting group keys when none exist."""
with patch("superset.data_access_rules.utils.db") as mock_db:
mock_db.session.query.return_value.all.return_value = []
keys = get_all_group_keys()
assert keys == set()
def test_get_all_group_keys_with_keys(app_context: None):
"""Test getting group keys from rules."""
rule1 = MagicMock(spec=DataAccessRule)
rule1.rule_dict = {
"allowed": [
{"database": "mydb", "rls": {"predicate": "x=1", "group_key": "key1"}},
{"database": "mydb", "rls": {"predicate": "x=2", "group_key": "key2"}},
],
"denied": [],
}
rule2 = MagicMock(spec=DataAccessRule)
rule2.rule_dict = {
"allowed": [
{"database": "mydb", "rls": {"predicate": "x=3", "group_key": "key1"}},
{"database": "mydb", "rls": {"predicate": "x=4"}}, # No group_key
],
"denied": [],
}
with patch("superset.data_access_rules.utils.db") as mock_db:
mock_db.session.query.return_value.all.return_value = [rule1, rule2]
keys = get_all_group_keys()
assert keys == {"key1", "key2"}
def test_get_all_group_keys_filtered_by_database(app_context: None):
"""Test getting group keys filtered by database."""
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [
{"database": "db1", "rls": {"predicate": "x=1", "group_key": "key1"}},
{"database": "db2", "rls": {"predicate": "x=2", "group_key": "key2"}},
],
"denied": [],
}
with patch("superset.data_access_rules.utils.db") as mock_db:
mock_db.session.query.return_value.all.return_value = [rule]
keys = get_all_group_keys(database_name="db1")
assert keys == {"key1"}
def test_get_all_group_keys_filtered_by_table(app_context: None):
"""Test getting group keys filtered by table."""
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [
{
"database": "db1",
"schema": "public",
"table": "users",
"rls": {"predicate": "x=1", "group_key": "key1"},
},
{
"database": "db1",
"schema": "public",
"table": "orders",
"rls": {"predicate": "x=2", "group_key": "key2"},
},
],
"denied": [],
}
with patch("superset.data_access_rules.utils.db") as mock_db:
mock_db.session.query.return_value.all.return_value = [rule]
table = Table(table="users", schema="public", catalog=None)
keys = get_all_group_keys(database_name="db1", table=table)
assert keys == {"key1"}
# Tests for get_hidden_columns_for_table
def test_get_hidden_columns_for_table_no_hidden(app_context: None):
"""Test getting hidden columns when no columns are hidden."""
from superset.data_access_rules.utils import get_hidden_columns_for_table
database = MagicMock()
database.database_name = "mydb"
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"table": "users",
"cls": {"email": "mask", "phone": "hash"}, # No "hide" actions
}
],
"denied": [],
}
table = Table(table="users", schema="public", catalog=None)
hidden = get_hidden_columns_for_table(table, database, rules=[rule])
assert hidden == set()
def test_get_hidden_columns_for_table_with_hidden(app_context: None):
"""Test getting hidden columns when some columns are hidden."""
from superset.data_access_rules.utils import get_hidden_columns_for_table
database = MagicMock()
database.database_name = "mydb"
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"table": "users",
"cls": {"email": "mask", "ssn": "hide", "password": "hide"},
}
],
"denied": [],
}
table = Table(table="users", schema="public", catalog=None)
hidden = get_hidden_columns_for_table(table, database, rules=[rule])
assert hidden == {"ssn", "password"}
def test_get_hidden_columns_for_table_denied_access(app_context: None):
"""Test that denied access returns no hidden columns."""
from superset.data_access_rules.utils import get_hidden_columns_for_table
database = MagicMock()
database.database_name = "mydb"
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [],
"denied": [
{
"database": "mydb",
"schema": "public",
"table": "users",
}
],
}
table = Table(table="users", schema="public", catalog=None)
hidden = get_hidden_columns_for_table(table, database, rules=[rule])
# Denied access means no CLS rules are returned
assert hidden == set()
# Tests for filter_columns_by_cls
def test_filter_columns_by_cls_no_hidden(app_context: None):
"""Test filtering columns when no columns are hidden."""
from superset.data_access_rules.utils import filter_columns_by_cls
database = MagicMock()
database.database_name = "mydb"
columns = [
{"column_name": "id", "type": "INTEGER"},
{"column_name": "name", "type": "VARCHAR"},
{"column_name": "email", "type": "VARCHAR"},
]
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [
{"database": "mydb", "schema": "public", "table": "users"}
],
"denied": [],
}
table = Table(table="users", schema="public", catalog=None)
with patch(
"superset.data_access_rules.utils.is_feature_enabled",
return_value=True,
):
with patch(
"superset.data_access_rules.utils.get_user_rules",
return_value=[rule],
):
filtered = filter_columns_by_cls(columns, table, database)
assert len(filtered) == 3
assert filtered == columns
def test_filter_columns_by_cls_with_hidden(app_context: None):
"""Test filtering columns when some columns are hidden."""
from superset.data_access_rules.utils import filter_columns_by_cls
database = MagicMock()
database.database_name = "mydb"
columns = [
{"column_name": "id", "type": "INTEGER"},
{"column_name": "name", "type": "VARCHAR"},
{"column_name": "email", "type": "VARCHAR"},
{"column_name": "ssn", "type": "VARCHAR"},
]
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"table": "users",
"cls": {"ssn": "hide"},
}
],
"denied": [],
}
table = Table(table="users", schema="public", catalog=None)
with patch(
"superset.data_access_rules.utils.is_feature_enabled",
return_value=True,
):
with patch(
"superset.data_access_rules.utils.get_user_rules",
return_value=[rule],
):
filtered = filter_columns_by_cls(columns, table, database)
assert len(filtered) == 3
column_names = [c["column_name"] for c in filtered]
assert "ssn" not in column_names
assert "id" in column_names
assert "name" in column_names
assert "email" in column_names
def test_filter_columns_by_cls_feature_disabled(app_context: None):
"""Test that filtering is skipped when feature flag is disabled."""
from superset.data_access_rules.utils import filter_columns_by_cls
database = MagicMock()
database.database_name = "mydb"
columns = [
{"column_name": "id", "type": "INTEGER"},
{"column_name": "ssn", "type": "VARCHAR"},
]
table = Table(table="users", schema="public", catalog=None)
with patch(
"superset.data_access_rules.utils.is_feature_enabled",
return_value=False,
):
# Even if there would be hidden columns, they are not filtered
filtered = filter_columns_by_cls(columns, table, database)
assert len(filtered) == 2
assert filtered == columns
def test_filter_columns_by_cls_custom_key(app_context: None):
"""Test filtering columns with custom column name key."""
from superset.data_access_rules.utils import filter_columns_by_cls
database = MagicMock()
database.database_name = "mydb"
# Columns with different key structure (like from SQL Lab table metadata)
columns = [
{"name": "id", "type": "INTEGER"},
{"name": "ssn", "type": "VARCHAR"},
]
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"table": "users",
"cls": {"ssn": "hide"},
}
],
"denied": [],
}
table = Table(table="users", schema="public", catalog=None)
with patch(
"superset.data_access_rules.utils.is_feature_enabled",
return_value=True,
):
with patch(
"superset.data_access_rules.utils.get_user_rules",
return_value=[rule],
):
filtered = filter_columns_by_cls(
columns, table, database, column_name_key="name"
)
assert len(filtered) == 1
assert filtered[0]["name"] == "id"

View File

@@ -1703,3 +1703,122 @@ def test_adhoc_column_with_spaces_in_full_query(database: Database) -> None:
# Verify SELECT and FROM clauses are present
assert "SELECT" in sql
assert "FROM" in sql
def test_get_query_str_extended_with_data_access_rules(
mocker: MockerFixture,
database: Database,
) -> None:
"""
Test that get_query_str_extended calls apply_data_access_rules when enabled.
This test mocks the get_sqla_query to return a simple SQL query, then verifies
that apply_data_access_rules is called when the feature flag is enabled.
"""
from unittest.mock import MagicMock
import sqlalchemy as sa
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.models.helpers import SqlaQuery
table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[TableColumn(column_name="a")],
)
# Create a mock SqlaQuery result (fields must be in order per NamedTuple)
mock_sqla_query = SqlaQuery(
applied_template_filters=[],
applied_filter_columns=[],
rejected_filter_columns=[],
cte=None,
extra_cache_keys=[],
labels_expected=["a"],
prequeries=[],
sqla_query=sa.select(sa.column("a")).select_from(sa.table("t")),
)
# Mock get_sqla_query to return our simple query
mocker.patch.object(table, "get_sqla_query", return_value=mock_sqla_query)
# Mock apply_data_access_rules
mock_apply_dar = mocker.patch(
"superset.models.helpers.apply_data_access_rules"
)
# Mock is_feature_enabled to return True for DATA_ACCESS_RULES
mocker.patch(
"superset.models.helpers.is_feature_enabled",
return_value=True,
)
# Mock database.get_default_schema
mocker.patch.object(
database, "get_default_schema", return_value="public"
)
query_obj = {
"granularity": None,
"from_dttm": None,
"to_dttm": None,
"columns": ["a"],
"metrics": [],
"orderby": [],
"row_limit": 100,
"filter": [],
}
result = table.get_query_str_extended(query_obj)
# Verify we got a result
assert result is not None
assert result.sql is not None
# Verify apply_data_access_rules was called
assert mock_apply_dar.called
def test_get_from_clause_with_data_access_rules(
mocker: MockerFixture,
database: Database,
) -> None:
"""
Test that get_from_clause calls apply_data_access_rules for virtual datasets.
"""
from superset.connectors.sqla.models import SqlaTable, TableColumn
# Create a virtual dataset (has SQL defined)
table = SqlaTable(
database=database,
schema=None,
table_name="virtual_table",
sql="SELECT * FROM base_table",
columns=[TableColumn(column_name="a")],
)
# Mock apply_data_access_rules
mock_apply_dar = mocker.patch(
"superset.models.helpers.apply_data_access_rules"
)
# Mock is_feature_enabled to return True for DATA_ACCESS_RULES
mocker.patch(
"superset.models.helpers.is_feature_enabled",
return_value=True,
)
# Mock database.get_default_schema
mocker.patch.object(
database, "get_default_schema", return_value="public"
)
try:
table.get_from_clause()
except Exception:
pass # We're just testing that apply_data_access_rules is called
# Verify apply_data_access_rules was called
assert mock_apply_dar.called

File diff suppressed because it is too large Load Diff

View File

@@ -36,7 +36,9 @@ from superset.sql_lab import (
get_sql_results,
)
from superset.utils.rls import apply_rls, get_predicates_for_table
from superset.data_access_rules.utils import apply_data_access_rules
from tests.conftest import with_config
from tests.unit_tests.conftest import with_feature_flags
from tests.unit_tests.models.core_test import oauth2_client_info
@@ -301,3 +303,119 @@ def test_get_predicates_for_table(mocker: MockerFixture) -> None:
table = Table("t1", "public", "examples")
assert get_predicates_for_table(table, database, "examples") == ["c1 = 1"]
def test_apply_data_access_rules(mocker: MockerFixture) -> None:
"""
Test the ``apply_data_access_rules`` helper function.
"""
from superset.data_access_rules.utils import (
AccessCheckResult,
RLSPredicate,
TableAccessInfo,
)
database = mocker.MagicMock()
database.database_name = "test_db"
database.get_default_schema_for_query.return_value = "public"
database.get_default_catalog.return_value = "examples"
database.db_engine_spec = PostgresEngineSpec
# Mock get_user_rules to return a rule with RLS predicates
get_user_rules = mocker.patch(
"superset.data_access_rules.utils.get_user_rules",
return_value=[],
)
# Mock check_table_access to return allowed access with RLS predicates
mocker.patch(
"superset.data_access_rules.utils.check_table_access",
return_value=TableAccessInfo(
access=AccessCheckResult.ALLOWED,
rls_predicates=[
RLSPredicate(predicate="org_id = 1", group_key=None),
],
cls_rules={},
),
)
# Mock is_feature_enabled
mocker.patch(
"superset.data_access_rules.utils.is_feature_enabled",
return_value=True,
)
parsed_statement = SQLStatement("SELECT * FROM t1", "postgresql")
apply_data_access_rules(database, "examples", "public", parsed_statement)
# Since we mocked the feature flag check, the function should have processed
get_user_rules.assert_called_once()
@with_feature_flags(DATA_ACCESS_RULES=True)
@with_config(
{
"SQLLAB_PAYLOAD_MAX_MB": 50,
"DISALLOWED_SQL_FUNCTIONS": {},
"SQLLAB_CTAS_NO_LIMIT": False,
"SQL_MAX_ROW": 100000,
"QUERY_LOGGER": None,
"TROUBLESHOOTING_LINK": None,
"STATS_LOGGER": MagicMock(),
}
)
def test_execute_sql_statements_with_data_access_rules(
mocker: MockerFixture, app
) -> None:
"""
Test that `execute_sql_statements` calls `apply_data_access_rules`
when the DATA_ACCESS_RULES feature flag is enabled.
"""
# Mock apply_data_access_rules to track calls
mock_apply_dar = mocker.patch("superset.sql_lab.apply_data_access_rules")
# Mock the query object and database
query = mocker.MagicMock()
query.limit = 1
query.database = mocker.MagicMock()
query.database.cache_timeout = 100
query.status = "RUNNING"
query.select_as_cta = False
query.database.allow_run_async = True
query.database.allow_dml = False
query.catalog = "examples"
query.database.get_default_schema_for_query.return_value = "public"
# Mock get_query to return our mocked query object
mocker.patch("superset.sql_lab.get_query", return_value=query)
# Mock db.session.refresh
mocker.patch("superset.sql_lab.db.session.refresh", return_value=None)
# Mock the results backend
mocker.patch("superset.sql_lab.results_backend", return_value=True)
# Mock sys.getsizeof to simulate a small payload
mocker.patch("sys.getsizeof", return_value=1000)
# Mock _serialize_payload
mocker.patch(
"superset.sql_lab._serialize_payload", return_value="serialized_payload"
)
try:
execute_sql_statements(
query_id=1,
rendered_query="SELECT 42 AS answer",
return_results=True,
store_results=True,
start_time=None,
expand_data=False,
log_params={},
)
except Exception:
pass # We're just testing that apply_data_access_rules is called
# Verify apply_data_access_rules was called
assert mock_apply_dar.called