Files
superset2/tests/unit_tests/commands/databases/validate_sql_test.py

280 lines
9.1 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from superset.commands.database.exceptions import (
ValidatorSQL400Error,
ValidatorSQLError,
)
from superset.commands.database.validate_sql import ValidateSQLCommand
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
SupersetSyntaxErrorException,
SupersetTemplateException,
)
@pytest.fixture
def mock_database(mocker: MockerFixture) -> MagicMock:
"""Create a mock database with PostgreSQL engine."""
database = mocker.MagicMock()
database.id = 1
database.db_engine_spec.engine = "postgresql"
DatabaseDAO = mocker.patch( # noqa: N806
"superset.commands.database.validate_sql.DatabaseDAO"
)
DatabaseDAO.find_by_id.return_value = database
return database
@pytest.fixture
def mock_validator(mocker: MockerFixture) -> MagicMock:
"""Create a mock SQL validator."""
validator = mocker.MagicMock()
validator.name = "PostgreSQLValidator"
validator.validate.return_value = []
get_validator_by_name = mocker.patch(
"superset.commands.database.validate_sql.get_validator_by_name"
)
get_validator_by_name.return_value = validator
return validator
@pytest.fixture
def mock_config(mocker: MockerFixture) -> dict[str, Any]:
"""Mock the application config."""
config = {
"SQL_VALIDATORS_BY_ENGINE": {"postgresql": "PostgreSQLValidator"},
"SQLLAB_VALIDATION_TIMEOUT": 30,
}
mocker.patch("superset.commands.database.validate_sql.app.config", config)
return config
@pytest.fixture
def mock_template_processor(
mocker: MockerFixture, mock_database: MagicMock
) -> MagicMock:
"""Create a mock template processor."""
template_processor = mocker.MagicMock()
get_template_processor = mocker.patch(
"superset.commands.database.validate_sql.get_template_processor"
)
get_template_processor.return_value = template_processor
return template_processor
def test_validate_sql_with_jinja_templates(
mock_database: MagicMock,
mock_validator: MagicMock,
mock_template_processor: MagicMock,
mock_config: dict[str, Any],
) -> None:
"""Test that Jinja templates are rendered before SQL validation."""
sql_with_jinja = """SELECT *
FROM birth_names
WHERE 1=1
{% if city_filter is defined %}
AND city = '{{ city_filter }}'
{% endif %}
LIMIT {{ limit | default(100) }}"""
mock_template_processor.process_template.return_value = (
"SELECT *\nFROM birth_names\nWHERE 1=1\nLIMIT 100"
)
data = {"sql": sql_with_jinja, "schema": "public", "template_params": {}}
command = ValidateSQLCommand(model_id=1, data=data)
result = command.run()
mock_template_processor.process_template.assert_called_once_with(sql_with_jinja)
mock_validator.validate.assert_called_once()
assert result == []
def test_validate_sql_with_jinja_templates_and_params(
mock_database: MagicMock,
mock_validator: MagicMock,
mock_template_processor: MagicMock,
mock_config: dict[str, Any],
) -> None:
"""Test that Jinja templates are rendered with parameters before SQL validation."""
sql_with_jinja = """SELECT *
FROM birth_names
WHERE 1=1
{% if city_filter is defined %}
AND city = '{{ city_filter }}'
{% endif %}
LIMIT {{ limit }}"""
template_params = {"city_filter": "New York", "limit": 50}
mock_template_processor.process_template.return_value = (
"SELECT *\nFROM birth_names\nWHERE 1=1\n AND city = 'New York'\nLIMIT 50"
)
data = {
"sql": sql_with_jinja,
"schema": "public",
"template_params": template_params,
}
command = ValidateSQLCommand(model_id=1, data=data)
result = command.run()
mock_template_processor.process_template.assert_called_once_with(
sql_with_jinja, **template_params
)
mock_validator.validate.assert_called_once()
assert result == []
def test_validate_sql_without_jinja_templates(
mock_database: MagicMock,
mock_validator: MagicMock,
mock_template_processor: MagicMock,
mock_config: dict[str, Any],
) -> None:
"""Test that regular SQL without Jinja templates still works."""
simple_sql = "SELECT * FROM birth_names LIMIT 100"
mock_template_processor.process_template.return_value = simple_sql
data = {"sql": simple_sql, "schema": "public", "template_params": {}}
command = ValidateSQLCommand(model_id=1, data=data)
result = command.run()
mock_template_processor.process_template.assert_called_once()
mock_validator.validate.assert_called_once()
assert result == []
def test_validate_sql_template_syntax_error(
mock_database: MagicMock,
mock_validator: MagicMock,
mock_template_processor: MagicMock,
mock_config: dict[str, Any],
) -> None:
"""
Test that template syntax errors are properly surfaced to the client.
When template processing raises a SupersetSyntaxErrorException (e.g.,
invalid Jinja2 syntax, undefined variables), it should be caught and
converted to a ValidatorSQL400Error with detailed error information
including line numbers.
"""
syntax_error = SupersetError(
message="Jinja2 template error (UndefinedError): 'city_filter' is undefined",
error_type=SupersetErrorType.GENERIC_COMMAND_ERROR,
level=ErrorLevel.ERROR,
extra={"template": "SELECT * FROM...", "line": 3},
)
mock_template_processor.process_template.side_effect = SupersetSyntaxErrorException(
[syntax_error]
)
sql_with_undefined_var = """SELECT *
FROM birth_names
WHERE city = '{{ city_filter }}'
LIMIT 100"""
data = {
"sql": sql_with_undefined_var,
"schema": "public",
"template_params": {},
}
command = ValidateSQLCommand(model_id=1, data=data)
with pytest.raises(ValidatorSQL400Error) as exc_info:
command.run()
error = exc_info.value
assert error.error.message is not None
assert "'city_filter' is undefined" in error.error.message
mock_validator.validate.assert_not_called()
def test_validate_sql_template_processing_error(
mock_database: MagicMock,
mock_validator: MagicMock,
mock_template_processor: MagicMock,
mock_config: dict[str, Any],
) -> None:
"""
Test that internal template processing errors are properly surfaced to the client.
When template processing raises a SupersetTemplateException (e.g., recursion,
unexpected failures), it should be caught and converted to a ValidatorSQL400Error
with an appropriate error message.
"""
mock_template_processor.process_template.side_effect = SupersetTemplateException(
"Infinite recursion detected in template"
)
data = {
"sql": "SELECT * FROM birth_names LIMIT 100",
"schema": "public",
"template_params": {},
}
command = ValidateSQLCommand(model_id=1, data=data)
with pytest.raises(ValidatorSQL400Error) as exc_info:
command.run()
error = exc_info.value
assert error.error.message is not None
assert "Template processing failed" in error.error.message
assert "Infinite recursion" in error.error.message
mock_validator.validate.assert_not_called()
def test_validate_sql_generic_exception(
mock_database: MagicMock,
mock_validator: MagicMock,
mock_template_processor: MagicMock,
mock_config: dict[str, Any],
) -> None:
"""
Test that unexpected exceptions are still caught and handled gracefully.
When an unexpected exception occurs (not template-related), it should be caught
and converted to a ValidatorSQLError with the validator name in the message.
"""
mock_template_processor.process_template.side_effect = RuntimeError(
"Unexpected error occurred"
)
data = {
"sql": "SELECT * FROM birth_names",
"schema": "public",
"template_params": {},
}
command = ValidateSQLCommand(model_id=1, data=data)
with pytest.raises(ValidatorSQLError) as exc_info:
command.run()
error = exc_info.value
assert error.error.message is not None
assert "PostgreSQLValidator" in error.error.message
assert "Unexpected error occurred" in error.error.message
mock_validator.validate.assert_not_called()