diff --git a/superset/common/query_context.py b/superset/common/query_context.py
index 6fe47a01980..c2534e9f64f 100644
--- a/superset/common/query_context.py
+++ b/superset/common/query_context.py
@@ -18,20 +18,22 @@
from datetime import datetime, timedelta
import logging
import pickle as pkl
-from typing import Dict, List
+from typing import Any, Dict, List, Optional
import numpy as np
import pandas as pd
from superset import app, cache
from superset import db
+from superset.connectors.base.models import BaseDatasource
from superset.connectors.connector_registry import ConnectorRegistry
+from superset.stats_logger import BaseStatsLogger
from superset.utils import core as utils
from superset.utils.core import DTTM_ALIAS
from .query_object import QueryObject
config = app.config
-stats_logger = config.get("STATS_LOGGER")
+stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
class QueryContext:
@@ -40,8 +42,13 @@ class QueryContext:
to retrieve the data payload for a given viz.
"""
- cache_type = "df"
- enforce_numerical_metrics = True
+ cache_type: str = "df"
+ enforce_numerical_metrics: bool = True
+
+ datasource: BaseDatasource
+ queries: List[QueryObject]
+ force: bool
+ custom_cache_timeout: Optional[int]
# TODO: Type datasource and query_object dictionary with TypedDict when it becomes
# a vanilla python type https://github.com/python/mypy/issues/5288
@@ -50,8 +57,8 @@ class QueryContext:
datasource: Dict,
queries: List[Dict],
force: bool = False,
- custom_cache_timeout: int = None,
- ):
+ custom_cache_timeout: Optional[int] = None,
+ ) -> None:
self.datasource = ConnectorRegistry.get_datasource(
datasource.get("type"), int(datasource.get("id")), db.session # noqa: T400
)
@@ -61,9 +68,7 @@ class QueryContext:
self.custom_cache_timeout = custom_cache_timeout
- self.enforce_numerical_metrics = True
-
- def get_query_result(self, query_object):
+ def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]:
"""Returns a pandas dataframe based on the query object"""
# Here, we assume that all the queries will use the same datasource, which is
@@ -109,23 +114,23 @@ class QueryContext:
"df": df,
}
- def df_metrics_to_num(self, df, query_object):
+ def df_metrics_to_num(self, df: pd.DataFrame, query_object: QueryObject) -> None:
"""Converting metrics to numeric when pandas.read_sql cannot"""
metrics = [metric for metric in query_object.metrics]
for col, dtype in df.dtypes.items():
if dtype.type == np.object_ and col in metrics:
df[col] = pd.to_numeric(df[col], errors="coerce")
- def get_data(self, df):
+ def get_data(self, df: pd.DataFrame) -> List[Dict]:
return df.to_dict(orient="records")
- def get_single_payload(self, query_obj: QueryObject):
+ def get_single_payload(self, query_obj: QueryObject) -> Dict[str, Any]:
"""Returns a payload of metadata and data"""
payload = self.get_df_payload(query_obj)
df = payload.get("df")
status = payload.get("status")
if status != utils.QueryStatus.FAILED:
- if df is not None and df.empty:
+ if df is None or df.empty:
payload["error"] = "No data"
else:
payload["data"] = self.get_data(df)
@@ -133,12 +138,12 @@ class QueryContext:
del payload["df"]
return payload
- def get_payload(self):
+ def get_payload(self) -> List[Dict[str, Any]]:
"""Get all the payloads from the arrays"""
return [self.get_single_payload(query_object) for query_object in self.queries]
@property
- def cache_timeout(self):
+ def cache_timeout(self) -> int:
if self.custom_cache_timeout is not None:
return self.custom_cache_timeout
if self.datasource.cache_timeout is not None:
@@ -148,10 +153,10 @@ class QueryContext:
and self.datasource.database.cache_timeout
) is not None:
return self.datasource.database.cache_timeout
- return config.get("CACHE_DEFAULT_TIMEOUT")
+ return config["CACHE_DEFAULT_TIMEOUT"]
- def get_df_payload(self, query_obj: QueryObject, **kwargs):
- """Handles caching around the df paylod retrieval"""
+ def get_df_payload(self, query_obj: QueryObject, **kwargs) -> Dict[str, Any]:
+ """Handles caching around the df payload retrieval"""
extra_cache_keys = self.datasource.get_extra_cache_keys(query_obj.to_dict())
cache_key = (
query_obj.cache_key(
@@ -207,9 +212,7 @@ class QueryContext:
if is_loaded and cache_key and cache and status != utils.QueryStatus.FAILED:
try:
- cache_value = dict(
- dttm=cached_dttm, df=df if df is not None else None, query=query
- )
+ cache_value = dict(dttm=cached_dttm, df=df, query=query)
cache_binary = pkl.dumps(cache_value, protocol=pkl.HIGHEST_PROTOCOL)
logging.info(
diff --git a/superset/common/query_object.py b/superset/common/query_object.py
index 28f2303c25d..71b690358a4 100644
--- a/superset/common/query_object.py
+++ b/superset/common/query_object.py
@@ -15,8 +15,9 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=R
+from datetime import datetime, timedelta
import hashlib
-from typing import Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional, Union
import simplejson as json
@@ -34,12 +35,28 @@ class QueryObject:
and druid. The query objects are constructed on the client.
"""
+ granularity: str
+ from_dttm: datetime
+ to_dttm: datetime
+ is_timeseries: bool
+ time_shift: timedelta
+ groupby: List[str]
+ metrics: List[Union[Dict, str]]
+ row_limit: int
+ filter: List[str]
+ timeseries_limit: int
+ timeseries_limit_metric: Optional[Dict]
+ order_desc: bool
+ extras: Dict
+ columns: List[str]
+ orderby: List[List]
+
def __init__(
self,
granularity: str,
metrics: List[Union[Dict, str]],
- groupby: List[str] = None,
- filters: List[str] = None,
+ groupby: Optional[List[str]] = None,
+ filters: Optional[List[str]] = None,
time_range: Optional[str] = None,
time_shift: Optional[str] = None,
is_timeseries: bool = False,
@@ -48,8 +65,8 @@ class QueryObject:
timeseries_limit_metric: Optional[Dict] = None,
order_desc: bool = True,
extras: Optional[Dict] = None,
- columns: List[str] = None,
- orderby: List[List] = None,
+ columns: Optional[List[str]] = None,
+ orderby: Optional[List[List]] = None,
relative_start: str = app.config.get("DEFAULT_RELATIVE_START_TIME", "today"),
relative_end: str = app.config.get("DEFAULT_RELATIVE_END_TIME", "today"),
):
@@ -63,7 +80,7 @@ class QueryObject:
self.is_timeseries = is_timeseries
self.time_range = time_range
self.time_shift = utils.parse_human_timedelta(time_shift)
- self.groupby = groupby if groupby is not None else []
+ self.groupby = groupby or []
# Temporal solution for backward compatability issue
# due the new format of non-ad-hoc metric.
@@ -72,15 +89,15 @@ class QueryObject:
for metric in metrics
]
self.row_limit = row_limit
- self.filter = filters if filters is not None else []
+ self.filter = filters or []
self.timeseries_limit = timeseries_limit
self.timeseries_limit_metric = timeseries_limit_metric
self.order_desc = order_desc
- self.extras = extras if extras is not None else {}
- self.columns = columns if columns is not None else []
- self.orderby = orderby if orderby is not None else []
+ self.extras = extras or {}
+ self.columns = columns or []
+ self.orderby = orderby or []
- def to_dict(self):
+ def to_dict(self) -> Dict[str, Any]:
query_object_dict = {
"granularity": self.granularity,
"from_dttm": self.from_dttm,
@@ -99,7 +116,7 @@ class QueryObject:
}
return query_object_dict
- def cache_key(self, **extra):
+ def cache_key(self, **extra) -> str:
"""
The cache key is made out of the key/values from to_dict(), plus any
other key/values in `extra`
@@ -117,7 +134,7 @@ class QueryObject:
json_data = self.json_dumps(cache_dict, sort_keys=True)
return hashlib.md5(json_data.encode("utf-8")).hexdigest()
- def json_dumps(self, obj, sort_keys=False):
+ def json_dumps(self, obj: Any, sort_keys: bool = False) -> str:
return json.dumps(
obj, default=utils.json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys
)
diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py
index d82f4d6aa02..d84910c6e56 100644
--- a/superset/connectors/base/models.py
+++ b/superset/connectors/base/models.py
@@ -16,14 +16,15 @@
# under the License.
# pylint: disable=C,R,W
import json
-from typing import Any, List
+from typing import Any, Dict, List, Optional
+from flask_appbuilder.security.sqla.models import User
from sqlalchemy import and_, Boolean, Column, Integer, String, Text
from sqlalchemy.ext.declarative import declared_attr
-from sqlalchemy.orm import foreign, relationship
+from sqlalchemy.orm import foreign, Query, relationship
from superset.models.core import Slice
-from superset.models.helpers import AuditMixinNullable, ImportMixin
+from superset.models.helpers import AuditMixinNullable, ImportMixin, QueryResult
from superset.utils import core as utils
@@ -59,9 +60,9 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
params = Column(String(1000))
perm = Column(String(1000))
- sql = None
- owners = None
- update_from_object_fields = None
+ sql: Optional[str] = None
+ owners: List[User]
+ update_from_object_fields: List[str]
@declared_attr
def slices(self):
@@ -79,20 +80,20 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
metrics: List[Any] = []
@property
- def uid(self):
+ def uid(self) -> str:
"""Unique id across datasource types"""
return f"{self.id}__{self.type}"
@property
- def column_names(self):
+ def column_names(self) -> List[str]:
return sorted([c.column_name for c in self.columns], key=lambda x: x or "")
@property
- def columns_types(self):
+ def columns_types(self) -> Dict:
return {c.column_name: c.type for c in self.columns}
@property
- def main_dttm_col(self):
+ def main_dttm_col(self) -> str:
return "timestamp"
@property
@@ -100,47 +101,47 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
raise NotImplementedError()
@property
- def connection(self):
+ def connection(self) -> Optional[str]:
"""String representing the context of the Datasource"""
return None
@property
- def schema(self):
+ def schema(self) -> Optional[str]:
"""String representing the schema of the Datasource (if it applies)"""
return None
@property
- def filterable_column_names(self):
+ def filterable_column_names(self) -> List[str]:
return sorted([c.column_name for c in self.columns if c.filterable])
@property
- def dttm_cols(self):
+ def dttm_cols(self) -> List:
return []
@property
- def url(self):
+ def url(self) -> str:
return "/{}/edit/{}".format(self.baselink, self.id)
@property
- def explore_url(self):
+ def explore_url(self) -> str:
if self.default_endpoint:
return self.default_endpoint
else:
return "/superset/explore/{obj.type}/{obj.id}/".format(obj=self)
@property
- def column_formats(self):
+ def column_formats(self) -> Dict[str, Optional[str]]:
return {m.metric_name: m.d3format for m in self.metrics if m.d3format}
- def add_missing_metrics(self, metrics):
- exisiting_metrics = {m.metric_name for m in self.metrics}
+ def add_missing_metrics(self, metrics: List["BaseMetric"]) -> None:
+ existing_metrics = {m.metric_name for m in self.metrics}
for metric in metrics:
- if metric.metric_name not in exisiting_metrics:
+ if metric.metric_name not in existing_metrics:
metric.table_id = self.id
- self.metrics += [metric]
+ self.metrics.append(metric)
@property
- def short_data(self):
+ def short_data(self) -> Dict[str, Any]:
"""Data representation of the datasource sent to the frontend"""
return {
"edit_url": self.url,
@@ -158,7 +159,7 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
pass
@property
- def data(self):
+ def data(self) -> Dict[str, Any]:
"""Data representation of the datasource sent to the frontend"""
order_by_choices = []
# self.column_names return sorted column_names
@@ -239,14 +240,14 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
"""Returns column information from the external system"""
raise NotImplementedError()
- def get_query_str(self, query_obj):
+ def get_query_str(self, query_obj) -> str:
"""Returns a query as a string
This is used to be displayed to the user so that she/he can
understand what is taking place behind the scene"""
raise NotImplementedError()
- def query(self, query_obj):
+ def query(self, query_obj) -> QueryResult:
"""Executes the query and returns a dataframe
query_obj is a dictionary representing Superset's query interface.
@@ -254,7 +255,7 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
"""
raise NotImplementedError()
- def values_for_column(self, column_name, limit=10000):
+ def values_for_column(self, column_name: str, limit: int = 10000) -> List:
"""Given a column, returns an iterable of distinct values
This is used to populate the dropdown showing a list of
@@ -262,13 +263,14 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
raise NotImplementedError()
@staticmethod
- def default_query(qry):
+ def default_query(qry) -> Query:
return qry
- def get_column(self, column_name):
+ def get_column(self, column_name: str) -> Optional["BaseColumn"]:
for col in self.columns:
if col.column_name == column_name:
return col
+ return None
def get_fk_many_from_list(self, object_list, fkmany, fkmany_class, key_attr):
"""Update ORM one-to-many list from object list
@@ -276,10 +278,9 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
Used for syncing metrics and columns using the same code"""
object_dict = {o.get(key_attr): o for o in object_list}
- object_keys = [o.get(key_attr) for o in object_list]
# delete fks that have been removed
- fkmany = [o for o in fkmany if getattr(o, key_attr) in object_keys]
+ fkmany = [o for o in fkmany if getattr(o, key_attr) in object_dict]
# sync existing fks
for fk in fkmany:
@@ -303,7 +304,7 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
fkmany += new_fks
return fkmany
- def update_from_object(self, obj):
+ def update_from_object(self, obj) -> None:
"""Update datasource from a data structure
The UI's table editor crafts a complex data structure that
@@ -330,7 +331,7 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
obj.get("columns"), self.columns, self.column_class, "column_name"
)
- def get_extra_cache_keys(self, query_obj) -> List[Any]:
+ def get_extra_cache_keys(self, query_obj: Dict) -> List[Any]:
""" If a datasource needs to provide additional keys for calculation of
cache keys, those can be provided via this method
"""
@@ -374,23 +375,23 @@ class BaseColumn(AuditMixinNullable, ImportMixin):
str_types = ("VARCHAR", "STRING", "CHAR")
@property
- def is_num(self):
- return self.type and any([t in self.type.upper() for t in self.num_types])
+ def is_num(self) -> bool:
+ return self.type and any(map(lambda t: t in self.type.upper(), self.num_types))
@property
- def is_time(self):
- return self.type and any([t in self.type.upper() for t in self.date_types])
+ def is_time(self) -> bool:
+ return self.type and any(map(lambda t: t in self.type.upper(), self.date_types))
@property
- def is_string(self):
- return self.type and any([t in self.type.upper() for t in self.str_types])
+ def is_string(self) -> bool:
+ return self.type and any(map(lambda t: t in self.type.upper(), self.str_types))
@property
def expression(self):
raise NotImplementedError()
@property
- def data(self):
+ def data(self) -> Dict[str, Any]:
attrs = (
"id",
"column_name",
@@ -443,7 +444,7 @@ class BaseMetric(AuditMixinNullable, ImportMixin):
raise NotImplementedError()
@property
- def data(self):
+ def data(self) -> Dict[str, Any]:
attrs = (
"id",
"metric_name",
diff --git a/superset/connectors/connector_registry.py b/superset/connectors/connector_registry.py
index d5e951aabab..9ce11802ff0 100644
--- a/superset/connectors/connector_registry.py
+++ b/superset/connectors/connector_registry.py
@@ -15,16 +15,23 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=C,R,W
-from sqlalchemy.orm import subqueryload
+from collections import OrderedDict
+from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING
+
+from sqlalchemy.orm import Session, subqueryload
+
+if TYPE_CHECKING:
+ from superset.models.core import Database
+ from superset.connectors.base.models import BaseDatasource
class ConnectorRegistry(object):
""" Central Registry for all available datasource engines"""
- sources = {}
+ sources: Dict[str, Type["BaseDatasource"]] = {}
@classmethod
- def register_sources(cls, datasource_config):
+ def register_sources(cls, datasource_config: OrderedDict) -> None:
for module_name, class_names in datasource_config.items():
class_names = [str(s) for s in class_names]
module_obj = __import__(module_name, fromlist=class_names)
@@ -33,7 +40,9 @@ class ConnectorRegistry(object):
cls.sources[source_class.type] = source_class
@classmethod
- def get_datasource(cls, datasource_type, datasource_id, session):
+ def get_datasource(
+ cls, datasource_type: str, datasource_id: int, session: Session
+ ) -> Optional["BaseDatasource"]:
return (
session.query(cls.sources[datasource_type])
.filter_by(id=datasource_id)
@@ -41,8 +50,8 @@ class ConnectorRegistry(object):
)
@classmethod
- def get_all_datasources(cls, session):
- datasources = []
+ def get_all_datasources(cls, session: Session) -> List["BaseDatasource"]:
+ datasources: List["BaseDatasource"] = []
for source_type in ConnectorRegistry.sources:
source_class = ConnectorRegistry.sources[source_type]
qry = session.query(source_class)
@@ -52,15 +61,22 @@ class ConnectorRegistry(object):
@classmethod
def get_datasource_by_name(
- cls, session, datasource_type, datasource_name, schema, database_name
- ):
+ cls,
+ session: Session,
+ datasource_type: str,
+ datasource_name: str,
+ schema: str,
+ database_name: str,
+ ) -> Optional["BaseDatasource"]:
datasource_class = ConnectorRegistry.sources[datasource_type]
return datasource_class.get_datasource_by_name(
session, datasource_name, schema, database_name
)
@classmethod
- def query_datasources_by_permissions(cls, session, database, permissions):
+ def query_datasources_by_permissions(
+ cls, session: Session, database: "Database", permissions: Set[str]
+ ) -> List["BaseDatasource"]:
datasource_class = ConnectorRegistry.sources[database.type]
return (
session.query(datasource_class)
@@ -70,7 +86,9 @@ class ConnectorRegistry(object):
)
@classmethod
- def get_eager_datasource(cls, session, datasource_type, datasource_id):
+ def get_eager_datasource(
+ cls, session: Session, datasource_type: str, datasource_id: int
+ ) -> "BaseDatasource":
"""Returns datasource with columns and metrics."""
datasource_class = ConnectorRegistry.sources[datasource_type]
return (
@@ -84,7 +102,13 @@ class ConnectorRegistry(object):
)
@classmethod
- def query_datasources_by_name(cls, session, database, datasource_name, schema=None):
+ def query_datasources_by_name(
+ cls,
+ session: Session,
+ database: "Database",
+ datasource_name: str,
+ schema: Optional[str] = None,
+ ) -> List["BaseDatasource"]:
datasource_class = ConnectorRegistry.sources[database.type]
return datasource_class.query_datasources_by_name(
session, database, datasource_name, schema=None
diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py
index b908056aaed..581ec7db596 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -24,13 +24,15 @@ import json
import logging
from multiprocessing.pool import ThreadPool
import re
+from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
from dateutil.parser import parse as dparse
from flask import escape, Markup
from flask_appbuilder import Model
from flask_appbuilder.models.decorators import renders
+from flask_appbuilder.security.sqla.models import User
from flask_babel import lazy_gettext as _
-import pandas
+import pandas as pd
try:
from pydruid.client import PyDruid
@@ -41,7 +43,7 @@ try:
RegisteredLookupExtraction,
)
from pydruid.utils.filters import Dimension, Filter
- from pydruid.utils.having import Aggregation
+ from pydruid.utils.having import Aggregation, Having
from pydruid.utils.postaggregator import (
Const,
Field,
@@ -65,12 +67,13 @@ from sqlalchemy import (
Text,
UniqueConstraint,
)
-from sqlalchemy.orm import backref, relationship
+from sqlalchemy.orm import backref, relationship, RelationshipProperty, Session
from sqlalchemy_utils import EncryptedType
from superset import conf, db, security_manager
from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric
from superset.exceptions import SupersetException
+from superset.models.core import Database
from superset.models.helpers import AuditMixinNullable, ImportMixin, QueryResult
from superset.utils import core as utils, import_datasource
@@ -78,6 +81,8 @@ try:
from superset.utils.core import DimSelector, DTTM_ALIAS, flasher
except ImportError:
pass
+
+
DRUID_TZ = conf.get("DRUID_TZ")
POST_AGG_TYPE = "postagg"
metadata = Model.metadata # pylint: disable=no-member
@@ -150,22 +155,22 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
return self.__repr__()
@property
- def data(self):
+ def data(self) -> Dict:
return {"id": self.id, "name": self.cluster_name, "backend": "druid"}
@staticmethod
- def get_base_url(host, port):
+ def get_base_url(host, port) -> str:
if not re.match("http(s)?://", host):
host = "http://" + host
url = "{0}:{1}".format(host, port) if port else host
return url
- def get_base_broker_url(self):
+ def get_base_broker_url(self) -> str:
base_url = self.get_base_url(self.broker_host, self.broker_port)
return f"{base_url}/{self.broker_endpoint}"
- def get_pydruid_client(self):
+ def get_pydruid_client(self) -> PyDruid:
cli = PyDruid(
self.get_base_url(self.broker_host, self.broker_port), self.broker_endpoint
)
@@ -173,39 +178,44 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
cli.set_basic_auth_credentials(self.broker_user, self.broker_pass)
return cli
- def get_datasources(self):
+ def get_datasources(self) -> List[str]:
endpoint = self.get_base_broker_url() + "/datasources"
auth = requests.auth.HTTPBasicAuth(self.broker_user, self.broker_pass)
return json.loads(requests.get(endpoint, auth=auth).text)
- def get_druid_version(self):
+ def get_druid_version(self) -> str:
endpoint = self.get_base_url(self.broker_host, self.broker_port) + "/status"
auth = requests.auth.HTTPBasicAuth(self.broker_user, self.broker_pass)
return json.loads(requests.get(endpoint, auth=auth).text)["version"]
- @property
+ @property # noqa: T484
@utils.memoized
- def druid_version(self):
+ def druid_version(self) -> str:
return self.get_druid_version()
def refresh_datasources(
- self, datasource_name=None, merge_flag=True, refreshAll=True
- ):
+ self,
+ datasource_name: Optional[str] = None,
+ merge_flag: bool = True,
+ refresh_all: bool = True,
+ ) -> None:
"""Refresh metadata of all datasources in the cluster
If ``datasource_name`` is specified, only that datasource is updated
"""
ds_list = self.get_datasources()
blacklist = conf.get("DRUID_DATA_SOURCE_BLACKLIST", [])
- ds_refresh = []
+ ds_refresh: List[str] = []
if not datasource_name:
ds_refresh = list(filter(lambda ds: ds not in blacklist, ds_list))
elif datasource_name not in blacklist and datasource_name in ds_list:
ds_refresh.append(datasource_name)
else:
return
- self.refresh(ds_refresh, merge_flag, refreshAll)
+ self.refresh(ds_refresh, merge_flag, refresh_all)
- def refresh(self, datasource_names, merge_flag, refreshAll):
+ def refresh(
+ self, datasource_names: List[str], merge_flag: bool, refresh_all: bool
+ ) -> None:
"""
Fetches metadata for the specified datasources and
merges to the Superset database
@@ -225,7 +235,7 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
session.add(datasource)
flasher(_("Adding new datasource [{}]").format(ds_name), "success")
ds_map[ds_name] = datasource
- elif refreshAll:
+ elif refresh_all:
flasher(_("Refreshing datasource [{}]").format(ds_name), "info")
else:
del ds_map[ds_name]
@@ -270,19 +280,19 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
session.commit()
@property
- def perm(self):
+ def perm(self) -> str:
return "[{obj.cluster_name}].(id:{obj.id})".format(obj=self)
- def get_perm(self):
+ def get_perm(self) -> str:
return self.perm
@property
- def name(self):
- return self.verbose_name if self.verbose_name else self.cluster_name
+ def name(self) -> str:
+ return self.verbose_name or self.cluster_name
@property
- def unique_name(self):
- return self.verbose_name if self.verbose_name else self.cluster_name
+ def unique_name(self) -> str:
+ return self.verbose_name or self.cluster_name
class DruidColumn(Model, BaseColumn):
@@ -318,25 +328,26 @@ class DruidColumn(Model, BaseColumn):
return self.column_name or str(self.id)
@property
- def expression(self):
+ def expression(self) -> str:
return self.dimension_spec_json
@property
- def dimension_spec(self):
+ def dimension_spec(self) -> Optional[Dict]: # noqa: T484
if self.dimension_spec_json:
return json.loads(self.dimension_spec_json)
- def get_metrics(self):
- metrics = {}
- metrics["count"] = DruidMetric(
- metric_name="count",
- verbose_name="COUNT(*)",
- metric_type="count",
- json=json.dumps({"type": "count", "name": "count"}),
- )
+ def get_metrics(self) -> Dict[str, "DruidMetric"]:
+ metrics = {
+ "count": DruidMetric(
+ metric_name="count",
+ verbose_name="COUNT(*)",
+ metric_type="count",
+ json=json.dumps({"type": "count", "name": "count"}),
+ )
+ }
return metrics
- def refresh_metrics(self):
+ def refresh_metrics(self) -> None:
"""Refresh metrics based on the column metadata"""
metrics = self.get_metrics()
dbmetrics = (
@@ -356,8 +367,8 @@ class DruidColumn(Model, BaseColumn):
db.session.add(metric)
@classmethod
- def import_obj(cls, i_column):
- def lookup_obj(lookup_column):
+ def import_obj(cls, i_column: "DruidColumn") -> "DruidColumn":
+ def lookup_obj(lookup_column: "DruidColumn") -> Optional["DruidColumn"]:
return (
db.session.query(DruidColumn)
.filter(
@@ -404,7 +415,7 @@ class DruidMetric(Model, BaseMetric):
return self.json
@property
- def json_obj(self):
+ def json_obj(self) -> Dict:
try:
obj = json.loads(self.json)
except Exception:
@@ -412,7 +423,7 @@ class DruidMetric(Model, BaseMetric):
return obj
@property
- def perm(self):
+ def perm(self) -> Optional[str]:
return (
("{parent_name}.[{obj.metric_name}](id:{obj.id})").format(
obj=self, parent_name=self.datasource.full_name
@@ -421,12 +432,12 @@ class DruidMetric(Model, BaseMetric):
else None
)
- def get_perm(self):
+ def get_perm(self) -> Optional[str]:
return self.perm
@classmethod
- def import_obj(cls, i_metric):
- def lookup_obj(lookup_metric):
+ def import_obj(cls, i_metric: "DruidMetric") -> "DruidMetric":
+ def lookup_obj(lookup_metric: DruidMetric) -> Optional[DruidMetric]:
return (
db.session.query(DruidMetric)
.filter(
@@ -494,23 +505,23 @@ class DruidDatasource(Model, BaseDatasource):
export_children = ["columns", "metrics"]
@property
- def database(self):
+ def database(self) -> RelationshipProperty:
return self.cluster
@property
- def connection(self):
+ def connection(self) -> str:
return str(self.database)
@property
- def num_cols(self):
+ def num_cols(self) -> List[str]:
return [c.column_name for c in self.columns if c.is_num]
@property
- def name(self):
+ def name(self) -> str:
return self.datasource_name
@property
- def schema(self):
+ def schema(self) -> Optional[str]:
ds_name = self.datasource_name or ""
name_pieces = ds_name.split(".")
if len(name_pieces) > 1:
@@ -519,11 +530,11 @@ class DruidDatasource(Model, BaseDatasource):
return None
@property
- def schema_perm(self):
+ def schema_perm(self) -> Optional[str]:
"""Returns schema permission if present, cluster one otherwise."""
return security_manager.get_schema_perm(self.cluster, self.schema)
- def get_perm(self):
+ def get_perm(self) -> str:
return ("[{obj.cluster_name}].[{obj.datasource_name}]" "(id:{obj.id})").format(
obj=self
)
@@ -532,16 +543,16 @@ class DruidDatasource(Model, BaseDatasource):
return NotImplementedError()
@property
- def link(self):
+ def link(self) -> Markup:
name = escape(self.datasource_name)
return Markup(f'{name}')
@property
- def full_name(self):
+ def full_name(self) -> str:
return utils.get_datasource_full_name(self.cluster_name, self.datasource_name)
@property
- def time_column_grains(self):
+ def time_column_grains(self) -> Dict[str, List[str]]:
return {
"time_columns": [
"all",
@@ -568,16 +579,18 @@ class DruidDatasource(Model, BaseDatasource):
return self.datasource_name
@renders("datasource_name")
- def datasource_link(self):
+ def datasource_link(self) -> str:
url = f"/superset/explore/{self.type}/{self.id}/"
name = escape(self.datasource_name)
return Markup(f'{name}')
- def get_metric_obj(self, metric_name):
+ def get_metric_obj(self, metric_name: str) -> Dict:
return [m.json_obj for m in self.metrics if m.metric_name == metric_name][0]
@classmethod
- def import_obj(cls, i_datasource, import_time=None):
+ def import_obj(
+ cls, i_datasource: "DruidDatasource", import_time: Optional[int] = None
+ ) -> int:
"""Imports the datasource from the object to the database.
Metrics and columns and datasource will be overridden if exists.
@@ -585,7 +598,7 @@ class DruidDatasource(Model, BaseDatasource):
superset instances. Audit metadata isn't copies over.
"""
- def lookup_datasource(d):
+ def lookup_datasource(d: DruidDatasource) -> Optional[DruidDatasource]:
return (
db.session.query(DruidDatasource)
.filter(
@@ -595,7 +608,7 @@ class DruidDatasource(Model, BaseDatasource):
.first()
)
- def lookup_cluster(d):
+ def lookup_cluster(d: DruidDatasource) -> Optional[DruidCluster]:
return (
db.session.query(DruidCluster)
.filter_by(cluster_name=d.cluster_name)
@@ -659,12 +672,14 @@ class DruidDatasource(Model, BaseDatasource):
if segment_metadata:
return segment_metadata[-1]["columns"]
- def refresh_metrics(self):
+ def refresh_metrics(self) -> None:
for col in self.columns:
col.refresh_metrics()
@classmethod
- def sync_to_db_from_config(cls, druid_config, user, cluster, refresh=True):
+ def sync_to_db_from_config(
+ cls, druid_config: Dict, user: User, cluster: DruidCluster, refresh: bool = True
+ ) -> None:
"""Merges the ds config from druid_config into one stored in the db."""
session = db.session
datasource = (
@@ -742,13 +757,15 @@ class DruidDatasource(Model, BaseDatasource):
session.commit()
@staticmethod
- def time_offset(granularity):
+ def time_offset(granularity: Union[str, Dict]) -> int:
if granularity == "week_ending_saturday":
return 6 * 24 * 3600 * 1000 # 6 days
return 0
@classmethod
- def get_datasource_by_name(cls, session, datasource_name, schema, database_name):
+ def get_datasource_by_name(
+ cls, session: Session, datasource_name: str, schema: str, database_name: str
+ ) -> Optional["DruidDatasource"]:
query = (
session.query(cls)
.join(DruidCluster)
@@ -761,7 +778,9 @@ class DruidDatasource(Model, BaseDatasource):
# http://druid.io/docs/0.8.0/querying/granularities.html
# TODO: pass origin from the UI
@staticmethod
- def granularity(period_name, timezone=None, origin=None):
+ def granularity(
+ period_name: str, timezone: Optional[str] = None, origin: Optional[str] = None
+ ) -> Union[str, Dict]:
if not period_name or period_name == "all":
return "all"
iso_8601_dict = {
@@ -810,7 +829,7 @@ class DruidDatasource(Model, BaseDatasource):
return granularity
@staticmethod
- def get_post_agg(mconf):
+ def get_post_agg(mconf: Dict) -> Postaggregator:
"""
For a metric specified as `postagg` returns the
kind of post aggregation for pydruid.
@@ -839,7 +858,9 @@ class DruidDatasource(Model, BaseDatasource):
return CustomPostAggregator(mconf.get("name", ""), mconf)
@staticmethod
- def find_postaggs_for(postagg_names, metrics_dict):
+ def find_postaggs_for(
+ postagg_names: Set[str], metrics_dict: Dict[str, DruidMetric]
+ ) -> List[DruidMetric]:
"""Return a list of metrics that are post aggregations"""
postagg_metrics = [
metrics_dict[name]
@@ -852,7 +873,7 @@ class DruidDatasource(Model, BaseDatasource):
return postagg_metrics
@staticmethod
- def recursive_get_fields(_conf):
+ def recursive_get_fields(_conf: Dict) -> List[str]:
_type = _conf.get("type")
_field = _conf.get("field")
_fields = _conf.get("fields")
@@ -875,11 +896,9 @@ class DruidDatasource(Model, BaseDatasource):
# Check if the fields are already in aggs
# or is a previous postagg
required_fields = set(
- [
- field
- for field in required_fields
- if field not in visited_postaggs and field not in agg_names
- ]
+ field
+ for field in required_fields
+ if field not in visited_postaggs and field not in agg_names
)
# First try to find postaggs that match
if len(required_fields) > 0:
@@ -903,7 +922,11 @@ class DruidDatasource(Model, BaseDatasource):
post_aggs[postagg.metric_name] = DruidDatasource.get_post_agg(postagg.json_obj)
@staticmethod
- def metrics_and_post_aggs(metrics, metrics_dict, druid_version=None):
+ def metrics_and_post_aggs(
+ metrics: List[Union[Dict, str]],
+ metrics_dict: Dict[str, DruidMetric],
+ druid_version=None,
+ ) -> Tuple[OrderedDict, OrderedDict]: # noqa: T484
# Separate metrics into those that are aggregations
# and those that are post aggregations
saved_agg_names = set()
@@ -912,26 +935,26 @@ class DruidDatasource(Model, BaseDatasource):
for metric in metrics:
if utils.is_adhoc_metric(metric):
adhoc_agg_configs.append(metric)
- elif metrics_dict[metric].metric_type != POST_AGG_TYPE:
+ elif metrics_dict[metric].metric_type != POST_AGG_TYPE: # noqa: T484
saved_agg_names.add(metric)
else:
postagg_names.append(metric)
# Create the post aggregations, maintain order since postaggs
# may depend on previous ones
- post_aggs = OrderedDict()
+ post_aggs = OrderedDict() # noqa: T484
visited_postaggs = set()
for postagg_name in postagg_names:
- postagg = metrics_dict[postagg_name]
+ postagg = metrics_dict[postagg_name] # noqa: T484
visited_postaggs.add(postagg_name)
DruidDatasource.resolve_postagg(
postagg, post_aggs, saved_agg_names, visited_postaggs, metrics_dict
)
- aggs = DruidDatasource.get_aggregations(
+ aggs = DruidDatasource.get_aggregations( # noqa: T484
metrics_dict, saved_agg_names, adhoc_agg_configs
)
return aggs, post_aggs
- def values_for_column(self, column_name, limit=10000):
+ def values_for_column(self, column_name: str, limit: int = 10000) -> List:
"""Retrieve some values for the given column"""
logging.info(
"Getting values for columns [{}] limited to [{}]".format(column_name, limit)
@@ -955,12 +978,14 @@ class DruidDatasource(Model, BaseDatasource):
client = self.cluster.get_pydruid_client()
client.topn(**qry)
df = client.export_pandas()
- return [row[column_name] for row in df.to_records(index=False)]
+ return df[column_name].to_list()
def get_query_str(self, query_obj, phase=1, client=None):
return self.run_query(client=client, phase=phase, **query_obj)
- def _add_filter_from_pre_query_data(self, df, dimensions, dim_filter):
+ def _add_filter_from_pre_query_data(
+ self, df: Optional[pd.DataFrame], dimensions, dim_filter
+ ):
ret = dim_filter
if df is not None and not df.empty:
new_filters = []
@@ -1002,7 +1027,7 @@ class DruidDatasource(Model, BaseDatasource):
return ret
@staticmethod
- def druid_type_from_adhoc_metric(adhoc_metric):
+ def druid_type_from_adhoc_metric(adhoc_metric: Dict) -> str:
column_type = adhoc_metric["column"]["type"].lower()
aggregate = adhoc_metric["aggregate"].lower()
@@ -1014,7 +1039,9 @@ class DruidDatasource(Model, BaseDatasource):
return column_type + aggregate.capitalize()
@staticmethod
- def get_aggregations(metrics_dict, saved_metrics, adhoc_metrics=[]):
+ def get_aggregations(
+ metrics_dict: Dict, saved_metrics: Iterable[str], adhoc_metrics: List[Dict] = []
+ ) -> OrderedDict:
"""
Returns a dictionary of aggregation metric names to aggregation json objects
@@ -1023,7 +1050,7 @@ class DruidDatasource(Model, BaseDatasource):
:param adhoc_metrics: list of adhoc metric names
:raise SupersetException: if one or more metric names are not aggregations
"""
- aggregations = OrderedDict()
+ aggregations: OrderedDict = OrderedDict()
invalid_metric_names = []
for metric_name in saved_metrics:
if metric_name in metrics_dict:
@@ -1047,19 +1074,18 @@ class DruidDatasource(Model, BaseDatasource):
}
return aggregations
- def get_dimensions(self, groupby, columns_dict):
+ def get_dimensions(
+ self, groupby: List[str], columns_dict: Dict[str, DruidColumn]
+ ) -> List[Union[str, Dict]]:
dimensions = []
groupby = [gb for gb in groupby if gb in columns_dict]
for column_name in groupby:
col = columns_dict.get(column_name)
dim_spec = col.dimension_spec if col else None
- if dim_spec:
- dimensions.append(dim_spec)
- else:
- dimensions.append(column_name)
+ dimensions.append(dim_spec or column_name)
return dimensions
- def intervals_from_dttms(self, from_dttm, to_dttm):
+ def intervals_from_dttms(self, from_dttm: datetime, to_dttm: datetime) -> str:
# Couldn't find a way to just not filter on time...
from_dttm = from_dttm or datetime(1901, 1, 1)
to_dttm = to_dttm or datetime(2101, 1, 1)
@@ -1091,7 +1117,7 @@ class DruidDatasource(Model, BaseDatasource):
return values
@staticmethod
- def sanitize_metric_object(metric):
+ def sanitize_metric_object(metric: Dict) -> None:
"""
Update a metric with the correct type if necessary.
:param dict metric: The metric to sanitize
@@ -1122,7 +1148,7 @@ class DruidDatasource(Model, BaseDatasource):
phase=2,
client=None,
order_desc=True,
- ):
+ ) -> str:
"""Runs a query against Druid and returns a dataframe.
"""
# TODO refactor into using a TBD Query object
@@ -1193,7 +1219,7 @@ class DruidDatasource(Model, BaseDatasource):
del qry["dimensions"]
client.timeseries(**qry)
elif not having_filters and len(groupby) == 1 and order_desc:
- dim = list(qry.get("dimensions"))[0]
+ dim = list(qry.get("dimensions"))[0] # noqa: T484
logging.info("Running two-phase topn query for dimension [{}]".format(dim))
pre_qry = deepcopy(qry)
if timeseries_limit_metric:
@@ -1324,7 +1350,7 @@ class DruidDatasource(Model, BaseDatasource):
return query_str
@staticmethod
- def homogenize_types(df, groupby_cols):
+ def homogenize_types(df: pd.DataFrame, groupby_cols: Iterable[str]) -> pd.DataFrame:
"""Converting all GROUPBY columns to strings
When grouping by a numeric (say FLOAT) column, pydruid returns
@@ -1334,11 +1360,10 @@ class DruidDatasource(Model, BaseDatasource):
Here we replace None with and make the whole series a
str instead of an object.
"""
- for col in groupby_cols:
- df[col] = df[col].fillna("").astype("unicode")
+ df[groupby_cols] = df[groupby_cols].fillna("").astype("unicode")
return df
- def query(self, query_obj):
+ def query(self, query_obj: Dict) -> QueryResult:
qry_start_dttm = datetime.now()
client = self.cluster.get_pydruid_client()
query_str = self.get_query_str(client=client, query_obj=query_obj, phase=2)
@@ -1346,7 +1371,7 @@ class DruidDatasource(Model, BaseDatasource):
if df is None or df.size == 0:
return QueryResult(
- df=pandas.DataFrame([]),
+ df=pd.DataFrame([]),
query=query_str,
duration=datetime.now() - qry_start_dttm,
)
@@ -1363,7 +1388,7 @@ class DruidDatasource(Model, BaseDatasource):
del df[DTTM_ALIAS]
# Reordering columns
- cols = []
+ cols: List[str] = []
if DTTM_ALIAS in df.columns:
cols += [DTTM_ALIAS]
cols += query_obj.get("groupby") or []
@@ -1413,7 +1438,7 @@ class DruidDatasource(Model, BaseDatasource):
return (col, extraction_fn)
@classmethod
- def get_filters(cls, raw_filters, num_cols, columns_dict): # noqa
+ def get_filters(cls, raw_filters, num_cols, columns_dict) -> Filter: # noqa: T484
"""Given Superset filter data structure, returns pydruid Filter(s)"""
filters = None
for flt in raw_filters:
@@ -1542,7 +1567,7 @@ class DruidDatasource(Model, BaseDatasource):
return filters
- def _get_having_obj(self, col, op, eq):
+ def _get_having_obj(self, col: str, op: str, eq: str) -> Having:
cond = None
if op == "==":
if col in self.column_names:
@@ -1556,7 +1581,7 @@ class DruidDatasource(Model, BaseDatasource):
return cond
- def get_having_filters(self, raw_filters):
+ def get_having_filters(self, raw_filters: List[Dict]) -> Having:
filters = None
reversed_op_map = {"!=": "==", ">=": "<", "<=": ">"}
@@ -1579,7 +1604,9 @@ class DruidDatasource(Model, BaseDatasource):
return filters
@classmethod
- def query_datasources_by_name(cls, session, database, datasource_name, schema=None):
+ def query_datasources_by_name(
+ cls, session: Session, database: Database, datasource_name: str, schema=None
+ ) -> List["DruidDatasource"]:
return (
session.query(cls)
.filter_by(cluster_name=database.id)
@@ -1587,7 +1614,7 @@ class DruidDatasource(Model, BaseDatasource):
.all()
)
- def external_metadata(self):
+ def external_metadata(self) -> List[Dict]:
self.merge_flag = True
return [
{"name": k, "type": v.get("type")}
diff --git a/superset/connectors/druid/views.py b/superset/connectors/druid/views.py
index 775e73f2163..606d3380b8a 100644
--- a/superset/connectors/druid/views.py
+++ b/superset/connectors/druid/views.py
@@ -395,7 +395,7 @@ class Druid(BaseSupersetView):
@has_access
@expose("/refresh_datasources/")
- def refresh_datasources(self, refreshAll=True):
+ def refresh_datasources(self, refresh_all=True):
"""endpoint that refreshes druid datasources metadata"""
session = db.session()
DruidCluster = ConnectorRegistry.sources["druid"].cluster_class
@@ -403,7 +403,7 @@ class Druid(BaseSupersetView):
cluster_name = cluster.cluster_name
valid_cluster = True
try:
- cluster.refresh_datasources(refreshAll=refreshAll)
+ cluster.refresh_datasources(refresh_all=refresh_all)
except Exception as e:
valid_cluster = False
flash(
@@ -432,7 +432,7 @@ class Druid(BaseSupersetView):
Calling this endpoint will cause a scan for new
datasources only and add them.
"""
- return self.refresh_datasources(refreshAll=False)
+ return self.refresh_datasources(refresh_all=False)
appbuilder.add_view_no_menu(Druid)
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 41a79579471..980e607772b 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -42,10 +42,10 @@ from sqlalchemy import (
Text,
)
from sqlalchemy.exc import CompileError
-from sqlalchemy.orm import backref, relationship
+from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.schema import UniqueConstraint
-from sqlalchemy.sql import column, literal_column, table, text
+from sqlalchemy.sql import column, ColumnElement, literal_column, table, text
from sqlalchemy.sql.expression import Label, Select, TextAsFrom
import sqlparse
@@ -83,7 +83,7 @@ class AnnotationDatasource(BaseDatasource):
cache_timeout = 0
- def query(self, query_obj):
+ def query(self, query_obj: Dict) -> QueryResult:
df = None
error_message = None
qry = db.session.query(Annotation)
@@ -143,7 +143,7 @@ class TableColumn(Model, BaseColumn):
update_from_object_fields = [s for s in export_fields if s not in ("table_id",)]
export_parent = "table"
- def get_sqla_col(self, label=None):
+ def get_sqla_col(self, label: Optional[str] = None) -> Column:
label = label or self.column_name
if not self.expression:
db_engine_spec = self.table.database.db_engine_spec
@@ -155,10 +155,12 @@ class TableColumn(Model, BaseColumn):
return col
@property
- def datasource(self):
+ def datasource(self) -> RelationshipProperty:
return self.table
- def get_time_filter(self, start_dttm, end_dttm):
+ def get_time_filter(
+ self, start_dttm: DateTime, end_dttm: DateTime
+ ) -> ColumnElement:
col = self.get_sqla_col(label="__time")
l = [] # noqa: E741
if start_dttm:
@@ -205,7 +207,7 @@ class TableColumn(Model, BaseColumn):
return import_datasource.import_simple_obj(db.session, i_column, lookup_obj)
- def dttm_sql_literal(self, dttm):
+ def dttm_sql_literal(self, dttm: DateTime) -> str:
"""Convert datetime object to a SQL expression string"""
tf = self.python_date_format
if tf:
@@ -249,13 +251,13 @@ class SqlMetric(Model, BaseMetric):
)
export_parent = "table"
- def get_sqla_col(self, label=None):
+ def get_sqla_col(self, label: Optional[str] = None) -> Column:
label = label or self.metric_name
sqla_col = literal_column(self.expression)
return self.table.make_sqla_column_compatible(sqla_col, label)
@property
- def perm(self):
+ def perm(self) -> Optional[str]:
return (
("{parent_name}.[{obj.metric_name}](id:{obj.id})").format(
obj=self, parent_name=self.table.full_name
@@ -264,7 +266,7 @@ class SqlMetric(Model, BaseMetric):
else None
)
- def get_perm(self):
+ def get_perm(self) -> Optional[str]:
return self.perm
@classmethod
@@ -351,7 +353,9 @@ class SqlaTable(Model, BaseDatasource):
"MAX": sa.func.MAX,
}
- def make_sqla_column_compatible(self, sqla_col, label=None):
+ def make_sqla_column_compatible(
+ self, sqla_col: Column, label: Optional[str] = None
+ ) -> Column:
"""Takes a sql alchemy column object and adds label info if supported by engine.
:param sqla_col: sql alchemy column instance
:param label: alias/label that column is expected to have
@@ -369,23 +373,29 @@ class SqlaTable(Model, BaseDatasource):
return self.name
@property
- def connection(self):
+ def connection(self) -> str:
return str(self.database)
@property
- def description_markeddown(self):
+ def description_markeddown(self) -> str:
return utils.markdown(self.description)
@property
- def datasource_name(self):
+ def datasource_name(self) -> str:
return self.table_name
@property
- def database_name(self):
+ def database_name(self) -> str:
return self.database.name
@classmethod
- def get_datasource_by_name(cls, session, datasource_name, schema, database_name):
+ def get_datasource_by_name(
+ cls,
+ session: Session,
+ datasource_name: str,
+ schema: Optional[str],
+ database_name: str,
+ ) -> Optional["SqlaTable"]:
schema = schema or None
query = (
session.query(cls)
@@ -398,52 +408,52 @@ class SqlaTable(Model, BaseDatasource):
for tbl in query.all():
if schema == (tbl.schema or None):
return tbl
+ return None
@property
- def link(self):
+ def link(self) -> Markup:
name = escape(self.name)
anchor = f'{name}'
return Markup(anchor)
@property
- def schema_perm(self):
+ def schema_perm(self) -> Optional[str]:
"""Returns schema permission if present, database one otherwise."""
return security_manager.get_schema_perm(self.database, self.schema)
- def get_perm(self):
+ def get_perm(self) -> str:
return ("[{obj.database}].[{obj.table_name}]" "(id:{obj.id})").format(obj=self)
@property
- def name(self):
+ def name(self) -> str:
if not self.schema:
return self.table_name
return "{}.{}".format(self.schema, self.table_name)
@property
- def full_name(self):
+ def full_name(self) -> str:
return utils.get_datasource_full_name(
self.database, self.table_name, schema=self.schema
)
@property
- def dttm_cols(self):
+ def dttm_cols(self) -> List:
l = [c.column_name for c in self.columns if c.is_dttm] # noqa: E741
if self.main_dttm_col and self.main_dttm_col not in l:
l.append(self.main_dttm_col)
return l
@property
- def num_cols(self):
+ def num_cols(self) -> List:
return [c.column_name for c in self.columns if c.is_num]
@property
- def any_dttm_col(self):
+ def any_dttm_col(self) -> Optional[str]:
cols = self.dttm_cols
- if cols:
- return cols[0]
+ return cols[0] if cols else None
@property
- def html(self):
+ def html(self) -> str:
t = ((c.column_name, c.type) for c in self.columns)
df = pd.DataFrame(t)
df.columns = ["field", "type"]
@@ -453,7 +463,7 @@ class SqlaTable(Model, BaseDatasource):
)
@property
- def sql_url(self):
+ def sql_url(self) -> str:
return self.database.sql_url + "?table_name=" + str(self.table_name)
def external_metadata(self):
@@ -466,28 +476,29 @@ class SqlaTable(Model, BaseDatasource):
return cols
@property
- def time_column_grains(self):
+ def time_column_grains(self) -> Dict[str, Any]:
return {
"time_columns": self.dttm_cols,
"time_grains": [grain.name for grain in self.database.grains()],
}
@property
- def select_star(self):
+ def select_star(self) -> str:
# show_cols and latest_partition set to false to avoid
# the expensive cost of inspecting the DB
return self.database.select_star(
self.table_name, schema=self.schema, show_cols=False, latest_partition=False
)
- def get_col(self, col_name):
+ def get_col(self, col_name: str) -> Optional[Column]:
columns = self.columns
for col in columns:
if col_name == col.column_name:
return col
+ return None
@property
- def data(self):
+ def data(self) -> Dict:
d = super(SqlaTable, self).data
if self.type == "table":
grains = self.database.grains() or []
@@ -500,7 +511,7 @@ class SqlaTable(Model, BaseDatasource):
d["template_params"] = self.template_params
return d
- def values_for_column(self, column_name, limit=10000):
+ def values_for_column(self, column_name: str, limit: int = 10000) -> List:
"""Runs query against sqla to retrieve some
sample values for the given column.
"""
@@ -525,9 +536,9 @@ class SqlaTable(Model, BaseDatasource):
sql = self.mutate_query_from_config(sql)
df = pd.read_sql_query(sql=sql, con=engine)
- return [row[0] for row in df.to_records(index=False)]
+ return df[column_name].to_list()
- def mutate_query_from_config(self, sql):
+ def mutate_query_from_config(self, sql: str) -> str:
"""Apply config's SQL_QUERY_MUTATOR
Typically adds comments to the query with context"""
@@ -540,7 +551,7 @@ class SqlaTable(Model, BaseDatasource):
def get_template_processor(self, **kwargs):
return get_template_processor(table=self, database=self.database, **kwargs)
- def get_query_str_extended(self, query_obj) -> QueryStringExtended:
+ def get_query_str_extended(self, query_obj: Dict) -> QueryStringExtended:
sqlaq = self.get_sqla_query(**query_obj)
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
logging.info(sql)
@@ -550,7 +561,7 @@ class SqlaTable(Model, BaseDatasource):
labels_expected=sqlaq.labels_expected, sql=sql, prequeries=sqlaq.prequeries
)
- def get_query_str(self, query_obj):
+ def get_query_str(self, query_obj: Dict) -> str:
query_str_ext = self.get_query_str_extended(query_obj)
all_queries = query_str_ext.prequeries + [query_str_ext.sql]
return ";\n\n".join(all_queries) + ";"
@@ -571,7 +582,7 @@ class SqlaTable(Model, BaseDatasource):
return TextAsFrom(sa.text(from_sql), []).alias("expr_qry")
return self.get_sqla_table()
- def adhoc_metric_to_sqla(self, metric, cols):
+ def adhoc_metric_to_sqla(self, metric: Dict, cols: Dict) -> Optional[Column]:
"""
Turn an adhoc metric into a sqlalchemy column.
@@ -584,13 +595,13 @@ class SqlaTable(Model, BaseDatasource):
label = utils.get_metric_name(metric)
if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"]:
- column_name = metric.get("column").get("column_name")
+ column_name = metric["column"].get("column_name")
table_column = cols.get(column_name)
if table_column:
sqla_column = table_column.get_sqla_col()
else:
sqla_column = column(column_name)
- sqla_metric = self.sqla_aggregations[metric.get("aggregate")](sqla_column)
+ sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column)
elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SQL"]:
sqla_metric = literal_column(metric.get("sqlExpression"))
else:
@@ -616,7 +627,7 @@ class SqlaTable(Model, BaseDatasource):
extras=None,
columns=None,
order_desc=True,
- ):
+ ) -> SqlaQuery:
"""Querying any sqla table from this common interface"""
template_kwargs = {
"from_dttm": from_dttm,
@@ -643,8 +654,8 @@ class SqlaTable(Model, BaseDatasource):
# Database spec supports join-free timeslot grouping
time_groupby_inline = db_engine_spec.time_groupby_inline
- cols = {col.column_name: col for col in self.columns}
- metrics_dict = {m.metric_name: m for m in self.metrics}
+ cols: Dict[str, Column] = {col.column_name: col for col in self.columns}
+ metrics_dict: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics}
if not granularity and is_timeseries:
raise Exception(
@@ -660,7 +671,7 @@ class SqlaTable(Model, BaseDatasource):
if utils.is_adhoc_metric(m):
metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols))
elif m in metrics_dict:
- metrics_exprs.append(metrics_dict.get(m).get_sqla_col())
+ metrics_exprs.append(metrics_dict[m].get_sqla_col())
else:
raise Exception(_("Metric '%(metric)s' does not exist", metric=m))
if metrics_exprs:
@@ -669,8 +680,8 @@ class SqlaTable(Model, BaseDatasource):
main_metric_expr, label = literal_column("COUNT(*)"), "ccount"
main_metric_expr = self.make_sqla_column_compatible(main_metric_expr, label)
- select_exprs = []
- groupby_exprs_sans_timestamp = OrderedDict()
+ select_exprs: List[Column] = []
+ groupby_exprs_sans_timestamp: OrderedDict = OrderedDict()
if groupby:
select_exprs = []
@@ -729,7 +740,7 @@ class SqlaTable(Model, BaseDatasource):
qry = qry.group_by(*groupby_exprs_with_timestamp.values())
where_clause_and = []
- having_clause_and = []
+ having_clause_and: List = []
for flt in filter:
if not all([flt.get(s) for s in ["col", "op"]]):
continue
@@ -899,7 +910,9 @@ class SqlaTable(Model, BaseDatasource):
return ob
- def _get_top_groups(self, df, dimensions, groupby_exprs):
+ def _get_top_groups(
+ self, df: pd.DataFrame, dimensions: List, groupby_exprs: OrderedDict
+ ) -> ColumnElement:
groups = []
for unused, row in df.iterrows():
group = []
@@ -909,7 +922,7 @@ class SqlaTable(Model, BaseDatasource):
return or_(*groups)
- def query(self, query_obj):
+ def query(self, query_obj: Dict) -> QueryResult:
qry_start_dttm = datetime.now()
query_str_ext = self.get_query_str_extended(query_obj)
sql = query_str_ext.sql
@@ -945,10 +958,10 @@ class SqlaTable(Model, BaseDatasource):
error_message=error_message,
)
- def get_sqla_table_object(self):
+ def get_sqla_table_object(self) -> Table:
return self.database.get_table(self.table_name, schema=self.schema)
- def fetch_metadata(self):
+ def fetch_metadata(self) -> None:
"""Fetches the metadata for the table and merges it in"""
try:
table = self.get_sqla_table_object()
@@ -1012,7 +1025,7 @@ class SqlaTable(Model, BaseDatasource):
db.session.commit()
@classmethod
- def import_obj(cls, i_datasource, import_time=None):
+ def import_obj(cls, i_datasource, import_time=None) -> int:
"""Imports the datasource from the object to the database.
Metrics and columns and datasource will be overrided if exists.
@@ -1052,7 +1065,9 @@ class SqlaTable(Model, BaseDatasource):
)
@classmethod
- def query_datasources_by_name(cls, session, database, datasource_name, schema=None):
+ def query_datasources_by_name(
+ cls, session: Session, database: Database, datasource_name: str, schema=None
+ ) -> List["SqlaTable"]:
query = (
session.query(cls)
.filter_by(database_id=database.id)
@@ -1063,7 +1078,7 @@ class SqlaTable(Model, BaseDatasource):
return query.all()
@staticmethod
- def default_query(qry):
+ def default_query(qry) -> Query:
return qry.filter_by(is_sqllab_view=False)
def has_extra_cache_keys(self, query_obj: Dict) -> bool: