From be6b9b8fecac14341b748bc987a0b2021c3c025a Mon Sep 17 00:00:00 2001 From: Bogdan Date: Wed, 17 Jun 2020 13:46:45 -0700 Subject: [PATCH] feat: implement dttm column configuration through db extra config (#9444) * Implement table mutator and examples how to set date defaults Fix tests * Fix flaky test Co-authored-by: bogdan kyryliuk --- superset/config.py | 8 ++ superset/connectors/sqla/models.py | 3 + tests/config_tests.py | 176 +++++++++++++++++++++++++++++ tests/dict_import_export_tests.py | 3 + 4 files changed, 190 insertions(+) create mode 100644 tests/config_tests.py diff --git a/superset/config.py b/superset/config.py index 84d11a603af..b862a9c9844 100644 --- a/superset/config.py +++ b/superset/config.py @@ -876,6 +876,14 @@ SIP_15_TOAST_MESSAGE = ( 'class="alert-link">here.' ) + +# SQLA table mutator, every time we fetch the metadata for a certain table +# (superset.connectors.sqla.models.SqlaTable), we call this hook +# to allow mutating the object with this callback. +# This can be used to set any properties of the object based on naming +# conventions and such. You can find examples in the tests. +SQLA_TABLE_MUTATOR = lambda table: table + if CONFIG_PATH_ENV_VAR in os.environ: # Explicitly import config module that is not necessarily in pythonpath; useful # for case where app is being executed via pex. diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index be6df5219ef..1f1fe296fac 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1209,6 +1209,9 @@ class SqlaTable(Model, BaseDatasource): self.main_dttm_col = any_date_col self.add_missing_metrics(metrics) + # Apply config supplied mutations. + config["SQLA_TABLE_MUTATOR"](self) + db.session.merge(self) if commit: db.session.commit() diff --git a/tests/config_tests.py b/tests/config_tests.py new file mode 100644 index 00000000000..851bb6bddf4 --- /dev/null +++ b/tests/config_tests.py @@ -0,0 +1,176 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# isort:skip_file + +import unittest +from typing import Any, Dict + +from tests.base_tests import SupersetTestCase +from tests.test_app import app + +from superset import db +from superset.connectors.sqla.models import SqlaTable +from superset.utils.core import get_or_create_db + +FULL_DTTM_DEFAULTS_EXAMPLE = { + "main_dttm_col": "id", + "dttm_columns": { + "dttm": { + "python_date_format": "epoch_s", + "expression": "CAST(dttm as INTEGER)", + }, + "id": {"python_date_format": "epoch_ms"}, + "month": { + "python_date_format": "%Y-%m-%d", + "expression": "CASE WHEN length(month) = 7 THEN month || '-01' ELSE month END", + }, + }, +} + + +def apply_dttm_defaults(table: SqlaTable, dttm_defaults: Dict[str, Any]): + """Applies dttm defaults to the table, mutates in place.""" + for dbcol in table.columns: + # Set is_dttm is column is listed in dttm_columns. + if dbcol.column_name in dttm_defaults.get("dttm_columns", {}): + dbcol.is_dttm = True + + # Skip non dttm columns. + if dbcol.column_name not in dttm_defaults.get("dttm_columns", {}): + continue + + # Set table main_dttm_col. + if dbcol.column_name == dttm_defaults.get("main_dttm_col"): + table.main_dttm_col = dbcol.column_name + + # Apply defaults if empty. + dttm_column_defaults = dttm_defaults.get("dttm_columns", {}).get( + dbcol.column_name, {} + ) + dbcol.is_dttm = True + if ( + not dbcol.python_date_format + and "python_date_format" in dttm_column_defaults + ): + dbcol.python_date_format = dttm_column_defaults["python_date_format"] + if not dbcol.expression and "expression" in dttm_column_defaults: + dbcol.expression = dttm_column_defaults["expression"] + + +class ConfigTests(SupersetTestCase): + def __init__(self, *args, **kwargs): + super(ConfigTests, self).__init__(*args, **kwargs) + + def setUp(self) -> None: + self.login(username="admin") + self._test_db_id = get_or_create_db( + "column_test_db", app.config["SQLALCHEMY_DATABASE_URI"] + ).id + self._old_sqla_table_mutator = app.config["SQLA_TABLE_MUTATOR"] + + def createTable(self, dttm_defaults): + app.config["SQLA_TABLE_MUTATOR"] = lambda t: apply_dttm_defaults( + t, dttm_defaults + ) + resp = self.client.post( + "/tablemodelview/add", + data=dict(database=self._test_db_id, table_name="logs"), + follow_redirects=True, + ) + self.assertEqual(resp.status_code, 200) + self._logs_table = ( + db.session.query(SqlaTable).filter_by(table_name="logs").one() + ) + + def tearDown(self): + app.config["SQLA_TABLE_MUTATOR"] = self._old_sqla_table_mutator + if hasattr(self, "_logs_table"): + db.session.delete(self._logs_table) + db.session.delete(self._logs_table.database) + db.session.commit() + + def test_main_dttm_col(self): + # Make sure that dttm column is set properly. + self.createTable({"main_dttm_col": "id", "dttm_columns": {"id": {}}}) + self.assertEqual(self._logs_table.main_dttm_col, "id") + + def test_main_dttm_col_nonexistent(self): + self.createTable({"main_dttm_col": "nonexistent"}) + # Column doesn't exist, falls back to dttm. + self.assertEqual(self._logs_table.main_dttm_col, "dttm") + + def test_main_dttm_col_nondttm(self): + self.createTable({"main_dttm_col": "duration_ms"}) + # duration_ms is not dttm column, falls back to dttm. + self.assertEqual(self._logs_table.main_dttm_col, "dttm") + + def test_python_date_format_by_column_name(self): + table_defaults = { + "dttm_columns": { + "id": {"python_date_format": "epoch_ms"}, + "dttm": {"python_date_format": "epoch_s"}, + "duration_ms": {"python_date_format": "invalid"}, + } + } + self.createTable(table_defaults) + id_col = [c for c in self._logs_table.columns if c.column_name == "id"][0] + self.assertTrue(id_col.is_dttm) + self.assertEquals(id_col.python_date_format, "epoch_ms") + dttm_col = [c for c in self._logs_table.columns if c.column_name == "dttm"][0] + self.assertTrue(dttm_col.is_dttm) + self.assertEquals(dttm_col.python_date_format, "epoch_s") + dms_col = [ + c for c in self._logs_table.columns if c.column_name == "duration_ms" + ][0] + self.assertTrue(dms_col.is_dttm) + self.assertEquals(dms_col.python_date_format, "invalid") + + def test_expression_by_column_name(self): + table_defaults = { + "dttm_columns": { + "dttm": {"expression": "CAST(dttm as INTEGER)"}, + "duration_ms": {"expression": "CAST(duration_ms as DOUBLE)"}, + } + } + self.createTable(table_defaults) + dttm_col = [c for c in self._logs_table.columns if c.column_name == "dttm"][0] + self.assertTrue(dttm_col.is_dttm) + self.assertEqual(dttm_col.expression, "CAST(dttm as INTEGER)") + dms_col = [ + c for c in self._logs_table.columns if c.column_name == "duration_ms" + ][0] + self.assertEqual(dms_col.expression, "CAST(duration_ms as DOUBLE)") + self.assertTrue(dms_col.is_dttm) + + def test_full_setting(self): + self.createTable(FULL_DTTM_DEFAULTS_EXAMPLE) + + self.assertEqual(self._logs_table.main_dttm_col, "id") + + id_col = [c for c in self._logs_table.columns if c.column_name == "id"][0] + self.assertTrue(id_col.is_dttm) + self.assertEquals(id_col.python_date_format, "epoch_ms") + self.assertIsNone(id_col.expression) + + dttm_col = [c for c in self._logs_table.columns if c.column_name == "dttm"][0] + self.assertTrue(dttm_col.is_dttm) + self.assertEquals(dttm_col.python_date_format, "epoch_s") + self.assertEqual(dttm_col.expression, "CAST(dttm as INTEGER)") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/dict_import_export_tests.py b/tests/dict_import_export_tests.py index e8574797e7b..3597eac3eae 100644 --- a/tests/dict_import_export_tests.py +++ b/tests/dict_import_export_tests.py @@ -280,6 +280,9 @@ class DictImportExportTests(SupersetTestCase): ) def test_export_datasource_ui_cli(self): + # TODO(bkyryliuk): find fake db is leaking from + self.delete_fake_db() + cli_export = export_to_dict( session=db.session, recursive=True,