# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import json import textwrap from typing import Dict, List, Tuple, Union import pandas as pd from flask_appbuilder.security.sqla.models import User from sqlalchemy import DateTime, String from sqlalchemy.sql import column from superset import app, db, security_manager from superset.connectors.base.models import BaseDatasource from superset.connectors.sqla.models import SqlMetric, TableColumn from superset.exceptions import NoDataException from superset.models.core import Database from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.utils.core import get_example_database from .helpers import ( get_example_data, get_slice_json, get_table_connector_registry, merge_slice, misc_dash_slices, update_slice_ids, ) def get_admin_user() -> User: admin = security_manager.find_user("admin") if admin is None: raise NoDataException( "Admin user does not exist. " "Please, check if test users are properly loaded " "(`superset load_test_users`)." ) return admin def gen_filter( subject: str, comparator: str, operator: str = "==" ) -> Dict[str, Union[bool, str]]: return { "clause": "WHERE", "comparator": comparator, "expressionType": "SIMPLE", "operator": operator, "subject": subject, } def load_data(tbl_name: str, database: Database, sample: bool = False) -> None: pdf = pd.read_json(get_example_data("birth_names2.json.gz")) # TODO(bkyryliuk): move load examples data into the pytest fixture if database.backend == "presto": pdf.ds = pd.to_datetime(pdf.ds, unit="ms") pdf.ds = pdf.ds.dt.strftime("%Y-%m-%d %H:%M%:%S") else: pdf.ds = pd.to_datetime(pdf.ds, unit="ms") pdf = pdf.head(100) if sample else pdf pdf.to_sql( tbl_name, database.get_sqla_engine(), if_exists="replace", chunksize=500, dtype={ # TODO(bkyryliuk): use TIMESTAMP type for presto "ds": DateTime if database.backend != "presto" else String(255), "gender": String(16), "state": String(10), "name": String(255), }, method="multi", index=False, ) print("Done loading table!") print("-" * 80) def load_birth_names( only_metadata: bool = False, force: bool = False, sample: bool = False ) -> None: """Loading birth name dataset from a zip file in the repo""" # pylint: disable=too-many-locals tbl_name = "birth_names" database = get_example_database() table_exists = database.has_table_by_name(tbl_name) if not only_metadata and (not table_exists or force): load_data(tbl_name, database, sample=sample) table = get_table_connector_registry() obj = db.session.query(table).filter_by(table_name=tbl_name).first() if not obj: print(f"Creating table [{tbl_name}] reference") obj = table(table_name=tbl_name) db.session.add(obj) _set_table_metadata(obj, database) _add_table_metrics(obj) db.session.commit() slices, _ = create_slices(obj, admin_owner=True) create_dashboard(slices) def _set_table_metadata(datasource: "BaseDatasource", database: "Database") -> None: datasource.main_dttm_col = "ds" # type: ignore datasource.database = database datasource.filter_select_enabled = True datasource.fetch_metadata() def _add_table_metrics(datasource: "BaseDatasource") -> None: if not any(col.column_name == "num_california" for col in datasource.columns): col_state = str(column("state").compile(db.engine)) col_num = str(column("num").compile(db.engine)) datasource.columns.append( TableColumn( column_name="num_california", expression=f"CASE WHEN {col_state} = 'CA' THEN {col_num} ELSE 0 END", ) ) if not any(col.metric_name == "sum__num" for col in datasource.metrics): col = str(column("num").compile(db.engine)) datasource.metrics.append( SqlMetric(metric_name="sum__num", expression=f"SUM({col})") ) for col in datasource.columns: if col.column_name == "ds": col.is_dttm = True # type: ignore break def create_slices( tbl: BaseDatasource, admin_owner: bool ) -> Tuple[List[Slice], List[Slice]]: metrics = [ { "expressionType": "SIMPLE", "column": {"column_name": "num", "type": "BIGINT"}, "aggregate": "SUM", "label": "Births", "optionName": "metric_11", } ] metric = "sum__num" defaults = { "compare_lag": "10", "compare_suffix": "o10Y", "limit": "25", "time_range": "No filter", "time_range_endpoints": ["inclusive", "exclusive"], "granularity_sqla": "ds", "groupby": [], "row_limit": app.config["ROW_LIMIT"], "since": "100 years ago", "until": "now", "viz_type": "table", "markup_type": "markdown", } admin = get_admin_user() if admin_owner: slice_props = dict( datasource_id=tbl.id, datasource_type="table", owners=[admin], created_by=admin, ) else: slice_props = dict( datasource_id=tbl.id, datasource_type="table", owners=[], created_by=admin ) print("Creating some slices") slices = [ Slice( **slice_props, slice_name="Participants", viz_type="big_number", params=get_slice_json( defaults, viz_type="big_number", granularity_sqla="ds", compare_lag="5", compare_suffix="over 5Y", metric=metric, ), ), Slice( **slice_props, slice_name="Genders", viz_type="pie", params=get_slice_json( defaults, viz_type="pie", groupby=["gender"], metric=metric ), ), Slice( **slice_props, slice_name="Trends", viz_type="line", params=get_slice_json( defaults, viz_type="line", groupby=["name"], granularity_sqla="ds", rich_tooltip=True, show_legend=True, metrics=metrics, ), ), Slice( **slice_props, slice_name="Genders by State", viz_type="dist_bar", params=get_slice_json( defaults, adhoc_filters=[ { "clause": "WHERE", "expressionType": "SIMPLE", "filterOptionName": "2745eae5", "comparator": ["other"], "operator": "NOT IN", "subject": "state", } ], viz_type="dist_bar", metrics=[ { "expressionType": "SIMPLE", "column": {"column_name": "num_boys", "type": "BIGINT(20)"}, "aggregate": "SUM", "label": "Boys", "optionName": "metric_11", }, { "expressionType": "SIMPLE", "column": {"column_name": "num_girls", "type": "BIGINT(20)"}, "aggregate": "SUM", "label": "Girls", "optionName": "metric_12", }, ], groupby=["state"], ), ), Slice( **slice_props, slice_name="Girls", viz_type="table", params=get_slice_json( defaults, groupby=["name"], adhoc_filters=[gen_filter("gender", "girl")], row_limit=50, timeseries_limit_metric=metric, metrics=[metric], ), ), Slice( **slice_props, slice_name="Girl Name Cloud", viz_type="word_cloud", params=get_slice_json( defaults, viz_type="word_cloud", size_from="10", series="name", size_to="70", rotation="square", limit="100", adhoc_filters=[gen_filter("gender", "girl")], metric=metric, ), ), Slice( **slice_props, slice_name="Boys", viz_type="table", params=get_slice_json( defaults, groupby=["name"], adhoc_filters=[gen_filter("gender", "boy")], row_limit=50, timeseries_limit_metric=metric, metrics=[metric], ), ), Slice( **slice_props, slice_name="Boy Name Cloud", viz_type="word_cloud", params=get_slice_json( defaults, viz_type="word_cloud", size_from="10", series="name", size_to="70", rotation="square", limit="100", adhoc_filters=[gen_filter("gender", "boy")], metric=metric, ), ), Slice( **slice_props, slice_name="Top 10 Girl Name Share", viz_type="area", params=get_slice_json( defaults, adhoc_filters=[gen_filter("gender", "girl")], comparison_type="values", groupby=["name"], limit=10, stacked_style="expand", time_grain_sqla="P1D", viz_type="area", x_axis_forma="smart_date", metrics=metrics, ), ), Slice( **slice_props, slice_name="Top 10 Boy Name Share", viz_type="area", params=get_slice_json( defaults, adhoc_filters=[gen_filter("gender", "boy")], comparison_type="values", groupby=["name"], limit=10, stacked_style="expand", time_grain_sqla="P1D", viz_type="area", x_axis_forma="smart_date", metrics=metrics, ), ), ] misc_slices = [ Slice( **slice_props, slice_name="Average and Sum Trends", viz_type="dual_line", params=get_slice_json( defaults, viz_type="dual_line", metric={ "expressionType": "SIMPLE", "column": {"column_name": "num", "type": "BIGINT(20)"}, "aggregate": "AVG", "label": "AVG(num)", "optionName": "metric_vgops097wej_g8uff99zhk7", }, metric_2="sum__num", granularity_sqla="ds", metrics=metrics, ), ), Slice( **slice_props, slice_name="Num Births Trend", viz_type="line", params=get_slice_json(defaults, viz_type="line", metrics=metrics), ), Slice( **slice_props, slice_name="Daily Totals", viz_type="table", params=get_slice_json( defaults, groupby=["ds"], since="40 years ago", until="now", viz_type="table", metrics=metrics, ), ), Slice( **slice_props, slice_name="Number of California Births", viz_type="big_number_total", params=get_slice_json( defaults, metric={ "expressionType": "SIMPLE", "column": { "column_name": "num_california", "expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END", }, "aggregate": "SUM", "label": "SUM(num_california)", }, viz_type="big_number_total", granularity_sqla="ds", ), ), Slice( **slice_props, slice_name="Top 10 California Names Timeseries", viz_type="line", params=get_slice_json( defaults, metrics=[ { "expressionType": "SIMPLE", "column": { "column_name": "num_california", "expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END", }, "aggregate": "SUM", "label": "SUM(num_california)", } ], viz_type="line", granularity_sqla="ds", groupby=["name"], timeseries_limit_metric={ "expressionType": "SIMPLE", "column": { "column_name": "num_california", "expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END", }, "aggregate": "SUM", "label": "SUM(num_california)", }, limit="10", ), ), Slice( **slice_props, slice_name="Names Sorted by Num in California", viz_type="table", params=get_slice_json( defaults, metrics=metrics, groupby=["name"], row_limit=50, timeseries_limit_metric={ "expressionType": "SIMPLE", "column": { "column_name": "num_california", "expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END", }, "aggregate": "SUM", "label": "SUM(num_california)", }, ), ), Slice( **slice_props, slice_name="Number of Girls", viz_type="big_number_total", params=get_slice_json( defaults, metric=metric, viz_type="big_number_total", granularity_sqla="ds", adhoc_filters=[gen_filter("gender", "girl")], subheader="total female participants", ), ), Slice( **slice_props, slice_name="Pivot Table", viz_type="pivot_table", params=get_slice_json( defaults, viz_type="pivot_table", groupby=["name"], columns=["state"], metrics=metrics, ), ), ] for slc in slices: merge_slice(slc) for slc in misc_slices: merge_slice(slc) misc_dash_slices.add(slc.slice_name) return slices, misc_slices def create_dashboard(slices: List[Slice]) -> Dashboard: print("Creating a dashboard") admin = get_admin_user() dash = db.session.query(Dashboard).filter_by(slug="births").first() if not dash: dash = Dashboard() dash.owners = [admin] dash.created_by = admin db.session.add(dash) dash.published = True dash.json_metadata = textwrap.dedent( """\ { "label_colors": { "Girls": "#FF69B4", "Boys": "#ADD8E6", "girl": "#FF69B4", "boy": "#ADD8E6" } }""" ) # pylint: disable=line-too-long pos = json.loads( textwrap.dedent( """\ { "CHART-6GdlekVise": { "children": [], "id": "CHART-6GdlekVise", "meta": { "chartId": 5547, "height": 50, "sliceName": "Top 10 Girl Name Share", "width": 5 }, "parents": [ "ROOT_ID", "GRID_ID", "ROW-eh0w37bWbR" ], "type": "CHART" }, "CHART-6n9jxb30JG": { "children": [], "id": "CHART-6n9jxb30JG", "meta": { "chartId": 5540, "height": 36, "sliceName": "Genders by State", "width": 5 }, "parents": [ "ROOT_ID", "GRID_ID", "ROW--EyBZQlDi" ], "type": "CHART" }, "CHART-Jj9qh1ol-N": { "children": [], "id": "CHART-Jj9qh1ol-N", "meta": { "chartId": 5545, "height": 50, "sliceName": "Boy Name Cloud", "width": 4 }, "parents": [ "ROOT_ID", "GRID_ID", "ROW-kzWtcvo8R1" ], "type": "CHART" }, "CHART-ODvantb_bF": { "children": [], "id": "CHART-ODvantb_bF", "meta": { "chartId": 5548, "height": 50, "sliceName": "Top 10 Boy Name Share", "width": 5 }, "parents": [ "ROOT_ID", "GRID_ID", "ROW-kzWtcvo8R1" ], "type": "CHART" }, "CHART-PAXUUqwmX9": { "children": [], "id": "CHART-PAXUUqwmX9", "meta": { "chartId": 5538, "height": 34, "sliceName": "Genders", "width": 3 }, "parents": [ "ROOT_ID", "GRID_ID", "ROW-2n0XgiHDgs" ], "type": "CHART" }, "CHART-_T6n_K9iQN": { "children": [], "id": "CHART-_T6n_K9iQN", "meta": { "chartId": 5539, "height": 36, "sliceName": "Trends", "width": 7 }, "parents": [ "ROOT_ID", "GRID_ID", "ROW--EyBZQlDi" ], "type": "CHART" }, "CHART-eNY0tcE_ic": { "children": [], "id": "CHART-eNY0tcE_ic", "meta": { "chartId": 5537, "height": 34, "sliceName": "Participants", "width": 3 }, "parents": [ "ROOT_ID", "GRID_ID", "ROW-2n0XgiHDgs" ], "type": "CHART" }, "CHART-g075mMgyYb": { "children": [], "id": "CHART-g075mMgyYb", "meta": { "chartId": 5541, "height": 50, "sliceName": "Girls", "width": 3 }, "parents": [ "ROOT_ID", "GRID_ID", "ROW-eh0w37bWbR" ], "type": "CHART" }, "CHART-n-zGGE6S1y": { "children": [], "id": "CHART-n-zGGE6S1y", "meta": { "chartId": 5542, "height": 50, "sliceName": "Girl Name Cloud", "width": 4 }, "parents": [ "ROOT_ID", "GRID_ID", "ROW-eh0w37bWbR" ], "type": "CHART" }, "CHART-vJIPjmcbD3": { "children": [], "id": "CHART-vJIPjmcbD3", "meta": { "chartId": 5543, "height": 50, "sliceName": "Boys", "width": 3 }, "parents": [ "ROOT_ID", "GRID_ID", "ROW-kzWtcvo8R1" ], "type": "CHART" }, "DASHBOARD_VERSION_KEY": "v2", "GRID_ID": { "children": [ "ROW-2n0XgiHDgs", "ROW--EyBZQlDi", "ROW-eh0w37bWbR", "ROW-kzWtcvo8R1" ], "id": "GRID_ID", "parents": [ "ROOT_ID" ], "type": "GRID" }, "HEADER_ID": { "id": "HEADER_ID", "meta": { "text": "Births" }, "type": "HEADER" }, "MARKDOWN-zaflB60tbC": { "children": [], "id": "MARKDOWN-zaflB60tbC", "meta": { "code": "

Birth Names Dashboard

", "height": 34, "width": 6 }, "parents": [ "ROOT_ID", "GRID_ID", "ROW-2n0XgiHDgs" ], "type": "MARKDOWN" }, "ROOT_ID": { "children": [ "GRID_ID" ], "id": "ROOT_ID", "type": "ROOT" }, "ROW--EyBZQlDi": { "children": [ "CHART-_T6n_K9iQN", "CHART-6n9jxb30JG" ], "id": "ROW--EyBZQlDi", "meta": { "background": "BACKGROUND_TRANSPARENT" }, "parents": [ "ROOT_ID", "GRID_ID" ], "type": "ROW" }, "ROW-2n0XgiHDgs": { "children": [ "CHART-eNY0tcE_ic", "MARKDOWN-zaflB60tbC", "CHART-PAXUUqwmX9" ], "id": "ROW-2n0XgiHDgs", "meta": { "background": "BACKGROUND_TRANSPARENT" }, "parents": [ "ROOT_ID", "GRID_ID" ], "type": "ROW" }, "ROW-eh0w37bWbR": { "children": [ "CHART-g075mMgyYb", "CHART-n-zGGE6S1y", "CHART-6GdlekVise" ], "id": "ROW-eh0w37bWbR", "meta": { "background": "BACKGROUND_TRANSPARENT" }, "parents": [ "ROOT_ID", "GRID_ID" ], "type": "ROW" }, "ROW-kzWtcvo8R1": { "children": [ "CHART-vJIPjmcbD3", "CHART-Jj9qh1ol-N", "CHART-ODvantb_bF" ], "id": "ROW-kzWtcvo8R1", "meta": { "background": "BACKGROUND_TRANSPARENT" }, "parents": [ "ROOT_ID", "GRID_ID" ], "type": "ROW" } } """ ) ) # pylint: enable=line-too-long # dashboard v2 doesn't allow add markup slice dash.slices = [slc for slc in slices if slc.viz_type != "markup"] update_slice_ids(pos, dash.slices) dash.dashboard_title = "USA Births Names" dash.position_json = json.dumps(pos, indent=4) dash.slug = "births" db.session.commit() return dash