fix(mcp): remove html.escape to fix ampersand display in chart titles (#37186)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Amin Ghadersohi
2026-02-05 20:49:37 -07:00
committed by GitHub
parent 97e5f0631d
commit 01ac966b83
4 changed files with 792 additions and 264 deletions

View File

@@ -21,8 +21,6 @@ Pydantic schemas for chart-related responses
from __future__ import annotations
import html
import re
from datetime import datetime, timezone
from typing import Annotated, Any, Dict, List, Literal, Protocol
@@ -50,6 +48,10 @@ from superset.mcp_service.system.schemas import (
TagInfo,
UserInfo,
)
from superset.mcp_service.utils.sanitization import (
sanitize_filter_value,
sanitize_user_input,
)
class ChartLike(Protocol):
@@ -357,113 +359,17 @@ class ColumnRef(BaseModel):
@classmethod
def sanitize_name(cls, v: str) -> str:
"""Sanitize column name to prevent XSS and SQL injection."""
if not v or not v.strip():
raise ValueError("Column name cannot be empty")
# Length check first to prevent ReDoS attacks
if len(v) > 255:
raise ValueError(
f"Column name too long ({len(v)} characters). "
f"Maximum allowed length is 255 characters."
)
# Remove HTML tags and decode entities
sanitized = html.escape(v.strip())
# Check for dangerous HTML tags using substring checks (safe)
dangerous_tags = ["<script", "</script>", "<iframe", "<object", "<embed"]
v_lower = v.lower()
for tag in dangerous_tags:
if tag in v_lower:
raise ValueError(
"Column name contains potentially malicious script content"
)
# Check URL schemes with word boundaries to match only actual URLs
if re.search(r"\b(javascript|vbscript|data):", v, re.IGNORECASE):
raise ValueError("Column name contains potentially malicious URL scheme")
# Basic SQL injection patterns (basic protection)
# Use simple patterns without backtracking
dangerous_patterns = [
r"[;|&$`]", # Dangerous shell characters
r"\b(DROP|DELETE|INSERT|UPDATE|CREATE|ALTER|EXEC|EXECUTE)\b",
r"--", # SQL comment
r"/\*", # SQL comment start (just check for start, not full pattern)
]
for pattern in dangerous_patterns:
if re.search(pattern, v, re.IGNORECASE):
raise ValueError(
"Column name contains potentially unsafe characters or SQL keywords"
)
return sanitized
# sanitize_user_input raises ValueError when allow_empty=False (default)
# so the return value is guaranteed to be a non-None str
return sanitize_user_input(
v, "Column name", max_length=255, check_sql_keywords=True
) # type: ignore[return-value]
@field_validator("label")
@classmethod
def sanitize_label(cls, v: str | None) -> str | None:
"""Sanitize display label to prevent XSS attacks."""
if v is None:
return v
# Strip whitespace
v = v.strip()
if not v:
return None
# Length check first to prevent ReDoS attacks
if len(v) > 500:
raise ValueError(
f"Label too long ({len(v)} characters). "
f"Maximum allowed length is 500 characters."
)
# Check for dangerous HTML tags and JavaScript protocols using substring checks
# This avoids ReDoS vulnerabilities from regex patterns
dangerous_tags = [
"<script",
"</script>",
"<iframe",
"</iframe>",
"<object",
"</object>",
"<embed",
"</embed>",
"<link",
"<meta",
]
v_lower = v.lower()
for tag in dangerous_tags:
if tag in v_lower:
raise ValueError(
"Label contains potentially malicious content. "
"HTML tags, JavaScript, and event handlers are not allowed "
"in labels."
)
# Check URL schemes and event handlers with word boundaries
dangerous_patterns = [
r"\b(javascript|vbscript|data):", # URL schemes
r"on\w+\s*=", # Event handlers
]
for pattern in dangerous_patterns:
if re.search(pattern, v, re.IGNORECASE):
raise ValueError(
"Label contains potentially malicious content. "
"HTML tags, JavaScript, and event handlers are not allowed."
)
# Filter dangerous Unicode characters
v = re.sub(
r"[\u200B-\u200D\uFEFF\u0000-\u0008\u000B\u000C\u000E-\u001F]", "", v
)
# HTML escape the cleaned content
sanitized = html.escape(v)
return sanitized if sanitized else None
return sanitize_user_input(v, "Label", max_length=500, allow_empty=True)
class AxisConfig(BaseModel):
@@ -496,112 +402,15 @@ class FilterConfig(BaseModel):
@classmethod
def sanitize_column(cls, v: str) -> str:
"""Sanitize filter column name to prevent injection attacks."""
if not v or not v.strip():
raise ValueError("Filter column name cannot be empty")
# Length check first to prevent ReDoS attacks
if len(v) > 255:
raise ValueError(
f"Filter column name too long ({len(v)} characters). "
f"Maximum allowed length is 255 characters."
)
# Remove HTML tags and decode entities
sanitized = html.escape(v.strip())
# Check for dangerous HTML tags using substring checks (safe)
dangerous_tags = ["<script", "</script>"]
v_lower = v.lower()
for tag in dangerous_tags:
if tag in v_lower:
raise ValueError(
"Filter column contains potentially malicious script content"
)
# Check URL schemes with word boundaries
if re.search(r"\b(javascript|vbscript|data):", v, re.IGNORECASE):
raise ValueError("Filter column contains potentially malicious URL scheme")
return sanitized
@staticmethod
def _validate_string_value(v: str) -> None:
"""Validate string filter value for security issues."""
# Check for dangerous HTML tags and SQL procedures
dangerous_substrings = [
"<script",
"</script>",
"<iframe",
"<object",
"<embed",
"xp_cmdshell",
"sp_executesql",
]
v_lower = v.lower()
for substring in dangerous_substrings:
if substring in v_lower:
raise ValueError(
"Filter value contains potentially malicious content. "
"HTML tags and JavaScript are not allowed."
)
# Check URL schemes with word boundaries
if re.search(r"\b(javascript|vbscript|data):", v, re.IGNORECASE):
raise ValueError("Filter value contains potentially malicious URL scheme")
# SQL injection patterns
sql_patterns = [
r";\s*(DROP|DELETE|INSERT|UPDATE|CREATE|ALTER|EXEC|EXECUTE)\b",
r"'\s*OR\s*'",
r"'\s*AND\s*'",
r"--\s*",
r"/\*",
r"UNION\s+SELECT",
]
for pattern in sql_patterns:
if re.search(pattern, v, re.IGNORECASE):
raise ValueError(
"Filter value contains potentially malicious SQL patterns."
)
# Check for other dangerous patterns
if re.search(r"[;&|`$()]", v):
raise ValueError(
"Filter value contains potentially unsafe shell characters."
)
if re.search(r"on\w+\s*=", v, re.IGNORECASE):
raise ValueError(
"Filter value contains potentially malicious event handlers."
)
if re.search(r"\\x[0-9a-fA-F]{2}", v):
raise ValueError("Filter value contains hex encoding which is not allowed.")
# sanitize_user_input raises ValueError when allow_empty=False (default)
# so the return value is guaranteed to be a non-None str
return sanitize_user_input(v, "Filter column", max_length=255) # type: ignore[return-value]
@field_validator("value")
@classmethod
def sanitize_value(cls, v: str | int | float | bool) -> str | int | float | bool:
"""Sanitize filter value to prevent XSS and SQL injection attacks."""
if isinstance(v, str):
v = v.strip()
# Length check FIRST to prevent ReDoS attacks
if len(v) > 1000:
raise ValueError(
f"Filter value too long ({len(v)} characters). "
f"Maximum allowed length is 1000 characters."
)
# Validate security
cls._validate_string_value(v)
# Filter dangerous Unicode characters
v = re.sub(
r"[\u200B-\u200D\uFEFF\u0000-\u0008\u000B\u000C\u000E-\u001F]", "", v
)
# HTML escape the cleaned content
return html.escape(v)
return v # Return non-string values as-is
return sanitize_filter_value(v, max_length=1000)
# Actual chart types
@@ -848,6 +657,11 @@ class ListChartsRequest(MetadataCacheControl):
class GenerateChartRequest(QueryCacheControl):
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
config: ChartConfig = Field(..., description="Chart configuration")
chart_name: str | None = Field(
None,
description="Custom chart name (optional, auto-generates if not provided)",
max_length=255,
)
save_chart: bool = Field(
default=False,
description="Whether to permanently save the chart in Superset",
@@ -861,6 +675,12 @@ class GenerateChartRequest(QueryCacheControl):
description="List of preview formats to generate",
)
@field_validator("chart_name")
@classmethod
def sanitize_chart_name(cls, v: str | None) -> str | None:
"""Sanitize chart name to prevent XSS attacks."""
return sanitize_user_input(v, "Chart name", max_length=255, allow_empty=True)
@model_validator(mode="after")
def validate_cache_timeout(self) -> "GenerateChartRequest":
"""Validate cache timeout is non-negative."""
@@ -911,62 +731,7 @@ class UpdateChartRequest(QueryCacheControl):
@classmethod
def sanitize_chart_name(cls, v: str | None) -> str | None:
"""Sanitize chart name to prevent XSS attacks."""
if v is None:
return v
# Strip whitespace
v = v.strip()
if not v:
return None
# Length check first to prevent ReDoS attacks
if len(v) > 255:
raise ValueError(
f"Chart name too long ({len(v)} characters). "
f"Maximum allowed length is 255 characters."
)
# Check for dangerous HTML tags using substring checks (safe)
dangerous_tags = [
"<script",
"</script>",
"<iframe",
"</iframe>",
"<object",
"</object>",
"<embed",
"</embed>",
"<link",
"<meta",
]
v_lower = v.lower()
for tag in dangerous_tags:
if tag in v_lower:
raise ValueError(
"Chart name contains potentially malicious content. "
"HTML tags and JavaScript are not allowed in chart names."
)
# Check URL schemes with word boundaries
if re.search(r"\b(javascript|vbscript|data):", v, re.IGNORECASE):
raise ValueError("Chart name contains potentially malicious URL scheme")
# Check for event handlers with simple regex
if re.search(r"on\w+\s*=", v, re.IGNORECASE):
raise ValueError(
"Chart name contains potentially malicious event handlers."
)
# Filter dangerous Unicode characters
v = re.sub(
r"[\u200B-\u200D\uFEFF\u0000-\u0008\u000B\u000C\u000E-\u001F]", "", v
)
# HTML escape the cleaned content
sanitized = html.escape(v)
return sanitized if sanitized else None
return sanitize_user_input(v, "Chart name", max_length=255, allow_empty=True)
class UpdateChartPreviewRequest(FormDataCacheControl):

View File

@@ -189,9 +189,9 @@ async def generate_chart( # noqa: C901
await ctx.report_progress(2, 5, "Creating chart in database")
from superset.commands.chart.create import CreateChartCommand
# Generate a chart name
chart_name = generate_chart_name(request.config)
await ctx.debug("Generated chart name: chart_name=%s" % (chart_name,))
# Use custom chart name if provided, otherwise auto-generate
chart_name = request.chart_name or generate_chart_name(request.config)
await ctx.debug("Chart name: chart_name=%s" % (chart_name,))
# Find the dataset to get its numeric ID
from superset.daos.dataset import DatasetDAO

View File

@@ -0,0 +1,283 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Centralized sanitization utilities for MCP service user inputs.
This module uses the nh3 library (Rust-based HTML sanitizer) to strip malicious
HTML tags and protocols from user inputs. nh3 is faster and safer than manual
regex-based sanitization.
Key features:
- Strips all HTML tags using nh3.clean() with no allowed tags
- Blocks dangerous URL schemes (javascript:, vbscript:, data:)
- Preserves safe text content (e.g., '&' stays as '&', not '&amp;')
- Additional SQL injection protection for database-facing inputs
"""
import html
import re
import nh3
def _strip_html_tags(value: str) -> str:
"""
Strip all HTML tags from the input using nh3.
Decodes all layers of HTML entity encoding BEFORE passing to nh3,
so entity-encoded tags (e.g., ``&lt;script&gt;``) are decoded into
real tags that nh3 can detect and strip. After nh3 removes all tags,
we only restore ``&amp;`` back to ``&`` (not a full html.unescape)
to preserve ampersands in display text without risking XSS from
re-introducing angle brackets or other HTML-significant characters.
Args:
value: The input string that may contain HTML
Returns:
String with all HTML tags removed and ampersands preserved
"""
# Decode all layers of HTML entity encoding to prevent bypass
# via entity-encoded tags (e.g., &lt;script&gt; or &amp;lt;script&amp;gt;)
# The loop terminates when unescape produces no change (idempotent on decoded text).
# Max iterations cap provides defense-in-depth against pathological inputs.
max_iterations = 100
decoded = value
prev = None
iterations = 0
while prev != decoded and iterations < max_iterations:
prev = decoded
decoded = html.unescape(decoded)
iterations += 1
# nh3.clean with tags=set() strips ALL HTML tags from the decoded input
# url_schemes=set() blocks all URL schemes in any remaining attributes
cleaned = nh3.clean(decoded, tags=set(), url_schemes=set())
# Only restore &amp; → & to preserve ampersands in display text (e.g. "A & B").
# Do NOT use html.unescape() here: nh3 may pass through HTML entities from
# the input (e.g. &lt;script&gt;), and a full unescape would re-introduce
# raw angle brackets, creating an XSS vector.
return cleaned.replace("&amp;", "&")
def _check_dangerous_patterns(value: str, field_name: str) -> None:
"""
Check for dangerous patterns that nh3 doesn't catch.
This includes URL schemes in plain text (not in HTML attributes),
event handler patterns, and dangerous Unicode characters.
Args:
value: The input string to check
field_name: Name of the field (for error messages)
Raises:
ValueError: If dangerous patterns are found
"""
# Block dangerous URL schemes in plain text (word boundary check)
if re.search(r"\b(javascript|vbscript|data):", value, re.IGNORECASE):
raise ValueError(f"{field_name} contains potentially malicious URL scheme")
# Block event handler patterns (onclick=, onerror=, etc.)
if re.search(r"on\w+\s*=", value, re.IGNORECASE):
raise ValueError(f"{field_name} contains potentially malicious event handler")
def _check_sql_patterns(value: str, field_name: str) -> None:
"""
Check for SQL injection patterns.
Args:
value: The input string to check
field_name: Name of the field (for error messages)
Raises:
ValueError: If SQL injection patterns are found
"""
# Check for dangerous SQL keywords
if re.search(
r"\b(DROP|DELETE|INSERT|UPDATE|CREATE|ALTER|EXEC|EXECUTE)\b",
value,
re.IGNORECASE,
):
raise ValueError(f"{field_name} contains potentially unsafe SQL keywords")
# Check for shell metacharacters and SQL comments
if re.search(r"[;|&$`]|--", value):
raise ValueError(f"{field_name} contains potentially unsafe characters")
# Check for SQL comment start
if "/*" in value:
raise ValueError(f"{field_name} contains potentially unsafe SQL comment syntax")
def _remove_dangerous_unicode(value: str) -> str:
"""
Remove dangerous Unicode characters (zero-width, control chars).
Args:
value: The input string
Returns:
String with dangerous Unicode characters removed
"""
return re.sub(
r"[\u200B-\u200D\uFEFF\u0000-\u0008\u000B\u000C\u000E-\u001F]", "", value
)
def sanitize_user_input(
value: str | None,
field_name: str,
max_length: int = 255,
check_sql_keywords: bool = False,
allow_empty: bool = False,
) -> str | None:
"""
Centralized sanitization for user-provided text inputs.
Uses nh3 to strip HTML tags and performs additional security checks.
Args:
value: The input string to sanitize
field_name: Name of the field (for error messages)
max_length: Maximum allowed length
check_sql_keywords: Whether to check for SQL injection keywords
allow_empty: Whether to allow empty/None values
Returns:
Sanitized string, or None if allow_empty=True and value is empty
Raises:
ValueError: If value fails security validation
Security checks performed:
- Strips all HTML tags using nh3 (Rust-based sanitizer)
- Blocks JavaScript/VBScript/data URL schemes
- Blocks event handlers (onclick=, onerror=, etc.)
- Removes dangerous Unicode characters (zero-width, control chars)
- SQL keywords and shell metacharacters (when check_sql_keywords=True)
"""
if value is None:
if allow_empty:
return None
raise ValueError(f"{field_name} cannot be empty")
value = value.strip()
if not value:
if allow_empty:
return None
raise ValueError(f"{field_name} cannot be empty")
# Length check first to prevent ReDoS attacks
if len(value) > max_length:
raise ValueError(
f"{field_name} too long ({len(value)} characters). "
f"Maximum allowed length is {max_length} characters."
)
# Strip all HTML tags using nh3
value = _strip_html_tags(value)
# Check for dangerous patterns (URL schemes, event handlers)
_check_dangerous_patterns(value, field_name)
# SQL keyword and shell metacharacter checks (for column names, etc.)
if check_sql_keywords:
_check_sql_patterns(value, field_name)
# Remove dangerous Unicode characters
value = _remove_dangerous_unicode(value)
return value
def sanitize_filter_value(
value: str | int | float | bool,
max_length: int = 1000,
) -> str | int | float | bool:
"""
Sanitize filter values which can be strings or other types.
For non-string values, returns as-is (no sanitization needed).
For strings, uses nh3 to strip HTML and applies security validation.
Args:
value: The filter value (string, int, float, or bool)
max_length: Maximum length for string values
Returns:
Sanitized value
Raises:
ValueError: If string value fails security validation
"""
if not isinstance(value, str):
return value
value = value.strip()
# Length check first
if len(value) > max_length:
raise ValueError(
f"Filter value too long ({len(value)} characters). "
f"Maximum allowed length is {max_length} characters."
)
# Strip all HTML tags using nh3
value = _strip_html_tags(value)
# Check for dangerous patterns
_check_dangerous_patterns(value, "Filter value")
# Check for dangerous SQL procedures (filter-specific)
v_lower = value.lower()
if "xp_cmdshell" in v_lower or "sp_executesql" in v_lower:
raise ValueError("Filter value contains potentially malicious SQL procedures.")
# SQL injection patterns specific to filter values
sql_patterns = [
r";\s*(DROP|DELETE|INSERT|UPDATE|CREATE|ALTER|EXEC|EXECUTE)\b",
r"'\s*OR\s*'",
r"'\s*AND\s*'",
r"--\s*",
r"/\*",
r"UNION\s+SELECT",
]
for pattern in sql_patterns:
if re.search(pattern, value, re.IGNORECASE):
raise ValueError(
"Filter value contains potentially malicious SQL patterns."
)
# Check for shell metacharacters that could indicate injection attempts
# Note: We allow '&' alone as it's common in text ("A & B") and is only
# dangerous in shell contexts, not in database queries
if re.search(r"[;|`$()]", value):
raise ValueError("Filter value contains potentially unsafe shell characters.")
# Check for hex encoding
if re.search(r"\\x[0-9a-fA-F]{2}", value):
raise ValueError("Filter value contains hex encoding which is not allowed.")
# Remove dangerous Unicode characters
value = _remove_dangerous_unicode(value)
return value

View File

@@ -0,0 +1,480 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from superset.mcp_service.utils.sanitization import (
_check_dangerous_patterns,
_check_sql_patterns,
_remove_dangerous_unicode,
_strip_html_tags,
sanitize_filter_value,
sanitize_user_input,
)
# --- _strip_html_tags tests ---
def test_strip_html_tags_plain_text():
assert _strip_html_tags("hello world") == "hello world"
def test_strip_html_tags_preserves_ampersand():
assert _strip_html_tags("A & B") == "A & B"
def test_strip_html_tags_preserves_multiple_ampersands():
assert _strip_html_tags("A & B & C") == "A & B & C"
def test_strip_html_tags_strips_bold_tags():
assert _strip_html_tags("<b>hello</b>") == "hello"
def test_strip_html_tags_strips_script_tags():
result = _strip_html_tags("<script>alert(1)</script>")
assert "<script>" not in result
assert "</script>" not in result
def test_strip_html_tags_strips_entity_encoded_script():
"""Entity-encoded tags must be decoded and stripped, not passed through."""
result = _strip_html_tags("&lt;script&gt;alert(1)&lt;/script&gt;")
assert "<script>" not in result
assert "&lt;script&gt;" not in result
def test_strip_html_tags_strips_double_encoded_script():
"""Double-encoded entities must also be decoded and stripped."""
result = _strip_html_tags("&amp;lt;script&amp;gt;alert(1)&amp;lt;/script&amp;gt;")
assert "<script>" not in result
assert "&lt;script&gt;" not in result
def test_strip_html_tags_strips_img_onerror():
result = _strip_html_tags('<img src=x onerror="alert(1)">')
assert "<img" not in result
assert "onerror" not in result
def test_strip_html_tags_strips_div_tags():
assert _strip_html_tags("<div>content</div>") == "content"
def test_strip_html_tags_preserves_less_than_in_text():
"""A bare < not forming a tag should be preserved."""
result = _strip_html_tags("5 < 10")
assert "5" in result
assert "10" in result
def test_strip_html_tags_empty_string():
assert _strip_html_tags("") == ""
def test_strip_html_tags_triple_encoded_script():
"""Triple-encoded entities must also be decoded and stripped."""
result = _strip_html_tags(
"&amp;amp;lt;script&amp;amp;gt;alert(1)&amp;amp;lt;/script&amp;amp;gt;"
)
assert "<script>" not in result
def test_strip_html_tags_mixed_encoded_and_raw():
"""Both raw and entity-encoded tags should be stripped."""
result = _strip_html_tags("<b>bold</b> and &lt;i&gt;italic&lt;/i&gt;")
assert "<b>" not in result
assert "<i>" not in result
assert "bold" in result
assert "italic" in result
def test_strip_html_tags_deep_encoding_terminates():
"""Verify the iterative decode loop terminates on many encoding layers."""
value = "<script>alert(1)</script>"
for _ in range(10):
value = value.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
result = _strip_html_tags(value)
assert "<script>" not in result
def test_strip_html_tags_entity_ampersand():
"""&amp; in input should become & in output."""
assert _strip_html_tags("A &amp; B") == "A & B"
# --- _check_dangerous_patterns tests ---
def test_check_dangerous_patterns_safe_input():
_check_dangerous_patterns("hello world", "test")
def test_check_dangerous_patterns_javascript_scheme():
with pytest.raises(ValueError, match="malicious URL scheme"):
_check_dangerous_patterns("javascript:alert(1)", "test")
def test_check_dangerous_patterns_vbscript_scheme():
with pytest.raises(ValueError, match="malicious URL scheme"):
_check_dangerous_patterns("vbscript:MsgBox", "test")
def test_check_dangerous_patterns_data_scheme():
with pytest.raises(ValueError, match="malicious URL scheme"):
_check_dangerous_patterns("data:text/html,<script>", "test")
def test_check_dangerous_patterns_case_insensitive():
with pytest.raises(ValueError, match="malicious URL scheme"):
_check_dangerous_patterns("JAVASCRIPT:alert(1)", "test")
def test_check_dangerous_patterns_onclick():
with pytest.raises(ValueError, match="malicious event handler"):
_check_dangerous_patterns("onclick=alert(1)", "test")
def test_check_dangerous_patterns_onerror():
with pytest.raises(ValueError, match="malicious event handler"):
_check_dangerous_patterns("onerror = alert(1)", "test")
def test_check_dangerous_patterns_onload():
with pytest.raises(ValueError, match="malicious event handler"):
_check_dangerous_patterns("onload=fetch('x')", "test")
# --- _check_sql_patterns tests ---
def test_check_sql_patterns_safe_input():
_check_sql_patterns("revenue_total", "test")
def test_check_sql_patterns_drop_table():
with pytest.raises(ValueError, match="unsafe SQL keywords"):
_check_sql_patterns("DROP TABLE users", "test")
def test_check_sql_patterns_delete():
with pytest.raises(ValueError, match="unsafe SQL keywords"):
_check_sql_patterns("DELETE FROM users", "test")
def test_check_sql_patterns_semicolon():
with pytest.raises(ValueError, match="unsafe characters"):
_check_sql_patterns("value; other", "test")
def test_check_sql_patterns_sql_comment_dash():
with pytest.raises(ValueError, match="unsafe characters"):
_check_sql_patterns("value -- comment", "test")
def test_check_sql_patterns_sql_comment_block():
with pytest.raises(ValueError, match="unsafe SQL comment"):
_check_sql_patterns("value /* comment */", "test")
def test_check_sql_patterns_pipe():
with pytest.raises(ValueError, match="unsafe characters"):
_check_sql_patterns("value | other", "test")
def test_check_sql_patterns_case_insensitive():
with pytest.raises(ValueError, match="unsafe SQL keywords"):
_check_sql_patterns("drop table users", "test")
# --- _remove_dangerous_unicode tests ---
def test_remove_dangerous_unicode_plain_text():
assert _remove_dangerous_unicode("hello") == "hello"
def test_remove_dangerous_unicode_zero_width_space():
assert _remove_dangerous_unicode("he\u200bllo") == "hello"
def test_remove_dangerous_unicode_zero_width_joiner():
assert _remove_dangerous_unicode("he\u200dllo") == "hello"
def test_remove_dangerous_unicode_bom():
assert _remove_dangerous_unicode("\ufeffhello") == "hello"
def test_remove_dangerous_unicode_null_byte():
assert _remove_dangerous_unicode("he\x00llo") == "hello"
def test_remove_dangerous_unicode_preserves_normal_unicode():
assert _remove_dangerous_unicode("café résumé") == "café résumé"
# --- sanitize_user_input tests ---
def test_sanitize_user_input_plain_text():
assert sanitize_user_input("hello", "test") == "hello"
def test_sanitize_user_input_preserves_ampersand():
assert sanitize_user_input("A & B", "test") == "A & B"
def test_sanitize_user_input_strips_html():
assert sanitize_user_input("<b>hello</b>", "test") == "hello"
def test_sanitize_user_input_none_not_allowed():
with pytest.raises(ValueError, match="cannot be empty"):
sanitize_user_input(None, "test")
def test_sanitize_user_input_none_allowed():
assert sanitize_user_input(None, "test", allow_empty=True) is None
def test_sanitize_user_input_empty_string_not_allowed():
with pytest.raises(ValueError, match="cannot be empty"):
sanitize_user_input("", "test")
def test_sanitize_user_input_empty_string_allowed():
assert sanitize_user_input("", "test", allow_empty=True) is None
def test_sanitize_user_input_whitespace_only():
with pytest.raises(ValueError, match="cannot be empty"):
sanitize_user_input(" ", "test")
def test_sanitize_user_input_strips_whitespace():
assert sanitize_user_input(" hello ", "test") == "hello"
def test_sanitize_user_input_too_long():
with pytest.raises(ValueError, match="too long"):
sanitize_user_input("a" * 256, "test", max_length=255)
def test_sanitize_user_input_max_length_ok():
result = sanitize_user_input("a" * 255, "test", max_length=255)
assert result == "a" * 255
def test_sanitize_user_input_blocks_javascript():
with pytest.raises(ValueError, match="malicious URL scheme"):
sanitize_user_input("javascript:alert(1)", "test")
def test_sanitize_user_input_blocks_event_handler():
with pytest.raises(ValueError, match="malicious event handler"):
sanitize_user_input("onclick=alert(1)", "test")
def test_sanitize_user_input_sql_keywords_not_checked_by_default():
result = sanitize_user_input("DROP TABLE", "test")
assert result == "DROP TABLE"
def test_sanitize_user_input_sql_keywords_checked_when_enabled():
with pytest.raises(ValueError, match="unsafe SQL keywords"):
sanitize_user_input("DROP TABLE users", "test", check_sql_keywords=True)
def test_sanitize_user_input_removes_zero_width_chars():
result = sanitize_user_input("hel\u200blo", "test")
assert result == "hello"
def test_sanitize_user_input_xss_entity_encoded():
"""Entity-encoded XSS attempts must be neutralized."""
result = sanitize_user_input("&lt;script&gt;alert(1)&lt;/script&gt;", "test")
assert "<script>" not in result
def test_sanitize_user_input_entity_encoded_javascript():
"""Entity-encoded javascript: scheme should be caught after decoding."""
with pytest.raises(ValueError, match="malicious URL scheme"):
sanitize_user_input("&#106;avascript:alert(1)", "test")
# --- sanitize_filter_value tests ---
def test_sanitize_filter_value_integer():
assert sanitize_filter_value(42) == 42
def test_sanitize_filter_value_float():
assert sanitize_filter_value(3.14) == 3.14
def test_sanitize_filter_value_bool():
assert sanitize_filter_value(True) is True
def test_sanitize_filter_value_plain_string():
assert sanitize_filter_value("hello") == "hello"
def test_sanitize_filter_value_preserves_ampersand():
assert sanitize_filter_value("A & B") == "A & B"
def test_sanitize_filter_value_strips_html():
assert sanitize_filter_value("<b>hello</b>") == "hello"
def test_sanitize_filter_value_too_long():
with pytest.raises(ValueError, match="too long"):
sanitize_filter_value("a" * 1001)
def test_sanitize_filter_value_blocks_javascript():
with pytest.raises(ValueError, match="malicious URL scheme"):
sanitize_filter_value("javascript:alert(1)")
def test_sanitize_filter_value_blocks_xp_cmdshell():
with pytest.raises(ValueError, match="malicious SQL procedures"):
sanitize_filter_value("xp_cmdshell('dir')")
def test_sanitize_filter_value_blocks_sp_executesql():
with pytest.raises(ValueError, match="malicious SQL procedures"):
sanitize_filter_value("sp_executesql @stmt")
def test_sanitize_filter_value_blocks_union_select():
with pytest.raises(ValueError, match="malicious SQL patterns"):
sanitize_filter_value("' UNION SELECT * FROM users")
def test_sanitize_filter_value_blocks_sql_comment():
with pytest.raises(ValueError, match="malicious SQL patterns"):
sanitize_filter_value("value -- drop")
def test_sanitize_filter_value_blocks_shell_semicolon():
with pytest.raises(ValueError, match="unsafe shell characters"):
sanitize_filter_value("value;rm -rf")
def test_sanitize_filter_value_blocks_shell_pipe():
with pytest.raises(ValueError, match="unsafe shell characters"):
sanitize_filter_value("value|cat /etc/passwd")
def test_sanitize_filter_value_blocks_backtick():
with pytest.raises(ValueError, match="unsafe shell characters"):
sanitize_filter_value("`whoami`")
def test_sanitize_filter_value_blocks_hex_encoding():
with pytest.raises(ValueError, match="hex encoding"):
sanitize_filter_value("\\x41\\x42")
def test_sanitize_filter_value_allows_ampersand_alone():
"""Ampersand is safe in filter values (only dangerous in shell contexts)."""
assert sanitize_filter_value("AT&T") == "AT&T"
def test_sanitize_filter_value_removes_zero_width_chars():
result = sanitize_filter_value("hel\u200blo")
assert result == "hello"
def test_sanitize_filter_value_blocks_or_injection():
with pytest.raises(ValueError, match="malicious SQL patterns"):
sanitize_filter_value("' OR '1'='1")
def test_sanitize_filter_value_blocks_and_injection():
with pytest.raises(ValueError, match="malicious SQL patterns"):
sanitize_filter_value("' AND '1'='1")
def test_sanitize_filter_value_blocks_block_comment():
with pytest.raises(ValueError, match="malicious SQL patterns"):
sanitize_filter_value("value /* comment */")
def test_sanitize_filter_value_blocks_semicolon_drop():
with pytest.raises(ValueError, match="malicious SQL patterns"):
sanitize_filter_value("; DROP TABLE users")
def test_sanitize_filter_value_blocks_parentheses():
with pytest.raises(ValueError, match="unsafe shell characters"):
sanitize_filter_value("$(whoami)")
def test_sanitize_filter_value_blocks_dollar_sign():
with pytest.raises(ValueError, match="unsafe shell characters"):
sanitize_filter_value("$HOME")
def test_sanitize_filter_value_blocks_event_handler():
with pytest.raises(ValueError, match="malicious event handler"):
sanitize_filter_value("onerror=alert(1)")
def test_sanitize_filter_value_xss_entity_encoded():
"""Entity-encoded XSS in filter values must be neutralized."""
result = sanitize_filter_value("&lt;img src=x onerror=alert(1)&gt;")
assert "<img" not in result
# --- Defense-in-depth: verify html.unescape is not used after nh3 ---
def test_strip_html_tags_does_not_unescape_angle_brackets():
"""Ensure nh3 entity output is not fully unescaped back to raw HTML.
nh3.clean may pass through HTML entities (e.g. &lt;script&gt;) from
the input without stripping them. A full html.unescape() on nh3's
output could reintroduce raw angle brackets, creating an XSS vector.
"""
# Plain text passes through unchanged
result = _strip_html_tags("safe text")
assert result == "safe text"
# Verify ampersand preservation still works
result = _strip_html_tags("A & B")
assert result == "A & B"
# Verify real tags are stripped
result = _strip_html_tags("<script>alert(1)</script>")
assert "<script>" not in result
# Entity-encoded script tags must not become real tags in the output
result = _strip_html_tags("&lt;script&gt;alert(1)&lt;/script&gt;")
assert "<script>" not in result
assert "</script>" not in result
def test_strip_html_tags_img_onerror_entity_bypass():
"""Entity-encoded img/onerror should not survive sanitization."""
result = _strip_html_tags("&lt;img src=x onerror=alert(1)&gt;")
assert "<img" not in result
assert "onerror" not in result