# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """ Shared fixtures and helpers for SQL execution tests. This module provides common mocks, fixtures, and helper functions used across test_celery_task.py and test_executor.py to reduce code duplication. """ from contextlib import contextmanager from typing import Any from unittest.mock import MagicMock import pandas as pd import pytest from flask import current_app from pytest_mock import MockerFixture from superset.common.db_query_status import QueryStatus as QueryStatusEnum from superset.models.core import Database # ============================================================================= # Core Fixtures # ============================================================================= @pytest.fixture(autouse=True) def mock_db_session(mocker: MockerFixture) -> MagicMock: """Mock database session for all tests to avoid foreign key constraints.""" mock_session = MagicMock() mocker.patch("superset.sql.execution.executor.db.session", mock_session) mocker.patch("superset.sql.execution.celery_task.db.session", mock_session) return mock_session @pytest.fixture def mock_query() -> MagicMock: """Create a mock Query model.""" query = MagicMock() query.id = 123 query.database_id = 1 query.sql = "SELECT * FROM users" query.status = QueryStatusEnum.PENDING query.error_message = None query.progress = 0 query.end_time = None query.start_running_time = None query.executed_sql = None query.tmp_table_name = None query.catalog = None query.schema = "public" query.extra = {} query.set_extra_json_key = MagicMock() query.results_key = None query.select_as_cta = False query.rows = 0 query.to_dict = MagicMock(return_value={"id": 123}) query.database = MagicMock() query.database.db_engine_spec.extract_errors.return_value = [] query.database.unique_name = "test_db" query.database.cache_timeout = 300 return query @pytest.fixture def mock_database() -> MagicMock: """Create a mock Database.""" database = MagicMock() database.id = 1 database.unique_name = "test_db" database.cache_timeout = 300 database.sqlalchemy_uri = "postgresql://localhost/test" database.db_engine_spec = MagicMock() database.db_engine_spec.engine = "postgresql" database.db_engine_spec.run_multiple_statements_as_one = False database.db_engine_spec.allows_sql_comments = True database.db_engine_spec.extract_errors = MagicMock(return_value=[]) database.db_engine_spec.execute_with_cursor = MagicMock() database.db_engine_spec.get_cancel_query_id = MagicMock(return_value=None) database.db_engine_spec.patch = MagicMock() database.db_engine_spec.fetch_data = MagicMock(return_value=[]) return database @pytest.fixture def mock_result_set() -> MagicMock: """Create a mock SupersetResultSet.""" result_set = MagicMock() result_set.size = 2 result_set.columns = [{"name": "id"}, {"name": "name"}] result_set.pa_table = MagicMock() result_set.to_pandas_df = MagicMock( return_value=pd.DataFrame({"id": [1, 2], "name": ["Alice", "Bob"]}) ) return result_set @pytest.fixture def database() -> Database: """Create a real test database instance.""" return Database( id=1, database_name="test_db", sqlalchemy_uri="sqlite://", allow_dml=False, ) @pytest.fixture def database_with_dml() -> Database: """Create a real test database instance with DML allowed.""" return Database( id=2, database_name="test_db_dml", sqlalchemy_uri="sqlite://", allow_dml=True, ) # ============================================================================= # Helper Functions for Mock Creation # ============================================================================= def create_mock_cursor( column_names: list[str], data: list[tuple[Any, ...]] | None = None, ) -> MagicMock: """ Create a mock database cursor with column description. Args: column_names: List of column names data: Optional data to return from fetchall() Returns: Configured MagicMock cursor """ mock_cursor = MagicMock() mock_cursor.description = [ (name, None, None, None, None, None, None) for name in column_names ] if data is not None: mock_cursor.fetchall.return_value = data return mock_cursor def create_mock_connection(mock_cursor: MagicMock | None = None) -> MagicMock: """ Create a mock database connection. Args: mock_cursor: Optional cursor to return from cursor() Returns: Configured MagicMock connection with context manager support """ if mock_cursor is None: mock_cursor = create_mock_cursor([]) mock_conn = MagicMock() mock_conn.cursor.return_value = mock_cursor mock_conn.close = MagicMock() mock_conn.__enter__ = MagicMock(return_value=mock_conn) mock_conn.__exit__ = MagicMock(return_value=False) return mock_conn def setup_mock_raw_connection( mock_database: MagicMock, mock_connection: MagicMock | None = None, ) -> MagicMock: """ Setup get_raw_connection as a context manager on a mock database. Args: mock_database: The database mock to configure mock_connection: Optional connection to yield Returns: The configured mock connection """ if mock_connection is None: mock_connection = create_mock_connection() @contextmanager def _raw_connection( catalog: str | None = None, schema: str | None = None, nullpool: bool = True, source: Any | None = None, ): yield mock_connection mock_database.get_raw_connection = _raw_connection return mock_connection def setup_db_session_query_mock( mock_db_session: MagicMock, return_value: Any = None, ) -> None: """ Setup database session query chain for query lookup. Args: mock_db_session: The database session mock return_value: Value to return from one_or_none() """ filter_mock = mock_db_session.query.return_value.filter_by.return_value filter_mock.one_or_none.return_value = return_value def mock_query_execution( mocker: MockerFixture, database: Database, return_data: list[tuple[Any, ...]], column_names: list[str], ) -> MagicMock: """ Mock the raw connection execution path for testing. This helper sets up all necessary mocks for executing a query through the database engine spec and returning results. Args: mocker: pytest-mock fixture database: Database instance to mock return_data: Data to return from fetch_data, e.g. [(1, "Alice"), (2, "Bob")] column_names: Column names for the result, e.g. ["id", "name"] Returns: The mock for get_raw_connection, so tests can make assertions on it """ from superset.result_set import SupersetResultSet # Mock cursor and connection mock_cursor = create_mock_cursor(column_names, return_data) mock_conn = create_mock_connection(mock_cursor) get_raw_conn_mock = mocker.patch.object( database, "get_raw_connection", return_value=mock_conn ) mocker.patch.object( database, "mutate_sql_based_on_config", side_effect=lambda sql, **kw: sql ) mocker.patch.object(database.db_engine_spec, "execute") mocker.patch.object(database.db_engine_spec, "fetch_data", return_value=return_data) # Create a real SupersetResultSet that converts to DataFrame properly mock_result_set = MagicMock(spec=SupersetResultSet) mock_result_set.to_pandas_df.return_value = pd.DataFrame( return_data, columns=column_names ) mocker.patch("superset.result_set.SupersetResultSet", return_value=mock_result_set) return get_raw_conn_mock # ============================================================================= # Composite Fixtures for Common Test Patterns # ============================================================================= @pytest.fixture def default_sql_config(mocker: MockerFixture) -> None: """Patch app config with default SQL execution settings.""" mocker.patch.dict( current_app.config, { "SQL_QUERY_MUTATOR": None, "SQLLAB_TIMEOUT": 30, "SQL_MAX_ROW": None, "QUERY_LOGGER": None, }, ) @pytest.fixture def mock_celery_task(mocker: MockerFixture) -> MagicMock: """Mock the Celery task for SQL execution.""" return mocker.patch("superset.sql.execution.celery_task.execute_sql_task") def setup_cache_mocks( mocker: MockerFixture, get_result: Any = None, store_side_effect: Any = None, ) -> tuple[MagicMock, MagicMock]: """ Setup cache get/store mocks for executor tests. Args: mocker: pytest-mock fixture get_result: Value to return from _get_from_cache store_side_effect: Optional side effect for _store_in_cache Returns: Tuple of (mock_get, mock_store) """ from superset.sql.execution.executor import SQLExecutor mock_get = mocker.patch.object( SQLExecutor, "_get_from_cache", return_value=get_result ) mock_store = mocker.patch.object( SQLExecutor, "_store_in_cache", side_effect=store_side_effect ) return mock_get, mock_store