mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
refactor: Ensure Flask framework leverages the Flask-SQLAlchemy session (Phase II) (#26909)
This commit is contained in:
@@ -28,6 +28,8 @@ from flask import current_app
|
||||
from pytest_mock import MockFixture
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset import db
|
||||
|
||||
|
||||
def test_filter_by_uuid(
|
||||
session: Session,
|
||||
@@ -49,14 +51,14 @@ def test_filter_by_uuid(
|
||||
|
||||
# create table for databases
|
||||
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
|
||||
session.add(
|
||||
db.session.add(
|
||||
Database(
|
||||
database_name="my_db",
|
||||
sqlalchemy_uri="sqlite://",
|
||||
uuid=UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"),
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/database/?q=(filters:!((col:uuid,opr:eq,value:"
|
||||
@@ -96,7 +98,7 @@ def test_post_with_uuid(
|
||||
payload = response.json
|
||||
assert payload["result"]["uuid"] == "7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"
|
||||
|
||||
database = session.query(Database).one()
|
||||
database = db.session.query(Database).one()
|
||||
assert database.uuid == UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb")
|
||||
|
||||
|
||||
@@ -139,8 +141,8 @@ def test_password_mask(
|
||||
}
|
||||
),
|
||||
)
|
||||
session.add(database)
|
||||
session.commit()
|
||||
db.session.add(database)
|
||||
db.session.commit()
|
||||
|
||||
# mock the lookup so that we don't need to include the driver
|
||||
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
|
||||
@@ -195,8 +197,8 @@ def test_database_connection(
|
||||
}
|
||||
),
|
||||
)
|
||||
session.add(database)
|
||||
session.commit()
|
||||
db.session.add(database)
|
||||
db.session.commit()
|
||||
|
||||
# mock the lookup so that we don't need to include the driver
|
||||
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
|
||||
@@ -331,8 +333,8 @@ def test_update_with_password_mask(
|
||||
}
|
||||
),
|
||||
)
|
||||
session.add(database)
|
||||
session.commit()
|
||||
db.session.add(database)
|
||||
db.session.commit()
|
||||
|
||||
client.put(
|
||||
"/api/v1/database/1",
|
||||
@@ -347,7 +349,7 @@ def test_update_with_password_mask(
|
||||
),
|
||||
},
|
||||
)
|
||||
database = session.query(Database).one()
|
||||
database = db.session.query(Database).one()
|
||||
assert (
|
||||
database.encrypted_extra
|
||||
== '{"service_account_info": {"project_id": "yellow-unicorn-314419", "private_key": "SECRET"}}'
|
||||
@@ -429,8 +431,8 @@ def test_delete_ssh_tunnel(
|
||||
}
|
||||
),
|
||||
)
|
||||
session.add(database)
|
||||
session.commit()
|
||||
db.session.add(database)
|
||||
db.session.commit()
|
||||
|
||||
# mock the lookup so that we don't need to include the driver
|
||||
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
|
||||
@@ -446,8 +448,8 @@ def test_delete_ssh_tunnel(
|
||||
database=database,
|
||||
)
|
||||
|
||||
session.add(tunnel)
|
||||
session.commit()
|
||||
db.session.add(tunnel)
|
||||
db.session.commit()
|
||||
|
||||
# Get our recently created SSHTunnel
|
||||
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
|
||||
@@ -505,8 +507,8 @@ def test_delete_ssh_tunnel_not_found(
|
||||
}
|
||||
),
|
||||
)
|
||||
session.add(database)
|
||||
session.commit()
|
||||
db.session.add(database)
|
||||
db.session.commit()
|
||||
|
||||
# mock the lookup so that we don't need to include the driver
|
||||
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
|
||||
@@ -522,8 +524,8 @@ def test_delete_ssh_tunnel_not_found(
|
||||
database=database,
|
||||
)
|
||||
|
||||
session.add(tunnel)
|
||||
session.commit()
|
||||
db.session.add(tunnel)
|
||||
db.session.commit()
|
||||
|
||||
# Delete the recently created SSHTunnel
|
||||
response_delete_tunnel = client.delete("/api/v1/database/2/ssh_tunnel/")
|
||||
@@ -576,8 +578,8 @@ def test_apply_dynamic_database_filter(
|
||||
}
|
||||
),
|
||||
)
|
||||
session.add(database)
|
||||
session.commit()
|
||||
db.session.add(database)
|
||||
db.session.commit()
|
||||
|
||||
# Create our Second Database
|
||||
database = Database(
|
||||
@@ -592,8 +594,8 @@ def test_apply_dynamic_database_filter(
|
||||
}
|
||||
),
|
||||
)
|
||||
session.add(database)
|
||||
session.commit()
|
||||
db.session.add(database)
|
||||
db.session.commit()
|
||||
|
||||
# mock the lookup so that we don't need to include the driver
|
||||
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
|
||||
|
||||
@@ -23,6 +23,7 @@ import pytest
|
||||
from pytest_mock import MockFixture
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset import db
|
||||
from superset.commands.exceptions import ImportFailedError
|
||||
|
||||
|
||||
@@ -37,11 +38,11 @@ def test_import_database(mocker: MockFixture, session: Session) -> None:
|
||||
|
||||
mocker.patch.object(security_manager, "can_access", return_value=True)
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
Database.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
config = copy.deepcopy(database_config)
|
||||
database = import_database(session, config)
|
||||
database = import_database(config)
|
||||
assert database.database_name == "imported_database"
|
||||
assert database.sqlalchemy_uri == "someengine://user:pass@host1"
|
||||
assert database.cache_timeout is None
|
||||
@@ -60,9 +61,9 @@ def test_import_database(mocker: MockFixture, session: Session) -> None:
|
||||
# missing
|
||||
config = copy.deepcopy(database_config)
|
||||
del config["allow_dml"]
|
||||
session.delete(database)
|
||||
session.flush()
|
||||
database = import_database(session, config)
|
||||
db.session.delete(database)
|
||||
db.session.flush()
|
||||
database = import_database(config)
|
||||
assert database.allow_dml is False
|
||||
|
||||
|
||||
@@ -78,12 +79,12 @@ def test_import_database_sqlite_invalid(mocker: MockFixture, session: Session) -
|
||||
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = True
|
||||
mocker.patch.object(security_manager, "can_access", return_value=True)
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
Database.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
config = copy.deepcopy(database_config_sqlite)
|
||||
with pytest.raises(ImportFailedError) as excinfo:
|
||||
_ = import_database(session, config)
|
||||
_ = import_database(config)
|
||||
assert (
|
||||
str(excinfo.value)
|
||||
== "SQLiteDialect_pysqlite cannot be used as a data source for security reasons."
|
||||
@@ -106,14 +107,14 @@ def test_import_database_managed_externally(
|
||||
|
||||
mocker.patch.object(security_manager, "can_access", return_value=True)
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
Database.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
config = copy.deepcopy(database_config)
|
||||
config["is_managed_externally"] = True
|
||||
config["external_url"] = "https://example.org/my_database"
|
||||
|
||||
database = import_database(session, config)
|
||||
database = import_database(config)
|
||||
assert database.is_managed_externally is True
|
||||
assert database.external_url == "https://example.org/my_database"
|
||||
|
||||
@@ -132,13 +133,13 @@ def test_import_database_without_permission(
|
||||
|
||||
mocker.patch.object(security_manager, "can_access", return_value=False)
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
Database.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
config = copy.deepcopy(database_config)
|
||||
|
||||
with pytest.raises(ImportFailedError) as excinfo:
|
||||
import_database(session, config)
|
||||
import_database(config)
|
||||
assert (
|
||||
str(excinfo.value)
|
||||
== "Database doesn't exist and user doesn't have permission to create databases"
|
||||
@@ -156,10 +157,10 @@ def test_import_database_with_version(mocker: MockFixture, session: Session) ->
|
||||
|
||||
mocker.patch.object(security_manager, "can_access", return_value=True)
|
||||
|
||||
engine = session.get_bind()
|
||||
engine = db.session.get_bind()
|
||||
Database.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
config = copy.deepcopy(database_config)
|
||||
config["extra"]["version"] = "1.1.1"
|
||||
database = import_database(session, config)
|
||||
database = import_database(config)
|
||||
assert json.loads(database.extra)["version"] == "1.1.1"
|
||||
|
||||
@@ -30,19 +30,19 @@ def session_with_data(session: Session) -> Iterator[Session]:
|
||||
engine = session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
sqla_table = SqlaTable(
|
||||
table_name="my_sqla_table",
|
||||
columns=[],
|
||||
metrics=[],
|
||||
database=db,
|
||||
database=database,
|
||||
)
|
||||
ssh_tunnel = SSHTunnel(
|
||||
database_id=db.id,
|
||||
database=db,
|
||||
database_id=database.id,
|
||||
database=database,
|
||||
)
|
||||
|
||||
session.add(db)
|
||||
session.add(database)
|
||||
session.add(sqla_table)
|
||||
session.add(ssh_tunnel)
|
||||
session.flush()
|
||||
|
||||
@@ -27,17 +27,17 @@ def test_create_ssh_tunnel_command() -> None:
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
|
||||
properties = {
|
||||
"database_id": db.id,
|
||||
"database_id": database.id,
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": "3005",
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
|
||||
result = CreateSSHTunnelCommand(db, properties).run()
|
||||
result = CreateSSHTunnelCommand(database, properties).run()
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, SSHTunnel)
|
||||
@@ -48,19 +48,19 @@ def test_create_ssh_tunnel_command_invalid_params() -> None:
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
|
||||
# If we are trying to create a tunnel with a private_key_password
|
||||
# then a private_key is mandatory
|
||||
properties = {
|
||||
"database": db,
|
||||
"database": database,
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": "3005",
|
||||
"username": "foo",
|
||||
"private_key_password": "bar",
|
||||
}
|
||||
|
||||
command = CreateSSHTunnelCommand(db, properties)
|
||||
command = CreateSSHTunnelCommand(database, properties)
|
||||
|
||||
with pytest.raises(SSHTunnelInvalidError) as excinfo:
|
||||
command.run()
|
||||
|
||||
@@ -31,19 +31,19 @@ def session_with_data(session: Session) -> Iterator[Session]:
|
||||
engine = session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
sqla_table = SqlaTable(
|
||||
table_name="my_sqla_table",
|
||||
columns=[],
|
||||
metrics=[],
|
||||
database=db,
|
||||
database=database,
|
||||
)
|
||||
ssh_tunnel = SSHTunnel(
|
||||
database_id=db.id,
|
||||
database=db,
|
||||
database_id=database.id,
|
||||
database=database,
|
||||
)
|
||||
|
||||
session.add(db)
|
||||
session.add(database)
|
||||
session.add(sqla_table)
|
||||
session.add(ssh_tunnel)
|
||||
session.flush()
|
||||
|
||||
@@ -32,16 +32,18 @@ def session_with_data(session: Session) -> Iterator[Session]:
|
||||
engine = session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
sqla_table = SqlaTable(
|
||||
table_name="my_sqla_table",
|
||||
columns=[],
|
||||
metrics=[],
|
||||
database=db,
|
||||
database=database,
|
||||
)
|
||||
ssh_tunnel = SSHTunnel(
|
||||
database_id=database.id, database=database, server_address="Test"
|
||||
)
|
||||
ssh_tunnel = SSHTunnel(database_id=db.id, database=db, server_address="Test")
|
||||
|
||||
session.add(db)
|
||||
session.add(database)
|
||||
session.add(sqla_table)
|
||||
session.add(ssh_tunnel)
|
||||
session.flush()
|
||||
|
||||
@@ -25,11 +25,11 @@ def test_create_ssh_tunnel():
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
|
||||
result = SSHTunnelDAO.create(
|
||||
attributes={
|
||||
"database_id": db.id,
|
||||
"database_id": database.id,
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": "3005",
|
||||
"username": "foo",
|
||||
|
||||
Reference in New Issue
Block a user