Compare commits

...

25 Commits

Author SHA1 Message Date
Beto Dealmeida
20b4f33710 Small fixes 2026-05-21 10:41:55 -04:00
Beto Dealmeida
c1b7a2d2ee Bump code coverage to 100% 2026-05-20 12:11:19 -04:00
Beto Dealmeida
0ce43abe5b Address comments 2026-05-12 10:17:50 -04:00
Beto Dealmeida
439130db54 Add dialects for Exa/Solr 2026-05-12 09:59:55 -04:00
Beto Dealmeida
ec9f2da81e Add to DB without sqlglot dialect 2026-05-12 09:41:30 -04:00
Beto Dealmeida
549e51bdc4 Fix lint 2026-05-12 09:25:57 -04:00
Beto Dealmeida
72ecb20e5c Run ruff-format 2026-05-08 17:27:11 -04:00
Beto Dealmeida
d1cd84931e Increase coverage 2026-05-08 16:40:29 -04:00
Beto Dealmeida
884649e3ed Increase parity 2026-05-08 15:32:10 -04:00
Beto Dealmeida
e9dd9a6107 Simplify function signature 2026-05-08 14:28:21 -04:00
Beto Dealmeida
0fe8293c7f Simplify 2026-05-08 14:12:46 -04:00
Beto Dealmeida
af2d3babec Improvements 2026-05-08 13:49:37 -04:00
Beto Dealmeida
3fe2b2505f feat: new splice RLSMethod 2026-05-08 13:33:42 -04:00
Evan Rusackas
5bde86785f fix(docs): read capability flags from engine specs in database docs generator (#39449)
Co-authored-by: Superset Dev <dev@superset.apache.org>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-08 09:13:08 -07:00
Mehmet Salih Yavuz
69fbbfd7ce fix(table): consolidate visual column options under Visual formatting section (#39856) 2026-05-08 10:43:38 +03:00
Enzo Martellucci
d3784879c2 fix(embedded-sdk): grant fullscreen and clipboard-write by default (#39943) 2026-05-08 09:28:55 +02:00
Vitor Avila
ad5e3170dd fix: OpenSearch dialect identifier delimiters (#39953) 2026-05-07 16:19:27 -03:00
Maxime Beauchemin
aa710672ed fix(ui): remove makeUrl() double-prefix bugs under subdirectory deployment (#39503)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Vitor Avila <96086495+Vitor-Avila@users.noreply.github.com>
2026-05-07 15:39:38 -03:00
Richard Fogaca Nienkotter
8c80caefa3 fix(explore): preserve preview chart name on save (#39908) 2026-05-07 13:08:28 -03:00
Richard Fogaca Nienkotter
8088c5d1de fix(dashboard): match auto-refresh paused-dot outline to icon color (#39909) 2026-05-07 13:07:52 -03:00
Amin Ghadersohi
9b520312a1 fix(mcp): use tiktoken for response-size-guard token estimation (#39912) 2026-05-07 11:51:31 -04:00
Amin Ghadersohi
9ac4711ac8 fix(mcp): prevent DetachedInstanceError in get_chart_preview (#39921) 2026-05-07 11:44:11 -04:00
dependabot[bot]
7593d2a164 chore(deps): bump caniuse-lite from 1.0.30001791 to 1.0.30001792 in /docs (#39933)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-07 21:57:29 +07:00
dependabot[bot]
d3c44e311e chore(deps): bump aws-actions/amazon-ecr-login from 2.1.4 to 2.1.5 (#39931)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-07 21:54:59 +07:00
Enzo Martellucci
b5186d1c65 fix(reports): keep body sized so standalone screenshots don't time out (#39944) 2026-05-07 12:26:50 +02:00
55 changed files with 9327 additions and 3738 deletions

View File

@@ -58,7 +58,7 @@ jobs:
- name: Login to Amazon ECR
if: steps.describe-services.outputs.active == 'true'
id: login-ecr
uses: aws-actions/amazon-ecr-login@19d944daaa35f0fa1d3f7f8af1d3f2e5de25c5b7 # v2
uses: aws-actions/amazon-ecr-login@fa648b43de3d4d023bcb3f89ed6940096949c419 # v2
- name: Delete ECR image tag
if: steps.describe-services.outputs.active == 'true'

View File

@@ -199,7 +199,7 @@ jobs:
- name: Login to Amazon ECR
id: login-ecr
uses: aws-actions/amazon-ecr-login@19d944daaa35f0fa1d3f7f8af1d3f2e5de25c5b7 # v2
uses: aws-actions/amazon-ecr-login@fa648b43de3d4d023bcb3f89ed6940096949c419 # v2
- name: Load, tag and push image to ECR
id: push-image
@@ -235,7 +235,7 @@ jobs:
- name: Login to Amazon ECR
id: login-ecr
uses: aws-actions/amazon-ecr-login@19d944daaa35f0fa1d3f7f8af1d3f2e5de25c5b7 # v2
uses: aws-actions/amazon-ecr-login@fa648b43de3d4d023bcb3f89ed6940096949c419 # v2
- name: Check target image exists in ECR
id: check-image

View File

@@ -70,7 +70,7 @@
"@swc/core": "^1.15.33",
"antd": "^6.3.7",
"baseline-browser-mapping": "^2.10.27",
"caniuse-lite": "^1.0.30001791",
"caniuse-lite": "^1.0.30001792",
"docusaurus-plugin-openapi-docs": "^5.0.2",
"docusaurus-theme-openapi-docs": "^5.0.2",
"js-yaml": "^4.1.1",

View File

@@ -141,6 +141,47 @@ def eval_node(node):
return "<f-string>"
return None
def static_return_bool(func_node):
"""
Statically resolve a method's return value to a bool when possible.
Returns True/False for functions whose body is (effectively) a single
\`return True\` / \`return False\` — allowing a leading docstring and
ignoring pure-comment/pass statements. Returns None for anything more
complex (conditional returns, computed values, no return, etc.).
Used by \`has_implicit_cancel\` handling: \`diagnose()\` in lib.py calls
the method and checks the return value, so an override that explicitly
returns False must NOT be treated as enabling query cancelation.
"""
returns = []
other_logic = False
docstring_skipped = False
for stmt in func_node.body:
# Skip docstring (only the FIRST expression statement that is a
# string constant — later bare string literals are not docstrings
# and should count as non-trivial logic).
if (not docstring_skipped
and isinstance(stmt, ast.Expr)
and isinstance(stmt.value, ast.Constant)
and isinstance(stmt.value.value, str)):
docstring_skipped = True
continue
if isinstance(stmt, ast.Pass):
continue
if isinstance(stmt, ast.Return):
returns.append(stmt)
continue
# Any other statement (if/for/assign/etc.) means control flow is
# non-trivial; bail out to be conservative.
other_logic = True
break
if other_logic or len(returns) != 1:
return None
val = eval_node(returns[0].value)
return val if isinstance(val, bool) else None
def deep_merge(base, override):
"""Deep merge two dictionaries. Override values take precedence."""
if base is None:
@@ -186,8 +227,55 @@ if not os.path.isdir(specs_dir):
print(json.dumps({"error": f"Directory not found: {specs_dir}", "cwd": os.getcwd()}))
sys.exit(1)
# First pass: collect all class info (name, bases, metadata)
class_info = {} # class_name -> {bases: [], metadata: {}, engine_name: str, filename: str}
# Capability flag attributes with their defaults from BaseEngineSpec
CAP_ATTR_DEFAULTS = {
'supports_dynamic_schema': False,
'supports_catalog': False,
'supports_dynamic_catalog': False,
'disable_ssh_tunneling': False,
'supports_file_upload': True,
'allows_joins': True,
'allows_subqueries': True,
}
# Maps source capability attribute -> output field name used in databases.json.
# When a cap attr is assigned an unevaluable expression (e.g.
# allows_joins = is_feature_enabled("DRUID_JOINS")), the JS layer uses this
# mapping to preserve the corresponding field from the previously-generated
# JSON rather than silently inheriting an incorrect parent default.
CAP_ATTR_TO_OUTPUT_FIELD = {
'allows_joins': 'joins',
'allows_subqueries': 'subqueries',
'supports_dynamic_schema': 'supports_dynamic_schema',
'supports_catalog': 'supports_catalog',
'supports_dynamic_catalog': 'supports_dynamic_catalog',
'disable_ssh_tunneling': 'ssh_tunneling',
'supports_file_upload': 'supports_file_upload',
}
# Methods that indicate a capability when overridden by a non-BaseEngineSpec class.
# Mirrors the has_custom_method checks in superset/db_engine_specs/lib.py.
# cancel_query / has_implicit_cancel -> query_cancelation
# (diagnose() checks cancel_query override OR has_implicit_cancel() == True;
# base has_implicit_cancel returns False, so overriding it is the static
# equivalent of that method returning True. get_cancel_query_id is NOT
# part of the diagnose() heuristic and is intentionally excluded.)
# estimate_statement_cost / estimate_query_cost -> query_cost_estimation
# impersonate_user / update_impersonation_config / get_url_for_impersonation -> user_impersonation
# validate_sql -> sql_validation (not used yet; validation is engine-based)
CAP_METHODS = {
'cancel_query', 'has_implicit_cancel',
'estimate_statement_cost', 'estimate_query_cost',
'impersonate_user', 'update_impersonation_config', 'get_url_for_impersonation',
'validate_sql',
}
# Only the literal BaseEngineSpec is excluded from method-override tracking.
# Intermediate base classes (e.g. PrestoBaseEngineSpec) do count as overrides.
TRUE_BASE_CLASS = 'BaseEngineSpec'
# First pass: collect all class info (name, bases, metadata, cap_attrs, direct_methods)
class_info = {} # class_name -> {bases: [], metadata: {}, engine_name: str, filename: str, ...}
for filename in sorted(os.listdir(specs_dir)):
if not filename.endswith('.py') or filename in ('__init__.py', 'lib.py', 'lint_metadata.py'):
@@ -218,30 +306,54 @@ for filename in sorted(os.listdir(specs_dir)):
# Extract class attributes
engine_name = None
engine_attr = None
metadata = None
cap_attrs = {} # capability flag attributes defined directly in this class
# Cap attrs assigned via expressions we can't statically resolve
# (e.g. is_feature_enabled("FLAG")). Tracked so the JS layer can
# fall back to the previously-generated databases.json value
# rather than inherit a parent default that would be wrong.
unresolved_cap_attrs = set()
direct_methods = set() # capability methods defined directly in this class
for item in node.body:
if isinstance(item, ast.Assign):
for target in item.targets:
if isinstance(target, ast.Name):
if target.id == 'engine_name':
val = eval_node(item.value)
if isinstance(val, str):
engine_name = val
elif target.id == 'metadata':
metadata = eval_node(item.value)
if not isinstance(target, ast.Name):
continue
if target.id == 'engine_name':
val = eval_node(item.value)
if isinstance(val, str):
engine_name = val
elif target.id == 'engine':
val = eval_node(item.value)
if isinstance(val, str):
engine_attr = val
elif target.id == 'metadata':
metadata = eval_node(item.value)
elif target.id in CAP_ATTR_DEFAULTS:
val = eval_node(item.value)
if isinstance(val, bool):
cap_attrs[target.id] = val
else:
# Unevaluable expression — defer to JS fallback.
unresolved_cap_attrs.add(target.id)
elif isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
if item.name in CAP_METHODS:
# has_implicit_cancel is special: diagnose() uses the
# method's RETURN VALUE, not just its presence. If the
# override statically returns False, treat it as if
# the method weren't overridden so query_cancelation
# matches diagnose(). Unresolvable / True / anything
# else falls through as an override (conservative).
if item.name == 'has_implicit_cancel':
if static_return_bool(item) is False:
continue
direct_methods.add(item.name)
# Check for engine attribute with non-empty value to distinguish
# true base classes from product classes like OceanBaseEngineSpec
has_non_empty_engine = False
for item in node.body:
if isinstance(item, ast.Assign):
for target in item.targets:
if isinstance(target, ast.Name) and target.id == 'engine':
# Check if engine value is non-empty string
if isinstance(item.value, ast.Constant):
has_non_empty_engine = bool(item.value.value)
break
has_non_empty_engine = engine_attr is not None and bool(engine_attr)
# True base classes: end with BaseEngineSpec AND don't define engine
# or have empty engine (like PostgresBaseEngineSpec with engine = "")
@@ -254,13 +366,18 @@ for filename in sorted(os.listdir(specs_dir)):
'bases': base_names,
'metadata': metadata,
'engine_name': engine_name,
'engine': engine_attr,
'filename': filename,
'is_base_or_mixin': is_true_base,
'cap_attrs': cap_attrs,
'unresolved_cap_attrs': unresolved_cap_attrs,
'direct_methods': direct_methods,
}
except Exception as e:
errors.append(f"{filename}: {str(e)}")
# Second pass: resolve inheritance and build final metadata
# Second pass: resolve inheritance and build final metadata + capability flags
def get_inherited_metadata(class_name, visited=None):
"""Recursively get metadata from parent classes."""
if visited is None:
@@ -286,6 +403,64 @@ def get_inherited_metadata(class_name, visited=None):
return inherited
def get_resolved_caps(class_name, visited=None):
"""
Resolve capability flags and method overrides with inheritance.
Returns (attr_values, unresolved, methods):
- attr_values: {attr: bool} for attrs where the nearest MRO assignment
was a literal bool. Defaults are applied at the call site.
- unresolved: attrs where the nearest MRO assignment was an unevaluable
expression (e.g. is_feature_enabled("FLAG")). The JS layer falls
back to the previously-generated JSON value for these.
- methods: capability methods defined directly in some non-base ancestor,
matching the has_custom_method() logic in db_engine_specs/lib.py.
attr_values and unresolved are disjoint — an attr is in at most one.
"""
if visited is None:
visited = set()
if class_name in visited:
return {}, set(), set()
visited.add(class_name)
info = class_info.get(class_name)
if not info:
return {}, set(), set()
attr_values = {}
unresolved = set()
resolved_methods = set()
# Collect from parents, iterating right-to-left so leftmost bases win
# (matches Python MRO: for class C(A, B), A's attributes take precedence).
for base_name in reversed(info['bases']):
p_vals, p_unres, p_meth = get_resolved_caps(base_name, visited.copy())
# A parent's literal assignments overwrite whatever we inherited so far.
for attr, val in p_vals.items():
attr_values[attr] = val
unresolved.discard(attr)
# A parent's unresolved assignments likewise take precedence.
for attr in p_unres:
unresolved.add(attr)
attr_values.pop(attr, None)
resolved_methods.update(p_meth)
# Apply this class's own assignments (override parents).
for attr, val in info['cap_attrs'].items():
attr_values[attr] = val
unresolved.discard(attr)
for attr in info['unresolved_cap_attrs']:
unresolved.add(attr)
attr_values.pop(attr, None)
# Accumulate method overrides, but skip the literal BaseEngineSpec
# (its implementations are stubs; only non-base overrides count).
if class_name != TRUE_BASE_CLASS:
resolved_methods.update(info['direct_methods'])
return attr_values, unresolved, resolved_methods
for class_name, info in class_info.items():
# Skip base classes and mixins
if info['is_base_or_mixin']:
@@ -310,7 +485,14 @@ for class_name, info in class_info.items():
if final_metadata and isinstance(final_metadata, dict) and display_name:
debug_info["classes_with_metadata"] += 1
databases[display_name] = {
# Resolve capability flags from Python source
attr_values, unresolved_caps, cap_methods = get_resolved_caps(class_name)
cap_attrs = dict(CAP_ATTR_DEFAULTS)
cap_attrs.update(attr_values)
engine_attr = info.get('engine') or ''
entry = {
'engine': display_name.lower().replace(' ', '_'),
'engine_name': display_name,
'module': info['filename'][:-3], # Remove .py extension
@@ -318,19 +500,40 @@ for class_name, info in class_info.items():
'time_grains': {},
'score': 0,
'max_score': 0,
'joins': True,
'subqueries': True,
'supports_dynamic_schema': False,
'supports_catalog': False,
'supports_dynamic_catalog': False,
'ssh_tunneling': False,
'query_cancelation': False,
'supports_file_upload': False,
'user_impersonation': False,
'query_cost_estimation': False,
'sql_validation': False,
# Capability flags read from engine spec class attributes/methods
'joins': cap_attrs['allows_joins'],
'subqueries': cap_attrs['allows_subqueries'],
'supports_dynamic_schema': cap_attrs['supports_dynamic_schema'],
'supports_catalog': cap_attrs['supports_catalog'],
'supports_dynamic_catalog': cap_attrs['supports_dynamic_catalog'],
'ssh_tunneling': not cap_attrs['disable_ssh_tunneling'],
'supports_file_upload': cap_attrs['supports_file_upload'],
# Method-based flags: True only when a non-base class overrides them.
# Matches diagnose() in lib.py: cancel_query override OR
# has_implicit_cancel() returning True (which, given the base
# returns False, is equivalent to overriding has_implicit_cancel).
'query_cancelation': bool({'cancel_query', 'has_implicit_cancel'} & cap_methods),
'query_cost_estimation': bool({'estimate_statement_cost', 'estimate_query_cost'} & cap_methods),
# SQL validation is implemented in external validator classes keyed by engine name
'sql_validation': engine_attr in {'presto', 'postgresql'},
'user_impersonation': bool(
{'impersonate_user', 'update_impersonation_config', 'get_url_for_impersonation'} & cap_methods
),
}
# Tell the JS layer which output fields were populated from the
# BaseEngineSpec default because the source assignment was an
# unevaluable expression; those get overridden from existing JSON.
unresolved_fields = sorted(
CAP_ATTR_TO_OUTPUT_FIELD[attr]
for attr in unresolved_caps
if attr in CAP_ATTR_TO_OUTPUT_FIELD
)
if unresolved_fields:
entry['_unresolved_cap_fields'] = unresolved_fields
databases[display_name] = entry
if errors and not databases:
print(json.dumps({"error": "Parse errors", "details": errors, "debug": debug_info}), file=sys.stderr)
@@ -851,24 +1054,52 @@ function loadExistingData() {
}
}
/**
* Fall back to the previously-generated databases.json for capability flags
* whose source assignment couldn't be statically resolved (e.g.
* `allows_joins = is_feature_enabled("DRUID_JOINS")`). The Python extractor
* flags these via the internal `_unresolved_cap_fields` marker; without this
* fallback those fields would silently inherit the BaseEngineSpec default
* and disagree with runtime behavior. The marker is stripped before output.
*/
function fallbackUnresolvedCaps(newDatabases, existingData) {
for (const [name, db] of Object.entries(newDatabases)) {
const unresolved = db._unresolved_cap_fields;
if (!unresolved || unresolved.length === 0) {
delete db._unresolved_cap_fields;
continue;
}
const existingDb = existingData?.databases?.[name];
if (existingDb) {
for (const field of unresolved) {
if (existingDb[field] !== undefined) {
db[field] = existingDb[field];
}
}
}
delete db._unresolved_cap_fields;
}
return newDatabases;
}
/**
* Merge new documentation with existing diagnostics
* Preserves score, time_grains, and feature flags from existing data
* Preserves score, max_score, and time_grains from existing data (these require
* Flask context to generate and cannot be derived from static source analysis).
* Capability flags (joins, supports_catalog, etc.) are NOT preserved here — they
* are read fresh from the Python engine spec source by extractEngineSpecMetadata(),
* with a separate fallback for expression-based assignments (see fallbackUnresolvedCaps).
*/
function mergeWithExistingDiagnostics(newDatabases, existingData) {
if (!existingData?.databases) return newDatabases;
const diagnosticFields = [
'score', 'max_score', 'time_grains', 'joins', 'subqueries',
'supports_dynamic_schema', 'supports_catalog', 'supports_dynamic_catalog',
'ssh_tunneling', 'query_cancelation', 'supports_file_upload',
'user_impersonation', 'query_cost_estimation', 'sql_validation'
];
// Only preserve fields that require Flask/runtime context to generate
const diagnosticFields = ['score', 'max_score', 'time_grains'];
for (const [name, db] of Object.entries(newDatabases)) {
const existingDb = existingData.databases[name];
if (existingDb && existingDb.score > 0) {
// Preserve diagnostics from existing data
// Preserve score/time_grain diagnostics from existing data
for (const field of diagnosticFields) {
if (existingDb[field] !== undefined) {
db[field] = existingDb[field];
@@ -879,7 +1110,7 @@ function mergeWithExistingDiagnostics(newDatabases, existingData) {
const preserved = Object.values(newDatabases).filter(d => d.score > 0).length;
if (preserved > 0) {
console.log(`Preserved diagnostics for ${preserved} databases from existing data`);
console.log(`Preserved score/time_grains for ${preserved} databases from existing data`);
}
return newDatabases;
@@ -927,6 +1158,12 @@ async function main() {
databases = mergeWithExistingDiagnostics(databases, existingData);
}
// For cap flags assigned via unevaluable expressions (e.g.
// `is_feature_enabled(...)`), prefer the value from a previously-generated
// JSON. Runs regardless of scores since it addresses static-analysis gaps,
// not missing Flask diagnostics. Always strips the internal marker.
databases = fallbackUnresolvedCaps(databases, existingData);
// Extract and merge custom_errors for troubleshooting documentation
const customErrors = extractCustomErrors();
mergeCustomErrors(databases, customErrors);

File diff suppressed because it is too large Load Diff

View File

@@ -6035,10 +6035,10 @@ caniuse-api@^3.0.0:
lodash.memoize "^4.1.2"
lodash.uniq "^4.5.0"
caniuse-lite@^1.0.0, caniuse-lite@^1.0.30001702, caniuse-lite@^1.0.30001759, caniuse-lite@^1.0.30001791:
version "1.0.30001791"
resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001791.tgz#dfb93d85c40ad380c57123e72e10f3c575786b51"
integrity sha512-yk0l/YSrOnFZk3UROpDLQD9+kC1l4meK/wed583AXrzoarMGJcbRi2Q4RaUYbKxYAsZ8sWmaSa/DsLmdBeI1vQ==
caniuse-lite@^1.0.0, caniuse-lite@^1.0.30001702, caniuse-lite@^1.0.30001759, caniuse-lite@^1.0.30001792:
version "1.0.30001792"
resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001792.tgz#ca8bb9be244835a335e2018272ce7223691873c5"
integrity sha512-hVLMUZFgR4JJ6ACt1uEESvQN1/dBVqPAKY0hgrV70eN3391K6juAfTjKZLKvOMsx8PxA7gsY1/tLMMTcfFLLpw==
ccount@^2.0.0:
version "2.0.1"

View File

@@ -142,10 +142,16 @@ druid = ["pydruid>=0.6.5,<0.7"]
duckdb = ["duckdb>=1.4.2,<2", "duckdb-engine>=0.17.0"]
dynamodb = ["pydynamodb>=0.4.2"]
solr = ["sqlalchemy-solr >= 0.2.0"]
elasticsearch = ["elasticsearch-dbapi>=0.2.12, <0.3.0"]
elasticsearch = ["elasticsearch-dbapi>=0.2.13, <0.3.0"]
exasol = ["sqlalchemy-exasol >= 2.4.0, <3.0"]
excel = ["xlrd>=1.2.0, <1.3"]
fastmcp = ["fastmcp>=3.2.4,<4.0"]
fastmcp = [
"fastmcp>=3.2.4,<4.0",
# tiktoken backs the response-size-guard token estimator. Without
# it, the middleware falls back to a coarser character-based
# heuristic that under-counts JSON-heavy MCP responses.
"tiktoken>=0.7.0,<1.0",
]
firebird = ["sqlalchemy-firebird>=0.7.0, <0.8"]
firebolt = ["firebolt-sqlalchemy>=1.0.0, <2"]
gevent = ["gevent>=23.9.1"]

View File

@@ -183,7 +183,9 @@ idna==3.10
# trio
# url-normalize
isodate==0.7.2
# via apache-superset (pyproject.toml)
# via
# apache-superset (pyproject.toml)
# apache-superset-core
itsdangerous==2.2.0
# via
# flask
@@ -296,6 +298,7 @@ pyarrow==20.0.0
# via
# -r requirements/base.in
# apache-superset (pyproject.toml)
# apache-superset-core
pyasn1==0.6.3
# via
# pyasn1-modules

View File

@@ -442,6 +442,7 @@ isodate==0.7.2
# via
# -c requirements/base-constraint.txt
# apache-superset
# apache-superset-core
isort==6.0.1
# via pylint
itsdangerous==2.2.0
@@ -715,6 +716,7 @@ pyarrow==20.0.0
# via
# -c requirements/base-constraint.txt
# apache-superset
# apache-superset-core
# db-dtypes
# pandas-gbq
pyasn1==0.6.3
@@ -866,6 +868,8 @@ referencing==0.36.2
# jsonschema
# jsonschema-path
# jsonschema-specifications
regex==2026.4.4
# via tiktoken
requests==2.33.0
# via
# -c requirements/base-constraint.txt
@@ -878,6 +882,7 @@ requests==2.33.0
# requests-cache
# requests-oauthlib
# shillelagh
# tiktoken
# trino
requests-cache==1.2.1
# via
@@ -1003,6 +1008,8 @@ tabulate==0.9.0
# via
# -c requirements/base-constraint.txt
# apache-superset
tiktoken==0.12.0
# via apache-superset
tomli-w==1.2.0
# via apache-superset-extensions-cli
tomlkit==0.13.3

View File

@@ -66,7 +66,7 @@ export type EmbedDashboardParams = {
iframeTitle?: string;
/** additional iframe sandbox attributes ex (allow-top-navigation, allow-popups-to-escape-sandbox) **/
iframeSandboxExtras?: string[];
/** iframe allow attribute for Permissions Policy (e.g., ['clipboard-write', 'fullscreen']) **/
/** Additional Permissions Policy features for the iframe's `allow` attribute (e.g., ['camera', 'microphone']). `fullscreen` and `clipboard-write` are granted by default. **/
iframeAllowExtras?: string[];
/** force a specific refererPolicy to be used in the iframe request **/
referrerPolicy?: ReferrerPolicy;
@@ -233,9 +233,14 @@ export async function embedDashboard({
iframe.src = `${supersetDomain}/embedded/${id}${urlParamsString}`;
iframe.title = iframeTitle;
iframe.style.background = 'transparent';
if (iframeAllowExtras.length > 0) {
iframe.setAttribute('allow', iframeAllowExtras.join('; '));
}
// Permissions Policy features the embedded dashboard relies on. Modern
// browsers gate these APIs on the iframe's `allow` attribute regardless
// of sandbox flags, so we include them by default. Host apps can extend
// the list via `iframeAllowExtras`.
const allowFeatures = Array.from(
new Set(['fullscreen', 'clipboard-write', ...iframeAllowExtras]),
);
iframe.setAttribute('allow', allowFeatures.join('; '));
//@ts-ignore
mountPoint.replaceChildren(iframe);
log('placed the iframe');

View File

@@ -494,6 +494,12 @@ const config: ControlPanelConfig = {
},
},
],
],
},
{
label: t('Visual formatting'),
expanded: true,
controlSetRows: [
[
{
name: 'column_config',
@@ -587,18 +593,12 @@ const config: ControlPanelConfig = {
},
},
],
],
},
{
label: t('Visual formatting'),
expanded: true,
controlSetRows: [
[
{
name: 'show_cell_bars',
config: {
type: 'CheckboxControl',
label: t('Show cell bars'),
label: t('Show cell bars for all columns'),
renderTrigger: true,
default: true,
description: t(
@@ -612,7 +612,7 @@ const config: ControlPanelConfig = {
name: 'align_pn',
config: {
type: 'CheckboxControl',
label: t('Align +/-'),
label: t('Align +/- for all columns'),
renderTrigger: true,
default: false,
description: t(
@@ -626,7 +626,7 @@ const config: ControlPanelConfig = {
name: 'color_pn',
config: {
type: 'CheckboxControl',
label: t('Add colors to cell bars for +/-'),
label: t('Add colors to cell bars for +/- for all columns'),
renderTrigger: true,
default: true,
description: t(

View File

@@ -552,6 +552,12 @@ const config: ControlPanelConfig = {
},
},
],
],
},
{
label: t('Visual formatting'),
expanded: true,
controlSetRows: [
[
{
name: 'column_config',
@@ -648,18 +654,12 @@ const config: ControlPanelConfig = {
},
},
],
],
},
{
label: t('Visual formatting'),
expanded: true,
controlSetRows: [
[
{
name: 'show_cell_bars',
config: {
type: 'CheckboxControl',
label: t('Show cell bars'),
label: t('Show cell bars for all columns'),
renderTrigger: true,
default: true,
description: t(
@@ -673,7 +673,7 @@ const config: ControlPanelConfig = {
name: 'align_pn',
config: {
type: 'CheckboxControl',
label: t('Align +/-'),
label: t('Align +/- for all columns'),
renderTrigger: true,
default: false,
description: t(
@@ -687,7 +687,7 @@ const config: ControlPanelConfig = {
name: 'color_pn',
config: {
type: 'CheckboxControl',
label: t('Add colors to cell bars for +/-'),
label: t('Add colors to cell bars for +/- for all columns'),
renderTrigger: true,
default: true,
description: t(

View File

@@ -17,7 +17,8 @@
* under the License.
*/
import { render, screen, act } from 'spec/helpers/testing-library';
import { StatusIndicatorDot } from './StatusIndicatorDot';
import { supersetTheme } from '@apache-superset/core/theme';
import { getStatusConfig, StatusIndicatorDot } from './StatusIndicatorDot';
import { AutoRefreshStatus } from '../../types/autoRefresh';
afterEach(() => {
@@ -62,6 +63,15 @@ test('renders with paused status', () => {
expect(dot).toHaveAttribute('data-status', AutoRefreshStatus.Paused);
});
test('uses the icon color for the paused status outline', () => {
expect(
getStatusConfig(supersetTheme, AutoRefreshStatus.Paused),
).toMatchObject({
needsBorder: true,
outlineColor: 'currentColor',
});
});
test('has correct accessibility attributes', () => {
render(<StatusIndicatorDot status={AutoRefreshStatus.Success} />);
const dot = screen.getByTestId('status-indicator-dot');

View File

@@ -39,9 +39,10 @@ export interface StatusIndicatorDotProps {
interface StatusConfig {
color: string;
needsBorder: boolean;
outlineColor?: string;
}
const getStatusConfig = (
export const getStatusConfig = (
theme: ReturnType<typeof useTheme>,
status: AutoRefreshStatus,
): StatusConfig => {
@@ -75,6 +76,7 @@ const getStatusConfig = (
return {
color: theme.colorBgContainer,
needsBorder: true,
outlineColor: 'currentColor',
};
default:
return {
@@ -136,13 +138,15 @@ export const StatusIndicatorDot: FC<StatusIndicatorDotProps> = ({
width: ${size}px;
height: ${size}px;
border-radius: 50%;
color: ${theme.colorTextSecondary};
background-color: ${statusConfig.color};
transition:
background-color ${theme.motionDurationMid} ease-in-out,
border-color ${theme.motionDurationMid} ease-in-out;
border: ${statusConfig.needsBorder
? `1px solid ${theme.colorBorder}`
: 'none'};
border: ${statusConfig.needsBorder ? '1px solid' : 'none'};
border-color: ${statusConfig.needsBorder
? statusConfig.outlineColor
: 'transparent'};
box-shadow: ${statusConfig.needsBorder
? 'none'
: `0 0 0 2px ${theme.colorBgContainer}`};

View File

@@ -21,6 +21,10 @@ import { VizType } from '@superset-ui/core';
import { hydrateExplore, HYDRATE_EXPLORE } from './hydrateExplore';
import { exploreInitialData } from '../fixtures';
afterEach(() => {
window.history.pushState({}, '', '/');
});
test('creates hydrate action from initial data', () => {
const dispatch = jest.fn();
const getState = jest.fn(() => ({
@@ -168,6 +172,84 @@ test('creates hydrate action with existing state', () => {
);
});
test('hydrates sliceName from preview form data before saved slice name', () => {
window.history.pushState({}, '', '/explore/?form_data_key=preview-key');
const dispatch = jest.fn();
const getState = jest.fn(() => ({
user: {},
charts: {},
datasources: {},
common: {},
explore: {},
}));
const previewSliceName = 'RENAMED - Bug Evidence';
const savedSliceName = 'Most Populated Countries';
const previewInitialData = {
...exploreInitialData,
form_data: {
...exploreInitialData.form_data,
slice_name: previewSliceName,
},
slice: {
...exploreInitialData.slice!,
slice_name: savedSliceName,
},
};
// @ts-expect-error we only need the fields consumed by hydrateExplore
hydrateExplore(previewInitialData)(dispatch, getState);
expect(dispatch).toHaveBeenCalledWith(
expect.objectContaining({
type: HYDRATE_EXPLORE,
data: expect.objectContaining({
explore: expect.objectContaining({
sliceName: previewSliceName,
}),
}),
}),
);
});
test('hydrates sliceName from saved slice when regular form data has stale name', () => {
const dispatch = jest.fn();
const getState = jest.fn(() => ({
user: {},
charts: {},
datasources: {},
common: {},
explore: {},
}));
const staleFormDataSliceName = 'Stale Params Name';
const savedSliceName = 'Current Saved Name';
const savedChartInitialData = {
...exploreInitialData,
form_data: {
...exploreInitialData.form_data,
slice_name: staleFormDataSliceName,
},
slice: {
...exploreInitialData.slice!,
slice_name: savedSliceName,
},
};
// @ts-expect-error we only need the fields consumed by hydrateExplore
hydrateExplore(savedChartInitialData)(dispatch, getState);
expect(dispatch).toHaveBeenCalledWith(
expect.objectContaining({
type: HYDRATE_EXPLORE,
data: expect.objectContaining({
explore: expect.objectContaining({
sliceName: savedSliceName,
}),
}),
}),
);
});
test('uses configured default time range if not set', () => {
const dispatch = jest.fn();
const getState = jest.fn(() => ({

View File

@@ -77,6 +77,12 @@ export const hydrateExplore =
const fallbackSlice = sliceId ? sliceEntities?.slices?.[sliceId] : null;
const initialSlice = slice ?? fallbackSlice;
const initialFormData = form_data ?? initialSlice?.form_data;
const isCachedFormData = getUrlParam(URL_PARAMS.formDataKey) !== null;
const [primarySliceNameSource, fallbackSliceNameSource] = isCachedFormData
? [initialFormData, initialSlice]
: [initialSlice, initialFormData];
const initialSliceName =
primarySliceNameSource?.slice_name ?? fallbackSliceNameSource?.slice_name;
if (!initialFormData.viz_type) {
const defaultVizType = common?.conf.DEFAULT_VIZ_TYPE || VizType.Table;
initialFormData.viz_type =
@@ -183,6 +189,7 @@ export const hydrateExplore =
// because `bootstrapData.controls` is undefined.
controls: initialControls,
form_data: initialFormData,
sliceName: initialSliceName,
slice: initialSlice,
controlsTransferred: explore.controlsTransferred,
standalone: getUrlParam(URL_PARAMS.standalone),

View File

@@ -179,6 +179,33 @@ test('renders the right footer buttons', () => {
).toBeInTheDocument();
});
test('initializes chart name from current Explore slice name', () => {
const previewSliceName = 'RENAMED - Bug Evidence';
const savedSliceName = 'Most Populated Countries';
const { getByTestId } = setup(
{
...defaultProps,
form_data: {
...defaultProps.form_data,
slice_name: previewSliceName,
},
sliceName: previewSliceName,
},
mockStore({
...initialState,
explore: {
...initialState.explore,
slice: {
...initialState.explore.slice,
slice_name: savedSliceName,
},
},
}),
);
expect(getByTestId('new-chart-name')).toHaveValue(previewSliceName);
});
test('does not render a message when overriding', () => {
const { getByRole, queryByRole } = setup();

View File

@@ -35,7 +35,6 @@ import { CheckboxChangeEvent } from '@superset-ui/core/components/Checkbox/types
import { useHistory } from 'react-router-dom';
import { setItem, LocalStorageKeys } from 'src/utils/localStorageHelpers';
import { makeUrl } from 'src/utils/pathUtils';
import Tabs from '@superset-ui/core/components/Tabs';
import {
Button,
@@ -1824,7 +1823,9 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
onClick={() => {
setLoading(true);
fetchAndSetDB();
redirectURL(makeUrl(`/sqllab?db=true`));
// redirectURL() delegates to history.push; React Router's basename
// already prefixes the application root, so pass a relative path.
redirectURL('/sqllab?db=true');
}}
>
{t('Query data in SQL Lab')}

View File

@@ -24,7 +24,6 @@ import { TableTab } from 'src/views/CRUD/types';
import { t } from '@apache-superset/core/translation';
import { styled } from '@apache-superset/core/theme';
import { navigateTo } from 'src/utils/navigationUtils';
import { makeUrl } from 'src/utils/pathUtils';
import { WelcomeTable } from './types';
const EmptyContainer = styled.div`
@@ -59,7 +58,9 @@ const REDIRECTS = {
create: {
[WelcomeTable.Charts]: '/chart/add',
[WelcomeTable.Dashboards]: '/dashboard/new',
[WelcomeTable.SavedQueries]: makeUrl('/sqllab?new=true'),
// navigateTo() applies the application root internally; keep this
// relative so the prefix isn't added twice.
[WelcomeTable.SavedQueries]: '/sqllab?new=true',
},
viewAll: {
[WelcomeTable.Charts]: '/chart/list',

View File

@@ -44,7 +44,7 @@ import {
TelemetryPixel,
} from '@superset-ui/core/components';
import type { ItemType, MenuItem } from '@superset-ui/core/components/Menu';
import { ensureAppRoot, makeUrl } from 'src/utils/pathUtils';
import { ensureAppRoot } from 'src/utils/pathUtils';
import { isEmbedded } from 'src/dashboard/util/isEmbedded';
import { findPermission } from 'src/utils/findPermission';
import { isUserAdmin } from 'src/dashboard/util/permissionUtils';
@@ -213,7 +213,10 @@ const RightMenu = ({
},
{
label: t('SQL query'),
url: makeUrl('/sqllab?new=true'),
// Keep the URL relative so isFrontendRoute() matches and Link navigates
// via React Router; the <Typography.Link> fallback applies ensureAppRoot
// exactly once for non-frontend routes.
url: '/sqllab?new=true',
icon: <Icons.SearchOutlined data-test={`menu-item-${t('SQL query')}`} />,
perm: 'can_sqllab',
view: 'Superset',

View File

@@ -25,11 +25,20 @@ import {
fireEvent,
waitFor,
} from 'spec/helpers/testing-library';
import { MemoryRouter } from 'react-router-dom';
import { MemoryRouter, useLocation } from 'react-router-dom';
import { QueryParamProvider } from 'use-query-params';
import { ReactRouter5Adapter } from 'use-query-params/adapters/react-router-5';
import * as getBootstrapData from 'src/utils/getBootstrapData';
import SavedQueryList from '.';
// Renders the current router pathname+search so tests can assert navigation.
function LocationDisplay() {
const location = useLocation();
return (
<div data-test="location-display">{`${location.pathname}${location.search}`}</div>
);
}
// Increase default timeout
jest.setTimeout(30000);
@@ -88,6 +97,7 @@ const renderList = (props = {}, storeOverrides = {}) =>
<MemoryRouter>
<QueryParamProvider adapter={ReactRouter5Adapter}>
<SavedQueryList user={mockUser} {...props} />
<LocationDisplay />
</QueryParamProvider>
</MemoryRouter>,
{
@@ -242,4 +252,39 @@ describe('SavedQueryList', () => {
// Verify delete buttons are not shown
expect(screen.queryByTestId('delete-action')).not.toBeInTheDocument();
});
test('"+ Query" button pushes a router-relative path (subdirectory deployment)', async () => {
// Simulate SUPERSET_APP_ROOT=/superset. ensureAppRoot/makeUrl read
// applicationRoot() dynamically, so mocking it here makes the buggy code
// path (makeUrl() around history.push) produce '/superset/sqllab?new=true'
// instead of being a no-op. React Router's <Router basename> prefixes the
// app root on its own, so history.push MUST receive a path without the
// app-root prefix — otherwise navigation lands at /superset/superset/sqllab
// and shows a blank page (sc-103661).
const applicationRootSpy = jest
.spyOn(getBootstrapData, 'applicationRoot')
.mockReturnValue('/superset');
try {
renderList();
await screen.findByTestId('saved_query-list-view');
const queryButton = await screen.findByRole('button', {
name: /query/i,
});
fireEvent.click(queryButton);
await waitFor(() => {
// The MemoryRouter in renderList uses the default ('/') basename, so
// useLocation reflects exactly what history.push received. A correct
// router-relative push produces '/sqllab?new=true'; a buggy push that
// re-applied the app root would produce '/superset/sqllab?new=true'.
const location = screen.getByTestId('location-display').textContent;
expect(location).toBe('/sqllab?new=true');
});
} finally {
applicationRootSpy.mockRestore();
}
});
});

View File

@@ -223,7 +223,9 @@ function SavedQueryList({
name: t('Query'),
buttonStyle: 'primary',
onClick: () => {
history.push(makeUrl('/sqllab?new=true'));
// React Router's basename already includes the application root; passing
// a relative path ensures correct navigation under subdirectory deployments.
history.push('/sqllab?new=true');
},
});
@@ -245,7 +247,9 @@ function SavedQueryList({
if (openInNewWindow) {
window.open(makeUrl(`/sqllab?savedQueryId=${id}`));
} else {
history.push(makeUrl(`/sqllab?savedQueryId=${id}`));
// React Router's basename already includes the application root; passing
// a relative path ensures correct navigation under subdirectory deployments.
history.push(`/sqllab?savedQueryId=${id}`);
}
};
@@ -338,9 +342,7 @@ function SavedQueryList({
row: {
original: { id, label },
},
}: any) => (
<Link to={makeUrl(`/sqllab?savedQueryId=${id}`)}>{label}</Link>
),
}: any) => <Link to={`/sqllab?savedQueryId=${id}`}>{label}</Link>,
id: 'label',
},
{

View File

@@ -557,6 +557,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# if True, database will be listed as option in the upload file form
supports_file_upload = True
# RLS strategy for this engine spec. Override in engine-specific classes as
# needed (for example ``RLSMethod.AS_PREDICATE`` for engines that don't
# support subquery-based RLS, or ``RLSMethod.AS_PREDICATE_SPLICE`` for
# engines where sqlglot generation is not faithful).
rls_method = RLSMethod.AS_SUBQUERY
# Is the DB engine spec able to change the default schema? This requires implementing # noqa: E501
# a custom `adjust_engine_params` method.
supports_dynamic_schema = False
@@ -618,21 +624,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
else cls.encrypted_extra_sensitive_fields
)
@classmethod
def get_rls_method(cls) -> RLSMethod:
"""
Returns the RLS method to be used for this engine.
There are two ways to insert RLS: either replacing the table with a subquery
that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
safer, but not supported in all databases.
"""
return (
RLSMethod.AS_SUBQUERY
if cls.allows_subqueries and cls.allows_alias_in_select
else RLSMethod.AS_PREDICATE
)
@classmethod
def is_oauth2_enabled(cls) -> bool:
return (

View File

@@ -35,6 +35,7 @@ from superset.db_engine_specs.base import (
DatabaseCategory,
)
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql.parse import RLSMethod
from superset.utils.network import is_hostname_valid, is_port_open
@@ -82,6 +83,7 @@ class CouchbaseEngineSpec(BasicParametersMixin, BaseEngineSpec):
default_driver = "couchbase"
allows_joins = False
allows_subqueries = False
rls_method = RLSMethod.AS_PREDICATE
sqlalchemy_uri_placeholder = (
"couchbase://user:password@host[:port]?truststorepath=value?ssl=value"
)

View File

@@ -23,6 +23,7 @@ from sqlalchemy import types
from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import RLSMethod
if TYPE_CHECKING:
from superset.connectors.sqla.models import TableColumn
@@ -68,6 +69,8 @@ class CrateEngineSpec(BaseEngineSpec):
TimeGrain.YEAR: "DATE_TRUNC('year', {col})",
}
rls_method = RLSMethod.AS_PREDICATE_SPLICE
@classmethod
def epoch_to_dttm(cls) -> str:
return "{col} * 1000"

View File

@@ -39,6 +39,7 @@ from superset.db_engine_specs.base import (
)
from superset.db_engine_specs.exceptions import SupersetDBAPIDatabaseError
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql.parse import RLSMethod
from superset.utils.core import GenericDataType
from superset.utils.hashing import hash_from_str
from superset.utils.network import is_hostname_valid, is_port_open
@@ -55,6 +56,8 @@ class DatabendBaseEngineSpec(BaseEngineSpec):
time_secondary_columns = True
time_groupby_inline = True
rls_method = RLSMethod.AS_PREDICATE_SPLICE
_time_grain_expressions = {
None: "{col}",
TimeGrain.SECOND: "DATE_TRUNC('SECOND', {col})",

View File

@@ -26,6 +26,7 @@ from superset.db_engine_specs.base import (
DatabaseCategory,
)
from superset.errors import SupersetErrorType
from superset.sql.parse import RLSMethod
# Internal class for defining error message patterns (for translation)
@@ -58,6 +59,8 @@ class DenodoEngineSpec(BaseEngineSpec, BasicParametersMixin):
engine = "denodo"
engine_name = "Denodo"
rls_method = RLSMethod.AS_PREDICATE_SPLICE
default_driver = "psycopg2"
sqlalchemy_uri_placeholder = (
"denodo://user:password@host:port/dbname[?key=value&key=value...]"

View File

@@ -21,12 +21,15 @@ from sqlalchemy import types
from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import RLSMethod
class DynamoDBEngineSpec(BaseEngineSpec):
engine = "dynamodb"
engine_name = "Amazon DynamoDB"
rls_method = RLSMethod.AS_PREDICATE_SPLICE
metadata = {
"description": (
"Amazon DynamoDB is a serverless NoSQL database with SQL via PartiQL."

View File

@@ -28,6 +28,7 @@ from superset.db_engine_specs.exceptions import (
SupersetDBAPIOperationalError,
SupersetDBAPIProgrammingError,
)
from superset.sql.parse import RLSMethod
logger = logging.getLogger()
@@ -39,6 +40,7 @@ class ElasticSearchEngineSpec(BaseEngineSpec): # pylint: disable=abstract-metho
allows_joins = False
allows_subqueries = True
allows_sql_comments = False
rls_method = RLSMethod.AS_PREDICATE_SPLICE
metadata = {
"description": (

View File

@@ -21,7 +21,7 @@ from sqlalchemy import types
from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import LimitMethod
from superset.sql.parse import LimitMethod, RLSMethod
class FirebirdEngineSpec(BaseEngineSpec):
@@ -53,6 +53,8 @@ class FirebirdEngineSpec(BaseEngineSpec):
# Firebird uses FIRST to limit: `SELECT FIRST 10 * FROM table`
limit_method = LimitMethod.FETCH_MANY
rls_method = RLSMethod.AS_PREDICATE_SPLICE
_time_grain_expressions = {
None: "{col}",
TimeGrain.SECOND: (

View File

@@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.sql.parse import RLSMethod
from .db2 import Db2EngineSpec
@@ -28,6 +30,8 @@ class IBMiEngineSpec(Db2EngineSpec):
engine_name = "IBM Db2 for i"
max_column_name_length = 128
rls_method = RLSMethod.AS_PREDICATE_SPLICE
@classmethod
def epoch_to_dttm(cls) -> str:
return "(DAYS({col}) - DAYS('1970-01-01')) * 86400 + MIDNIGHT_SECONDS({col})"

View File

@@ -28,7 +28,7 @@ from superset.db_engine_specs.exceptions import (
SupersetDBAPIOperationalError,
SupersetDBAPIProgrammingError,
)
from superset.sql.parse import LimitMethod
from superset.sql.parse import LimitMethod, RLSMethod
from superset.utils.core import GenericDataType
@@ -40,6 +40,7 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
allows_joins = True
allows_subqueries = True
allows_sql_comments = False
rls_method = RLSMethod.AS_PREDICATE_SPLICE
metadata = {
"description": (

View File

@@ -21,6 +21,7 @@ from sqlalchemy import types
from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import RLSMethod
class KylinEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
@@ -29,6 +30,8 @@ class KylinEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
engine = "kylin"
engine_name = "Apache Kylin"
rls_method = RLSMethod.AS_PREDICATE_SPLICE
metadata = {
"description": "Apache Kylin is an open-source OLAP engine for big data.",
"logo": "apache-kylin.png",

View File

@@ -753,6 +753,15 @@ def generate_yaml_docs(output_dir: str | None = None) -> dict[str, dict[str, Any
continue
name = get_name(spec)
# Skip "base" specs (e.g. PostgresBaseEngineSpec) that share an engine_name
# with a real product spec but have no concrete engine value. When multiple
# specs share the same engine_name the one with a non-empty engine string is
# the authoritative product spec; letting a base class overwrite it would
# produce incorrect capability flags.
if not spec.engine and name in all_docs:
continue
doc_data = diagnose(spec)
# Get documentation metadata (prefers spec.metadata over DATABASE_DOCS)
@@ -766,6 +775,7 @@ def generate_yaml_docs(output_dir: str | None = None) -> dict[str, dict[str, Any
doc_data["supports_file_upload"] = spec.supports_file_upload
doc_data["supports_dynamic_schema"] = spec.supports_dynamic_schema
doc_data["supports_catalog"] = spec.supports_catalog
doc_data["supports_dynamic_catalog"] = spec.supports_dynamic_catalog
all_docs[name] = doc_data

View File

@@ -39,6 +39,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.errors import SupersetErrorType
from superset.models.core import Database
from superset.models.sql_lab import Query
from superset.sql.parse import RLSMethod
# Regular expressions to catch custom errors
@@ -227,6 +228,8 @@ class OcientEngineSpec(BaseEngineSpec):
force_column_alias_quotes = True
max_column_name_length = 30
rls_method = RLSMethod.AS_PREDICATE_SPLICE
allows_cte_in_subquery = False
# Ocient does not support cte names starting with underscores
cte_alias = "cte__"

View File

@@ -20,6 +20,7 @@ from sqlalchemy.types import TypeEngine
from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import RLSMethod
class PinotEngineSpec(BaseEngineSpec):
@@ -30,6 +31,7 @@ class PinotEngineSpec(BaseEngineSpec):
allows_joins = False
allows_alias_in_select = False
allows_alias_in_orderby = False
rls_method = RLSMethod.AS_PREDICATE
# pinotdb only sets cursor.description when the response contains
# columnDataTypes, which Pinot omits for zero-row results.

View File

@@ -16,6 +16,7 @@
# under the License.
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import RLSMethod
class SolrEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
@@ -27,6 +28,7 @@ class SolrEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
time_groupby_inline = False
allows_joins = False
allows_subqueries = False
rls_method = RLSMethod.AS_PREDICATE
metadata = {
"description": "Apache Solr is an open-source enterprise search platform.",

View File

@@ -22,6 +22,7 @@ from urllib import parse
from sqlalchemy.engine.url import make_url, URL # noqa: F401
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import RLSMethod
class TDengineEngineSpec(BaseEngineSpec):
@@ -29,6 +30,8 @@ class TDengineEngineSpec(BaseEngineSpec):
engine_name = "TDengine"
max_column_name_length = 64
default_driver = "taosws"
rls_method = RLSMethod.AS_PREDICATE_SPLICE
sqlalchemy_uri_placeholder = (
"taosws://user:******@host:port/dbname[?key=value&key=value...]"
)

View File

@@ -28,7 +28,7 @@ from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.commands.exceptions import CommandException
from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException
from superset.extensions import event_logger
from superset.extensions import db, event_logger
from superset.mcp_service.chart.ascii_charts import (
generate_ascii_chart,
generate_ascii_table,
@@ -1140,6 +1140,15 @@ async def _get_chart_preview_internal( # noqa: C901
)
chart = find_chart_by_identifier(request.identifier)
# Eagerly refresh all attributes while the session is still
# active. SQLAlchemy expires object attributes after any
# commit; if a downstream operation commits before the strategy
# classes access chart attributes, a DetachedInstanceError will
# be raised. Calling refresh() here ensures all column values
# are loaded into the object's __dict__ upfront.
if chart is not None:
db.session.refresh(chart)
# If not found and looks like a form_data_key, try transient
if (
not chart
@@ -1371,6 +1380,20 @@ async def _get_chart_preview_internal( # noqa: C901
return _sanitize_chart_preview_for_llm_context(result)
except SQLAlchemyError as e:
# Catch DetachedInstanceError and other SQLAlchemy errors that can
# surface when the ORM session expires or commits mid-request.
await ctx.error(
"Chart preview failed due to database session error: "
"identifier=%s, error_type=%s, error=%s"
% (request.identifier, type(e).__name__, str(e))
)
logger.exception("SQLAlchemy error in get_chart_preview: %s", e)
return ChartError(
error="Database session error while generating chart preview. "
"Please retry the request.",
error_type="InternalError",
)
except (
CommandException,
SupersetException,

View File

@@ -41,6 +41,12 @@ from superset.mcp_service.constants import (
DEFAULT_TOKEN_LIMIT,
DEFAULT_WARN_THRESHOLD_PCT,
)
from superset.mcp_service.utils.token_utils import (
estimate_response_tokens,
format_size_limit_error,
INFO_TOOLS,
truncate_oversized_response,
)
from superset.utils.core import get_user_id
logger = logging.getLogger(__name__)
@@ -1104,11 +1110,6 @@ class ResponseSizeGuardMiddleware(Middleware):
``content[0].text`` as a JSON string. We parse that string, run the
truncation phases on the resulting dict, then re-wrap the result.
"""
from superset.mcp_service.utils.token_utils import (
estimate_response_tokens,
truncate_oversized_response,
)
# Unwrap ToolResult so truncation operates on the real payload
extracted = self._extract_payload_from_tool_result(response)
if extracted is not None:
@@ -1191,12 +1192,6 @@ class ResponseSizeGuardMiddleware(Middleware):
# Execute the tool
response = await call_next(context)
# Estimate response token count (guard against huge responses causing OOM)
from superset.mcp_service.utils.token_utils import (
estimate_response_tokens,
format_size_limit_error,
)
# When the response is a ToolResult, estimate tokens on the actual
# payload inside content[0].text rather than on the ToolResult
# wrapper (which would double-serialize the JSON string).
@@ -1233,8 +1228,6 @@ class ResponseSizeGuardMiddleware(Middleware):
params = getattr(context.message, "params", {}) or {}
# For info tools, try dynamic truncation before blocking
from superset.mcp_service.utils.token_utils import INFO_TOOLS
if tool_name in INFO_TOOLS:
truncated = self._try_truncate_info_response(
tool_name, response, estimated_tokens

View File

@@ -21,6 +21,26 @@ Token counting and response size utilities for MCP service.
This module provides utilities to estimate token counts and generate smart
suggestions when responses exceed configured limits. This prevents large
responses from overwhelming LLM clients like Claude Desktop.
Token counting strategy:
1. ``tiktoken`` with the ``cl100k_base`` encoding when the package is
installed (it is shipped as part of the ``fastmcp`` extra). This is a
real BPE tokenizer trained on a similar vocabulary to Claude's; for
English and JSON-heavy MCP payloads it tracks Claude's tokenizer
within roughly ±10%, which is far more accurate than the legacy
character heuristic.
2. A character-based fallback (``CHARS_PER_TOKEN``) when tiktoken is not
importable. The fallback uses a slightly more conservative ratio than
before (3.0 chars/token instead of 3.5) so that JSON-heavy responses
are not under-counted, which previously let oversized payloads slip
past the response-size guard.
The exact-Claude tokenizer is only available via Anthropic's network
``count_tokens`` API; calling it from a synchronous middleware on every
tool result is too slow and adds an external dependency on every
response. ``tiktoken`` is the closest approximation we can ship without
that risk.
"""
from __future__ import annotations
@@ -36,18 +56,63 @@ logger = logging.getLogger(__name__)
# Type alias for MCP tool responses (Pydantic models, dicts, lists, strings, bytes)
ToolResponse: TypeAlias = Union[BaseModel, Dict[str, Any], List[Any], str, bytes]
# Approximate characters per token for estimation
# Claude tokenizer averages ~4 chars per token for English text
# JSON tends to be more verbose, so we use a slightly lower ratio
CHARS_PER_TOKEN = 3.5
# Fallback character-to-token ratio used when tiktoken is unavailable.
# 3.0 is conservative for JSON content (the previous 3.5 under-counted
# JSON-heavy payloads relative to Claude's actual tokenizer, which let
# oversized responses slip past the response-size guard).
CHARS_PER_TOKEN = 3.0
# Encoding used when tiktoken is available. cl100k_base is OpenAI's
# tokenizer for GPT-3.5/4; it is BPE-based with a vocabulary similar to
# Claude's and tracks Claude's token counts within roughly ±10% for
# English and JSON-heavy MCP responses.
_TIKTOKEN_ENCODING_NAME = "cl100k_base"
def _load_tiktoken_encoding() -> Any:
"""Return a tiktoken encoding instance, or None if tiktoken is unavailable.
Imported lazily so the module can be used in environments without
tiktoken installed. The encoding is small (~1 MB) so we cache it on
first use.
"""
try:
import tiktoken
except ImportError:
logger.info(
"tiktoken not installed; falling back to char-based token "
"estimation (CHARS_PER_TOKEN=%s). Install the 'fastmcp' extra "
"for accurate counts.",
CHARS_PER_TOKEN,
)
return None
try:
return tiktoken.get_encoding(_TIKTOKEN_ENCODING_NAME)
except (KeyError, ValueError) as exc:
# tiktoken installed but the requested encoding is missing — this
# only happens on partial installs. Treat as no tokenizer rather
# than crashing on every tool call.
logger.warning(
"tiktoken encoding '%s' unavailable: %s; falling back to "
"char-based token estimation",
_TIKTOKEN_ENCODING_NAME,
exc,
)
return None
# Cached encoding instance (None if tiktoken not importable).
_ENCODING = _load_tiktoken_encoding()
def estimate_token_count(text: str | bytes) -> int:
"""
Estimate the token count for a given text.
Uses a character-based heuristic since we don't have direct access to
the actual tokenizer. This is conservative to avoid underestimating.
Uses tiktoken's ``cl100k_base`` encoding when available for
Claude-aligned accuracy (within ~10%), falling back to a
character-based heuristic otherwise.
Args:
text: The text to estimate tokens for (string or bytes)
@@ -58,11 +123,19 @@ def estimate_token_count(text: str | bytes) -> int:
if isinstance(text, bytes):
text = text.decode("utf-8", errors="replace")
# Simple heuristic: ~3.5 characters per token for JSON/code
text_length = len(text)
if text_length == 0:
if not text:
return 0
return max(1, int(text_length / CHARS_PER_TOKEN))
if _ENCODING is not None:
try:
return len(_ENCODING.encode(text))
except (ValueError, UnicodeError) as exc:
# Defensive: if tiktoken chokes on a specific input, fall
# back to the char heuristic for this call rather than
# raising — the response size guard must never fail-open.
logger.warning("tiktoken encode failed (%s); using fallback", exc)
return max(1, int(len(text) / CHARS_PER_TOKEN))
def estimate_response_tokens(response: ToolResponse) -> int:

View File

@@ -19,9 +19,7 @@
OpenSearch SQL dialect.
OpenSearch SQL is syntactically close to MySQL but accepts both backticks and
double-quotes as identifier delimiters. Treating ``"`` as an identifier (rather
than a string delimiter, as MySQL does) is what keeps mixed-case column names
from being emitted as string literals after a SQLGlot round-trip.
double-quotes as identifier delimiters.
"""
from __future__ import annotations
@@ -31,4 +29,4 @@ from sqlglot.dialects.mysql import MySQL
class OpenSearch(MySQL):
class Tokenizer(MySQL.Tokenizer):
IDENTIFIERS = ['"', "`"]
IDENTIFIERS = ["`", '"']

View File

@@ -76,7 +76,7 @@ SQLGLOT_DIALECTS = {
"duckdb": Dialects.DUCKDB,
# "dynamodb": ???
# "elasticsearch": ???
# "exa": ???
"exa": Dialects.EXASOL,
# "firebird": ???
"firebolt": Firebolt,
"gsheets": Dialects.SQLITE,
@@ -105,7 +105,7 @@ SQLGLOT_DIALECTS = {
"shillelagh": Dialects.SQLITE,
"singlestoredb": SingleStore,
"snowflake": Dialects.SNOWFLAKE,
# "solr": ???
"solr": Dialects.SOLR,
"spark": Dialects.SPARK,
"sqlite": Dialects.SQLITE,
"starrocks": Dialects.STARROCKS,
@@ -142,6 +142,7 @@ class RLSMethod(enum.Enum):
AS_PREDICATE = enum.auto()
AS_SUBQUERY = enum.auto()
AS_PREDICATE_SPLICE = enum.auto()
class RLSTransformer:
@@ -355,6 +356,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
statement: str | None = None,
engine: str = "base",
ast: InternalRepresentation | None = None,
source: str | None = None,
):
if ast:
self._parsed = ast
@@ -365,6 +367,16 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
self.engine = engine
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
# Original SQL substring for this statement, when known. Used by the
# splice-mode RLS path which rewrites this string instead of regenerating
# SQL from the AST. ``None`` means the statement was constructed from an
# AST without an associated source string (splice mode falls back).
self._source_sql: str | None = source if source is not None else statement
# Verbatim SQL to return from ``format()``. Set by string-rewriting
# operations (e.g. splice-mode RLS) that produce a final SQL string and
# need to bypass the dialect generator. Cleared by AST-mutating methods
# since those invalidate this cached text.
self._raw_sql: str | None = None
@classmethod
def split_script(
@@ -531,7 +543,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
self,
catalog: str | None,
schema: str | None,
predicates: dict[Table, list[InternalRepresentation]],
predicates: dict[Table, list[str]],
method: RLSMethod,
) -> None:
"""
@@ -559,9 +571,10 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
statement: str | None = None,
engine: str = "base",
ast: exp.Expression | None = None,
source: str | None = None,
):
self._dialect = SQLGLOT_DIALECTS.get(engine)
super().__init__(statement, engine, ast)
super().__init__(statement, engine, ast, source)
@classmethod
def _parse(cls, script: str, engine: str) -> list[exp.Expression]:
@@ -626,10 +639,57 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
script: str,
engine: str,
) -> list[SQLStatement]:
asts = [ast for ast in cls._parse(script, engine) if ast]
sources = cls._split_source(script, engine, len(asts))
return [
cls(ast=ast, engine=engine) for ast in cls._parse(script, engine) if ast
cls(ast=ast, engine=engine, source=source)
for ast, source in zip(asts, sources, strict=True)
]
@classmethod
def _split_source(
cls,
script: str,
engine: str,
expected_count: int,
) -> list[str | None]:
"""
Slice ``script`` into per-statement substrings using top-level semicolon
positions from the tokenizer. Returns a list of length ``expected_count``;
any entry is ``None`` if the slicing didn't yield a usable substring.
The returned substrings preserve the original byte content of the script
for each statement — necessary for splice-mode RLS, which rewrites the
original SQL rather than regenerating from the AST.
"""
none_result: list[str | None] = [None] * expected_count
dialect = SQLGLOT_DIALECTS.get(engine)
try:
tokens = list(Dialect.get_or_raise(dialect).tokenize(script))
except sqlglot.errors.SqlglotError:
return none_result
# Top-level semicolon offsets (depth 0).
boundaries: list[int] = []
depth = 0
for tok in tokens:
if tok.token_type == sqlglot.tokens.TokenType.L_PAREN:
depth += 1
elif tok.token_type == sqlglot.tokens.TokenType.R_PAREN:
# Clamp at 0 so malformed SQL with unbalanced ')' can't drive
# depth negative and misclassify later semicolons as nested.
depth = max(0, depth - 1)
elif tok.token_type == sqlglot.tokens.TokenType.SEMICOLON and depth == 0:
boundaries.append(tok.start)
starts = [0, *(b + 1 for b in boundaries)]
ends = [*boundaries, len(script)]
sources = [script[s:e].strip() for s, e in zip(starts, ends, strict=True)]
sources = [s for s in sources if s]
if len(sources) != expected_count:
return none_result
return list(sources)
@classmethod
def _parse_statement(
cls,
@@ -722,7 +782,13 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL statement.
When a string-rewriting operation (e.g. splice-mode RLS) has cached a
verbatim result in ``_raw_sql``, return it as-is — the whole point of
those operations is to avoid the dialect generator round-trip.
"""
if self._raw_sql is not None:
return self._raw_sql
return Dialect.get_or_raise(self._dialect).generate(
self._parsed,
copy=True,
@@ -808,6 +874,13 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
"""
Modify the `LIMIT` or `TOP` value of the SQL statement inplace.
"""
# AST mutation invalidates any cached verbatim SQL (e.g. from splice).
# If we already have a rewritten SQL string, re-parse it first so further
# AST mutations (like LIMIT injection) preserve prior text-based rewrites.
if self._raw_sql is not None:
self._parsed = self._parse_statement(self._raw_sql, self.engine)
self._source_sql = self._raw_sql
self._raw_sql = None
if method == LimitMethod.FORCE_LIMIT:
self._parsed.args["limit"] = exp.Limit(
expression=exp.Literal(this=str(limit), is_string=False)
@@ -902,7 +975,7 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
self,
catalog: str | None,
schema: str | None,
predicates: dict[Table, list[exp.Expression]],
predicates: dict[Table, list[str]],
method: RLSMethod,
) -> None:
"""
@@ -910,11 +983,22 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
:param catalog: The default catalog for non-qualified table names
:param schema: The default schema for non-qualified table names
:param predicates: Mapping of fully qualified ``Table`` to raw predicate
SQL strings.
:param method: The method to use for applying the rules.
"""
if not predicates:
return
if method == RLSMethod.AS_PREDICATE_SPLICE:
self._apply_rls_splice(catalog, schema, predicates)
return
parsed_predicates: dict[Table, list[exp.Expression]] = {
table: [self.parse_predicate(predicate) for predicate in table_predicates]
for table, table_predicates in predicates.items()
}
transformers = {
RLSMethod.AS_PREDICATE: RLSAsPredicateTransformer,
RLSMethod.AS_SUBQUERY: RLSAsSubqueryTransformer,
@@ -922,9 +1006,39 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
if method not in transformers:
raise ValueError(f"Invalid RLS method: {method}")
transformer = transformers[method](catalog, schema, predicates)
transformer = transformers[method](catalog, schema, parsed_predicates)
self._parsed = self._parsed.transform(transformer)
def _apply_rls_splice(
self,
catalog: str | None,
schema: str | None,
predicates: dict[Table, list[str]],
) -> None:
"""
Apply RLS via text splicing on the original SQL.
Requires the source SQL substring to be available. Raises ``ValueError``
if it isn't — the caller must ensure the statement was constructed from
a source string (the standard ``SQLScript`` path does this).
"""
from superset.sql.rls_splice import apply_rls_splice
if self._source_sql is None:
raise ValueError(
"Splice-mode RLS requires the source SQL string; "
"this SQLStatement was constructed without one."
)
spliced = apply_rls_splice(
self._source_sql,
catalog,
schema,
predicates,
dialect=self._dialect,
)
self._raw_sql = spliced
class KQLSplitState(enum.Enum):
"""

474
superset/sql/rls_splice.py Normal file
View File

@@ -0,0 +1,474 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
RLS predicate injection via text splicing.
Instead of round-tripping through sqlglot's generator (which transpiles
dialect-specific functions like ``LAST_DAY`` into something else), this approach:
1. Parses the SQL with sqlglot — only to understand structure (scope tree).
2. Uses sqlglot's tokenizer to get byte-accurate positions for every token
in the original SQL string.
3. For each ``SELECT`` scope that references a table with an RLS predicate,
finds the exact byte offset to inject at — either the end of an existing
``WHERE`` clause, or just before ``GROUP BY`` / ``ORDER BY`` / ``HAVING``
/ ``LIMIT`` / the closing paren of a subquery.
4. Splices the predicate text directly into the original string at that
offset — never calling ``.sql()``, so the generator never runs.
Result: everything outside the splice points is the original SQL, byte for
byte. Dialect-specific functions, comments, and formatting are all preserved
exactly.
Known limitations:
- SQL that fails to parse under the chosen dialect raises a ``ParseError``.
A thin dialect subclass is still required for parsing — but only for
parsing, not generation.
- Predicate strings are spliced in as raw SQL. They must come from a trusted
source (the RLS config), not user input.
- Predicate **column qualification** (prefixing bare columns with the table
alias) currently round-trips the predicate through the sqlglot generator
via ``_qualify_predicate``. Predicates that contain dialect-specific
functions can therefore still be transpiled by the generator at that step,
even though the surrounding query is preserved byte-for-byte. The
surrounding-query fidelity guarantee does not extend to the predicate
string itself.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import sqlglot
from sqlglot import exp
from sqlglot.optimizer.scope import traverse_scope
from sqlglot.tokens import Token, TokenType
if TYPE_CHECKING:
from superset.sql.parse import Table
# Token types that end a WHERE clause / FROM section at the current paren depth,
# indicating where a new predicate must be inserted just before.
_CLAUSE_ENDS = {
TokenType.GROUP_BY,
TokenType.HAVING,
TokenType.ORDER_BY,
TokenType.WINDOW,
TokenType.QUALIFY,
TokenType.LIMIT,
TokenType.FETCH,
TokenType.CLUSTER_BY,
TokenType.DISTRIBUTE_BY,
TokenType.SORT_BY,
TokenType.CONNECT_BY,
TokenType.START_WITH,
TokenType.UNION,
TokenType.INTERSECT,
TokenType.EXCEPT,
}
_JOIN_STARTS = {
TokenType.JOIN,
TokenType.STRAIGHT_JOIN,
TokenType.JOIN_MARKER,
}
def _splice_priority(text: str) -> int:
"""
Priority for applying splices at the same offset.
Insert full SQL fragments (WHERE/ON/predicates) before closing parens so
wrapping splices like ``pred AND (existing)`` compose correctly.
"""
return 1 if text != ")" else 0
def _after_previous_token(tokens: list[Token], index: int) -> int:
"""
Return the offset immediately after the token preceding *index*.
The sqlglot tokenizer strips comments and whitespace from the token stream,
so the previous token's ``end + 1`` is the splice point that lands right
after the last real SQL content — naturally skipping any intervening
comments or whitespace, and never confusing ``--`` or ``/*`` inside string
literals for real comment delimiters.
"""
if index <= 0:
return 0
return tokens[index - 1].end + 1
def _table_from_node(
node: exp.Table,
catalog: str | None,
schema: str | None,
) -> Table:
"""
Build a fully qualified ``Table`` from a sqlglot ``exp.Table`` node, defaulting
unqualified parts to the supplied catalog/schema.
"""
# Imported lazily to avoid a circular import with ``superset.sql.parse``.
from superset.sql.parse import Table
return Table(
table=node.name,
schema=node.db if node.db else schema,
catalog=node.catalog if node.catalog else catalog,
)
def apply_rls_splice(
sql: str,
catalog: str | None,
schema: str | None,
predicates: dict[Table, list[str]],
dialect: str | None = None,
) -> str:
"""
Inject RLS predicates into ``sql`` by splicing text at the right positions.
:param sql: The original SQL query. Returned unchanged except at splice points.
:param catalog: The default catalog for non-qualified table names.
:param schema: The default schema for non-qualified table names.
:param predicates: Mapping of ``Table`` to predicate SQL strings. Each entry
maps a fully qualified table to one or more raw predicate strings to
``AND`` together when that table is referenced in a SELECT scope.
:param dialect: The sqlglot dialect used for *parsing only* — to understand
scope structure and locate token positions. The generator is never
called, so this does not affect output formatting.
:return: The query with RLS predicates injected into every relevant SELECT
scope.
"""
if not predicates or not any(predicates.values()):
return sql
resolved_dialect = sqlglot.Dialect.get_or_raise(dialect)
tokens = list(resolved_dialect.tokenize(sql))
tree = sqlglot.parse_one(sql, dialect=dialect)
splices: list[tuple[int, str]] = []
for scope in traverse_scope(tree):
splices.extend(
_splices_for_scope(
sql,
tokens,
scope,
predicates,
catalog,
schema,
dialect,
)
)
# Apply splices in reverse offset order so earlier positions stay valid.
# For equal offsets, apply predicate/WHERE/ON inserts before ")" inserts.
splices.sort(key=lambda item: (item[0], _splice_priority(item[1])), reverse=True)
result = sql
for offset, text in splices:
result = result[:offset] + text + result[offset:]
return result
def _splices_for_scope(
sql: str,
tokens: list[Token],
scope: object,
predicates: dict[Table, list[str]],
catalog: str | None,
schema: str | None,
dialect: str | None,
) -> list[tuple[int, str]]:
"""
Compute all splices for a single SELECT scope.
This mirrors ``RLSAsPredicateTransformer`` semantics:
- predicates for FROM tables are applied to the SELECT WHERE clause as
``pred AND (existing_where)``
- predicates for JOIN tables are applied to each JOIN ON clause as
``pred AND (existing_on)`` (or ``ON pred`` when ON is absent)
"""
from_predicates: list[str] = []
from_table_ends: list[int] = []
join_splices: list[tuple[int, str]] = []
for source in scope.sources.values(): # type: ignore[attr-defined]
source_type, table_end, pred_sql = _classify_source_predicate(
source,
predicates,
catalog,
schema,
dialect,
)
if source_type == "none" or table_end is None or pred_sql is None:
continue
if source_type == "from":
from_predicates.append(pred_sql)
from_table_ends.append(table_end)
continue
join_splice = _find_join_splice(sql, tokens, table_end, pred_sql)
if join_splice:
join_splices.extend(join_splice)
if not from_predicates:
return join_splices
combined_predicates = " AND ".join(dict.fromkeys(from_predicates))
from_splice = _find_where_splice(
sql,
tokens,
max(from_table_ends),
combined_predicates,
)
return [*join_splices, *from_splice]
def _table_end(source: exp.Table) -> int | None:
ident = source.find(exp.Identifier)
meta = getattr(ident, "_meta", None) if ident else None
if meta is None:
return None
return meta.get("end")
def _classify_source_predicate(
source: object,
predicates: dict[Table, list[str]],
catalog: str | None,
schema: str | None,
dialect: str | None,
) -> tuple[str, int | None, str | None]:
"""
Return source kind (from/join/none), table end offset, and predicate SQL.
"""
if not isinstance(source, exp.Table):
return ("none", None, None)
table = _table_from_node(source, catalog, schema)
table_predicates = [
_qualify_predicate(predicate, source, dialect)
for predicate in predicates.get(table, [])
if predicate
]
if not table_predicates:
return ("none", None, None)
table_end = _table_end(source)
if table_end is None:
return ("none", None, None)
pred_sql = " AND ".join(dict.fromkeys(table_predicates))
if isinstance(source.parent, exp.From):
return ("from", table_end, pred_sql)
if isinstance(source.parent, exp.Join):
return ("join", table_end, pred_sql)
return ("none", None, None)
def _qualify_predicate(
predicate: str,
table_node: exp.Table,
dialect: str | None,
) -> str:
"""
Qualify predicate columns with the table alias/name, mirroring
``RLSAsPredicateTransformer``.
Note: this re-renders the predicate via the sqlglot generator, so the
splice-mode fidelity guarantee does not extend to the predicate text
itself. Predicates containing dialect-specific functions may be transpiled
here even though the surrounding query is preserved byte-for-byte.
"""
parsed = sqlglot.parse_one(predicate, dialect=dialect)
table = table_node.alias_or_name
table_expr = exp.to_identifier(table)
for column in parsed.find_all(exp.Column):
column.set("table", table_expr.copy())
return parsed.sql(dialect=dialect)
def _scan_until_scope_boundary(
tokens: list[Token],
anchor: int,
*,
stop_at_join: bool,
) -> tuple[str, int | None]:
"""
Scan tokens forward from ``anchor`` until a clause/scope boundary.
Returns ``("where", index)`` when a WHERE token is found at depth 0,
``("boundary", index)`` for a non-WHERE boundary token, and
``("eof", None)`` when no boundary token is found.
"""
depth = 0
for i, tok in enumerate(tokens):
if tok.start <= anchor:
continue
if tok.token_type == TokenType.L_PAREN:
depth += 1
continue
if tok.token_type == TokenType.R_PAREN:
if depth == 0:
return ("boundary", i)
depth -= 1
continue
if depth > 0:
continue
if tok.token_type == TokenType.WHERE:
return ("where", i)
if tok.token_type in _CLAUSE_ENDS or (
stop_at_join and tok.token_type in _JOIN_STARTS
):
return ("boundary", i)
return ("eof", None)
def _find_condition_end(
tokens: list[Token],
start_index: int,
*,
stop_at_join: bool,
) -> int:
"""
Find the end offset for a WHERE/ON condition body.
"""
depth = 0
prev_end = tokens[start_index].end
for tok in tokens[start_index + 1 :]:
if tok.token_type == TokenType.L_PAREN:
depth += 1
elif tok.token_type == TokenType.R_PAREN:
if depth == 0:
return prev_end + 1
depth -= 1
elif depth == 0 and (
(stop_at_join and tok.token_type == TokenType.WHERE)
or tok.token_type in _CLAUSE_ENDS
or (stop_at_join and tok.token_type in _JOIN_STARTS)
):
return prev_end + 1
prev_end = tok.end
return prev_end + 1
def _find_where_splice(
sql: str,
tokens: list[Token],
anchor: int,
pred_sql: str,
) -> list[tuple[int, str]]:
"""
Build splices for adding predicate semantics to the SELECT WHERE clause:
``pred`` when absent, ``pred AND (existing)`` when present.
"""
kind, idx = _scan_until_scope_boundary(tokens, anchor, stop_at_join=False)
if kind == "where" and idx is not None:
if idx + 1 >= len(tokens):
return [(tokens[idx].end + 1, f" {pred_sql}")]
body_start = tokens[idx + 1].start
body_end = _find_condition_end(tokens, idx, stop_at_join=False)
return [
(body_start, f"{pred_sql} AND ("),
(body_end, ")"),
]
if kind == "boundary" and idx is not None:
return [(_after_previous_token(tokens, idx), f" WHERE {pred_sql}")]
return [(len(sql), f" WHERE {pred_sql}")]
def _find_join_splice(
sql: str,
tokens: list[Token],
anchor: int,
pred_sql: str,
) -> list[tuple[int, str]]:
"""
Build splices for adding predicate semantics to a JOIN clause:
``ON pred`` when ON absent, ``ON pred AND (existing_on)`` when present.
"""
on_index, boundary_index = _scan_join_clause(tokens, anchor)
if on_index is not None:
if on_index + 1 >= len(tokens):
return [(tokens[on_index].end + 1, f" {pred_sql}")]
body_start = tokens[on_index + 1].start
body_end = _find_condition_end(tokens, on_index, stop_at_join=True)
return [
(body_start, f"{pred_sql} AND ("),
(body_end, ")"),
]
if boundary_index is not None:
return [(_after_previous_token(tokens, boundary_index), f" ON {pred_sql}")]
return [(len(sql), f" ON {pred_sql}")]
def _scan_join_clause(
tokens: list[Token],
anchor: int,
) -> tuple[int | None, int | None]:
"""
Find ON and boundary token indexes for a JOIN segment.
"""
depth = 0
on_index: int | None = None
boundary_index: int | None = None
for i, tok in enumerate(tokens):
if tok.start <= anchor:
continue
if tok.token_type == TokenType.L_PAREN:
depth += 1
continue
if tok.token_type == TokenType.R_PAREN:
if depth == 0:
boundary_index = i
break
depth -= 1
continue
if depth > 0:
continue
if tok.token_type == TokenType.ON and on_index is None:
on_index = i
continue
if tok.token_type == TokenType.WHERE:
boundary_index = i
break
if tok.token_type in _JOIN_STARTS or tok.token_type in _CLAUSE_ENDS:
boundary_index = i
break
return on_index, boundary_index

View File

@@ -45,6 +45,13 @@
color: #000;
}
{% endif %}
{% if standalone_mode %}
/* Keep body sized so screenshot waits don't see it as hidden before React mounts. */
html, body.standalone {
min-height: 100vh;
margin: 0;
}
{% endif %}
</style>
{% if dark_theme_bg and entry != 'embedded' %}

View File

@@ -40,17 +40,19 @@ def apply_rls(
:returns: True if any RLS predicates were actually applied, False otherwise.
"""
# There are two ways to insert RLS: either replacing the table with a subquery
# that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
# safer, but not supported in all databases.
method = database.db_engine_spec.get_rls_method()
# There are three ways to insert RLS:
# - replace the table with a subquery containing the RLS (safest, but not
# supported in all databases)
# - append the RLS to the ``WHERE`` clause via AST transformation
# - splice the RLS into the original SQL string (preserves dialect-specific
# syntax that the sqlglot generator would otherwise transpile)
method = database.db_engine_spec.rls_method
# collect all RLS predicates for all tables in the query
predicates: dict[Table, list[Any]] = {}
predicates: dict[Table, list[str]] = {}
for table in parsed_statement.tables:
table = table.qualify(catalog=catalog, schema=schema)
predicates[table] = [
parsed_statement.parse_predicate(predicate)
raw_predicates = [
predicate
for predicate in get_predicates_for_table(
table,
database,
@@ -58,6 +60,7 @@ def apply_rls(
)
if predicate
]
predicates[table] = raw_predicates
has_predicates = any(predicates.values())
parsed_statement.apply_rls(catalog, schema, predicates, method)

View File

@@ -36,7 +36,7 @@ from sqlalchemy.sql import sqltypes
from superset.db_engine_specs.base import BaseEngineSpec, convert_inspector_columns
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import OAuth2RedirectError
from superset.sql.parse import Table
from superset.sql.parse import RLSMethod, Table
from superset.superset_typing import (
OAuth2ClientConfig,
OAuth2State,
@@ -1283,3 +1283,8 @@ def test_start_oauth2_dance_falls_back_to_url_for(mocker: MockerFixture) -> None
error = exc_info.value.error
assert error.extra["redirect_uri"] == fallback_uri
def test_default_rls_method_is_subquery() -> None:
"""Base engine spec defaults to subquery-based RLS."""
assert BaseEngineSpec.rls_method == RLSMethod.AS_SUBQUERY

View File

@@ -68,3 +68,50 @@ def test_spa_template_includes_css_bundles():
"spa.html must call css_bundle for the page entry to load "
"entry-specific extracted CSS in production builds"
)
def test_spa_template_standalone_body_has_min_height():
"""Standalone body must be measurable so screenshot waits don't time out."""
from jinja2 import DictLoader, Environment
template_path = join(SUPERSET_DIR, "templates", "superset", "spa.html")
with open(template_path) as f:
template_content = f.read()
env = Environment( # noqa: S701
loader=DictLoader(
{
"spa.html": template_content,
# Stub out includes/imports that are not relevant for this test.
"appbuilder/general/lib.html": "",
"superset/partials/asset_bundle.html": (
"{% macro css_bundle(prefix, entry) %}{% endmacro %}"
"{% macro js_bundle(prefix, entry) %}{% endmacro %}"
),
"superset/macros.html": ("{% macro get_nonce() %}{% endmacro %}"),
"tail_js_custom_extra.html": "",
"head_custom_extra.html": "",
}
)
)
appbuilder = Mock()
appbuilder.app.config = {"FAVICONS": []}
def render(standalone_mode: bool) -> str:
return env.get_template("spa.html").render(
appbuilder=appbuilder,
assets_prefix="",
bootstrap_data="{}",
entry="spa",
standalone_mode=standalone_mode,
theme_tokens={},
spinner_svg=None,
)
standalone_html = render(standalone_mode=True)
assert "body.standalone" in standalone_html
assert "min-height: 100vh" in standalone_html
non_standalone_html = render(standalone_mode=False)
assert "body.standalone" not in non_standalone_html

View File

@@ -595,3 +595,191 @@ Market Share
"""
# These demonstrate the expected ASCII formats for different chart types
class TestDetachedInstanceError:
"""Tests that DetachedInstanceError is handled gracefully.
When the SQLAlchemy session commits mid-request, ORM objects expire and
become detached. Accessing lazy attributes on a detached Slice raises
DetachedInstanceError. The tool must:
1. Call db.session.refresh() immediately after loading the chart so all
column values are loaded upfront before any downstream operation.
2. Catch SQLAlchemyError (the base class) and return a ChartError
instead of propagating the exception.
"""
@pytest.mark.asyncio
async def test_session_refresh_called_after_chart_load(self):
"""db.session.refresh() is invoked right after find_chart_by_identifier."""
import importlib
from contextlib import nullcontext
from unittest.mock import MagicMock, patch
from superset.mcp_service.chart.schemas import URLPreview
from superset.utils import json
get_chart_preview_module = importlib.import_module(
"superset.mcp_service.chart.tool.get_chart_preview"
)
mock_chart = MagicMock()
mock_chart.id = 42
mock_chart.slice_name = "Sales Chart"
mock_chart.viz_type = "table"
mock_chart.datasource_id = 1
mock_chart.datasource_type = "table"
mock_chart.params = "{}"
refresh_calls: list[object] = []
def _fake_refresh(obj: object) -> None:
refresh_calls.append(obj)
url_preview = URLPreview(
preview_url="http://localhost/explore/?slice_id=42",
width=800,
height=600,
)
with (
patch.object(
get_chart_preview_module,
"find_chart_by_identifier",
return_value=mock_chart,
),
patch.object(
get_chart_preview_module.db,
"session",
**{"refresh.side_effect": _fake_refresh},
),
patch.object(
get_chart_preview_module,
"validate_chart_dataset",
return_value=MagicMock(is_valid=True, warnings=[]),
),
patch.object(
get_chart_preview_module.event_logger,
"log_context",
return_value=nullcontext(),
),
# Return a real URLPreview so Pydantic model validation succeeds
patch.object(
get_chart_preview_module.PreviewFormatGenerator,
"generate",
return_value=url_preview,
),
patch(
"superset.mcp_service.utils.url_utils.get_superset_base_url",
return_value="http://localhost",
),
):
from fastmcp import Client
from superset.mcp_service.app import mcp
from superset.mcp_service.chart.schemas import GetChartPreviewRequest
with patch("superset.mcp_service.auth.get_user_from_request") as mu:
mu.return_value = MagicMock(id=1, username="admin")
with patch(
"superset.mcp_service.auth.check_tool_permission", return_value=True
):
async with Client(mcp) as client:
response = await client.call_tool(
"get_chart_preview",
{
"request": GetChartPreviewRequest(
identifier=42, format="url"
).model_dump()
},
)
data = json.loads(response.content[0].text)
# The tool should succeed — not return a ChartError
assert "error_type" not in data, (
f"Expected ChartPreview but got ChartError: {data.get('error')}"
)
assert data.get("chart_id") == 42
assert len(refresh_calls) == 1, (
"db.session.refresh() should be called once after loading the chart"
)
assert refresh_calls[0] is mock_chart
@pytest.mark.asyncio
async def test_detached_instance_error_returns_chart_error(self):
"""DetachedInstanceError during preview generation returns ChartError."""
import importlib
from contextlib import nullcontext
from unittest.mock import MagicMock, patch
from sqlalchemy.orm.exc import DetachedInstanceError
get_chart_preview_module = importlib.import_module(
"superset.mcp_service.chart.tool.get_chart_preview"
)
mock_chart = MagicMock()
mock_chart.id = 7
mock_chart.slice_name = "Broken Chart"
mock_chart.viz_type = "bar"
mock_chart.datasource_id = 3
mock_chart.datasource_type = "table"
mock_chart.params = "{}"
with (
patch.object(
get_chart_preview_module,
"find_chart_by_identifier",
return_value=mock_chart,
),
patch.object(
get_chart_preview_module.db,
"session",
**{"refresh.return_value": None},
),
patch.object(
get_chart_preview_module,
"validate_chart_dataset",
return_value=MagicMock(is_valid=True, warnings=[]),
),
patch.object(
get_chart_preview_module.event_logger,
"log_context",
return_value=nullcontext(),
),
# Simulate the session expiring inside the strategy
patch.object(
get_chart_preview_module.PreviewFormatGenerator,
"generate",
side_effect=DetachedInstanceError(),
),
patch(
"superset.mcp_service.utils.url_utils.get_superset_base_url",
return_value="http://localhost",
),
):
from fastmcp import Client
from superset.mcp_service.app import mcp
from superset.mcp_service.chart.schemas import GetChartPreviewRequest
from superset.utils import json
with patch("superset.mcp_service.auth.get_user_from_request") as mu:
mu.return_value = MagicMock(id=1, username="admin")
with patch(
"superset.mcp_service.auth.check_tool_permission", return_value=True
):
async with Client(mcp) as client:
response = await client.call_tool(
"get_chart_preview",
{
"request": GetChartPreviewRequest(
identifier=7, format="ascii"
).model_dump()
},
)
data = json.loads(response.content[0].text)
assert data["error_type"] == "InternalError"
assert "session" in data["error"].lower() or "retry" in data["error"].lower()

View File

@@ -146,7 +146,13 @@ class TestResponseSizeGuardMiddleware:
@pytest.mark.asyncio
async def test_logs_warning_at_threshold(self) -> None:
"""Should log warning when approaching limit."""
"""Should log warning when approaching limit.
Mocks the token estimator to return a specific value above the
warn threshold but below the hard limit, decoupling the test
from whichever tokenizer (tiktoken or char heuristic) happens
to be loaded.
"""
middleware = ResponseSizeGuardMiddleware(
token_limit=1000, warn_threshold_pct=80
)
@@ -155,18 +161,21 @@ class TestResponseSizeGuardMiddleware:
context.message.name = "list_charts"
context.message.params = {}
# Response at ~85% of limit (should trigger warning but not block)
response = {"data": "x" * 2900} # ~828 tokens at 3.5 chars/token
response = {"data": "approaching the limit"}
call_next = AsyncMock(return_value=response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
patch(
"superset.mcp_service.middleware.estimate_response_tokens",
return_value=850,
),
patch("superset.mcp_service.middleware.logger") as mock_logger,
):
result = await middleware.on_call_tool(context, call_next)
# Should return response (not blocked)
# Should return response (not blocked at 85% of limit)
assert result == response
# Should log warning
mock_logger.warning.assert_called()

View File

@@ -20,9 +20,11 @@ Unit tests for MCP service token utilities.
"""
from typing import Any, List
from unittest.mock import patch
from pydantic import BaseModel
from superset.mcp_service.utils import token_utils
from superset.mcp_service.utils.token_utils import (
_replace_collections_with_summaries,
_summarize_large_dicts,
@@ -45,29 +47,65 @@ class TestEstimateTokenCount:
"""Test estimate_token_count function."""
def test_estimate_string(self) -> None:
"""Should estimate tokens for a string."""
"""Should produce a positive non-zero estimate for a normal string.
We don't assert on a specific number because the result depends on
which tokenizer is loaded (tiktoken when available, char heuristic
otherwise).
"""
text = "Hello world"
result = estimate_token_count(text)
expected = int(len(text) / CHARS_PER_TOKEN)
assert result == expected
assert result > 0
def test_estimate_bytes(self) -> None:
"""Should estimate tokens for bytes."""
text = b"Hello world"
result = estimate_token_count(text)
expected = int(len(text) / CHARS_PER_TOKEN)
assert result == expected
"""Bytes input should be decoded and produce the same count as the
equivalent string."""
text = "Hello world"
assert estimate_token_count(text.encode("utf-8")) == estimate_token_count(text)
def test_empty_string(self) -> None:
"""Should return 0 for empty string."""
"""Should return 0 for empty string and empty bytes."""
assert estimate_token_count("") == 0
assert estimate_token_count(b"") == 0
def test_json_like_content(self) -> None:
"""Should estimate tokens for JSON-like content."""
"""JSON content should produce a positive estimate."""
json_str = '{"name": "test", "value": 123, "items": [1, 2, 3]}'
result = estimate_token_count(json_str)
assert result > 0
assert result == int(len(json_str) / CHARS_PER_TOKEN)
assert estimate_token_count(json_str) > 0
def test_long_text_roughly_scales_with_length(self) -> None:
"""A doubled string should produce roughly double the token count
(within ±10%)."""
small = "the quick brown fox jumps over the lazy dog. " * 20
large = small * 2
small_n = estimate_token_count(small)
large_n = estimate_token_count(large)
# Within 10% of 2x — both tokenizers (tiktoken and the char
# fallback) preserve length monotonicity.
assert 1.8 * small_n <= large_n <= 2.2 * small_n
def test_fallback_uses_chars_per_token_when_tiktoken_unavailable(
self,
) -> None:
"""When the tiktoken encoding is None (not installed), the
function falls back to len/CHARS_PER_TOKEN math."""
text = "x" * 100
with patch.object(token_utils, "_ENCODING", None):
result = estimate_token_count(text)
assert result == int(100 / CHARS_PER_TOKEN)
def test_fallback_when_tiktoken_encode_raises(self) -> None:
"""A misbehaving encoding should fall back to the char heuristic
rather than raise — the size guard must never fail-open."""
class BoomEncoding:
def encode(self, text: str) -> list[int]:
raise ValueError("simulated tiktoken failure")
text = "abc" * 50
with patch.object(token_utils, "_ENCODING", BoomEncoding()):
result = estimate_token_count(text)
assert result == int(len(text) / CHARS_PER_TOKEN)
class TestEstimateResponseTokens:

View File

@@ -209,7 +209,7 @@ class TestApplyRlsReturnValue:
from superset.utils.rls import apply_rls
database = MagicMock()
database.db_engine_spec.get_rls_method.return_value = MagicMock()
database.db_engine_spec.rls_method = RLSMethod.AS_SUBQUERY
database.get_default_catalog.return_value = None
statement = MagicMock()
@@ -237,7 +237,7 @@ class TestApplyRlsReturnValue:
mock_get_predicates.return_value = []
database = MagicMock()
database.db_engine_spec.get_rls_method.return_value = MagicMock()
database.db_engine_spec.rls_method = RLSMethod.AS_SUBQUERY
database.get_default_catalog.return_value = None
mock_table = MagicMock()
@@ -268,7 +268,7 @@ class TestApplyRlsReturnValue:
mock_get_predicates.return_value = ["user_id = 42"]
database = MagicMock()
database.db_engine_spec.get_rls_method.return_value = MagicMock()
database.db_engine_spec.rls_method = RLSMethod.AS_SUBQUERY
database.get_default_catalog.return_value = None
mock_table = MagicMock()
@@ -276,8 +276,6 @@ class TestApplyRlsReturnValue:
statement = MagicMock()
statement.tables = [mock_table]
statement.parse_predicate.return_value = MagicMock()
result = apply_rls(
database=database,
catalog=None,
@@ -312,11 +310,10 @@ class TestRLSSubqueryAlias:
"""
sql = "SELECT pens.pen_id, pens.is_green FROM public.pens"
statement = SQLStatement(sql, engine="redshift")
predicate = statement.parse_predicate("user_id = 1")
statement.apply_rls(
None,
"public",
{Table("pens", "public", None): [predicate]},
{Table("pens", "public", None): ["user_id = 1"]},
RLSMethod.AS_SUBQUERY,
)
result = statement.format()
@@ -333,11 +330,10 @@ class TestRLSSubqueryAlias:
"""
sql = "SELECT pens.pen_id, pens.is_green FROM mycat.public.pens"
statement = SQLStatement(sql, engine="redshift")
predicate = statement.parse_predicate("user_id = 1")
statement.apply_rls(
None,
"public",
{Table("pens", "public", "mycat"): [predicate]},
{Table("pens", "public", "mycat"): ["user_id = 1"]},
RLSMethod.AS_SUBQUERY,
)
result = statement.format()
@@ -351,11 +347,10 @@ class TestRLSSubqueryAlias:
"""
sql = "SELECT p.pen_id, p.is_green FROM public.pens p"
statement = SQLStatement(sql, engine="redshift")
predicate = statement.parse_predicate("user_id = 1")
statement.apply_rls(
None,
"public",
{Table("pens", "public", None): [predicate]},
{Table("pens", "public", None): ["user_id = 1"]},
RLSMethod.AS_SUBQUERY,
)
result = statement.format()
@@ -369,11 +364,10 @@ class TestRLSSubqueryAlias:
"""
sql = "SELECT pen_id, is_green FROM public.pens"
statement = SQLStatement(sql, engine="redshift")
predicate = statement.parse_predicate("user_id = 1")
statement.apply_rls(
None,
"public",
{Table("pens", "public", None): [predicate]},
{Table("pens", "public", None): ["user_id = 1"]},
RLSMethod.AS_SUBQUERY,
)
result = statement.format()

View File

@@ -33,7 +33,8 @@ def test_opensearch_dialect_registered() -> None:
def test_double_quotes_as_identifiers() -> None:
"""
Test that double quotes are treated as identifiers, not string literals.
Test that double quotes are treated as identifiers, not string literals,
and normalized to backticks in output.
"""
sql = 'SELECT "AvgTicketPrice" FROM "flights"'
ast = sqlglot.parse_one(sql, OpenSearch)
@@ -42,8 +43,8 @@ def test_double_quotes_as_identifiers() -> None:
OpenSearch().generate(expression=ast, pretty=True)
== """
SELECT
"AvgTicketPrice"
FROM "flights"
`AvgTicketPrice`
FROM `flights`
""".strip()
)
@@ -69,8 +70,7 @@ WHERE
def test_backticks_as_identifiers() -> None:
"""
Test that backticks work as identifiers (MySQL-style).
Backticks are normalized to double quotes in output.
Test that backticks are accepted as identifiers and preserved on output.
"""
sql = "SELECT `AvgTicketPrice` FROM `flights`"
ast = sqlglot.parse_one(sql, OpenSearch)
@@ -79,15 +79,16 @@ def test_backticks_as_identifiers() -> None:
OpenSearch().generate(expression=ast, pretty=True)
== """
SELECT
"AvgTicketPrice"
FROM "flights"
`AvgTicketPrice`
FROM `flights`
""".strip()
)
def test_mixed_identifier_quotes() -> None:
"""
Test mixing double quotes and backticks for identifiers.
Test mixing double quotes and backticks for identifiers are all normalized to
backticks on output.
"""
sql = 'SELECT "AvgTicketPrice" AS `AvgTicketPrice` FROM `default`.`flights`'
ast = sqlglot.parse_one(sql, OpenSearch)
@@ -96,12 +97,26 @@ def test_mixed_identifier_quotes() -> None:
OpenSearch().generate(expression=ast, pretty=True)
== """
SELECT
"AvgTicketPrice" AS "AvgTicketPrice"
FROM "default"."flights"
`AvgTicketPrice` AS `AvgTicketPrice`
FROM `default`.`flights`
""".strip()
)
def test_alias_with_space() -> None:
"""
Test that an alias containing a space (e.g. a metric key like ``my test``)
is preserved as a backtick-quoted identifier through the round-trip.
"""
sql = 'SELECT COUNT(*) AS "my test" FROM `flights`'
ast = sqlglot.parse_one(sql, OpenSearch)
assert (
OpenSearch().generate(expression=ast, pretty=False)
== "SELECT COUNT(*) AS `my test` FROM `flights`"
)
@pytest.mark.parametrize(
"sql, expected",
[
@@ -110,20 +125,20 @@ FROM "default"."flights"
"""
SELECT
COUNT(*)
FROM "flights"
FROM `flights`
WHERE
"Cancelled" = TRUE
`Cancelled` = TRUE
""".strip(),
),
(
'SELECT "Carrier", SUM("AvgTicketPrice") FROM "flights" GROUP BY "Carrier"',
"""
SELECT
"Carrier",
SUM("AvgTicketPrice")
FROM "flights"
`Carrier`,
SUM(`AvgTicketPrice`)
FROM `flights`
GROUP BY
"Carrier"
`Carrier`
""".strip(),
),
(
@@ -131,9 +146,9 @@ GROUP BY
"""
SELECT
*
FROM "flights"
FROM `flights`
WHERE
"DestCountry" IN ('US', 'CA', 'MX')
`DestCountry` IN ('US', 'CA', 'MX')
""".strip(),
),
],
@@ -165,13 +180,13 @@ GROUP BY "Carrier"
OpenSearch().generate(expression=ast, pretty=True)
== """
SELECT
"Carrier",
`Carrier`,
COUNT(*),
AVG("AvgTicketPrice"),
MAX("FlightDelayMin")
FROM "flights"
AVG(`AvgTicketPrice`),
MAX(`FlightDelayMin`)
FROM `flights`
GROUP BY
"Carrier"
`Carrier`
""".strip()
)
@@ -190,10 +205,10 @@ SELECT
*
FROM (
SELECT
"Carrier",
"AvgTicketPrice"
FROM "flights"
) AS "sub"
`Carrier`,
`AvgTicketPrice`
FROM `flights`
) AS `sub`
""".strip()
)
@@ -212,12 +227,12 @@ def test_order_by_with_quoted_identifiers() -> None:
OpenSearch().generate(expression=ast, pretty=True)
== """
SELECT
"Carrier",
"AvgTicketPrice"
FROM "flights"
`Carrier`,
`AvgTicketPrice`
FROM `flights`
ORDER BY
"AvgTicketPrice" DESC,
"Carrier" ASC
`AvgTicketPrice` DESC,
`Carrier` ASC
""".strip()
)
@@ -234,7 +249,7 @@ def test_limit_clause() -> None:
== """
SELECT
*
FROM "flights"
FROM `flights`
LIMIT 10
""".strip()
)

View File

@@ -1704,6 +1704,21 @@ def test_set_limit_value(
assert statement.format() == expected
def test_set_limit_value_after_splice_reparses_from_raw_sql() -> None:
"""
When a statement has cached verbatim SQL from splice-mode rewrites, setting
limit should reparse that SQL before mutating the AST.
"""
statement = SQLStatement("SELECT * FROM some_table", "postgresql")
statement._raw_sql = "SELECT * FROM some_table WHERE tenant_id = 42"
statement.set_limit_value(10, LimitMethod.FORCE_LIMIT)
formatted = statement.format()
assert "tenant_id = 42" in formatted
assert "LIMIT 10" in formatted
@pytest.mark.parametrize(
"kql, limit, expected",
[
@@ -2198,7 +2213,7 @@ def test_rls_subquery_transformer(
statement.apply_rls(
"catalog1",
"schema1",
{k: [parse_one(v)] for k, v in rules.items()},
{k: [v] for k, v in rules.items()},
RLSMethod.AS_SUBQUERY,
)
assert statement.format() == expected
@@ -2542,12 +2557,337 @@ def test_rls_predicate_transformer(
statement.apply_rls(
"catalog1",
"schema1",
{k: [parse_one(v)] for k, v in rules.items()},
{k: [v] for k, v in rules.items()},
RLSMethod.AS_PREDICATE,
)
assert statement.format() == expected
@pytest.mark.parametrize(
"sql, rules, expected",
[
(
"SELECT t.foo FROM some_table AS t",
{Table("some_table", "schema1", "catalog1"): "t.id = 42"},
"SELECT t.foo FROM some_table AS t WHERE t.id = 42",
),
(
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz' OR foo = 'qux'",
{Table("some_table", "schema1", "catalog1"): "t.id = 42"},
"SELECT t.foo FROM some_table AS t WHERE t.id = 42 "
"AND (bar = 'baz' OR foo = 'qux')",
),
(
"SELECT * FROM table JOIN other_table ON table.id = other_table.id",
{Table("other_table", "schema1", "catalog1"): "other_table.id = 42"},
"SELECT * FROM table JOIN other_table ON other_table.id = 42 "
"AND (table.id = other_table.id)",
),
(
"SELECT * FROM table JOIN other_table",
{Table("other_table", "schema1", "catalog1"): "other_table.id = 42"},
"SELECT * FROM table JOIN other_table ON other_table.id = 42",
),
(
"SELECT * FROM table JOIN other_table ON table.id = other_table.id "
"WHERE 1=1",
{Table("other_table", "schema1", "catalog1"): "other_table.id = 42"},
"SELECT * FROM table JOIN other_table ON other_table.id = 42 "
"AND (table.id = other_table.id) WHERE 1=1",
),
],
)
def test_rls_predicate_splice_semantics_match_predicate(
sql: str,
rules: dict[Table, str],
expected: str,
) -> None:
"""
Splice mode should preserve predicate-mode semantics for boolean grouping
and JOIN-vs-WHERE placement.
"""
statement = SQLStatement(sql)
statement.apply_rls(
"catalog1",
"schema1",
{k: [v] for k, v in rules.items()},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == expected
@pytest.mark.parametrize(
"sql, rules, expected",
[
# Simple — no WHERE clause to extend.
(
"SELECT LAST_DAY(d) FROM some_table",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table WHERE some_table.tenant_id = 42",
),
# Append to an existing WHERE clause.
(
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open'",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table "
"WHERE some_table.tenant_id = 42 AND (status = 'open')",
),
# WHERE precedes GROUP BY: predicate goes before GROUP BY.
(
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open' GROUP BY d",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table "
"WHERE some_table.tenant_id = 42 AND (status = 'open') GROUP BY d",
),
# No WHERE, but GROUP BY and ORDER BY are present.
(
"SELECT LAST_DAY(d) FROM some_table GROUP BY d ORDER BY d",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table WHERE some_table.tenant_id = 42 "
"GROUP BY d ORDER BY d",
),
# JOIN — predicate scoped to one of the tables.
(
"SELECT o.id FROM some_table o JOIN locations l ON o.loc_id = l.id",
{Table("some_table", "schema1", "catalog1"): "o.tenant_id = 42"},
"SELECT o.id FROM some_table o JOIN locations l "
"ON o.loc_id = l.id WHERE o.tenant_id = 42",
),
# JOIN — different predicate per table, both spliced into one WHERE.
(
"SELECT * FROM some_table JOIN events ON some_table.id = events.order_id",
{
Table("events", "schema1", "catalog1"): "events.user_id = 99",
Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42",
},
"SELECT * FROM some_table JOIN events "
"ON events.user_id = 99 AND (some_table.id = events.order_id) "
"WHERE some_table.tenant_id = 42",
),
# Subquery in FROM — splice into the inner SELECT.
(
"SELECT x FROM (SELECT LAST_DAY(d) AS x FROM some_table) sub",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT x FROM (SELECT LAST_DAY(d) AS x FROM some_table "
"WHERE some_table.tenant_id = 42) sub",
),
# CTE — splice into the CTE body.
(
"WITH cte AS (SELECT LAST_DAY(d) FROM some_table) SELECT * FROM cte",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"WITH cte AS (SELECT LAST_DAY(d) FROM some_table "
"WHERE some_table.tenant_id = 42) SELECT * FROM cte",
),
# Dialect-specific function (LAST_DAY) preserved verbatim.
(
"SELECT id, LAST_DAY(created_at) FROM some_table WHERE region = 'US'",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT id, LAST_DAY(created_at) FROM some_table "
"WHERE some_table.tenant_id = 42 AND (region = 'US')",
),
# Multiline + inline comment preserved exactly.
(
"SELECT LAST_DAY(created_at) -- last day of month\n"
"FROM some_table\n"
"WHERE region = 'US'",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT LAST_DAY(created_at) -- last day of month\n"
"FROM some_table\n"
"WHERE some_table.tenant_id = 42 AND (region = 'US')",
),
# Schema-qualified table name (no default schema match) — no predicate.
(
"SELECT t.foo FROM schema2.some_table AS t",
{Table("some_table", "schema1", "catalog1"): "t.id = 42"},
"SELECT t.foo FROM schema2.some_table AS t",
),
],
)
def test_rls_predicate_splice(
sql: str,
rules: dict[Table, str],
expected: str,
) -> None:
"""
Test the splice-mode RLS via ``RLSMethod.AS_PREDICATE_SPLICE``.
Splice mode rewrites the original SQL string instead of re-rendering the
AST through the dialect generator, so byte-level fidelity (including
dialect-specific functions, comments, and whitespace) is preserved.
"""
statement = SQLStatement(sql)
statement.apply_rls(
"catalog1",
"schema1",
{k: [v] for k, v in rules.items()},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == expected
def test_rls_predicate_splice_requires_source() -> None:
"""
Splice mode requires the original SQL substring; constructing a statement
purely from an AST should make splice mode raise.
"""
ast = parse_one("SELECT * FROM some_table")
statement = SQLStatement(ast=ast, engine="postgresql")
with pytest.raises(ValueError, match="Splice-mode RLS requires the source SQL"):
statement.apply_rls(
"catalog1",
"schema1",
{Table("some_table", "schema1", "catalog1"): ["id = 42"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
def test_rls_predicate_splice_preserves_dialect_function() -> None:
"""
Splice mode must NOT round-trip through the sqlglot generator. ``LAST_DAY``
on the postgres dialect would otherwise be transpiled by the generator.
"""
sql = "SELECT LAST_DAY(d) FROM some_table"
statement = SQLStatement(sql, engine="postgresql")
statement.apply_rls(
None,
None,
{Table("some_table"): ["some_table.tenant_id = 42"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == (
"SELECT LAST_DAY(d) FROM some_table WHERE some_table.tenant_id = 42"
)
def test_rls_predicate_splice_combines_multiple_predicates() -> None:
"""
Splice mode should AND together multiple predicates configured for the same
table into a single injected condition.
"""
sql = "SELECT * FROM some_table WHERE status = 'open'"
statement = SQLStatement(sql, engine="postgresql")
statement.apply_rls(
None,
None,
{
Table("some_table"): [
"some_table.tenant_id = 42",
"some_table.region = 'US'",
],
},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == (
"SELECT * FROM some_table "
"WHERE some_table.tenant_id = 42 AND some_table.region = 'US' "
"AND (status = 'open')"
)
def test_rls_predicate_splice_string_predicates_skip_parse() -> None:
"""
Splice mode accepts predicate strings directly — no ``parse_predicate`` is
needed at the call site.
"""
sql = "SELECT * FROM some_table"
statement = SQLStatement(sql, engine="postgresql")
statement.apply_rls(
None,
None,
{Table("some_table"): ["some_table.tenant_id = 42 AND some_table.active"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == (
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 AND some_table.active"
)
@pytest.mark.parametrize(
"sql, expected",
[
(
"SELECT * FROM some_table -- hi\nGROUP BY id",
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
"-- hi\nGROUP BY id",
),
(
"SELECT * FROM some_table /* inline */ GROUP BY id",
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
"/* inline */ GROUP BY id",
),
],
)
def test_rls_predicate_splice_inserts_before_comments(sql: str, expected: str) -> None:
"""
Splice mode should insert predicates before comments that precede the next
clause boundary, so comments do not swallow the injected SQL.
"""
statement = SQLStatement(sql, engine="postgresql")
statement.apply_rls(
None,
None,
{Table("some_table"): ["some_table.tenant_id = 42"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == expected
@pytest.mark.parametrize(
"sql, engine, expected",
[
(
"SELECT * FROM some_table QUALIFY row_number() OVER "
"(PARTITION BY id ORDER BY ts DESC) = 1",
"snowflake",
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
"QUALIFY row_number() OVER (PARTITION BY id ORDER BY ts DESC) = 1",
),
(
"SELECT sum(v) OVER () FROM some_table WINDOW w AS (PARTITION BY id)",
"postgresql",
"SELECT sum(v) OVER () FROM some_table "
"WHERE some_table.tenant_id = 42 "
"WINDOW w AS (PARTITION BY id)",
),
],
)
def test_rls_predicate_splice_handles_additional_clause_boundaries(
sql: str,
engine: str,
expected: str,
) -> None:
"""
Splice mode should insert WHERE before clause types that can legally follow
FROM/WHERE (for example QUALIFY and WINDOW).
"""
statement = SQLStatement(sql, engine=engine)
statement.apply_rls(
None,
None,
{Table("some_table"): ["some_table.tenant_id = 42"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == expected
def test_rls_predicate_splice_then_limit_keeps_rls() -> None:
"""
LIMIT rewrites after splice-mode RLS should retain injected predicates.
"""
statement = SQLStatement("SELECT * FROM some_table", engine="postgresql")
statement.apply_rls(
None,
None,
{Table("some_table"): ["tenant_id = 42"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
statement.set_limit_value(101, LimitMethod.FORCE_LIMIT)
formatted = statement.format()
assert "some_table.tenant_id = 42" in formatted
assert "LIMIT 101" in formatted
@pytest.mark.parametrize(
"sql, table, expected",
[

View File

@@ -0,0 +1,298 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import sqlglot
from sqlglot import Dialect, exp
from superset.sql.parse import SQLStatement, Table
from superset.sql.rls_splice import (
_after_previous_token,
_classify_source_predicate,
_find_condition_end,
_find_join_splice,
_find_where_splice,
_scan_join_clause,
_scan_until_scope_boundary,
_splices_for_scope,
_table_end,
apply_rls_splice,
)
def _tokenize(sql: str) -> list[sqlglot.tokens.Token]:
return list(Dialect.get_or_raise(None).tokenize(sql))
def _token_index(tokens: list[sqlglot.tokens.Token], token_type: object) -> int:
return next(i for i, token in enumerate(tokens) if token.token_type == token_type)
def _token_by_text(
tokens: list[sqlglot.tokens.Token], text: str
) -> sqlglot.tokens.Token:
return next(token for token in tokens if token.text == text)
def test_split_source_returns_none_result_when_tokenize_fails(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _BrokenDialect:
@staticmethod
def tokenize(_: str) -> list[sqlglot.tokens.Token]:
raise sqlglot.errors.SqlglotError("boom")
monkeypatch.setattr(
"superset.sql.parse.Dialect.get_or_raise",
lambda _: _BrokenDialect(),
)
assert SQLStatement._split_source("SELECT 1", "postgresql", 2) == [None, None]
def test_apply_rls_splice_ignores_empty_predicates() -> None:
sql = "SELECT 1"
assert apply_rls_splice(sql, None, None, {Table("foo"): []}) == sql
def test_apply_rls_splice_ignores_dash_dash_inside_string_literal() -> None:
"""
Regression: the splice point must not be confused by ``--`` appearing
inside a string literal. Earlier ``rfind("--", ...)`` logic mistook this
for an inline comment and inserted the predicate inside the quoted text.
"""
sql = "SELECT * FROM some_table WHERE note = '--x' GROUP BY id"
spliced = apply_rls_splice(
sql,
None,
None,
{Table("some_table"): ["some_table.tenant_id = 42"]},
dialect="postgres",
)
assert spliced == (
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
"AND (note = '--x') GROUP BY id"
)
def test_table_end_returns_none_without_metadata() -> None:
source = exp.Table(this=exp.Identifier(this="foo"))
assert _table_end(source) is None
def test_classify_source_predicate_returns_none_without_table_metadata() -> None:
source = exp.Table(this=exp.Identifier(this="foo"))
exp.From(this=source)
result = _classify_source_predicate(
source,
{Table("foo"): ["id = 1"]},
None,
None,
None,
)
assert result == ("none", None, None)
def test_classify_source_predicate_returns_none_for_unsupported_parent() -> None:
source = exp.Table(this=exp.Identifier(this="foo"))
source.this.meta["end"] = 3
exp.Alias(this=source, alias=exp.Identifier(this="alias"))
result = _classify_source_predicate(
source,
{Table("foo"): ["id = 1"]},
None,
None,
None,
)
assert result == ("none", None, None)
def test_after_previous_token_returns_zero_at_stream_start() -> None:
tokens = _tokenize("SELECT 1")
assert _after_previous_token(tokens, 0) == 0
def test_scan_until_scope_boundary_tracks_parenthesis_depth() -> None:
sql = "SELECT * FROM t WHERE (a = 1)"
tokens = _tokenize(sql)
where_token = _token_by_text(tokens, "WHERE")
assert _scan_until_scope_boundary(
tokens, where_token.start, stop_at_join=False
) == (
"eof",
None,
)
def test_find_condition_end_handles_subquery_closing_paren() -> None:
sql = "SELECT * FROM (SELECT * FROM t WHERE a = 1)"
tokens = _tokenize(sql)
where_index = _token_index(tokens, sqlglot.tokens.TokenType.WHERE)
end = _find_condition_end(tokens, where_index, stop_at_join=False)
assert sql[end] == ")"
def test_find_condition_end_handles_parenthesized_expression() -> None:
sql = "SELECT * FROM t WHERE (a = 1)"
tokens = _tokenize(sql)
where_index = _token_index(tokens, sqlglot.tokens.TokenType.WHERE)
end = _find_condition_end(tokens, where_index, stop_at_join=False)
assert end == len(sql)
def test_find_where_splice_handles_trailing_where_keyword() -> None:
sql = "SELECT * FROM t WHERE"
tokens = _tokenize(sql)
splices = _find_where_splice(sql, tokens, anchor=0, pred_sql="t.id = 1")
assert splices == [(len(sql), " t.id = 1")]
def test_find_join_splice_handles_trailing_on_keyword() -> None:
sql = "SELECT * FROM a JOIN b ON"
tokens = _tokenize(sql)
b_token = _token_by_text(tokens, "b")
splices = _find_join_splice(sql, tokens, b_token.end, "b.id = 1")
assert splices == [(len(sql), " b.id = 1")]
def test_find_join_splice_inserts_on_before_where_boundary() -> None:
sql = "SELECT * FROM a JOIN b WHERE x = 1"
tokens = _tokenize(sql)
b_token = _token_by_text(tokens, "b")
splices = _find_join_splice(sql, tokens, b_token.end, "b.id = 1")
assert splices == [(sql.index("WHERE") - 1, " ON b.id = 1")]
def test_scan_join_clause_covers_nested_parentheses_and_join_boundary() -> None:
sql = "SELECT * FROM a JOIN b ON (a.id = b.id) JOIN c ON 1 = 1"
tokens = _tokenize(sql)
b_token = _token_by_text(tokens, "b")
on_index, boundary_index = _scan_join_clause(tokens, b_token.end)
assert on_index is not None
assert boundary_index is not None
assert tokens[boundary_index].token_type == sqlglot.tokens.TokenType.JOIN
def test_scan_join_clause_stops_at_outer_closing_paren() -> None:
sql = "SELECT * FROM (SELECT * FROM a JOIN b) sub"
tokens = _tokenize(sql)
b_token = _token_by_text(tokens, "b")
_, boundary_index = _scan_join_clause(tokens, b_token.end)
assert boundary_index is not None
assert tokens[boundary_index].token_type == sqlglot.tokens.TokenType.R_PAREN
def test_splices_for_scope_handles_empty_join_splice_result(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _Scope:
sources = {"x": object()}
sql = "SELECT 1"
tokens = _tokenize(sql)
monkeypatch.setattr(
"superset.sql.rls_splice._classify_source_predicate",
lambda *args, **kwargs: ("join", 0, "x.id = 1"),
)
monkeypatch.setattr(
"superset.sql.rls_splice._find_join_splice",
lambda *args, **kwargs: [],
)
assert (
_splices_for_scope(
sql,
tokens,
_Scope(),
{Table("x"): ["x.id = 1"]},
None,
None,
None,
)
== []
)
def test_splices_for_scope_combines_join_and_from_splices(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _Scope:
sources = {"f": object(), "j": object()}
sql = "SELECT 1"
tokens = _tokenize(sql)
calls = [("from", 3, "f.id = 1"), ("join", 6, "j.id = 2")]
def _fake_classify(*args: object, **kwargs: object) -> tuple[str, int, str]:
return calls.pop(0)
monkeypatch.setattr(
"superset.sql.rls_splice._classify_source_predicate", _fake_classify
)
monkeypatch.setattr(
"superset.sql.rls_splice._find_join_splice",
lambda *args, **kwargs: [(50, " ON j.id = 2")],
)
monkeypatch.setattr(
"superset.sql.rls_splice._find_where_splice",
lambda *args, **kwargs: [(20, " WHERE f.id = 1")],
)
assert _splices_for_scope(
sql,
tokens,
_Scope(),
{Table("f"): ["id = 1"], Table("j"): ["id = 2"]},
None,
None,
None,
) == [(50, " ON j.id = 2"), (20, " WHERE f.id = 1")]
def test_splices_for_scope_join_then_next_source(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _Scope:
sources = {"j": object(), "f": object()}
sql = "SELECT 1"
tokens = _tokenize(sql)
calls = [("join", 3, "j.id = 2"), ("none", None, None)]
def _fake_classify(
*args: object, **kwargs: object
) -> tuple[str, int | None, str | None]:
return calls.pop(0)
monkeypatch.setattr(
"superset.sql.rls_splice._classify_source_predicate", _fake_classify
)
monkeypatch.setattr(
"superset.sql.rls_splice._find_join_splice",
lambda *args, **kwargs: [],
)
assert (
_splices_for_scope(
sql,
tokens,
_Scope(),
{Table("j"): ["id = 2"]},
None,
None,
None,
)
== []
)