Files
superset2/superset/datasets/dao.py
Will Barrett 8e23d4f369 chore: Upgrade pylint to 2.5.3 and fix most new rules (#10101)
* Bump pylint version to 2.5.3

* Add a global disable for the most common new pylint error

* Fix a bunch of files containing very few errors

* More pylint tweakage, low-hanging fruit

* More easy stuff...

* Fix more erroring files

* Fix the last couple of errors, clean pylint!

* Black

* Fix mypy issue in connectors/druid/models.py
2020-06-18 14:03:42 -07:00

189 lines
6.8 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict, List, Optional
from flask import current_app
from sqlalchemy.exc import SQLAlchemyError
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.dao.base import BaseDAO
from superset.extensions import db
from superset.models.core import Database
from superset.views.base import DatasourceFilter
logger = logging.getLogger(__name__)
class DatasetDAO(BaseDAO):
model_cls = SqlaTable
base_filter = DatasourceFilter
@staticmethod
def get_owner_by_id(owner_id: int) -> Optional[object]:
return (
db.session.query(current_app.appbuilder.sm.user_model)
.filter_by(id=owner_id)
.one_or_none()
)
@staticmethod
def get_database_by_id(database_id: int) -> Optional[Database]:
try:
return db.session.query(Database).filter_by(id=database_id).one_or_none()
except SQLAlchemyError as ex: # pragma: no cover
logger.error("Could not get database by id: %s", str(ex))
return None
@staticmethod
def validate_table_exists(database: Database, table_name: str, schema: str) -> bool:
try:
database.get_table(table_name, schema=schema)
return True
except SQLAlchemyError as ex: # pragma: no cover
logger.error("Got an error %s validating table: %s", str(ex), table_name)
return False
@staticmethod
def validate_uniqueness(database_id: int, name: str) -> bool:
dataset_query = db.session.query(SqlaTable).filter(
SqlaTable.table_name == name, SqlaTable.database_id == database_id
)
return not db.session.query(dataset_query.exists()).scalar()
@staticmethod
def validate_update_uniqueness(
database_id: int, dataset_id: int, name: str
) -> bool:
dataset_query = db.session.query(SqlaTable).filter(
SqlaTable.table_name == name,
SqlaTable.database_id == database_id,
SqlaTable.id != dataset_id,
)
return not db.session.query(dataset_query.exists()).scalar()
@staticmethod
def validate_columns_exist(dataset_id: int, columns_ids: List[int]) -> bool:
dataset_query = (
db.session.query(TableColumn.id).filter(
TableColumn.table_id == dataset_id, TableColumn.id.in_(columns_ids)
)
).all()
return len(columns_ids) == len(dataset_query)
@staticmethod
def validate_columns_uniqueness(dataset_id: int, columns_names: List[str]) -> bool:
dataset_query = (
db.session.query(TableColumn.id).filter(
TableColumn.table_id == dataset_id,
TableColumn.column_name.in_(columns_names),
)
).all()
return len(dataset_query) == 0
@staticmethod
def validate_metrics_exist(dataset_id: int, metrics_ids: List[int]) -> bool:
dataset_query = (
db.session.query(SqlMetric.id).filter(
SqlMetric.table_id == dataset_id, SqlMetric.id.in_(metrics_ids)
)
).all()
return len(metrics_ids) == len(dataset_query)
@staticmethod
def validate_metrics_uniqueness(dataset_id: int, metrics_names: List[str]) -> bool:
dataset_query = (
db.session.query(SqlMetric.id).filter(
SqlMetric.table_id == dataset_id,
SqlMetric.metric_name.in_(metrics_names),
)
).all()
return len(dataset_query) == 0
@classmethod
def update(
cls, model: SqlaTable, properties: Dict[str, Any], commit: bool = True
) -> Optional[SqlaTable]:
"""
Updates a Dataset model on the metadata DB
"""
if "columns" in properties:
new_columns = list()
for column in properties.get("columns", []):
if column.get("id"):
column_obj = db.session.query(TableColumn).get(column.get("id"))
column_obj = DatasetDAO.update_column(
column_obj, column, commit=commit
)
else:
column_obj = DatasetDAO.create_column(column, commit=commit)
new_columns.append(column_obj)
properties["columns"] = new_columns
if "metrics" in properties:
new_metrics = list()
for metric in properties.get("metrics", []):
if metric.get("id"):
metric_obj = db.session.query(SqlMetric).get(metric.get("id"))
metric_obj = DatasetDAO.update_metric(
metric_obj, metric, commit=commit
)
else:
metric_obj = DatasetDAO.create_metric(metric, commit=commit)
new_metrics.append(metric_obj)
properties["metrics"] = new_metrics
return super().update(model, properties, commit=commit)
@classmethod
def update_column(
cls, model: TableColumn, properties: Dict[str, Any], commit: bool = True
) -> Optional[TableColumn]:
return DatasetColumnDAO.update(model, properties, commit=commit)
@classmethod
def create_column(
cls, properties: Dict[str, Any], commit: bool = True
) -> Optional[TableColumn]:
"""
Creates a Dataset model on the metadata DB
"""
return DatasetColumnDAO.create(properties, commit=commit)
@classmethod
def update_metric(
cls, model: SqlMetric, properties: Dict[str, Any], commit: bool = True
) -> Optional[SqlMetric]:
return DatasetMetricDAO.update(model, properties, commit=commit)
@classmethod
def create_metric(
cls, properties: Dict[str, Any], commit: bool = True
) -> Optional[SqlMetric]:
"""
Creates a Dataset model on the metadata DB
"""
return DatasetMetricDAO.create(properties, commit=commit)
class DatasetColumnDAO(BaseDAO):
model_cls = TableColumn
class DatasetMetricDAO(BaseDAO):
model_cls = SqlMetric