refactor: Ensure Flask framework leverages the Flask-SQLAlchemy session (Phase II) (#26909)

This commit is contained in:
John Bodley
2024-02-14 06:20:15 +13:00
committed by GitHub
parent 827864b939
commit 847ed3f5b0
96 changed files with 656 additions and 730 deletions

View File

@@ -28,6 +28,8 @@ from flask import current_app
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session
from superset import db
def test_filter_by_uuid(
session: Session,
@@ -49,14 +51,14 @@ def test_filter_by_uuid(
# create table for databases
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
session.add(
db.session.add(
Database(
database_name="my_db",
sqlalchemy_uri="sqlite://",
uuid=UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"),
)
)
session.commit()
db.session.commit()
response = client.get(
"/api/v1/database/?q=(filters:!((col:uuid,opr:eq,value:"
@@ -96,7 +98,7 @@ def test_post_with_uuid(
payload = response.json
assert payload["result"]["uuid"] == "7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"
database = session.query(Database).one()
database = db.session.query(Database).one()
assert database.uuid == UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb")
@@ -139,8 +141,8 @@ def test_password_mask(
}
),
)
session.add(database)
session.commit()
db.session.add(database)
db.session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
@@ -195,8 +197,8 @@ def test_database_connection(
}
),
)
session.add(database)
session.commit()
db.session.add(database)
db.session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
@@ -331,8 +333,8 @@ def test_update_with_password_mask(
}
),
)
session.add(database)
session.commit()
db.session.add(database)
db.session.commit()
client.put(
"/api/v1/database/1",
@@ -347,7 +349,7 @@ def test_update_with_password_mask(
),
},
)
database = session.query(Database).one()
database = db.session.query(Database).one()
assert (
database.encrypted_extra
== '{"service_account_info": {"project_id": "yellow-unicorn-314419", "private_key": "SECRET"}}'
@@ -429,8 +431,8 @@ def test_delete_ssh_tunnel(
}
),
)
session.add(database)
session.commit()
db.session.add(database)
db.session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
@@ -446,8 +448,8 @@ def test_delete_ssh_tunnel(
database=database,
)
session.add(tunnel)
session.commit()
db.session.add(tunnel)
db.session.commit()
# Get our recently created SSHTunnel
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
@@ -505,8 +507,8 @@ def test_delete_ssh_tunnel_not_found(
}
),
)
session.add(database)
session.commit()
db.session.add(database)
db.session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
@@ -522,8 +524,8 @@ def test_delete_ssh_tunnel_not_found(
database=database,
)
session.add(tunnel)
session.commit()
db.session.add(tunnel)
db.session.commit()
# Delete the recently created SSHTunnel
response_delete_tunnel = client.delete("/api/v1/database/2/ssh_tunnel/")
@@ -576,8 +578,8 @@ def test_apply_dynamic_database_filter(
}
),
)
session.add(database)
session.commit()
db.session.add(database)
db.session.commit()
# Create our Second Database
database = Database(
@@ -592,8 +594,8 @@ def test_apply_dynamic_database_filter(
}
),
)
session.add(database)
session.commit()
db.session.add(database)
db.session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")

View File

@@ -23,6 +23,7 @@ import pytest
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session
from superset import db
from superset.commands.exceptions import ImportFailedError
@@ -37,11 +38,11 @@ def test_import_database(mocker: MockFixture, session: Session) -> None:
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = session.get_bind()
engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member
config = copy.deepcopy(database_config)
database = import_database(session, config)
database = import_database(config)
assert database.database_name == "imported_database"
assert database.sqlalchemy_uri == "someengine://user:pass@host1"
assert database.cache_timeout is None
@@ -60,9 +61,9 @@ def test_import_database(mocker: MockFixture, session: Session) -> None:
# missing
config = copy.deepcopy(database_config)
del config["allow_dml"]
session.delete(database)
session.flush()
database = import_database(session, config)
db.session.delete(database)
db.session.flush()
database = import_database(config)
assert database.allow_dml is False
@@ -78,12 +79,12 @@ def test_import_database_sqlite_invalid(mocker: MockFixture, session: Session) -
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = True
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = session.get_bind()
engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member
config = copy.deepcopy(database_config_sqlite)
with pytest.raises(ImportFailedError) as excinfo:
_ = import_database(session, config)
_ = import_database(config)
assert (
str(excinfo.value)
== "SQLiteDialect_pysqlite cannot be used as a data source for security reasons."
@@ -106,14 +107,14 @@ def test_import_database_managed_externally(
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = session.get_bind()
engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member
config = copy.deepcopy(database_config)
config["is_managed_externally"] = True
config["external_url"] = "https://example.org/my_database"
database = import_database(session, config)
database = import_database(config)
assert database.is_managed_externally is True
assert database.external_url == "https://example.org/my_database"
@@ -132,13 +133,13 @@ def test_import_database_without_permission(
mocker.patch.object(security_manager, "can_access", return_value=False)
engine = session.get_bind()
engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member
config = copy.deepcopy(database_config)
with pytest.raises(ImportFailedError) as excinfo:
import_database(session, config)
import_database(config)
assert (
str(excinfo.value)
== "Database doesn't exist and user doesn't have permission to create databases"
@@ -156,10 +157,10 @@ def test_import_database_with_version(mocker: MockFixture, session: Session) ->
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = session.get_bind()
engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member
config = copy.deepcopy(database_config)
config["extra"]["version"] = "1.1.1"
database = import_database(session, config)
database = import_database(config)
assert json.loads(database.extra)["version"] == "1.1.1"

View File

@@ -30,19 +30,19 @@ def session_with_data(session: Session) -> Iterator[Session]:
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
database=db,
database=database,
)
ssh_tunnel = SSHTunnel(
database_id=db.id,
database=db,
database_id=database.id,
database=database,
)
session.add(db)
session.add(database)
session.add(sqla_table)
session.add(ssh_tunnel)
session.flush()

View File

@@ -27,17 +27,17 @@ def test_create_ssh_tunnel_command() -> None:
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
properties = {
"database_id": db.id,
"database_id": database.id,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"password": "bar",
}
result = CreateSSHTunnelCommand(db, properties).run()
result = CreateSSHTunnelCommand(database, properties).run()
assert result is not None
assert isinstance(result, SSHTunnel)
@@ -48,19 +48,19 @@ def test_create_ssh_tunnel_command_invalid_params() -> None:
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
# If we are trying to create a tunnel with a private_key_password
# then a private_key is mandatory
properties = {
"database": db,
"database": database,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"private_key_password": "bar",
}
command = CreateSSHTunnelCommand(db, properties)
command = CreateSSHTunnelCommand(database, properties)
with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run()

View File

@@ -31,19 +31,19 @@ def session_with_data(session: Session) -> Iterator[Session]:
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
database=db,
database=database,
)
ssh_tunnel = SSHTunnel(
database_id=db.id,
database=db,
database_id=database.id,
database=database,
)
session.add(db)
session.add(database)
session.add(sqla_table)
session.add(ssh_tunnel)
session.flush()

View File

@@ -32,16 +32,18 @@ def session_with_data(session: Session) -> Iterator[Session]:
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
database=db,
database=database,
)
ssh_tunnel = SSHTunnel(
database_id=database.id, database=database, server_address="Test"
)
ssh_tunnel = SSHTunnel(database_id=db.id, database=db, server_address="Test")
session.add(db)
session.add(database)
session.add(sqla_table)
session.add(ssh_tunnel)
session.flush()

View File

@@ -25,11 +25,11 @@ def test_create_ssh_tunnel():
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
result = SSHTunnelDAO.create(
attributes={
"database_id": db.id,
"database_id": database.id,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",