# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """ Centralized sanitization utilities for MCP service user inputs. This module uses the nh3 library (Rust-based HTML sanitizer) to strip malicious HTML tags and protocols from user inputs. nh3 is faster and safer than manual regex-based sanitization. Key features: - Strips all HTML tags using nh3.clean() with no allowed tags - Blocks dangerous URL schemes (javascript:, vbscript:, data:) - Preserves safe text content (e.g., '&' stays as '&', not '&') - Additional SQL injection protection for database-facing inputs """ import html import re from typing import Any import nh3 LLM_CONTEXT_OPEN_DELIMITER = "" LLM_CONTEXT_CLOSE_DELIMITER = "" LLM_CONTEXT_ESCAPED_OPEN_DELIMITER = "[ESCAPED-UNTRUSTED-CONTENT-OPEN]" LLM_CONTEXT_ESCAPED_CLOSE_DELIMITER = "[ESCAPED-UNTRUSTED-CONTENT-CLOSE]" LLM_CONTEXT_EXCLUDED_FIELD_NAMES = frozenset( { "cache_key", "database", "database_name", "schema", "schema_name", "slug", "url", "urls", "uuid", } ) def _normalize_field_name(field_name: str) -> str: """Normalize a field name for exclusion matching.""" return field_name.strip().lower().replace("-", "_") def _escape_llm_context_delimiters(value: str) -> str: """Escape delimiter tokens without wrapping the value.""" return value.replace( LLM_CONTEXT_OPEN_DELIMITER, LLM_CONTEXT_ESCAPED_OPEN_DELIMITER, ).replace( LLM_CONTEXT_CLOSE_DELIMITER, LLM_CONTEXT_ESCAPED_CLOSE_DELIMITER, ) def _escape_llm_context_dict_key(key: Any) -> Any: """Escape delimiter tokens in string dict keys.""" if isinstance(key, str): return _escape_llm_context_delimiters(key) return key def escape_llm_context_delimiters(value: Any) -> Any: """Escape delimiter tokens in operational values that should not be wrapped.""" if isinstance(value, str): return _escape_llm_context_delimiters(value) if isinstance(value, dict): return { _escape_llm_context_dict_key(key): escape_llm_context_delimiters( nested_value ) for key, nested_value in value.items() } if isinstance(value, list): return [escape_llm_context_delimiters(item) for item in value] if isinstance(value, tuple): return tuple(escape_llm_context_delimiters(item) for item in value) return value def _wrap_llm_context_string(value: str) -> str: """Wrap an untrusted string with explicit LLM-context delimiters.""" wrapped_prefix = f"{LLM_CONTEXT_OPEN_DELIMITER}\n" wrapped_suffix = f"\n{LLM_CONTEXT_CLOSE_DELIMITER}" if value.startswith(wrapped_prefix) and value.endswith(wrapped_suffix): inner_value = value[len(wrapped_prefix) : -len(wrapped_suffix)] return ( f"{wrapped_prefix}" f"{_escape_llm_context_delimiters(inner_value)}" f"{wrapped_suffix}" ) escaped_value = _escape_llm_context_delimiters(value) return ( f"{LLM_CONTEXT_OPEN_DELIMITER}\n{escaped_value}\n{LLM_CONTEXT_CLOSE_DELIMITER}" ) def sanitize_for_llm_context( value: Any, *, field_path: tuple[str, ...] = (), excluded_field_names: frozenset[str] | None = None, ) -> Any: """ Recursively wrap user-controlled strings before placing them in LLM context. Strings are wrapped in explicit untrusted-content delimiters unless the current field name is part of the shared operational exclusion policy. Container shapes and non-string values are preserved. """ excluded_names = ( LLM_CONTEXT_EXCLUDED_FIELD_NAMES if excluded_field_names is None else excluded_field_names ) normalized_exclusions = frozenset( _normalize_field_name(field_name) for field_name in excluded_names ) def _sanitize(current_value: Any, current_path: tuple[str, ...]) -> Any: current_field_name = current_path[-1] if current_path else "" if current_field_name and ( _normalize_field_name(current_field_name) in normalized_exclusions ): return escape_llm_context_delimiters(current_value) if isinstance(current_value, str): return _wrap_llm_context_string(current_value) if isinstance(current_value, dict): return { _escape_llm_context_dict_key(key): _sanitize( nested_value, (*current_path, str(key)), ) for key, nested_value in current_value.items() } if isinstance(current_value, list): return [ _sanitize(item, (*current_path, str(index))) for index, item in enumerate(current_value) ] if isinstance(current_value, tuple): return tuple( _sanitize(item, (*current_path, str(index))) for index, item in enumerate(current_value) ) return current_value return _sanitize(value, field_path) def _strip_html_tags(value: str) -> str: """ Strip all HTML tags from the input using nh3. Decodes all layers of HTML entity encoding BEFORE passing to nh3, so entity-encoded tags (e.g., ``<script>``) are decoded into real tags that nh3 can detect and strip. After nh3 removes all tags, we only restore ``&`` back to ``&`` (not a full html.unescape) to preserve ampersands in display text without risking XSS from re-introducing angle brackets or other HTML-significant characters. Args: value: The input string that may contain HTML Returns: String with all HTML tags removed and ampersands preserved """ # Decode all layers of HTML entity encoding to prevent bypass # via entity-encoded tags (e.g., <script> or &lt;script&gt;) # The loop terminates when unescape produces no change (idempotent on decoded text). # Max iterations cap provides defense-in-depth against pathological inputs. max_iterations = 100 decoded = value prev = None iterations = 0 while prev != decoded and iterations < max_iterations: prev = decoded decoded = html.unescape(decoded) iterations += 1 # nh3.clean with tags=set() strips ALL HTML tags from the decoded input # url_schemes=set() blocks all URL schemes in any remaining attributes cleaned = nh3.clean(decoded, tags=set(), url_schemes=set()) # Only restore & → & to preserve ampersands in display text (e.g. "A & B"). # Do NOT use html.unescape() here: nh3 may pass through HTML entities from # the input (e.g. <script>), and a full unescape would re-introduce # raw angle brackets, creating an XSS vector. return cleaned.replace("&", "&") def _check_dangerous_patterns(value: str, field_name: str) -> None: """ Check for dangerous patterns that nh3 doesn't catch. This includes URL schemes in plain text (not in HTML attributes), event handler patterns, and dangerous Unicode characters. Args: value: The input string to check field_name: Name of the field (for error messages) Raises: ValueError: If dangerous patterns are found """ # Block dangerous URL schemes in plain text (word boundary check) if re.search(r"\b(javascript|vbscript|data):", value, re.IGNORECASE): raise ValueError(f"{field_name} contains potentially malicious URL scheme") # Block event handler patterns (onclick=, onerror=, etc.) if re.search(r"on\w+\s*=", value, re.IGNORECASE): raise ValueError(f"{field_name} contains potentially malicious event handler") def _check_sql_patterns(value: str, field_name: str) -> None: """ Check for SQL injection patterns. Args: value: The input string to check field_name: Name of the field (for error messages) Raises: ValueError: If SQL injection patterns are found """ # Check for dangerous SQL keywords if re.search( r"\b(DROP|DELETE|INSERT|UPDATE|CREATE|ALTER|EXEC|EXECUTE)\b", value, re.IGNORECASE, ): raise ValueError(f"{field_name} contains potentially unsafe SQL keywords") # Check for shell metacharacters and SQL comments if re.search(r"[;|&$`]|--", value): raise ValueError(f"{field_name} contains potentially unsafe characters") # Check for SQL comment start if "/*" in value: raise ValueError(f"{field_name} contains potentially unsafe SQL comment syntax") def _remove_dangerous_unicode(value: str) -> str: """ Remove dangerous Unicode characters (zero-width, control chars). Args: value: The input string Returns: String with dangerous Unicode characters removed """ return re.sub( r"[\u200B-\u200D\uFEFF\u0000-\u0008\u000B\u000C\u000E-\u001F]", "", value ) def sanitize_user_input_with_changes( value: str | None, field_name: str, max_length: int = 255, check_sql_keywords: bool = False, allow_empty: bool = False, ) -> tuple[str | None, bool]: """ Sanitize and report whether the value was modified. Same security guarantees as ``sanitize_user_input`` — returns both the sanitized value and a boolean indicating whether any characters were stripped or altered. Callers that need to surface a warning when user-provided content is silently removed (e.g. XSS payloads) should use this variant instead of ``sanitize_user_input``. """ original_stripped = value.strip() if isinstance(value, str) else value sanitized = sanitize_user_input( value, field_name, max_length=max_length, check_sql_keywords=check_sql_keywords, allow_empty=allow_empty, ) was_modified = original_stripped != (sanitized or "") and bool(original_stripped) return sanitized, was_modified def sanitize_user_input( value: str | None, field_name: str, max_length: int = 255, check_sql_keywords: bool = False, allow_empty: bool = False, ) -> str | None: """ Centralized sanitization for user-provided text inputs. Uses nh3 to strip HTML tags and performs additional security checks. Args: value: The input string to sanitize field_name: Name of the field (for error messages) max_length: Maximum allowed length check_sql_keywords: Whether to check for SQL injection keywords allow_empty: Whether to allow empty/None values Returns: Sanitized string, or None if allow_empty=True and value is empty Raises: ValueError: If value fails security validation Security checks performed: - Strips all HTML tags using nh3 (Rust-based sanitizer) - Blocks JavaScript/VBScript/data URL schemes - Blocks event handlers (onclick=, onerror=, etc.) - Removes dangerous Unicode characters (zero-width, control chars) - SQL keywords and shell metacharacters (when check_sql_keywords=True) """ if value is None: if allow_empty: return None raise ValueError(f"{field_name} cannot be empty") value = value.strip() if not value: if allow_empty: return None raise ValueError(f"{field_name} cannot be empty") # Length check first to prevent ReDoS attacks if len(value) > max_length: raise ValueError( f"{field_name} too long ({len(value)} characters). " f"Maximum allowed length is {max_length} characters." ) # Strip all HTML tags using nh3 value = _strip_html_tags(value) # Check for dangerous patterns (URL schemes, event handlers) _check_dangerous_patterns(value, field_name) # SQL keyword and shell metacharacter checks (for column names, etc.) if check_sql_keywords: _check_sql_patterns(value, field_name) # Remove dangerous Unicode characters value = _remove_dangerous_unicode(value) return value def sanitize_filter_value( value: str | int | float | bool, max_length: int = 1000, ) -> str | int | float | bool: """ Sanitize filter values which can be strings or other types. For non-string values, returns as-is (no sanitization needed). For strings, uses nh3 to strip HTML and applies security validation. Args: value: The filter value (string, int, float, or bool) max_length: Maximum length for string values Returns: Sanitized value Raises: ValueError: If string value fails security validation """ if not isinstance(value, str): return value value = value.strip() # Length check first if len(value) > max_length: raise ValueError( f"Filter value too long ({len(value)} characters). " f"Maximum allowed length is {max_length} characters." ) # Strip all HTML tags using nh3 value = _strip_html_tags(value) # Check for dangerous patterns _check_dangerous_patterns(value, "Filter value") # Check for dangerous SQL procedures (filter-specific) v_lower = value.lower() if "xp_cmdshell" in v_lower or "sp_executesql" in v_lower: raise ValueError("Filter value contains potentially malicious SQL procedures.") # SQL injection patterns specific to filter values sql_patterns = [ r";\s*(DROP|DELETE|INSERT|UPDATE|CREATE|ALTER|EXEC|EXECUTE)\b", r"'\s*OR\s*'", r"'\s*AND\s*'", r"--\s*", r"/\*", r"UNION\s+SELECT", ] for pattern in sql_patterns: if re.search(pattern, value, re.IGNORECASE): raise ValueError( "Filter value contains potentially malicious SQL patterns." ) # Check for shell metacharacters that could indicate injection attempts # Note: We allow '&' alone as it's common in text ("A & B") and is only # dangerous in shell contexts, not in database queries if re.search(r"[;|`$()]", value): raise ValueError("Filter value contains potentially unsafe shell characters.") # Check for hex encoding if re.search(r"\\x[0-9a-fA-F]{2}", value): raise ValueError("Filter value contains hex encoding which is not allowed.") # Remove dangerous Unicode characters value = _remove_dangerous_unicode(value) return value