refactor: use contextmanager for event_logger decorators (#11222)

This commit is contained in:
Jesse Yang
2020-10-14 10:44:06 -07:00
committed by GitHub
parent bb2e6cfca9
commit 634676d467
4 changed files with 119 additions and 60 deletions

View File

@@ -33,7 +33,6 @@ from superset.extensions import (
talisman,
)
from superset.security import SupersetSecurityManager
from superset.utils.log import DBEventLogger, get_event_logger_from_cfg_value
# All of the fields located here should be considered legacy. The correct way
# to declare "global" dependencies is to define it in extensions.py,

View File

@@ -19,9 +19,10 @@ import inspect
import json
import logging
import textwrap
import time
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Callable, cast, Optional, Type
from contextlib import contextmanager
from typing import Any, Callable, cast, Iterator, Optional, Type
from flask import current_app, g, request
from sqlalchemy.exc import SQLAlchemyError
@@ -36,58 +37,76 @@ class AbstractEventLogger(ABC):
) -> None:
pass
def log_this(self, f: Callable[..., Any]) -> Callable[..., Any]:
@contextmanager
def log_context(self, action: str) -> Iterator[Callable[..., None]]:
"""
Log an event while reading information from the request context.
`kwargs` will be appended directly to the log payload.
"""
from superset.views.core import get_form_data
start_time = time.time()
referrer = request.referrer[:1000] if request.referrer else None
user_id = g.user.get_id() if hasattr(g, "user") and g.user else None
payload = request.form.to_dict() or {}
# request parameters can overwrite post body
payload.update(request.args.to_dict())
# yield a helper to update additional kwargs
yield lambda **kwargs: payload.update(kwargs)
dashboard_id = payload.get("dashboard_id")
if "form_data" in payload:
form_data, _ = get_form_data()
payload["form_data"] = form_data
slice_id = form_data.get("slice_id")
else:
slice_id = payload.get("slice_id")
try:
slice_id = int(slice_id) # type: ignore
except (TypeError, ValueError):
slice_id = 0
self.stats_logger.incr(action)
# bulk insert
try:
explode_by = payload.get("explode")
records = json.loads(payload.get(explode_by)) # type: ignore
except Exception: # pylint: disable=broad-except
records = [payload]
self.log(
user_id,
action,
records=records,
dashboard_id=dashboard_id,
slice_id=slice_id,
duration_ms=round((time.time() - start_time) * 1000),
referrer=referrer,
)
def log_this(self, f: Callable[..., Any]) -> Callable[..., Any]:
@functools.wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> Any:
with self.log_context(f.__name__) as log:
value = f(*args, **kwargs)
log(**kwargs)
return value
return wrapper
def log_manually(self, f: Callable[..., Any]) -> Callable[..., Any]:
"""Allow a function to manually update"""
@functools.wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> Any:
user_id = None
if hasattr(g, "user") and g.user:
user_id = g.user.get_id()
payload = request.form.to_dict() or {}
# request parameters can overwrite post body
request_params = request.args.to_dict()
payload.update(request_params)
payload.update(kwargs)
dashboard_id = payload.get("dashboard_id")
if "form_data" in payload:
form_data, _ = get_form_data()
payload["form_data"] = form_data
slice_id = form_data.get("slice_id")
else:
slice_id = payload.get("slice_id")
try:
slice_id = int(slice_id) # type: ignore
except (TypeError, ValueError):
slice_id = 0
self.stats_logger.incr(f.__name__)
start_dttm = datetime.now()
value = f(*args, **kwargs)
duration_ms = (datetime.now() - start_dttm).total_seconds() * 1000
# bulk insert
try:
explode_by = payload.get("explode")
records = json.loads(payload.get(explode_by)) # type: ignore
except Exception: # pylint: disable=broad-except
records = [payload]
referrer = request.referrer[:1000] if request.referrer else None
self.log(
user_id,
f.__name__,
records=records,
dashboard_id=dashboard_id,
slice_id=slice_id,
duration_ms=duration_ms,
referrer=referrer,
)
with self.log_context(f.__name__) as log:
# updated_log_payload should be either the last positional
# argument or one of the named arguments of the decorated function
value = f(*args, update_log_payload=log, **kwargs)
return value
return wrapper
@@ -141,6 +160,8 @@ def get_event_logger_from_cfg_value(cfg_value: Any) -> AbstractEventLogger:
class DBEventLogger(AbstractEventLogger):
"""Event logger that commits logs to Superset DB"""
def log( # pylint: disable=too-many-locals
self, user_id: Optional[int], action: str, *args: Any, **kwargs: Any
) -> None:

View File

@@ -19,7 +19,7 @@ import logging
import re
from contextlib import closing
from datetime import datetime
from typing import Any, cast, Dict, List, Optional, Union
from typing import Any, Callable, cast, Dict, List, Optional, Union
from urllib import parse
import backoff
@@ -1602,8 +1602,13 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@has_access
@expose("/dashboard/<dashboard_id_or_slug>/")
@event_logger.log_manually
def dashboard( # pylint: disable=too-many-locals
self, dashboard_id_or_slug: str
self,
dashboard_id_or_slug: str,
# this parameter is added by `log_manually`,
# set a default value to appease pylint
update_log_payload: Callable[..., None] = lambda **kwargs: None,
) -> FlaskResponse:
"""Server side rendering for a dashboard"""
session = db.session()
@@ -1652,12 +1657,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
request.args.get(utils.ReservedUrlParameters.EDIT_MODE.value) == "true"
)
# Hack to log the dashboard_id properly, even when getting a slug
@event_logger.log_this
def dashboard(**_: Any) -> None:
pass
dashboard(
update_log_payload(
dashboard_id=dash.id,
dashboard_version="v2",
dash_edit_perm=dash_edit_perm,

View File

@@ -15,9 +15,17 @@
# specific language governing permissions and limitations
# under the License.
import logging
import time
import unittest
from datetime import datetime
from unittest.mock import patch
from superset.utils.log import DBEventLogger, get_event_logger_from_cfg_value
from superset.utils.log import (
AbstractEventLogger,
DBEventLogger,
get_event_logger_from_cfg_value,
)
from tests.test_app import app
class TestEventLogger(unittest.TestCase):
@@ -42,3 +50,34 @@ class TestEventLogger(unittest.TestCase):
# test that assignment of non AbstractEventLogger derived type raises TypeError
with self.assertRaises(TypeError):
get_event_logger_from_cfg_value(logging.getLogger())
@patch.object(DBEventLogger, "log")
def test_log_this_decorator(self, mock_log):
logger = DBEventLogger()
@logger.log_this
def test_func():
time.sleep(0.05)
return 1
with app.test_request_context():
result = test_func()
self.assertEqual(result, 1)
assert mock_log.call_args[1]["duration_ms"] >= 50
@patch.object(DBEventLogger, "log")
def test_log_manually_decorator(self, mock_log):
logger = DBEventLogger()
@logger.log_manually
def test_func(arg1, update_log_payload, karg1=1):
time.sleep(0.1)
update_log_payload(foo="bar")
return arg1 * karg1
with app.test_request_context():
result = test_func(1, karg1=2) # pylint: disable=no-value-for-parameter
self.assertEqual(result, 2)
# should contain only manual payload
self.assertEqual(mock_log.call_args[1]["records"], [{"foo": "bar"}])
assert mock_log.call_args[1]["duration_ms"] >= 100