mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
fix(mcp): fix dashboard slug null and execute_sql encoding error (#38710)
This commit is contained in:
@@ -23,6 +23,7 @@ and response conversion logic.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
@@ -883,3 +884,208 @@ class TestExecuteSql:
|
||||
options = call_args[0][1]
|
||||
assert options.cache is not None
|
||||
assert options.cache.force_refresh is True
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_bytes_in_dataframe(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test that bytes/memoryview values in DataFrame are sanitized for JSON.
|
||||
|
||||
Regression test: execute_sql fails with 'encoding without a string
|
||||
argument' when queries return binary/bytea data.
|
||||
"""
|
||||
mock_database = _mock_database()
|
||||
df = pd.DataFrame(
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"name": "test",
|
||||
"utf8_data": b"hello world",
|
||||
"binary_data": b"\x00\x01\x02\xff",
|
||||
},
|
||||
]
|
||||
)
|
||||
mock_database.execute.return_value = QueryResult(
|
||||
status=QueryStatus.SUCCESS,
|
||||
statements=[
|
||||
StatementResult(
|
||||
original_sql="SELECT * FROM files",
|
||||
executed_sql="SELECT * FROM files",
|
||||
data=df,
|
||||
row_count=1,
|
||||
execution_time_ms=5.0,
|
||||
)
|
||||
],
|
||||
query_id=None,
|
||||
total_execution_time_ms=5.0,
|
||||
is_cached=False,
|
||||
)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT * FROM files",
|
||||
"limit": 10,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
data = result.structured_content
|
||||
assert data["success"] is True
|
||||
assert data["row_count"] == 1
|
||||
row = data["rows"][0]
|
||||
# UTF-8 decodable bytes should become string
|
||||
assert row["utf8_data"] == "hello world"
|
||||
# Non-UTF-8 bytes should become hex
|
||||
assert row["binary_data"] == "000102ff"
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_decimal_in_dataframe(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test that Decimal values in DataFrame are converted to float for JSON.
|
||||
|
||||
Regression test: execute_sql fails with 'encoding without a string
|
||||
argument' when queries return Decimal types (common with SUM/AVG).
|
||||
"""
|
||||
mock_database = _mock_database()
|
||||
df = pd.DataFrame(
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"price": Decimal("19.99"),
|
||||
"total": Decimal("1234567.89"),
|
||||
},
|
||||
]
|
||||
)
|
||||
mock_database.execute.return_value = QueryResult(
|
||||
status=QueryStatus.SUCCESS,
|
||||
statements=[
|
||||
StatementResult(
|
||||
original_sql="SELECT * FROM orders",
|
||||
executed_sql="SELECT * FROM orders",
|
||||
data=df,
|
||||
row_count=1,
|
||||
execution_time_ms=5.0,
|
||||
)
|
||||
],
|
||||
query_id=None,
|
||||
total_execution_time_ms=5.0,
|
||||
is_cached=False,
|
||||
)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT * FROM orders",
|
||||
"limit": 10,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
data = result.structured_content
|
||||
assert data["success"] is True
|
||||
assert data["row_count"] == 1
|
||||
row = data["rows"][0]
|
||||
assert row["price"] == 19.99
|
||||
assert row["total"] == 1234567.89
|
||||
assert isinstance(row["price"], float)
|
||||
|
||||
|
||||
class TestSanitizeRowValues:
|
||||
"""Unit tests for _sanitize_row_values helper function."""
|
||||
|
||||
def test_sanitize_utf8_bytes(self):
|
||||
from superset.mcp_service.sql_lab.tool.execute_sql import _sanitize_row_values
|
||||
|
||||
rows = [{"data": b"hello"}]
|
||||
_sanitize_row_values(rows)
|
||||
assert rows[0]["data"] == "hello"
|
||||
|
||||
def test_sanitize_non_utf8_bytes(self):
|
||||
from superset.mcp_service.sql_lab.tool.execute_sql import _sanitize_row_values
|
||||
|
||||
rows = [{"data": b"\x00\xff"}]
|
||||
_sanitize_row_values(rows)
|
||||
assert rows[0]["data"] == "00ff"
|
||||
|
||||
def test_sanitize_memoryview(self):
|
||||
from superset.mcp_service.sql_lab.tool.execute_sql import _sanitize_row_values
|
||||
|
||||
rows = [{"data": memoryview(b"test")}]
|
||||
_sanitize_row_values(rows)
|
||||
assert rows[0]["data"] == "test"
|
||||
|
||||
def test_sanitize_decimal(self):
|
||||
from superset.mcp_service.sql_lab.tool.execute_sql import _sanitize_row_values
|
||||
|
||||
rows = [{"price": Decimal("19.99"), "count": Decimal("42")}]
|
||||
_sanitize_row_values(rows)
|
||||
assert rows[0]["price"] == 19.99
|
||||
assert isinstance(rows[0]["price"], float)
|
||||
assert rows[0]["count"] == 42.0
|
||||
|
||||
def test_sanitize_custom_type_uses_str(self):
|
||||
from superset.mcp_service.sql_lab.tool.execute_sql import _sanitize_row_values
|
||||
|
||||
class CustomType:
|
||||
def __str__(self):
|
||||
return "custom_value"
|
||||
|
||||
rows = [{"data": CustomType()}]
|
||||
_sanitize_row_values(rows)
|
||||
assert rows[0]["data"] == "custom_value"
|
||||
|
||||
def test_preserves_json_serializable_types(self):
|
||||
from superset.mcp_service.sql_lab.tool.execute_sql import _sanitize_row_values
|
||||
|
||||
rows = [
|
||||
{
|
||||
"str_val": "hello",
|
||||
"int_val": 42,
|
||||
"float_val": 3.14,
|
||||
"bool_val": True,
|
||||
"none_val": None,
|
||||
"list_val": [1, 2],
|
||||
"dict_val": {"a": 1},
|
||||
}
|
||||
]
|
||||
original = [dict(row) for row in rows]
|
||||
_sanitize_row_values(rows)
|
||||
assert rows == original
|
||||
|
||||
def test_sanitize_empty_rows(self):
|
||||
from superset.mcp_service.sql_lab.tool.execute_sql import _sanitize_row_values
|
||||
|
||||
rows: list[dict[str, Any]] = []
|
||||
_sanitize_row_values(rows)
|
||||
assert rows == []
|
||||
|
||||
def test_sanitize_mixed_types_in_single_row(self):
|
||||
from superset.mcp_service.sql_lab.tool.execute_sql import _sanitize_row_values
|
||||
|
||||
rows = [
|
||||
{
|
||||
"id": 1,
|
||||
"name": "test",
|
||||
"price": Decimal("9.99"),
|
||||
"blob": b"\x00\x01\x02\xff",
|
||||
}
|
||||
]
|
||||
_sanitize_row_values(rows)
|
||||
assert rows[0]["id"] == 1
|
||||
assert rows[0]["name"] == "test"
|
||||
assert rows[0]["price"] == 9.99
|
||||
assert rows[0]["blob"] == "000102ff"
|
||||
|
||||
Reference in New Issue
Block a user