From dfb3bf69a0d4388a04d23c3049ed33b5aedf524f Mon Sep 17 00:00:00 2001 From: serenajiang Date: Thu, 19 Sep 2019 16:51:01 -0700 Subject: [PATCH] [typing] add typing for superset/connectors and superset/common (#8138) --- superset/common/query_context.py | 45 +++-- superset/common/query_object.py | 43 ++-- superset/connectors/base/models.py | 81 ++++---- superset/connectors/connector_registry.py | 46 +++-- superset/connectors/druid/models.py | 227 ++++++++++++---------- superset/connectors/druid/views.py | 6 +- superset/connectors/sqla/models.py | 123 +++++++----- 7 files changed, 329 insertions(+), 242 deletions(-) diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 6fe47a01980..c2534e9f64f 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -18,20 +18,22 @@ from datetime import datetime, timedelta import logging import pickle as pkl -from typing import Dict, List +from typing import Any, Dict, List, Optional import numpy as np import pandas as pd from superset import app, cache from superset import db +from superset.connectors.base.models import BaseDatasource from superset.connectors.connector_registry import ConnectorRegistry +from superset.stats_logger import BaseStatsLogger from superset.utils import core as utils from superset.utils.core import DTTM_ALIAS from .query_object import QueryObject config = app.config -stats_logger = config.get("STATS_LOGGER") +stats_logger: BaseStatsLogger = config["STATS_LOGGER"] class QueryContext: @@ -40,8 +42,13 @@ class QueryContext: to retrieve the data payload for a given viz. """ - cache_type = "df" - enforce_numerical_metrics = True + cache_type: str = "df" + enforce_numerical_metrics: bool = True + + datasource: BaseDatasource + queries: List[QueryObject] + force: bool + custom_cache_timeout: Optional[int] # TODO: Type datasource and query_object dictionary with TypedDict when it becomes # a vanilla python type https://github.com/python/mypy/issues/5288 @@ -50,8 +57,8 @@ class QueryContext: datasource: Dict, queries: List[Dict], force: bool = False, - custom_cache_timeout: int = None, - ): + custom_cache_timeout: Optional[int] = None, + ) -> None: self.datasource = ConnectorRegistry.get_datasource( datasource.get("type"), int(datasource.get("id")), db.session # noqa: T400 ) @@ -61,9 +68,7 @@ class QueryContext: self.custom_cache_timeout = custom_cache_timeout - self.enforce_numerical_metrics = True - - def get_query_result(self, query_object): + def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]: """Returns a pandas dataframe based on the query object""" # Here, we assume that all the queries will use the same datasource, which is @@ -109,23 +114,23 @@ class QueryContext: "df": df, } - def df_metrics_to_num(self, df, query_object): + def df_metrics_to_num(self, df: pd.DataFrame, query_object: QueryObject) -> None: """Converting metrics to numeric when pandas.read_sql cannot""" metrics = [metric for metric in query_object.metrics] for col, dtype in df.dtypes.items(): if dtype.type == np.object_ and col in metrics: df[col] = pd.to_numeric(df[col], errors="coerce") - def get_data(self, df): + def get_data(self, df: pd.DataFrame) -> List[Dict]: return df.to_dict(orient="records") - def get_single_payload(self, query_obj: QueryObject): + def get_single_payload(self, query_obj: QueryObject) -> Dict[str, Any]: """Returns a payload of metadata and data""" payload = self.get_df_payload(query_obj) df = payload.get("df") status = payload.get("status") if status != utils.QueryStatus.FAILED: - if df is not None and df.empty: + if df is None or df.empty: payload["error"] = "No data" else: payload["data"] = self.get_data(df) @@ -133,12 +138,12 @@ class QueryContext: del payload["df"] return payload - def get_payload(self): + def get_payload(self) -> List[Dict[str, Any]]: """Get all the payloads from the arrays""" return [self.get_single_payload(query_object) for query_object in self.queries] @property - def cache_timeout(self): + def cache_timeout(self) -> int: if self.custom_cache_timeout is not None: return self.custom_cache_timeout if self.datasource.cache_timeout is not None: @@ -148,10 +153,10 @@ class QueryContext: and self.datasource.database.cache_timeout ) is not None: return self.datasource.database.cache_timeout - return config.get("CACHE_DEFAULT_TIMEOUT") + return config["CACHE_DEFAULT_TIMEOUT"] - def get_df_payload(self, query_obj: QueryObject, **kwargs): - """Handles caching around the df paylod retrieval""" + def get_df_payload(self, query_obj: QueryObject, **kwargs) -> Dict[str, Any]: + """Handles caching around the df payload retrieval""" extra_cache_keys = self.datasource.get_extra_cache_keys(query_obj.to_dict()) cache_key = ( query_obj.cache_key( @@ -207,9 +212,7 @@ class QueryContext: if is_loaded and cache_key and cache and status != utils.QueryStatus.FAILED: try: - cache_value = dict( - dttm=cached_dttm, df=df if df is not None else None, query=query - ) + cache_value = dict(dttm=cached_dttm, df=df, query=query) cache_binary = pkl.dumps(cache_value, protocol=pkl.HIGHEST_PROTOCOL) logging.info( diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 28f2303c25d..71b690358a4 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -15,8 +15,9 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=R +from datetime import datetime, timedelta import hashlib -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import simplejson as json @@ -34,12 +35,28 @@ class QueryObject: and druid. The query objects are constructed on the client. """ + granularity: str + from_dttm: datetime + to_dttm: datetime + is_timeseries: bool + time_shift: timedelta + groupby: List[str] + metrics: List[Union[Dict, str]] + row_limit: int + filter: List[str] + timeseries_limit: int + timeseries_limit_metric: Optional[Dict] + order_desc: bool + extras: Dict + columns: List[str] + orderby: List[List] + def __init__( self, granularity: str, metrics: List[Union[Dict, str]], - groupby: List[str] = None, - filters: List[str] = None, + groupby: Optional[List[str]] = None, + filters: Optional[List[str]] = None, time_range: Optional[str] = None, time_shift: Optional[str] = None, is_timeseries: bool = False, @@ -48,8 +65,8 @@ class QueryObject: timeseries_limit_metric: Optional[Dict] = None, order_desc: bool = True, extras: Optional[Dict] = None, - columns: List[str] = None, - orderby: List[List] = None, + columns: Optional[List[str]] = None, + orderby: Optional[List[List]] = None, relative_start: str = app.config.get("DEFAULT_RELATIVE_START_TIME", "today"), relative_end: str = app.config.get("DEFAULT_RELATIVE_END_TIME", "today"), ): @@ -63,7 +80,7 @@ class QueryObject: self.is_timeseries = is_timeseries self.time_range = time_range self.time_shift = utils.parse_human_timedelta(time_shift) - self.groupby = groupby if groupby is not None else [] + self.groupby = groupby or [] # Temporal solution for backward compatability issue # due the new format of non-ad-hoc metric. @@ -72,15 +89,15 @@ class QueryObject: for metric in metrics ] self.row_limit = row_limit - self.filter = filters if filters is not None else [] + self.filter = filters or [] self.timeseries_limit = timeseries_limit self.timeseries_limit_metric = timeseries_limit_metric self.order_desc = order_desc - self.extras = extras if extras is not None else {} - self.columns = columns if columns is not None else [] - self.orderby = orderby if orderby is not None else [] + self.extras = extras or {} + self.columns = columns or [] + self.orderby = orderby or [] - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: query_object_dict = { "granularity": self.granularity, "from_dttm": self.from_dttm, @@ -99,7 +116,7 @@ class QueryObject: } return query_object_dict - def cache_key(self, **extra): + def cache_key(self, **extra) -> str: """ The cache key is made out of the key/values from to_dict(), plus any other key/values in `extra` @@ -117,7 +134,7 @@ class QueryObject: json_data = self.json_dumps(cache_dict, sort_keys=True) return hashlib.md5(json_data.encode("utf-8")).hexdigest() - def json_dumps(self, obj, sort_keys=False): + def json_dumps(self, obj: Any, sort_keys: bool = False) -> str: return json.dumps( obj, default=utils.json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys ) diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py index d82f4d6aa02..d84910c6e56 100644 --- a/superset/connectors/base/models.py +++ b/superset/connectors/base/models.py @@ -16,14 +16,15 @@ # under the License. # pylint: disable=C,R,W import json -from typing import Any, List +from typing import Any, Dict, List, Optional +from flask_appbuilder.security.sqla.models import User from sqlalchemy import and_, Boolean, Column, Integer, String, Text from sqlalchemy.ext.declarative import declared_attr -from sqlalchemy.orm import foreign, relationship +from sqlalchemy.orm import foreign, Query, relationship from superset.models.core import Slice -from superset.models.helpers import AuditMixinNullable, ImportMixin +from superset.models.helpers import AuditMixinNullable, ImportMixin, QueryResult from superset.utils import core as utils @@ -59,9 +60,9 @@ class BaseDatasource(AuditMixinNullable, ImportMixin): params = Column(String(1000)) perm = Column(String(1000)) - sql = None - owners = None - update_from_object_fields = None + sql: Optional[str] = None + owners: List[User] + update_from_object_fields: List[str] @declared_attr def slices(self): @@ -79,20 +80,20 @@ class BaseDatasource(AuditMixinNullable, ImportMixin): metrics: List[Any] = [] @property - def uid(self): + def uid(self) -> str: """Unique id across datasource types""" return f"{self.id}__{self.type}" @property - def column_names(self): + def column_names(self) -> List[str]: return sorted([c.column_name for c in self.columns], key=lambda x: x or "") @property - def columns_types(self): + def columns_types(self) -> Dict: return {c.column_name: c.type for c in self.columns} @property - def main_dttm_col(self): + def main_dttm_col(self) -> str: return "timestamp" @property @@ -100,47 +101,47 @@ class BaseDatasource(AuditMixinNullable, ImportMixin): raise NotImplementedError() @property - def connection(self): + def connection(self) -> Optional[str]: """String representing the context of the Datasource""" return None @property - def schema(self): + def schema(self) -> Optional[str]: """String representing the schema of the Datasource (if it applies)""" return None @property - def filterable_column_names(self): + def filterable_column_names(self) -> List[str]: return sorted([c.column_name for c in self.columns if c.filterable]) @property - def dttm_cols(self): + def dttm_cols(self) -> List: return [] @property - def url(self): + def url(self) -> str: return "/{}/edit/{}".format(self.baselink, self.id) @property - def explore_url(self): + def explore_url(self) -> str: if self.default_endpoint: return self.default_endpoint else: return "/superset/explore/{obj.type}/{obj.id}/".format(obj=self) @property - def column_formats(self): + def column_formats(self) -> Dict[str, Optional[str]]: return {m.metric_name: m.d3format for m in self.metrics if m.d3format} - def add_missing_metrics(self, metrics): - exisiting_metrics = {m.metric_name for m in self.metrics} + def add_missing_metrics(self, metrics: List["BaseMetric"]) -> None: + existing_metrics = {m.metric_name for m in self.metrics} for metric in metrics: - if metric.metric_name not in exisiting_metrics: + if metric.metric_name not in existing_metrics: metric.table_id = self.id - self.metrics += [metric] + self.metrics.append(metric) @property - def short_data(self): + def short_data(self) -> Dict[str, Any]: """Data representation of the datasource sent to the frontend""" return { "edit_url": self.url, @@ -158,7 +159,7 @@ class BaseDatasource(AuditMixinNullable, ImportMixin): pass @property - def data(self): + def data(self) -> Dict[str, Any]: """Data representation of the datasource sent to the frontend""" order_by_choices = [] # self.column_names return sorted column_names @@ -239,14 +240,14 @@ class BaseDatasource(AuditMixinNullable, ImportMixin): """Returns column information from the external system""" raise NotImplementedError() - def get_query_str(self, query_obj): + def get_query_str(self, query_obj) -> str: """Returns a query as a string This is used to be displayed to the user so that she/he can understand what is taking place behind the scene""" raise NotImplementedError() - def query(self, query_obj): + def query(self, query_obj) -> QueryResult: """Executes the query and returns a dataframe query_obj is a dictionary representing Superset's query interface. @@ -254,7 +255,7 @@ class BaseDatasource(AuditMixinNullable, ImportMixin): """ raise NotImplementedError() - def values_for_column(self, column_name, limit=10000): + def values_for_column(self, column_name: str, limit: int = 10000) -> List: """Given a column, returns an iterable of distinct values This is used to populate the dropdown showing a list of @@ -262,13 +263,14 @@ class BaseDatasource(AuditMixinNullable, ImportMixin): raise NotImplementedError() @staticmethod - def default_query(qry): + def default_query(qry) -> Query: return qry - def get_column(self, column_name): + def get_column(self, column_name: str) -> Optional["BaseColumn"]: for col in self.columns: if col.column_name == column_name: return col + return None def get_fk_many_from_list(self, object_list, fkmany, fkmany_class, key_attr): """Update ORM one-to-many list from object list @@ -276,10 +278,9 @@ class BaseDatasource(AuditMixinNullable, ImportMixin): Used for syncing metrics and columns using the same code""" object_dict = {o.get(key_attr): o for o in object_list} - object_keys = [o.get(key_attr) for o in object_list] # delete fks that have been removed - fkmany = [o for o in fkmany if getattr(o, key_attr) in object_keys] + fkmany = [o for o in fkmany if getattr(o, key_attr) in object_dict] # sync existing fks for fk in fkmany: @@ -303,7 +304,7 @@ class BaseDatasource(AuditMixinNullable, ImportMixin): fkmany += new_fks return fkmany - def update_from_object(self, obj): + def update_from_object(self, obj) -> None: """Update datasource from a data structure The UI's table editor crafts a complex data structure that @@ -330,7 +331,7 @@ class BaseDatasource(AuditMixinNullable, ImportMixin): obj.get("columns"), self.columns, self.column_class, "column_name" ) - def get_extra_cache_keys(self, query_obj) -> List[Any]: + def get_extra_cache_keys(self, query_obj: Dict) -> List[Any]: """ If a datasource needs to provide additional keys for calculation of cache keys, those can be provided via this method """ @@ -374,23 +375,23 @@ class BaseColumn(AuditMixinNullable, ImportMixin): str_types = ("VARCHAR", "STRING", "CHAR") @property - def is_num(self): - return self.type and any([t in self.type.upper() for t in self.num_types]) + def is_num(self) -> bool: + return self.type and any(map(lambda t: t in self.type.upper(), self.num_types)) @property - def is_time(self): - return self.type and any([t in self.type.upper() for t in self.date_types]) + def is_time(self) -> bool: + return self.type and any(map(lambda t: t in self.type.upper(), self.date_types)) @property - def is_string(self): - return self.type and any([t in self.type.upper() for t in self.str_types]) + def is_string(self) -> bool: + return self.type and any(map(lambda t: t in self.type.upper(), self.str_types)) @property def expression(self): raise NotImplementedError() @property - def data(self): + def data(self) -> Dict[str, Any]: attrs = ( "id", "column_name", @@ -443,7 +444,7 @@ class BaseMetric(AuditMixinNullable, ImportMixin): raise NotImplementedError() @property - def data(self): + def data(self) -> Dict[str, Any]: attrs = ( "id", "metric_name", diff --git a/superset/connectors/connector_registry.py b/superset/connectors/connector_registry.py index d5e951aabab..9ce11802ff0 100644 --- a/superset/connectors/connector_registry.py +++ b/superset/connectors/connector_registry.py @@ -15,16 +15,23 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=C,R,W -from sqlalchemy.orm import subqueryload +from collections import OrderedDict +from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING + +from sqlalchemy.orm import Session, subqueryload + +if TYPE_CHECKING: + from superset.models.core import Database + from superset.connectors.base.models import BaseDatasource class ConnectorRegistry(object): """ Central Registry for all available datasource engines""" - sources = {} + sources: Dict[str, Type["BaseDatasource"]] = {} @classmethod - def register_sources(cls, datasource_config): + def register_sources(cls, datasource_config: OrderedDict) -> None: for module_name, class_names in datasource_config.items(): class_names = [str(s) for s in class_names] module_obj = __import__(module_name, fromlist=class_names) @@ -33,7 +40,9 @@ class ConnectorRegistry(object): cls.sources[source_class.type] = source_class @classmethod - def get_datasource(cls, datasource_type, datasource_id, session): + def get_datasource( + cls, datasource_type: str, datasource_id: int, session: Session + ) -> Optional["BaseDatasource"]: return ( session.query(cls.sources[datasource_type]) .filter_by(id=datasource_id) @@ -41,8 +50,8 @@ class ConnectorRegistry(object): ) @classmethod - def get_all_datasources(cls, session): - datasources = [] + def get_all_datasources(cls, session: Session) -> List["BaseDatasource"]: + datasources: List["BaseDatasource"] = [] for source_type in ConnectorRegistry.sources: source_class = ConnectorRegistry.sources[source_type] qry = session.query(source_class) @@ -52,15 +61,22 @@ class ConnectorRegistry(object): @classmethod def get_datasource_by_name( - cls, session, datasource_type, datasource_name, schema, database_name - ): + cls, + session: Session, + datasource_type: str, + datasource_name: str, + schema: str, + database_name: str, + ) -> Optional["BaseDatasource"]: datasource_class = ConnectorRegistry.sources[datasource_type] return datasource_class.get_datasource_by_name( session, datasource_name, schema, database_name ) @classmethod - def query_datasources_by_permissions(cls, session, database, permissions): + def query_datasources_by_permissions( + cls, session: Session, database: "Database", permissions: Set[str] + ) -> List["BaseDatasource"]: datasource_class = ConnectorRegistry.sources[database.type] return ( session.query(datasource_class) @@ -70,7 +86,9 @@ class ConnectorRegistry(object): ) @classmethod - def get_eager_datasource(cls, session, datasource_type, datasource_id): + def get_eager_datasource( + cls, session: Session, datasource_type: str, datasource_id: int + ) -> "BaseDatasource": """Returns datasource with columns and metrics.""" datasource_class = ConnectorRegistry.sources[datasource_type] return ( @@ -84,7 +102,13 @@ class ConnectorRegistry(object): ) @classmethod - def query_datasources_by_name(cls, session, database, datasource_name, schema=None): + def query_datasources_by_name( + cls, + session: Session, + database: "Database", + datasource_name: str, + schema: Optional[str] = None, + ) -> List["BaseDatasource"]: datasource_class = ConnectorRegistry.sources[database.type] return datasource_class.query_datasources_by_name( session, database, datasource_name, schema=None diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index b908056aaed..581ec7db596 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -24,13 +24,15 @@ import json import logging from multiprocessing.pool import ThreadPool import re +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union from dateutil.parser import parse as dparse from flask import escape, Markup from flask_appbuilder import Model from flask_appbuilder.models.decorators import renders +from flask_appbuilder.security.sqla.models import User from flask_babel import lazy_gettext as _ -import pandas +import pandas as pd try: from pydruid.client import PyDruid @@ -41,7 +43,7 @@ try: RegisteredLookupExtraction, ) from pydruid.utils.filters import Dimension, Filter - from pydruid.utils.having import Aggregation + from pydruid.utils.having import Aggregation, Having from pydruid.utils.postaggregator import ( Const, Field, @@ -65,12 +67,13 @@ from sqlalchemy import ( Text, UniqueConstraint, ) -from sqlalchemy.orm import backref, relationship +from sqlalchemy.orm import backref, relationship, RelationshipProperty, Session from sqlalchemy_utils import EncryptedType from superset import conf, db, security_manager from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric from superset.exceptions import SupersetException +from superset.models.core import Database from superset.models.helpers import AuditMixinNullable, ImportMixin, QueryResult from superset.utils import core as utils, import_datasource @@ -78,6 +81,8 @@ try: from superset.utils.core import DimSelector, DTTM_ALIAS, flasher except ImportError: pass + + DRUID_TZ = conf.get("DRUID_TZ") POST_AGG_TYPE = "postagg" metadata = Model.metadata # pylint: disable=no-member @@ -150,22 +155,22 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin): return self.__repr__() @property - def data(self): + def data(self) -> Dict: return {"id": self.id, "name": self.cluster_name, "backend": "druid"} @staticmethod - def get_base_url(host, port): + def get_base_url(host, port) -> str: if not re.match("http(s)?://", host): host = "http://" + host url = "{0}:{1}".format(host, port) if port else host return url - def get_base_broker_url(self): + def get_base_broker_url(self) -> str: base_url = self.get_base_url(self.broker_host, self.broker_port) return f"{base_url}/{self.broker_endpoint}" - def get_pydruid_client(self): + def get_pydruid_client(self) -> PyDruid: cli = PyDruid( self.get_base_url(self.broker_host, self.broker_port), self.broker_endpoint ) @@ -173,39 +178,44 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin): cli.set_basic_auth_credentials(self.broker_user, self.broker_pass) return cli - def get_datasources(self): + def get_datasources(self) -> List[str]: endpoint = self.get_base_broker_url() + "/datasources" auth = requests.auth.HTTPBasicAuth(self.broker_user, self.broker_pass) return json.loads(requests.get(endpoint, auth=auth).text) - def get_druid_version(self): + def get_druid_version(self) -> str: endpoint = self.get_base_url(self.broker_host, self.broker_port) + "/status" auth = requests.auth.HTTPBasicAuth(self.broker_user, self.broker_pass) return json.loads(requests.get(endpoint, auth=auth).text)["version"] - @property + @property # noqa: T484 @utils.memoized - def druid_version(self): + def druid_version(self) -> str: return self.get_druid_version() def refresh_datasources( - self, datasource_name=None, merge_flag=True, refreshAll=True - ): + self, + datasource_name: Optional[str] = None, + merge_flag: bool = True, + refresh_all: bool = True, + ) -> None: """Refresh metadata of all datasources in the cluster If ``datasource_name`` is specified, only that datasource is updated """ ds_list = self.get_datasources() blacklist = conf.get("DRUID_DATA_SOURCE_BLACKLIST", []) - ds_refresh = [] + ds_refresh: List[str] = [] if not datasource_name: ds_refresh = list(filter(lambda ds: ds not in blacklist, ds_list)) elif datasource_name not in blacklist and datasource_name in ds_list: ds_refresh.append(datasource_name) else: return - self.refresh(ds_refresh, merge_flag, refreshAll) + self.refresh(ds_refresh, merge_flag, refresh_all) - def refresh(self, datasource_names, merge_flag, refreshAll): + def refresh( + self, datasource_names: List[str], merge_flag: bool, refresh_all: bool + ) -> None: """ Fetches metadata for the specified datasources and merges to the Superset database @@ -225,7 +235,7 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin): session.add(datasource) flasher(_("Adding new datasource [{}]").format(ds_name), "success") ds_map[ds_name] = datasource - elif refreshAll: + elif refresh_all: flasher(_("Refreshing datasource [{}]").format(ds_name), "info") else: del ds_map[ds_name] @@ -270,19 +280,19 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin): session.commit() @property - def perm(self): + def perm(self) -> str: return "[{obj.cluster_name}].(id:{obj.id})".format(obj=self) - def get_perm(self): + def get_perm(self) -> str: return self.perm @property - def name(self): - return self.verbose_name if self.verbose_name else self.cluster_name + def name(self) -> str: + return self.verbose_name or self.cluster_name @property - def unique_name(self): - return self.verbose_name if self.verbose_name else self.cluster_name + def unique_name(self) -> str: + return self.verbose_name or self.cluster_name class DruidColumn(Model, BaseColumn): @@ -318,25 +328,26 @@ class DruidColumn(Model, BaseColumn): return self.column_name or str(self.id) @property - def expression(self): + def expression(self) -> str: return self.dimension_spec_json @property - def dimension_spec(self): + def dimension_spec(self) -> Optional[Dict]: # noqa: T484 if self.dimension_spec_json: return json.loads(self.dimension_spec_json) - def get_metrics(self): - metrics = {} - metrics["count"] = DruidMetric( - metric_name="count", - verbose_name="COUNT(*)", - metric_type="count", - json=json.dumps({"type": "count", "name": "count"}), - ) + def get_metrics(self) -> Dict[str, "DruidMetric"]: + metrics = { + "count": DruidMetric( + metric_name="count", + verbose_name="COUNT(*)", + metric_type="count", + json=json.dumps({"type": "count", "name": "count"}), + ) + } return metrics - def refresh_metrics(self): + def refresh_metrics(self) -> None: """Refresh metrics based on the column metadata""" metrics = self.get_metrics() dbmetrics = ( @@ -356,8 +367,8 @@ class DruidColumn(Model, BaseColumn): db.session.add(metric) @classmethod - def import_obj(cls, i_column): - def lookup_obj(lookup_column): + def import_obj(cls, i_column: "DruidColumn") -> "DruidColumn": + def lookup_obj(lookup_column: "DruidColumn") -> Optional["DruidColumn"]: return ( db.session.query(DruidColumn) .filter( @@ -404,7 +415,7 @@ class DruidMetric(Model, BaseMetric): return self.json @property - def json_obj(self): + def json_obj(self) -> Dict: try: obj = json.loads(self.json) except Exception: @@ -412,7 +423,7 @@ class DruidMetric(Model, BaseMetric): return obj @property - def perm(self): + def perm(self) -> Optional[str]: return ( ("{parent_name}.[{obj.metric_name}](id:{obj.id})").format( obj=self, parent_name=self.datasource.full_name @@ -421,12 +432,12 @@ class DruidMetric(Model, BaseMetric): else None ) - def get_perm(self): + def get_perm(self) -> Optional[str]: return self.perm @classmethod - def import_obj(cls, i_metric): - def lookup_obj(lookup_metric): + def import_obj(cls, i_metric: "DruidMetric") -> "DruidMetric": + def lookup_obj(lookup_metric: DruidMetric) -> Optional[DruidMetric]: return ( db.session.query(DruidMetric) .filter( @@ -494,23 +505,23 @@ class DruidDatasource(Model, BaseDatasource): export_children = ["columns", "metrics"] @property - def database(self): + def database(self) -> RelationshipProperty: return self.cluster @property - def connection(self): + def connection(self) -> str: return str(self.database) @property - def num_cols(self): + def num_cols(self) -> List[str]: return [c.column_name for c in self.columns if c.is_num] @property - def name(self): + def name(self) -> str: return self.datasource_name @property - def schema(self): + def schema(self) -> Optional[str]: ds_name = self.datasource_name or "" name_pieces = ds_name.split(".") if len(name_pieces) > 1: @@ -519,11 +530,11 @@ class DruidDatasource(Model, BaseDatasource): return None @property - def schema_perm(self): + def schema_perm(self) -> Optional[str]: """Returns schema permission if present, cluster one otherwise.""" return security_manager.get_schema_perm(self.cluster, self.schema) - def get_perm(self): + def get_perm(self) -> str: return ("[{obj.cluster_name}].[{obj.datasource_name}]" "(id:{obj.id})").format( obj=self ) @@ -532,16 +543,16 @@ class DruidDatasource(Model, BaseDatasource): return NotImplementedError() @property - def link(self): + def link(self) -> Markup: name = escape(self.datasource_name) return Markup(f'{name}') @property - def full_name(self): + def full_name(self) -> str: return utils.get_datasource_full_name(self.cluster_name, self.datasource_name) @property - def time_column_grains(self): + def time_column_grains(self) -> Dict[str, List[str]]: return { "time_columns": [ "all", @@ -568,16 +579,18 @@ class DruidDatasource(Model, BaseDatasource): return self.datasource_name @renders("datasource_name") - def datasource_link(self): + def datasource_link(self) -> str: url = f"/superset/explore/{self.type}/{self.id}/" name = escape(self.datasource_name) return Markup(f'{name}') - def get_metric_obj(self, metric_name): + def get_metric_obj(self, metric_name: str) -> Dict: return [m.json_obj for m in self.metrics if m.metric_name == metric_name][0] @classmethod - def import_obj(cls, i_datasource, import_time=None): + def import_obj( + cls, i_datasource: "DruidDatasource", import_time: Optional[int] = None + ) -> int: """Imports the datasource from the object to the database. Metrics and columns and datasource will be overridden if exists. @@ -585,7 +598,7 @@ class DruidDatasource(Model, BaseDatasource): superset instances. Audit metadata isn't copies over. """ - def lookup_datasource(d): + def lookup_datasource(d: DruidDatasource) -> Optional[DruidDatasource]: return ( db.session.query(DruidDatasource) .filter( @@ -595,7 +608,7 @@ class DruidDatasource(Model, BaseDatasource): .first() ) - def lookup_cluster(d): + def lookup_cluster(d: DruidDatasource) -> Optional[DruidCluster]: return ( db.session.query(DruidCluster) .filter_by(cluster_name=d.cluster_name) @@ -659,12 +672,14 @@ class DruidDatasource(Model, BaseDatasource): if segment_metadata: return segment_metadata[-1]["columns"] - def refresh_metrics(self): + def refresh_metrics(self) -> None: for col in self.columns: col.refresh_metrics() @classmethod - def sync_to_db_from_config(cls, druid_config, user, cluster, refresh=True): + def sync_to_db_from_config( + cls, druid_config: Dict, user: User, cluster: DruidCluster, refresh: bool = True + ) -> None: """Merges the ds config from druid_config into one stored in the db.""" session = db.session datasource = ( @@ -742,13 +757,15 @@ class DruidDatasource(Model, BaseDatasource): session.commit() @staticmethod - def time_offset(granularity): + def time_offset(granularity: Union[str, Dict]) -> int: if granularity == "week_ending_saturday": return 6 * 24 * 3600 * 1000 # 6 days return 0 @classmethod - def get_datasource_by_name(cls, session, datasource_name, schema, database_name): + def get_datasource_by_name( + cls, session: Session, datasource_name: str, schema: str, database_name: str + ) -> Optional["DruidDatasource"]: query = ( session.query(cls) .join(DruidCluster) @@ -761,7 +778,9 @@ class DruidDatasource(Model, BaseDatasource): # http://druid.io/docs/0.8.0/querying/granularities.html # TODO: pass origin from the UI @staticmethod - def granularity(period_name, timezone=None, origin=None): + def granularity( + period_name: str, timezone: Optional[str] = None, origin: Optional[str] = None + ) -> Union[str, Dict]: if not period_name or period_name == "all": return "all" iso_8601_dict = { @@ -810,7 +829,7 @@ class DruidDatasource(Model, BaseDatasource): return granularity @staticmethod - def get_post_agg(mconf): + def get_post_agg(mconf: Dict) -> Postaggregator: """ For a metric specified as `postagg` returns the kind of post aggregation for pydruid. @@ -839,7 +858,9 @@ class DruidDatasource(Model, BaseDatasource): return CustomPostAggregator(mconf.get("name", ""), mconf) @staticmethod - def find_postaggs_for(postagg_names, metrics_dict): + def find_postaggs_for( + postagg_names: Set[str], metrics_dict: Dict[str, DruidMetric] + ) -> List[DruidMetric]: """Return a list of metrics that are post aggregations""" postagg_metrics = [ metrics_dict[name] @@ -852,7 +873,7 @@ class DruidDatasource(Model, BaseDatasource): return postagg_metrics @staticmethod - def recursive_get_fields(_conf): + def recursive_get_fields(_conf: Dict) -> List[str]: _type = _conf.get("type") _field = _conf.get("field") _fields = _conf.get("fields") @@ -875,11 +896,9 @@ class DruidDatasource(Model, BaseDatasource): # Check if the fields are already in aggs # or is a previous postagg required_fields = set( - [ - field - for field in required_fields - if field not in visited_postaggs and field not in agg_names - ] + field + for field in required_fields + if field not in visited_postaggs and field not in agg_names ) # First try to find postaggs that match if len(required_fields) > 0: @@ -903,7 +922,11 @@ class DruidDatasource(Model, BaseDatasource): post_aggs[postagg.metric_name] = DruidDatasource.get_post_agg(postagg.json_obj) @staticmethod - def metrics_and_post_aggs(metrics, metrics_dict, druid_version=None): + def metrics_and_post_aggs( + metrics: List[Union[Dict, str]], + metrics_dict: Dict[str, DruidMetric], + druid_version=None, + ) -> Tuple[OrderedDict, OrderedDict]: # noqa: T484 # Separate metrics into those that are aggregations # and those that are post aggregations saved_agg_names = set() @@ -912,26 +935,26 @@ class DruidDatasource(Model, BaseDatasource): for metric in metrics: if utils.is_adhoc_metric(metric): adhoc_agg_configs.append(metric) - elif metrics_dict[metric].metric_type != POST_AGG_TYPE: + elif metrics_dict[metric].metric_type != POST_AGG_TYPE: # noqa: T484 saved_agg_names.add(metric) else: postagg_names.append(metric) # Create the post aggregations, maintain order since postaggs # may depend on previous ones - post_aggs = OrderedDict() + post_aggs = OrderedDict() # noqa: T484 visited_postaggs = set() for postagg_name in postagg_names: - postagg = metrics_dict[postagg_name] + postagg = metrics_dict[postagg_name] # noqa: T484 visited_postaggs.add(postagg_name) DruidDatasource.resolve_postagg( postagg, post_aggs, saved_agg_names, visited_postaggs, metrics_dict ) - aggs = DruidDatasource.get_aggregations( + aggs = DruidDatasource.get_aggregations( # noqa: T484 metrics_dict, saved_agg_names, adhoc_agg_configs ) return aggs, post_aggs - def values_for_column(self, column_name, limit=10000): + def values_for_column(self, column_name: str, limit: int = 10000) -> List: """Retrieve some values for the given column""" logging.info( "Getting values for columns [{}] limited to [{}]".format(column_name, limit) @@ -955,12 +978,14 @@ class DruidDatasource(Model, BaseDatasource): client = self.cluster.get_pydruid_client() client.topn(**qry) df = client.export_pandas() - return [row[column_name] for row in df.to_records(index=False)] + return df[column_name].to_list() def get_query_str(self, query_obj, phase=1, client=None): return self.run_query(client=client, phase=phase, **query_obj) - def _add_filter_from_pre_query_data(self, df, dimensions, dim_filter): + def _add_filter_from_pre_query_data( + self, df: Optional[pd.DataFrame], dimensions, dim_filter + ): ret = dim_filter if df is not None and not df.empty: new_filters = [] @@ -1002,7 +1027,7 @@ class DruidDatasource(Model, BaseDatasource): return ret @staticmethod - def druid_type_from_adhoc_metric(adhoc_metric): + def druid_type_from_adhoc_metric(adhoc_metric: Dict) -> str: column_type = adhoc_metric["column"]["type"].lower() aggregate = adhoc_metric["aggregate"].lower() @@ -1014,7 +1039,9 @@ class DruidDatasource(Model, BaseDatasource): return column_type + aggregate.capitalize() @staticmethod - def get_aggregations(metrics_dict, saved_metrics, adhoc_metrics=[]): + def get_aggregations( + metrics_dict: Dict, saved_metrics: Iterable[str], adhoc_metrics: List[Dict] = [] + ) -> OrderedDict: """ Returns a dictionary of aggregation metric names to aggregation json objects @@ -1023,7 +1050,7 @@ class DruidDatasource(Model, BaseDatasource): :param adhoc_metrics: list of adhoc metric names :raise SupersetException: if one or more metric names are not aggregations """ - aggregations = OrderedDict() + aggregations: OrderedDict = OrderedDict() invalid_metric_names = [] for metric_name in saved_metrics: if metric_name in metrics_dict: @@ -1047,19 +1074,18 @@ class DruidDatasource(Model, BaseDatasource): } return aggregations - def get_dimensions(self, groupby, columns_dict): + def get_dimensions( + self, groupby: List[str], columns_dict: Dict[str, DruidColumn] + ) -> List[Union[str, Dict]]: dimensions = [] groupby = [gb for gb in groupby if gb in columns_dict] for column_name in groupby: col = columns_dict.get(column_name) dim_spec = col.dimension_spec if col else None - if dim_spec: - dimensions.append(dim_spec) - else: - dimensions.append(column_name) + dimensions.append(dim_spec or column_name) return dimensions - def intervals_from_dttms(self, from_dttm, to_dttm): + def intervals_from_dttms(self, from_dttm: datetime, to_dttm: datetime) -> str: # Couldn't find a way to just not filter on time... from_dttm = from_dttm or datetime(1901, 1, 1) to_dttm = to_dttm or datetime(2101, 1, 1) @@ -1091,7 +1117,7 @@ class DruidDatasource(Model, BaseDatasource): return values @staticmethod - def sanitize_metric_object(metric): + def sanitize_metric_object(metric: Dict) -> None: """ Update a metric with the correct type if necessary. :param dict metric: The metric to sanitize @@ -1122,7 +1148,7 @@ class DruidDatasource(Model, BaseDatasource): phase=2, client=None, order_desc=True, - ): + ) -> str: """Runs a query against Druid and returns a dataframe. """ # TODO refactor into using a TBD Query object @@ -1193,7 +1219,7 @@ class DruidDatasource(Model, BaseDatasource): del qry["dimensions"] client.timeseries(**qry) elif not having_filters and len(groupby) == 1 and order_desc: - dim = list(qry.get("dimensions"))[0] + dim = list(qry.get("dimensions"))[0] # noqa: T484 logging.info("Running two-phase topn query for dimension [{}]".format(dim)) pre_qry = deepcopy(qry) if timeseries_limit_metric: @@ -1324,7 +1350,7 @@ class DruidDatasource(Model, BaseDatasource): return query_str @staticmethod - def homogenize_types(df, groupby_cols): + def homogenize_types(df: pd.DataFrame, groupby_cols: Iterable[str]) -> pd.DataFrame: """Converting all GROUPBY columns to strings When grouping by a numeric (say FLOAT) column, pydruid returns @@ -1334,11 +1360,10 @@ class DruidDatasource(Model, BaseDatasource): Here we replace None with and make the whole series a str instead of an object. """ - for col in groupby_cols: - df[col] = df[col].fillna("").astype("unicode") + df[groupby_cols] = df[groupby_cols].fillna("").astype("unicode") return df - def query(self, query_obj): + def query(self, query_obj: Dict) -> QueryResult: qry_start_dttm = datetime.now() client = self.cluster.get_pydruid_client() query_str = self.get_query_str(client=client, query_obj=query_obj, phase=2) @@ -1346,7 +1371,7 @@ class DruidDatasource(Model, BaseDatasource): if df is None or df.size == 0: return QueryResult( - df=pandas.DataFrame([]), + df=pd.DataFrame([]), query=query_str, duration=datetime.now() - qry_start_dttm, ) @@ -1363,7 +1388,7 @@ class DruidDatasource(Model, BaseDatasource): del df[DTTM_ALIAS] # Reordering columns - cols = [] + cols: List[str] = [] if DTTM_ALIAS in df.columns: cols += [DTTM_ALIAS] cols += query_obj.get("groupby") or [] @@ -1413,7 +1438,7 @@ class DruidDatasource(Model, BaseDatasource): return (col, extraction_fn) @classmethod - def get_filters(cls, raw_filters, num_cols, columns_dict): # noqa + def get_filters(cls, raw_filters, num_cols, columns_dict) -> Filter: # noqa: T484 """Given Superset filter data structure, returns pydruid Filter(s)""" filters = None for flt in raw_filters: @@ -1542,7 +1567,7 @@ class DruidDatasource(Model, BaseDatasource): return filters - def _get_having_obj(self, col, op, eq): + def _get_having_obj(self, col: str, op: str, eq: str) -> Having: cond = None if op == "==": if col in self.column_names: @@ -1556,7 +1581,7 @@ class DruidDatasource(Model, BaseDatasource): return cond - def get_having_filters(self, raw_filters): + def get_having_filters(self, raw_filters: List[Dict]) -> Having: filters = None reversed_op_map = {"!=": "==", ">=": "<", "<=": ">"} @@ -1579,7 +1604,9 @@ class DruidDatasource(Model, BaseDatasource): return filters @classmethod - def query_datasources_by_name(cls, session, database, datasource_name, schema=None): + def query_datasources_by_name( + cls, session: Session, database: Database, datasource_name: str, schema=None + ) -> List["DruidDatasource"]: return ( session.query(cls) .filter_by(cluster_name=database.id) @@ -1587,7 +1614,7 @@ class DruidDatasource(Model, BaseDatasource): .all() ) - def external_metadata(self): + def external_metadata(self) -> List[Dict]: self.merge_flag = True return [ {"name": k, "type": v.get("type")} diff --git a/superset/connectors/druid/views.py b/superset/connectors/druid/views.py index 775e73f2163..606d3380b8a 100644 --- a/superset/connectors/druid/views.py +++ b/superset/connectors/druid/views.py @@ -395,7 +395,7 @@ class Druid(BaseSupersetView): @has_access @expose("/refresh_datasources/") - def refresh_datasources(self, refreshAll=True): + def refresh_datasources(self, refresh_all=True): """endpoint that refreshes druid datasources metadata""" session = db.session() DruidCluster = ConnectorRegistry.sources["druid"].cluster_class @@ -403,7 +403,7 @@ class Druid(BaseSupersetView): cluster_name = cluster.cluster_name valid_cluster = True try: - cluster.refresh_datasources(refreshAll=refreshAll) + cluster.refresh_datasources(refresh_all=refresh_all) except Exception as e: valid_cluster = False flash( @@ -432,7 +432,7 @@ class Druid(BaseSupersetView): Calling this endpoint will cause a scan for new datasources only and add them. """ - return self.refresh_datasources(refreshAll=False) + return self.refresh_datasources(refresh_all=False) appbuilder.add_view_no_menu(Druid) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 41a79579471..980e607772b 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -42,10 +42,10 @@ from sqlalchemy import ( Text, ) from sqlalchemy.exc import CompileError -from sqlalchemy.orm import backref, relationship +from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.schema import UniqueConstraint -from sqlalchemy.sql import column, literal_column, table, text +from sqlalchemy.sql import column, ColumnElement, literal_column, table, text from sqlalchemy.sql.expression import Label, Select, TextAsFrom import sqlparse @@ -83,7 +83,7 @@ class AnnotationDatasource(BaseDatasource): cache_timeout = 0 - def query(self, query_obj): + def query(self, query_obj: Dict) -> QueryResult: df = None error_message = None qry = db.session.query(Annotation) @@ -143,7 +143,7 @@ class TableColumn(Model, BaseColumn): update_from_object_fields = [s for s in export_fields if s not in ("table_id",)] export_parent = "table" - def get_sqla_col(self, label=None): + def get_sqla_col(self, label: Optional[str] = None) -> Column: label = label or self.column_name if not self.expression: db_engine_spec = self.table.database.db_engine_spec @@ -155,10 +155,12 @@ class TableColumn(Model, BaseColumn): return col @property - def datasource(self): + def datasource(self) -> RelationshipProperty: return self.table - def get_time_filter(self, start_dttm, end_dttm): + def get_time_filter( + self, start_dttm: DateTime, end_dttm: DateTime + ) -> ColumnElement: col = self.get_sqla_col(label="__time") l = [] # noqa: E741 if start_dttm: @@ -205,7 +207,7 @@ class TableColumn(Model, BaseColumn): return import_datasource.import_simple_obj(db.session, i_column, lookup_obj) - def dttm_sql_literal(self, dttm): + def dttm_sql_literal(self, dttm: DateTime) -> str: """Convert datetime object to a SQL expression string""" tf = self.python_date_format if tf: @@ -249,13 +251,13 @@ class SqlMetric(Model, BaseMetric): ) export_parent = "table" - def get_sqla_col(self, label=None): + def get_sqla_col(self, label: Optional[str] = None) -> Column: label = label or self.metric_name sqla_col = literal_column(self.expression) return self.table.make_sqla_column_compatible(sqla_col, label) @property - def perm(self): + def perm(self) -> Optional[str]: return ( ("{parent_name}.[{obj.metric_name}](id:{obj.id})").format( obj=self, parent_name=self.table.full_name @@ -264,7 +266,7 @@ class SqlMetric(Model, BaseMetric): else None ) - def get_perm(self): + def get_perm(self) -> Optional[str]: return self.perm @classmethod @@ -351,7 +353,9 @@ class SqlaTable(Model, BaseDatasource): "MAX": sa.func.MAX, } - def make_sqla_column_compatible(self, sqla_col, label=None): + def make_sqla_column_compatible( + self, sqla_col: Column, label: Optional[str] = None + ) -> Column: """Takes a sql alchemy column object and adds label info if supported by engine. :param sqla_col: sql alchemy column instance :param label: alias/label that column is expected to have @@ -369,23 +373,29 @@ class SqlaTable(Model, BaseDatasource): return self.name @property - def connection(self): + def connection(self) -> str: return str(self.database) @property - def description_markeddown(self): + def description_markeddown(self) -> str: return utils.markdown(self.description) @property - def datasource_name(self): + def datasource_name(self) -> str: return self.table_name @property - def database_name(self): + def database_name(self) -> str: return self.database.name @classmethod - def get_datasource_by_name(cls, session, datasource_name, schema, database_name): + def get_datasource_by_name( + cls, + session: Session, + datasource_name: str, + schema: Optional[str], + database_name: str, + ) -> Optional["SqlaTable"]: schema = schema or None query = ( session.query(cls) @@ -398,52 +408,52 @@ class SqlaTable(Model, BaseDatasource): for tbl in query.all(): if schema == (tbl.schema or None): return tbl + return None @property - def link(self): + def link(self) -> Markup: name = escape(self.name) anchor = f'{name}' return Markup(anchor) @property - def schema_perm(self): + def schema_perm(self) -> Optional[str]: """Returns schema permission if present, database one otherwise.""" return security_manager.get_schema_perm(self.database, self.schema) - def get_perm(self): + def get_perm(self) -> str: return ("[{obj.database}].[{obj.table_name}]" "(id:{obj.id})").format(obj=self) @property - def name(self): + def name(self) -> str: if not self.schema: return self.table_name return "{}.{}".format(self.schema, self.table_name) @property - def full_name(self): + def full_name(self) -> str: return utils.get_datasource_full_name( self.database, self.table_name, schema=self.schema ) @property - def dttm_cols(self): + def dttm_cols(self) -> List: l = [c.column_name for c in self.columns if c.is_dttm] # noqa: E741 if self.main_dttm_col and self.main_dttm_col not in l: l.append(self.main_dttm_col) return l @property - def num_cols(self): + def num_cols(self) -> List: return [c.column_name for c in self.columns if c.is_num] @property - def any_dttm_col(self): + def any_dttm_col(self) -> Optional[str]: cols = self.dttm_cols - if cols: - return cols[0] + return cols[0] if cols else None @property - def html(self): + def html(self) -> str: t = ((c.column_name, c.type) for c in self.columns) df = pd.DataFrame(t) df.columns = ["field", "type"] @@ -453,7 +463,7 @@ class SqlaTable(Model, BaseDatasource): ) @property - def sql_url(self): + def sql_url(self) -> str: return self.database.sql_url + "?table_name=" + str(self.table_name) def external_metadata(self): @@ -466,28 +476,29 @@ class SqlaTable(Model, BaseDatasource): return cols @property - def time_column_grains(self): + def time_column_grains(self) -> Dict[str, Any]: return { "time_columns": self.dttm_cols, "time_grains": [grain.name for grain in self.database.grains()], } @property - def select_star(self): + def select_star(self) -> str: # show_cols and latest_partition set to false to avoid # the expensive cost of inspecting the DB return self.database.select_star( self.table_name, schema=self.schema, show_cols=False, latest_partition=False ) - def get_col(self, col_name): + def get_col(self, col_name: str) -> Optional[Column]: columns = self.columns for col in columns: if col_name == col.column_name: return col + return None @property - def data(self): + def data(self) -> Dict: d = super(SqlaTable, self).data if self.type == "table": grains = self.database.grains() or [] @@ -500,7 +511,7 @@ class SqlaTable(Model, BaseDatasource): d["template_params"] = self.template_params return d - def values_for_column(self, column_name, limit=10000): + def values_for_column(self, column_name: str, limit: int = 10000) -> List: """Runs query against sqla to retrieve some sample values for the given column. """ @@ -525,9 +536,9 @@ class SqlaTable(Model, BaseDatasource): sql = self.mutate_query_from_config(sql) df = pd.read_sql_query(sql=sql, con=engine) - return [row[0] for row in df.to_records(index=False)] + return df[column_name].to_list() - def mutate_query_from_config(self, sql): + def mutate_query_from_config(self, sql: str) -> str: """Apply config's SQL_QUERY_MUTATOR Typically adds comments to the query with context""" @@ -540,7 +551,7 @@ class SqlaTable(Model, BaseDatasource): def get_template_processor(self, **kwargs): return get_template_processor(table=self, database=self.database, **kwargs) - def get_query_str_extended(self, query_obj) -> QueryStringExtended: + def get_query_str_extended(self, query_obj: Dict) -> QueryStringExtended: sqlaq = self.get_sqla_query(**query_obj) sql = self.database.compile_sqla_query(sqlaq.sqla_query) logging.info(sql) @@ -550,7 +561,7 @@ class SqlaTable(Model, BaseDatasource): labels_expected=sqlaq.labels_expected, sql=sql, prequeries=sqlaq.prequeries ) - def get_query_str(self, query_obj): + def get_query_str(self, query_obj: Dict) -> str: query_str_ext = self.get_query_str_extended(query_obj) all_queries = query_str_ext.prequeries + [query_str_ext.sql] return ";\n\n".join(all_queries) + ";" @@ -571,7 +582,7 @@ class SqlaTable(Model, BaseDatasource): return TextAsFrom(sa.text(from_sql), []).alias("expr_qry") return self.get_sqla_table() - def adhoc_metric_to_sqla(self, metric, cols): + def adhoc_metric_to_sqla(self, metric: Dict, cols: Dict) -> Optional[Column]: """ Turn an adhoc metric into a sqlalchemy column. @@ -584,13 +595,13 @@ class SqlaTable(Model, BaseDatasource): label = utils.get_metric_name(metric) if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"]: - column_name = metric.get("column").get("column_name") + column_name = metric["column"].get("column_name") table_column = cols.get(column_name) if table_column: sqla_column = table_column.get_sqla_col() else: sqla_column = column(column_name) - sqla_metric = self.sqla_aggregations[metric.get("aggregate")](sqla_column) + sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column) elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SQL"]: sqla_metric = literal_column(metric.get("sqlExpression")) else: @@ -616,7 +627,7 @@ class SqlaTable(Model, BaseDatasource): extras=None, columns=None, order_desc=True, - ): + ) -> SqlaQuery: """Querying any sqla table from this common interface""" template_kwargs = { "from_dttm": from_dttm, @@ -643,8 +654,8 @@ class SqlaTable(Model, BaseDatasource): # Database spec supports join-free timeslot grouping time_groupby_inline = db_engine_spec.time_groupby_inline - cols = {col.column_name: col for col in self.columns} - metrics_dict = {m.metric_name: m for m in self.metrics} + cols: Dict[str, Column] = {col.column_name: col for col in self.columns} + metrics_dict: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics} if not granularity and is_timeseries: raise Exception( @@ -660,7 +671,7 @@ class SqlaTable(Model, BaseDatasource): if utils.is_adhoc_metric(m): metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols)) elif m in metrics_dict: - metrics_exprs.append(metrics_dict.get(m).get_sqla_col()) + metrics_exprs.append(metrics_dict[m].get_sqla_col()) else: raise Exception(_("Metric '%(metric)s' does not exist", metric=m)) if metrics_exprs: @@ -669,8 +680,8 @@ class SqlaTable(Model, BaseDatasource): main_metric_expr, label = literal_column("COUNT(*)"), "ccount" main_metric_expr = self.make_sqla_column_compatible(main_metric_expr, label) - select_exprs = [] - groupby_exprs_sans_timestamp = OrderedDict() + select_exprs: List[Column] = [] + groupby_exprs_sans_timestamp: OrderedDict = OrderedDict() if groupby: select_exprs = [] @@ -729,7 +740,7 @@ class SqlaTable(Model, BaseDatasource): qry = qry.group_by(*groupby_exprs_with_timestamp.values()) where_clause_and = [] - having_clause_and = [] + having_clause_and: List = [] for flt in filter: if not all([flt.get(s) for s in ["col", "op"]]): continue @@ -899,7 +910,9 @@ class SqlaTable(Model, BaseDatasource): return ob - def _get_top_groups(self, df, dimensions, groupby_exprs): + def _get_top_groups( + self, df: pd.DataFrame, dimensions: List, groupby_exprs: OrderedDict + ) -> ColumnElement: groups = [] for unused, row in df.iterrows(): group = [] @@ -909,7 +922,7 @@ class SqlaTable(Model, BaseDatasource): return or_(*groups) - def query(self, query_obj): + def query(self, query_obj: Dict) -> QueryResult: qry_start_dttm = datetime.now() query_str_ext = self.get_query_str_extended(query_obj) sql = query_str_ext.sql @@ -945,10 +958,10 @@ class SqlaTable(Model, BaseDatasource): error_message=error_message, ) - def get_sqla_table_object(self): + def get_sqla_table_object(self) -> Table: return self.database.get_table(self.table_name, schema=self.schema) - def fetch_metadata(self): + def fetch_metadata(self) -> None: """Fetches the metadata for the table and merges it in""" try: table = self.get_sqla_table_object() @@ -1012,7 +1025,7 @@ class SqlaTable(Model, BaseDatasource): db.session.commit() @classmethod - def import_obj(cls, i_datasource, import_time=None): + def import_obj(cls, i_datasource, import_time=None) -> int: """Imports the datasource from the object to the database. Metrics and columns and datasource will be overrided if exists. @@ -1052,7 +1065,9 @@ class SqlaTable(Model, BaseDatasource): ) @classmethod - def query_datasources_by_name(cls, session, database, datasource_name, schema=None): + def query_datasources_by_name( + cls, session: Session, database: Database, datasource_name: str, schema=None + ) -> List["SqlaTable"]: query = ( session.query(cls) .filter_by(database_id=database.id) @@ -1063,7 +1078,7 @@ class SqlaTable(Model, BaseDatasource): return query.all() @staticmethod - def default_query(qry): + def default_query(qry) -> Query: return qry.filter_by(is_sqllab_view=False) def has_extra_cache_keys(self, query_obj: Dict) -> bool: