fix(mcp): remove html.escape to fix ampersand display in chart titles (#37186)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Amin Ghadersohi
2026-02-05 20:49:37 -07:00
committed by GitHub
parent 97e5f0631d
commit 01ac966b83
4 changed files with 792 additions and 264 deletions

View File

@@ -21,8 +21,6 @@ Pydantic schemas for chart-related responses
from __future__ import annotations from __future__ import annotations
import html
import re
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Annotated, Any, Dict, List, Literal, Protocol from typing import Annotated, Any, Dict, List, Literal, Protocol
@@ -50,6 +48,10 @@ from superset.mcp_service.system.schemas import (
TagInfo, TagInfo,
UserInfo, UserInfo,
) )
from superset.mcp_service.utils.sanitization import (
sanitize_filter_value,
sanitize_user_input,
)
class ChartLike(Protocol): class ChartLike(Protocol):
@@ -357,113 +359,17 @@ class ColumnRef(BaseModel):
@classmethod @classmethod
def sanitize_name(cls, v: str) -> str: def sanitize_name(cls, v: str) -> str:
"""Sanitize column name to prevent XSS and SQL injection.""" """Sanitize column name to prevent XSS and SQL injection."""
if not v or not v.strip(): # sanitize_user_input raises ValueError when allow_empty=False (default)
raise ValueError("Column name cannot be empty") # so the return value is guaranteed to be a non-None str
return sanitize_user_input(
# Length check first to prevent ReDoS attacks v, "Column name", max_length=255, check_sql_keywords=True
if len(v) > 255: ) # type: ignore[return-value]
raise ValueError(
f"Column name too long ({len(v)} characters). "
f"Maximum allowed length is 255 characters."
)
# Remove HTML tags and decode entities
sanitized = html.escape(v.strip())
# Check for dangerous HTML tags using substring checks (safe)
dangerous_tags = ["<script", "</script>", "<iframe", "<object", "<embed"]
v_lower = v.lower()
for tag in dangerous_tags:
if tag in v_lower:
raise ValueError(
"Column name contains potentially malicious script content"
)
# Check URL schemes with word boundaries to match only actual URLs
if re.search(r"\b(javascript|vbscript|data):", v, re.IGNORECASE):
raise ValueError("Column name contains potentially malicious URL scheme")
# Basic SQL injection patterns (basic protection)
# Use simple patterns without backtracking
dangerous_patterns = [
r"[;|&$`]", # Dangerous shell characters
r"\b(DROP|DELETE|INSERT|UPDATE|CREATE|ALTER|EXEC|EXECUTE)\b",
r"--", # SQL comment
r"/\*", # SQL comment start (just check for start, not full pattern)
]
for pattern in dangerous_patterns:
if re.search(pattern, v, re.IGNORECASE):
raise ValueError(
"Column name contains potentially unsafe characters or SQL keywords"
)
return sanitized
@field_validator("label") @field_validator("label")
@classmethod @classmethod
def sanitize_label(cls, v: str | None) -> str | None: def sanitize_label(cls, v: str | None) -> str | None:
"""Sanitize display label to prevent XSS attacks.""" """Sanitize display label to prevent XSS attacks."""
if v is None: return sanitize_user_input(v, "Label", max_length=500, allow_empty=True)
return v
# Strip whitespace
v = v.strip()
if not v:
return None
# Length check first to prevent ReDoS attacks
if len(v) > 500:
raise ValueError(
f"Label too long ({len(v)} characters). "
f"Maximum allowed length is 500 characters."
)
# Check for dangerous HTML tags and JavaScript protocols using substring checks
# This avoids ReDoS vulnerabilities from regex patterns
dangerous_tags = [
"<script",
"</script>",
"<iframe",
"</iframe>",
"<object",
"</object>",
"<embed",
"</embed>",
"<link",
"<meta",
]
v_lower = v.lower()
for tag in dangerous_tags:
if tag in v_lower:
raise ValueError(
"Label contains potentially malicious content. "
"HTML tags, JavaScript, and event handlers are not allowed "
"in labels."
)
# Check URL schemes and event handlers with word boundaries
dangerous_patterns = [
r"\b(javascript|vbscript|data):", # URL schemes
r"on\w+\s*=", # Event handlers
]
for pattern in dangerous_patterns:
if re.search(pattern, v, re.IGNORECASE):
raise ValueError(
"Label contains potentially malicious content. "
"HTML tags, JavaScript, and event handlers are not allowed."
)
# Filter dangerous Unicode characters
v = re.sub(
r"[\u200B-\u200D\uFEFF\u0000-\u0008\u000B\u000C\u000E-\u001F]", "", v
)
# HTML escape the cleaned content
sanitized = html.escape(v)
return sanitized if sanitized else None
class AxisConfig(BaseModel): class AxisConfig(BaseModel):
@@ -496,112 +402,15 @@ class FilterConfig(BaseModel):
@classmethod @classmethod
def sanitize_column(cls, v: str) -> str: def sanitize_column(cls, v: str) -> str:
"""Sanitize filter column name to prevent injection attacks.""" """Sanitize filter column name to prevent injection attacks."""
if not v or not v.strip(): # sanitize_user_input raises ValueError when allow_empty=False (default)
raise ValueError("Filter column name cannot be empty") # so the return value is guaranteed to be a non-None str
return sanitize_user_input(v, "Filter column", max_length=255) # type: ignore[return-value]
# Length check first to prevent ReDoS attacks
if len(v) > 255:
raise ValueError(
f"Filter column name too long ({len(v)} characters). "
f"Maximum allowed length is 255 characters."
)
# Remove HTML tags and decode entities
sanitized = html.escape(v.strip())
# Check for dangerous HTML tags using substring checks (safe)
dangerous_tags = ["<script", "</script>"]
v_lower = v.lower()
for tag in dangerous_tags:
if tag in v_lower:
raise ValueError(
"Filter column contains potentially malicious script content"
)
# Check URL schemes with word boundaries
if re.search(r"\b(javascript|vbscript|data):", v, re.IGNORECASE):
raise ValueError("Filter column contains potentially malicious URL scheme")
return sanitized
@staticmethod
def _validate_string_value(v: str) -> None:
"""Validate string filter value for security issues."""
# Check for dangerous HTML tags and SQL procedures
dangerous_substrings = [
"<script",
"</script>",
"<iframe",
"<object",
"<embed",
"xp_cmdshell",
"sp_executesql",
]
v_lower = v.lower()
for substring in dangerous_substrings:
if substring in v_lower:
raise ValueError(
"Filter value contains potentially malicious content. "
"HTML tags and JavaScript are not allowed."
)
# Check URL schemes with word boundaries
if re.search(r"\b(javascript|vbscript|data):", v, re.IGNORECASE):
raise ValueError("Filter value contains potentially malicious URL scheme")
# SQL injection patterns
sql_patterns = [
r";\s*(DROP|DELETE|INSERT|UPDATE|CREATE|ALTER|EXEC|EXECUTE)\b",
r"'\s*OR\s*'",
r"'\s*AND\s*'",
r"--\s*",
r"/\*",
r"UNION\s+SELECT",
]
for pattern in sql_patterns:
if re.search(pattern, v, re.IGNORECASE):
raise ValueError(
"Filter value contains potentially malicious SQL patterns."
)
# Check for other dangerous patterns
if re.search(r"[;&|`$()]", v):
raise ValueError(
"Filter value contains potentially unsafe shell characters."
)
if re.search(r"on\w+\s*=", v, re.IGNORECASE):
raise ValueError(
"Filter value contains potentially malicious event handlers."
)
if re.search(r"\\x[0-9a-fA-F]{2}", v):
raise ValueError("Filter value contains hex encoding which is not allowed.")
@field_validator("value") @field_validator("value")
@classmethod @classmethod
def sanitize_value(cls, v: str | int | float | bool) -> str | int | float | bool: def sanitize_value(cls, v: str | int | float | bool) -> str | int | float | bool:
"""Sanitize filter value to prevent XSS and SQL injection attacks.""" """Sanitize filter value to prevent XSS and SQL injection attacks."""
if isinstance(v, str): return sanitize_filter_value(v, max_length=1000)
v = v.strip()
# Length check FIRST to prevent ReDoS attacks
if len(v) > 1000:
raise ValueError(
f"Filter value too long ({len(v)} characters). "
f"Maximum allowed length is 1000 characters."
)
# Validate security
cls._validate_string_value(v)
# Filter dangerous Unicode characters
v = re.sub(
r"[\u200B-\u200D\uFEFF\u0000-\u0008\u000B\u000C\u000E-\u001F]", "", v
)
# HTML escape the cleaned content
return html.escape(v)
return v # Return non-string values as-is
# Actual chart types # Actual chart types
@@ -848,6 +657,11 @@ class ListChartsRequest(MetadataCacheControl):
class GenerateChartRequest(QueryCacheControl): class GenerateChartRequest(QueryCacheControl):
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)") dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
config: ChartConfig = Field(..., description="Chart configuration") config: ChartConfig = Field(..., description="Chart configuration")
chart_name: str | None = Field(
None,
description="Custom chart name (optional, auto-generates if not provided)",
max_length=255,
)
save_chart: bool = Field( save_chart: bool = Field(
default=False, default=False,
description="Whether to permanently save the chart in Superset", description="Whether to permanently save the chart in Superset",
@@ -861,6 +675,12 @@ class GenerateChartRequest(QueryCacheControl):
description="List of preview formats to generate", description="List of preview formats to generate",
) )
@field_validator("chart_name")
@classmethod
def sanitize_chart_name(cls, v: str | None) -> str | None:
"""Sanitize chart name to prevent XSS attacks."""
return sanitize_user_input(v, "Chart name", max_length=255, allow_empty=True)
@model_validator(mode="after") @model_validator(mode="after")
def validate_cache_timeout(self) -> "GenerateChartRequest": def validate_cache_timeout(self) -> "GenerateChartRequest":
"""Validate cache timeout is non-negative.""" """Validate cache timeout is non-negative."""
@@ -911,62 +731,7 @@ class UpdateChartRequest(QueryCacheControl):
@classmethod @classmethod
def sanitize_chart_name(cls, v: str | None) -> str | None: def sanitize_chart_name(cls, v: str | None) -> str | None:
"""Sanitize chart name to prevent XSS attacks.""" """Sanitize chart name to prevent XSS attacks."""
if v is None: return sanitize_user_input(v, "Chart name", max_length=255, allow_empty=True)
return v
# Strip whitespace
v = v.strip()
if not v:
return None
# Length check first to prevent ReDoS attacks
if len(v) > 255:
raise ValueError(
f"Chart name too long ({len(v)} characters). "
f"Maximum allowed length is 255 characters."
)
# Check for dangerous HTML tags using substring checks (safe)
dangerous_tags = [
"<script",
"</script>",
"<iframe",
"</iframe>",
"<object",
"</object>",
"<embed",
"</embed>",
"<link",
"<meta",
]
v_lower = v.lower()
for tag in dangerous_tags:
if tag in v_lower:
raise ValueError(
"Chart name contains potentially malicious content. "
"HTML tags and JavaScript are not allowed in chart names."
)
# Check URL schemes with word boundaries
if re.search(r"\b(javascript|vbscript|data):", v, re.IGNORECASE):
raise ValueError("Chart name contains potentially malicious URL scheme")
# Check for event handlers with simple regex
if re.search(r"on\w+\s*=", v, re.IGNORECASE):
raise ValueError(
"Chart name contains potentially malicious event handlers."
)
# Filter dangerous Unicode characters
v = re.sub(
r"[\u200B-\u200D\uFEFF\u0000-\u0008\u000B\u000C\u000E-\u001F]", "", v
)
# HTML escape the cleaned content
sanitized = html.escape(v)
return sanitized if sanitized else None
class UpdateChartPreviewRequest(FormDataCacheControl): class UpdateChartPreviewRequest(FormDataCacheControl):

View File

@@ -189,9 +189,9 @@ async def generate_chart( # noqa: C901
await ctx.report_progress(2, 5, "Creating chart in database") await ctx.report_progress(2, 5, "Creating chart in database")
from superset.commands.chart.create import CreateChartCommand from superset.commands.chart.create import CreateChartCommand
# Generate a chart name # Use custom chart name if provided, otherwise auto-generate
chart_name = generate_chart_name(request.config) chart_name = request.chart_name or generate_chart_name(request.config)
await ctx.debug("Generated chart name: chart_name=%s" % (chart_name,)) await ctx.debug("Chart name: chart_name=%s" % (chart_name,))
# Find the dataset to get its numeric ID # Find the dataset to get its numeric ID
from superset.daos.dataset import DatasetDAO from superset.daos.dataset import DatasetDAO

View File

@@ -0,0 +1,283 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Centralized sanitization utilities for MCP service user inputs.
This module uses the nh3 library (Rust-based HTML sanitizer) to strip malicious
HTML tags and protocols from user inputs. nh3 is faster and safer than manual
regex-based sanitization.
Key features:
- Strips all HTML tags using nh3.clean() with no allowed tags
- Blocks dangerous URL schemes (javascript:, vbscript:, data:)
- Preserves safe text content (e.g., '&' stays as '&', not '&amp;')
- Additional SQL injection protection for database-facing inputs
"""
import html
import re
import nh3
def _strip_html_tags(value: str) -> str:
"""
Strip all HTML tags from the input using nh3.
Decodes all layers of HTML entity encoding BEFORE passing to nh3,
so entity-encoded tags (e.g., ``&lt;script&gt;``) are decoded into
real tags that nh3 can detect and strip. After nh3 removes all tags,
we only restore ``&amp;`` back to ``&`` (not a full html.unescape)
to preserve ampersands in display text without risking XSS from
re-introducing angle brackets or other HTML-significant characters.
Args:
value: The input string that may contain HTML
Returns:
String with all HTML tags removed and ampersands preserved
"""
# Decode all layers of HTML entity encoding to prevent bypass
# via entity-encoded tags (e.g., &lt;script&gt; or &amp;lt;script&amp;gt;)
# The loop terminates when unescape produces no change (idempotent on decoded text).
# Max iterations cap provides defense-in-depth against pathological inputs.
max_iterations = 100
decoded = value
prev = None
iterations = 0
while prev != decoded and iterations < max_iterations:
prev = decoded
decoded = html.unescape(decoded)
iterations += 1
# nh3.clean with tags=set() strips ALL HTML tags from the decoded input
# url_schemes=set() blocks all URL schemes in any remaining attributes
cleaned = nh3.clean(decoded, tags=set(), url_schemes=set())
# Only restore &amp; → & to preserve ampersands in display text (e.g. "A & B").
# Do NOT use html.unescape() here: nh3 may pass through HTML entities from
# the input (e.g. &lt;script&gt;), and a full unescape would re-introduce
# raw angle brackets, creating an XSS vector.
return cleaned.replace("&amp;", "&")
def _check_dangerous_patterns(value: str, field_name: str) -> None:
"""
Check for dangerous patterns that nh3 doesn't catch.
This includes URL schemes in plain text (not in HTML attributes),
event handler patterns, and dangerous Unicode characters.
Args:
value: The input string to check
field_name: Name of the field (for error messages)
Raises:
ValueError: If dangerous patterns are found
"""
# Block dangerous URL schemes in plain text (word boundary check)
if re.search(r"\b(javascript|vbscript|data):", value, re.IGNORECASE):
raise ValueError(f"{field_name} contains potentially malicious URL scheme")
# Block event handler patterns (onclick=, onerror=, etc.)
if re.search(r"on\w+\s*=", value, re.IGNORECASE):
raise ValueError(f"{field_name} contains potentially malicious event handler")
def _check_sql_patterns(value: str, field_name: str) -> None:
"""
Check for SQL injection patterns.
Args:
value: The input string to check
field_name: Name of the field (for error messages)
Raises:
ValueError: If SQL injection patterns are found
"""
# Check for dangerous SQL keywords
if re.search(
r"\b(DROP|DELETE|INSERT|UPDATE|CREATE|ALTER|EXEC|EXECUTE)\b",
value,
re.IGNORECASE,
):
raise ValueError(f"{field_name} contains potentially unsafe SQL keywords")
# Check for shell metacharacters and SQL comments
if re.search(r"[;|&$`]|--", value):
raise ValueError(f"{field_name} contains potentially unsafe characters")
# Check for SQL comment start
if "/*" in value:
raise ValueError(f"{field_name} contains potentially unsafe SQL comment syntax")
def _remove_dangerous_unicode(value: str) -> str:
"""
Remove dangerous Unicode characters (zero-width, control chars).
Args:
value: The input string
Returns:
String with dangerous Unicode characters removed
"""
return re.sub(
r"[\u200B-\u200D\uFEFF\u0000-\u0008\u000B\u000C\u000E-\u001F]", "", value
)
def sanitize_user_input(
value: str | None,
field_name: str,
max_length: int = 255,
check_sql_keywords: bool = False,
allow_empty: bool = False,
) -> str | None:
"""
Centralized sanitization for user-provided text inputs.
Uses nh3 to strip HTML tags and performs additional security checks.
Args:
value: The input string to sanitize
field_name: Name of the field (for error messages)
max_length: Maximum allowed length
check_sql_keywords: Whether to check for SQL injection keywords
allow_empty: Whether to allow empty/None values
Returns:
Sanitized string, or None if allow_empty=True and value is empty
Raises:
ValueError: If value fails security validation
Security checks performed:
- Strips all HTML tags using nh3 (Rust-based sanitizer)
- Blocks JavaScript/VBScript/data URL schemes
- Blocks event handlers (onclick=, onerror=, etc.)
- Removes dangerous Unicode characters (zero-width, control chars)
- SQL keywords and shell metacharacters (when check_sql_keywords=True)
"""
if value is None:
if allow_empty:
return None
raise ValueError(f"{field_name} cannot be empty")
value = value.strip()
if not value:
if allow_empty:
return None
raise ValueError(f"{field_name} cannot be empty")
# Length check first to prevent ReDoS attacks
if len(value) > max_length:
raise ValueError(
f"{field_name} too long ({len(value)} characters). "
f"Maximum allowed length is {max_length} characters."
)
# Strip all HTML tags using nh3
value = _strip_html_tags(value)
# Check for dangerous patterns (URL schemes, event handlers)
_check_dangerous_patterns(value, field_name)
# SQL keyword and shell metacharacter checks (for column names, etc.)
if check_sql_keywords:
_check_sql_patterns(value, field_name)
# Remove dangerous Unicode characters
value = _remove_dangerous_unicode(value)
return value
def sanitize_filter_value(
value: str | int | float | bool,
max_length: int = 1000,
) -> str | int | float | bool:
"""
Sanitize filter values which can be strings or other types.
For non-string values, returns as-is (no sanitization needed).
For strings, uses nh3 to strip HTML and applies security validation.
Args:
value: The filter value (string, int, float, or bool)
max_length: Maximum length for string values
Returns:
Sanitized value
Raises:
ValueError: If string value fails security validation
"""
if not isinstance(value, str):
return value
value = value.strip()
# Length check first
if len(value) > max_length:
raise ValueError(
f"Filter value too long ({len(value)} characters). "
f"Maximum allowed length is {max_length} characters."
)
# Strip all HTML tags using nh3
value = _strip_html_tags(value)
# Check for dangerous patterns
_check_dangerous_patterns(value, "Filter value")
# Check for dangerous SQL procedures (filter-specific)
v_lower = value.lower()
if "xp_cmdshell" in v_lower or "sp_executesql" in v_lower:
raise ValueError("Filter value contains potentially malicious SQL procedures.")
# SQL injection patterns specific to filter values
sql_patterns = [
r";\s*(DROP|DELETE|INSERT|UPDATE|CREATE|ALTER|EXEC|EXECUTE)\b",
r"'\s*OR\s*'",
r"'\s*AND\s*'",
r"--\s*",
r"/\*",
r"UNION\s+SELECT",
]
for pattern in sql_patterns:
if re.search(pattern, value, re.IGNORECASE):
raise ValueError(
"Filter value contains potentially malicious SQL patterns."
)
# Check for shell metacharacters that could indicate injection attempts
# Note: We allow '&' alone as it's common in text ("A & B") and is only
# dangerous in shell contexts, not in database queries
if re.search(r"[;|`$()]", value):
raise ValueError("Filter value contains potentially unsafe shell characters.")
# Check for hex encoding
if re.search(r"\\x[0-9a-fA-F]{2}", value):
raise ValueError("Filter value contains hex encoding which is not allowed.")
# Remove dangerous Unicode characters
value = _remove_dangerous_unicode(value)
return value

View File

@@ -0,0 +1,480 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from superset.mcp_service.utils.sanitization import (
_check_dangerous_patterns,
_check_sql_patterns,
_remove_dangerous_unicode,
_strip_html_tags,
sanitize_filter_value,
sanitize_user_input,
)
# --- _strip_html_tags tests ---
def test_strip_html_tags_plain_text():
assert _strip_html_tags("hello world") == "hello world"
def test_strip_html_tags_preserves_ampersand():
assert _strip_html_tags("A & B") == "A & B"
def test_strip_html_tags_preserves_multiple_ampersands():
assert _strip_html_tags("A & B & C") == "A & B & C"
def test_strip_html_tags_strips_bold_tags():
assert _strip_html_tags("<b>hello</b>") == "hello"
def test_strip_html_tags_strips_script_tags():
result = _strip_html_tags("<script>alert(1)</script>")
assert "<script>" not in result
assert "</script>" not in result
def test_strip_html_tags_strips_entity_encoded_script():
"""Entity-encoded tags must be decoded and stripped, not passed through."""
result = _strip_html_tags("&lt;script&gt;alert(1)&lt;/script&gt;")
assert "<script>" not in result
assert "&lt;script&gt;" not in result
def test_strip_html_tags_strips_double_encoded_script():
"""Double-encoded entities must also be decoded and stripped."""
result = _strip_html_tags("&amp;lt;script&amp;gt;alert(1)&amp;lt;/script&amp;gt;")
assert "<script>" not in result
assert "&lt;script&gt;" not in result
def test_strip_html_tags_strips_img_onerror():
result = _strip_html_tags('<img src=x onerror="alert(1)">')
assert "<img" not in result
assert "onerror" not in result
def test_strip_html_tags_strips_div_tags():
assert _strip_html_tags("<div>content</div>") == "content"
def test_strip_html_tags_preserves_less_than_in_text():
"""A bare < not forming a tag should be preserved."""
result = _strip_html_tags("5 < 10")
assert "5" in result
assert "10" in result
def test_strip_html_tags_empty_string():
assert _strip_html_tags("") == ""
def test_strip_html_tags_triple_encoded_script():
"""Triple-encoded entities must also be decoded and stripped."""
result = _strip_html_tags(
"&amp;amp;lt;script&amp;amp;gt;alert(1)&amp;amp;lt;/script&amp;amp;gt;"
)
assert "<script>" not in result
def test_strip_html_tags_mixed_encoded_and_raw():
"""Both raw and entity-encoded tags should be stripped."""
result = _strip_html_tags("<b>bold</b> and &lt;i&gt;italic&lt;/i&gt;")
assert "<b>" not in result
assert "<i>" not in result
assert "bold" in result
assert "italic" in result
def test_strip_html_tags_deep_encoding_terminates():
"""Verify the iterative decode loop terminates on many encoding layers."""
value = "<script>alert(1)</script>"
for _ in range(10):
value = value.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
result = _strip_html_tags(value)
assert "<script>" not in result
def test_strip_html_tags_entity_ampersand():
"""&amp; in input should become & in output."""
assert _strip_html_tags("A &amp; B") == "A & B"
# --- _check_dangerous_patterns tests ---
def test_check_dangerous_patterns_safe_input():
_check_dangerous_patterns("hello world", "test")
def test_check_dangerous_patterns_javascript_scheme():
with pytest.raises(ValueError, match="malicious URL scheme"):
_check_dangerous_patterns("javascript:alert(1)", "test")
def test_check_dangerous_patterns_vbscript_scheme():
with pytest.raises(ValueError, match="malicious URL scheme"):
_check_dangerous_patterns("vbscript:MsgBox", "test")
def test_check_dangerous_patterns_data_scheme():
with pytest.raises(ValueError, match="malicious URL scheme"):
_check_dangerous_patterns("data:text/html,<script>", "test")
def test_check_dangerous_patterns_case_insensitive():
with pytest.raises(ValueError, match="malicious URL scheme"):
_check_dangerous_patterns("JAVASCRIPT:alert(1)", "test")
def test_check_dangerous_patterns_onclick():
with pytest.raises(ValueError, match="malicious event handler"):
_check_dangerous_patterns("onclick=alert(1)", "test")
def test_check_dangerous_patterns_onerror():
with pytest.raises(ValueError, match="malicious event handler"):
_check_dangerous_patterns("onerror = alert(1)", "test")
def test_check_dangerous_patterns_onload():
with pytest.raises(ValueError, match="malicious event handler"):
_check_dangerous_patterns("onload=fetch('x')", "test")
# --- _check_sql_patterns tests ---
def test_check_sql_patterns_safe_input():
_check_sql_patterns("revenue_total", "test")
def test_check_sql_patterns_drop_table():
with pytest.raises(ValueError, match="unsafe SQL keywords"):
_check_sql_patterns("DROP TABLE users", "test")
def test_check_sql_patterns_delete():
with pytest.raises(ValueError, match="unsafe SQL keywords"):
_check_sql_patterns("DELETE FROM users", "test")
def test_check_sql_patterns_semicolon():
with pytest.raises(ValueError, match="unsafe characters"):
_check_sql_patterns("value; other", "test")
def test_check_sql_patterns_sql_comment_dash():
with pytest.raises(ValueError, match="unsafe characters"):
_check_sql_patterns("value -- comment", "test")
def test_check_sql_patterns_sql_comment_block():
with pytest.raises(ValueError, match="unsafe SQL comment"):
_check_sql_patterns("value /* comment */", "test")
def test_check_sql_patterns_pipe():
with pytest.raises(ValueError, match="unsafe characters"):
_check_sql_patterns("value | other", "test")
def test_check_sql_patterns_case_insensitive():
with pytest.raises(ValueError, match="unsafe SQL keywords"):
_check_sql_patterns("drop table users", "test")
# --- _remove_dangerous_unicode tests ---
def test_remove_dangerous_unicode_plain_text():
assert _remove_dangerous_unicode("hello") == "hello"
def test_remove_dangerous_unicode_zero_width_space():
assert _remove_dangerous_unicode("he\u200bllo") == "hello"
def test_remove_dangerous_unicode_zero_width_joiner():
assert _remove_dangerous_unicode("he\u200dllo") == "hello"
def test_remove_dangerous_unicode_bom():
assert _remove_dangerous_unicode("\ufeffhello") == "hello"
def test_remove_dangerous_unicode_null_byte():
assert _remove_dangerous_unicode("he\x00llo") == "hello"
def test_remove_dangerous_unicode_preserves_normal_unicode():
assert _remove_dangerous_unicode("café résumé") == "café résumé"
# --- sanitize_user_input tests ---
def test_sanitize_user_input_plain_text():
assert sanitize_user_input("hello", "test") == "hello"
def test_sanitize_user_input_preserves_ampersand():
assert sanitize_user_input("A & B", "test") == "A & B"
def test_sanitize_user_input_strips_html():
assert sanitize_user_input("<b>hello</b>", "test") == "hello"
def test_sanitize_user_input_none_not_allowed():
with pytest.raises(ValueError, match="cannot be empty"):
sanitize_user_input(None, "test")
def test_sanitize_user_input_none_allowed():
assert sanitize_user_input(None, "test", allow_empty=True) is None
def test_sanitize_user_input_empty_string_not_allowed():
with pytest.raises(ValueError, match="cannot be empty"):
sanitize_user_input("", "test")
def test_sanitize_user_input_empty_string_allowed():
assert sanitize_user_input("", "test", allow_empty=True) is None
def test_sanitize_user_input_whitespace_only():
with pytest.raises(ValueError, match="cannot be empty"):
sanitize_user_input(" ", "test")
def test_sanitize_user_input_strips_whitespace():
assert sanitize_user_input(" hello ", "test") == "hello"
def test_sanitize_user_input_too_long():
with pytest.raises(ValueError, match="too long"):
sanitize_user_input("a" * 256, "test", max_length=255)
def test_sanitize_user_input_max_length_ok():
result = sanitize_user_input("a" * 255, "test", max_length=255)
assert result == "a" * 255
def test_sanitize_user_input_blocks_javascript():
with pytest.raises(ValueError, match="malicious URL scheme"):
sanitize_user_input("javascript:alert(1)", "test")
def test_sanitize_user_input_blocks_event_handler():
with pytest.raises(ValueError, match="malicious event handler"):
sanitize_user_input("onclick=alert(1)", "test")
def test_sanitize_user_input_sql_keywords_not_checked_by_default():
result = sanitize_user_input("DROP TABLE", "test")
assert result == "DROP TABLE"
def test_sanitize_user_input_sql_keywords_checked_when_enabled():
with pytest.raises(ValueError, match="unsafe SQL keywords"):
sanitize_user_input("DROP TABLE users", "test", check_sql_keywords=True)
def test_sanitize_user_input_removes_zero_width_chars():
result = sanitize_user_input("hel\u200blo", "test")
assert result == "hello"
def test_sanitize_user_input_xss_entity_encoded():
"""Entity-encoded XSS attempts must be neutralized."""
result = sanitize_user_input("&lt;script&gt;alert(1)&lt;/script&gt;", "test")
assert "<script>" not in result
def test_sanitize_user_input_entity_encoded_javascript():
"""Entity-encoded javascript: scheme should be caught after decoding."""
with pytest.raises(ValueError, match="malicious URL scheme"):
sanitize_user_input("&#106;avascript:alert(1)", "test")
# --- sanitize_filter_value tests ---
def test_sanitize_filter_value_integer():
assert sanitize_filter_value(42) == 42
def test_sanitize_filter_value_float():
assert sanitize_filter_value(3.14) == 3.14
def test_sanitize_filter_value_bool():
assert sanitize_filter_value(True) is True
def test_sanitize_filter_value_plain_string():
assert sanitize_filter_value("hello") == "hello"
def test_sanitize_filter_value_preserves_ampersand():
assert sanitize_filter_value("A & B") == "A & B"
def test_sanitize_filter_value_strips_html():
assert sanitize_filter_value("<b>hello</b>") == "hello"
def test_sanitize_filter_value_too_long():
with pytest.raises(ValueError, match="too long"):
sanitize_filter_value("a" * 1001)
def test_sanitize_filter_value_blocks_javascript():
with pytest.raises(ValueError, match="malicious URL scheme"):
sanitize_filter_value("javascript:alert(1)")
def test_sanitize_filter_value_blocks_xp_cmdshell():
with pytest.raises(ValueError, match="malicious SQL procedures"):
sanitize_filter_value("xp_cmdshell('dir')")
def test_sanitize_filter_value_blocks_sp_executesql():
with pytest.raises(ValueError, match="malicious SQL procedures"):
sanitize_filter_value("sp_executesql @stmt")
def test_sanitize_filter_value_blocks_union_select():
with pytest.raises(ValueError, match="malicious SQL patterns"):
sanitize_filter_value("' UNION SELECT * FROM users")
def test_sanitize_filter_value_blocks_sql_comment():
with pytest.raises(ValueError, match="malicious SQL patterns"):
sanitize_filter_value("value -- drop")
def test_sanitize_filter_value_blocks_shell_semicolon():
with pytest.raises(ValueError, match="unsafe shell characters"):
sanitize_filter_value("value;rm -rf")
def test_sanitize_filter_value_blocks_shell_pipe():
with pytest.raises(ValueError, match="unsafe shell characters"):
sanitize_filter_value("value|cat /etc/passwd")
def test_sanitize_filter_value_blocks_backtick():
with pytest.raises(ValueError, match="unsafe shell characters"):
sanitize_filter_value("`whoami`")
def test_sanitize_filter_value_blocks_hex_encoding():
with pytest.raises(ValueError, match="hex encoding"):
sanitize_filter_value("\\x41\\x42")
def test_sanitize_filter_value_allows_ampersand_alone():
"""Ampersand is safe in filter values (only dangerous in shell contexts)."""
assert sanitize_filter_value("AT&T") == "AT&T"
def test_sanitize_filter_value_removes_zero_width_chars():
result = sanitize_filter_value("hel\u200blo")
assert result == "hello"
def test_sanitize_filter_value_blocks_or_injection():
with pytest.raises(ValueError, match="malicious SQL patterns"):
sanitize_filter_value("' OR '1'='1")
def test_sanitize_filter_value_blocks_and_injection():
with pytest.raises(ValueError, match="malicious SQL patterns"):
sanitize_filter_value("' AND '1'='1")
def test_sanitize_filter_value_blocks_block_comment():
with pytest.raises(ValueError, match="malicious SQL patterns"):
sanitize_filter_value("value /* comment */")
def test_sanitize_filter_value_blocks_semicolon_drop():
with pytest.raises(ValueError, match="malicious SQL patterns"):
sanitize_filter_value("; DROP TABLE users")
def test_sanitize_filter_value_blocks_parentheses():
with pytest.raises(ValueError, match="unsafe shell characters"):
sanitize_filter_value("$(whoami)")
def test_sanitize_filter_value_blocks_dollar_sign():
with pytest.raises(ValueError, match="unsafe shell characters"):
sanitize_filter_value("$HOME")
def test_sanitize_filter_value_blocks_event_handler():
with pytest.raises(ValueError, match="malicious event handler"):
sanitize_filter_value("onerror=alert(1)")
def test_sanitize_filter_value_xss_entity_encoded():
"""Entity-encoded XSS in filter values must be neutralized."""
result = sanitize_filter_value("&lt;img src=x onerror=alert(1)&gt;")
assert "<img" not in result
# --- Defense-in-depth: verify html.unescape is not used after nh3 ---
def test_strip_html_tags_does_not_unescape_angle_brackets():
"""Ensure nh3 entity output is not fully unescaped back to raw HTML.
nh3.clean may pass through HTML entities (e.g. &lt;script&gt;) from
the input without stripping them. A full html.unescape() on nh3's
output could reintroduce raw angle brackets, creating an XSS vector.
"""
# Plain text passes through unchanged
result = _strip_html_tags("safe text")
assert result == "safe text"
# Verify ampersand preservation still works
result = _strip_html_tags("A & B")
assert result == "A & B"
# Verify real tags are stripped
result = _strip_html_tags("<script>alert(1)</script>")
assert "<script>" not in result
# Entity-encoded script tags must not become real tags in the output
result = _strip_html_tags("&lt;script&gt;alert(1)&lt;/script&gt;")
assert "<script>" not in result
assert "</script>" not in result
def test_strip_html_tags_img_onerror_entity_bypass():
"""Entity-encoded img/onerror should not survive sanitization."""
result = _strip_html_tags("&lt;img src=x onerror=alert(1)&gt;")
assert "<img" not in result
assert "onerror" not in result