refactor: Ensure Flask framework leverages the Flask-SQLAlchemy session (Phase II) (#26909)

This commit is contained in:
John Bodley
2024-02-14 06:20:15 +13:00
committed by GitHub
parent 827864b939
commit 847ed3f5b0
96 changed files with 656 additions and 730 deletions

View File

@@ -19,6 +19,8 @@ from typing import Any
from sqlalchemy.orm.session import Session
from superset import db
def test_put_invalid_dataset(
session: Session,
@@ -31,7 +33,7 @@ def test_put_invalid_dataset(
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
SqlaTable.metadata.create_all(session.get_bind())
SqlaTable.metadata.create_all(db.session.get_bind())
database = Database(
database_name="my_db",
@@ -41,8 +43,8 @@ def test_put_invalid_dataset(
table_name="test_put_invalid_dataset",
database=database,
)
session.add(dataset)
session.flush()
db.session.add(dataset)
db.session.flush()
response = client.put(
"/api/v1/dataset/1",

View File

@@ -20,6 +20,8 @@ import json
from sqlalchemy.orm.session import Session
from superset import db
def test_export(session: Session) -> None:
"""
@@ -29,12 +31,12 @@ def test_export(session: Session) -> None:
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.models.core import Database
engine = session.get_bind()
engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
session.add(database)
session.flush()
db.session.add(database)
db.session.flush()
columns = [
TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),

View File

@@ -28,6 +28,7 @@ from flask import current_app
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session
from superset import db
from superset.commands.dataset.exceptions import (
DatasetForbiddenDataURI,
ImportFailedError,
@@ -46,12 +47,12 @@ def test_import_dataset(mocker: MockFixture, session: Session) -> None:
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = session.get_bind()
engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
session.add(database)
session.flush()
db.session.add(database)
db.session.flush()
dataset_uuid = uuid.uuid4()
config = {
@@ -108,7 +109,7 @@ def test_import_dataset(mocker: MockFixture, session: Session) -> None:
"database_id": database.id,
}
sqla_table = import_dataset(session, config)
sqla_table = import_dataset(config)
assert sqla_table.table_name == "my_table"
assert sqla_table.main_dttm_col == "ds"
assert sqla_table.description == "This is the description"
@@ -162,23 +163,23 @@ def test_import_dataset_duplicate_column(mocker: MockFixture, session: Session)
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = session.get_bind()
engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
dataset_uuid = uuid.uuid4()
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
session.add(database)
session.flush()
db.session.add(database)
db.session.flush()
dataset = SqlaTable(
uuid=dataset_uuid, table_name="existing_dataset", database_id=database.id
)
column = TableColumn(column_name="existing_column")
session.add(dataset)
session.add(column)
session.flush()
db.session.add(dataset)
db.session.add(column)
db.session.flush()
config = {
"table_name": dataset.table_name,
@@ -234,7 +235,7 @@ def test_import_dataset_duplicate_column(mocker: MockFixture, session: Session)
"database_id": database.id,
}
sqla_table = import_dataset(session, config, overwrite=True)
sqla_table = import_dataset(config, overwrite=True)
assert sqla_table.table_name == dataset.table_name
assert sqla_table.main_dttm_col == "ds"
assert sqla_table.description == "This is the description"
@@ -288,12 +289,12 @@ def test_import_column_extra_is_string(mocker: MockFixture, session: Session) ->
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = session.get_bind()
engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
session.add(database)
session.flush()
db.session.add(database)
db.session.flush()
dataset_uuid = uuid.uuid4()
yaml_config: dict[str, Any] = {
@@ -352,7 +353,7 @@ def test_import_column_extra_is_string(mocker: MockFixture, session: Session) ->
schema = ImportV1DatasetSchema()
dataset_config = schema.load(yaml_config)
dataset_config["database_id"] = database.id
sqla_table = import_dataset(session, dataset_config)
sqla_table = import_dataset(dataset_config)
assert sqla_table.metrics[0].extra == '{"warning_markdown": null}'
assert sqla_table.columns[0].extra == '{"certified_by": "User"}'
@@ -373,12 +374,12 @@ def test_import_dataset_extra_empty_string(
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = session.get_bind()
engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
session.add(database)
session.flush()
db.session.add(database)
db.session.flush()
dataset_uuid = uuid.uuid4()
yaml_config: dict[str, Any] = {
@@ -417,7 +418,7 @@ def test_import_dataset_extra_empty_string(
schema = ImportV1DatasetSchema()
dataset_config = schema.load(yaml_config)
dataset_config["database_id"] = database.id
sqla_table = import_dataset(session, dataset_config)
sqla_table = import_dataset(dataset_config)
assert sqla_table.extra == None
@@ -443,12 +444,12 @@ def test_import_column_allowed_data_url(
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = session.get_bind()
engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
session.add(database)
session.flush()
db.session.add(database)
db.session.flush()
dataset_uuid = uuid.uuid4()
yaml_config: dict[str, Any] = {
@@ -495,9 +496,8 @@ def test_import_column_allowed_data_url(
schema = ImportV1DatasetSchema()
dataset_config = schema.load(yaml_config)
dataset_config["database_id"] = database.id
_ = import_dataset(session, dataset_config, force_data=True)
session.connection()
assert [("value1",), ("value2",)] == session.execute(
_ = import_dataset(dataset_config, force_data=True)
assert [("value1",), ("value2",)] == db.session.execute(
"SELECT * FROM my_table"
).fetchall()
@@ -517,19 +517,19 @@ def test_import_dataset_managed_externally(
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = session.get_bind()
engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
session.add(database)
session.flush()
db.session.add(database)
db.session.flush()
config = copy.deepcopy(dataset_config)
config["is_managed_externally"] = True
config["external_url"] = "https://example.org/my_table"
config["database_id"] = database.id
sqla_table = import_dataset(session, config)
sqla_table = import_dataset(config)
assert sqla_table.is_managed_externally is True
assert sqla_table.external_url == "https://example.org/my_table"

View File

@@ -29,15 +29,15 @@ def session_with_data(session: Session) -> Iterator[Session]:
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
database=db,
database=database,
)
session.add(db)
session.add(database)
session.add(sqla_table)
session.flush()
yield session
@@ -50,7 +50,6 @@ def test_datasource_find_by_id_skip_base_filter(session_with_data: Session) -> N
result = DatasetDAO.find_by_id(
1,
session=session_with_data,
skip_base_filter=True,
)
@@ -67,7 +66,6 @@ def test_datasource_find_by_id_skip_base_filter_not_found(
result = DatasetDAO.find_by_id(
125326326,
session=session_with_data,
skip_base_filter=True,
)
assert result is None
@@ -79,7 +77,6 @@ def test_datasource_find_by_ids_skip_base_filter(session_with_data: Session) ->
result = DatasetDAO.find_by_ids(
[1, 125326326],
session=session_with_data,
skip_base_filter=True,
)
@@ -96,7 +93,6 @@ def test_datasource_find_by_ids_skip_base_filter_not_found(
result = DatasetDAO.find_by_ids(
[125326326, 125326326125326326],
session=session_with_data,
skip_base_filter=True,
)