fix(mcp): prevent DATE_TRUNC on non-temporal columns in chart generation (#37433)

This commit is contained in:
Amin Ghadersohi
2026-02-05 10:24:31 -07:00
committed by GitHub
parent a9dca529c1
commit 4147d877fc
7 changed files with 449 additions and 31 deletions

View File

@@ -17,14 +17,17 @@
"""Tests for chart utilities module"""
from typing import Any
from unittest.mock import patch
import pytest
from superset.mcp_service.chart.chart_utils import (
configure_temporal_handling,
create_metric_object,
generate_chart_name,
generate_explore_link,
is_column_truly_temporal,
map_config_to_form_data,
map_filter_operator,
map_table_config,
@@ -38,6 +41,7 @@ from superset.mcp_service.chart.schemas import (
TableChartConfig,
XYChartConfig,
)
from superset.utils.core import GenericDataType
class TestCreateMetricObject:
@@ -423,7 +427,9 @@ class TestGenerateExploreLink:
mock_get_base_url.return_value = "https://superset.example.com"
form_data = {"viz_type": "table", "metrics": ["count"]}
result = generate_explore_link("123", form_data)
# Mock dataset not found to trigger fallback URL
with patch("superset.daos.dataset.DatasetDAO.find_by_id", return_value=None):
result = generate_explore_link("123", form_data)
# Should use the configured base URL - use urlparse to avoid CodeQL warning
parsed_url = urlparse(result)
@@ -471,13 +477,15 @@ class TestGenerateExploreLink:
@patch("superset.mcp_service.chart.chart_utils.get_superset_base_url")
def test_generate_explore_link_exception_handling(self, mock_get_base_url) -> None:
"""Test generate_explore_link handles exceptions gracefully"""
"""Test generate_explore_link handles SQLAlchemy exceptions gracefully"""
from sqlalchemy.exc import SQLAlchemyError
mock_get_base_url.return_value = "http://localhost:9001"
# Mock exception during form_data creation
# Mock SQLAlchemy exception during dataset lookup
with patch(
"superset.daos.dataset.DatasetDAO.find_by_id",
side_effect=Exception("DB error"),
side_effect=SQLAlchemyError("DB error"),
):
result = generate_explore_link("123", {"viz_type": "table"})
@@ -605,3 +613,268 @@ class TestCriticalBugFixes:
form_data = map_xy_config(config)
assert form_data["viz_type"] == expected_viz_type
class TestIsColumnTrulyTemporal:
"""Test is_column_truly_temporal function using db_engine_spec"""
def _create_mock_dataset(
self,
column_name: str,
column_type: str,
generic_type: GenericDataType,
):
"""Helper to create a mock dataset with proper db_engine_spec"""
from unittest.mock import MagicMock
from superset.utils.core import ColumnSpec
mock_column = MagicMock()
mock_column.column_name = column_name
mock_column.type = column_type
mock_db_engine_spec = MagicMock()
mock_column_spec = ColumnSpec(
sqla_type=MagicMock(), generic_type=generic_type, is_dttm=False
)
mock_db_engine_spec.get_column_spec.return_value = mock_column_spec
mock_database = MagicMock()
mock_database.db_engine_spec = mock_db_engine_spec
mock_dataset = MagicMock()
mock_dataset.columns = [mock_column]
mock_dataset.database = mock_database
return mock_dataset
def test_returns_true_when_no_dataset_id(self) -> None:
"""Test returns True (default) when dataset_id is None"""
result = is_column_truly_temporal("year", None)
assert result is True
@patch("superset.daos.dataset.DatasetDAO")
def test_returns_true_when_dataset_not_found(self, mock_dao) -> None:
"""Test returns True when dataset is not found"""
mock_dao.find_by_id.return_value = None
result = is_column_truly_temporal("year", 123)
assert result is True
@patch("superset.daos.dataset.DatasetDAO")
def test_returns_false_for_numeric_column(self, mock_dao) -> None:
"""Test returns False for NUMERIC generic type (e.g., BIGINT)"""
mock_dataset = self._create_mock_dataset(
"year", "BIGINT", GenericDataType.NUMERIC
)
mock_dao.find_by_id.return_value = mock_dataset
result = is_column_truly_temporal("year", 123)
assert result is False
@patch("superset.daos.dataset.DatasetDAO")
def test_returns_false_for_integer_column(self, mock_dao) -> None:
"""Test returns False for INTEGER column (NUMERIC generic type)"""
mock_dataset = self._create_mock_dataset(
"month", "INTEGER", GenericDataType.NUMERIC
)
mock_dao.find_by_id.return_value = mock_dataset
result = is_column_truly_temporal("month", 123)
assert result is False
@patch("superset.daos.dataset.DatasetDAO")
def test_returns_true_for_temporal_column(self, mock_dao) -> None:
"""Test returns True for TEMPORAL generic type (e.g., TIMESTAMP)"""
mock_dataset = self._create_mock_dataset(
"created_at", "TIMESTAMP", GenericDataType.TEMPORAL
)
mock_dao.find_by_id.return_value = mock_dataset
result = is_column_truly_temporal("created_at", 123)
assert result is True
@patch("superset.daos.dataset.DatasetDAO")
def test_returns_true_for_date_column(self, mock_dao) -> None:
"""Test returns True for DATE column (TEMPORAL generic type)"""
mock_dataset = self._create_mock_dataset(
"order_date", "DATE", GenericDataType.TEMPORAL
)
mock_dao.find_by_id.return_value = mock_dataset
result = is_column_truly_temporal("order_date", 123)
assert result is True
@patch("superset.daos.dataset.DatasetDAO")
def test_case_insensitive_column_name_lookup(self, mock_dao) -> None:
"""Test column name lookup is case insensitive"""
mock_dataset = self._create_mock_dataset(
"Year", "BIGINT", GenericDataType.NUMERIC
)
mock_dao.find_by_id.return_value = mock_dataset
result = is_column_truly_temporal("year", 123)
assert result is False
@patch("superset.daos.dataset.DatasetDAO")
def test_returns_true_on_value_error(self, mock_dao) -> None:
"""Test returns True (default) when ValueError occurs"""
mock_dao.find_by_id.side_effect = ValueError("Invalid ID")
result = is_column_truly_temporal("year", 123)
assert result is True
@patch("superset.daos.dataset.DatasetDAO")
def test_returns_true_on_attribute_error(self, mock_dao) -> None:
"""Test returns True (default) when AttributeError occurs"""
mock_dao.find_by_id.side_effect = AttributeError("Missing attribute")
result = is_column_truly_temporal("year", 123)
assert result is True
@patch("superset.daos.dataset.DatasetDAO")
def test_handles_uuid_dataset_id(self, mock_dao) -> None:
"""Test handles UUID string as dataset_id"""
mock_dataset = self._create_mock_dataset(
"year", "BIGINT", GenericDataType.NUMERIC
)
mock_dao.find_by_id.return_value = mock_dataset
result = is_column_truly_temporal("year", "abc-123-uuid")
assert result is False
mock_dao.find_by_id.assert_called_with("abc-123-uuid", id_column="uuid")
@patch("superset.daos.dataset.DatasetDAO")
def test_falls_back_to_is_dttm_when_no_column_spec(self, mock_dao) -> None:
"""Test falls back to is_dttm flag when get_column_spec returns None"""
from unittest.mock import MagicMock
mock_column = MagicMock()
mock_column.column_name = "year"
mock_column.type = "UNKNOWN_TYPE"
mock_column.is_dttm = False
mock_db_engine_spec = MagicMock()
mock_db_engine_spec.get_column_spec.return_value = None
mock_database = MagicMock()
mock_database.db_engine_spec = mock_db_engine_spec
mock_dataset = MagicMock()
mock_dataset.columns = [mock_column]
mock_dataset.database = mock_database
mock_dao.find_by_id.return_value = mock_dataset
result = is_column_truly_temporal("year", 123)
assert result is False
@patch("superset.daos.dataset.DatasetDAO")
def test_falls_back_to_is_dttm_when_no_type(self, mock_dao) -> None:
"""Test falls back to is_dttm flag when column has no type"""
from unittest.mock import MagicMock
mock_column = MagicMock()
mock_column.column_name = "year"
mock_column.type = None
mock_column.is_dttm = True
mock_dataset = MagicMock()
mock_dataset.columns = [mock_column]
mock_dao.find_by_id.return_value = mock_dataset
result = is_column_truly_temporal("year", 123)
assert result is True
class TestConfigureTemporalHandling:
"""Test configure_temporal_handling function"""
def test_temporal_column_with_time_grain(self) -> None:
"""Test temporal column sets time_grain_sqla"""
form_data: dict[str, Any] = {}
configure_temporal_handling(form_data, x_is_temporal=True, time_grain="P1M")
assert form_data["time_grain_sqla"] == "P1M"
def test_temporal_column_without_time_grain(self) -> None:
"""Test temporal column without time_grain doesn't set time_grain_sqla"""
form_data: dict[str, Any] = {}
configure_temporal_handling(form_data, x_is_temporal=True, time_grain=None)
assert "time_grain_sqla" not in form_data
def test_non_temporal_column_sets_categorical_config(self) -> None:
"""Test non-temporal column sets categorical configuration"""
form_data: dict[str, Any] = {}
configure_temporal_handling(form_data, x_is_temporal=False, time_grain=None)
assert form_data["x_axis_sort_series_type"] == "name"
assert form_data["x_axis_sort_series_ascending"] is True
assert form_data["time_grain_sqla"] is None
assert form_data["granularity_sqla"] is None
def test_non_temporal_column_ignores_time_grain(self) -> None:
"""Test non-temporal column ignores time_grain parameter"""
form_data: dict[str, Any] = {}
configure_temporal_handling(form_data, x_is_temporal=False, time_grain="P1M")
# Should still set categorical config, not time_grain
assert form_data["time_grain_sqla"] is None
assert form_data["x_axis_sort_series_type"] == "name"
class TestMapXYConfigWithNonTemporalColumn:
"""Test map_xy_config with non-temporal columns (DATE_TRUNC fix)"""
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
def test_non_temporal_column_disables_time_grain(self, mock_is_temporal) -> None:
"""Test non-temporal column sets categorical config"""
mock_is_temporal.return_value = False
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="year"),
y=[ColumnRef(name="sales", aggregate="SUM")],
kind="bar",
)
result = map_xy_config(config, dataset_id=123)
assert result["x_axis"] == "year"
assert result["x_axis_sort_series_type"] == "name"
assert result["x_axis_sort_series_ascending"] is True
assert result["time_grain_sqla"] is None
assert result["granularity_sqla"] is None
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
def test_temporal_column_allows_time_grain(self, mock_is_temporal) -> None:
"""Test temporal column allows time_grain to be set"""
mock_is_temporal.return_value = True
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="created_at"),
y=[ColumnRef(name="count", aggregate="COUNT")],
kind="line",
time_grain="P1W",
)
result = map_xy_config(config, dataset_id=123)
assert result["x_axis"] == "created_at"
assert result["time_grain_sqla"] == "P1W"
assert "x_axis_sort_series_type" not in result
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
def test_non_temporal_ignores_time_grain_param(self, mock_is_temporal) -> None:
"""Test non-temporal column ignores time_grain even if specified"""
mock_is_temporal.return_value = False
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="year"),
y=[ColumnRef(name="sales", aggregate="SUM")],
kind="bar",
time_grain="P1M", # This should be ignored for non-temporal
)
result = map_xy_config(config, dataset_id=123)
# time_grain_sqla should be None, not P1M
assert result["time_grain_sqla"] is None
assert result["x_axis_sort_series_type"] == "name"

View File

@@ -68,9 +68,30 @@ def mock_webdriver_baseurl(app_context):
def _mock_dataset(id: int = 1) -> Mock:
"""Create a mock dataset object."""
"""Create a mock dataset object with columns and db_engine_spec."""
from superset.utils.core import ColumnSpec, GenericDataType
# Create mock column that appears temporal
mock_column = Mock()
mock_column.column_name = "date"
mock_column.type = "TIMESTAMP"
# Create mock db_engine_spec
mock_db_engine_spec = Mock()
mock_column_spec = ColumnSpec(
sqla_type=Mock(), generic_type=GenericDataType.TEMPORAL, is_dttm=True
)
mock_db_engine_spec.get_column_spec.return_value = mock_column_spec
# Create mock database
mock_database = Mock()
mock_database.db_engine_spec = mock_db_engine_spec
# Create dataset with all required attributes
dataset = Mock()
dataset.id = id
dataset.columns = [mock_column]
dataset.database = mock_database
return dataset