refactor: Ensure Flask framework leverages the Flask-SQLAlchemy session (Phase II) (#26909)

This commit is contained in:
John Bodley
2024-02-14 06:20:15 +13:00
committed by GitHub
parent 827864b939
commit 847ed3f5b0
96 changed files with 656 additions and 730 deletions

View File

@@ -24,7 +24,7 @@ from flask_appbuilder.security.sqla.models import Role, User
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session
from superset import security_manager
from superset import db, security_manager
from superset.commands.chart.importers.v1.utils import import_chart
from superset.commands.exceptions import ImportFailedError
from superset.connectors.sqla.models import Database, SqlaTable
@@ -82,7 +82,7 @@ def test_import_chart(mocker: MockFixture, session_with_schema: Session) -> None
config["datasource_id"] = 1
config["datasource_type"] = "table"
chart = import_chart(session_with_schema, config)
chart = import_chart(config)
assert chart.slice_name == "Deck Path"
assert chart.viz_type == "deck_path"
assert chart.is_managed_externally is False
@@ -106,7 +106,7 @@ def test_import_chart_managed_externally(
config["is_managed_externally"] = True
config["external_url"] = "https://example.org/my_chart"
chart = import_chart(session_with_schema, config)
chart = import_chart(config)
assert chart.is_managed_externally is True
assert chart.external_url == "https://example.org/my_chart"
@@ -128,7 +128,7 @@ def test_import_chart_without_permission(
config["datasource_type"] = "table"
with pytest.raises(ImportFailedError) as excinfo:
import_chart(session_with_schema, config)
import_chart(config)
assert (
str(excinfo.value)
== "Chart doesn't exist and user doesn't have permission to create charts"
@@ -173,7 +173,7 @@ def test_import_existing_chart_without_permission(
with override_user("admin"):
with pytest.raises(ImportFailedError) as excinfo:
import_chart(session_with_data, chart_config, overwrite=True)
import_chart(chart_config, overwrite=True)
assert (
str(excinfo.value)
== "A chart already exists and user doesn't have permissions to overwrite it"
@@ -213,7 +213,7 @@ def test_import_existing_chart_with_permission(
)
with override_user(admin):
import_chart(session_with_data, config, overwrite=True)
import_chart(config, overwrite=True)
# Assert that the can write to chart was checked
security_manager.can_access.assert_called_once_with("can_write", "Chart")
security_manager.can_access_chart.assert_called_once_with(slice)

View File

@@ -48,7 +48,7 @@ def test_slice_find_by_id_skip_base_filter(session_with_data: Session) -> None:
from superset.daos.chart import ChartDAO
from superset.models.slice import Slice
result = ChartDAO.find_by_id(1, session=session_with_data, skip_base_filter=True)
result = ChartDAO.find_by_id(1, skip_base_filter=True)
assert result
assert 1 == result.id
@@ -57,20 +57,18 @@ def test_slice_find_by_id_skip_base_filter(session_with_data: Session) -> None:
def test_datasource_find_by_id_skip_base_filter_not_found(
session_with_data: Session,
session: Session,
) -> None:
from superset.daos.chart import ChartDAO
result = ChartDAO.find_by_id(
125326326, session=session_with_data, skip_base_filter=True
)
result = ChartDAO.find_by_id(125326326, skip_base_filter=True)
assert result is None
def test_add_favorite(session_with_data: Session) -> None:
def test_add_favorite(session: Session) -> None:
from superset.daos.chart import ChartDAO
chart = ChartDAO.find_by_id(1, session=session_with_data, skip_base_filter=True)
chart = ChartDAO.find_by_id(1, skip_base_filter=True)
if not chart:
return
assert len(ChartDAO.favorited_ids([chart])) == 0
@@ -82,10 +80,10 @@ def test_add_favorite(session_with_data: Session) -> None:
assert len(ChartDAO.favorited_ids([chart])) == 1
def test_remove_favorite(session_with_data: Session) -> None:
def test_remove_favorite(session: Session) -> None:
from superset.daos.chart import ChartDAO
chart = ChartDAO.find_by_id(1, session=session_with_data, skip_base_filter=True)
chart = ChartDAO.find_by_id(1, skip_base_filter=True)
if not chart:
return
assert len(ChartDAO.favorited_ids([chart])) == 0

View File

@@ -1965,12 +1965,13 @@ def test_apply_post_process_json_format_data_is_none():
def test_apply_post_process_verbose_map(session: Session):
from superset import db
from superset.connectors.sqla.models import SqlaTable, SqlMetric
from superset.models.core import Database
engine = session.get_bind()
engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
@@ -1982,7 +1983,7 @@ def test_apply_post_process_verbose_map(session: Session):
expression="COUNT(*)",
)
],
database=db,
database=database,
)
result = {

View File

@@ -24,9 +24,10 @@ def test_column_model(session: Session) -> None:
"""
Test basic attributes of a ``Column``.
"""
from superset import db
from superset.columns.models import Column
engine = session.get_bind()
engine = db.session.get_bind()
Column.metadata.create_all(engine) # pylint: disable=no-member
column = Column(
@@ -35,8 +36,8 @@ def test_column_model(session: Session) -> None:
expression="ds",
)
session.add(column)
session.flush()
db.session.add(column)
db.session.flush()
assert column.id == 1
assert column.uuid is not None

View File

@@ -35,14 +35,14 @@ def test_import_new_assets(mocker: MockFixture, session: Session) -> None:
"""
Test that all new assets are imported correctly.
"""
from superset import security_manager
from superset import db, security_manager
from superset.commands.importers.v1.assets import ImportAssetsCommand
from superset.models.dashboard import dashboard_slices
from superset.models.slice import Slice
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = session.get_bind()
engine = db.session.get_bind()
Slice.metadata.create_all(engine) # pylint: disable=no-member
configs = {
**copy.deepcopy(databases_config),
@@ -53,11 +53,11 @@ def test_import_new_assets(mocker: MockFixture, session: Session) -> None:
expected_number_of_dashboards = len(dashboards_config_1)
expected_number_of_charts = len(charts_config_1)
ImportAssetsCommand._import(session, configs)
dashboard_ids = session.scalars(
ImportAssetsCommand._import(configs)
dashboard_ids = db.session.scalars(
select(dashboard_slices.c.dashboard_id).distinct()
).all()
chart_ids = session.scalars(select(dashboard_slices.c.slice_id)).all()
chart_ids = db.session.scalars(select(dashboard_slices.c.slice_id)).all()
assert len(chart_ids) == expected_number_of_charts
assert len(dashboard_ids) == expected_number_of_dashboards
@@ -67,14 +67,14 @@ def test_import_adds_dashboard_charts(mocker: MockFixture, session: Session) ->
"""
Test that existing dashboards are updated with new charts.
"""
from superset import security_manager
from superset import db, security_manager
from superset.commands.importers.v1.assets import ImportAssetsCommand
from superset.models.dashboard import dashboard_slices
from superset.models.slice import Slice
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = session.get_bind()
engine = db.session.get_bind()
Slice.metadata.create_all(engine) # pylint: disable=no-member
base_configs = {
**copy.deepcopy(databases_config),
@@ -91,12 +91,12 @@ def test_import_adds_dashboard_charts(mocker: MockFixture, session: Session) ->
expected_number_of_dashboards = len(dashboards_config_1)
expected_number_of_charts = len(charts_config_1)
ImportAssetsCommand._import(session, base_configs)
ImportAssetsCommand._import(session, new_configs)
dashboard_ids = session.scalars(
ImportAssetsCommand._import(base_configs)
ImportAssetsCommand._import(new_configs)
dashboard_ids = db.session.scalars(
select(dashboard_slices.c.dashboard_id).distinct()
).all()
chart_ids = session.scalars(select(dashboard_slices.c.slice_id)).all()
chart_ids = db.session.scalars(select(dashboard_slices.c.slice_id)).all()
assert len(chart_ids) == expected_number_of_charts
assert len(dashboard_ids) == expected_number_of_dashboards
@@ -106,14 +106,14 @@ def test_import_removes_dashboard_charts(mocker: MockFixture, session: Session)
"""
Test that existing dashboards are updated without old charts.
"""
from superset import security_manager
from superset import db, security_manager
from superset.commands.importers.v1.assets import ImportAssetsCommand
from superset.models.dashboard import dashboard_slices
from superset.models.slice import Slice
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = session.get_bind()
engine = db.session.get_bind()
Slice.metadata.create_all(engine) # pylint: disable=no-member
base_configs = {
**copy.deepcopy(databases_config),
@@ -130,12 +130,12 @@ def test_import_removes_dashboard_charts(mocker: MockFixture, session: Session)
expected_number_of_dashboards = len(dashboards_config_2)
expected_number_of_charts = len(charts_config_2)
ImportAssetsCommand._import(session, base_configs)
ImportAssetsCommand._import(session, new_configs)
dashboard_ids = session.scalars(
ImportAssetsCommand._import(base_configs)
ImportAssetsCommand._import(new_configs)
dashboard_ids = db.session.scalars(
select(dashboard_slices.c.dashboard_id).distinct()
).all()
chart_ids = session.scalars(select(dashboard_slices.c.slice_id)).all()
chart_ids = db.session.scalars(select(dashboard_slices.c.slice_id)).all()
assert len(chart_ids) == expected_number_of_charts
assert len(dashboard_ids) == expected_number_of_dashboards

View File

@@ -23,6 +23,8 @@ import pytest
from pytest_mock import MockerFixture
from sqlalchemy.orm.session import Session
from superset import db
if TYPE_CHECKING:
from superset.connectors.sqla.models import SqlaTable
@@ -81,7 +83,7 @@ def test_table(session: Session) -> "SqlaTable":
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.models.core import Database
engine = session.get_bind()
engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
columns = [

View File

@@ -41,7 +41,7 @@ from superset.initialization import SupersetAppInitializer
@pytest.fixture
def get_session(mocker: MockFixture) -> Callable[[], Session]:
"""
Create an in-memory SQLite session to test models.
Create an in-memory SQLite db.session.to test models.
"""
engine = create_engine("sqlite://")
@@ -49,7 +49,7 @@ def get_session(mocker: MockFixture) -> Callable[[], Session]:
Session_ = sessionmaker(bind=engine) # pylint: disable=invalid-name
in_memory_session = Session_()
# flask calls session.remove()
# flask calls db.session.remove()
in_memory_session.remove = lambda: None
# patch session

View File

@@ -27,6 +27,7 @@ def test_validate_update_uniqueness(session: Session) -> None:
In particular, allow datasets with the same name in the same database as long as they
are in different schemas
"""
from superset import db
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
@@ -46,8 +47,8 @@ def test_validate_update_uniqueness(session: Session) -> None:
schema="dev",
database=database,
)
session.add_all([database, dataset1, dataset2])
session.flush()
db.session.add_all([database, dataset1, dataset2])
db.session.flush()
# same table name, different schema
assert (

View File

@@ -25,17 +25,18 @@ from superset.exceptions import QueryNotFoundException, SupersetCancelQueryExcep
def test_query_dao_save_metadata(session: Session) -> None:
from superset import db
from superset.models.core import Database
from superset.models.sql_lab import Query
engine = session.get_bind()
engine = db.session.get_bind()
Query.metadata.create_all(engine) # pylint: disable=no-member
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
query_obj = Query(
client_id="foo",
database=db,
database=database,
tab_name="test_tab",
sql_editor_id="test_editor_id",
sql="select * from bar",
@@ -48,30 +49,31 @@ def test_query_dao_save_metadata(session: Session) -> None:
results_key="abc",
)
session.add(db)
session.add(query_obj)
db.session.add(database)
db.session.add(query_obj)
from superset.daos.query import QueryDAO
query = session.query(Query).one()
query = db.session.query(Query).one()
QueryDAO.save_metadata(query=query, payload={"columns": []})
assert query.extra.get("columns", None) == []
def test_query_dao_get_queries_changed_after(session: Session) -> None:
from superset import db
from superset.models.core import Database
from superset.models.sql_lab import Query
engine = session.get_bind()
engine = db.session.get_bind()
Query.metadata.create_all(engine) # pylint: disable=no-member
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
now = datetime.utcnow()
old_query_obj = Query(
client_id="foo",
database=db,
database=database,
tab_name="test_tab",
sql_editor_id="test_editor_id",
sql="select * from bar",
@@ -87,7 +89,7 @@ def test_query_dao_get_queries_changed_after(session: Session) -> None:
updated_query_obj = Query(
client_id="updated_foo",
database=db,
database=database,
tab_name="test_tab",
sql_editor_id="test_editor_id",
sql="select * from foo",
@@ -101,9 +103,9 @@ def test_query_dao_get_queries_changed_after(session: Session) -> None:
changed_on=now - timedelta(days=1),
)
session.add(db)
session.add(old_query_obj)
session.add(updated_query_obj)
db.session.add(database)
db.session.add(old_query_obj)
db.session.add(updated_query_obj)
from superset.daos.query import QueryDAO
@@ -116,18 +118,19 @@ def test_query_dao_get_queries_changed_after(session: Session) -> None:
def test_query_dao_stop_query_not_found(
mocker: MockFixture, app: Any, session: Session
) -> None:
from superset import db
from superset.common.db_query_status import QueryStatus
from superset.models.core import Database
from superset.models.sql_lab import Query
engine = session.get_bind()
engine = db.session.get_bind()
Query.metadata.create_all(engine) # pylint: disable=no-member
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
query_obj = Query(
client_id="foo",
database=db,
database=database,
tab_name="test_tab",
sql_editor_id="test_editor_id",
sql="select * from bar",
@@ -141,8 +144,8 @@ def test_query_dao_stop_query_not_found(
status=QueryStatus.RUNNING,
)
session.add(db)
session.add(query_obj)
db.session.add(database)
db.session.add(query_obj)
mocker.patch("superset.sql_lab.cancel_query", return_value=False)
@@ -151,25 +154,26 @@ def test_query_dao_stop_query_not_found(
with pytest.raises(QueryNotFoundException):
QueryDAO.stop_query("foo2")
query = session.query(Query).one()
query = db.session.query(Query).one()
assert query.status == QueryStatus.RUNNING
def test_query_dao_stop_query_not_running(
mocker: MockFixture, app: Any, session: Session
) -> None:
from superset import db
from superset.common.db_query_status import QueryStatus
from superset.models.core import Database
from superset.models.sql_lab import Query
engine = session.get_bind()
engine = db.session.get_bind()
Query.metadata.create_all(engine) # pylint: disable=no-member
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
query_obj = Query(
client_id="foo",
database=db,
database=database,
tab_name="test_tab",
sql_editor_id="test_editor_id",
sql="select * from bar",
@@ -183,31 +187,32 @@ def test_query_dao_stop_query_not_running(
status=QueryStatus.FAILED,
)
session.add(db)
session.add(query_obj)
db.session.add(database)
db.session.add(query_obj)
from superset.daos.query import QueryDAO
QueryDAO.stop_query(query_obj.client_id)
query = session.query(Query).one()
query = db.session.query(Query).one()
assert query.status == QueryStatus.FAILED
def test_query_dao_stop_query_failed(
mocker: MockFixture, app: Any, session: Session
) -> None:
from superset import db
from superset.common.db_query_status import QueryStatus
from superset.models.core import Database
from superset.models.sql_lab import Query
engine = session.get_bind()
engine = db.session.get_bind()
Query.metadata.create_all(engine) # pylint: disable=no-member
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
query_obj = Query(
client_id="foo",
database=db,
database=database,
tab_name="test_tab",
sql_editor_id="test_editor_id",
sql="select * from bar",
@@ -221,8 +226,8 @@ def test_query_dao_stop_query_failed(
status=QueryStatus.RUNNING,
)
session.add(db)
session.add(query_obj)
db.session.add(database)
db.session.add(query_obj)
mocker.patch("superset.sql_lab.cancel_query", return_value=False)
@@ -231,23 +236,24 @@ def test_query_dao_stop_query_failed(
with pytest.raises(SupersetCancelQueryException):
QueryDAO.stop_query(query_obj.client_id)
query = session.query(Query).one()
query = db.session.query(Query).one()
assert query.status == QueryStatus.RUNNING
def test_query_dao_stop_query(mocker: MockFixture, app: Any, session: Session) -> None:
from superset import db
from superset.common.db_query_status import QueryStatus
from superset.models.core import Database
from superset.models.sql_lab import Query
engine = session.get_bind()
engine = db.session.get_bind()
Query.metadata.create_all(engine) # pylint: disable=no-member
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
query_obj = Query(
client_id="foo",
database=db,
database=database,
tab_name="test_tab",
sql_editor_id="test_editor_id",
sql="select * from bar",
@@ -261,13 +267,13 @@ def test_query_dao_stop_query(mocker: MockFixture, app: Any, session: Session) -
status=QueryStatus.RUNNING,
)
session.add(db)
session.add(query_obj)
db.session.add(database)
db.session.add(query_obj)
mocker.patch("superset.sql_lab.cancel_query", return_value=True)
from superset.daos.query import QueryDAO
QueryDAO.stop_query(query_obj.client_id)
query = session.query(Query).one()
query = db.session.query(Query).one()
assert query.status == QueryStatus.STOPPED

View File

@@ -70,7 +70,7 @@ def test_remove_user_favorite_tag(mocker):
# Check that users_favorited no longer contains the user
assert mock_user not in mock_tag.users_favorited
# Check that the session was committed
# Check that the db.session.was committed
mock_session.commit.assert_called_once()

View File

@@ -24,7 +24,7 @@ from flask_appbuilder.security.sqla.models import Role, User
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session
from superset import security_manager
from superset import db, security_manager
from superset.commands.dashboard.importers.v1.utils import import_dashboard
from superset.commands.exceptions import ImportFailedError
from superset.models.dashboard import Dashboard
@@ -67,7 +67,7 @@ def test_import_dashboard(mocker: MockFixture, session_with_schema: Session) ->
"""
mocker.patch.object(security_manager, "can_access", return_value=True)
dashboard = import_dashboard(session_with_schema, dashboard_config)
dashboard = import_dashboard(dashboard_config)
assert dashboard.dashboard_title == "Test dash"
assert dashboard.description is None
assert dashboard.is_managed_externally is False
@@ -88,8 +88,7 @@ def test_import_dashboard_managed_externally(
config = copy.deepcopy(dashboard_config)
config["is_managed_externally"] = True
config["external_url"] = "https://example.org/my_dashboard"
dashboard = import_dashboard(session_with_schema, config)
dashboard = import_dashboard(config)
assert dashboard.is_managed_externally is True
assert dashboard.external_url == "https://example.org/my_dashboard"
@@ -107,7 +106,7 @@ def test_import_dashboard_without_permission(
mocker.patch.object(security_manager, "can_access", return_value=False)
with pytest.raises(ImportFailedError) as excinfo:
import_dashboard(session_with_schema, dashboard_config)
import_dashboard(dashboard_config)
assert (
str(excinfo.value)
== "Dashboard doesn't exist and user doesn't have permission to create dashboards"
@@ -135,7 +134,7 @@ def test_import_existing_dashboard_without_permission(
with override_user("admin"):
with pytest.raises(ImportFailedError) as excinfo:
import_dashboard(session_with_data, dashboard_config, overwrite=True)
import_dashboard(dashboard_config, overwrite=True)
assert (
str(excinfo.value)
== "A dashboard already exists and user doesn't have permissions to overwrite it"
@@ -171,7 +170,8 @@ def test_import_existing_dashboard_with_permission(
)
with override_user(admin):
import_dashboard(session_with_data, dashboard_config, overwrite=True)
import_dashboard(dashboard_config, overwrite=True)
# Assert that the can write to dashboard was checked
security_manager.can_access.assert_called_once_with("can_write", "Dashboard")
security_manager.can_access_dashboard.assert_called_once_with(dashboard)

View File

@@ -42,12 +42,10 @@ def session_with_data(session: Session) -> Iterator[Session]:
session.rollback()
def test_add_favorite(session_with_data: Session) -> None:
def test_add_favorite(session: Session) -> None:
from superset.daos.dashboard import DashboardDAO
dashboard = DashboardDAO.find_by_id(
100, session=session_with_data, skip_base_filter=True
)
dashboard = DashboardDAO.find_by_id(100, skip_base_filter=True)
if not dashboard:
return
assert len(DashboardDAO.favorited_ids([dashboard])) == 0
@@ -59,12 +57,10 @@ def test_add_favorite(session_with_data: Session) -> None:
assert len(DashboardDAO.favorited_ids([dashboard])) == 1
def test_remove_favorite(session_with_data: Session) -> None:
def test_remove_favorite(session: Session) -> None:
from superset.daos.dashboard import DashboardDAO
dashboard = DashboardDAO.find_by_id(
100, session=session_with_data, skip_base_filter=True
)
dashboard = DashboardDAO.find_by_id(100, skip_base_filter=True)
if not dashboard:
return
assert len(DashboardDAO.favorited_ids([dashboard])) == 0

View File

@@ -28,6 +28,8 @@ from flask import current_app
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session
from superset import db
def test_filter_by_uuid(
session: Session,
@@ -49,14 +51,14 @@ def test_filter_by_uuid(
# create table for databases
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
session.add(
db.session.add(
Database(
database_name="my_db",
sqlalchemy_uri="sqlite://",
uuid=UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"),
)
)
session.commit()
db.session.commit()
response = client.get(
"/api/v1/database/?q=(filters:!((col:uuid,opr:eq,value:"
@@ -96,7 +98,7 @@ def test_post_with_uuid(
payload = response.json
assert payload["result"]["uuid"] == "7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"
database = session.query(Database).one()
database = db.session.query(Database).one()
assert database.uuid == UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb")
@@ -139,8 +141,8 @@ def test_password_mask(
}
),
)
session.add(database)
session.commit()
db.session.add(database)
db.session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
@@ -195,8 +197,8 @@ def test_database_connection(
}
),
)
session.add(database)
session.commit()
db.session.add(database)
db.session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
@@ -331,8 +333,8 @@ def test_update_with_password_mask(
}
),
)
session.add(database)
session.commit()
db.session.add(database)
db.session.commit()
client.put(
"/api/v1/database/1",
@@ -347,7 +349,7 @@ def test_update_with_password_mask(
),
},
)
database = session.query(Database).one()
database = db.session.query(Database).one()
assert (
database.encrypted_extra
== '{"service_account_info": {"project_id": "yellow-unicorn-314419", "private_key": "SECRET"}}'
@@ -429,8 +431,8 @@ def test_delete_ssh_tunnel(
}
),
)
session.add(database)
session.commit()
db.session.add(database)
db.session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
@@ -446,8 +448,8 @@ def test_delete_ssh_tunnel(
database=database,
)
session.add(tunnel)
session.commit()
db.session.add(tunnel)
db.session.commit()
# Get our recently created SSHTunnel
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
@@ -505,8 +507,8 @@ def test_delete_ssh_tunnel_not_found(
}
),
)
session.add(database)
session.commit()
db.session.add(database)
db.session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
@@ -522,8 +524,8 @@ def test_delete_ssh_tunnel_not_found(
database=database,
)
session.add(tunnel)
session.commit()
db.session.add(tunnel)
db.session.commit()
# Delete the recently created SSHTunnel
response_delete_tunnel = client.delete("/api/v1/database/2/ssh_tunnel/")
@@ -576,8 +578,8 @@ def test_apply_dynamic_database_filter(
}
),
)
session.add(database)
session.commit()
db.session.add(database)
db.session.commit()
# Create our Second Database
database = Database(
@@ -592,8 +594,8 @@ def test_apply_dynamic_database_filter(
}
),
)
session.add(database)
session.commit()
db.session.add(database)
db.session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")

View File

@@ -23,6 +23,7 @@ import pytest
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session
from superset import db
from superset.commands.exceptions import ImportFailedError
@@ -37,11 +38,11 @@ def test_import_database(mocker: MockFixture, session: Session) -> None:
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = session.get_bind()
engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member
config = copy.deepcopy(database_config)
database = import_database(session, config)
database = import_database(config)
assert database.database_name == "imported_database"
assert database.sqlalchemy_uri == "someengine://user:pass@host1"
assert database.cache_timeout is None
@@ -60,9 +61,9 @@ def test_import_database(mocker: MockFixture, session: Session) -> None:
# missing
config = copy.deepcopy(database_config)
del config["allow_dml"]
session.delete(database)
session.flush()
database = import_database(session, config)
db.session.delete(database)
db.session.flush()
database = import_database(config)
assert database.allow_dml is False
@@ -78,12 +79,12 @@ def test_import_database_sqlite_invalid(mocker: MockFixture, session: Session) -
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = True
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = session.get_bind()
engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member
config = copy.deepcopy(database_config_sqlite)
with pytest.raises(ImportFailedError) as excinfo:
_ = import_database(session, config)
_ = import_database(config)
assert (
str(excinfo.value)
== "SQLiteDialect_pysqlite cannot be used as a data source for security reasons."
@@ -106,14 +107,14 @@ def test_import_database_managed_externally(
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = session.get_bind()
engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member
config = copy.deepcopy(database_config)
config["is_managed_externally"] = True
config["external_url"] = "https://example.org/my_database"
database = import_database(session, config)
database = import_database(config)
assert database.is_managed_externally is True
assert database.external_url == "https://example.org/my_database"
@@ -132,13 +133,13 @@ def test_import_database_without_permission(
mocker.patch.object(security_manager, "can_access", return_value=False)
engine = session.get_bind()
engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member
config = copy.deepcopy(database_config)
with pytest.raises(ImportFailedError) as excinfo:
import_database(session, config)
import_database(config)
assert (
str(excinfo.value)
== "Database doesn't exist and user doesn't have permission to create databases"
@@ -156,10 +157,10 @@ def test_import_database_with_version(mocker: MockFixture, session: Session) ->
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = session.get_bind()
engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member
config = copy.deepcopy(database_config)
config["extra"]["version"] = "1.1.1"
database = import_database(session, config)
database = import_database(config)
assert json.loads(database.extra)["version"] == "1.1.1"

View File

@@ -30,19 +30,19 @@ def session_with_data(session: Session) -> Iterator[Session]:
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
database=db,
database=database,
)
ssh_tunnel = SSHTunnel(
database_id=db.id,
database=db,
database_id=database.id,
database=database,
)
session.add(db)
session.add(database)
session.add(sqla_table)
session.add(ssh_tunnel)
session.flush()

View File

@@ -27,17 +27,17 @@ def test_create_ssh_tunnel_command() -> None:
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
properties = {
"database_id": db.id,
"database_id": database.id,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"password": "bar",
}
result = CreateSSHTunnelCommand(db, properties).run()
result = CreateSSHTunnelCommand(database, properties).run()
assert result is not None
assert isinstance(result, SSHTunnel)
@@ -48,19 +48,19 @@ def test_create_ssh_tunnel_command_invalid_params() -> None:
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
# If we are trying to create a tunnel with a private_key_password
# then a private_key is mandatory
properties = {
"database": db,
"database": database,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"private_key_password": "bar",
}
command = CreateSSHTunnelCommand(db, properties)
command = CreateSSHTunnelCommand(database, properties)
with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run()

View File

@@ -31,19 +31,19 @@ def session_with_data(session: Session) -> Iterator[Session]:
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
database=db,
database=database,
)
ssh_tunnel = SSHTunnel(
database_id=db.id,
database=db,
database_id=database.id,
database=database,
)
session.add(db)
session.add(database)
session.add(sqla_table)
session.add(ssh_tunnel)
session.flush()

View File

@@ -32,16 +32,18 @@ def session_with_data(session: Session) -> Iterator[Session]:
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
database=db,
database=database,
)
ssh_tunnel = SSHTunnel(
database_id=database.id, database=database, server_address="Test"
)
ssh_tunnel = SSHTunnel(database_id=db.id, database=db, server_address="Test")
session.add(db)
session.add(database)
session.add(sqla_table)
session.add(ssh_tunnel)
session.flush()

View File

@@ -25,11 +25,11 @@ def test_create_ssh_tunnel():
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
result = SSHTunnelDAO.create(
attributes={
"database_id": db.id,
"database_id": database.id,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",

View File

@@ -19,6 +19,8 @@ from typing import Any
from sqlalchemy.orm.session import Session
from superset import db
def test_put_invalid_dataset(
session: Session,
@@ -31,7 +33,7 @@ def test_put_invalid_dataset(
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
SqlaTable.metadata.create_all(session.get_bind())
SqlaTable.metadata.create_all(db.session.get_bind())
database = Database(
database_name="my_db",
@@ -41,8 +43,8 @@ def test_put_invalid_dataset(
table_name="test_put_invalid_dataset",
database=database,
)
session.add(dataset)
session.flush()
db.session.add(dataset)
db.session.flush()
response = client.put(
"/api/v1/dataset/1",

View File

@@ -20,6 +20,8 @@ import json
from sqlalchemy.orm.session import Session
from superset import db
def test_export(session: Session) -> None:
"""
@@ -29,12 +31,12 @@ def test_export(session: Session) -> None:
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.models.core import Database
engine = session.get_bind()
engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
session.add(database)
session.flush()
db.session.add(database)
db.session.flush()
columns = [
TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),

View File

@@ -28,6 +28,7 @@ from flask import current_app
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session
from superset import db
from superset.commands.dataset.exceptions import (
DatasetForbiddenDataURI,
ImportFailedError,
@@ -46,12 +47,12 @@ def test_import_dataset(mocker: MockFixture, session: Session) -> None:
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = session.get_bind()
engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
session.add(database)
session.flush()
db.session.add(database)
db.session.flush()
dataset_uuid = uuid.uuid4()
config = {
@@ -108,7 +109,7 @@ def test_import_dataset(mocker: MockFixture, session: Session) -> None:
"database_id": database.id,
}
sqla_table = import_dataset(session, config)
sqla_table = import_dataset(config)
assert sqla_table.table_name == "my_table"
assert sqla_table.main_dttm_col == "ds"
assert sqla_table.description == "This is the description"
@@ -162,23 +163,23 @@ def test_import_dataset_duplicate_column(mocker: MockFixture, session: Session)
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = session.get_bind()
engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
dataset_uuid = uuid.uuid4()
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
session.add(database)
session.flush()
db.session.add(database)
db.session.flush()
dataset = SqlaTable(
uuid=dataset_uuid, table_name="existing_dataset", database_id=database.id
)
column = TableColumn(column_name="existing_column")
session.add(dataset)
session.add(column)
session.flush()
db.session.add(dataset)
db.session.add(column)
db.session.flush()
config = {
"table_name": dataset.table_name,
@@ -234,7 +235,7 @@ def test_import_dataset_duplicate_column(mocker: MockFixture, session: Session)
"database_id": database.id,
}
sqla_table = import_dataset(session, config, overwrite=True)
sqla_table = import_dataset(config, overwrite=True)
assert sqla_table.table_name == dataset.table_name
assert sqla_table.main_dttm_col == "ds"
assert sqla_table.description == "This is the description"
@@ -288,12 +289,12 @@ def test_import_column_extra_is_string(mocker: MockFixture, session: Session) ->
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = session.get_bind()
engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
session.add(database)
session.flush()
db.session.add(database)
db.session.flush()
dataset_uuid = uuid.uuid4()
yaml_config: dict[str, Any] = {
@@ -352,7 +353,7 @@ def test_import_column_extra_is_string(mocker: MockFixture, session: Session) ->
schema = ImportV1DatasetSchema()
dataset_config = schema.load(yaml_config)
dataset_config["database_id"] = database.id
sqla_table = import_dataset(session, dataset_config)
sqla_table = import_dataset(dataset_config)
assert sqla_table.metrics[0].extra == '{"warning_markdown": null}'
assert sqla_table.columns[0].extra == '{"certified_by": "User"}'
@@ -373,12 +374,12 @@ def test_import_dataset_extra_empty_string(
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = session.get_bind()
engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
session.add(database)
session.flush()
db.session.add(database)
db.session.flush()
dataset_uuid = uuid.uuid4()
yaml_config: dict[str, Any] = {
@@ -417,7 +418,7 @@ def test_import_dataset_extra_empty_string(
schema = ImportV1DatasetSchema()
dataset_config = schema.load(yaml_config)
dataset_config["database_id"] = database.id
sqla_table = import_dataset(session, dataset_config)
sqla_table = import_dataset(dataset_config)
assert sqla_table.extra == None
@@ -443,12 +444,12 @@ def test_import_column_allowed_data_url(
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = session.get_bind()
engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
session.add(database)
session.flush()
db.session.add(database)
db.session.flush()
dataset_uuid = uuid.uuid4()
yaml_config: dict[str, Any] = {
@@ -495,9 +496,8 @@ def test_import_column_allowed_data_url(
schema = ImportV1DatasetSchema()
dataset_config = schema.load(yaml_config)
dataset_config["database_id"] = database.id
_ = import_dataset(session, dataset_config, force_data=True)
session.connection()
assert [("value1",), ("value2",)] == session.execute(
_ = import_dataset(dataset_config, force_data=True)
assert [("value1",), ("value2",)] == db.session.execute(
"SELECT * FROM my_table"
).fetchall()
@@ -517,19 +517,19 @@ def test_import_dataset_managed_externally(
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = session.get_bind()
engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
session.add(database)
session.flush()
db.session.add(database)
db.session.flush()
config = copy.deepcopy(dataset_config)
config["is_managed_externally"] = True
config["external_url"] = "https://example.org/my_table"
config["database_id"] = database.id
sqla_table = import_dataset(session, config)
sqla_table = import_dataset(config)
assert sqla_table.is_managed_externally is True
assert sqla_table.external_url == "https://example.org/my_table"

View File

@@ -29,15 +29,15 @@ def session_with_data(session: Session) -> Iterator[Session]:
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
database=db,
database=database,
)
session.add(db)
session.add(database)
session.add(sqla_table)
session.flush()
yield session
@@ -50,7 +50,6 @@ def test_datasource_find_by_id_skip_base_filter(session_with_data: Session) -> N
result = DatasetDAO.find_by_id(
1,
session=session_with_data,
skip_base_filter=True,
)
@@ -67,7 +66,6 @@ def test_datasource_find_by_id_skip_base_filter_not_found(
result = DatasetDAO.find_by_id(
125326326,
session=session_with_data,
skip_base_filter=True,
)
assert result is None
@@ -79,7 +77,6 @@ def test_datasource_find_by_ids_skip_base_filter(session_with_data: Session) ->
result = DatasetDAO.find_by_ids(
[1, 125326326],
session=session_with_data,
skip_base_filter=True,
)
@@ -96,7 +93,6 @@ def test_datasource_find_by_ids_skip_base_filter_not_found(
result = DatasetDAO.find_by_ids(
[125326326, 125326326125326326],
session=session_with_data,
skip_base_filter=True,
)

View File

@@ -35,7 +35,7 @@ def session_with_data(session: Session) -> Iterator[Session]:
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
columns = [
TableColumn(column_name="a", type="INTEGER"),
@@ -45,12 +45,12 @@ def session_with_data(session: Session) -> Iterator[Session]:
table_name="my_sqla_table",
columns=columns,
metrics=[],
database=db,
database=database,
)
query_obj = Query(
client_id="foo",
database=db,
database=database,
tab_name="test_tab",
sql_editor_id="test_editor_id",
sql="select * from bar",
@@ -63,13 +63,13 @@ def session_with_data(session: Session) -> Iterator[Session]:
results_key="abc",
)
saved_query = SavedQuery(database=db, sql="select * from foo")
saved_query = SavedQuery(database=database, sql="select * from foo")
table = Table(
name="my_table",
schema="my_schema",
catalog="my_catalog",
database=db,
database=database,
columns=[],
)
@@ -93,7 +93,7 @@ FROM my_catalog.my_schema.my_table
session.add(table)
session.add(saved_query)
session.add(query_obj)
session.add(db)
session.add(database)
session.add(sqla_table)
session.flush()
yield session
@@ -190,7 +190,7 @@ def test_get_datasource_w_str_param(session_with_data: Session) -> None:
def test_get_all_datasources(session_with_data: Session) -> None:
from superset.connectors.sqla.models import SqlaTable
result = SqlaTable.get_all_datasources(session=session_with_data)
result = SqlaTable.get_all_datasources()
assert len(result) == 1

View File

@@ -74,10 +74,10 @@ def test_extras_without_ssl() -> None:
from superset.db_engine_specs.druid import DruidEngineSpec
from tests.integration_tests.fixtures.database import default_db_extra
db = mock.Mock()
db.extra = default_db_extra
db.server_cert = None
extras = DruidEngineSpec.get_extra_params(db)
database = mock.Mock()
database.extra = default_db_extra
database.server_cert = None
extras = DruidEngineSpec.get_extra_params(database)
assert "connect_args" not in extras["engine_params"]
@@ -86,10 +86,10 @@ def test_extras_with_ssl() -> None:
from tests.integration_tests.fixtures.certificates import ssl_certificate
from tests.integration_tests.fixtures.database import default_db_extra
db = mock.Mock()
db.extra = default_db_extra
db.server_cert = ssl_certificate
extras = DruidEngineSpec.get_extra_params(db)
database = mock.Mock()
database.extra = default_db_extra
database.server_cert = ssl_certificate
extras = DruidEngineSpec.get_extra_params(database)
connect_args = extras["engine_params"]["connect_args"]
assert connect_args["scheme"] == "https"
assert "ssl_verify_cert" in connect_args

View File

@@ -50,8 +50,8 @@ def test_extras_without_ssl() -> None:
from superset.db_engine_specs.pinot import PinotEngineSpec as spec
from tests.integration_tests.fixtures.database import default_db_extra
db = mock.Mock()
db.extra = default_db_extra
db.server_cert = None
extras = spec.get_extra_params(db)
database = mock.Mock()
database.extra = default_db_extra
database.server_cert = None
extras = spec.get_extra_params(database)
assert "connect_args" not in extras["engine_params"]

View File

@@ -26,6 +26,7 @@ from sqlalchemy.engine import create_engine
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.orm.session import Session
from superset import db
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetSecurityException
from tests.unit_tests.conftest import with_feature_flags
@@ -38,7 +39,7 @@ if TYPE_CHECKING:
def database1(session: Session) -> Iterator["Database"]:
from superset.models.core import Database
engine = session.connection().engine
engine = db.session.connection().engine
Database.metadata.create_all(engine) # pylint: disable=no-member
database = Database(
@@ -46,13 +47,13 @@ def database1(session: Session) -> Iterator["Database"]:
sqlalchemy_uri="sqlite:///database1.db",
allow_dml=True,
)
session.add(database)
session.commit()
db.session.add(database)
db.session.commit()
yield database
session.delete(database)
session.commit()
db.session.delete(database)
db.session.commit()
os.unlink("database1.db")
@@ -62,12 +63,12 @@ def table1(session: Session, database1: "Database") -> Iterator[None]:
conn = engine.connect()
conn.execute("CREATE TABLE table1 (a INTEGER NOT NULL PRIMARY KEY, b INTEGER)")
conn.execute("INSERT INTO table1 (a, b) VALUES (1, 10), (2, 20)")
session.commit()
db.session.commit()
yield
conn.execute("DROP TABLE table1")
session.commit()
db.session.commit()
@pytest.fixture
@@ -79,13 +80,13 @@ def database2(session: Session) -> Iterator["Database"]:
sqlalchemy_uri="sqlite:///database2.db",
allow_dml=False,
)
session.add(database)
session.commit()
db.session.add(database)
db.session.commit()
yield database
session.delete(database)
session.commit()
db.session.delete(database)
db.session.commit()
os.unlink("database2.db")
@@ -95,12 +96,12 @@ def table2(session: Session, database2: "Database") -> Iterator[None]:
conn = engine.connect()
conn.execute("CREATE TABLE table2 (a INTEGER NOT NULL PRIMARY KEY, b TEXT)")
conn.execute("INSERT INTO table2 (a, b) VALUES (1, 'ten'), (2, 'twenty')")
session.commit()
db.session.commit()
yield
conn.execute("DROP TABLE table2")
session.commit()
db.session.commit()
@with_feature_flags(ENABLE_SUPERSET_META_DB=True)

View File

@@ -22,10 +22,10 @@ def test_column_attributes_on_query():
from superset.models.core import Database
from superset.models.sql_lab import Query
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
query_obj = Query(
client_id="foo",
database=db,
database=database,
tab_name="test_tab",
sql_editor_id="test_editor_id",
sql="select * from bar",

View File

@@ -125,7 +125,7 @@ def test_sql_lab_insert_rls_as_subquery(
from superset.sql_lab import execute_sql_statement
from superset.utils.core import RowLevelSecurityFilterType
engine = session.connection().engine
engine = db.session.connection().engine
Query.metadata.create_all(engine) # pylint: disable=no-member
connection = engine.raw_connection()
@@ -143,8 +143,8 @@ def test_sql_lab_insert_rls_as_subquery(
limit=5,
select_as_cta_used=False,
)
session.add(query)
session.commit()
db.session.add(query)
db.session.commit()
admin = User(
first_name="Alice",
@@ -185,8 +185,8 @@ def test_sql_lab_insert_rls_as_subquery(
group_key=None,
clause="c > 5",
)
session.add(rls)
session.flush()
db.session.add(rls)
db.session.flush()
mocker.patch.object(SupersetSecurityManager, "find_user", return_value=admin)
mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True)

View File

@@ -1759,8 +1759,7 @@ def test_get_rls_for_table(mocker: MockerFixture) -> None:
Tests for ``get_rls_for_table``.
"""
candidate = Identifier([Token(Name, "some_table")])
db = mocker.patch("superset.db")
dataset = db.session.query().filter().one_or_none()
dataset = mocker.patch("superset.db").session.query().filter().one_or_none()
dataset.__str__.return_value = "some_table"
dataset.get_sqla_row_level_filters.return_value = [text("organization_id = 1")]

View File

@@ -14,11 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-outside-toplevel, unused-argument
from sqlalchemy.orm.session import Session
from superset import db
def test_table_model(session: Session) -> None:
"""
@@ -28,7 +28,7 @@ def test_table_model(session: Session) -> None:
from superset.models.core import Database
from superset.tables.models import Table
engine = session.get_bind()
engine = db.session.get_bind()
Table.metadata.create_all(engine) # pylint: disable=no-member
table = Table(
@@ -44,8 +44,8 @@ def test_table_model(session: Session) -> None:
)
],
)
session.add(table)
session.flush()
db.session.add(table)
db.session.flush()
assert table.id == 1
assert table.uuid is not None

View File

@@ -18,6 +18,7 @@ import pytest
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session
from superset import db
from superset.utils.core import DatasourceType
@@ -40,13 +41,15 @@ def session_with_data(session: Session):
slice_name="slice_name",
)
db = Database(database_name="my_database", sqlalchemy_uri="postgresql://")
database = Database(database_name="my_database", sqlalchemy_uri="postgresql://")
columns = [
TableColumn(column_name="a", type="INTEGER"),
]
saved_query = SavedQuery(label="test_query", database=db, sql="select * from foo")
saved_query = SavedQuery(
label="test_query", database=database, sql="select * from foo"
)
dashboard_obj = Dashboard(
id=100,
@@ -57,7 +60,7 @@ def session_with_data(session: Session):
)
session.add(slice_obj)
session.add(db)
session.add(database)
session.add(saved_query)
session.add(dashboard_obj)
session.commit()
@@ -74,9 +77,9 @@ def test_create_command_success(session_with_data: Session, mocker: MockFixture)
from superset.tags.models import ObjectType, TaggedObject
# Define a list of objects to tag
query = session_with_data.query(SavedQuery).first()
chart = session_with_data.query(Slice).first()
dashboard = session_with_data.query(Dashboard).first()
query = db.session.query(SavedQuery).first()
chart = db.session.query(Slice).first()
dashboard = db.session.query(Dashboard).first()
mocker.patch(
"superset.security.SupersetSecurityManager.is_admin", return_value=True
@@ -94,10 +97,10 @@ def test_create_command_success(session_with_data: Session, mocker: MockFixture)
data={"name": "test_tag", "objects_to_tag": objects_to_tag}
).run()
assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag)
assert len(db.session.query(TaggedObject).all()) == len(objects_to_tag)
for object_type, object_id in objects_to_tag:
assert (
session_with_data.query(TaggedObject)
db.session.query(TaggedObject)
.filter(
TaggedObject.object_type == object_type,
TaggedObject.object_id == object_id,
@@ -117,9 +120,9 @@ def test_create_command_success_clear(session_with_data: Session, mocker: MockFi
from superset.tags.models import ObjectType, TaggedObject
# Define a list of objects to tag
query = session_with_data.query(SavedQuery).first()
chart = session_with_data.query(Slice).first()
dashboard = session_with_data.query(Dashboard).first()
query = db.session.query(SavedQuery).first()
chart = db.session.query(Slice).first()
dashboard = db.session.query(Dashboard).first()
mocker.patch(
"superset.security.SupersetSecurityManager.is_admin", return_value=True
@@ -136,10 +139,10 @@ def test_create_command_success_clear(session_with_data: Session, mocker: MockFi
CreateCustomTagWithRelationshipsCommand(
data={"name": "test_tag", "objects_to_tag": objects_to_tag}
).run()
assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag)
assert len(db.session.query(TaggedObject).all()) == len(objects_to_tag)
CreateCustomTagWithRelationshipsCommand(
data={"name": "test_tag", "objects_to_tag": []}
).run()
assert len(session_with_data.query(TaggedObject).all()) == 0
assert len(db.session.query(TaggedObject).all()) == 0

View File

@@ -18,6 +18,7 @@ import pytest
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session
from superset import db
from superset.utils.core import DatasourceType
@@ -41,7 +42,7 @@ def session_with_data(session: Session):
slice_name="slice_name",
)
db = Database(database_name="my_database", sqlalchemy_uri="postgresql://")
database = Database(database_name="my_database", sqlalchemy_uri="postgresql://")
columns = [
TableColumn(column_name="a", type="INTEGER"),
@@ -51,7 +52,7 @@ def session_with_data(session: Session):
table_name="my_sqla_table",
columns=columns,
metrics=[],
database=db,
database=database,
)
dashboard_obj = Dashboard(
@@ -62,7 +63,9 @@ def session_with_data(session: Session):
published=True,
)
saved_query = SavedQuery(label="test_query", database=db, sql="select * from foo")
saved_query = SavedQuery(
label="test_query", database=database, sql="select * from foo"
)
tag = Tag(name="test_name", description="test_description")
@@ -79,7 +82,7 @@ def test_update_command_success(session_with_data: Session, mocker: MockFixture)
from superset.models.dashboard import Dashboard
from superset.tags.models import ObjectType, TaggedObject
dashboard = session_with_data.query(Dashboard).first()
dashboard = db.session.query(Dashboard).first()
mocker.patch(
"superset.security.SupersetSecurityManager.is_admin", return_value=True
)
@@ -104,7 +107,7 @@ def test_update_command_success(session_with_data: Session, mocker: MockFixture)
updated_tag = TagDAO.find_by_name("new_name")
assert updated_tag is not None
assert updated_tag.description == "new_description"
assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag)
assert len(db.session.query(TaggedObject).all()) == len(objects_to_tag)
def test_update_command_success_duplicates(
@@ -117,8 +120,8 @@ def test_update_command_success_duplicates(
from superset.models.slice import Slice
from superset.tags.models import ObjectType, TaggedObject
dashboard = session_with_data.query(Dashboard).first()
chart = session_with_data.query(Slice).first()
dashboard = db.session.query(Dashboard).first()
chart = db.session.query(Slice).first()
mocker.patch(
"superset.security.SupersetSecurityManager.is_admin", return_value=True
@@ -153,7 +156,7 @@ def test_update_command_success_duplicates(
updated_tag = TagDAO.find_by_name("new_name")
assert updated_tag is not None
assert updated_tag.description == "new_description"
assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag)
assert len(db.session.query(TaggedObject).all()) == len(objects_to_tag)
assert changed_model.objects[0].object_id == chart.id
@@ -168,8 +171,8 @@ def test_update_command_failed_validation(
from superset.models.slice import Slice
from superset.tags.models import ObjectType
dashboard = session_with_data.query(Dashboard).first()
chart = session_with_data.query(Slice).first()
dashboard = db.session.query(Dashboard).first()
chart = db.session.query(Slice).first()
objects_to_tag = [
(ObjectType.chart, chart.id),
]