mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
refactor: Ensure Flask framework leverages the Flask-SQLAlchemy session (Phase II) (#26909)
This commit is contained in:
@@ -19,6 +19,8 @@ from typing import Any
|
||||
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset import db
|
||||
|
||||
|
||||
def test_put_invalid_dataset(
|
||||
session: Session,
|
||||
@@ -31,7 +33,7 @@ def test_put_invalid_dataset(
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.models.core import Database
|
||||
|
||||
SqlaTable.metadata.create_all(session.get_bind())
|
||||
SqlaTable.metadata.create_all(db.session.get_bind())
|
||||
|
||||
database = Database(
|
||||
database_name="my_db",
|
||||
@@ -41,8 +43,8 @@ def test_put_invalid_dataset(
|
||||
table_name="test_put_invalid_dataset",
|
||||
database=database,
|
||||
)
|
||||
session.add(dataset)
|
||||
session.flush()
|
||||
db.session.add(dataset)
|
||||
db.session.flush()
|
||||
|
||||
response = client.put(
|
||||
"/api/v1/dataset/1",
|
||||
|
||||
@@ -20,6 +20,8 @@ import json
|
||||
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset import db
|
||||
|
||||
|
||||
def test_export(session: Session) -> None:
|
||||
"""
|
||||
@@ -29,12 +31,12 @@ def test_export(session: Session) -> None:
|
||||
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
|
||||
from superset.models.core import Database
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
session.add(database)
|
||||
session.flush()
|
||||
db.session.add(database)
|
||||
db.session.flush()
|
||||
|
||||
columns = [
|
||||
TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
|
||||
|
||||
@@ -28,6 +28,7 @@ from flask import current_app
|
||||
from pytest_mock import MockFixture
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset import db
|
||||
from superset.commands.dataset.exceptions import (
|
||||
DatasetForbiddenDataURI,
|
||||
ImportFailedError,
|
||||
@@ -46,12 +47,12 @@ def test_import_dataset(mocker: MockFixture, session: Session) -> None:
|
||||
|
||||
mocker.patch.object(security_manager, "can_access", return_value=True)
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
session.add(database)
|
||||
session.flush()
|
||||
db.session.add(database)
|
||||
db.session.flush()
|
||||
|
||||
dataset_uuid = uuid.uuid4()
|
||||
config = {
|
||||
@@ -108,7 +109,7 @@ def test_import_dataset(mocker: MockFixture, session: Session) -> None:
|
||||
"database_id": database.id,
|
||||
}
|
||||
|
||||
sqla_table = import_dataset(session, config)
|
||||
sqla_table = import_dataset(config)
|
||||
assert sqla_table.table_name == "my_table"
|
||||
assert sqla_table.main_dttm_col == "ds"
|
||||
assert sqla_table.description == "This is the description"
|
||||
@@ -162,23 +163,23 @@ def test_import_dataset_duplicate_column(mocker: MockFixture, session: Session)
|
||||
|
||||
mocker.patch.object(security_manager, "can_access", return_value=True)
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
dataset_uuid = uuid.uuid4()
|
||||
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
|
||||
session.add(database)
|
||||
session.flush()
|
||||
db.session.add(database)
|
||||
db.session.flush()
|
||||
|
||||
dataset = SqlaTable(
|
||||
uuid=dataset_uuid, table_name="existing_dataset", database_id=database.id
|
||||
)
|
||||
column = TableColumn(column_name="existing_column")
|
||||
session.add(dataset)
|
||||
session.add(column)
|
||||
session.flush()
|
||||
db.session.add(dataset)
|
||||
db.session.add(column)
|
||||
db.session.flush()
|
||||
|
||||
config = {
|
||||
"table_name": dataset.table_name,
|
||||
@@ -234,7 +235,7 @@ def test_import_dataset_duplicate_column(mocker: MockFixture, session: Session)
|
||||
"database_id": database.id,
|
||||
}
|
||||
|
||||
sqla_table = import_dataset(session, config, overwrite=True)
|
||||
sqla_table = import_dataset(config, overwrite=True)
|
||||
assert sqla_table.table_name == dataset.table_name
|
||||
assert sqla_table.main_dttm_col == "ds"
|
||||
assert sqla_table.description == "This is the description"
|
||||
@@ -288,12 +289,12 @@ def test_import_column_extra_is_string(mocker: MockFixture, session: Session) ->
|
||||
|
||||
mocker.patch.object(security_manager, "can_access", return_value=True)
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
session.add(database)
|
||||
session.flush()
|
||||
db.session.add(database)
|
||||
db.session.flush()
|
||||
|
||||
dataset_uuid = uuid.uuid4()
|
||||
yaml_config: dict[str, Any] = {
|
||||
@@ -352,7 +353,7 @@ def test_import_column_extra_is_string(mocker: MockFixture, session: Session) ->
|
||||
schema = ImportV1DatasetSchema()
|
||||
dataset_config = schema.load(yaml_config)
|
||||
dataset_config["database_id"] = database.id
|
||||
sqla_table = import_dataset(session, dataset_config)
|
||||
sqla_table = import_dataset(dataset_config)
|
||||
|
||||
assert sqla_table.metrics[0].extra == '{"warning_markdown": null}'
|
||||
assert sqla_table.columns[0].extra == '{"certified_by": "User"}'
|
||||
@@ -373,12 +374,12 @@ def test_import_dataset_extra_empty_string(
|
||||
|
||||
mocker.patch.object(security_manager, "can_access", return_value=True)
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
session.add(database)
|
||||
session.flush()
|
||||
db.session.add(database)
|
||||
db.session.flush()
|
||||
|
||||
dataset_uuid = uuid.uuid4()
|
||||
yaml_config: dict[str, Any] = {
|
||||
@@ -417,7 +418,7 @@ def test_import_dataset_extra_empty_string(
|
||||
schema = ImportV1DatasetSchema()
|
||||
dataset_config = schema.load(yaml_config)
|
||||
dataset_config["database_id"] = database.id
|
||||
sqla_table = import_dataset(session, dataset_config)
|
||||
sqla_table = import_dataset(dataset_config)
|
||||
|
||||
assert sqla_table.extra == None
|
||||
|
||||
@@ -443,12 +444,12 @@ def test_import_column_allowed_data_url(
|
||||
|
||||
mocker.patch.object(security_manager, "can_access", return_value=True)
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
session.add(database)
|
||||
session.flush()
|
||||
db.session.add(database)
|
||||
db.session.flush()
|
||||
|
||||
dataset_uuid = uuid.uuid4()
|
||||
yaml_config: dict[str, Any] = {
|
||||
@@ -495,9 +496,8 @@ def test_import_column_allowed_data_url(
|
||||
schema = ImportV1DatasetSchema()
|
||||
dataset_config = schema.load(yaml_config)
|
||||
dataset_config["database_id"] = database.id
|
||||
_ = import_dataset(session, dataset_config, force_data=True)
|
||||
session.connection()
|
||||
assert [("value1",), ("value2",)] == session.execute(
|
||||
_ = import_dataset(dataset_config, force_data=True)
|
||||
assert [("value1",), ("value2",)] == db.session.execute(
|
||||
"SELECT * FROM my_table"
|
||||
).fetchall()
|
||||
|
||||
@@ -517,19 +517,19 @@ def test_import_dataset_managed_externally(
|
||||
|
||||
mocker.patch.object(security_manager, "can_access", return_value=True)
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
session.add(database)
|
||||
session.flush()
|
||||
db.session.add(database)
|
||||
db.session.flush()
|
||||
|
||||
config = copy.deepcopy(dataset_config)
|
||||
config["is_managed_externally"] = True
|
||||
config["external_url"] = "https://example.org/my_table"
|
||||
config["database_id"] = database.id
|
||||
|
||||
sqla_table = import_dataset(session, config)
|
||||
sqla_table = import_dataset(config)
|
||||
assert sqla_table.is_managed_externally is True
|
||||
assert sqla_table.external_url == "https://example.org/my_table"
|
||||
|
||||
|
||||
@@ -29,15 +29,15 @@ def session_with_data(session: Session) -> Iterator[Session]:
|
||||
engine = session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
sqla_table = SqlaTable(
|
||||
table_name="my_sqla_table",
|
||||
columns=[],
|
||||
metrics=[],
|
||||
database=db,
|
||||
database=database,
|
||||
)
|
||||
|
||||
session.add(db)
|
||||
session.add(database)
|
||||
session.add(sqla_table)
|
||||
session.flush()
|
||||
yield session
|
||||
@@ -50,7 +50,6 @@ def test_datasource_find_by_id_skip_base_filter(session_with_data: Session) -> N
|
||||
|
||||
result = DatasetDAO.find_by_id(
|
||||
1,
|
||||
session=session_with_data,
|
||||
skip_base_filter=True,
|
||||
)
|
||||
|
||||
@@ -67,7 +66,6 @@ def test_datasource_find_by_id_skip_base_filter_not_found(
|
||||
|
||||
result = DatasetDAO.find_by_id(
|
||||
125326326,
|
||||
session=session_with_data,
|
||||
skip_base_filter=True,
|
||||
)
|
||||
assert result is None
|
||||
@@ -79,7 +77,6 @@ def test_datasource_find_by_ids_skip_base_filter(session_with_data: Session) ->
|
||||
|
||||
result = DatasetDAO.find_by_ids(
|
||||
[1, 125326326],
|
||||
session=session_with_data,
|
||||
skip_base_filter=True,
|
||||
)
|
||||
|
||||
@@ -96,7 +93,6 @@ def test_datasource_find_by_ids_skip_base_filter_not_found(
|
||||
|
||||
result = DatasetDAO.find_by_ids(
|
||||
[125326326, 125326326125326326],
|
||||
session=session_with_data,
|
||||
skip_base_filter=True,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user