diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index 6b5ed699672..d02f3c7de02 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -21,8 +21,6 @@ Pydantic schemas for chart-related responses from __future__ import annotations -import html -import re from datetime import datetime, timezone from typing import Annotated, Any, Dict, List, Literal, Protocol @@ -50,6 +48,10 @@ from superset.mcp_service.system.schemas import ( TagInfo, UserInfo, ) +from superset.mcp_service.utils.sanitization import ( + sanitize_filter_value, + sanitize_user_input, +) class ChartLike(Protocol): @@ -357,113 +359,17 @@ class ColumnRef(BaseModel): @classmethod def sanitize_name(cls, v: str) -> str: """Sanitize column name to prevent XSS and SQL injection.""" - if not v or not v.strip(): - raise ValueError("Column name cannot be empty") - - # Length check first to prevent ReDoS attacks - if len(v) > 255: - raise ValueError( - f"Column name too long ({len(v)} characters). " - f"Maximum allowed length is 255 characters." - ) - - # Remove HTML tags and decode entities - sanitized = html.escape(v.strip()) - - # Check for dangerous HTML tags using substring checks (safe) - dangerous_tags = ["", " str | None: """Sanitize display label to prevent XSS attacks.""" - if v is None: - return v - - # Strip whitespace - v = v.strip() - if not v: - return None - - # Length check first to prevent ReDoS attacks - if len(v) > 500: - raise ValueError( - f"Label too long ({len(v)} characters). " - f"Maximum allowed length is 500 characters." - ) - - # Check for dangerous HTML tags and JavaScript protocols using substring checks - # This avoids ReDoS vulnerabilities from regex patterns - dangerous_tags = [ - "", - "", - "", - "", - " str: """Sanitize filter column name to prevent injection attacks.""" - if not v or not v.strip(): - raise ValueError("Filter column name cannot be empty") - - # Length check first to prevent ReDoS attacks - if len(v) > 255: - raise ValueError( - f"Filter column name too long ({len(v)} characters). " - f"Maximum allowed length is 255 characters." - ) - - # Remove HTML tags and decode entities - sanitized = html.escape(v.strip()) - - # Check for dangerous HTML tags using substring checks (safe) - dangerous_tags = [""] - v_lower = v.lower() - for tag in dangerous_tags: - if tag in v_lower: - raise ValueError( - "Filter column contains potentially malicious script content" - ) - - # Check URL schemes with word boundaries - if re.search(r"\b(javascript|vbscript|data):", v, re.IGNORECASE): - raise ValueError("Filter column contains potentially malicious URL scheme") - - return sanitized - - @staticmethod - def _validate_string_value(v: str) -> None: - """Validate string filter value for security issues.""" - # Check for dangerous HTML tags and SQL procedures - dangerous_substrings = [ - "", - " str | int | float | bool: """Sanitize filter value to prevent XSS and SQL injection attacks.""" - if isinstance(v, str): - v = v.strip() - - # Length check FIRST to prevent ReDoS attacks - if len(v) > 1000: - raise ValueError( - f"Filter value too long ({len(v)} characters). " - f"Maximum allowed length is 1000 characters." - ) - - # Validate security - cls._validate_string_value(v) - - # Filter dangerous Unicode characters - v = re.sub( - r"[\u200B-\u200D\uFEFF\u0000-\u0008\u000B\u000C\u000E-\u001F]", "", v - ) - - # HTML escape the cleaned content - return html.escape(v) - - return v # Return non-string values as-is + return sanitize_filter_value(v, max_length=1000) # Actual chart types @@ -848,6 +657,11 @@ class ListChartsRequest(MetadataCacheControl): class GenerateChartRequest(QueryCacheControl): dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)") config: ChartConfig = Field(..., description="Chart configuration") + chart_name: str | None = Field( + None, + description="Custom chart name (optional, auto-generates if not provided)", + max_length=255, + ) save_chart: bool = Field( default=False, description="Whether to permanently save the chart in Superset", @@ -861,6 +675,12 @@ class GenerateChartRequest(QueryCacheControl): description="List of preview formats to generate", ) + @field_validator("chart_name") + @classmethod + def sanitize_chart_name(cls, v: str | None) -> str | None: + """Sanitize chart name to prevent XSS attacks.""" + return sanitize_user_input(v, "Chart name", max_length=255, allow_empty=True) + @model_validator(mode="after") def validate_cache_timeout(self) -> "GenerateChartRequest": """Validate cache timeout is non-negative.""" @@ -911,62 +731,7 @@ class UpdateChartRequest(QueryCacheControl): @classmethod def sanitize_chart_name(cls, v: str | None) -> str | None: """Sanitize chart name to prevent XSS attacks.""" - if v is None: - return v - - # Strip whitespace - v = v.strip() - if not v: - return None - - # Length check first to prevent ReDoS attacks - if len(v) > 255: - raise ValueError( - f"Chart name too long ({len(v)} characters). " - f"Maximum allowed length is 255 characters." - ) - - # Check for dangerous HTML tags using substring checks (safe) - dangerous_tags = [ - "", - "", - "", - "", - " str: + """ + Strip all HTML tags from the input using nh3. + + Decodes all layers of HTML entity encoding BEFORE passing to nh3, + so entity-encoded tags (e.g., ``<script>``) are decoded into + real tags that nh3 can detect and strip. After nh3 removes all tags, + we only restore ``&`` back to ``&`` (not a full html.unescape) + to preserve ampersands in display text without risking XSS from + re-introducing angle brackets or other HTML-significant characters. + + Args: + value: The input string that may contain HTML + + Returns: + String with all HTML tags removed and ampersands preserved + """ + # Decode all layers of HTML entity encoding to prevent bypass + # via entity-encoded tags (e.g., <script> or &lt;script&gt;) + # The loop terminates when unescape produces no change (idempotent on decoded text). + # Max iterations cap provides defense-in-depth against pathological inputs. + max_iterations = 100 + decoded = value + prev = None + iterations = 0 + while prev != decoded and iterations < max_iterations: + prev = decoded + decoded = html.unescape(decoded) + iterations += 1 + + # nh3.clean with tags=set() strips ALL HTML tags from the decoded input + # url_schemes=set() blocks all URL schemes in any remaining attributes + cleaned = nh3.clean(decoded, tags=set(), url_schemes=set()) + + # Only restore & → & to preserve ampersands in display text (e.g. "A & B"). + # Do NOT use html.unescape() here: nh3 may pass through HTML entities from + # the input (e.g. <script>), and a full unescape would re-introduce + # raw angle brackets, creating an XSS vector. + return cleaned.replace("&", "&") + + +def _check_dangerous_patterns(value: str, field_name: str) -> None: + """ + Check for dangerous patterns that nh3 doesn't catch. + + This includes URL schemes in plain text (not in HTML attributes), + event handler patterns, and dangerous Unicode characters. + + Args: + value: The input string to check + field_name: Name of the field (for error messages) + + Raises: + ValueError: If dangerous patterns are found + """ + # Block dangerous URL schemes in plain text (word boundary check) + if re.search(r"\b(javascript|vbscript|data):", value, re.IGNORECASE): + raise ValueError(f"{field_name} contains potentially malicious URL scheme") + + # Block event handler patterns (onclick=, onerror=, etc.) + if re.search(r"on\w+\s*=", value, re.IGNORECASE): + raise ValueError(f"{field_name} contains potentially malicious event handler") + + +def _check_sql_patterns(value: str, field_name: str) -> None: + """ + Check for SQL injection patterns. + + Args: + value: The input string to check + field_name: Name of the field (for error messages) + + Raises: + ValueError: If SQL injection patterns are found + """ + # Check for dangerous SQL keywords + if re.search( + r"\b(DROP|DELETE|INSERT|UPDATE|CREATE|ALTER|EXEC|EXECUTE)\b", + value, + re.IGNORECASE, + ): + raise ValueError(f"{field_name} contains potentially unsafe SQL keywords") + + # Check for shell metacharacters and SQL comments + if re.search(r"[;|&$`]|--", value): + raise ValueError(f"{field_name} contains potentially unsafe characters") + + # Check for SQL comment start + if "/*" in value: + raise ValueError(f"{field_name} contains potentially unsafe SQL comment syntax") + + +def _remove_dangerous_unicode(value: str) -> str: + """ + Remove dangerous Unicode characters (zero-width, control chars). + + Args: + value: The input string + + Returns: + String with dangerous Unicode characters removed + """ + return re.sub( + r"[\u200B-\u200D\uFEFF\u0000-\u0008\u000B\u000C\u000E-\u001F]", "", value + ) + + +def sanitize_user_input( + value: str | None, + field_name: str, + max_length: int = 255, + check_sql_keywords: bool = False, + allow_empty: bool = False, +) -> str | None: + """ + Centralized sanitization for user-provided text inputs. + + Uses nh3 to strip HTML tags and performs additional security checks. + + Args: + value: The input string to sanitize + field_name: Name of the field (for error messages) + max_length: Maximum allowed length + check_sql_keywords: Whether to check for SQL injection keywords + allow_empty: Whether to allow empty/None values + + Returns: + Sanitized string, or None if allow_empty=True and value is empty + + Raises: + ValueError: If value fails security validation + + Security checks performed: + - Strips all HTML tags using nh3 (Rust-based sanitizer) + - Blocks JavaScript/VBScript/data URL schemes + - Blocks event handlers (onclick=, onerror=, etc.) + - Removes dangerous Unicode characters (zero-width, control chars) + - SQL keywords and shell metacharacters (when check_sql_keywords=True) + """ + if value is None: + if allow_empty: + return None + raise ValueError(f"{field_name} cannot be empty") + + value = value.strip() + + if not value: + if allow_empty: + return None + raise ValueError(f"{field_name} cannot be empty") + + # Length check first to prevent ReDoS attacks + if len(value) > max_length: + raise ValueError( + f"{field_name} too long ({len(value)} characters). " + f"Maximum allowed length is {max_length} characters." + ) + + # Strip all HTML tags using nh3 + value = _strip_html_tags(value) + + # Check for dangerous patterns (URL schemes, event handlers) + _check_dangerous_patterns(value, field_name) + + # SQL keyword and shell metacharacter checks (for column names, etc.) + if check_sql_keywords: + _check_sql_patterns(value, field_name) + + # Remove dangerous Unicode characters + value = _remove_dangerous_unicode(value) + + return value + + +def sanitize_filter_value( + value: str | int | float | bool, + max_length: int = 1000, +) -> str | int | float | bool: + """ + Sanitize filter values which can be strings or other types. + + For non-string values, returns as-is (no sanitization needed). + For strings, uses nh3 to strip HTML and applies security validation. + + Args: + value: The filter value (string, int, float, or bool) + max_length: Maximum length for string values + + Returns: + Sanitized value + + Raises: + ValueError: If string value fails security validation + """ + if not isinstance(value, str): + return value + + value = value.strip() + + # Length check first + if len(value) > max_length: + raise ValueError( + f"Filter value too long ({len(value)} characters). " + f"Maximum allowed length is {max_length} characters." + ) + + # Strip all HTML tags using nh3 + value = _strip_html_tags(value) + + # Check for dangerous patterns + _check_dangerous_patterns(value, "Filter value") + + # Check for dangerous SQL procedures (filter-specific) + v_lower = value.lower() + if "xp_cmdshell" in v_lower or "sp_executesql" in v_lower: + raise ValueError("Filter value contains potentially malicious SQL procedures.") + + # SQL injection patterns specific to filter values + sql_patterns = [ + r";\s*(DROP|DELETE|INSERT|UPDATE|CREATE|ALTER|EXEC|EXECUTE)\b", + r"'\s*OR\s*'", + r"'\s*AND\s*'", + r"--\s*", + r"/\*", + r"UNION\s+SELECT", + ] + for pattern in sql_patterns: + if re.search(pattern, value, re.IGNORECASE): + raise ValueError( + "Filter value contains potentially malicious SQL patterns." + ) + + # Check for shell metacharacters that could indicate injection attempts + # Note: We allow '&' alone as it's common in text ("A & B") and is only + # dangerous in shell contexts, not in database queries + if re.search(r"[;|`$()]", value): + raise ValueError("Filter value contains potentially unsafe shell characters.") + + # Check for hex encoding + if re.search(r"\\x[0-9a-fA-F]{2}", value): + raise ValueError("Filter value contains hex encoding which is not allowed.") + + # Remove dangerous Unicode characters + value = _remove_dangerous_unicode(value) + + return value diff --git a/tests/unit_tests/mcp_service/utils/test_sanitization.py b/tests/unit_tests/mcp_service/utils/test_sanitization.py new file mode 100644 index 00000000000..330cc2fb7d2 --- /dev/null +++ b/tests/unit_tests/mcp_service/utils/test_sanitization.py @@ -0,0 +1,480 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from superset.mcp_service.utils.sanitization import ( + _check_dangerous_patterns, + _check_sql_patterns, + _remove_dangerous_unicode, + _strip_html_tags, + sanitize_filter_value, + sanitize_user_input, +) + +# --- _strip_html_tags tests --- + + +def test_strip_html_tags_plain_text(): + assert _strip_html_tags("hello world") == "hello world" + + +def test_strip_html_tags_preserves_ampersand(): + assert _strip_html_tags("A & B") == "A & B" + + +def test_strip_html_tags_preserves_multiple_ampersands(): + assert _strip_html_tags("A & B & C") == "A & B & C" + + +def test_strip_html_tags_strips_bold_tags(): + assert _strip_html_tags("hello") == "hello" + + +def test_strip_html_tags_strips_script_tags(): + result = _strip_html_tags("") + assert "" not in result + + +def test_strip_html_tags_strips_entity_encoded_script(): + """Entity-encoded tags must be decoded and stripped, not passed through.""" + result = _strip_html_tags("<script>alert(1)</script>") + assert "" + for _ in range(10): + value = value.replace("&", "&").replace("<", "<").replace(">", ">") + result = _strip_html_tags(value) + assert "") + assert "" not in result + + +def test_strip_html_tags_img_onerror_entity_bypass(): + """Entity-encoded img/onerror should not survive sanitization.""" + result = _strip_html_tags("<img src=x onerror=alert(1)>") + assert "