chore: Cleanup database sessions (#10427)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley
2020-07-30 23:07:56 -07:00
committed by GitHub
parent 7ff1757448
commit 7645fc85c3
39 changed files with 488 additions and 637 deletions

View File

@@ -22,10 +22,10 @@ from io import BytesIO
from typing import Any, Dict, Optional
from flask_babel import lazy_gettext as _
from sqlalchemy.orm import Session
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.exceptions import DashboardImportException
from superset.extensions import db
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
@@ -71,7 +71,6 @@ def decode_dashboards( # pylint: disable=too-many-return-statements
def import_dashboards(
session: Session,
data_stream: BytesIO,
database_id: Optional[int] = None,
import_time: Optional[int] = None,
@@ -84,16 +83,16 @@ def import_dashboards(
raise DashboardImportException(_("No data in file"))
for table in data["datasources"]:
type(table).import_obj(table, database_id, import_time=import_time)
session.commit()
db.session.commit()
for dashboard in data["dashboards"]:
Dashboard.import_obj(dashboard, import_time=import_time)
session.commit()
db.session.commit()
def export_dashboards(session: Session) -> str:
def export_dashboards() -> str:
"""Returns all dashboards metadata as a json dump"""
logger.info("Starting export")
dashboards = session.query(Dashboard)
dashboards = db.session.query(Dashboard)
dashboard_ids = []
for dashboard in dashboards:
dashboard_ids.append(dashboard.id)

View File

@@ -17,9 +17,8 @@
import logging
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from superset.connectors.druid.models import DruidCluster
from superset.extensions import db
from superset.models.core import Database
DATABASES_KEY = "databases"
@@ -44,11 +43,11 @@ def export_schema_to_dict(back_references: bool) -> Dict[str, Any]:
def export_to_dict(
session: Session, recursive: bool, back_references: bool, include_defaults: bool
recursive: bool, back_references: bool, include_defaults: bool
) -> Dict[str, Any]:
"""Exports databases and druid clusters to a dictionary"""
logger.info("Starting export")
dbs = session.query(Database)
dbs = db.session.query(Database)
databases = [
database.export_to_dict(
recursive=recursive,
@@ -58,7 +57,7 @@ def export_to_dict(
for database in dbs
]
logger.info("Exported %d %s", len(databases), DATABASES_KEY)
cls = session.query(DruidCluster)
cls = db.session.query(DruidCluster)
clusters = [
cluster.export_to_dict(
recursive=recursive,
@@ -76,22 +75,20 @@ def export_to_dict(
return data
def import_from_dict(
session: Session, data: Dict[str, Any], sync: Optional[List[str]] = None
) -> None:
def import_from_dict(data: Dict[str, Any], sync: Optional[List[str]] = None) -> None:
"""Imports databases and druid clusters from dictionary"""
if not sync:
sync = []
if isinstance(data, dict):
logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY)
for database in data.get(DATABASES_KEY, []):
Database.import_from_dict(session, database, sync=sync)
Database.import_from_dict(database, sync=sync)
logger.info(
"Importing %d %s", len(data.get(DRUID_CLUSTERS_KEY, [])), DRUID_CLUSTERS_KEY
)
for datasource in data.get(DRUID_CLUSTERS_KEY, []):
DruidCluster.import_from_dict(session, datasource, sync=sync)
session.commit()
DruidCluster.import_from_dict(datasource, sync=sync)
db.session.commit()
else:
logger.info("Supplied object is not a dictionary.")

View File

@@ -18,14 +18,14 @@ import logging
from typing import Callable, Optional
from flask_appbuilder import Model
from sqlalchemy.orm import Session
from sqlalchemy.orm.session import make_transient
from superset.extensions import db
logger = logging.getLogger(__name__)
def import_datasource( # pylint: disable=too-many-arguments
session: Session,
i_datasource: Model,
lookup_database: Callable[[Model], Model],
lookup_datasource: Callable[[Model], Model],
@@ -52,11 +52,11 @@ def import_datasource( # pylint: disable=too-many-arguments
if datasource:
datasource.override(i_datasource)
session.flush()
db.session.flush()
else:
datasource = i_datasource.copy()
session.add(datasource)
session.flush()
db.session.add(datasource)
db.session.flush()
for metric in i_datasource.metrics:
new_m = metric.copy()
@@ -81,13 +81,11 @@ def import_datasource( # pylint: disable=too-many-arguments
imported_c = i_datasource.column_class.import_obj(new_c)
if imported_c.column_name not in [c.column_name for c in datasource.columns]:
datasource.columns.append(imported_c)
session.flush()
db.session.flush()
return datasource.id
def import_simple_obj(
session: Session, i_obj: Model, lookup_obj: Callable[[Model], Model]
) -> Model:
def import_simple_obj(i_obj: Model, lookup_obj: Callable[[Model], Model]) -> Model:
make_transient(i_obj)
i_obj.id = None
i_obj.table = None
@@ -97,9 +95,9 @@ def import_simple_obj(
i_obj.table = None
if existing_column:
existing_column.override(i_obj)
session.flush()
db.session.flush()
return existing_column
session.add(i_obj)
session.flush()
db.session.add(i_obj)
db.session.flush()
return i_obj