diff --git a/superset/commands/chart/data/streaming_export_command.py b/superset/commands/chart/data/streaming_export_command.py index 7dec6bc1d41..6c0dbbd09b4 100644 --- a/superset/commands/chart/data/streaming_export_command.py +++ b/superset/commands/chart/data/streaming_export_command.py @@ -55,20 +55,23 @@ class StreamingCSVExportCommand(BaseStreamingCSVExportCommand): """Validate permissions and query context.""" self._query_context.raise_for_access() - def _get_sql_and_database(self) -> tuple[str, Any]: + def _get_sql_and_database(self) -> tuple[str, Any, str | None, str | None]: """ - Get the SQL query and database for chart export. + Get the SQL query, database, catalog, and schema for chart export. Returns: - Tuple of (sql_query, database_object) + Tuple of (sql_query, database_object, catalog, schema) """ # Get datasource and generate SQL query # Note: datasource should already be attached to a session from query_context datasource = self._query_context.datasource query_obj = self._query_context.queries[0] sql_query = datasource.get_query_str(query_obj.to_dict()) + database = getattr(datasource, "database", None) + catalog = getattr(datasource, "catalog", None) + schema = getattr(datasource, "schema", None) - return sql_query, getattr(datasource, "database", None) + return sql_query, database, catalog, schema def _get_row_limit(self) -> int | None: """ diff --git a/superset/commands/sql_lab/streaming_export_command.py b/superset/commands/sql_lab/streaming_export_command.py index 6b6585ac442..17685355a1f 100644 --- a/superset/commands/sql_lab/streaming_export_command.py +++ b/superset/commands/sql_lab/streaming_export_command.py @@ -87,12 +87,12 @@ class StreamingSqlResultExportCommand(BaseStreamingCSVExportCommand): status=403, ) from ex - def _get_sql_and_database(self) -> tuple[str, Any]: + def _get_sql_and_database(self) -> tuple[str, Any, str | None, str | None]: """ - Get the SQL query and database for SQL Lab export. + Get the SQL query, database, catalog, and schema for SQL Lab export. Returns: - Tuple of (sql_query, database_object) + Tuple of (sql_query, database_object, catalog, schema) """ assert self._query is not None @@ -103,7 +103,7 @@ class StreamingSqlResultExportCommand(BaseStreamingCSVExportCommand): # Get the SQL query sql = select_sql or executed_sql - return sql, database + return sql, database, self._query.catalog, self._query.schema def _get_row_limit(self) -> int | None: """ diff --git a/superset/commands/streaming_export/base.py b/superset/commands/streaming_export/base.py index 1393cf2b193..39b7d233c34 100644 --- a/superset/commands/streaming_export/base.py +++ b/superset/commands/streaming_export/base.py @@ -79,12 +79,12 @@ class BaseStreamingCSVExportCommand(BaseCommand): self._current_app = app._get_current_object() @abstractmethod - def _get_sql_and_database(self) -> tuple[str, Any]: + def _get_sql_and_database(self) -> tuple[str, Any, str | None, str | None]: """ - Get the SQL query and database for execution. + Get the SQL query, database, catalog, and schema for execution. Returns: - Tuple of (sql_query, database_object) + Tuple of (sql_query, database_object, catalog, schema) """ @abstractmethod @@ -150,7 +150,12 @@ class BaseStreamingCSVExportCommand(BaseCommand): yield remaining_data, row_count, data_bytes def _execute_query_and_stream( - self, sql: str, database: Any, limit: int | None + self, + sql: str, + database: Any, + limit: int | None, + catalog: str | None = None, + schema: str | None = None, ) -> Generator[str, None, None]: """Execute query with streaming and yield CSV chunks.""" start_time = time.time() @@ -160,8 +165,9 @@ class BaseStreamingCSVExportCommand(BaseCommand): # Merge database to prevent DetachedInstanceError merged_database = session.merge(database) - # Execute query with streaming - with merged_database.get_sqla_engine() as engine: + with merged_database.get_sqla_engine( + catalog=catalog, schema=schema + ) as engine: with engine.connect() as connection: result_proxy = connection.execution_options( stream_results=True @@ -209,7 +215,7 @@ class BaseStreamingCSVExportCommand(BaseCommand): """ # Load all needed data while session is still active # to avoid DetachedInstanceError - sql, database = self._get_sql_and_database() + sql, database, catalog, schema = self._get_sql_and_database() limit = self._get_row_limit() # Capture flask.g attributes to preserve request-scoped data # when the streaming generator runs in a new app context. @@ -222,7 +228,9 @@ class BaseStreamingCSVExportCommand(BaseCommand): with self._current_app.app_context(): with preserve_g_context(captured_g): try: - yield from self._execute_query_and_stream(sql, database, limit) + yield from self._execute_query_and_stream( + sql, database, limit, catalog, schema + ) except Exception as e: logger.error("Error in streaming CSV generator: %s", e) import traceback diff --git a/superset/models/core.py b/superset/models/core.py index 612ad6aaf37..1f99630ab3d 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -468,13 +468,34 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint: engine_context_manager = app.config["ENGINE_CONTEXT_MANAGER"] with engine_context_manager(self, catalog, schema): with check_for_oauth2(self): - yield self._get_sqla_engine( + engine = self._get_sqla_engine( catalog=catalog, schema=schema, nullpool=nullpool, source=source, sqlalchemy_uri=sqlalchemy_uri, ) + prequeries = self.db_engine_spec.get_prequeries( + database=self, + catalog=catalog, + schema=schema, + ) + if prequeries: + # SQLAlchemy connect event: runs prequeries on every new + # DBAPI connection (e.g. SET search_path for PostgreSQL). + def run_prequeries( + dbapi_connection: Any, + connection_record: Any, # pylint: disable=unused-argument + ) -> None: + cursor = dbapi_connection.cursor() + try: + for prequery in prequeries: + cursor.execute(prequery) + finally: + cursor.close() + + sqla.event.listen(engine, "connect", run_prequeries) + yield engine def _get_sqla_engine( # pylint: disable=too-many-locals # noqa: C901 self, @@ -583,15 +604,6 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint: ) as engine: with check_for_oauth2(self): with closing(engine.raw_connection()) as conn: - # pre-session queries are used to set the selected catalog/schema - for prequery in self.db_engine_spec.get_prequeries( - database=self, - catalog=catalog, - schema=schema, - ): - cursor = conn.cursor() - cursor.execute(prequery) - yield conn def get_default_catalog(self) -> str | None: diff --git a/tests/unit_tests/commands/chart/streaming_export_command_test.py b/tests/unit_tests/commands/chart/streaming_export_command_test.py index 6096eaf5d7c..fc12da53322 100644 --- a/tests/unit_tests/commands/chart/streaming_export_command_test.py +++ b/tests/unit_tests/commands/chart/streaming_export_command_test.py @@ -25,7 +25,10 @@ from superset.commands.chart.data.streaming_export_command import ( def _setup_chart_mocks( - mocker: MockerFixture, sql: str = "SELECT * FROM test" + mocker: MockerFixture, + sql: str = "SELECT * FROM test", + catalog: str | None = None, + schema: str | None = None, ) -> tuple[MockerFixture, MockerFixture, MockerFixture]: """Set up common mocks for chart streaming export tests.""" mock_db = mocker.patch("superset.commands.streaming_export.base.db") @@ -36,6 +39,8 @@ def _setup_chart_mocks( datasource = mocker.MagicMock() datasource.get_query_str.return_value = sql datasource.database = mocker.MagicMock() + datasource.catalog = catalog + datasource.schema = schema query_context.datasource = datasource query_context.queries = [mocker.MagicMock()] mock_session.merge.return_value = datasource.database @@ -256,3 +261,38 @@ def test_empty_result_set(mocker: MockerFixture) -> None: lines = [line.strip() for line in csv_data.strip().split("\n")] assert len(lines) == 1 assert lines[0] == "col1,col2" + + +def test_catalog_and_schema_passed_to_engine(mocker: MockerFixture) -> None: + """Test that catalog and schema are forwarded to get_sqla_engine. + + Prequeries (e.g. SET search_path for PostgreSQL) are now run automatically + via a connect event listener registered inside get_sqla_engine, not by the + streaming command itself. + """ + mock_db, query_context, datasource = _setup_chart_mocks( + mocker, catalog="my_catalog", schema="my_schema" + ) + + mock_result = mocker.MagicMock() + mock_result.keys.return_value = ["col1"] + mock_result.fetchmany.side_effect = [[("val",)], []] + + mock_connection = mocker.MagicMock() + mock_connection.execution_options.return_value.execute.return_value = mock_result + mock_connection.__enter__.return_value = mock_connection + mock_connection.__exit__.return_value = None + + mock_engine = mocker.MagicMock() + mock_engine.connect.return_value = mock_connection + datasource.database.get_sqla_engine.return_value.__enter__.return_value = ( + mock_engine + ) + + command = StreamingCSVExportCommand(query_context) + list(command.run()()) + + datasource.database.get_sqla_engine.assert_called_once_with( + catalog="my_catalog", + schema="my_schema", + ) diff --git a/tests/unit_tests/commands/sql_lab/streaming_export_command_test.py b/tests/unit_tests/commands/sql_lab/streaming_export_command_test.py index 5c7d4ac2d48..d0d881b20a2 100644 --- a/tests/unit_tests/commands/sql_lab/streaming_export_command_test.py +++ b/tests/unit_tests/commands/sql_lab/streaming_export_command_test.py @@ -55,6 +55,8 @@ def mock_query(): query.select_sql = None query.executed_sql = "SELECT * FROM test_table" query.limiting_factor = LimitingFactor.NOT_LIMITED + query.catalog = None + query.schema = "public" query.database = MagicMock() query.database.db_engine_spec = MagicMock() query.database.db_engine_spec.engine = "postgresql" @@ -538,3 +540,40 @@ def test_null_values_handling(mocker, mock_query): assert "1,,100" in csv_data assert "2,test," in csv_data assert ",," in csv_data + + +def test_catalog_and_schema_passed_to_engine(mocker, mock_query, mock_result_proxy): + """Test that catalog and schema are forwarded to get_sqla_engine. + + Prequeries (e.g. SET search_path for PostgreSQL) are now run automatically + via a connect event listener registered inside get_sqla_engine, not by the + streaming command itself. + """ + mock_query.select_sql = "SELECT * FROM test" + mock_query.catalog = "my_catalog" + mock_query.schema = "my_schema" + + mock_db, mock_session = _setup_sqllab_mocks(mocker, mock_query) + + mock_connection = MagicMock() + mock_connection.execution_options.return_value.execute.return_value = ( + mock_result_proxy + ) + mock_connection.__enter__.return_value = mock_connection + mock_connection.__exit__.return_value = None + + mock_engine = MagicMock() + mock_engine.connect.return_value = mock_connection + mock_query.database.get_sqla_engine.return_value.__enter__.return_value = ( + mock_engine + ) + + command = StreamingSqlResultExportCommand("test_client_123") + command.validate() + + list(command.run()()) + + mock_query.database.get_sqla_engine.assert_called_once_with( + catalog="my_catalog", + schema="my_schema", + ) diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index 23d62218700..31df840e445 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -17,6 +17,7 @@ # pylint: disable=import-outside-toplevel from datetime import datetime +from typing import Any, Callable import pytest from flask import current_app @@ -261,21 +262,6 @@ def test_table_column_database() -> None: assert TableColumn(database=database).database is database -def test_get_prequeries(mocker: MockerFixture) -> None: - """ - Tests for ``get_prequeries``. - """ - mocker.patch.object(Database, "get_sqla_engine") - db_engine_spec = mocker.patch.object(Database, "db_engine_spec") - db_engine_spec.get_prequeries.return_value = ["set a=1", "set b=2"] - - database = Database(database_name="db") - with database.get_raw_connection() as conn: - conn.cursor().execute.assert_has_calls( - [mocker.call("set a=1"), mocker.call("set b=2")] - ) - - def test_catalog_cache() -> None: """ Test the catalog cache. @@ -634,6 +620,142 @@ def test_get_sqla_engine_user_impersonation_email(mocker: MockerFixture) -> None ) +def test_get_sqla_engine_registers_prequery_event_listener( + app_context: None, + mocker: MockerFixture, +) -> None: + """ + Test that get_sqla_engine registers a connect event listener for prequeries. + + Engines returned by get_sqla_engine must automatically execute prequeries + (e.g. SET search_path) on every new connection, so that callers don't need + to remember to call get_prequeries() themselves. + """ + + mock_engine = mocker.MagicMock() + mocker.patch.object(Database, "_get_sqla_engine", return_value=mock_engine) + db_engine_spec = mocker.patch.object(Database, "db_engine_spec") + db_engine_spec.get_prequeries.return_value = ['SET search_path = "my_schema"'] + event_listen = mocker.patch("superset.models.core.sqla.event.listen") + + database = Database(database_name="my_db", sqlalchemy_uri="postgresql://") + with database.get_sqla_engine(catalog="my_catalog", schema="my_schema"): + pass + + db_engine_spec.get_prequeries.assert_called_once_with( + database=database, + catalog="my_catalog", + schema="my_schema", + ) + event_listen.assert_called_once_with(mock_engine, "connect", mocker.ANY) + + # Call the captured closure directly to verify cursor create → execute → close. + captured_fn = event_listen.call_args[0][2] + mock_dbapi_conn = mocker.MagicMock() + mock_cursor = mocker.MagicMock() + mock_dbapi_conn.cursor.return_value = mock_cursor + captured_fn(mock_dbapi_conn, None) + mock_cursor.execute.assert_called_once_with('SET search_path = "my_schema"') + mock_cursor.close.assert_called_once() + + +def test_get_sqla_engine_prequery_cursor_closed_on_exception( + app_context: None, + mocker: MockerFixture, +) -> None: + """ + Test that the cursor is always closed even when a prequery raises. + """ + mock_engine = mocker.MagicMock() + mocker.patch.object(Database, "_get_sqla_engine", return_value=mock_engine) + db_engine_spec = mocker.patch.object(Database, "db_engine_spec") + db_engine_spec.get_prequeries.return_value = ['SET search_path = "bad_schema"'] + event_listen = mocker.patch("superset.models.core.sqla.event.listen") + + database = Database(database_name="my_db", sqlalchemy_uri="postgresql://") + with database.get_sqla_engine(catalog=None, schema="bad_schema"): + pass + + captured_fn = event_listen.call_args[0][2] + mock_dbapi_conn = mocker.MagicMock() + mock_cursor = mocker.MagicMock() + mock_cursor.execute.side_effect = Exception("invalid schema") + mock_dbapi_conn.cursor.return_value = mock_cursor + + with pytest.raises(Exception, match="invalid schema"): + captured_fn(mock_dbapi_conn, None) + + mock_cursor.close.assert_called_once() + + +def test_get_sqla_engine_no_prequeries_no_event_listener( + app_context: None, + mocker: MockerFixture, +) -> None: + """ + Test that get_sqla_engine does not register an event listener when there + are no prequeries. + """ + mock_engine = mocker.MagicMock() + mocker.patch.object(Database, "_get_sqla_engine", return_value=mock_engine) + db_engine_spec = mocker.patch.object(Database, "db_engine_spec") + db_engine_spec.get_prequeries.return_value = [] + event_listen = mocker.patch("superset.models.core.sqla.event.listen") + + database = Database(database_name="my_db", sqlalchemy_uri="postgresql://") + with database.get_sqla_engine(catalog=None, schema=None): + pass + + event_listen.assert_not_called() + + +def test_get_raw_connection_executes_prequeries_exactly_once( + app_context: None, + mocker: MockerFixture, +) -> None: + """ + Test that get_raw_connection() runs prequeries exactly once through the + connect event listener registered by get_sqla_engine(). + + Previously get_raw_connection() had its own manual prequery loop AND + called get_sqla_engine() (which registers the listener), so prequeries + ran twice. After removing the manual loop the listener is the sole + execution point — this test proves exactly-once semantics. + """ + mock_engine = mocker.MagicMock() + mocker.patch.object(Database, "_get_sqla_engine", return_value=mock_engine) + db_engine_spec = mocker.patch.object(Database, "db_engine_spec") + prequery = 'SET search_path = "my_schema"' + db_engine_spec.get_prequeries.return_value = [prequery] + + # Capture the closure registered via sqla.event.listen. + captured_listeners: list[Callable[..., None]] = [] + original_listen = mocker.patch("superset.models.core.sqla.event.listen") + original_listen.side_effect = lambda engine, event, fn: captured_listeners.append( + fn + ) + + # Simulate SQLAlchemy firing the "connect" event when raw_connection() is called. + mock_dbapi_conn = mocker.MagicMock() + mock_cursor = mocker.MagicMock() + mock_dbapi_conn.cursor.return_value = mock_cursor + + def raw_connection_side_effect() -> Any: + for listener in captured_listeners: + listener(mock_dbapi_conn, None) + return mock_dbapi_conn + + mock_engine.raw_connection.side_effect = raw_connection_side_effect + + database = Database(database_name="my_db", sqlalchemy_uri="postgresql://") + with database.get_raw_connection(schema="my_schema"): + pass + + # Exactly one prequery, exactly once — not twice, not zero. + mock_cursor.execute.assert_called_once_with(prequery) + mock_cursor.close.assert_called_once() + + def test_is_oauth2_enabled() -> None: """ Test the `is_oauth2_enabled` method.