fix(sqllab): execute prequeries on streaming connection to fix PostgreSQL CSV export (#40194)

Co-authored-by: Matt Fitzgerald <matt.fitzgerald@preset.io>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Mafi
2026-05-18 23:43:06 +10:00
committed by GitHub
parent 61b77fa35d
commit b66c104fde
7 changed files with 266 additions and 42 deletions

View File

@@ -55,20 +55,23 @@ class StreamingCSVExportCommand(BaseStreamingCSVExportCommand):
"""Validate permissions and query context."""
self._query_context.raise_for_access()
def _get_sql_and_database(self) -> tuple[str, Any]:
def _get_sql_and_database(self) -> tuple[str, Any, str | None, str | None]:
"""
Get the SQL query and database for chart export.
Get the SQL query, database, catalog, and schema for chart export.
Returns:
Tuple of (sql_query, database_object)
Tuple of (sql_query, database_object, catalog, schema)
"""
# Get datasource and generate SQL query
# Note: datasource should already be attached to a session from query_context
datasource = self._query_context.datasource
query_obj = self._query_context.queries[0]
sql_query = datasource.get_query_str(query_obj.to_dict())
database = getattr(datasource, "database", None)
catalog = getattr(datasource, "catalog", None)
schema = getattr(datasource, "schema", None)
return sql_query, getattr(datasource, "database", None)
return sql_query, database, catalog, schema
def _get_row_limit(self) -> int | None:
"""

View File

@@ -87,12 +87,12 @@ class StreamingSqlResultExportCommand(BaseStreamingCSVExportCommand):
status=403,
) from ex
def _get_sql_and_database(self) -> tuple[str, Any]:
def _get_sql_and_database(self) -> tuple[str, Any, str | None, str | None]:
"""
Get the SQL query and database for SQL Lab export.
Get the SQL query, database, catalog, and schema for SQL Lab export.
Returns:
Tuple of (sql_query, database_object)
Tuple of (sql_query, database_object, catalog, schema)
"""
assert self._query is not None
@@ -103,7 +103,7 @@ class StreamingSqlResultExportCommand(BaseStreamingCSVExportCommand):
# Get the SQL query
sql = select_sql or executed_sql
return sql, database
return sql, database, self._query.catalog, self._query.schema
def _get_row_limit(self) -> int | None:
"""

View File

@@ -79,12 +79,12 @@ class BaseStreamingCSVExportCommand(BaseCommand):
self._current_app = app._get_current_object()
@abstractmethod
def _get_sql_and_database(self) -> tuple[str, Any]:
def _get_sql_and_database(self) -> tuple[str, Any, str | None, str | None]:
"""
Get the SQL query and database for execution.
Get the SQL query, database, catalog, and schema for execution.
Returns:
Tuple of (sql_query, database_object)
Tuple of (sql_query, database_object, catalog, schema)
"""
@abstractmethod
@@ -150,7 +150,12 @@ class BaseStreamingCSVExportCommand(BaseCommand):
yield remaining_data, row_count, data_bytes
def _execute_query_and_stream(
self, sql: str, database: Any, limit: int | None
self,
sql: str,
database: Any,
limit: int | None,
catalog: str | None = None,
schema: str | None = None,
) -> Generator[str, None, None]:
"""Execute query with streaming and yield CSV chunks."""
start_time = time.time()
@@ -160,8 +165,9 @@ class BaseStreamingCSVExportCommand(BaseCommand):
# Merge database to prevent DetachedInstanceError
merged_database = session.merge(database)
# Execute query with streaming
with merged_database.get_sqla_engine() as engine:
with merged_database.get_sqla_engine(
catalog=catalog, schema=schema
) as engine:
with engine.connect() as connection:
result_proxy = connection.execution_options(
stream_results=True
@@ -209,7 +215,7 @@ class BaseStreamingCSVExportCommand(BaseCommand):
"""
# Load all needed data while session is still active
# to avoid DetachedInstanceError
sql, database = self._get_sql_and_database()
sql, database, catalog, schema = self._get_sql_and_database()
limit = self._get_row_limit()
# Capture flask.g attributes to preserve request-scoped data
# when the streaming generator runs in a new app context.
@@ -222,7 +228,9 @@ class BaseStreamingCSVExportCommand(BaseCommand):
with self._current_app.app_context():
with preserve_g_context(captured_g):
try:
yield from self._execute_query_and_stream(sql, database, limit)
yield from self._execute_query_and_stream(
sql, database, limit, catalog, schema
)
except Exception as e:
logger.error("Error in streaming CSV generator: %s", e)
import traceback

View File

@@ -468,13 +468,34 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint:
engine_context_manager = app.config["ENGINE_CONTEXT_MANAGER"]
with engine_context_manager(self, catalog, schema):
with check_for_oauth2(self):
yield self._get_sqla_engine(
engine = self._get_sqla_engine(
catalog=catalog,
schema=schema,
nullpool=nullpool,
source=source,
sqlalchemy_uri=sqlalchemy_uri,
)
prequeries = self.db_engine_spec.get_prequeries(
database=self,
catalog=catalog,
schema=schema,
)
if prequeries:
# SQLAlchemy connect event: runs prequeries on every new
# DBAPI connection (e.g. SET search_path for PostgreSQL).
def run_prequeries(
dbapi_connection: Any,
connection_record: Any, # pylint: disable=unused-argument
) -> None:
cursor = dbapi_connection.cursor()
try:
for prequery in prequeries:
cursor.execute(prequery)
finally:
cursor.close()
sqla.event.listen(engine, "connect", run_prequeries)
yield engine
def _get_sqla_engine( # pylint: disable=too-many-locals # noqa: C901
self,
@@ -583,15 +604,6 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint:
) as engine:
with check_for_oauth2(self):
with closing(engine.raw_connection()) as conn:
# pre-session queries are used to set the selected catalog/schema
for prequery in self.db_engine_spec.get_prequeries(
database=self,
catalog=catalog,
schema=schema,
):
cursor = conn.cursor()
cursor.execute(prequery)
yield conn
def get_default_catalog(self) -> str | None:

View File

@@ -25,7 +25,10 @@ from superset.commands.chart.data.streaming_export_command import (
def _setup_chart_mocks(
mocker: MockerFixture, sql: str = "SELECT * FROM test"
mocker: MockerFixture,
sql: str = "SELECT * FROM test",
catalog: str | None = None,
schema: str | None = None,
) -> tuple[MockerFixture, MockerFixture, MockerFixture]:
"""Set up common mocks for chart streaming export tests."""
mock_db = mocker.patch("superset.commands.streaming_export.base.db")
@@ -36,6 +39,8 @@ def _setup_chart_mocks(
datasource = mocker.MagicMock()
datasource.get_query_str.return_value = sql
datasource.database = mocker.MagicMock()
datasource.catalog = catalog
datasource.schema = schema
query_context.datasource = datasource
query_context.queries = [mocker.MagicMock()]
mock_session.merge.return_value = datasource.database
@@ -256,3 +261,38 @@ def test_empty_result_set(mocker: MockerFixture) -> None:
lines = [line.strip() for line in csv_data.strip().split("\n")]
assert len(lines) == 1
assert lines[0] == "col1,col2"
def test_catalog_and_schema_passed_to_engine(mocker: MockerFixture) -> None:
"""Test that catalog and schema are forwarded to get_sqla_engine.
Prequeries (e.g. SET search_path for PostgreSQL) are now run automatically
via a connect event listener registered inside get_sqla_engine, not by the
streaming command itself.
"""
mock_db, query_context, datasource = _setup_chart_mocks(
mocker, catalog="my_catalog", schema="my_schema"
)
mock_result = mocker.MagicMock()
mock_result.keys.return_value = ["col1"]
mock_result.fetchmany.side_effect = [[("val",)], []]
mock_connection = mocker.MagicMock()
mock_connection.execution_options.return_value.execute.return_value = mock_result
mock_connection.__enter__.return_value = mock_connection
mock_connection.__exit__.return_value = None
mock_engine = mocker.MagicMock()
mock_engine.connect.return_value = mock_connection
datasource.database.get_sqla_engine.return_value.__enter__.return_value = (
mock_engine
)
command = StreamingCSVExportCommand(query_context)
list(command.run()())
datasource.database.get_sqla_engine.assert_called_once_with(
catalog="my_catalog",
schema="my_schema",
)

View File

@@ -55,6 +55,8 @@ def mock_query():
query.select_sql = None
query.executed_sql = "SELECT * FROM test_table"
query.limiting_factor = LimitingFactor.NOT_LIMITED
query.catalog = None
query.schema = "public"
query.database = MagicMock()
query.database.db_engine_spec = MagicMock()
query.database.db_engine_spec.engine = "postgresql"
@@ -538,3 +540,40 @@ def test_null_values_handling(mocker, mock_query):
assert "1,,100" in csv_data
assert "2,test," in csv_data
assert ",," in csv_data
def test_catalog_and_schema_passed_to_engine(mocker, mock_query, mock_result_proxy):
"""Test that catalog and schema are forwarded to get_sqla_engine.
Prequeries (e.g. SET search_path for PostgreSQL) are now run automatically
via a connect event listener registered inside get_sqla_engine, not by the
streaming command itself.
"""
mock_query.select_sql = "SELECT * FROM test"
mock_query.catalog = "my_catalog"
mock_query.schema = "my_schema"
mock_db, mock_session = _setup_sqllab_mocks(mocker, mock_query)
mock_connection = MagicMock()
mock_connection.execution_options.return_value.execute.return_value = (
mock_result_proxy
)
mock_connection.__enter__.return_value = mock_connection
mock_connection.__exit__.return_value = None
mock_engine = MagicMock()
mock_engine.connect.return_value = mock_connection
mock_query.database.get_sqla_engine.return_value.__enter__.return_value = (
mock_engine
)
command = StreamingSqlResultExportCommand("test_client_123")
command.validate()
list(command.run()())
mock_query.database.get_sqla_engine.assert_called_once_with(
catalog="my_catalog",
schema="my_schema",
)

View File

@@ -17,6 +17,7 @@
# pylint: disable=import-outside-toplevel
from datetime import datetime
from typing import Any, Callable
import pytest
from flask import current_app
@@ -261,21 +262,6 @@ def test_table_column_database() -> None:
assert TableColumn(database=database).database is database
def test_get_prequeries(mocker: MockerFixture) -> None:
"""
Tests for ``get_prequeries``.
"""
mocker.patch.object(Database, "get_sqla_engine")
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
db_engine_spec.get_prequeries.return_value = ["set a=1", "set b=2"]
database = Database(database_name="db")
with database.get_raw_connection() as conn:
conn.cursor().execute.assert_has_calls(
[mocker.call("set a=1"), mocker.call("set b=2")]
)
def test_catalog_cache() -> None:
"""
Test the catalog cache.
@@ -634,6 +620,142 @@ def test_get_sqla_engine_user_impersonation_email(mocker: MockerFixture) -> None
)
def test_get_sqla_engine_registers_prequery_event_listener(
app_context: None,
mocker: MockerFixture,
) -> None:
"""
Test that get_sqla_engine registers a connect event listener for prequeries.
Engines returned by get_sqla_engine must automatically execute prequeries
(e.g. SET search_path) on every new connection, so that callers don't need
to remember to call get_prequeries() themselves.
"""
mock_engine = mocker.MagicMock()
mocker.patch.object(Database, "_get_sqla_engine", return_value=mock_engine)
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
db_engine_spec.get_prequeries.return_value = ['SET search_path = "my_schema"']
event_listen = mocker.patch("superset.models.core.sqla.event.listen")
database = Database(database_name="my_db", sqlalchemy_uri="postgresql://")
with database.get_sqla_engine(catalog="my_catalog", schema="my_schema"):
pass
db_engine_spec.get_prequeries.assert_called_once_with(
database=database,
catalog="my_catalog",
schema="my_schema",
)
event_listen.assert_called_once_with(mock_engine, "connect", mocker.ANY)
# Call the captured closure directly to verify cursor create → execute → close.
captured_fn = event_listen.call_args[0][2]
mock_dbapi_conn = mocker.MagicMock()
mock_cursor = mocker.MagicMock()
mock_dbapi_conn.cursor.return_value = mock_cursor
captured_fn(mock_dbapi_conn, None)
mock_cursor.execute.assert_called_once_with('SET search_path = "my_schema"')
mock_cursor.close.assert_called_once()
def test_get_sqla_engine_prequery_cursor_closed_on_exception(
app_context: None,
mocker: MockerFixture,
) -> None:
"""
Test that the cursor is always closed even when a prequery raises.
"""
mock_engine = mocker.MagicMock()
mocker.patch.object(Database, "_get_sqla_engine", return_value=mock_engine)
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
db_engine_spec.get_prequeries.return_value = ['SET search_path = "bad_schema"']
event_listen = mocker.patch("superset.models.core.sqla.event.listen")
database = Database(database_name="my_db", sqlalchemy_uri="postgresql://")
with database.get_sqla_engine(catalog=None, schema="bad_schema"):
pass
captured_fn = event_listen.call_args[0][2]
mock_dbapi_conn = mocker.MagicMock()
mock_cursor = mocker.MagicMock()
mock_cursor.execute.side_effect = Exception("invalid schema")
mock_dbapi_conn.cursor.return_value = mock_cursor
with pytest.raises(Exception, match="invalid schema"):
captured_fn(mock_dbapi_conn, None)
mock_cursor.close.assert_called_once()
def test_get_sqla_engine_no_prequeries_no_event_listener(
app_context: None,
mocker: MockerFixture,
) -> None:
"""
Test that get_sqla_engine does not register an event listener when there
are no prequeries.
"""
mock_engine = mocker.MagicMock()
mocker.patch.object(Database, "_get_sqla_engine", return_value=mock_engine)
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
db_engine_spec.get_prequeries.return_value = []
event_listen = mocker.patch("superset.models.core.sqla.event.listen")
database = Database(database_name="my_db", sqlalchemy_uri="postgresql://")
with database.get_sqla_engine(catalog=None, schema=None):
pass
event_listen.assert_not_called()
def test_get_raw_connection_executes_prequeries_exactly_once(
app_context: None,
mocker: MockerFixture,
) -> None:
"""
Test that get_raw_connection() runs prequeries exactly once through the
connect event listener registered by get_sqla_engine().
Previously get_raw_connection() had its own manual prequery loop AND
called get_sqla_engine() (which registers the listener), so prequeries
ran twice. After removing the manual loop the listener is the sole
execution point — this test proves exactly-once semantics.
"""
mock_engine = mocker.MagicMock()
mocker.patch.object(Database, "_get_sqla_engine", return_value=mock_engine)
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
prequery = 'SET search_path = "my_schema"'
db_engine_spec.get_prequeries.return_value = [prequery]
# Capture the closure registered via sqla.event.listen.
captured_listeners: list[Callable[..., None]] = []
original_listen = mocker.patch("superset.models.core.sqla.event.listen")
original_listen.side_effect = lambda engine, event, fn: captured_listeners.append(
fn
)
# Simulate SQLAlchemy firing the "connect" event when raw_connection() is called.
mock_dbapi_conn = mocker.MagicMock()
mock_cursor = mocker.MagicMock()
mock_dbapi_conn.cursor.return_value = mock_cursor
def raw_connection_side_effect() -> Any:
for listener in captured_listeners:
listener(mock_dbapi_conn, None)
return mock_dbapi_conn
mock_engine.raw_connection.side_effect = raw_connection_side_effect
database = Database(database_name="my_db", sqlalchemy_uri="postgresql://")
with database.get_raw_connection(schema="my_schema"):
pass
# Exactly one prequery, exactly once — not twice, not zero.
mock_cursor.execute.assert_called_once_with(prequery)
mock_cursor.close.assert_called_once()
def test_is_oauth2_enabled() -> None:
"""
Test the `is_oauth2_enabled` method.