mirror of
https://github.com/apache/superset.git
synced 2026-05-25 09:45:18 +00:00
fix(sqllab): execute prequeries on streaming connection to fix PostgreSQL CSV export (#40194)
Co-authored-by: Matt Fitzgerald <matt.fitzgerald@preset.io> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable
|
||||
|
||||
import pytest
|
||||
from flask import current_app
|
||||
@@ -261,21 +262,6 @@ def test_table_column_database() -> None:
|
||||
assert TableColumn(database=database).database is database
|
||||
|
||||
|
||||
def test_get_prequeries(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Tests for ``get_prequeries``.
|
||||
"""
|
||||
mocker.patch.object(Database, "get_sqla_engine")
|
||||
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
|
||||
db_engine_spec.get_prequeries.return_value = ["set a=1", "set b=2"]
|
||||
|
||||
database = Database(database_name="db")
|
||||
with database.get_raw_connection() as conn:
|
||||
conn.cursor().execute.assert_has_calls(
|
||||
[mocker.call("set a=1"), mocker.call("set b=2")]
|
||||
)
|
||||
|
||||
|
||||
def test_catalog_cache() -> None:
|
||||
"""
|
||||
Test the catalog cache.
|
||||
@@ -634,6 +620,142 @@ def test_get_sqla_engine_user_impersonation_email(mocker: MockerFixture) -> None
|
||||
)
|
||||
|
||||
|
||||
def test_get_sqla_engine_registers_prequery_event_listener(
|
||||
app_context: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that get_sqla_engine registers a connect event listener for prequeries.
|
||||
|
||||
Engines returned by get_sqla_engine must automatically execute prequeries
|
||||
(e.g. SET search_path) on every new connection, so that callers don't need
|
||||
to remember to call get_prequeries() themselves.
|
||||
"""
|
||||
|
||||
mock_engine = mocker.MagicMock()
|
||||
mocker.patch.object(Database, "_get_sqla_engine", return_value=mock_engine)
|
||||
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
|
||||
db_engine_spec.get_prequeries.return_value = ['SET search_path = "my_schema"']
|
||||
event_listen = mocker.patch("superset.models.core.sqla.event.listen")
|
||||
|
||||
database = Database(database_name="my_db", sqlalchemy_uri="postgresql://")
|
||||
with database.get_sqla_engine(catalog="my_catalog", schema="my_schema"):
|
||||
pass
|
||||
|
||||
db_engine_spec.get_prequeries.assert_called_once_with(
|
||||
database=database,
|
||||
catalog="my_catalog",
|
||||
schema="my_schema",
|
||||
)
|
||||
event_listen.assert_called_once_with(mock_engine, "connect", mocker.ANY)
|
||||
|
||||
# Call the captured closure directly to verify cursor create → execute → close.
|
||||
captured_fn = event_listen.call_args[0][2]
|
||||
mock_dbapi_conn = mocker.MagicMock()
|
||||
mock_cursor = mocker.MagicMock()
|
||||
mock_dbapi_conn.cursor.return_value = mock_cursor
|
||||
captured_fn(mock_dbapi_conn, None)
|
||||
mock_cursor.execute.assert_called_once_with('SET search_path = "my_schema"')
|
||||
mock_cursor.close.assert_called_once()
|
||||
|
||||
|
||||
def test_get_sqla_engine_prequery_cursor_closed_on_exception(
|
||||
app_context: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that the cursor is always closed even when a prequery raises.
|
||||
"""
|
||||
mock_engine = mocker.MagicMock()
|
||||
mocker.patch.object(Database, "_get_sqla_engine", return_value=mock_engine)
|
||||
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
|
||||
db_engine_spec.get_prequeries.return_value = ['SET search_path = "bad_schema"']
|
||||
event_listen = mocker.patch("superset.models.core.sqla.event.listen")
|
||||
|
||||
database = Database(database_name="my_db", sqlalchemy_uri="postgresql://")
|
||||
with database.get_sqla_engine(catalog=None, schema="bad_schema"):
|
||||
pass
|
||||
|
||||
captured_fn = event_listen.call_args[0][2]
|
||||
mock_dbapi_conn = mocker.MagicMock()
|
||||
mock_cursor = mocker.MagicMock()
|
||||
mock_cursor.execute.side_effect = Exception("invalid schema")
|
||||
mock_dbapi_conn.cursor.return_value = mock_cursor
|
||||
|
||||
with pytest.raises(Exception, match="invalid schema"):
|
||||
captured_fn(mock_dbapi_conn, None)
|
||||
|
||||
mock_cursor.close.assert_called_once()
|
||||
|
||||
|
||||
def test_get_sqla_engine_no_prequeries_no_event_listener(
|
||||
app_context: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that get_sqla_engine does not register an event listener when there
|
||||
are no prequeries.
|
||||
"""
|
||||
mock_engine = mocker.MagicMock()
|
||||
mocker.patch.object(Database, "_get_sqla_engine", return_value=mock_engine)
|
||||
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
|
||||
db_engine_spec.get_prequeries.return_value = []
|
||||
event_listen = mocker.patch("superset.models.core.sqla.event.listen")
|
||||
|
||||
database = Database(database_name="my_db", sqlalchemy_uri="postgresql://")
|
||||
with database.get_sqla_engine(catalog=None, schema=None):
|
||||
pass
|
||||
|
||||
event_listen.assert_not_called()
|
||||
|
||||
|
||||
def test_get_raw_connection_executes_prequeries_exactly_once(
|
||||
app_context: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that get_raw_connection() runs prequeries exactly once through the
|
||||
connect event listener registered by get_sqla_engine().
|
||||
|
||||
Previously get_raw_connection() had its own manual prequery loop AND
|
||||
called get_sqla_engine() (which registers the listener), so prequeries
|
||||
ran twice. After removing the manual loop the listener is the sole
|
||||
execution point — this test proves exactly-once semantics.
|
||||
"""
|
||||
mock_engine = mocker.MagicMock()
|
||||
mocker.patch.object(Database, "_get_sqla_engine", return_value=mock_engine)
|
||||
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
|
||||
prequery = 'SET search_path = "my_schema"'
|
||||
db_engine_spec.get_prequeries.return_value = [prequery]
|
||||
|
||||
# Capture the closure registered via sqla.event.listen.
|
||||
captured_listeners: list[Callable[..., None]] = []
|
||||
original_listen = mocker.patch("superset.models.core.sqla.event.listen")
|
||||
original_listen.side_effect = lambda engine, event, fn: captured_listeners.append(
|
||||
fn
|
||||
)
|
||||
|
||||
# Simulate SQLAlchemy firing the "connect" event when raw_connection() is called.
|
||||
mock_dbapi_conn = mocker.MagicMock()
|
||||
mock_cursor = mocker.MagicMock()
|
||||
mock_dbapi_conn.cursor.return_value = mock_cursor
|
||||
|
||||
def raw_connection_side_effect() -> Any:
|
||||
for listener in captured_listeners:
|
||||
listener(mock_dbapi_conn, None)
|
||||
return mock_dbapi_conn
|
||||
|
||||
mock_engine.raw_connection.side_effect = raw_connection_side_effect
|
||||
|
||||
database = Database(database_name="my_db", sqlalchemy_uri="postgresql://")
|
||||
with database.get_raw_connection(schema="my_schema"):
|
||||
pass
|
||||
|
||||
# Exactly one prequery, exactly once — not twice, not zero.
|
||||
mock_cursor.execute.assert_called_once_with(prequery)
|
||||
mock_cursor.close.assert_called_once()
|
||||
|
||||
|
||||
def test_is_oauth2_enabled() -> None:
|
||||
"""
|
||||
Test the `is_oauth2_enabled` method.
|
||||
|
||||
Reference in New Issue
Block a user