Compare commits

...

16 Commits

Author SHA1 Message Date
Beto Dealmeida
0ce43abe5b Address comments 2026-05-12 10:17:50 -04:00
Beto Dealmeida
439130db54 Add dialects for Exa/Solr 2026-05-12 09:59:55 -04:00
Beto Dealmeida
ec9f2da81e Add to DB without sqlglot dialect 2026-05-12 09:41:30 -04:00
Beto Dealmeida
549e51bdc4 Fix lint 2026-05-12 09:25:57 -04:00
Beto Dealmeida
72ecb20e5c Run ruff-format 2026-05-08 17:27:11 -04:00
Beto Dealmeida
d1cd84931e Increase coverage 2026-05-08 16:40:29 -04:00
Beto Dealmeida
884649e3ed Increase parity 2026-05-08 15:32:10 -04:00
Beto Dealmeida
e9dd9a6107 Simplify function signature 2026-05-08 14:28:21 -04:00
Beto Dealmeida
0fe8293c7f Simplify 2026-05-08 14:12:46 -04:00
Beto Dealmeida
af2d3babec Improvements 2026-05-08 13:49:37 -04:00
Beto Dealmeida
3fe2b2505f feat: new splice RLSMethod 2026-05-08 13:33:42 -04:00
Evan Rusackas
5bde86785f fix(docs): read capability flags from engine specs in database docs generator (#39449)
Co-authored-by: Superset Dev <dev@superset.apache.org>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-08 09:13:08 -07:00
Mehmet Salih Yavuz
69fbbfd7ce fix(table): consolidate visual column options under Visual formatting section (#39856) 2026-05-08 10:43:38 +03:00
Enzo Martellucci
d3784879c2 fix(embedded-sdk): grant fullscreen and clipboard-write by default (#39943) 2026-05-08 09:28:55 +02:00
Vitor Avila
ad5e3170dd fix: OpenSearch dialect identifier delimiters (#39953) 2026-05-07 16:19:27 -03:00
Maxime Beauchemin
aa710672ed fix(ui): remove makeUrl() double-prefix bugs under subdirectory deployment (#39503)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Vitor Avila <96086495+Vitor-Avila@users.noreply.github.com>
2026-05-07 15:39:38 -03:00
36 changed files with 8702 additions and 3682 deletions

View File

@@ -141,6 +141,47 @@ def eval_node(node):
return "<f-string>"
return None
def static_return_bool(func_node):
"""
Statically resolve a method's return value to a bool when possible.
Returns True/False for functions whose body is (effectively) a single
\`return True\` / \`return False\` — allowing a leading docstring and
ignoring pure-comment/pass statements. Returns None for anything more
complex (conditional returns, computed values, no return, etc.).
Used by \`has_implicit_cancel\` handling: \`diagnose()\` in lib.py calls
the method and checks the return value, so an override that explicitly
returns False must NOT be treated as enabling query cancelation.
"""
returns = []
other_logic = False
docstring_skipped = False
for stmt in func_node.body:
# Skip docstring (only the FIRST expression statement that is a
# string constant — later bare string literals are not docstrings
# and should count as non-trivial logic).
if (not docstring_skipped
and isinstance(stmt, ast.Expr)
and isinstance(stmt.value, ast.Constant)
and isinstance(stmt.value.value, str)):
docstring_skipped = True
continue
if isinstance(stmt, ast.Pass):
continue
if isinstance(stmt, ast.Return):
returns.append(stmt)
continue
# Any other statement (if/for/assign/etc.) means control flow is
# non-trivial; bail out to be conservative.
other_logic = True
break
if other_logic or len(returns) != 1:
return None
val = eval_node(returns[0].value)
return val if isinstance(val, bool) else None
def deep_merge(base, override):
"""Deep merge two dictionaries. Override values take precedence."""
if base is None:
@@ -186,8 +227,55 @@ if not os.path.isdir(specs_dir):
print(json.dumps({"error": f"Directory not found: {specs_dir}", "cwd": os.getcwd()}))
sys.exit(1)
# First pass: collect all class info (name, bases, metadata)
class_info = {} # class_name -> {bases: [], metadata: {}, engine_name: str, filename: str}
# Capability flag attributes with their defaults from BaseEngineSpec
CAP_ATTR_DEFAULTS = {
'supports_dynamic_schema': False,
'supports_catalog': False,
'supports_dynamic_catalog': False,
'disable_ssh_tunneling': False,
'supports_file_upload': True,
'allows_joins': True,
'allows_subqueries': True,
}
# Maps source capability attribute -> output field name used in databases.json.
# When a cap attr is assigned an unevaluable expression (e.g.
# allows_joins = is_feature_enabled("DRUID_JOINS")), the JS layer uses this
# mapping to preserve the corresponding field from the previously-generated
# JSON rather than silently inheriting an incorrect parent default.
CAP_ATTR_TO_OUTPUT_FIELD = {
'allows_joins': 'joins',
'allows_subqueries': 'subqueries',
'supports_dynamic_schema': 'supports_dynamic_schema',
'supports_catalog': 'supports_catalog',
'supports_dynamic_catalog': 'supports_dynamic_catalog',
'disable_ssh_tunneling': 'ssh_tunneling',
'supports_file_upload': 'supports_file_upload',
}
# Methods that indicate a capability when overridden by a non-BaseEngineSpec class.
# Mirrors the has_custom_method checks in superset/db_engine_specs/lib.py.
# cancel_query / has_implicit_cancel -> query_cancelation
# (diagnose() checks cancel_query override OR has_implicit_cancel() == True;
# base has_implicit_cancel returns False, so overriding it is the static
# equivalent of that method returning True. get_cancel_query_id is NOT
# part of the diagnose() heuristic and is intentionally excluded.)
# estimate_statement_cost / estimate_query_cost -> query_cost_estimation
# impersonate_user / update_impersonation_config / get_url_for_impersonation -> user_impersonation
# validate_sql -> sql_validation (not used yet; validation is engine-based)
CAP_METHODS = {
'cancel_query', 'has_implicit_cancel',
'estimate_statement_cost', 'estimate_query_cost',
'impersonate_user', 'update_impersonation_config', 'get_url_for_impersonation',
'validate_sql',
}
# Only the literal BaseEngineSpec is excluded from method-override tracking.
# Intermediate base classes (e.g. PrestoBaseEngineSpec) do count as overrides.
TRUE_BASE_CLASS = 'BaseEngineSpec'
# First pass: collect all class info (name, bases, metadata, cap_attrs, direct_methods)
class_info = {} # class_name -> {bases: [], metadata: {}, engine_name: str, filename: str, ...}
for filename in sorted(os.listdir(specs_dir)):
if not filename.endswith('.py') or filename in ('__init__.py', 'lib.py', 'lint_metadata.py'):
@@ -218,30 +306,54 @@ for filename in sorted(os.listdir(specs_dir)):
# Extract class attributes
engine_name = None
engine_attr = None
metadata = None
cap_attrs = {} # capability flag attributes defined directly in this class
# Cap attrs assigned via expressions we can't statically resolve
# (e.g. is_feature_enabled("FLAG")). Tracked so the JS layer can
# fall back to the previously-generated databases.json value
# rather than inherit a parent default that would be wrong.
unresolved_cap_attrs = set()
direct_methods = set() # capability methods defined directly in this class
for item in node.body:
if isinstance(item, ast.Assign):
for target in item.targets:
if isinstance(target, ast.Name):
if target.id == 'engine_name':
val = eval_node(item.value)
if isinstance(val, str):
engine_name = val
elif target.id == 'metadata':
metadata = eval_node(item.value)
if not isinstance(target, ast.Name):
continue
if target.id == 'engine_name':
val = eval_node(item.value)
if isinstance(val, str):
engine_name = val
elif target.id == 'engine':
val = eval_node(item.value)
if isinstance(val, str):
engine_attr = val
elif target.id == 'metadata':
metadata = eval_node(item.value)
elif target.id in CAP_ATTR_DEFAULTS:
val = eval_node(item.value)
if isinstance(val, bool):
cap_attrs[target.id] = val
else:
# Unevaluable expression — defer to JS fallback.
unresolved_cap_attrs.add(target.id)
elif isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
if item.name in CAP_METHODS:
# has_implicit_cancel is special: diagnose() uses the
# method's RETURN VALUE, not just its presence. If the
# override statically returns False, treat it as if
# the method weren't overridden so query_cancelation
# matches diagnose(). Unresolvable / True / anything
# else falls through as an override (conservative).
if item.name == 'has_implicit_cancel':
if static_return_bool(item) is False:
continue
direct_methods.add(item.name)
# Check for engine attribute with non-empty value to distinguish
# true base classes from product classes like OceanBaseEngineSpec
has_non_empty_engine = False
for item in node.body:
if isinstance(item, ast.Assign):
for target in item.targets:
if isinstance(target, ast.Name) and target.id == 'engine':
# Check if engine value is non-empty string
if isinstance(item.value, ast.Constant):
has_non_empty_engine = bool(item.value.value)
break
has_non_empty_engine = engine_attr is not None and bool(engine_attr)
# True base classes: end with BaseEngineSpec AND don't define engine
# or have empty engine (like PostgresBaseEngineSpec with engine = "")
@@ -254,13 +366,18 @@ for filename in sorted(os.listdir(specs_dir)):
'bases': base_names,
'metadata': metadata,
'engine_name': engine_name,
'engine': engine_attr,
'filename': filename,
'is_base_or_mixin': is_true_base,
'cap_attrs': cap_attrs,
'unresolved_cap_attrs': unresolved_cap_attrs,
'direct_methods': direct_methods,
}
except Exception as e:
errors.append(f"{filename}: {str(e)}")
# Second pass: resolve inheritance and build final metadata
# Second pass: resolve inheritance and build final metadata + capability flags
def get_inherited_metadata(class_name, visited=None):
"""Recursively get metadata from parent classes."""
if visited is None:
@@ -286,6 +403,64 @@ def get_inherited_metadata(class_name, visited=None):
return inherited
def get_resolved_caps(class_name, visited=None):
"""
Resolve capability flags and method overrides with inheritance.
Returns (attr_values, unresolved, methods):
- attr_values: {attr: bool} for attrs where the nearest MRO assignment
was a literal bool. Defaults are applied at the call site.
- unresolved: attrs where the nearest MRO assignment was an unevaluable
expression (e.g. is_feature_enabled("FLAG")). The JS layer falls
back to the previously-generated JSON value for these.
- methods: capability methods defined directly in some non-base ancestor,
matching the has_custom_method() logic in db_engine_specs/lib.py.
attr_values and unresolved are disjoint — an attr is in at most one.
"""
if visited is None:
visited = set()
if class_name in visited:
return {}, set(), set()
visited.add(class_name)
info = class_info.get(class_name)
if not info:
return {}, set(), set()
attr_values = {}
unresolved = set()
resolved_methods = set()
# Collect from parents, iterating right-to-left so leftmost bases win
# (matches Python MRO: for class C(A, B), A's attributes take precedence).
for base_name in reversed(info['bases']):
p_vals, p_unres, p_meth = get_resolved_caps(base_name, visited.copy())
# A parent's literal assignments overwrite whatever we inherited so far.
for attr, val in p_vals.items():
attr_values[attr] = val
unresolved.discard(attr)
# A parent's unresolved assignments likewise take precedence.
for attr in p_unres:
unresolved.add(attr)
attr_values.pop(attr, None)
resolved_methods.update(p_meth)
# Apply this class's own assignments (override parents).
for attr, val in info['cap_attrs'].items():
attr_values[attr] = val
unresolved.discard(attr)
for attr in info['unresolved_cap_attrs']:
unresolved.add(attr)
attr_values.pop(attr, None)
# Accumulate method overrides, but skip the literal BaseEngineSpec
# (its implementations are stubs; only non-base overrides count).
if class_name != TRUE_BASE_CLASS:
resolved_methods.update(info['direct_methods'])
return attr_values, unresolved, resolved_methods
for class_name, info in class_info.items():
# Skip base classes and mixins
if info['is_base_or_mixin']:
@@ -310,7 +485,14 @@ for class_name, info in class_info.items():
if final_metadata and isinstance(final_metadata, dict) and display_name:
debug_info["classes_with_metadata"] += 1
databases[display_name] = {
# Resolve capability flags from Python source
attr_values, unresolved_caps, cap_methods = get_resolved_caps(class_name)
cap_attrs = dict(CAP_ATTR_DEFAULTS)
cap_attrs.update(attr_values)
engine_attr = info.get('engine') or ''
entry = {
'engine': display_name.lower().replace(' ', '_'),
'engine_name': display_name,
'module': info['filename'][:-3], # Remove .py extension
@@ -318,19 +500,40 @@ for class_name, info in class_info.items():
'time_grains': {},
'score': 0,
'max_score': 0,
'joins': True,
'subqueries': True,
'supports_dynamic_schema': False,
'supports_catalog': False,
'supports_dynamic_catalog': False,
'ssh_tunneling': False,
'query_cancelation': False,
'supports_file_upload': False,
'user_impersonation': False,
'query_cost_estimation': False,
'sql_validation': False,
# Capability flags read from engine spec class attributes/methods
'joins': cap_attrs['allows_joins'],
'subqueries': cap_attrs['allows_subqueries'],
'supports_dynamic_schema': cap_attrs['supports_dynamic_schema'],
'supports_catalog': cap_attrs['supports_catalog'],
'supports_dynamic_catalog': cap_attrs['supports_dynamic_catalog'],
'ssh_tunneling': not cap_attrs['disable_ssh_tunneling'],
'supports_file_upload': cap_attrs['supports_file_upload'],
# Method-based flags: True only when a non-base class overrides them.
# Matches diagnose() in lib.py: cancel_query override OR
# has_implicit_cancel() returning True (which, given the base
# returns False, is equivalent to overriding has_implicit_cancel).
'query_cancelation': bool({'cancel_query', 'has_implicit_cancel'} & cap_methods),
'query_cost_estimation': bool({'estimate_statement_cost', 'estimate_query_cost'} & cap_methods),
# SQL validation is implemented in external validator classes keyed by engine name
'sql_validation': engine_attr in {'presto', 'postgresql'},
'user_impersonation': bool(
{'impersonate_user', 'update_impersonation_config', 'get_url_for_impersonation'} & cap_methods
),
}
# Tell the JS layer which output fields were populated from the
# BaseEngineSpec default because the source assignment was an
# unevaluable expression; those get overridden from existing JSON.
unresolved_fields = sorted(
CAP_ATTR_TO_OUTPUT_FIELD[attr]
for attr in unresolved_caps
if attr in CAP_ATTR_TO_OUTPUT_FIELD
)
if unresolved_fields:
entry['_unresolved_cap_fields'] = unresolved_fields
databases[display_name] = entry
if errors and not databases:
print(json.dumps({"error": "Parse errors", "details": errors, "debug": debug_info}), file=sys.stderr)
@@ -851,24 +1054,52 @@ function loadExistingData() {
}
}
/**
* Fall back to the previously-generated databases.json for capability flags
* whose source assignment couldn't be statically resolved (e.g.
* `allows_joins = is_feature_enabled("DRUID_JOINS")`). The Python extractor
* flags these via the internal `_unresolved_cap_fields` marker; without this
* fallback those fields would silently inherit the BaseEngineSpec default
* and disagree with runtime behavior. The marker is stripped before output.
*/
function fallbackUnresolvedCaps(newDatabases, existingData) {
for (const [name, db] of Object.entries(newDatabases)) {
const unresolved = db._unresolved_cap_fields;
if (!unresolved || unresolved.length === 0) {
delete db._unresolved_cap_fields;
continue;
}
const existingDb = existingData?.databases?.[name];
if (existingDb) {
for (const field of unresolved) {
if (existingDb[field] !== undefined) {
db[field] = existingDb[field];
}
}
}
delete db._unresolved_cap_fields;
}
return newDatabases;
}
/**
* Merge new documentation with existing diagnostics
* Preserves score, time_grains, and feature flags from existing data
* Preserves score, max_score, and time_grains from existing data (these require
* Flask context to generate and cannot be derived from static source analysis).
* Capability flags (joins, supports_catalog, etc.) are NOT preserved here — they
* are read fresh from the Python engine spec source by extractEngineSpecMetadata(),
* with a separate fallback for expression-based assignments (see fallbackUnresolvedCaps).
*/
function mergeWithExistingDiagnostics(newDatabases, existingData) {
if (!existingData?.databases) return newDatabases;
const diagnosticFields = [
'score', 'max_score', 'time_grains', 'joins', 'subqueries',
'supports_dynamic_schema', 'supports_catalog', 'supports_dynamic_catalog',
'ssh_tunneling', 'query_cancelation', 'supports_file_upload',
'user_impersonation', 'query_cost_estimation', 'sql_validation'
];
// Only preserve fields that require Flask/runtime context to generate
const diagnosticFields = ['score', 'max_score', 'time_grains'];
for (const [name, db] of Object.entries(newDatabases)) {
const existingDb = existingData.databases[name];
if (existingDb && existingDb.score > 0) {
// Preserve diagnostics from existing data
// Preserve score/time_grain diagnostics from existing data
for (const field of diagnosticFields) {
if (existingDb[field] !== undefined) {
db[field] = existingDb[field];
@@ -879,7 +1110,7 @@ function mergeWithExistingDiagnostics(newDatabases, existingData) {
const preserved = Object.values(newDatabases).filter(d => d.score > 0).length;
if (preserved > 0) {
console.log(`Preserved diagnostics for ${preserved} databases from existing data`);
console.log(`Preserved score/time_grains for ${preserved} databases from existing data`);
}
return newDatabases;
@@ -927,6 +1158,12 @@ async function main() {
databases = mergeWithExistingDiagnostics(databases, existingData);
}
// For cap flags assigned via unevaluable expressions (e.g.
// `is_feature_enabled(...)`), prefer the value from a previously-generated
// JSON. Runs regardless of scores since it addresses static-analysis gaps,
// not missing Flask diagnostics. Always strips the internal marker.
databases = fallbackUnresolvedCaps(databases, existingData);
// Extract and merge custom_errors for troubleshooting documentation
const customErrors = extractCustomErrors();
mergeCustomErrors(databases, customErrors);

File diff suppressed because it is too large Load Diff

View File

@@ -142,7 +142,7 @@ druid = ["pydruid>=0.6.5,<0.7"]
duckdb = ["duckdb>=1.4.2,<2", "duckdb-engine>=0.17.0"]
dynamodb = ["pydynamodb>=0.4.2"]
solr = ["sqlalchemy-solr >= 0.2.0"]
elasticsearch = ["elasticsearch-dbapi>=0.2.12, <0.3.0"]
elasticsearch = ["elasticsearch-dbapi>=0.2.13, <0.3.0"]
exasol = ["sqlalchemy-exasol >= 2.4.0, <3.0"]
excel = ["xlrd>=1.2.0, <1.3"]
fastmcp = [

View File

@@ -66,7 +66,7 @@ export type EmbedDashboardParams = {
iframeTitle?: string;
/** additional iframe sandbox attributes ex (allow-top-navigation, allow-popups-to-escape-sandbox) **/
iframeSandboxExtras?: string[];
/** iframe allow attribute for Permissions Policy (e.g., ['clipboard-write', 'fullscreen']) **/
/** Additional Permissions Policy features for the iframe's `allow` attribute (e.g., ['camera', 'microphone']). `fullscreen` and `clipboard-write` are granted by default. **/
iframeAllowExtras?: string[];
/** force a specific refererPolicy to be used in the iframe request **/
referrerPolicy?: ReferrerPolicy;
@@ -233,9 +233,14 @@ export async function embedDashboard({
iframe.src = `${supersetDomain}/embedded/${id}${urlParamsString}`;
iframe.title = iframeTitle;
iframe.style.background = 'transparent';
if (iframeAllowExtras.length > 0) {
iframe.setAttribute('allow', iframeAllowExtras.join('; '));
}
// Permissions Policy features the embedded dashboard relies on. Modern
// browsers gate these APIs on the iframe's `allow` attribute regardless
// of sandbox flags, so we include them by default. Host apps can extend
// the list via `iframeAllowExtras`.
const allowFeatures = Array.from(
new Set(['fullscreen', 'clipboard-write', ...iframeAllowExtras]),
);
iframe.setAttribute('allow', allowFeatures.join('; '));
//@ts-ignore
mountPoint.replaceChildren(iframe);
log('placed the iframe');

View File

@@ -494,6 +494,12 @@ const config: ControlPanelConfig = {
},
},
],
],
},
{
label: t('Visual formatting'),
expanded: true,
controlSetRows: [
[
{
name: 'column_config',
@@ -587,18 +593,12 @@ const config: ControlPanelConfig = {
},
},
],
],
},
{
label: t('Visual formatting'),
expanded: true,
controlSetRows: [
[
{
name: 'show_cell_bars',
config: {
type: 'CheckboxControl',
label: t('Show cell bars'),
label: t('Show cell bars for all columns'),
renderTrigger: true,
default: true,
description: t(
@@ -612,7 +612,7 @@ const config: ControlPanelConfig = {
name: 'align_pn',
config: {
type: 'CheckboxControl',
label: t('Align +/-'),
label: t('Align +/- for all columns'),
renderTrigger: true,
default: false,
description: t(
@@ -626,7 +626,7 @@ const config: ControlPanelConfig = {
name: 'color_pn',
config: {
type: 'CheckboxControl',
label: t('Add colors to cell bars for +/-'),
label: t('Add colors to cell bars for +/- for all columns'),
renderTrigger: true,
default: true,
description: t(

View File

@@ -552,6 +552,12 @@ const config: ControlPanelConfig = {
},
},
],
],
},
{
label: t('Visual formatting'),
expanded: true,
controlSetRows: [
[
{
name: 'column_config',
@@ -648,18 +654,12 @@ const config: ControlPanelConfig = {
},
},
],
],
},
{
label: t('Visual formatting'),
expanded: true,
controlSetRows: [
[
{
name: 'show_cell_bars',
config: {
type: 'CheckboxControl',
label: t('Show cell bars'),
label: t('Show cell bars for all columns'),
renderTrigger: true,
default: true,
description: t(
@@ -673,7 +673,7 @@ const config: ControlPanelConfig = {
name: 'align_pn',
config: {
type: 'CheckboxControl',
label: t('Align +/-'),
label: t('Align +/- for all columns'),
renderTrigger: true,
default: false,
description: t(
@@ -687,7 +687,7 @@ const config: ControlPanelConfig = {
name: 'color_pn',
config: {
type: 'CheckboxControl',
label: t('Add colors to cell bars for +/-'),
label: t('Add colors to cell bars for +/- for all columns'),
renderTrigger: true,
default: true,
description: t(

View File

@@ -35,7 +35,6 @@ import { CheckboxChangeEvent } from '@superset-ui/core/components/Checkbox/types
import { useHistory } from 'react-router-dom';
import { setItem, LocalStorageKeys } from 'src/utils/localStorageHelpers';
import { makeUrl } from 'src/utils/pathUtils';
import Tabs from '@superset-ui/core/components/Tabs';
import {
Button,
@@ -1824,7 +1823,9 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
onClick={() => {
setLoading(true);
fetchAndSetDB();
redirectURL(makeUrl(`/sqllab?db=true`));
// redirectURL() delegates to history.push; React Router's basename
// already prefixes the application root, so pass a relative path.
redirectURL('/sqllab?db=true');
}}
>
{t('Query data in SQL Lab')}

View File

@@ -24,7 +24,6 @@ import { TableTab } from 'src/views/CRUD/types';
import { t } from '@apache-superset/core/translation';
import { styled } from '@apache-superset/core/theme';
import { navigateTo } from 'src/utils/navigationUtils';
import { makeUrl } from 'src/utils/pathUtils';
import { WelcomeTable } from './types';
const EmptyContainer = styled.div`
@@ -59,7 +58,9 @@ const REDIRECTS = {
create: {
[WelcomeTable.Charts]: '/chart/add',
[WelcomeTable.Dashboards]: '/dashboard/new',
[WelcomeTable.SavedQueries]: makeUrl('/sqllab?new=true'),
// navigateTo() applies the application root internally; keep this
// relative so the prefix isn't added twice.
[WelcomeTable.SavedQueries]: '/sqllab?new=true',
},
viewAll: {
[WelcomeTable.Charts]: '/chart/list',

View File

@@ -44,7 +44,7 @@ import {
TelemetryPixel,
} from '@superset-ui/core/components';
import type { ItemType, MenuItem } from '@superset-ui/core/components/Menu';
import { ensureAppRoot, makeUrl } from 'src/utils/pathUtils';
import { ensureAppRoot } from 'src/utils/pathUtils';
import { isEmbedded } from 'src/dashboard/util/isEmbedded';
import { findPermission } from 'src/utils/findPermission';
import { isUserAdmin } from 'src/dashboard/util/permissionUtils';
@@ -213,7 +213,10 @@ const RightMenu = ({
},
{
label: t('SQL query'),
url: makeUrl('/sqllab?new=true'),
// Keep the URL relative so isFrontendRoute() matches and Link navigates
// via React Router; the <Typography.Link> fallback applies ensureAppRoot
// exactly once for non-frontend routes.
url: '/sqllab?new=true',
icon: <Icons.SearchOutlined data-test={`menu-item-${t('SQL query')}`} />,
perm: 'can_sqllab',
view: 'Superset',

View File

@@ -25,11 +25,20 @@ import {
fireEvent,
waitFor,
} from 'spec/helpers/testing-library';
import { MemoryRouter } from 'react-router-dom';
import { MemoryRouter, useLocation } from 'react-router-dom';
import { QueryParamProvider } from 'use-query-params';
import { ReactRouter5Adapter } from 'use-query-params/adapters/react-router-5';
import * as getBootstrapData from 'src/utils/getBootstrapData';
import SavedQueryList from '.';
// Renders the current router pathname+search so tests can assert navigation.
function LocationDisplay() {
const location = useLocation();
return (
<div data-test="location-display">{`${location.pathname}${location.search}`}</div>
);
}
// Increase default timeout
jest.setTimeout(30000);
@@ -88,6 +97,7 @@ const renderList = (props = {}, storeOverrides = {}) =>
<MemoryRouter>
<QueryParamProvider adapter={ReactRouter5Adapter}>
<SavedQueryList user={mockUser} {...props} />
<LocationDisplay />
</QueryParamProvider>
</MemoryRouter>,
{
@@ -242,4 +252,39 @@ describe('SavedQueryList', () => {
// Verify delete buttons are not shown
expect(screen.queryByTestId('delete-action')).not.toBeInTheDocument();
});
test('"+ Query" button pushes a router-relative path (subdirectory deployment)', async () => {
// Simulate SUPERSET_APP_ROOT=/superset. ensureAppRoot/makeUrl read
// applicationRoot() dynamically, so mocking it here makes the buggy code
// path (makeUrl() around history.push) produce '/superset/sqllab?new=true'
// instead of being a no-op. React Router's <Router basename> prefixes the
// app root on its own, so history.push MUST receive a path without the
// app-root prefix — otherwise navigation lands at /superset/superset/sqllab
// and shows a blank page (sc-103661).
const applicationRootSpy = jest
.spyOn(getBootstrapData, 'applicationRoot')
.mockReturnValue('/superset');
try {
renderList();
await screen.findByTestId('saved_query-list-view');
const queryButton = await screen.findByRole('button', {
name: /query/i,
});
fireEvent.click(queryButton);
await waitFor(() => {
// The MemoryRouter in renderList uses the default ('/') basename, so
// useLocation reflects exactly what history.push received. A correct
// router-relative push produces '/sqllab?new=true'; a buggy push that
// re-applied the app root would produce '/superset/sqllab?new=true'.
const location = screen.getByTestId('location-display').textContent;
expect(location).toBe('/sqllab?new=true');
});
} finally {
applicationRootSpy.mockRestore();
}
});
});

View File

@@ -223,7 +223,9 @@ function SavedQueryList({
name: t('Query'),
buttonStyle: 'primary',
onClick: () => {
history.push(makeUrl('/sqllab?new=true'));
// React Router's basename already includes the application root; passing
// a relative path ensures correct navigation under subdirectory deployments.
history.push('/sqllab?new=true');
},
});
@@ -245,7 +247,9 @@ function SavedQueryList({
if (openInNewWindow) {
window.open(makeUrl(`/sqllab?savedQueryId=${id}`));
} else {
history.push(makeUrl(`/sqllab?savedQueryId=${id}`));
// React Router's basename already includes the application root; passing
// a relative path ensures correct navigation under subdirectory deployments.
history.push(`/sqllab?savedQueryId=${id}`);
}
};
@@ -338,9 +342,7 @@ function SavedQueryList({
row: {
original: { id, label },
},
}: any) => (
<Link to={makeUrl(`/sqllab?savedQueryId=${id}`)}>{label}</Link>
),
}: any) => <Link to={`/sqllab?savedQueryId=${id}`}>{label}</Link>,
id: 'label',
},
{

View File

@@ -557,6 +557,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# if True, database will be listed as option in the upload file form
supports_file_upload = True
# RLS strategy for this engine spec. Override in engine-specific classes as
# needed (for example ``RLSMethod.AS_PREDICATE`` for engines that don't
# support subquery-based RLS, or ``RLSMethod.AS_PREDICATE_SPLICE`` for
# engines where sqlglot generation is not faithful).
rls_method = RLSMethod.AS_SUBQUERY
# Is the DB engine spec able to change the default schema? This requires implementing # noqa: E501
# a custom `adjust_engine_params` method.
supports_dynamic_schema = False
@@ -618,21 +624,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
else cls.encrypted_extra_sensitive_fields
)
@classmethod
def get_rls_method(cls) -> RLSMethod:
"""
Returns the RLS method to be used for this engine.
There are two ways to insert RLS: either replacing the table with a subquery
that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
safer, but not supported in all databases.
"""
return (
RLSMethod.AS_SUBQUERY
if cls.allows_subqueries and cls.allows_alias_in_select
else RLSMethod.AS_PREDICATE
)
@classmethod
def is_oauth2_enabled(cls) -> bool:
return (

View File

@@ -35,6 +35,7 @@ from superset.db_engine_specs.base import (
DatabaseCategory,
)
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql.parse import RLSMethod
from superset.utils.network import is_hostname_valid, is_port_open
@@ -82,6 +83,7 @@ class CouchbaseEngineSpec(BasicParametersMixin, BaseEngineSpec):
default_driver = "couchbase"
allows_joins = False
allows_subqueries = False
rls_method = RLSMethod.AS_PREDICATE
sqlalchemy_uri_placeholder = (
"couchbase://user:password@host[:port]?truststorepath=value?ssl=value"
)

View File

@@ -23,6 +23,7 @@ from sqlalchemy import types
from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import RLSMethod
if TYPE_CHECKING:
from superset.connectors.sqla.models import TableColumn
@@ -68,6 +69,8 @@ class CrateEngineSpec(BaseEngineSpec):
TimeGrain.YEAR: "DATE_TRUNC('year', {col})",
}
rls_method = RLSMethod.AS_PREDICATE_SPLICE
@classmethod
def epoch_to_dttm(cls) -> str:
return "{col} * 1000"

View File

@@ -39,6 +39,7 @@ from superset.db_engine_specs.base import (
)
from superset.db_engine_specs.exceptions import SupersetDBAPIDatabaseError
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql.parse import RLSMethod
from superset.utils.core import GenericDataType
from superset.utils.hashing import hash_from_str
from superset.utils.network import is_hostname_valid, is_port_open
@@ -55,6 +56,8 @@ class DatabendBaseEngineSpec(BaseEngineSpec):
time_secondary_columns = True
time_groupby_inline = True
rls_method = RLSMethod.AS_PREDICATE_SPLICE
_time_grain_expressions = {
None: "{col}",
TimeGrain.SECOND: "DATE_TRUNC('SECOND', {col})",

View File

@@ -26,6 +26,7 @@ from superset.db_engine_specs.base import (
DatabaseCategory,
)
from superset.errors import SupersetErrorType
from superset.sql.parse import RLSMethod
# Internal class for defining error message patterns (for translation)
@@ -58,6 +59,8 @@ class DenodoEngineSpec(BaseEngineSpec, BasicParametersMixin):
engine = "denodo"
engine_name = "Denodo"
rls_method = RLSMethod.AS_PREDICATE_SPLICE
default_driver = "psycopg2"
sqlalchemy_uri_placeholder = (
"denodo://user:password@host:port/dbname[?key=value&key=value...]"

View File

@@ -21,12 +21,15 @@ from sqlalchemy import types
from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import RLSMethod
class DynamoDBEngineSpec(BaseEngineSpec):
engine = "dynamodb"
engine_name = "Amazon DynamoDB"
rls_method = RLSMethod.AS_PREDICATE_SPLICE
metadata = {
"description": (
"Amazon DynamoDB is a serverless NoSQL database with SQL via PartiQL."

View File

@@ -28,6 +28,7 @@ from superset.db_engine_specs.exceptions import (
SupersetDBAPIOperationalError,
SupersetDBAPIProgrammingError,
)
from superset.sql.parse import RLSMethod
logger = logging.getLogger()
@@ -39,6 +40,7 @@ class ElasticSearchEngineSpec(BaseEngineSpec): # pylint: disable=abstract-metho
allows_joins = False
allows_subqueries = True
allows_sql_comments = False
rls_method = RLSMethod.AS_PREDICATE_SPLICE
metadata = {
"description": (

View File

@@ -21,7 +21,7 @@ from sqlalchemy import types
from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import LimitMethod
from superset.sql.parse import LimitMethod, RLSMethod
class FirebirdEngineSpec(BaseEngineSpec):
@@ -53,6 +53,8 @@ class FirebirdEngineSpec(BaseEngineSpec):
# Firebird uses FIRST to limit: `SELECT FIRST 10 * FROM table`
limit_method = LimitMethod.FETCH_MANY
rls_method = RLSMethod.AS_PREDICATE_SPLICE
_time_grain_expressions = {
None: "{col}",
TimeGrain.SECOND: (

View File

@@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.sql.parse import RLSMethod
from .db2 import Db2EngineSpec
@@ -28,6 +30,8 @@ class IBMiEngineSpec(Db2EngineSpec):
engine_name = "IBM Db2 for i"
max_column_name_length = 128
rls_method = RLSMethod.AS_PREDICATE_SPLICE
@classmethod
def epoch_to_dttm(cls) -> str:
return "(DAYS({col}) - DAYS('1970-01-01')) * 86400 + MIDNIGHT_SECONDS({col})"

View File

@@ -28,7 +28,7 @@ from superset.db_engine_specs.exceptions import (
SupersetDBAPIOperationalError,
SupersetDBAPIProgrammingError,
)
from superset.sql.parse import LimitMethod
from superset.sql.parse import LimitMethod, RLSMethod
from superset.utils.core import GenericDataType
@@ -40,6 +40,7 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
allows_joins = True
allows_subqueries = True
allows_sql_comments = False
rls_method = RLSMethod.AS_PREDICATE_SPLICE
metadata = {
"description": (

View File

@@ -21,6 +21,7 @@ from sqlalchemy import types
from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import RLSMethod
class KylinEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
@@ -29,6 +30,8 @@ class KylinEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
engine = "kylin"
engine_name = "Apache Kylin"
rls_method = RLSMethod.AS_PREDICATE_SPLICE
metadata = {
"description": "Apache Kylin is an open-source OLAP engine for big data.",
"logo": "apache-kylin.png",

View File

@@ -753,6 +753,15 @@ def generate_yaml_docs(output_dir: str | None = None) -> dict[str, dict[str, Any
continue
name = get_name(spec)
# Skip "base" specs (e.g. PostgresBaseEngineSpec) that share an engine_name
# with a real product spec but have no concrete engine value. When multiple
# specs share the same engine_name the one with a non-empty engine string is
# the authoritative product spec; letting a base class overwrite it would
# produce incorrect capability flags.
if not spec.engine and name in all_docs:
continue
doc_data = diagnose(spec)
# Get documentation metadata (prefers spec.metadata over DATABASE_DOCS)
@@ -766,6 +775,7 @@ def generate_yaml_docs(output_dir: str | None = None) -> dict[str, dict[str, Any
doc_data["supports_file_upload"] = spec.supports_file_upload
doc_data["supports_dynamic_schema"] = spec.supports_dynamic_schema
doc_data["supports_catalog"] = spec.supports_catalog
doc_data["supports_dynamic_catalog"] = spec.supports_dynamic_catalog
all_docs[name] = doc_data

View File

@@ -39,6 +39,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.errors import SupersetErrorType
from superset.models.core import Database
from superset.models.sql_lab import Query
from superset.sql.parse import RLSMethod
# Regular expressions to catch custom errors
@@ -227,6 +228,8 @@ class OcientEngineSpec(BaseEngineSpec):
force_column_alias_quotes = True
max_column_name_length = 30
rls_method = RLSMethod.AS_PREDICATE_SPLICE
allows_cte_in_subquery = False
# Ocient does not support cte names starting with underscores
cte_alias = "cte__"

View File

@@ -20,6 +20,7 @@ from sqlalchemy.types import TypeEngine
from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import RLSMethod
class PinotEngineSpec(BaseEngineSpec):
@@ -30,6 +31,7 @@ class PinotEngineSpec(BaseEngineSpec):
allows_joins = False
allows_alias_in_select = False
allows_alias_in_orderby = False
rls_method = RLSMethod.AS_PREDICATE
# pinotdb only sets cursor.description when the response contains
# columnDataTypes, which Pinot omits for zero-row results.

View File

@@ -16,6 +16,7 @@
# under the License.
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import RLSMethod
class SolrEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
@@ -27,6 +28,7 @@ class SolrEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
time_groupby_inline = False
allows_joins = False
allows_subqueries = False
rls_method = RLSMethod.AS_PREDICATE
metadata = {
"description": "Apache Solr is an open-source enterprise search platform.",

View File

@@ -22,6 +22,7 @@ from urllib import parse
from sqlalchemy.engine.url import make_url, URL # noqa: F401
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import RLSMethod
class TDengineEngineSpec(BaseEngineSpec):
@@ -29,6 +30,8 @@ class TDengineEngineSpec(BaseEngineSpec):
engine_name = "TDengine"
max_column_name_length = 64
default_driver = "taosws"
rls_method = RLSMethod.AS_PREDICATE_SPLICE
sqlalchemy_uri_placeholder = (
"taosws://user:******@host:port/dbname[?key=value&key=value...]"
)

View File

@@ -19,9 +19,7 @@
OpenSearch SQL dialect.
OpenSearch SQL is syntactically close to MySQL but accepts both backticks and
double-quotes as identifier delimiters. Treating ``"`` as an identifier (rather
than a string delimiter, as MySQL does) is what keeps mixed-case column names
from being emitted as string literals after a SQLGlot round-trip.
double-quotes as identifier delimiters.
"""
from __future__ import annotations
@@ -31,4 +29,4 @@ from sqlglot.dialects.mysql import MySQL
class OpenSearch(MySQL):
class Tokenizer(MySQL.Tokenizer):
IDENTIFIERS = ['"', "`"]
IDENTIFIERS = ["`", '"']

View File

@@ -76,7 +76,7 @@ SQLGLOT_DIALECTS = {
"duckdb": Dialects.DUCKDB,
# "dynamodb": ???
# "elasticsearch": ???
# "exa": ???
"exa": Dialects.EXASOL,
# "firebird": ???
"firebolt": Firebolt,
"gsheets": Dialects.SQLITE,
@@ -105,7 +105,7 @@ SQLGLOT_DIALECTS = {
"shillelagh": Dialects.SQLITE,
"singlestoredb": SingleStore,
"snowflake": Dialects.SNOWFLAKE,
# "solr": ???
"solr": Dialects.SOLR,
"spark": Dialects.SPARK,
"sqlite": Dialects.SQLITE,
"starrocks": Dialects.STARROCKS,
@@ -142,6 +142,7 @@ class RLSMethod(enum.Enum):
AS_PREDICATE = enum.auto()
AS_SUBQUERY = enum.auto()
AS_PREDICATE_SPLICE = enum.auto()
class RLSTransformer:
@@ -355,6 +356,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
statement: str | None = None,
engine: str = "base",
ast: InternalRepresentation | None = None,
source: str | None = None,
):
if ast:
self._parsed = ast
@@ -365,6 +367,16 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
self.engine = engine
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
# Original SQL substring for this statement, when known. Used by the
# splice-mode RLS path which rewrites this string instead of regenerating
# SQL from the AST. ``None`` means the statement was constructed from an
# AST without an associated source string (splice mode falls back).
self._source_sql: str | None = source if source is not None else statement
# Verbatim SQL to return from ``format()``. Set by string-rewriting
# operations (e.g. splice-mode RLS) that produce a final SQL string and
# need to bypass the dialect generator. Cleared by AST-mutating methods
# since those invalidate this cached text.
self._raw_sql: str | None = None
@classmethod
def split_script(
@@ -531,7 +543,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
self,
catalog: str | None,
schema: str | None,
predicates: dict[Table, list[InternalRepresentation]],
predicates: dict[Table, list[str]],
method: RLSMethod,
) -> None:
"""
@@ -559,9 +571,10 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
statement: str | None = None,
engine: str = "base",
ast: exp.Expression | None = None,
source: str | None = None,
):
self._dialect = SQLGLOT_DIALECTS.get(engine)
super().__init__(statement, engine, ast)
super().__init__(statement, engine, ast, source)
@classmethod
def _parse(cls, script: str, engine: str) -> list[exp.Expression]:
@@ -626,10 +639,55 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
script: str,
engine: str,
) -> list[SQLStatement]:
asts = [ast for ast in cls._parse(script, engine) if ast]
sources = cls._split_source(script, engine, len(asts))
return [
cls(ast=ast, engine=engine) for ast in cls._parse(script, engine) if ast
cls(ast=ast, engine=engine, source=source)
for ast, source in zip(asts, sources, strict=False)
]
@classmethod
def _split_source(
cls,
script: str,
engine: str,
expected_count: int,
) -> list[str | None]:
"""
Slice ``script`` into per-statement substrings using top-level semicolon
positions from the tokenizer. Returns a list of length ``expected_count``;
any entry is ``None`` if the slicing didn't yield a usable substring.
The returned substrings preserve the original byte content of the script
for each statement — necessary for splice-mode RLS, which rewrites the
original SQL rather than regenerating from the AST.
"""
none_result: list[str | None] = [None] * expected_count
dialect = SQLGLOT_DIALECTS.get(engine)
try:
tokens = list(Dialect.get_or_raise(dialect).tokenize(script))
except sqlglot.errors.SqlglotError:
return none_result
# Top-level semicolon offsets (depth 0).
boundaries: list[int] = []
depth = 0
for tok in tokens:
if tok.token_type == sqlglot.tokens.TokenType.L_PAREN:
depth += 1
elif tok.token_type == sqlglot.tokens.TokenType.R_PAREN:
depth -= 1
elif tok.token_type == sqlglot.tokens.TokenType.SEMICOLON and depth == 0:
boundaries.append(tok.start)
starts = [0, *(b + 1 for b in boundaries)]
ends = [*boundaries, len(script)]
sources = [script[s:e].strip() for s, e in zip(starts, ends, strict=False)]
sources = [s for s in sources if s]
if len(sources) != expected_count:
return none_result
return list(sources)
@classmethod
def _parse_statement(
cls,
@@ -722,7 +780,13 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL statement.
When a string-rewriting operation (e.g. splice-mode RLS) has cached a
verbatim result in ``_raw_sql``, return it as-is — the whole point of
those operations is to avoid the dialect generator round-trip.
"""
if self._raw_sql is not None:
return self._raw_sql
return Dialect.get_or_raise(self._dialect).generate(
self._parsed,
copy=True,
@@ -808,6 +872,13 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
"""
Modify the `LIMIT` or `TOP` value of the SQL statement inplace.
"""
# AST mutation invalidates any cached verbatim SQL (e.g. from splice).
# If we already have a rewritten SQL string, re-parse it first so further
# AST mutations (like LIMIT injection) preserve prior text-based rewrites.
if self._raw_sql is not None:
self._parsed = self._parse_statement(self._raw_sql, self.engine)
self._source_sql = self._raw_sql
self._raw_sql = None
if method == LimitMethod.FORCE_LIMIT:
self._parsed.args["limit"] = exp.Limit(
expression=exp.Literal(this=str(limit), is_string=False)
@@ -902,7 +973,7 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
self,
catalog: str | None,
schema: str | None,
predicates: dict[Table, list[exp.Expression]],
predicates: dict[Table, list[str]],
method: RLSMethod,
) -> None:
"""
@@ -910,11 +981,22 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
:param catalog: The default catalog for non-qualified table names
:param schema: The default schema for non-qualified table names
:param predicates: Mapping of fully qualified ``Table`` to raw predicate
SQL strings.
:param method: The method to use for applying the rules.
"""
if not predicates:
return
if method == RLSMethod.AS_PREDICATE_SPLICE:
self._apply_rls_splice(catalog, schema, predicates)
return
parsed_predicates: dict[Table, list[exp.Expression]] = {
table: [self.parse_predicate(predicate) for predicate in table_predicates]
for table, table_predicates in predicates.items()
}
transformers = {
RLSMethod.AS_PREDICATE: RLSAsPredicateTransformer,
RLSMethod.AS_SUBQUERY: RLSAsSubqueryTransformer,
@@ -922,9 +1004,39 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
if method not in transformers:
raise ValueError(f"Invalid RLS method: {method}")
transformer = transformers[method](catalog, schema, predicates)
transformer = transformers[method](catalog, schema, parsed_predicates)
self._parsed = self._parsed.transform(transformer)
def _apply_rls_splice(
self,
catalog: str | None,
schema: str | None,
predicates: dict[Table, list[str]],
) -> None:
"""
Apply RLS via text splicing on the original SQL.
Requires the source SQL substring to be available. Raises ``ValueError``
if it isn't — the caller must ensure the statement was constructed from
a source string (the standard ``SQLScript`` path does this).
"""
from superset.sql.rls_splice import apply_rls_splice
if self._source_sql is None:
raise ValueError(
"Splice-mode RLS requires the source SQL string; "
"this SQLStatement was constructed without one."
)
spliced = apply_rls_splice(
self._source_sql,
catalog,
schema,
predicates,
dialect=self._dialect,
)
self._raw_sql = spliced
class KQLSplitState(enum.Enum):
"""

462
superset/sql/rls_splice.py Normal file
View File

@@ -0,0 +1,462 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
RLS predicate injection via text splicing.
Instead of round-tripping through sqlglot's generator (which transpiles
dialect-specific functions like ``LAST_DAY`` into something else), this approach:
1. Parses the SQL with sqlglot — only to understand structure (scope tree).
2. Uses sqlglot's tokenizer to get byte-accurate positions for every token
in the original SQL string.
3. For each ``SELECT`` scope that references a table with an RLS predicate,
finds the exact byte offset to inject at — either the end of an existing
``WHERE`` clause, or just before ``GROUP BY`` / ``ORDER BY`` / ``HAVING``
/ ``LIMIT`` / the closing paren of a subquery.
4. Splices the predicate text directly into the original string at that
offset — never calling ``.sql()``, so the generator never runs.
Result: everything outside the splice points is the original SQL, byte for
byte. Dialect-specific functions, comments, and formatting are all preserved
exactly.
Known limitations:
- SQL that fails to parse under the chosen dialect raises a ``ParseError``.
A thin dialect subclass is still required for parsing — but only for
parsing, not generation.
- Predicate strings are spliced in as raw SQL. They must come from a trusted
source (the RLS config), not user input.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import sqlglot
from sqlglot import exp
from sqlglot.optimizer.scope import traverse_scope
from sqlglot.tokens import Token, TokenType
if TYPE_CHECKING:
from superset.sql.parse import Table
# Token types that end a WHERE clause / FROM section at the current paren depth,
# indicating where a new predicate must be inserted just before.
_CLAUSE_ENDS = {
TokenType.GROUP_BY,
TokenType.HAVING,
TokenType.ORDER_BY,
TokenType.WINDOW,
TokenType.QUALIFY,
TokenType.LIMIT,
TokenType.FETCH,
TokenType.CLUSTER_BY,
TokenType.DISTRIBUTE_BY,
TokenType.SORT_BY,
TokenType.CONNECT_BY,
TokenType.START_WITH,
TokenType.UNION,
TokenType.INTERSECT,
TokenType.EXCEPT,
}
_JOIN_STARTS = {
TokenType.JOIN,
TokenType.STRAIGHT_JOIN,
TokenType.JOIN_MARKER,
}
def _splice_priority(text: str) -> int:
"""
Priority for applying splices at the same offset.
Insert full SQL fragments (WHERE/ON/predicates) before closing parens so
wrapping splices like ``pred AND (existing)`` compose correctly.
"""
return 1 if text != ")" else 0
def _after_previous_token(tokens: list[Token], index: int) -> int:
"""
Return the offset immediately after the token preceding *index*.
The sqlglot tokenizer strips comments and whitespace from the token stream,
so the previous token's ``end + 1`` is the splice point that lands right
after the last real SQL content — naturally skipping any intervening
comments or whitespace, and never confusing ``--`` or ``/*`` inside string
literals for real comment delimiters.
"""
if index <= 0:
return 0
return tokens[index - 1].end + 1
def _table_from_node(
node: exp.Table,
catalog: str | None,
schema: str | None,
) -> Table:
"""
Build a fully qualified ``Table`` from a sqlglot ``exp.Table`` node, defaulting
unqualified parts to the supplied catalog/schema.
"""
# Imported lazily to avoid a circular import with ``superset.sql.parse``.
from superset.sql.parse import Table
return Table(
table=node.name,
schema=node.db if node.db else schema,
catalog=node.catalog if node.catalog else catalog,
)
def apply_rls_splice(
sql: str,
catalog: str | None,
schema: str | None,
predicates: dict[Table, list[str]],
dialect: str | None = None,
) -> str:
"""
Inject RLS predicates into ``sql`` by splicing text at the right positions.
:param sql: The original SQL query. Returned unchanged except at splice points.
:param catalog: The default catalog for non-qualified table names.
:param schema: The default schema for non-qualified table names.
:param predicates: Mapping of ``Table`` to predicate SQL strings. Each entry
maps a fully qualified table to one or more raw predicate strings to
``AND`` together when that table is referenced in a SELECT scope.
:param dialect: The sqlglot dialect used for *parsing only* — to understand
scope structure and locate token positions. The generator is never
called, so this does not affect output formatting.
:return: The query with RLS predicates injected into every relevant SELECT
scope.
"""
if not predicates or not any(predicates.values()):
return sql
resolved_dialect = sqlglot.Dialect.get_or_raise(dialect)
tokens = list(resolved_dialect.tokenize(sql))
tree = sqlglot.parse_one(sql, dialect=dialect)
splices: list[tuple[int, str]] = []
for scope in traverse_scope(tree):
splices.extend(
_splices_for_scope(
sql,
tokens,
scope,
predicates,
catalog,
schema,
dialect,
)
)
# Apply splices in reverse offset order so earlier positions stay valid.
# For equal offsets, apply predicate/WHERE/ON inserts before ")" inserts.
splices.sort(key=lambda item: (item[0], _splice_priority(item[1])), reverse=True)
result = sql
for offset, text in splices:
result = result[:offset] + text + result[offset:]
return result
def _splices_for_scope(
sql: str,
tokens: list[Token],
scope: object,
predicates: dict[Table, list[str]],
catalog: str | None,
schema: str | None,
dialect: str | None,
) -> list[tuple[int, str]]:
"""
Compute all splices for a single SELECT scope.
This mirrors ``RLSAsPredicateTransformer`` semantics:
- predicates for FROM tables are applied to the SELECT WHERE clause as
``pred AND (existing_where)``
- predicates for JOIN tables are applied to each JOIN ON clause as
``pred AND (existing_on)`` (or ``ON pred`` when ON is absent)
"""
from_predicates: list[str] = []
from_table_ends: list[int] = []
join_splices: list[tuple[int, str]] = []
for source in scope.sources.values(): # type: ignore[attr-defined]
source_type, table_end, pred_sql = _classify_source_predicate(
source,
predicates,
catalog,
schema,
dialect,
)
if source_type == "none" or table_end is None or pred_sql is None:
continue
if source_type == "from":
from_predicates.append(pred_sql)
from_table_ends.append(table_end)
continue
join_splice = _find_join_splice(sql, tokens, table_end, pred_sql)
if join_splice:
join_splices.extend(join_splice)
continue
if not from_predicates:
return join_splices
combined_predicates = " AND ".join(dict.fromkeys(from_predicates))
from_splice = _find_where_splice(
sql,
tokens,
max(from_table_ends),
combined_predicates,
)
return [*join_splices, *from_splice]
def _table_end(source: exp.Table) -> int | None:
ident = source.find(exp.Identifier)
if ident and getattr(ident, "_meta", None):
return ident._meta["end"]
return None
def _classify_source_predicate(
source: object,
predicates: dict[Table, list[str]],
catalog: str | None,
schema: str | None,
dialect: str | None,
) -> tuple[str, int | None, str | None]:
"""
Return source kind (from/join/none), table end offset, and predicate SQL.
"""
if not isinstance(source, exp.Table):
return ("none", None, None)
table = _table_from_node(source, catalog, schema)
table_predicates = [
_qualify_predicate(predicate, source, dialect)
for predicate in predicates.get(table, [])
if predicate
]
if not table_predicates:
return ("none", None, None)
table_end = _table_end(source)
if table_end is None:
return ("none", None, None)
pred_sql = " AND ".join(dict.fromkeys(table_predicates))
if isinstance(source.parent, exp.From):
return ("from", table_end, pred_sql)
if isinstance(source.parent, exp.Join):
return ("join", table_end, pred_sql)
return ("none", None, None)
def _qualify_predicate(
predicate: str,
table_node: exp.Table,
dialect: str | None,
) -> str:
"""
Qualify predicate columns with the table alias/name, mirroring
``RLSAsPredicateTransformer``.
"""
parsed = sqlglot.parse_one(predicate, dialect=dialect)
table = table_node.alias_or_name
table_expr = exp.to_identifier(table)
for column in parsed.find_all(exp.Column):
column.set("table", table_expr.copy())
return parsed.sql(dialect=dialect)
def _scan_until_scope_boundary(
tokens: list[Token],
anchor: int,
*,
stop_at_join: bool,
) -> tuple[str, int | None]:
"""
Scan tokens forward from ``anchor`` until a clause/scope boundary.
Returns ``("where", index)`` when a WHERE token is found at depth 0,
``("boundary", index)`` for a non-WHERE boundary token, and
``("eof", None)`` when no boundary token is found.
"""
depth = 0
for i, tok in enumerate(tokens):
if tok.start <= anchor:
continue
if tok.token_type == TokenType.L_PAREN:
depth += 1
continue
if tok.token_type == TokenType.R_PAREN:
if depth == 0:
return ("boundary", i)
depth -= 1
continue
if depth > 0:
continue
if tok.token_type == TokenType.WHERE:
return ("where", i)
if tok.token_type in _CLAUSE_ENDS or (
stop_at_join and tok.token_type in _JOIN_STARTS
):
return ("boundary", i)
return ("eof", None)
def _find_condition_end(
tokens: list[Token],
start_index: int,
*,
stop_at_join: bool,
) -> int:
"""
Find the end offset for a WHERE/ON condition body.
"""
depth = 0
prev_end = tokens[start_index].end
for tok in tokens[start_index + 1 :]:
if tok.token_type == TokenType.L_PAREN:
depth += 1
elif tok.token_type == TokenType.R_PAREN:
if depth == 0:
return prev_end + 1
depth -= 1
elif depth == 0 and (
(stop_at_join and tok.token_type == TokenType.WHERE)
or tok.token_type in _CLAUSE_ENDS
or (stop_at_join and tok.token_type in _JOIN_STARTS)
):
return prev_end + 1
prev_end = tok.end
return prev_end + 1
def _find_where_splice(
sql: str,
tokens: list[Token],
anchor: int,
pred_sql: str,
) -> list[tuple[int, str]]:
"""
Build splices for adding predicate semantics to the SELECT WHERE clause:
``pred`` when absent, ``pred AND (existing)`` when present.
"""
kind, idx = _scan_until_scope_boundary(tokens, anchor, stop_at_join=False)
if kind == "where" and idx is not None:
if idx + 1 >= len(tokens):
return [(tokens[idx].end + 1, f" {pred_sql}")]
body_start = tokens[idx + 1].start
body_end = _find_condition_end(tokens, idx, stop_at_join=False)
return [
(body_start, f"{pred_sql} AND ("),
(body_end, ")"),
]
if kind == "boundary" and idx is not None:
return [(_after_previous_token(tokens, idx), f" WHERE {pred_sql}")]
return [(len(sql), f" WHERE {pred_sql}")]
def _find_join_splice(
sql: str,
tokens: list[Token],
anchor: int,
pred_sql: str,
) -> list[tuple[int, str]]:
"""
Build splices for adding predicate semantics to a JOIN clause:
``ON pred`` when ON absent, ``ON pred AND (existing_on)`` when present.
"""
on_index, boundary_index = _scan_join_clause(tokens, anchor)
if on_index is not None:
if on_index + 1 >= len(tokens):
return [(tokens[on_index].end + 1, f" {pred_sql}")]
body_start = tokens[on_index + 1].start
body_end = _find_condition_end(tokens, on_index, stop_at_join=True)
return [
(body_start, f"{pred_sql} AND ("),
(body_end, ")"),
]
if boundary_index is not None:
return [(_after_previous_token(tokens, boundary_index), f" ON {pred_sql}")]
return [(len(sql), f" ON {pred_sql}")]
def _scan_join_clause(
tokens: list[Token],
anchor: int,
) -> tuple[int | None, int | None]:
"""
Find ON and boundary token indexes for a JOIN segment.
"""
depth = 0
on_index: int | None = None
boundary_index: int | None = None
for i, tok in enumerate(tokens):
if tok.start <= anchor:
continue
if tok.token_type == TokenType.L_PAREN:
depth += 1
continue
if tok.token_type == TokenType.R_PAREN:
if depth == 0:
boundary_index = i
break
depth -= 1
continue
if depth > 0:
continue
if tok.token_type == TokenType.ON and on_index is None:
on_index = i
continue
if tok.token_type == TokenType.WHERE:
boundary_index = i
break
if tok.token_type in _JOIN_STARTS or tok.token_type in _CLAUSE_ENDS:
boundary_index = i
break
return on_index, boundary_index

View File

@@ -40,17 +40,19 @@ def apply_rls(
:returns: True if any RLS predicates were actually applied, False otherwise.
"""
# There are two ways to insert RLS: either replacing the table with a subquery
# that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
# safer, but not supported in all databases.
method = database.db_engine_spec.get_rls_method()
# There are three ways to insert RLS:
# - replace the table with a subquery containing the RLS (safest, but not
# supported in all databases)
# - append the RLS to the ``WHERE`` clause via AST transformation
# - splice the RLS into the original SQL string (preserves dialect-specific
# syntax that the sqlglot generator would otherwise transpile)
method = database.db_engine_spec.rls_method
# collect all RLS predicates for all tables in the query
predicates: dict[Table, list[Any]] = {}
predicates: dict[Table, list[str]] = {}
for table in parsed_statement.tables:
table = table.qualify(catalog=catalog, schema=schema)
predicates[table] = [
parsed_statement.parse_predicate(predicate)
raw_predicates = [
predicate
for predicate in get_predicates_for_table(
table,
database,
@@ -58,6 +60,7 @@ def apply_rls(
)
if predicate
]
predicates[table] = raw_predicates
has_predicates = any(predicates.values())
parsed_statement.apply_rls(catalog, schema, predicates, method)

View File

@@ -36,7 +36,7 @@ from sqlalchemy.sql import sqltypes
from superset.db_engine_specs.base import BaseEngineSpec, convert_inspector_columns
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import OAuth2RedirectError
from superset.sql.parse import Table
from superset.sql.parse import RLSMethod, Table
from superset.superset_typing import (
OAuth2ClientConfig,
OAuth2State,
@@ -1283,3 +1283,8 @@ def test_start_oauth2_dance_falls_back_to_url_for(mocker: MockerFixture) -> None
error = exc_info.value.error
assert error.extra["redirect_uri"] == fallback_uri
def test_default_rls_method_is_subquery() -> None:
"""Base engine spec defaults to subquery-based RLS."""
assert BaseEngineSpec.rls_method == RLSMethod.AS_SUBQUERY

View File

@@ -209,7 +209,7 @@ class TestApplyRlsReturnValue:
from superset.utils.rls import apply_rls
database = MagicMock()
database.db_engine_spec.get_rls_method.return_value = MagicMock()
database.db_engine_spec.rls_method = RLSMethod.AS_SUBQUERY
database.get_default_catalog.return_value = None
statement = MagicMock()
@@ -237,7 +237,7 @@ class TestApplyRlsReturnValue:
mock_get_predicates.return_value = []
database = MagicMock()
database.db_engine_spec.get_rls_method.return_value = MagicMock()
database.db_engine_spec.rls_method = RLSMethod.AS_SUBQUERY
database.get_default_catalog.return_value = None
mock_table = MagicMock()
@@ -268,7 +268,7 @@ class TestApplyRlsReturnValue:
mock_get_predicates.return_value = ["user_id = 42"]
database = MagicMock()
database.db_engine_spec.get_rls_method.return_value = MagicMock()
database.db_engine_spec.rls_method = RLSMethod.AS_SUBQUERY
database.get_default_catalog.return_value = None
mock_table = MagicMock()
@@ -276,8 +276,6 @@ class TestApplyRlsReturnValue:
statement = MagicMock()
statement.tables = [mock_table]
statement.parse_predicate.return_value = MagicMock()
result = apply_rls(
database=database,
catalog=None,
@@ -312,11 +310,10 @@ class TestRLSSubqueryAlias:
"""
sql = "SELECT pens.pen_id, pens.is_green FROM public.pens"
statement = SQLStatement(sql, engine="redshift")
predicate = statement.parse_predicate("user_id = 1")
statement.apply_rls(
None,
"public",
{Table("pens", "public", None): [predicate]},
{Table("pens", "public", None): ["user_id = 1"]},
RLSMethod.AS_SUBQUERY,
)
result = statement.format()
@@ -333,11 +330,10 @@ class TestRLSSubqueryAlias:
"""
sql = "SELECT pens.pen_id, pens.is_green FROM mycat.public.pens"
statement = SQLStatement(sql, engine="redshift")
predicate = statement.parse_predicate("user_id = 1")
statement.apply_rls(
None,
"public",
{Table("pens", "public", "mycat"): [predicate]},
{Table("pens", "public", "mycat"): ["user_id = 1"]},
RLSMethod.AS_SUBQUERY,
)
result = statement.format()
@@ -351,11 +347,10 @@ class TestRLSSubqueryAlias:
"""
sql = "SELECT p.pen_id, p.is_green FROM public.pens p"
statement = SQLStatement(sql, engine="redshift")
predicate = statement.parse_predicate("user_id = 1")
statement.apply_rls(
None,
"public",
{Table("pens", "public", None): [predicate]},
{Table("pens", "public", None): ["user_id = 1"]},
RLSMethod.AS_SUBQUERY,
)
result = statement.format()
@@ -369,11 +364,10 @@ class TestRLSSubqueryAlias:
"""
sql = "SELECT pen_id, is_green FROM public.pens"
statement = SQLStatement(sql, engine="redshift")
predicate = statement.parse_predicate("user_id = 1")
statement.apply_rls(
None,
"public",
{Table("pens", "public", None): [predicate]},
{Table("pens", "public", None): ["user_id = 1"]},
RLSMethod.AS_SUBQUERY,
)
result = statement.format()

View File

@@ -33,7 +33,8 @@ def test_opensearch_dialect_registered() -> None:
def test_double_quotes_as_identifiers() -> None:
"""
Test that double quotes are treated as identifiers, not string literals.
Test that double quotes are treated as identifiers, not string literals,
and normalized to backticks in output.
"""
sql = 'SELECT "AvgTicketPrice" FROM "flights"'
ast = sqlglot.parse_one(sql, OpenSearch)
@@ -42,8 +43,8 @@ def test_double_quotes_as_identifiers() -> None:
OpenSearch().generate(expression=ast, pretty=True)
== """
SELECT
"AvgTicketPrice"
FROM "flights"
`AvgTicketPrice`
FROM `flights`
""".strip()
)
@@ -69,8 +70,7 @@ WHERE
def test_backticks_as_identifiers() -> None:
"""
Test that backticks work as identifiers (MySQL-style).
Backticks are normalized to double quotes in output.
Test that backticks are accepted as identifiers and preserved on output.
"""
sql = "SELECT `AvgTicketPrice` FROM `flights`"
ast = sqlglot.parse_one(sql, OpenSearch)
@@ -79,15 +79,16 @@ def test_backticks_as_identifiers() -> None:
OpenSearch().generate(expression=ast, pretty=True)
== """
SELECT
"AvgTicketPrice"
FROM "flights"
`AvgTicketPrice`
FROM `flights`
""".strip()
)
def test_mixed_identifier_quotes() -> None:
"""
Test mixing double quotes and backticks for identifiers.
Test mixing double quotes and backticks for identifiers are all normalized to
backticks on output.
"""
sql = 'SELECT "AvgTicketPrice" AS `AvgTicketPrice` FROM `default`.`flights`'
ast = sqlglot.parse_one(sql, OpenSearch)
@@ -96,12 +97,26 @@ def test_mixed_identifier_quotes() -> None:
OpenSearch().generate(expression=ast, pretty=True)
== """
SELECT
"AvgTicketPrice" AS "AvgTicketPrice"
FROM "default"."flights"
`AvgTicketPrice` AS `AvgTicketPrice`
FROM `default`.`flights`
""".strip()
)
def test_alias_with_space() -> None:
"""
Test that an alias containing a space (e.g. a metric key like ``my test``)
is preserved as a backtick-quoted identifier through the round-trip.
"""
sql = 'SELECT COUNT(*) AS "my test" FROM `flights`'
ast = sqlglot.parse_one(sql, OpenSearch)
assert (
OpenSearch().generate(expression=ast, pretty=False)
== "SELECT COUNT(*) AS `my test` FROM `flights`"
)
@pytest.mark.parametrize(
"sql, expected",
[
@@ -110,20 +125,20 @@ FROM "default"."flights"
"""
SELECT
COUNT(*)
FROM "flights"
FROM `flights`
WHERE
"Cancelled" = TRUE
`Cancelled` = TRUE
""".strip(),
),
(
'SELECT "Carrier", SUM("AvgTicketPrice") FROM "flights" GROUP BY "Carrier"',
"""
SELECT
"Carrier",
SUM("AvgTicketPrice")
FROM "flights"
`Carrier`,
SUM(`AvgTicketPrice`)
FROM `flights`
GROUP BY
"Carrier"
`Carrier`
""".strip(),
),
(
@@ -131,9 +146,9 @@ GROUP BY
"""
SELECT
*
FROM "flights"
FROM `flights`
WHERE
"DestCountry" IN ('US', 'CA', 'MX')
`DestCountry` IN ('US', 'CA', 'MX')
""".strip(),
),
],
@@ -165,13 +180,13 @@ GROUP BY "Carrier"
OpenSearch().generate(expression=ast, pretty=True)
== """
SELECT
"Carrier",
`Carrier`,
COUNT(*),
AVG("AvgTicketPrice"),
MAX("FlightDelayMin")
FROM "flights"
AVG(`AvgTicketPrice`),
MAX(`FlightDelayMin`)
FROM `flights`
GROUP BY
"Carrier"
`Carrier`
""".strip()
)
@@ -190,10 +205,10 @@ SELECT
*
FROM (
SELECT
"Carrier",
"AvgTicketPrice"
FROM "flights"
) AS "sub"
`Carrier`,
`AvgTicketPrice`
FROM `flights`
) AS `sub`
""".strip()
)
@@ -212,12 +227,12 @@ def test_order_by_with_quoted_identifiers() -> None:
OpenSearch().generate(expression=ast, pretty=True)
== """
SELECT
"Carrier",
"AvgTicketPrice"
FROM "flights"
`Carrier`,
`AvgTicketPrice`
FROM `flights`
ORDER BY
"AvgTicketPrice" DESC,
"Carrier" ASC
`AvgTicketPrice` DESC,
`Carrier` ASC
""".strip()
)
@@ -234,7 +249,7 @@ def test_limit_clause() -> None:
== """
SELECT
*
FROM "flights"
FROM `flights`
LIMIT 10
""".strip()
)

View File

@@ -1704,6 +1704,21 @@ def test_set_limit_value(
assert statement.format() == expected
def test_set_limit_value_after_splice_reparses_from_raw_sql() -> None:
"""
When a statement has cached verbatim SQL from splice-mode rewrites, setting
limit should reparse that SQL before mutating the AST.
"""
statement = SQLStatement("SELECT * FROM some_table", "postgresql")
statement._raw_sql = "SELECT * FROM some_table WHERE tenant_id = 42"
statement.set_limit_value(10, LimitMethod.FORCE_LIMIT)
formatted = statement.format()
assert "tenant_id = 42" in formatted
assert "LIMIT 10" in formatted
@pytest.mark.parametrize(
"kql, limit, expected",
[
@@ -2198,7 +2213,7 @@ def test_rls_subquery_transformer(
statement.apply_rls(
"catalog1",
"schema1",
{k: [parse_one(v)] for k, v in rules.items()},
{k: [v] for k, v in rules.items()},
RLSMethod.AS_SUBQUERY,
)
assert statement.format() == expected
@@ -2542,12 +2557,312 @@ def test_rls_predicate_transformer(
statement.apply_rls(
"catalog1",
"schema1",
{k: [parse_one(v)] for k, v in rules.items()},
{k: [v] for k, v in rules.items()},
RLSMethod.AS_PREDICATE,
)
assert statement.format() == expected
@pytest.mark.parametrize(
"sql, rules, expected",
[
(
"SELECT t.foo FROM some_table AS t",
{Table("some_table", "schema1", "catalog1"): "t.id = 42"},
"SELECT t.foo FROM some_table AS t WHERE t.id = 42",
),
(
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz' OR foo = 'qux'",
{Table("some_table", "schema1", "catalog1"): "t.id = 42"},
"SELECT t.foo FROM some_table AS t WHERE t.id = 42 "
"AND (bar = 'baz' OR foo = 'qux')",
),
(
"SELECT * FROM table JOIN other_table ON table.id = other_table.id",
{Table("other_table", "schema1", "catalog1"): "other_table.id = 42"},
"SELECT * FROM table JOIN other_table ON other_table.id = 42 "
"AND (table.id = other_table.id)",
),
(
"SELECT * FROM table JOIN other_table",
{Table("other_table", "schema1", "catalog1"): "other_table.id = 42"},
"SELECT * FROM table JOIN other_table ON other_table.id = 42",
),
(
"SELECT * FROM table JOIN other_table ON table.id = other_table.id "
"WHERE 1=1",
{Table("other_table", "schema1", "catalog1"): "other_table.id = 42"},
"SELECT * FROM table JOIN other_table ON other_table.id = 42 "
"AND (table.id = other_table.id) WHERE 1=1",
),
],
)
def test_rls_predicate_splice_semantics_match_predicate(
sql: str,
rules: dict[Table, str],
expected: str,
) -> None:
"""
Splice mode should preserve predicate-mode semantics for boolean grouping
and JOIN-vs-WHERE placement.
"""
statement = SQLStatement(sql)
statement.apply_rls(
"catalog1",
"schema1",
{k: [v] for k, v in rules.items()},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == expected
@pytest.mark.parametrize(
"sql, rules, expected",
[
# Simple — no WHERE clause to extend.
(
"SELECT LAST_DAY(d) FROM some_table",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table WHERE some_table.tenant_id = 42",
),
# Append to an existing WHERE clause.
(
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open'",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table "
"WHERE some_table.tenant_id = 42 AND (status = 'open')",
),
# WHERE precedes GROUP BY: predicate goes before GROUP BY.
(
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open' GROUP BY d",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table "
"WHERE some_table.tenant_id = 42 AND (status = 'open') GROUP BY d",
),
# No WHERE, but GROUP BY and ORDER BY are present.
(
"SELECT LAST_DAY(d) FROM some_table GROUP BY d ORDER BY d",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table WHERE some_table.tenant_id = 42 "
"GROUP BY d ORDER BY d",
),
# JOIN — predicate scoped to one of the tables.
(
"SELECT o.id FROM some_table o JOIN locations l ON o.loc_id = l.id",
{Table("some_table", "schema1", "catalog1"): "o.tenant_id = 42"},
"SELECT o.id FROM some_table o JOIN locations l "
"ON o.loc_id = l.id WHERE o.tenant_id = 42",
),
# JOIN — different predicate per table, both spliced into one WHERE.
(
"SELECT * FROM some_table JOIN events ON some_table.id = events.order_id",
{
Table("events", "schema1", "catalog1"): "events.user_id = 99",
Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42",
},
"SELECT * FROM some_table JOIN events "
"ON events.user_id = 99 AND (some_table.id = events.order_id) "
"WHERE some_table.tenant_id = 42",
),
# Subquery in FROM — splice into the inner SELECT.
(
"SELECT x FROM (SELECT LAST_DAY(d) AS x FROM some_table) sub",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT x FROM (SELECT LAST_DAY(d) AS x FROM some_table "
"WHERE some_table.tenant_id = 42) sub",
),
# CTE — splice into the CTE body.
(
"WITH cte AS (SELECT LAST_DAY(d) FROM some_table) SELECT * FROM cte",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"WITH cte AS (SELECT LAST_DAY(d) FROM some_table "
"WHERE some_table.tenant_id = 42) SELECT * FROM cte",
),
# Dialect-specific function (LAST_DAY) preserved verbatim.
(
"SELECT id, LAST_DAY(created_at) FROM some_table WHERE region = 'US'",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT id, LAST_DAY(created_at) FROM some_table "
"WHERE some_table.tenant_id = 42 AND (region = 'US')",
),
# Multiline + inline comment preserved exactly.
(
"SELECT LAST_DAY(created_at) -- last day of month\n"
"FROM some_table\n"
"WHERE region = 'US'",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT LAST_DAY(created_at) -- last day of month\n"
"FROM some_table\n"
"WHERE some_table.tenant_id = 42 AND (region = 'US')",
),
# Schema-qualified table name (no default schema match) — no predicate.
(
"SELECT t.foo FROM schema2.some_table AS t",
{Table("some_table", "schema1", "catalog1"): "t.id = 42"},
"SELECT t.foo FROM schema2.some_table AS t",
),
],
)
def test_rls_predicate_splice(
sql: str,
rules: dict[Table, str],
expected: str,
) -> None:
"""
Test the splice-mode RLS via ``RLSMethod.AS_PREDICATE_SPLICE``.
Splice mode rewrites the original SQL string instead of re-rendering the
AST through the dialect generator, so byte-level fidelity (including
dialect-specific functions, comments, and whitespace) is preserved.
"""
statement = SQLStatement(sql)
statement.apply_rls(
"catalog1",
"schema1",
{k: [v] for k, v in rules.items()},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == expected
def test_rls_predicate_splice_requires_source() -> None:
"""
Splice mode requires the original SQL substring; constructing a statement
purely from an AST should make splice mode raise.
"""
ast = parse_one("SELECT * FROM some_table")
statement = SQLStatement(ast=ast, engine="postgresql")
with pytest.raises(ValueError, match="Splice-mode RLS requires the source SQL"):
statement.apply_rls(
"catalog1",
"schema1",
{Table("some_table", "schema1", "catalog1"): ["id = 42"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
def test_rls_predicate_splice_preserves_dialect_function() -> None:
"""
Splice mode must NOT round-trip through the sqlglot generator. ``LAST_DAY``
on the postgres dialect would otherwise be transpiled by the generator.
"""
sql = "SELECT LAST_DAY(d) FROM some_table"
statement = SQLStatement(sql, engine="postgresql")
statement.apply_rls(
None,
None,
{Table("some_table"): ["some_table.tenant_id = 42"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == (
"SELECT LAST_DAY(d) FROM some_table WHERE some_table.tenant_id = 42"
)
def test_rls_predicate_splice_string_predicates_skip_parse() -> None:
"""
Splice mode accepts predicate strings directly — no ``parse_predicate`` is
needed at the call site.
"""
sql = "SELECT * FROM some_table"
statement = SQLStatement(sql, engine="postgresql")
statement.apply_rls(
None,
None,
{Table("some_table"): ["some_table.tenant_id = 42 AND some_table.active"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == (
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 AND some_table.active"
)
@pytest.mark.parametrize(
"sql, expected",
[
(
"SELECT * FROM some_table -- hi\nGROUP BY id",
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
"-- hi\nGROUP BY id",
),
(
"SELECT * FROM some_table /* inline */ GROUP BY id",
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
"/* inline */ GROUP BY id",
),
],
)
def test_rls_predicate_splice_inserts_before_comments(sql: str, expected: str) -> None:
"""
Splice mode should insert predicates before comments that precede the next
clause boundary, so comments do not swallow the injected SQL.
"""
statement = SQLStatement(sql, engine="postgresql")
statement.apply_rls(
None,
None,
{Table("some_table"): ["some_table.tenant_id = 42"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == expected
@pytest.mark.parametrize(
"sql, engine, expected",
[
(
"SELECT * FROM some_table QUALIFY row_number() OVER "
"(PARTITION BY id ORDER BY ts DESC) = 1",
"snowflake",
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
"QUALIFY row_number() OVER (PARTITION BY id ORDER BY ts DESC) = 1",
),
(
"SELECT sum(v) OVER () FROM some_table WINDOW w AS (PARTITION BY id)",
"postgresql",
"SELECT sum(v) OVER () FROM some_table "
"WHERE some_table.tenant_id = 42 "
"WINDOW w AS (PARTITION BY id)",
),
],
)
def test_rls_predicate_splice_handles_additional_clause_boundaries(
sql: str,
engine: str,
expected: str,
) -> None:
"""
Splice mode should insert WHERE before clause types that can legally follow
FROM/WHERE (for example QUALIFY and WINDOW).
"""
statement = SQLStatement(sql, engine=engine)
statement.apply_rls(
None,
None,
{Table("some_table"): ["some_table.tenant_id = 42"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == expected
def test_rls_predicate_splice_then_limit_keeps_rls() -> None:
"""
LIMIT rewrites after splice-mode RLS should retain injected predicates.
"""
statement = SQLStatement("SELECT * FROM some_table", engine="postgresql")
statement.apply_rls(
None,
None,
{Table("some_table"): ["tenant_id = 42"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
statement.set_limit_value(101, LimitMethod.FORCE_LIMIT)
formatted = statement.format()
assert "some_table.tenant_id = 42" in formatted
assert "LIMIT 101" in formatted
@pytest.mark.parametrize(
"sql, table, expected",
[

View File

@@ -0,0 +1,292 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import sqlglot
from sqlglot import Dialect, exp
from superset.sql.parse import SQLStatement, Table
from superset.sql.rls_splice import (
_classify_source_predicate,
_find_condition_end,
_find_join_splice,
_find_where_splice,
_scan_join_clause,
_scan_until_scope_boundary,
_splices_for_scope,
_table_end,
apply_rls_splice,
)
def _tokenize(sql: str) -> list[sqlglot.tokens.Token]:
return list(Dialect.get_or_raise(None).tokenize(sql))
def _token_index(tokens: list[sqlglot.tokens.Token], token_type: object) -> int:
return next(i for i, token in enumerate(tokens) if token.token_type == token_type)
def _token_by_text(
tokens: list[sqlglot.tokens.Token], text: str
) -> sqlglot.tokens.Token:
return next(token for token in tokens if token.text == text)
def test_split_source_returns_none_result_when_tokenize_fails(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _BrokenDialect:
@staticmethod
def tokenize(_: str) -> list[sqlglot.tokens.Token]:
raise sqlglot.errors.SqlglotError("boom")
monkeypatch.setattr(
"superset.sql.parse.Dialect.get_or_raise",
lambda _: _BrokenDialect(),
)
assert SQLStatement._split_source("SELECT 1", "postgresql", 2) == [None, None]
def test_apply_rls_splice_ignores_empty_predicates() -> None:
sql = "SELECT 1"
assert apply_rls_splice(sql, None, None, {Table("foo"): []}) == sql
def test_apply_rls_splice_ignores_dash_dash_inside_string_literal() -> None:
"""
Regression: the splice point must not be confused by ``--`` appearing
inside a string literal. Earlier ``rfind("--", ...)`` logic mistook this
for an inline comment and inserted the predicate inside the quoted text.
"""
sql = "SELECT * FROM some_table WHERE note = '--x' GROUP BY id"
spliced = apply_rls_splice(
sql,
None,
None,
{Table("some_table"): ["some_table.tenant_id = 42"]},
dialect="postgres",
)
assert spliced == (
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
"AND (note = '--x') GROUP BY id"
)
def test_table_end_returns_none_without_metadata() -> None:
source = exp.Table(this=exp.Identifier(this="foo"))
assert _table_end(source) is None
def test_classify_source_predicate_returns_none_without_table_metadata() -> None:
source = exp.Table(this=exp.Identifier(this="foo"))
exp.From(this=source)
result = _classify_source_predicate(
source,
{Table("foo"): ["id = 1"]},
None,
None,
None,
)
assert result == ("none", None, None)
def test_classify_source_predicate_returns_none_for_unsupported_parent() -> None:
source = exp.Table(this=exp.Identifier(this="foo"))
source.this.meta["end"] = 3
exp.Alias(this=source, alias=exp.Identifier(this="alias"))
result = _classify_source_predicate(
source,
{Table("foo"): ["id = 1"]},
None,
None,
None,
)
assert result == ("none", None, None)
def test_scan_until_scope_boundary_tracks_parenthesis_depth() -> None:
sql = "SELECT * FROM t WHERE (a = 1)"
tokens = _tokenize(sql)
where_token = _token_by_text(tokens, "WHERE")
assert _scan_until_scope_boundary(
tokens, where_token.start, stop_at_join=False
) == (
"eof",
None,
)
def test_find_condition_end_handles_subquery_closing_paren() -> None:
sql = "SELECT * FROM (SELECT * FROM t WHERE a = 1)"
tokens = _tokenize(sql)
where_index = _token_index(tokens, sqlglot.tokens.TokenType.WHERE)
end = _find_condition_end(tokens, where_index, stop_at_join=False)
assert sql[end] == ")"
def test_find_condition_end_handles_parenthesized_expression() -> None:
sql = "SELECT * FROM t WHERE (a = 1)"
tokens = _tokenize(sql)
where_index = _token_index(tokens, sqlglot.tokens.TokenType.WHERE)
end = _find_condition_end(tokens, where_index, stop_at_join=False)
assert end == len(sql)
def test_find_where_splice_handles_trailing_where_keyword() -> None:
sql = "SELECT * FROM t WHERE"
tokens = _tokenize(sql)
splices = _find_where_splice(sql, tokens, anchor=0, pred_sql="t.id = 1")
assert splices == [(len(sql), " t.id = 1")]
def test_find_join_splice_handles_trailing_on_keyword() -> None:
sql = "SELECT * FROM a JOIN b ON"
tokens = _tokenize(sql)
b_token = _token_by_text(tokens, "b")
splices = _find_join_splice(sql, tokens, b_token.end, "b.id = 1")
assert splices == [(len(sql), " b.id = 1")]
def test_find_join_splice_inserts_on_before_where_boundary() -> None:
sql = "SELECT * FROM a JOIN b WHERE x = 1"
tokens = _tokenize(sql)
b_token = _token_by_text(tokens, "b")
splices = _find_join_splice(sql, tokens, b_token.end, "b.id = 1")
assert splices == [(sql.index("WHERE") - 1, " ON b.id = 1")]
def test_scan_join_clause_covers_nested_parentheses_and_join_boundary() -> None:
sql = "SELECT * FROM a JOIN b ON (a.id = b.id) JOIN c ON 1 = 1"
tokens = _tokenize(sql)
b_token = _token_by_text(tokens, "b")
on_index, boundary_index = _scan_join_clause(tokens, b_token.end)
assert on_index is not None
assert boundary_index is not None
assert tokens[boundary_index].token_type == sqlglot.tokens.TokenType.JOIN
def test_scan_join_clause_stops_at_outer_closing_paren() -> None:
sql = "SELECT * FROM (SELECT * FROM a JOIN b) sub"
tokens = _tokenize(sql)
b_token = _token_by_text(tokens, "b")
_, boundary_index = _scan_join_clause(tokens, b_token.end)
assert boundary_index is not None
assert tokens[boundary_index].token_type == sqlglot.tokens.TokenType.R_PAREN
def test_splices_for_scope_handles_empty_join_splice_result(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _Scope:
sources = {"x": object()}
sql = "SELECT 1"
tokens = _tokenize(sql)
monkeypatch.setattr(
"superset.sql.rls_splice._classify_source_predicate",
lambda *args, **kwargs: ("join", 0, "x.id = 1"),
)
monkeypatch.setattr(
"superset.sql.rls_splice._find_join_splice",
lambda *args, **kwargs: [],
)
assert (
_splices_for_scope(
sql,
tokens,
_Scope(),
{Table("x"): ["x.id = 1"]},
None,
None,
None,
)
== []
)
def test_splices_for_scope_combines_join_and_from_splices(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _Scope:
sources = {"f": object(), "j": object()}
sql = "SELECT 1"
tokens = _tokenize(sql)
calls = [("from", 3, "f.id = 1"), ("join", 6, "j.id = 2")]
def _fake_classify(*args: object, **kwargs: object) -> tuple[str, int, str]:
return calls.pop(0)
monkeypatch.setattr(
"superset.sql.rls_splice._classify_source_predicate", _fake_classify
)
monkeypatch.setattr(
"superset.sql.rls_splice._find_join_splice",
lambda *args, **kwargs: [(50, " ON j.id = 2")],
)
monkeypatch.setattr(
"superset.sql.rls_splice._find_where_splice",
lambda *args, **kwargs: [(20, " WHERE f.id = 1")],
)
assert _splices_for_scope(
sql,
tokens,
_Scope(),
{Table("f"): ["id = 1"], Table("j"): ["id = 2"]},
None,
None,
None,
) == [(50, " ON j.id = 2"), (20, " WHERE f.id = 1")]
def test_splices_for_scope_join_then_next_source(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _Scope:
sources = {"j": object(), "f": object()}
sql = "SELECT 1"
tokens = _tokenize(sql)
calls = [("join", 3, "j.id = 2"), ("none", None, None)]
def _fake_classify(
*args: object, **kwargs: object
) -> tuple[str, int | None, str | None]:
return calls.pop(0)
monkeypatch.setattr(
"superset.sql.rls_splice._classify_source_predicate", _fake_classify
)
monkeypatch.setattr(
"superset.sql.rls_splice._find_join_splice",
lambda *args, **kwargs: [],
)
assert (
_splices_for_scope(
sql,
tokens,
_Scope(),
{Table("j"): ["id = 2"]},
None,
None,
None,
)
== []
)