mirror of
https://github.com/apache/superset.git
synced 2026-05-21 15:55:10 +00:00
fix(sqllab): execute prequeries on streaming connection to fix PostgreSQL CSV export (#40194)
Co-authored-by: Matt Fitzgerald <matt.fitzgerald@preset.io> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -55,20 +55,23 @@ class StreamingCSVExportCommand(BaseStreamingCSVExportCommand):
|
||||
"""Validate permissions and query context."""
|
||||
self._query_context.raise_for_access()
|
||||
|
||||
def _get_sql_and_database(self) -> tuple[str, Any]:
|
||||
def _get_sql_and_database(self) -> tuple[str, Any, str | None, str | None]:
|
||||
"""
|
||||
Get the SQL query and database for chart export.
|
||||
Get the SQL query, database, catalog, and schema for chart export.
|
||||
|
||||
Returns:
|
||||
Tuple of (sql_query, database_object)
|
||||
Tuple of (sql_query, database_object, catalog, schema)
|
||||
"""
|
||||
# Get datasource and generate SQL query
|
||||
# Note: datasource should already be attached to a session from query_context
|
||||
datasource = self._query_context.datasource
|
||||
query_obj = self._query_context.queries[0]
|
||||
sql_query = datasource.get_query_str(query_obj.to_dict())
|
||||
database = getattr(datasource, "database", None)
|
||||
catalog = getattr(datasource, "catalog", None)
|
||||
schema = getattr(datasource, "schema", None)
|
||||
|
||||
return sql_query, getattr(datasource, "database", None)
|
||||
return sql_query, database, catalog, schema
|
||||
|
||||
def _get_row_limit(self) -> int | None:
|
||||
"""
|
||||
|
||||
@@ -87,12 +87,12 @@ class StreamingSqlResultExportCommand(BaseStreamingCSVExportCommand):
|
||||
status=403,
|
||||
) from ex
|
||||
|
||||
def _get_sql_and_database(self) -> tuple[str, Any]:
|
||||
def _get_sql_and_database(self) -> tuple[str, Any, str | None, str | None]:
|
||||
"""
|
||||
Get the SQL query and database for SQL Lab export.
|
||||
Get the SQL query, database, catalog, and schema for SQL Lab export.
|
||||
|
||||
Returns:
|
||||
Tuple of (sql_query, database_object)
|
||||
Tuple of (sql_query, database_object, catalog, schema)
|
||||
"""
|
||||
assert self._query is not None
|
||||
|
||||
@@ -103,7 +103,7 @@ class StreamingSqlResultExportCommand(BaseStreamingCSVExportCommand):
|
||||
# Get the SQL query
|
||||
sql = select_sql or executed_sql
|
||||
|
||||
return sql, database
|
||||
return sql, database, self._query.catalog, self._query.schema
|
||||
|
||||
def _get_row_limit(self) -> int | None:
|
||||
"""
|
||||
|
||||
@@ -79,12 +79,12 @@ class BaseStreamingCSVExportCommand(BaseCommand):
|
||||
self._current_app = app._get_current_object()
|
||||
|
||||
@abstractmethod
|
||||
def _get_sql_and_database(self) -> tuple[str, Any]:
|
||||
def _get_sql_and_database(self) -> tuple[str, Any, str | None, str | None]:
|
||||
"""
|
||||
Get the SQL query and database for execution.
|
||||
Get the SQL query, database, catalog, and schema for execution.
|
||||
|
||||
Returns:
|
||||
Tuple of (sql_query, database_object)
|
||||
Tuple of (sql_query, database_object, catalog, schema)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@@ -150,7 +150,12 @@ class BaseStreamingCSVExportCommand(BaseCommand):
|
||||
yield remaining_data, row_count, data_bytes
|
||||
|
||||
def _execute_query_and_stream(
|
||||
self, sql: str, database: Any, limit: int | None
|
||||
self,
|
||||
sql: str,
|
||||
database: Any,
|
||||
limit: int | None,
|
||||
catalog: str | None = None,
|
||||
schema: str | None = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""Execute query with streaming and yield CSV chunks."""
|
||||
start_time = time.time()
|
||||
@@ -160,8 +165,9 @@ class BaseStreamingCSVExportCommand(BaseCommand):
|
||||
# Merge database to prevent DetachedInstanceError
|
||||
merged_database = session.merge(database)
|
||||
|
||||
# Execute query with streaming
|
||||
with merged_database.get_sqla_engine() as engine:
|
||||
with merged_database.get_sqla_engine(
|
||||
catalog=catalog, schema=schema
|
||||
) as engine:
|
||||
with engine.connect() as connection:
|
||||
result_proxy = connection.execution_options(
|
||||
stream_results=True
|
||||
@@ -209,7 +215,7 @@ class BaseStreamingCSVExportCommand(BaseCommand):
|
||||
"""
|
||||
# Load all needed data while session is still active
|
||||
# to avoid DetachedInstanceError
|
||||
sql, database = self._get_sql_and_database()
|
||||
sql, database, catalog, schema = self._get_sql_and_database()
|
||||
limit = self._get_row_limit()
|
||||
# Capture flask.g attributes to preserve request-scoped data
|
||||
# when the streaming generator runs in a new app context.
|
||||
@@ -222,7 +228,9 @@ class BaseStreamingCSVExportCommand(BaseCommand):
|
||||
with self._current_app.app_context():
|
||||
with preserve_g_context(captured_g):
|
||||
try:
|
||||
yield from self._execute_query_and_stream(sql, database, limit)
|
||||
yield from self._execute_query_and_stream(
|
||||
sql, database, limit, catalog, schema
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error in streaming CSV generator: %s", e)
|
||||
import traceback
|
||||
|
||||
@@ -468,13 +468,34 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint:
|
||||
engine_context_manager = app.config["ENGINE_CONTEXT_MANAGER"]
|
||||
with engine_context_manager(self, catalog, schema):
|
||||
with check_for_oauth2(self):
|
||||
yield self._get_sqla_engine(
|
||||
engine = self._get_sqla_engine(
|
||||
catalog=catalog,
|
||||
schema=schema,
|
||||
nullpool=nullpool,
|
||||
source=source,
|
||||
sqlalchemy_uri=sqlalchemy_uri,
|
||||
)
|
||||
prequeries = self.db_engine_spec.get_prequeries(
|
||||
database=self,
|
||||
catalog=catalog,
|
||||
schema=schema,
|
||||
)
|
||||
if prequeries:
|
||||
# SQLAlchemy connect event: runs prequeries on every new
|
||||
# DBAPI connection (e.g. SET search_path for PostgreSQL).
|
||||
def run_prequeries(
|
||||
dbapi_connection: Any,
|
||||
connection_record: Any, # pylint: disable=unused-argument
|
||||
) -> None:
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
for prequery in prequeries:
|
||||
cursor.execute(prequery)
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
sqla.event.listen(engine, "connect", run_prequeries)
|
||||
yield engine
|
||||
|
||||
def _get_sqla_engine( # pylint: disable=too-many-locals # noqa: C901
|
||||
self,
|
||||
@@ -583,15 +604,6 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint:
|
||||
) as engine:
|
||||
with check_for_oauth2(self):
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
# pre-session queries are used to set the selected catalog/schema
|
||||
for prequery in self.db_engine_spec.get_prequeries(
|
||||
database=self,
|
||||
catalog=catalog,
|
||||
schema=schema,
|
||||
):
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(prequery)
|
||||
|
||||
yield conn
|
||||
|
||||
def get_default_catalog(self) -> str | None:
|
||||
|
||||
@@ -25,7 +25,10 @@ from superset.commands.chart.data.streaming_export_command import (
|
||||
|
||||
|
||||
def _setup_chart_mocks(
|
||||
mocker: MockerFixture, sql: str = "SELECT * FROM test"
|
||||
mocker: MockerFixture,
|
||||
sql: str = "SELECT * FROM test",
|
||||
catalog: str | None = None,
|
||||
schema: str | None = None,
|
||||
) -> tuple[MockerFixture, MockerFixture, MockerFixture]:
|
||||
"""Set up common mocks for chart streaming export tests."""
|
||||
mock_db = mocker.patch("superset.commands.streaming_export.base.db")
|
||||
@@ -36,6 +39,8 @@ def _setup_chart_mocks(
|
||||
datasource = mocker.MagicMock()
|
||||
datasource.get_query_str.return_value = sql
|
||||
datasource.database = mocker.MagicMock()
|
||||
datasource.catalog = catalog
|
||||
datasource.schema = schema
|
||||
query_context.datasource = datasource
|
||||
query_context.queries = [mocker.MagicMock()]
|
||||
mock_session.merge.return_value = datasource.database
|
||||
@@ -256,3 +261,38 @@ def test_empty_result_set(mocker: MockerFixture) -> None:
|
||||
lines = [line.strip() for line in csv_data.strip().split("\n")]
|
||||
assert len(lines) == 1
|
||||
assert lines[0] == "col1,col2"
|
||||
|
||||
|
||||
def test_catalog_and_schema_passed_to_engine(mocker: MockerFixture) -> None:
|
||||
"""Test that catalog and schema are forwarded to get_sqla_engine.
|
||||
|
||||
Prequeries (e.g. SET search_path for PostgreSQL) are now run automatically
|
||||
via a connect event listener registered inside get_sqla_engine, not by the
|
||||
streaming command itself.
|
||||
"""
|
||||
mock_db, query_context, datasource = _setup_chart_mocks(
|
||||
mocker, catalog="my_catalog", schema="my_schema"
|
||||
)
|
||||
|
||||
mock_result = mocker.MagicMock()
|
||||
mock_result.keys.return_value = ["col1"]
|
||||
mock_result.fetchmany.side_effect = [[("val",)], []]
|
||||
|
||||
mock_connection = mocker.MagicMock()
|
||||
mock_connection.execution_options.return_value.execute.return_value = mock_result
|
||||
mock_connection.__enter__.return_value = mock_connection
|
||||
mock_connection.__exit__.return_value = None
|
||||
|
||||
mock_engine = mocker.MagicMock()
|
||||
mock_engine.connect.return_value = mock_connection
|
||||
datasource.database.get_sqla_engine.return_value.__enter__.return_value = (
|
||||
mock_engine
|
||||
)
|
||||
|
||||
command = StreamingCSVExportCommand(query_context)
|
||||
list(command.run()())
|
||||
|
||||
datasource.database.get_sqla_engine.assert_called_once_with(
|
||||
catalog="my_catalog",
|
||||
schema="my_schema",
|
||||
)
|
||||
|
||||
@@ -55,6 +55,8 @@ def mock_query():
|
||||
query.select_sql = None
|
||||
query.executed_sql = "SELECT * FROM test_table"
|
||||
query.limiting_factor = LimitingFactor.NOT_LIMITED
|
||||
query.catalog = None
|
||||
query.schema = "public"
|
||||
query.database = MagicMock()
|
||||
query.database.db_engine_spec = MagicMock()
|
||||
query.database.db_engine_spec.engine = "postgresql"
|
||||
@@ -538,3 +540,40 @@ def test_null_values_handling(mocker, mock_query):
|
||||
assert "1,,100" in csv_data
|
||||
assert "2,test," in csv_data
|
||||
assert ",," in csv_data
|
||||
|
||||
|
||||
def test_catalog_and_schema_passed_to_engine(mocker, mock_query, mock_result_proxy):
|
||||
"""Test that catalog and schema are forwarded to get_sqla_engine.
|
||||
|
||||
Prequeries (e.g. SET search_path for PostgreSQL) are now run automatically
|
||||
via a connect event listener registered inside get_sqla_engine, not by the
|
||||
streaming command itself.
|
||||
"""
|
||||
mock_query.select_sql = "SELECT * FROM test"
|
||||
mock_query.catalog = "my_catalog"
|
||||
mock_query.schema = "my_schema"
|
||||
|
||||
mock_db, mock_session = _setup_sqllab_mocks(mocker, mock_query)
|
||||
|
||||
mock_connection = MagicMock()
|
||||
mock_connection.execution_options.return_value.execute.return_value = (
|
||||
mock_result_proxy
|
||||
)
|
||||
mock_connection.__enter__.return_value = mock_connection
|
||||
mock_connection.__exit__.return_value = None
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.connect.return_value = mock_connection
|
||||
mock_query.database.get_sqla_engine.return_value.__enter__.return_value = (
|
||||
mock_engine
|
||||
)
|
||||
|
||||
command = StreamingSqlResultExportCommand("test_client_123")
|
||||
command.validate()
|
||||
|
||||
list(command.run()())
|
||||
|
||||
mock_query.database.get_sqla_engine.assert_called_once_with(
|
||||
catalog="my_catalog",
|
||||
schema="my_schema",
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable
|
||||
|
||||
import pytest
|
||||
from flask import current_app
|
||||
@@ -261,21 +262,6 @@ def test_table_column_database() -> None:
|
||||
assert TableColumn(database=database).database is database
|
||||
|
||||
|
||||
def test_get_prequeries(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Tests for ``get_prequeries``.
|
||||
"""
|
||||
mocker.patch.object(Database, "get_sqla_engine")
|
||||
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
|
||||
db_engine_spec.get_prequeries.return_value = ["set a=1", "set b=2"]
|
||||
|
||||
database = Database(database_name="db")
|
||||
with database.get_raw_connection() as conn:
|
||||
conn.cursor().execute.assert_has_calls(
|
||||
[mocker.call("set a=1"), mocker.call("set b=2")]
|
||||
)
|
||||
|
||||
|
||||
def test_catalog_cache() -> None:
|
||||
"""
|
||||
Test the catalog cache.
|
||||
@@ -634,6 +620,142 @@ def test_get_sqla_engine_user_impersonation_email(mocker: MockerFixture) -> None
|
||||
)
|
||||
|
||||
|
||||
def test_get_sqla_engine_registers_prequery_event_listener(
|
||||
app_context: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that get_sqla_engine registers a connect event listener for prequeries.
|
||||
|
||||
Engines returned by get_sqla_engine must automatically execute prequeries
|
||||
(e.g. SET search_path) on every new connection, so that callers don't need
|
||||
to remember to call get_prequeries() themselves.
|
||||
"""
|
||||
|
||||
mock_engine = mocker.MagicMock()
|
||||
mocker.patch.object(Database, "_get_sqla_engine", return_value=mock_engine)
|
||||
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
|
||||
db_engine_spec.get_prequeries.return_value = ['SET search_path = "my_schema"']
|
||||
event_listen = mocker.patch("superset.models.core.sqla.event.listen")
|
||||
|
||||
database = Database(database_name="my_db", sqlalchemy_uri="postgresql://")
|
||||
with database.get_sqla_engine(catalog="my_catalog", schema="my_schema"):
|
||||
pass
|
||||
|
||||
db_engine_spec.get_prequeries.assert_called_once_with(
|
||||
database=database,
|
||||
catalog="my_catalog",
|
||||
schema="my_schema",
|
||||
)
|
||||
event_listen.assert_called_once_with(mock_engine, "connect", mocker.ANY)
|
||||
|
||||
# Call the captured closure directly to verify cursor create → execute → close.
|
||||
captured_fn = event_listen.call_args[0][2]
|
||||
mock_dbapi_conn = mocker.MagicMock()
|
||||
mock_cursor = mocker.MagicMock()
|
||||
mock_dbapi_conn.cursor.return_value = mock_cursor
|
||||
captured_fn(mock_dbapi_conn, None)
|
||||
mock_cursor.execute.assert_called_once_with('SET search_path = "my_schema"')
|
||||
mock_cursor.close.assert_called_once()
|
||||
|
||||
|
||||
def test_get_sqla_engine_prequery_cursor_closed_on_exception(
|
||||
app_context: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that the cursor is always closed even when a prequery raises.
|
||||
"""
|
||||
mock_engine = mocker.MagicMock()
|
||||
mocker.patch.object(Database, "_get_sqla_engine", return_value=mock_engine)
|
||||
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
|
||||
db_engine_spec.get_prequeries.return_value = ['SET search_path = "bad_schema"']
|
||||
event_listen = mocker.patch("superset.models.core.sqla.event.listen")
|
||||
|
||||
database = Database(database_name="my_db", sqlalchemy_uri="postgresql://")
|
||||
with database.get_sqla_engine(catalog=None, schema="bad_schema"):
|
||||
pass
|
||||
|
||||
captured_fn = event_listen.call_args[0][2]
|
||||
mock_dbapi_conn = mocker.MagicMock()
|
||||
mock_cursor = mocker.MagicMock()
|
||||
mock_cursor.execute.side_effect = Exception("invalid schema")
|
||||
mock_dbapi_conn.cursor.return_value = mock_cursor
|
||||
|
||||
with pytest.raises(Exception, match="invalid schema"):
|
||||
captured_fn(mock_dbapi_conn, None)
|
||||
|
||||
mock_cursor.close.assert_called_once()
|
||||
|
||||
|
||||
def test_get_sqla_engine_no_prequeries_no_event_listener(
|
||||
app_context: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that get_sqla_engine does not register an event listener when there
|
||||
are no prequeries.
|
||||
"""
|
||||
mock_engine = mocker.MagicMock()
|
||||
mocker.patch.object(Database, "_get_sqla_engine", return_value=mock_engine)
|
||||
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
|
||||
db_engine_spec.get_prequeries.return_value = []
|
||||
event_listen = mocker.patch("superset.models.core.sqla.event.listen")
|
||||
|
||||
database = Database(database_name="my_db", sqlalchemy_uri="postgresql://")
|
||||
with database.get_sqla_engine(catalog=None, schema=None):
|
||||
pass
|
||||
|
||||
event_listen.assert_not_called()
|
||||
|
||||
|
||||
def test_get_raw_connection_executes_prequeries_exactly_once(
|
||||
app_context: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that get_raw_connection() runs prequeries exactly once through the
|
||||
connect event listener registered by get_sqla_engine().
|
||||
|
||||
Previously get_raw_connection() had its own manual prequery loop AND
|
||||
called get_sqla_engine() (which registers the listener), so prequeries
|
||||
ran twice. After removing the manual loop the listener is the sole
|
||||
execution point — this test proves exactly-once semantics.
|
||||
"""
|
||||
mock_engine = mocker.MagicMock()
|
||||
mocker.patch.object(Database, "_get_sqla_engine", return_value=mock_engine)
|
||||
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
|
||||
prequery = 'SET search_path = "my_schema"'
|
||||
db_engine_spec.get_prequeries.return_value = [prequery]
|
||||
|
||||
# Capture the closure registered via sqla.event.listen.
|
||||
captured_listeners: list[Callable[..., None]] = []
|
||||
original_listen = mocker.patch("superset.models.core.sqla.event.listen")
|
||||
original_listen.side_effect = lambda engine, event, fn: captured_listeners.append(
|
||||
fn
|
||||
)
|
||||
|
||||
# Simulate SQLAlchemy firing the "connect" event when raw_connection() is called.
|
||||
mock_dbapi_conn = mocker.MagicMock()
|
||||
mock_cursor = mocker.MagicMock()
|
||||
mock_dbapi_conn.cursor.return_value = mock_cursor
|
||||
|
||||
def raw_connection_side_effect() -> Any:
|
||||
for listener in captured_listeners:
|
||||
listener(mock_dbapi_conn, None)
|
||||
return mock_dbapi_conn
|
||||
|
||||
mock_engine.raw_connection.side_effect = raw_connection_side_effect
|
||||
|
||||
database = Database(database_name="my_db", sqlalchemy_uri="postgresql://")
|
||||
with database.get_raw_connection(schema="my_schema"):
|
||||
pass
|
||||
|
||||
# Exactly one prequery, exactly once — not twice, not zero.
|
||||
mock_cursor.execute.assert_called_once_with(prequery)
|
||||
mock_cursor.close.assert_called_once()
|
||||
|
||||
|
||||
def test_is_oauth2_enabled() -> None:
|
||||
"""
|
||||
Test the `is_oauth2_enabled` method.
|
||||
|
||||
Reference in New Issue
Block a user