fix(sqllab): execute prequeries on streaming connection to fix PostgreSQL CSV export (#40194)

Co-authored-by: Matt Fitzgerald <matt.fitzgerald@preset.io>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Mafi
2026-05-18 23:43:06 +10:00
committed by GitHub
parent 61b77fa35d
commit b66c104fde
7 changed files with 266 additions and 42 deletions

View File

@@ -17,6 +17,7 @@
# pylint: disable=import-outside-toplevel
from datetime import datetime
from typing import Any, Callable
import pytest
from flask import current_app
@@ -261,21 +262,6 @@ def test_table_column_database() -> None:
assert TableColumn(database=database).database is database
def test_get_prequeries(mocker: MockerFixture) -> None:
"""
Tests for ``get_prequeries``.
"""
mocker.patch.object(Database, "get_sqla_engine")
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
db_engine_spec.get_prequeries.return_value = ["set a=1", "set b=2"]
database = Database(database_name="db")
with database.get_raw_connection() as conn:
conn.cursor().execute.assert_has_calls(
[mocker.call("set a=1"), mocker.call("set b=2")]
)
def test_catalog_cache() -> None:
"""
Test the catalog cache.
@@ -634,6 +620,142 @@ def test_get_sqla_engine_user_impersonation_email(mocker: MockerFixture) -> None
)
def test_get_sqla_engine_registers_prequery_event_listener(
app_context: None,
mocker: MockerFixture,
) -> None:
"""
Test that get_sqla_engine registers a connect event listener for prequeries.
Engines returned by get_sqla_engine must automatically execute prequeries
(e.g. SET search_path) on every new connection, so that callers don't need
to remember to call get_prequeries() themselves.
"""
mock_engine = mocker.MagicMock()
mocker.patch.object(Database, "_get_sqla_engine", return_value=mock_engine)
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
db_engine_spec.get_prequeries.return_value = ['SET search_path = "my_schema"']
event_listen = mocker.patch("superset.models.core.sqla.event.listen")
database = Database(database_name="my_db", sqlalchemy_uri="postgresql://")
with database.get_sqla_engine(catalog="my_catalog", schema="my_schema"):
pass
db_engine_spec.get_prequeries.assert_called_once_with(
database=database,
catalog="my_catalog",
schema="my_schema",
)
event_listen.assert_called_once_with(mock_engine, "connect", mocker.ANY)
# Call the captured closure directly to verify cursor create → execute → close.
captured_fn = event_listen.call_args[0][2]
mock_dbapi_conn = mocker.MagicMock()
mock_cursor = mocker.MagicMock()
mock_dbapi_conn.cursor.return_value = mock_cursor
captured_fn(mock_dbapi_conn, None)
mock_cursor.execute.assert_called_once_with('SET search_path = "my_schema"')
mock_cursor.close.assert_called_once()
def test_get_sqla_engine_prequery_cursor_closed_on_exception(
app_context: None,
mocker: MockerFixture,
) -> None:
"""
Test that the cursor is always closed even when a prequery raises.
"""
mock_engine = mocker.MagicMock()
mocker.patch.object(Database, "_get_sqla_engine", return_value=mock_engine)
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
db_engine_spec.get_prequeries.return_value = ['SET search_path = "bad_schema"']
event_listen = mocker.patch("superset.models.core.sqla.event.listen")
database = Database(database_name="my_db", sqlalchemy_uri="postgresql://")
with database.get_sqla_engine(catalog=None, schema="bad_schema"):
pass
captured_fn = event_listen.call_args[0][2]
mock_dbapi_conn = mocker.MagicMock()
mock_cursor = mocker.MagicMock()
mock_cursor.execute.side_effect = Exception("invalid schema")
mock_dbapi_conn.cursor.return_value = mock_cursor
with pytest.raises(Exception, match="invalid schema"):
captured_fn(mock_dbapi_conn, None)
mock_cursor.close.assert_called_once()
def test_get_sqla_engine_no_prequeries_no_event_listener(
app_context: None,
mocker: MockerFixture,
) -> None:
"""
Test that get_sqla_engine does not register an event listener when there
are no prequeries.
"""
mock_engine = mocker.MagicMock()
mocker.patch.object(Database, "_get_sqla_engine", return_value=mock_engine)
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
db_engine_spec.get_prequeries.return_value = []
event_listen = mocker.patch("superset.models.core.sqla.event.listen")
database = Database(database_name="my_db", sqlalchemy_uri="postgresql://")
with database.get_sqla_engine(catalog=None, schema=None):
pass
event_listen.assert_not_called()
def test_get_raw_connection_executes_prequeries_exactly_once(
app_context: None,
mocker: MockerFixture,
) -> None:
"""
Test that get_raw_connection() runs prequeries exactly once through the
connect event listener registered by get_sqla_engine().
Previously get_raw_connection() had its own manual prequery loop AND
called get_sqla_engine() (which registers the listener), so prequeries
ran twice. After removing the manual loop the listener is the sole
execution point — this test proves exactly-once semantics.
"""
mock_engine = mocker.MagicMock()
mocker.patch.object(Database, "_get_sqla_engine", return_value=mock_engine)
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
prequery = 'SET search_path = "my_schema"'
db_engine_spec.get_prequeries.return_value = [prequery]
# Capture the closure registered via sqla.event.listen.
captured_listeners: list[Callable[..., None]] = []
original_listen = mocker.patch("superset.models.core.sqla.event.listen")
original_listen.side_effect = lambda engine, event, fn: captured_listeners.append(
fn
)
# Simulate SQLAlchemy firing the "connect" event when raw_connection() is called.
mock_dbapi_conn = mocker.MagicMock()
mock_cursor = mocker.MagicMock()
mock_dbapi_conn.cursor.return_value = mock_cursor
def raw_connection_side_effect() -> Any:
for listener in captured_listeners:
listener(mock_dbapi_conn, None)
return mock_dbapi_conn
mock_engine.raw_connection.side_effect = raw_connection_side_effect
database = Database(database_name="my_db", sqlalchemy_uri="postgresql://")
with database.get_raw_connection(schema="my_schema"):
pass
# Exactly one prequery, exactly once — not twice, not zero.
mock_cursor.execute.assert_called_once_with(prequery)
mock_cursor.close.assert_called_once()
def test_is_oauth2_enabled() -> None:
"""
Test the `is_oauth2_enabled` method.