From 5239ed15db4dff30c7563f516d7132b4e8480840 Mon Sep 17 00:00:00 2001 From: AAfghahi Date: Wed, 16 Nov 2022 11:27:00 -0500 Subject: [PATCH] add physical dataset --- superset/db_engine_specs/base.py | 34 +++++ superset/models/core.py | 19 ++- superset/views/core.py | 2 +- .../charts/data/api_tests.py | 5 +- tests/integration_tests/conftest.py | 125 +++++++++++++++++- 5 files changed, 177 insertions(+), 8 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index b4f4ec25c45..43ca6967b65 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1310,6 +1310,21 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return sqla_type, generic_type return None + @classmethod + def get_url_for_impersonation( + cls, url: URL, impersonate_user: bool, username: Optional[str] + ) -> URL: + """ + Return a modified URL with the username set. + :param url: SQLAlchemy URL object + :param impersonate_user: Flag indicating if impersonation is enabled + :param username: Effective username + """ + if impersonate_user and username is not None: + url = url.set(username=username) + + return url + @staticmethod def _mutate_label(label: str) -> str: """ @@ -1441,6 +1456,25 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods logger.error(ex, exc_info=True) raise ex + @staticmethod + def update_params_from_encrypted_extra( # pylint: disable=invalid-name + database: "Database", params: Dict[str, Any] + ) -> None: + """ + Some databases require some sensitive information which do not conform to + the username:password syntax normally used by SQLAlchemy. + :param database: database instance from which to extract extras + :param params: params to be updated + """ + if not database.encrypted_extra: + return + try: + encrypted_extra = json.loads(database.encrypted_extra) + params.update(encrypted_extra) + except json.JSONDecodeError as ex: + logger.error(ex, exc_info=True) + raise ex + @classmethod def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool: """Pessimistic readonly, 100% sure statement won't mutate anything""" diff --git a/superset/models/core.py b/superset/models/core.py index d21ac56dad5..c31f727df9f 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -21,7 +21,7 @@ import json import logging import textwrap from ast import literal_eval -from contextlib import closing +from contextlib import closing, contextmanager from copy import deepcopy from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type @@ -339,6 +339,18 @@ class Database( else None ) + @contextmanager + def get_sqla_engine_with_context( + self, + schema: Optional[str] = None, + nullpool: bool = True, + source: Optional[utils.QuerySource] = None, + ) -> Engine: + try: + yield self.get_sqla_engine(schema=schema, nullpool=nullpool, source=source) + except Exception as ex: + raise self.db_engine_spec.get_dbapi_mapped_exception(ex) + @memoized( watch=( "impersonate_user", @@ -672,6 +684,11 @@ class Database( def update_encrypted_extra_params(self, params: Dict[str, Any]) -> None: self.db_engine_spec.update_encrypted_extra_params(self, params) + def update_params_from_encrypted_extra( # pylint: disable=invalid-name + self, params: Dict[str, Any] + ) -> None: + self.db_engine_spec.update_params_from_encrypted_extra(self, params) + def get_table(self, table_name: str, schema: Optional[str] = None) -> Table: extra = self.get_extra() meta = MetaData(**extra.get("metadata_params", {})) diff --git a/superset/views/core.py b/superset/views/core.py index dd0862852c5..83c0a33386e 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -322,7 +322,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods @has_access @event_logger.log_this @expose("/approve", methods=["POST"]) - def approve(self) -> FlaskResponse: # pylint: disable=too-many-locals,no-self-use + def approve(self) -> FlaskResponse: # pylint: disable=too-many-locals def clean_fulfilled_requests(session: Session) -> None: for dar in session.query(DAR).all(): datasource = ConnectorRegistry.get_datasource( diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index f02b62af75f..4aba5b2e229 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -995,12 +995,12 @@ def physical_query_context(physical_dataset) -> Dict[str, Any]: }, }, ) -def test_cache_default_timeout(login_as_admin, physical_query_context): +def test_cache_default_timeout(test_client, login_as_admin, physical_query_context): rv = test_client.post(CHART_DATA_URI, json=physical_query_context) assert rv.json["result"][0]["cache_timeout"] == 1234 -def test_custom_cache_timeout(login_as_admin, physical_query_context): +def test_custom_cache_timeout(test_client, login_as_admin, physical_query_context): physical_query_context["custom_cache_timeout"] = 5678 rv = test_client.post(CHART_DATA_URI, json=physical_query_context) assert rv.json["result"][0]["cache_timeout"] == 5678 @@ -1018,6 +1018,7 @@ def test_custom_cache_timeout(login_as_admin, physical_query_context): }, ) def test_data_cache_default_timeout( + test_client, login_as_admin, physical_query_context, ): diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 0abdcff4c89..3e2744d2b19 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -16,15 +16,18 @@ # under the License. from __future__ import annotations +import contextlib import functools -from typing import Any, Callable, Generator, Optional, TYPE_CHECKING +import os +from typing import Any, Callable, Optional, TYPE_CHECKING from unittest.mock import patch import pytest from flask.ctx import AppContext +from flask_appbuilder.security.sqla import models as ab_models from sqlalchemy.engine import Engine -from superset import db +from superset import db, security_manager from superset.extensions import feature_flag_manager from superset.utils.core import json_dumps_w_dates from superset.utils.database import get_example_database, remove_database @@ -70,8 +73,6 @@ def setup_sample_data() -> Any: yield with app.app_context(): - engine = get_example_database().get_sqla_engine() - # drop sqlachemy tables db.session.commit() @@ -101,6 +102,50 @@ def login_as_admin(login_as: Callable[..., None]): yield login_as("admin") +@pytest.fixture +def create_user(app_context: AppContext): + def _create_user(username: str, role: str = "Admin", password: str = "general"): + security_manager.add_user( + username, + "firstname", + "lastname", + "email@exaple.com", + security_manager.find_role(role), + password, + ) + return security_manager.find_user(username) + + return _create_user + + +@pytest.fixture +def get_user(app_context: AppContext): + def _get_user(username: str) -> ab_models.User: + return ( + db.session.query(security_manager.user_model) + .filter_by(username=username) + .one_or_none() + ) + + return _get_user + + +@pytest.fixture +def get_or_create_user(get_user, create_user) -> ab_models.User: + @contextlib.contextmanager + def _get_user(username: str) -> ab_models.User: + user = get_user(username) + if not user: + # if user is created by test, remove it after done + user = create_user(username) + yield user + db.session.delete(user) + else: + yield user + + return _get_user + + def drop_from_schema(engine: Engine, schema_name: str): schemas = engine.execute(f"SHOW SCHEMAS").fetchall() if schema_name not in [s[0] for s in schemas]: @@ -206,3 +251,75 @@ def with_feature_flags(**mock_feature_flags): return functools.update_wrapper(wrapper, test_fn) return decorate + + +@pytest.fixture +def physical_dataset(): + from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn + from superset.connectors.sqla.utils import get_identifier_quoter + + example_database = get_example_database() + + with example_database.get_sqla_engine_with_context() as engine: + quoter = get_identifier_quoter(engine.name) + # sqlite can only execute one statement at a time + engine.execute( + f""" + CREATE TABLE IF NOT EXISTS physical_dataset( + col1 INTEGER, + col2 VARCHAR(255), + col3 DECIMAL(4,2), + col4 VARCHAR(255), + col5 TIMESTAMP DEFAULT '1970-01-01 00:00:01', + col6 TIMESTAMP DEFAULT '1970-01-01 00:00:01', + {quoter('time column with spaces')} TIMESTAMP DEFAULT '1970-01-01 00:00:01' + ); + """ + ) + engine.execute( + """ + INSERT INTO physical_dataset values + (0, 'a', 1.0, NULL, '2000-01-01 00:00:00', '2002-01-03 00:00:00', '2002-01-03 00:00:00'), + (1, 'b', 1.1, NULL, '2000-01-02 00:00:00', '2002-02-04 00:00:00', '2002-02-04 00:00:00'), + (2, 'c', 1.2, NULL, '2000-01-03 00:00:00', '2002-03-07 00:00:00', '2002-03-07 00:00:00'), + (3, 'd', 1.3, NULL, '2000-01-04 00:00:00', '2002-04-12 00:00:00', '2002-04-12 00:00:00'), + (4, 'e', 1.4, NULL, '2000-01-05 00:00:00', '2002-05-11 00:00:00', '2002-05-11 00:00:00'), + (5, 'f', 1.5, NULL, '2000-01-06 00:00:00', '2002-06-13 00:00:00', '2002-06-13 00:00:00'), + (6, 'g', 1.6, NULL, '2000-01-07 00:00:00', '2002-07-15 00:00:00', '2002-07-15 00:00:00'), + (7, 'h', 1.7, NULL, '2000-01-08 00:00:00', '2002-08-18 00:00:00', '2002-08-18 00:00:00'), + (8, 'i', 1.8, NULL, '2000-01-09 00:00:00', '2002-09-20 00:00:00', '2002-09-20 00:00:00'), + (9, 'j', 1.9, NULL, '2000-01-10 00:00:00', '2002-10-22 00:00:00', '2002-10-22 00:00:00'); + """ + ) + + dataset = SqlaTable( + table_name="physical_dataset", + database=example_database, + ) + TableColumn(column_name="col1", type="INTEGER", table=dataset) + TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset) + TableColumn(column_name="col3", type="DECIMAL(4,2)", table=dataset) + TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset) + TableColumn(column_name="col5", type="TIMESTAMP", is_dttm=True, table=dataset) + TableColumn(column_name="col6", type="TIMESTAMP", is_dttm=True, table=dataset) + TableColumn( + column_name="time column with spaces", + type="TIMESTAMP", + is_dttm=True, + table=dataset, + ) + SqlMetric(metric_name="count", expression="count(*)", table=dataset) + db.session.merge(dataset) + db.session.commit() + + yield dataset + + engine.execute( + """ + DROP TABLE physical_dataset; + """ + ) + dataset = db.session.query(SqlaTable).filter_by(table_name="physical_dataset").all() + for ds in dataset: + db.session.delete(ds) + db.session.commit()