mirror of
https://github.com/apache/superset.git
synced 2026-04-19 16:14:52 +00:00
refactor: Ensure Flask framework leverages the Flask-SQLAlchemy session (Phase II) (#26909)
This commit is contained in:
@@ -24,7 +24,7 @@ from flask_appbuilder.security.sqla.models import Role, User
|
||||
from pytest_mock import MockFixture
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset import security_manager
|
||||
from superset import db, security_manager
|
||||
from superset.commands.chart.importers.v1.utils import import_chart
|
||||
from superset.commands.exceptions import ImportFailedError
|
||||
from superset.connectors.sqla.models import Database, SqlaTable
|
||||
@@ -82,7 +82,7 @@ def test_import_chart(mocker: MockFixture, session_with_schema: Session) -> None
|
||||
config["datasource_id"] = 1
|
||||
config["datasource_type"] = "table"
|
||||
|
||||
chart = import_chart(session_with_schema, config)
|
||||
chart = import_chart(config)
|
||||
assert chart.slice_name == "Deck Path"
|
||||
assert chart.viz_type == "deck_path"
|
||||
assert chart.is_managed_externally is False
|
||||
@@ -106,7 +106,7 @@ def test_import_chart_managed_externally(
|
||||
config["is_managed_externally"] = True
|
||||
config["external_url"] = "https://example.org/my_chart"
|
||||
|
||||
chart = import_chart(session_with_schema, config)
|
||||
chart = import_chart(config)
|
||||
assert chart.is_managed_externally is True
|
||||
assert chart.external_url == "https://example.org/my_chart"
|
||||
|
||||
@@ -128,7 +128,7 @@ def test_import_chart_without_permission(
|
||||
config["datasource_type"] = "table"
|
||||
|
||||
with pytest.raises(ImportFailedError) as excinfo:
|
||||
import_chart(session_with_schema, config)
|
||||
import_chart(config)
|
||||
assert (
|
||||
str(excinfo.value)
|
||||
== "Chart doesn't exist and user doesn't have permission to create charts"
|
||||
@@ -173,7 +173,7 @@ def test_import_existing_chart_without_permission(
|
||||
|
||||
with override_user("admin"):
|
||||
with pytest.raises(ImportFailedError) as excinfo:
|
||||
import_chart(session_with_data, chart_config, overwrite=True)
|
||||
import_chart(chart_config, overwrite=True)
|
||||
assert (
|
||||
str(excinfo.value)
|
||||
== "A chart already exists and user doesn't have permissions to overwrite it"
|
||||
@@ -213,7 +213,7 @@ def test_import_existing_chart_with_permission(
|
||||
)
|
||||
|
||||
with override_user(admin):
|
||||
import_chart(session_with_data, config, overwrite=True)
|
||||
import_chart(config, overwrite=True)
|
||||
# Assert that the can write to chart was checked
|
||||
security_manager.can_access.assert_called_once_with("can_write", "Chart")
|
||||
security_manager.can_access_chart.assert_called_once_with(slice)
|
||||
|
||||
@@ -48,7 +48,7 @@ def test_slice_find_by_id_skip_base_filter(session_with_data: Session) -> None:
|
||||
from superset.daos.chart import ChartDAO
|
||||
from superset.models.slice import Slice
|
||||
|
||||
result = ChartDAO.find_by_id(1, session=session_with_data, skip_base_filter=True)
|
||||
result = ChartDAO.find_by_id(1, skip_base_filter=True)
|
||||
|
||||
assert result
|
||||
assert 1 == result.id
|
||||
@@ -57,20 +57,18 @@ def test_slice_find_by_id_skip_base_filter(session_with_data: Session) -> None:
|
||||
|
||||
|
||||
def test_datasource_find_by_id_skip_base_filter_not_found(
|
||||
session_with_data: Session,
|
||||
session: Session,
|
||||
) -> None:
|
||||
from superset.daos.chart import ChartDAO
|
||||
|
||||
result = ChartDAO.find_by_id(
|
||||
125326326, session=session_with_data, skip_base_filter=True
|
||||
)
|
||||
result = ChartDAO.find_by_id(125326326, skip_base_filter=True)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_add_favorite(session_with_data: Session) -> None:
|
||||
def test_add_favorite(session: Session) -> None:
|
||||
from superset.daos.chart import ChartDAO
|
||||
|
||||
chart = ChartDAO.find_by_id(1, session=session_with_data, skip_base_filter=True)
|
||||
chart = ChartDAO.find_by_id(1, skip_base_filter=True)
|
||||
if not chart:
|
||||
return
|
||||
assert len(ChartDAO.favorited_ids([chart])) == 0
|
||||
@@ -82,10 +80,10 @@ def test_add_favorite(session_with_data: Session) -> None:
|
||||
assert len(ChartDAO.favorited_ids([chart])) == 1
|
||||
|
||||
|
||||
def test_remove_favorite(session_with_data: Session) -> None:
|
||||
def test_remove_favorite(session: Session) -> None:
|
||||
from superset.daos.chart import ChartDAO
|
||||
|
||||
chart = ChartDAO.find_by_id(1, session=session_with_data, skip_base_filter=True)
|
||||
chart = ChartDAO.find_by_id(1, skip_base_filter=True)
|
||||
if not chart:
|
||||
return
|
||||
assert len(ChartDAO.favorited_ids([chart])) == 0
|
||||
|
||||
@@ -1965,12 +1965,13 @@ def test_apply_post_process_json_format_data_is_none():
|
||||
|
||||
|
||||
def test_apply_post_process_verbose_map(session: Session):
|
||||
from superset import db
|
||||
from superset.connectors.sqla.models import SqlaTable, SqlMetric
|
||||
from superset.models.core import Database
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
sqla_table = SqlaTable(
|
||||
table_name="my_sqla_table",
|
||||
columns=[],
|
||||
@@ -1982,7 +1983,7 @@ def test_apply_post_process_verbose_map(session: Session):
|
||||
expression="COUNT(*)",
|
||||
)
|
||||
],
|
||||
database=db,
|
||||
database=database,
|
||||
)
|
||||
|
||||
result = {
|
||||
|
||||
@@ -24,9 +24,10 @@ def test_column_model(session: Session) -> None:
|
||||
"""
|
||||
Test basic attributes of a ``Column``.
|
||||
"""
|
||||
from superset import db
|
||||
from superset.columns.models import Column
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
Column.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
column = Column(
|
||||
@@ -35,8 +36,8 @@ def test_column_model(session: Session) -> None:
|
||||
expression="ds",
|
||||
)
|
||||
|
||||
session.add(column)
|
||||
session.flush()
|
||||
db.session.add(column)
|
||||
db.session.flush()
|
||||
|
||||
assert column.id == 1
|
||||
assert column.uuid is not None
|
||||
|
||||
@@ -35,14 +35,14 @@ def test_import_new_assets(mocker: MockFixture, session: Session) -> None:
|
||||
"""
|
||||
Test that all new assets are imported correctly.
|
||||
"""
|
||||
from superset import security_manager
|
||||
from superset import db, security_manager
|
||||
from superset.commands.importers.v1.assets import ImportAssetsCommand
|
||||
from superset.models.dashboard import dashboard_slices
|
||||
from superset.models.slice import Slice
|
||||
|
||||
mocker.patch.object(security_manager, "can_access", return_value=True)
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
Slice.metadata.create_all(engine) # pylint: disable=no-member
|
||||
configs = {
|
||||
**copy.deepcopy(databases_config),
|
||||
@@ -53,11 +53,11 @@ def test_import_new_assets(mocker: MockFixture, session: Session) -> None:
|
||||
expected_number_of_dashboards = len(dashboards_config_1)
|
||||
expected_number_of_charts = len(charts_config_1)
|
||||
|
||||
ImportAssetsCommand._import(session, configs)
|
||||
dashboard_ids = session.scalars(
|
||||
ImportAssetsCommand._import(configs)
|
||||
dashboard_ids = db.session.scalars(
|
||||
select(dashboard_slices.c.dashboard_id).distinct()
|
||||
).all()
|
||||
chart_ids = session.scalars(select(dashboard_slices.c.slice_id)).all()
|
||||
chart_ids = db.session.scalars(select(dashboard_slices.c.slice_id)).all()
|
||||
|
||||
assert len(chart_ids) == expected_number_of_charts
|
||||
assert len(dashboard_ids) == expected_number_of_dashboards
|
||||
@@ -67,14 +67,14 @@ def test_import_adds_dashboard_charts(mocker: MockFixture, session: Session) ->
|
||||
"""
|
||||
Test that existing dashboards are updated with new charts.
|
||||
"""
|
||||
from superset import security_manager
|
||||
from superset import db, security_manager
|
||||
from superset.commands.importers.v1.assets import ImportAssetsCommand
|
||||
from superset.models.dashboard import dashboard_slices
|
||||
from superset.models.slice import Slice
|
||||
|
||||
mocker.patch.object(security_manager, "can_access", return_value=True)
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
Slice.metadata.create_all(engine) # pylint: disable=no-member
|
||||
base_configs = {
|
||||
**copy.deepcopy(databases_config),
|
||||
@@ -91,12 +91,12 @@ def test_import_adds_dashboard_charts(mocker: MockFixture, session: Session) ->
|
||||
expected_number_of_dashboards = len(dashboards_config_1)
|
||||
expected_number_of_charts = len(charts_config_1)
|
||||
|
||||
ImportAssetsCommand._import(session, base_configs)
|
||||
ImportAssetsCommand._import(session, new_configs)
|
||||
dashboard_ids = session.scalars(
|
||||
ImportAssetsCommand._import(base_configs)
|
||||
ImportAssetsCommand._import(new_configs)
|
||||
dashboard_ids = db.session.scalars(
|
||||
select(dashboard_slices.c.dashboard_id).distinct()
|
||||
).all()
|
||||
chart_ids = session.scalars(select(dashboard_slices.c.slice_id)).all()
|
||||
chart_ids = db.session.scalars(select(dashboard_slices.c.slice_id)).all()
|
||||
|
||||
assert len(chart_ids) == expected_number_of_charts
|
||||
assert len(dashboard_ids) == expected_number_of_dashboards
|
||||
@@ -106,14 +106,14 @@ def test_import_removes_dashboard_charts(mocker: MockFixture, session: Session)
|
||||
"""
|
||||
Test that existing dashboards are updated without old charts.
|
||||
"""
|
||||
from superset import security_manager
|
||||
from superset import db, security_manager
|
||||
from superset.commands.importers.v1.assets import ImportAssetsCommand
|
||||
from superset.models.dashboard import dashboard_slices
|
||||
from superset.models.slice import Slice
|
||||
|
||||
mocker.patch.object(security_manager, "can_access", return_value=True)
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
Slice.metadata.create_all(engine) # pylint: disable=no-member
|
||||
base_configs = {
|
||||
**copy.deepcopy(databases_config),
|
||||
@@ -130,12 +130,12 @@ def test_import_removes_dashboard_charts(mocker: MockFixture, session: Session)
|
||||
expected_number_of_dashboards = len(dashboards_config_2)
|
||||
expected_number_of_charts = len(charts_config_2)
|
||||
|
||||
ImportAssetsCommand._import(session, base_configs)
|
||||
ImportAssetsCommand._import(session, new_configs)
|
||||
dashboard_ids = session.scalars(
|
||||
ImportAssetsCommand._import(base_configs)
|
||||
ImportAssetsCommand._import(new_configs)
|
||||
dashboard_ids = db.session.scalars(
|
||||
select(dashboard_slices.c.dashboard_id).distinct()
|
||||
).all()
|
||||
chart_ids = session.scalars(select(dashboard_slices.c.slice_id)).all()
|
||||
chart_ids = db.session.scalars(select(dashboard_slices.c.slice_id)).all()
|
||||
|
||||
assert len(chart_ids) == expected_number_of_charts
|
||||
assert len(dashboard_ids) == expected_number_of_dashboards
|
||||
|
||||
@@ -23,6 +23,8 @@ import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset import db
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
@@ -81,7 +83,7 @@ def test_table(session: Session) -> "SqlaTable":
|
||||
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||
from superset.models.core import Database
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
columns = [
|
||||
|
||||
@@ -41,7 +41,7 @@ from superset.initialization import SupersetAppInitializer
|
||||
@pytest.fixture
|
||||
def get_session(mocker: MockFixture) -> Callable[[], Session]:
|
||||
"""
|
||||
Create an in-memory SQLite session to test models.
|
||||
Create an in-memory SQLite db.session.to test models.
|
||||
"""
|
||||
engine = create_engine("sqlite://")
|
||||
|
||||
@@ -49,7 +49,7 @@ def get_session(mocker: MockFixture) -> Callable[[], Session]:
|
||||
Session_ = sessionmaker(bind=engine) # pylint: disable=invalid-name
|
||||
in_memory_session = Session_()
|
||||
|
||||
# flask calls session.remove()
|
||||
# flask calls db.session.remove()
|
||||
in_memory_session.remove = lambda: None
|
||||
|
||||
# patch session
|
||||
|
||||
@@ -27,6 +27,7 @@ def test_validate_update_uniqueness(session: Session) -> None:
|
||||
In particular, allow datasets with the same name in the same database as long as they
|
||||
are in different schemas
|
||||
"""
|
||||
from superset import db
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.models.core import Database
|
||||
|
||||
@@ -46,8 +47,8 @@ def test_validate_update_uniqueness(session: Session) -> None:
|
||||
schema="dev",
|
||||
database=database,
|
||||
)
|
||||
session.add_all([database, dataset1, dataset2])
|
||||
session.flush()
|
||||
db.session.add_all([database, dataset1, dataset2])
|
||||
db.session.flush()
|
||||
|
||||
# same table name, different schema
|
||||
assert (
|
||||
|
||||
@@ -25,17 +25,18 @@ from superset.exceptions import QueryNotFoundException, SupersetCancelQueryExcep
|
||||
|
||||
|
||||
def test_query_dao_save_metadata(session: Session) -> None:
|
||||
from superset import db
|
||||
from superset.models.core import Database
|
||||
from superset.models.sql_lab import Query
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
Query.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
|
||||
query_obj = Query(
|
||||
client_id="foo",
|
||||
database=db,
|
||||
database=database,
|
||||
tab_name="test_tab",
|
||||
sql_editor_id="test_editor_id",
|
||||
sql="select * from bar",
|
||||
@@ -48,30 +49,31 @@ def test_query_dao_save_metadata(session: Session) -> None:
|
||||
results_key="abc",
|
||||
)
|
||||
|
||||
session.add(db)
|
||||
session.add(query_obj)
|
||||
db.session.add(database)
|
||||
db.session.add(query_obj)
|
||||
|
||||
from superset.daos.query import QueryDAO
|
||||
|
||||
query = session.query(Query).one()
|
||||
query = db.session.query(Query).one()
|
||||
QueryDAO.save_metadata(query=query, payload={"columns": []})
|
||||
assert query.extra.get("columns", None) == []
|
||||
|
||||
|
||||
def test_query_dao_get_queries_changed_after(session: Session) -> None:
|
||||
from superset import db
|
||||
from superset.models.core import Database
|
||||
from superset.models.sql_lab import Query
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
Query.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
old_query_obj = Query(
|
||||
client_id="foo",
|
||||
database=db,
|
||||
database=database,
|
||||
tab_name="test_tab",
|
||||
sql_editor_id="test_editor_id",
|
||||
sql="select * from bar",
|
||||
@@ -87,7 +89,7 @@ def test_query_dao_get_queries_changed_after(session: Session) -> None:
|
||||
|
||||
updated_query_obj = Query(
|
||||
client_id="updated_foo",
|
||||
database=db,
|
||||
database=database,
|
||||
tab_name="test_tab",
|
||||
sql_editor_id="test_editor_id",
|
||||
sql="select * from foo",
|
||||
@@ -101,9 +103,9 @@ def test_query_dao_get_queries_changed_after(session: Session) -> None:
|
||||
changed_on=now - timedelta(days=1),
|
||||
)
|
||||
|
||||
session.add(db)
|
||||
session.add(old_query_obj)
|
||||
session.add(updated_query_obj)
|
||||
db.session.add(database)
|
||||
db.session.add(old_query_obj)
|
||||
db.session.add(updated_query_obj)
|
||||
|
||||
from superset.daos.query import QueryDAO
|
||||
|
||||
@@ -116,18 +118,19 @@ def test_query_dao_get_queries_changed_after(session: Session) -> None:
|
||||
def test_query_dao_stop_query_not_found(
|
||||
mocker: MockFixture, app: Any, session: Session
|
||||
) -> None:
|
||||
from superset import db
|
||||
from superset.common.db_query_status import QueryStatus
|
||||
from superset.models.core import Database
|
||||
from superset.models.sql_lab import Query
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
Query.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
|
||||
query_obj = Query(
|
||||
client_id="foo",
|
||||
database=db,
|
||||
database=database,
|
||||
tab_name="test_tab",
|
||||
sql_editor_id="test_editor_id",
|
||||
sql="select * from bar",
|
||||
@@ -141,8 +144,8 @@ def test_query_dao_stop_query_not_found(
|
||||
status=QueryStatus.RUNNING,
|
||||
)
|
||||
|
||||
session.add(db)
|
||||
session.add(query_obj)
|
||||
db.session.add(database)
|
||||
db.session.add(query_obj)
|
||||
|
||||
mocker.patch("superset.sql_lab.cancel_query", return_value=False)
|
||||
|
||||
@@ -151,25 +154,26 @@ def test_query_dao_stop_query_not_found(
|
||||
with pytest.raises(QueryNotFoundException):
|
||||
QueryDAO.stop_query("foo2")
|
||||
|
||||
query = session.query(Query).one()
|
||||
query = db.session.query(Query).one()
|
||||
assert query.status == QueryStatus.RUNNING
|
||||
|
||||
|
||||
def test_query_dao_stop_query_not_running(
|
||||
mocker: MockFixture, app: Any, session: Session
|
||||
) -> None:
|
||||
from superset import db
|
||||
from superset.common.db_query_status import QueryStatus
|
||||
from superset.models.core import Database
|
||||
from superset.models.sql_lab import Query
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
Query.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
|
||||
query_obj = Query(
|
||||
client_id="foo",
|
||||
database=db,
|
||||
database=database,
|
||||
tab_name="test_tab",
|
||||
sql_editor_id="test_editor_id",
|
||||
sql="select * from bar",
|
||||
@@ -183,31 +187,32 @@ def test_query_dao_stop_query_not_running(
|
||||
status=QueryStatus.FAILED,
|
||||
)
|
||||
|
||||
session.add(db)
|
||||
session.add(query_obj)
|
||||
db.session.add(database)
|
||||
db.session.add(query_obj)
|
||||
|
||||
from superset.daos.query import QueryDAO
|
||||
|
||||
QueryDAO.stop_query(query_obj.client_id)
|
||||
query = session.query(Query).one()
|
||||
query = db.session.query(Query).one()
|
||||
assert query.status == QueryStatus.FAILED
|
||||
|
||||
|
||||
def test_query_dao_stop_query_failed(
|
||||
mocker: MockFixture, app: Any, session: Session
|
||||
) -> None:
|
||||
from superset import db
|
||||
from superset.common.db_query_status import QueryStatus
|
||||
from superset.models.core import Database
|
||||
from superset.models.sql_lab import Query
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
Query.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
|
||||
query_obj = Query(
|
||||
client_id="foo",
|
||||
database=db,
|
||||
database=database,
|
||||
tab_name="test_tab",
|
||||
sql_editor_id="test_editor_id",
|
||||
sql="select * from bar",
|
||||
@@ -221,8 +226,8 @@ def test_query_dao_stop_query_failed(
|
||||
status=QueryStatus.RUNNING,
|
||||
)
|
||||
|
||||
session.add(db)
|
||||
session.add(query_obj)
|
||||
db.session.add(database)
|
||||
db.session.add(query_obj)
|
||||
|
||||
mocker.patch("superset.sql_lab.cancel_query", return_value=False)
|
||||
|
||||
@@ -231,23 +236,24 @@ def test_query_dao_stop_query_failed(
|
||||
with pytest.raises(SupersetCancelQueryException):
|
||||
QueryDAO.stop_query(query_obj.client_id)
|
||||
|
||||
query = session.query(Query).one()
|
||||
query = db.session.query(Query).one()
|
||||
assert query.status == QueryStatus.RUNNING
|
||||
|
||||
|
||||
def test_query_dao_stop_query(mocker: MockFixture, app: Any, session: Session) -> None:
|
||||
from superset import db
|
||||
from superset.common.db_query_status import QueryStatus
|
||||
from superset.models.core import Database
|
||||
from superset.models.sql_lab import Query
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
Query.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
|
||||
query_obj = Query(
|
||||
client_id="foo",
|
||||
database=db,
|
||||
database=database,
|
||||
tab_name="test_tab",
|
||||
sql_editor_id="test_editor_id",
|
||||
sql="select * from bar",
|
||||
@@ -261,13 +267,13 @@ def test_query_dao_stop_query(mocker: MockFixture, app: Any, session: Session) -
|
||||
status=QueryStatus.RUNNING,
|
||||
)
|
||||
|
||||
session.add(db)
|
||||
session.add(query_obj)
|
||||
db.session.add(database)
|
||||
db.session.add(query_obj)
|
||||
|
||||
mocker.patch("superset.sql_lab.cancel_query", return_value=True)
|
||||
|
||||
from superset.daos.query import QueryDAO
|
||||
|
||||
QueryDAO.stop_query(query_obj.client_id)
|
||||
query = session.query(Query).one()
|
||||
query = db.session.query(Query).one()
|
||||
assert query.status == QueryStatus.STOPPED
|
||||
|
||||
@@ -70,7 +70,7 @@ def test_remove_user_favorite_tag(mocker):
|
||||
# Check that users_favorited no longer contains the user
|
||||
assert mock_user not in mock_tag.users_favorited
|
||||
|
||||
# Check that the session was committed
|
||||
# Check that the db.session.was committed
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from flask_appbuilder.security.sqla.models import Role, User
|
||||
from pytest_mock import MockFixture
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset import security_manager
|
||||
from superset import db, security_manager
|
||||
from superset.commands.dashboard.importers.v1.utils import import_dashboard
|
||||
from superset.commands.exceptions import ImportFailedError
|
||||
from superset.models.dashboard import Dashboard
|
||||
@@ -67,7 +67,7 @@ def test_import_dashboard(mocker: MockFixture, session_with_schema: Session) ->
|
||||
"""
|
||||
mocker.patch.object(security_manager, "can_access", return_value=True)
|
||||
|
||||
dashboard = import_dashboard(session_with_schema, dashboard_config)
|
||||
dashboard = import_dashboard(dashboard_config)
|
||||
assert dashboard.dashboard_title == "Test dash"
|
||||
assert dashboard.description is None
|
||||
assert dashboard.is_managed_externally is False
|
||||
@@ -88,8 +88,7 @@ def test_import_dashboard_managed_externally(
|
||||
config = copy.deepcopy(dashboard_config)
|
||||
config["is_managed_externally"] = True
|
||||
config["external_url"] = "https://example.org/my_dashboard"
|
||||
|
||||
dashboard = import_dashboard(session_with_schema, config)
|
||||
dashboard = import_dashboard(config)
|
||||
assert dashboard.is_managed_externally is True
|
||||
assert dashboard.external_url == "https://example.org/my_dashboard"
|
||||
|
||||
@@ -107,7 +106,7 @@ def test_import_dashboard_without_permission(
|
||||
mocker.patch.object(security_manager, "can_access", return_value=False)
|
||||
|
||||
with pytest.raises(ImportFailedError) as excinfo:
|
||||
import_dashboard(session_with_schema, dashboard_config)
|
||||
import_dashboard(dashboard_config)
|
||||
assert (
|
||||
str(excinfo.value)
|
||||
== "Dashboard doesn't exist and user doesn't have permission to create dashboards"
|
||||
@@ -135,7 +134,7 @@ def test_import_existing_dashboard_without_permission(
|
||||
|
||||
with override_user("admin"):
|
||||
with pytest.raises(ImportFailedError) as excinfo:
|
||||
import_dashboard(session_with_data, dashboard_config, overwrite=True)
|
||||
import_dashboard(dashboard_config, overwrite=True)
|
||||
assert (
|
||||
str(excinfo.value)
|
||||
== "A dashboard already exists and user doesn't have permissions to overwrite it"
|
||||
@@ -171,7 +170,8 @@ def test_import_existing_dashboard_with_permission(
|
||||
)
|
||||
|
||||
with override_user(admin):
|
||||
import_dashboard(session_with_data, dashboard_config, overwrite=True)
|
||||
import_dashboard(dashboard_config, overwrite=True)
|
||||
|
||||
# Assert that the can write to dashboard was checked
|
||||
security_manager.can_access.assert_called_once_with("can_write", "Dashboard")
|
||||
security_manager.can_access_dashboard.assert_called_once_with(dashboard)
|
||||
|
||||
@@ -42,12 +42,10 @@ def session_with_data(session: Session) -> Iterator[Session]:
|
||||
session.rollback()
|
||||
|
||||
|
||||
def test_add_favorite(session_with_data: Session) -> None:
|
||||
def test_add_favorite(session: Session) -> None:
|
||||
from superset.daos.dashboard import DashboardDAO
|
||||
|
||||
dashboard = DashboardDAO.find_by_id(
|
||||
100, session=session_with_data, skip_base_filter=True
|
||||
)
|
||||
dashboard = DashboardDAO.find_by_id(100, skip_base_filter=True)
|
||||
if not dashboard:
|
||||
return
|
||||
assert len(DashboardDAO.favorited_ids([dashboard])) == 0
|
||||
@@ -59,12 +57,10 @@ def test_add_favorite(session_with_data: Session) -> None:
|
||||
assert len(DashboardDAO.favorited_ids([dashboard])) == 1
|
||||
|
||||
|
||||
def test_remove_favorite(session_with_data: Session) -> None:
|
||||
def test_remove_favorite(session: Session) -> None:
|
||||
from superset.daos.dashboard import DashboardDAO
|
||||
|
||||
dashboard = DashboardDAO.find_by_id(
|
||||
100, session=session_with_data, skip_base_filter=True
|
||||
)
|
||||
dashboard = DashboardDAO.find_by_id(100, skip_base_filter=True)
|
||||
if not dashboard:
|
||||
return
|
||||
assert len(DashboardDAO.favorited_ids([dashboard])) == 0
|
||||
|
||||
@@ -28,6 +28,8 @@ from flask import current_app
|
||||
from pytest_mock import MockFixture
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset import db
|
||||
|
||||
|
||||
def test_filter_by_uuid(
|
||||
session: Session,
|
||||
@@ -49,14 +51,14 @@ def test_filter_by_uuid(
|
||||
|
||||
# create table for databases
|
||||
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
|
||||
session.add(
|
||||
db.session.add(
|
||||
Database(
|
||||
database_name="my_db",
|
||||
sqlalchemy_uri="sqlite://",
|
||||
uuid=UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"),
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/database/?q=(filters:!((col:uuid,opr:eq,value:"
|
||||
@@ -96,7 +98,7 @@ def test_post_with_uuid(
|
||||
payload = response.json
|
||||
assert payload["result"]["uuid"] == "7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"
|
||||
|
||||
database = session.query(Database).one()
|
||||
database = db.session.query(Database).one()
|
||||
assert database.uuid == UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb")
|
||||
|
||||
|
||||
@@ -139,8 +141,8 @@ def test_password_mask(
|
||||
}
|
||||
),
|
||||
)
|
||||
session.add(database)
|
||||
session.commit()
|
||||
db.session.add(database)
|
||||
db.session.commit()
|
||||
|
||||
# mock the lookup so that we don't need to include the driver
|
||||
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
|
||||
@@ -195,8 +197,8 @@ def test_database_connection(
|
||||
}
|
||||
),
|
||||
)
|
||||
session.add(database)
|
||||
session.commit()
|
||||
db.session.add(database)
|
||||
db.session.commit()
|
||||
|
||||
# mock the lookup so that we don't need to include the driver
|
||||
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
|
||||
@@ -331,8 +333,8 @@ def test_update_with_password_mask(
|
||||
}
|
||||
),
|
||||
)
|
||||
session.add(database)
|
||||
session.commit()
|
||||
db.session.add(database)
|
||||
db.session.commit()
|
||||
|
||||
client.put(
|
||||
"/api/v1/database/1",
|
||||
@@ -347,7 +349,7 @@ def test_update_with_password_mask(
|
||||
),
|
||||
},
|
||||
)
|
||||
database = session.query(Database).one()
|
||||
database = db.session.query(Database).one()
|
||||
assert (
|
||||
database.encrypted_extra
|
||||
== '{"service_account_info": {"project_id": "yellow-unicorn-314419", "private_key": "SECRET"}}'
|
||||
@@ -429,8 +431,8 @@ def test_delete_ssh_tunnel(
|
||||
}
|
||||
),
|
||||
)
|
||||
session.add(database)
|
||||
session.commit()
|
||||
db.session.add(database)
|
||||
db.session.commit()
|
||||
|
||||
# mock the lookup so that we don't need to include the driver
|
||||
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
|
||||
@@ -446,8 +448,8 @@ def test_delete_ssh_tunnel(
|
||||
database=database,
|
||||
)
|
||||
|
||||
session.add(tunnel)
|
||||
session.commit()
|
||||
db.session.add(tunnel)
|
||||
db.session.commit()
|
||||
|
||||
# Get our recently created SSHTunnel
|
||||
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
|
||||
@@ -505,8 +507,8 @@ def test_delete_ssh_tunnel_not_found(
|
||||
}
|
||||
),
|
||||
)
|
||||
session.add(database)
|
||||
session.commit()
|
||||
db.session.add(database)
|
||||
db.session.commit()
|
||||
|
||||
# mock the lookup so that we don't need to include the driver
|
||||
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
|
||||
@@ -522,8 +524,8 @@ def test_delete_ssh_tunnel_not_found(
|
||||
database=database,
|
||||
)
|
||||
|
||||
session.add(tunnel)
|
||||
session.commit()
|
||||
db.session.add(tunnel)
|
||||
db.session.commit()
|
||||
|
||||
# Delete the recently created SSHTunnel
|
||||
response_delete_tunnel = client.delete("/api/v1/database/2/ssh_tunnel/")
|
||||
@@ -576,8 +578,8 @@ def test_apply_dynamic_database_filter(
|
||||
}
|
||||
),
|
||||
)
|
||||
session.add(database)
|
||||
session.commit()
|
||||
db.session.add(database)
|
||||
db.session.commit()
|
||||
|
||||
# Create our Second Database
|
||||
database = Database(
|
||||
@@ -592,8 +594,8 @@ def test_apply_dynamic_database_filter(
|
||||
}
|
||||
),
|
||||
)
|
||||
session.add(database)
|
||||
session.commit()
|
||||
db.session.add(database)
|
||||
db.session.commit()
|
||||
|
||||
# mock the lookup so that we don't need to include the driver
|
||||
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
|
||||
|
||||
@@ -23,6 +23,7 @@ import pytest
|
||||
from pytest_mock import MockFixture
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset import db
|
||||
from superset.commands.exceptions import ImportFailedError
|
||||
|
||||
|
||||
@@ -37,11 +38,11 @@ def test_import_database(mocker: MockFixture, session: Session) -> None:
|
||||
|
||||
mocker.patch.object(security_manager, "can_access", return_value=True)
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
Database.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
config = copy.deepcopy(database_config)
|
||||
database = import_database(session, config)
|
||||
database = import_database(config)
|
||||
assert database.database_name == "imported_database"
|
||||
assert database.sqlalchemy_uri == "someengine://user:pass@host1"
|
||||
assert database.cache_timeout is None
|
||||
@@ -60,9 +61,9 @@ def test_import_database(mocker: MockFixture, session: Session) -> None:
|
||||
# missing
|
||||
config = copy.deepcopy(database_config)
|
||||
del config["allow_dml"]
|
||||
session.delete(database)
|
||||
session.flush()
|
||||
database = import_database(session, config)
|
||||
db.session.delete(database)
|
||||
db.session.flush()
|
||||
database = import_database(config)
|
||||
assert database.allow_dml is False
|
||||
|
||||
|
||||
@@ -78,12 +79,12 @@ def test_import_database_sqlite_invalid(mocker: MockFixture, session: Session) -
|
||||
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = True
|
||||
mocker.patch.object(security_manager, "can_access", return_value=True)
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
Database.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
config = copy.deepcopy(database_config_sqlite)
|
||||
with pytest.raises(ImportFailedError) as excinfo:
|
||||
_ = import_database(session, config)
|
||||
_ = import_database(config)
|
||||
assert (
|
||||
str(excinfo.value)
|
||||
== "SQLiteDialect_pysqlite cannot be used as a data source for security reasons."
|
||||
@@ -106,14 +107,14 @@ def test_import_database_managed_externally(
|
||||
|
||||
mocker.patch.object(security_manager, "can_access", return_value=True)
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
Database.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
config = copy.deepcopy(database_config)
|
||||
config["is_managed_externally"] = True
|
||||
config["external_url"] = "https://example.org/my_database"
|
||||
|
||||
database = import_database(session, config)
|
||||
database = import_database(config)
|
||||
assert database.is_managed_externally is True
|
||||
assert database.external_url == "https://example.org/my_database"
|
||||
|
||||
@@ -132,13 +133,13 @@ def test_import_database_without_permission(
|
||||
|
||||
mocker.patch.object(security_manager, "can_access", return_value=False)
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
Database.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
config = copy.deepcopy(database_config)
|
||||
|
||||
with pytest.raises(ImportFailedError) as excinfo:
|
||||
import_database(session, config)
|
||||
import_database(config)
|
||||
assert (
|
||||
str(excinfo.value)
|
||||
== "Database doesn't exist and user doesn't have permission to create databases"
|
||||
@@ -156,10 +157,10 @@ def test_import_database_with_version(mocker: MockFixture, session: Session) ->
|
||||
|
||||
mocker.patch.object(security_manager, "can_access", return_value=True)
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
Database.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
config = copy.deepcopy(database_config)
|
||||
config["extra"]["version"] = "1.1.1"
|
||||
database = import_database(session, config)
|
||||
database = import_database(config)
|
||||
assert json.loads(database.extra)["version"] == "1.1.1"
|
||||
|
||||
@@ -30,19 +30,19 @@ def session_with_data(session: Session) -> Iterator[Session]:
|
||||
engine = session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
sqla_table = SqlaTable(
|
||||
table_name="my_sqla_table",
|
||||
columns=[],
|
||||
metrics=[],
|
||||
database=db,
|
||||
database=database,
|
||||
)
|
||||
ssh_tunnel = SSHTunnel(
|
||||
database_id=db.id,
|
||||
database=db,
|
||||
database_id=database.id,
|
||||
database=database,
|
||||
)
|
||||
|
||||
session.add(db)
|
||||
session.add(database)
|
||||
session.add(sqla_table)
|
||||
session.add(ssh_tunnel)
|
||||
session.flush()
|
||||
|
||||
@@ -27,17 +27,17 @@ def test_create_ssh_tunnel_command() -> None:
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
|
||||
properties = {
|
||||
"database_id": db.id,
|
||||
"database_id": database.id,
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": "3005",
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
|
||||
result = CreateSSHTunnelCommand(db, properties).run()
|
||||
result = CreateSSHTunnelCommand(database, properties).run()
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, SSHTunnel)
|
||||
@@ -48,19 +48,19 @@ def test_create_ssh_tunnel_command_invalid_params() -> None:
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
|
||||
# If we are trying to create a tunnel with a private_key_password
|
||||
# then a private_key is mandatory
|
||||
properties = {
|
||||
"database": db,
|
||||
"database": database,
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": "3005",
|
||||
"username": "foo",
|
||||
"private_key_password": "bar",
|
||||
}
|
||||
|
||||
command = CreateSSHTunnelCommand(db, properties)
|
||||
command = CreateSSHTunnelCommand(database, properties)
|
||||
|
||||
with pytest.raises(SSHTunnelInvalidError) as excinfo:
|
||||
command.run()
|
||||
|
||||
@@ -31,19 +31,19 @@ def session_with_data(session: Session) -> Iterator[Session]:
|
||||
engine = session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
sqla_table = SqlaTable(
|
||||
table_name="my_sqla_table",
|
||||
columns=[],
|
||||
metrics=[],
|
||||
database=db,
|
||||
database=database,
|
||||
)
|
||||
ssh_tunnel = SSHTunnel(
|
||||
database_id=db.id,
|
||||
database=db,
|
||||
database_id=database.id,
|
||||
database=database,
|
||||
)
|
||||
|
||||
session.add(db)
|
||||
session.add(database)
|
||||
session.add(sqla_table)
|
||||
session.add(ssh_tunnel)
|
||||
session.flush()
|
||||
|
||||
@@ -32,16 +32,18 @@ def session_with_data(session: Session) -> Iterator[Session]:
|
||||
engine = session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
sqla_table = SqlaTable(
|
||||
table_name="my_sqla_table",
|
||||
columns=[],
|
||||
metrics=[],
|
||||
database=db,
|
||||
database=database,
|
||||
)
|
||||
ssh_tunnel = SSHTunnel(
|
||||
database_id=database.id, database=database, server_address="Test"
|
||||
)
|
||||
ssh_tunnel = SSHTunnel(database_id=db.id, database=db, server_address="Test")
|
||||
|
||||
session.add(db)
|
||||
session.add(database)
|
||||
session.add(sqla_table)
|
||||
session.add(ssh_tunnel)
|
||||
session.flush()
|
||||
|
||||
@@ -25,11 +25,11 @@ def test_create_ssh_tunnel():
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
|
||||
result = SSHTunnelDAO.create(
|
||||
attributes={
|
||||
"database_id": db.id,
|
||||
"database_id": database.id,
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": "3005",
|
||||
"username": "foo",
|
||||
|
||||
@@ -19,6 +19,8 @@ from typing import Any
|
||||
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset import db
|
||||
|
||||
|
||||
def test_put_invalid_dataset(
|
||||
session: Session,
|
||||
@@ -31,7 +33,7 @@ def test_put_invalid_dataset(
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.models.core import Database
|
||||
|
||||
SqlaTable.metadata.create_all(session.get_bind())
|
||||
SqlaTable.metadata.create_all(db.session.get_bind())
|
||||
|
||||
database = Database(
|
||||
database_name="my_db",
|
||||
@@ -41,8 +43,8 @@ def test_put_invalid_dataset(
|
||||
table_name="test_put_invalid_dataset",
|
||||
database=database,
|
||||
)
|
||||
session.add(dataset)
|
||||
session.flush()
|
||||
db.session.add(dataset)
|
||||
db.session.flush()
|
||||
|
||||
response = client.put(
|
||||
"/api/v1/dataset/1",
|
||||
|
||||
@@ -20,6 +20,8 @@ import json
|
||||
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset import db
|
||||
|
||||
|
||||
def test_export(session: Session) -> None:
|
||||
"""
|
||||
@@ -29,12 +31,12 @@ def test_export(session: Session) -> None:
|
||||
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
|
||||
from superset.models.core import Database
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
session.add(database)
|
||||
session.flush()
|
||||
db.session.add(database)
|
||||
db.session.flush()
|
||||
|
||||
columns = [
|
||||
TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
|
||||
|
||||
@@ -28,6 +28,7 @@ from flask import current_app
|
||||
from pytest_mock import MockFixture
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset import db
|
||||
from superset.commands.dataset.exceptions import (
|
||||
DatasetForbiddenDataURI,
|
||||
ImportFailedError,
|
||||
@@ -46,12 +47,12 @@ def test_import_dataset(mocker: MockFixture, session: Session) -> None:
|
||||
|
||||
mocker.patch.object(security_manager, "can_access", return_value=True)
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
session.add(database)
|
||||
session.flush()
|
||||
db.session.add(database)
|
||||
db.session.flush()
|
||||
|
||||
dataset_uuid = uuid.uuid4()
|
||||
config = {
|
||||
@@ -108,7 +109,7 @@ def test_import_dataset(mocker: MockFixture, session: Session) -> None:
|
||||
"database_id": database.id,
|
||||
}
|
||||
|
||||
sqla_table = import_dataset(session, config)
|
||||
sqla_table = import_dataset(config)
|
||||
assert sqla_table.table_name == "my_table"
|
||||
assert sqla_table.main_dttm_col == "ds"
|
||||
assert sqla_table.description == "This is the description"
|
||||
@@ -162,23 +163,23 @@ def test_import_dataset_duplicate_column(mocker: MockFixture, session: Session)
|
||||
|
||||
mocker.patch.object(security_manager, "can_access", return_value=True)
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
dataset_uuid = uuid.uuid4()
|
||||
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
|
||||
session.add(database)
|
||||
session.flush()
|
||||
db.session.add(database)
|
||||
db.session.flush()
|
||||
|
||||
dataset = SqlaTable(
|
||||
uuid=dataset_uuid, table_name="existing_dataset", database_id=database.id
|
||||
)
|
||||
column = TableColumn(column_name="existing_column")
|
||||
session.add(dataset)
|
||||
session.add(column)
|
||||
session.flush()
|
||||
db.session.add(dataset)
|
||||
db.session.add(column)
|
||||
db.session.flush()
|
||||
|
||||
config = {
|
||||
"table_name": dataset.table_name,
|
||||
@@ -234,7 +235,7 @@ def test_import_dataset_duplicate_column(mocker: MockFixture, session: Session)
|
||||
"database_id": database.id,
|
||||
}
|
||||
|
||||
sqla_table = import_dataset(session, config, overwrite=True)
|
||||
sqla_table = import_dataset(config, overwrite=True)
|
||||
assert sqla_table.table_name == dataset.table_name
|
||||
assert sqla_table.main_dttm_col == "ds"
|
||||
assert sqla_table.description == "This is the description"
|
||||
@@ -288,12 +289,12 @@ def test_import_column_extra_is_string(mocker: MockFixture, session: Session) ->
|
||||
|
||||
mocker.patch.object(security_manager, "can_access", return_value=True)
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
session.add(database)
|
||||
session.flush()
|
||||
db.session.add(database)
|
||||
db.session.flush()
|
||||
|
||||
dataset_uuid = uuid.uuid4()
|
||||
yaml_config: dict[str, Any] = {
|
||||
@@ -352,7 +353,7 @@ def test_import_column_extra_is_string(mocker: MockFixture, session: Session) ->
|
||||
schema = ImportV1DatasetSchema()
|
||||
dataset_config = schema.load(yaml_config)
|
||||
dataset_config["database_id"] = database.id
|
||||
sqla_table = import_dataset(session, dataset_config)
|
||||
sqla_table = import_dataset(dataset_config)
|
||||
|
||||
assert sqla_table.metrics[0].extra == '{"warning_markdown": null}'
|
||||
assert sqla_table.columns[0].extra == '{"certified_by": "User"}'
|
||||
@@ -373,12 +374,12 @@ def test_import_dataset_extra_empty_string(
|
||||
|
||||
mocker.patch.object(security_manager, "can_access", return_value=True)
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
session.add(database)
|
||||
session.flush()
|
||||
db.session.add(database)
|
||||
db.session.flush()
|
||||
|
||||
dataset_uuid = uuid.uuid4()
|
||||
yaml_config: dict[str, Any] = {
|
||||
@@ -417,7 +418,7 @@ def test_import_dataset_extra_empty_string(
|
||||
schema = ImportV1DatasetSchema()
|
||||
dataset_config = schema.load(yaml_config)
|
||||
dataset_config["database_id"] = database.id
|
||||
sqla_table = import_dataset(session, dataset_config)
|
||||
sqla_table = import_dataset(dataset_config)
|
||||
|
||||
assert sqla_table.extra == None
|
||||
|
||||
@@ -443,12 +444,12 @@ def test_import_column_allowed_data_url(
|
||||
|
||||
mocker.patch.object(security_manager, "can_access", return_value=True)
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
session.add(database)
|
||||
session.flush()
|
||||
db.session.add(database)
|
||||
db.session.flush()
|
||||
|
||||
dataset_uuid = uuid.uuid4()
|
||||
yaml_config: dict[str, Any] = {
|
||||
@@ -495,9 +496,8 @@ def test_import_column_allowed_data_url(
|
||||
schema = ImportV1DatasetSchema()
|
||||
dataset_config = schema.load(yaml_config)
|
||||
dataset_config["database_id"] = database.id
|
||||
_ = import_dataset(session, dataset_config, force_data=True)
|
||||
session.connection()
|
||||
assert [("value1",), ("value2",)] == session.execute(
|
||||
_ = import_dataset(dataset_config, force_data=True)
|
||||
assert [("value1",), ("value2",)] == db.session.execute(
|
||||
"SELECT * FROM my_table"
|
||||
).fetchall()
|
||||
|
||||
@@ -517,19 +517,19 @@ def test_import_dataset_managed_externally(
|
||||
|
||||
mocker.patch.object(security_manager, "can_access", return_value=True)
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
session.add(database)
|
||||
session.flush()
|
||||
db.session.add(database)
|
||||
db.session.flush()
|
||||
|
||||
config = copy.deepcopy(dataset_config)
|
||||
config["is_managed_externally"] = True
|
||||
config["external_url"] = "https://example.org/my_table"
|
||||
config["database_id"] = database.id
|
||||
|
||||
sqla_table = import_dataset(session, config)
|
||||
sqla_table = import_dataset(config)
|
||||
assert sqla_table.is_managed_externally is True
|
||||
assert sqla_table.external_url == "https://example.org/my_table"
|
||||
|
||||
|
||||
@@ -29,15 +29,15 @@ def session_with_data(session: Session) -> Iterator[Session]:
|
||||
engine = session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
sqla_table = SqlaTable(
|
||||
table_name="my_sqla_table",
|
||||
columns=[],
|
||||
metrics=[],
|
||||
database=db,
|
||||
database=database,
|
||||
)
|
||||
|
||||
session.add(db)
|
||||
session.add(database)
|
||||
session.add(sqla_table)
|
||||
session.flush()
|
||||
yield session
|
||||
@@ -50,7 +50,6 @@ def test_datasource_find_by_id_skip_base_filter(session_with_data: Session) -> N
|
||||
|
||||
result = DatasetDAO.find_by_id(
|
||||
1,
|
||||
session=session_with_data,
|
||||
skip_base_filter=True,
|
||||
)
|
||||
|
||||
@@ -67,7 +66,6 @@ def test_datasource_find_by_id_skip_base_filter_not_found(
|
||||
|
||||
result = DatasetDAO.find_by_id(
|
||||
125326326,
|
||||
session=session_with_data,
|
||||
skip_base_filter=True,
|
||||
)
|
||||
assert result is None
|
||||
@@ -79,7 +77,6 @@ def test_datasource_find_by_ids_skip_base_filter(session_with_data: Session) ->
|
||||
|
||||
result = DatasetDAO.find_by_ids(
|
||||
[1, 125326326],
|
||||
session=session_with_data,
|
||||
skip_base_filter=True,
|
||||
)
|
||||
|
||||
@@ -96,7 +93,6 @@ def test_datasource_find_by_ids_skip_base_filter_not_found(
|
||||
|
||||
result = DatasetDAO.find_by_ids(
|
||||
[125326326, 125326326125326326],
|
||||
session=session_with_data,
|
||||
skip_base_filter=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ def session_with_data(session: Session) -> Iterator[Session]:
|
||||
engine = session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
|
||||
columns = [
|
||||
TableColumn(column_name="a", type="INTEGER"),
|
||||
@@ -45,12 +45,12 @@ def session_with_data(session: Session) -> Iterator[Session]:
|
||||
table_name="my_sqla_table",
|
||||
columns=columns,
|
||||
metrics=[],
|
||||
database=db,
|
||||
database=database,
|
||||
)
|
||||
|
||||
query_obj = Query(
|
||||
client_id="foo",
|
||||
database=db,
|
||||
database=database,
|
||||
tab_name="test_tab",
|
||||
sql_editor_id="test_editor_id",
|
||||
sql="select * from bar",
|
||||
@@ -63,13 +63,13 @@ def session_with_data(session: Session) -> Iterator[Session]:
|
||||
results_key="abc",
|
||||
)
|
||||
|
||||
saved_query = SavedQuery(database=db, sql="select * from foo")
|
||||
saved_query = SavedQuery(database=database, sql="select * from foo")
|
||||
|
||||
table = Table(
|
||||
name="my_table",
|
||||
schema="my_schema",
|
||||
catalog="my_catalog",
|
||||
database=db,
|
||||
database=database,
|
||||
columns=[],
|
||||
)
|
||||
|
||||
@@ -93,7 +93,7 @@ FROM my_catalog.my_schema.my_table
|
||||
session.add(table)
|
||||
session.add(saved_query)
|
||||
session.add(query_obj)
|
||||
session.add(db)
|
||||
session.add(database)
|
||||
session.add(sqla_table)
|
||||
session.flush()
|
||||
yield session
|
||||
@@ -190,7 +190,7 @@ def test_get_datasource_w_str_param(session_with_data: Session) -> None:
|
||||
def test_get_all_datasources(session_with_data: Session) -> None:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
result = SqlaTable.get_all_datasources(session=session_with_data)
|
||||
result = SqlaTable.get_all_datasources()
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
|
||||
@@ -74,10 +74,10 @@ def test_extras_without_ssl() -> None:
|
||||
from superset.db_engine_specs.druid import DruidEngineSpec
|
||||
from tests.integration_tests.fixtures.database import default_db_extra
|
||||
|
||||
db = mock.Mock()
|
||||
db.extra = default_db_extra
|
||||
db.server_cert = None
|
||||
extras = DruidEngineSpec.get_extra_params(db)
|
||||
database = mock.Mock()
|
||||
database.extra = default_db_extra
|
||||
database.server_cert = None
|
||||
extras = DruidEngineSpec.get_extra_params(database)
|
||||
assert "connect_args" not in extras["engine_params"]
|
||||
|
||||
|
||||
@@ -86,10 +86,10 @@ def test_extras_with_ssl() -> None:
|
||||
from tests.integration_tests.fixtures.certificates import ssl_certificate
|
||||
from tests.integration_tests.fixtures.database import default_db_extra
|
||||
|
||||
db = mock.Mock()
|
||||
db.extra = default_db_extra
|
||||
db.server_cert = ssl_certificate
|
||||
extras = DruidEngineSpec.get_extra_params(db)
|
||||
database = mock.Mock()
|
||||
database.extra = default_db_extra
|
||||
database.server_cert = ssl_certificate
|
||||
extras = DruidEngineSpec.get_extra_params(database)
|
||||
connect_args = extras["engine_params"]["connect_args"]
|
||||
assert connect_args["scheme"] == "https"
|
||||
assert "ssl_verify_cert" in connect_args
|
||||
|
||||
@@ -50,8 +50,8 @@ def test_extras_without_ssl() -> None:
|
||||
from superset.db_engine_specs.pinot import PinotEngineSpec as spec
|
||||
from tests.integration_tests.fixtures.database import default_db_extra
|
||||
|
||||
db = mock.Mock()
|
||||
db.extra = default_db_extra
|
||||
db.server_cert = None
|
||||
extras = spec.get_extra_params(db)
|
||||
database = mock.Mock()
|
||||
database.extra = default_db_extra
|
||||
database.server_cert = None
|
||||
extras = spec.get_extra_params(database)
|
||||
assert "connect_args" not in extras["engine_params"]
|
||||
|
||||
@@ -26,6 +26,7 @@ from sqlalchemy.engine import create_engine
|
||||
from sqlalchemy.exc import ProgrammingError
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset import db
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
from tests.unit_tests.conftest import with_feature_flags
|
||||
@@ -38,7 +39,7 @@ if TYPE_CHECKING:
|
||||
def database1(session: Session) -> Iterator["Database"]:
|
||||
from superset.models.core import Database
|
||||
|
||||
engine = session.connection().engine
|
||||
engine = db.session.connection().engine
|
||||
Database.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
database = Database(
|
||||
@@ -46,13 +47,13 @@ def database1(session: Session) -> Iterator["Database"]:
|
||||
sqlalchemy_uri="sqlite:///database1.db",
|
||||
allow_dml=True,
|
||||
)
|
||||
session.add(database)
|
||||
session.commit()
|
||||
db.session.add(database)
|
||||
db.session.commit()
|
||||
|
||||
yield database
|
||||
|
||||
session.delete(database)
|
||||
session.commit()
|
||||
db.session.delete(database)
|
||||
db.session.commit()
|
||||
os.unlink("database1.db")
|
||||
|
||||
|
||||
@@ -62,12 +63,12 @@ def table1(session: Session, database1: "Database") -> Iterator[None]:
|
||||
conn = engine.connect()
|
||||
conn.execute("CREATE TABLE table1 (a INTEGER NOT NULL PRIMARY KEY, b INTEGER)")
|
||||
conn.execute("INSERT INTO table1 (a, b) VALUES (1, 10), (2, 20)")
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
yield
|
||||
|
||||
conn.execute("DROP TABLE table1")
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -79,13 +80,13 @@ def database2(session: Session) -> Iterator["Database"]:
|
||||
sqlalchemy_uri="sqlite:///database2.db",
|
||||
allow_dml=False,
|
||||
)
|
||||
session.add(database)
|
||||
session.commit()
|
||||
db.session.add(database)
|
||||
db.session.commit()
|
||||
|
||||
yield database
|
||||
|
||||
session.delete(database)
|
||||
session.commit()
|
||||
db.session.delete(database)
|
||||
db.session.commit()
|
||||
os.unlink("database2.db")
|
||||
|
||||
|
||||
@@ -95,12 +96,12 @@ def table2(session: Session, database2: "Database") -> Iterator[None]:
|
||||
conn = engine.connect()
|
||||
conn.execute("CREATE TABLE table2 (a INTEGER NOT NULL PRIMARY KEY, b TEXT)")
|
||||
conn.execute("INSERT INTO table2 (a, b) VALUES (1, 'ten'), (2, 'twenty')")
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
yield
|
||||
|
||||
conn.execute("DROP TABLE table2")
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@with_feature_flags(ENABLE_SUPERSET_META_DB=True)
|
||||
|
||||
@@ -22,10 +22,10 @@ def test_column_attributes_on_query():
|
||||
from superset.models.core import Database
|
||||
from superset.models.sql_lab import Query
|
||||
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
query_obj = Query(
|
||||
client_id="foo",
|
||||
database=db,
|
||||
database=database,
|
||||
tab_name="test_tab",
|
||||
sql_editor_id="test_editor_id",
|
||||
sql="select * from bar",
|
||||
|
||||
@@ -125,7 +125,7 @@ def test_sql_lab_insert_rls_as_subquery(
|
||||
from superset.sql_lab import execute_sql_statement
|
||||
from superset.utils.core import RowLevelSecurityFilterType
|
||||
|
||||
engine = session.connection().engine
|
||||
engine = db.session.connection().engine
|
||||
Query.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
connection = engine.raw_connection()
|
||||
@@ -143,8 +143,8 @@ def test_sql_lab_insert_rls_as_subquery(
|
||||
limit=5,
|
||||
select_as_cta_used=False,
|
||||
)
|
||||
session.add(query)
|
||||
session.commit()
|
||||
db.session.add(query)
|
||||
db.session.commit()
|
||||
|
||||
admin = User(
|
||||
first_name="Alice",
|
||||
@@ -185,8 +185,8 @@ def test_sql_lab_insert_rls_as_subquery(
|
||||
group_key=None,
|
||||
clause="c > 5",
|
||||
)
|
||||
session.add(rls)
|
||||
session.flush()
|
||||
db.session.add(rls)
|
||||
db.session.flush()
|
||||
mocker.patch.object(SupersetSecurityManager, "find_user", return_value=admin)
|
||||
mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True)
|
||||
|
||||
|
||||
@@ -1759,8 +1759,7 @@ def test_get_rls_for_table(mocker: MockerFixture) -> None:
|
||||
Tests for ``get_rls_for_table``.
|
||||
"""
|
||||
candidate = Identifier([Token(Name, "some_table")])
|
||||
db = mocker.patch("superset.db")
|
||||
dataset = db.session.query().filter().one_or_none()
|
||||
dataset = mocker.patch("superset.db").session.query().filter().one_or_none()
|
||||
dataset.__str__.return_value = "some_table"
|
||||
|
||||
dataset.get_sqla_row_level_filters.return_value = [text("organization_id = 1")]
|
||||
|
||||
@@ -14,11 +14,11 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
# pylint: disable=import-outside-toplevel, unused-argument
|
||||
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset import db
|
||||
|
||||
|
||||
def test_table_model(session: Session) -> None:
|
||||
"""
|
||||
@@ -28,7 +28,7 @@ def test_table_model(session: Session) -> None:
|
||||
from superset.models.core import Database
|
||||
from superset.tables.models import Table
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
Table.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
table = Table(
|
||||
@@ -44,8 +44,8 @@ def test_table_model(session: Session) -> None:
|
||||
)
|
||||
],
|
||||
)
|
||||
session.add(table)
|
||||
session.flush()
|
||||
db.session.add(table)
|
||||
db.session.flush()
|
||||
|
||||
assert table.id == 1
|
||||
assert table.uuid is not None
|
||||
|
||||
@@ -18,6 +18,7 @@ import pytest
|
||||
from pytest_mock import MockFixture
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset import db
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
|
||||
@@ -40,13 +41,15 @@ def session_with_data(session: Session):
|
||||
slice_name="slice_name",
|
||||
)
|
||||
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="postgresql://")
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="postgresql://")
|
||||
|
||||
columns = [
|
||||
TableColumn(column_name="a", type="INTEGER"),
|
||||
]
|
||||
|
||||
saved_query = SavedQuery(label="test_query", database=db, sql="select * from foo")
|
||||
saved_query = SavedQuery(
|
||||
label="test_query", database=database, sql="select * from foo"
|
||||
)
|
||||
|
||||
dashboard_obj = Dashboard(
|
||||
id=100,
|
||||
@@ -57,7 +60,7 @@ def session_with_data(session: Session):
|
||||
)
|
||||
|
||||
session.add(slice_obj)
|
||||
session.add(db)
|
||||
session.add(database)
|
||||
session.add(saved_query)
|
||||
session.add(dashboard_obj)
|
||||
session.commit()
|
||||
@@ -74,9 +77,9 @@ def test_create_command_success(session_with_data: Session, mocker: MockFixture)
|
||||
from superset.tags.models import ObjectType, TaggedObject
|
||||
|
||||
# Define a list of objects to tag
|
||||
query = session_with_data.query(SavedQuery).first()
|
||||
chart = session_with_data.query(Slice).first()
|
||||
dashboard = session_with_data.query(Dashboard).first()
|
||||
query = db.session.query(SavedQuery).first()
|
||||
chart = db.session.query(Slice).first()
|
||||
dashboard = db.session.query(Dashboard).first()
|
||||
|
||||
mocker.patch(
|
||||
"superset.security.SupersetSecurityManager.is_admin", return_value=True
|
||||
@@ -94,10 +97,10 @@ def test_create_command_success(session_with_data: Session, mocker: MockFixture)
|
||||
data={"name": "test_tag", "objects_to_tag": objects_to_tag}
|
||||
).run()
|
||||
|
||||
assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag)
|
||||
assert len(db.session.query(TaggedObject).all()) == len(objects_to_tag)
|
||||
for object_type, object_id in objects_to_tag:
|
||||
assert (
|
||||
session_with_data.query(TaggedObject)
|
||||
db.session.query(TaggedObject)
|
||||
.filter(
|
||||
TaggedObject.object_type == object_type,
|
||||
TaggedObject.object_id == object_id,
|
||||
@@ -117,9 +120,9 @@ def test_create_command_success_clear(session_with_data: Session, mocker: MockFi
|
||||
from superset.tags.models import ObjectType, TaggedObject
|
||||
|
||||
# Define a list of objects to tag
|
||||
query = session_with_data.query(SavedQuery).first()
|
||||
chart = session_with_data.query(Slice).first()
|
||||
dashboard = session_with_data.query(Dashboard).first()
|
||||
query = db.session.query(SavedQuery).first()
|
||||
chart = db.session.query(Slice).first()
|
||||
dashboard = db.session.query(Dashboard).first()
|
||||
|
||||
mocker.patch(
|
||||
"superset.security.SupersetSecurityManager.is_admin", return_value=True
|
||||
@@ -136,10 +139,10 @@ def test_create_command_success_clear(session_with_data: Session, mocker: MockFi
|
||||
CreateCustomTagWithRelationshipsCommand(
|
||||
data={"name": "test_tag", "objects_to_tag": objects_to_tag}
|
||||
).run()
|
||||
assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag)
|
||||
assert len(db.session.query(TaggedObject).all()) == len(objects_to_tag)
|
||||
|
||||
CreateCustomTagWithRelationshipsCommand(
|
||||
data={"name": "test_tag", "objects_to_tag": []}
|
||||
).run()
|
||||
|
||||
assert len(session_with_data.query(TaggedObject).all()) == 0
|
||||
assert len(db.session.query(TaggedObject).all()) == 0
|
||||
|
||||
@@ -18,6 +18,7 @@ import pytest
|
||||
from pytest_mock import MockFixture
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset import db
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
|
||||
@@ -41,7 +42,7 @@ def session_with_data(session: Session):
|
||||
slice_name="slice_name",
|
||||
)
|
||||
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="postgresql://")
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="postgresql://")
|
||||
|
||||
columns = [
|
||||
TableColumn(column_name="a", type="INTEGER"),
|
||||
@@ -51,7 +52,7 @@ def session_with_data(session: Session):
|
||||
table_name="my_sqla_table",
|
||||
columns=columns,
|
||||
metrics=[],
|
||||
database=db,
|
||||
database=database,
|
||||
)
|
||||
|
||||
dashboard_obj = Dashboard(
|
||||
@@ -62,7 +63,9 @@ def session_with_data(session: Session):
|
||||
published=True,
|
||||
)
|
||||
|
||||
saved_query = SavedQuery(label="test_query", database=db, sql="select * from foo")
|
||||
saved_query = SavedQuery(
|
||||
label="test_query", database=database, sql="select * from foo"
|
||||
)
|
||||
|
||||
tag = Tag(name="test_name", description="test_description")
|
||||
|
||||
@@ -79,7 +82,7 @@ def test_update_command_success(session_with_data: Session, mocker: MockFixture)
|
||||
from superset.models.dashboard import Dashboard
|
||||
from superset.tags.models import ObjectType, TaggedObject
|
||||
|
||||
dashboard = session_with_data.query(Dashboard).first()
|
||||
dashboard = db.session.query(Dashboard).first()
|
||||
mocker.patch(
|
||||
"superset.security.SupersetSecurityManager.is_admin", return_value=True
|
||||
)
|
||||
@@ -104,7 +107,7 @@ def test_update_command_success(session_with_data: Session, mocker: MockFixture)
|
||||
updated_tag = TagDAO.find_by_name("new_name")
|
||||
assert updated_tag is not None
|
||||
assert updated_tag.description == "new_description"
|
||||
assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag)
|
||||
assert len(db.session.query(TaggedObject).all()) == len(objects_to_tag)
|
||||
|
||||
|
||||
def test_update_command_success_duplicates(
|
||||
@@ -117,8 +120,8 @@ def test_update_command_success_duplicates(
|
||||
from superset.models.slice import Slice
|
||||
from superset.tags.models import ObjectType, TaggedObject
|
||||
|
||||
dashboard = session_with_data.query(Dashboard).first()
|
||||
chart = session_with_data.query(Slice).first()
|
||||
dashboard = db.session.query(Dashboard).first()
|
||||
chart = db.session.query(Slice).first()
|
||||
|
||||
mocker.patch(
|
||||
"superset.security.SupersetSecurityManager.is_admin", return_value=True
|
||||
@@ -153,7 +156,7 @@ def test_update_command_success_duplicates(
|
||||
updated_tag = TagDAO.find_by_name("new_name")
|
||||
assert updated_tag is not None
|
||||
assert updated_tag.description == "new_description"
|
||||
assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag)
|
||||
assert len(db.session.query(TaggedObject).all()) == len(objects_to_tag)
|
||||
assert changed_model.objects[0].object_id == chart.id
|
||||
|
||||
|
||||
@@ -168,8 +171,8 @@ def test_update_command_failed_validation(
|
||||
from superset.models.slice import Slice
|
||||
from superset.tags.models import ObjectType
|
||||
|
||||
dashboard = session_with_data.query(Dashboard).first()
|
||||
chart = session_with_data.query(Slice).first()
|
||||
dashboard = db.session.query(Dashboard).first()
|
||||
chart = db.session.query(Slice).first()
|
||||
objects_to_tag = [
|
||||
(ObjectType.chart, chart.id),
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user