mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
refactor: use contextmanager for event_logger decorators (#11222)
This commit is contained in:
@@ -33,7 +33,6 @@ from superset.extensions import (
|
||||
talisman,
|
||||
)
|
||||
from superset.security import SupersetSecurityManager
|
||||
from superset.utils.log import DBEventLogger, get_event_logger_from_cfg_value
|
||||
|
||||
# All of the fields located here should be considered legacy. The correct way
|
||||
# to declare "global" dependencies is to define it in extensions.py,
|
||||
|
||||
@@ -19,9 +19,10 @@ import inspect
|
||||
import json
|
||||
import logging
|
||||
import textwrap
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, cast, Optional, Type
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, cast, Iterator, Optional, Type
|
||||
|
||||
from flask import current_app, g, request
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
@@ -36,58 +37,76 @@ class AbstractEventLogger(ABC):
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def log_this(self, f: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@contextmanager
|
||||
def log_context(self, action: str) -> Iterator[Callable[..., None]]:
|
||||
"""
|
||||
Log an event while reading information from the request context.
|
||||
`kwargs` will be appended directly to the log payload.
|
||||
"""
|
||||
from superset.views.core import get_form_data
|
||||
|
||||
start_time = time.time()
|
||||
referrer = request.referrer[:1000] if request.referrer else None
|
||||
user_id = g.user.get_id() if hasattr(g, "user") and g.user else None
|
||||
payload = request.form.to_dict() or {}
|
||||
# request parameters can overwrite post body
|
||||
payload.update(request.args.to_dict())
|
||||
|
||||
# yield a helper to update additional kwargs
|
||||
yield lambda **kwargs: payload.update(kwargs)
|
||||
|
||||
dashboard_id = payload.get("dashboard_id")
|
||||
|
||||
if "form_data" in payload:
|
||||
form_data, _ = get_form_data()
|
||||
payload["form_data"] = form_data
|
||||
slice_id = form_data.get("slice_id")
|
||||
else:
|
||||
slice_id = payload.get("slice_id")
|
||||
|
||||
try:
|
||||
slice_id = int(slice_id) # type: ignore
|
||||
except (TypeError, ValueError):
|
||||
slice_id = 0
|
||||
|
||||
self.stats_logger.incr(action)
|
||||
|
||||
# bulk insert
|
||||
try:
|
||||
explode_by = payload.get("explode")
|
||||
records = json.loads(payload.get(explode_by)) # type: ignore
|
||||
except Exception: # pylint: disable=broad-except
|
||||
records = [payload]
|
||||
|
||||
self.log(
|
||||
user_id,
|
||||
action,
|
||||
records=records,
|
||||
dashboard_id=dashboard_id,
|
||||
slice_id=slice_id,
|
||||
duration_ms=round((time.time() - start_time) * 1000),
|
||||
referrer=referrer,
|
||||
)
|
||||
|
||||
def log_this(self, f: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
with self.log_context(f.__name__) as log:
|
||||
value = f(*args, **kwargs)
|
||||
log(**kwargs)
|
||||
return value
|
||||
|
||||
return wrapper
|
||||
|
||||
def log_manually(self, f: Callable[..., Any]) -> Callable[..., Any]:
|
||||
"""Allow a function to manually update"""
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
user_id = None
|
||||
if hasattr(g, "user") and g.user:
|
||||
user_id = g.user.get_id()
|
||||
payload = request.form.to_dict() or {}
|
||||
|
||||
# request parameters can overwrite post body
|
||||
request_params = request.args.to_dict()
|
||||
payload.update(request_params)
|
||||
payload.update(kwargs)
|
||||
|
||||
dashboard_id = payload.get("dashboard_id")
|
||||
|
||||
if "form_data" in payload:
|
||||
form_data, _ = get_form_data()
|
||||
payload["form_data"] = form_data
|
||||
slice_id = form_data.get("slice_id")
|
||||
else:
|
||||
slice_id = payload.get("slice_id")
|
||||
|
||||
try:
|
||||
slice_id = int(slice_id) # type: ignore
|
||||
except (TypeError, ValueError):
|
||||
slice_id = 0
|
||||
|
||||
self.stats_logger.incr(f.__name__)
|
||||
start_dttm = datetime.now()
|
||||
value = f(*args, **kwargs)
|
||||
duration_ms = (datetime.now() - start_dttm).total_seconds() * 1000
|
||||
|
||||
# bulk insert
|
||||
try:
|
||||
explode_by = payload.get("explode")
|
||||
records = json.loads(payload.get(explode_by)) # type: ignore
|
||||
except Exception: # pylint: disable=broad-except
|
||||
records = [payload]
|
||||
|
||||
referrer = request.referrer[:1000] if request.referrer else None
|
||||
|
||||
self.log(
|
||||
user_id,
|
||||
f.__name__,
|
||||
records=records,
|
||||
dashboard_id=dashboard_id,
|
||||
slice_id=slice_id,
|
||||
duration_ms=duration_ms,
|
||||
referrer=referrer,
|
||||
)
|
||||
with self.log_context(f.__name__) as log:
|
||||
# updated_log_payload should be either the last positional
|
||||
# argument or one of the named arguments of the decorated function
|
||||
value = f(*args, update_log_payload=log, **kwargs)
|
||||
return value
|
||||
|
||||
return wrapper
|
||||
@@ -141,6 +160,8 @@ def get_event_logger_from_cfg_value(cfg_value: Any) -> AbstractEventLogger:
|
||||
|
||||
|
||||
class DBEventLogger(AbstractEventLogger):
|
||||
"""Event logger that commits logs to Superset DB"""
|
||||
|
||||
def log( # pylint: disable=too-many-locals
|
||||
self, user_id: Optional[int], action: str, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
|
||||
@@ -19,7 +19,7 @@ import logging
|
||||
import re
|
||||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
from typing import Any, cast, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, Union
|
||||
from urllib import parse
|
||||
|
||||
import backoff
|
||||
@@ -1602,8 +1602,13 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
||||
|
||||
@has_access
|
||||
@expose("/dashboard/<dashboard_id_or_slug>/")
|
||||
@event_logger.log_manually
|
||||
def dashboard( # pylint: disable=too-many-locals
|
||||
self, dashboard_id_or_slug: str
|
||||
self,
|
||||
dashboard_id_or_slug: str,
|
||||
# this parameter is added by `log_manually`,
|
||||
# set a default value to appease pylint
|
||||
update_log_payload: Callable[..., None] = lambda **kwargs: None,
|
||||
) -> FlaskResponse:
|
||||
"""Server side rendering for a dashboard"""
|
||||
session = db.session()
|
||||
@@ -1652,12 +1657,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
||||
request.args.get(utils.ReservedUrlParameters.EDIT_MODE.value) == "true"
|
||||
)
|
||||
|
||||
# Hack to log the dashboard_id properly, even when getting a slug
|
||||
@event_logger.log_this
|
||||
def dashboard(**_: Any) -> None:
|
||||
pass
|
||||
|
||||
dashboard(
|
||||
update_log_payload(
|
||||
dashboard_id=dash.id,
|
||||
dashboard_version="v2",
|
||||
dash_edit_perm=dash_edit_perm,
|
||||
|
||||
@@ -15,9 +15,17 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
import time
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
from superset.utils.log import DBEventLogger, get_event_logger_from_cfg_value
|
||||
from superset.utils.log import (
|
||||
AbstractEventLogger,
|
||||
DBEventLogger,
|
||||
get_event_logger_from_cfg_value,
|
||||
)
|
||||
from tests.test_app import app
|
||||
|
||||
|
||||
class TestEventLogger(unittest.TestCase):
|
||||
@@ -42,3 +50,34 @@ class TestEventLogger(unittest.TestCase):
|
||||
# test that assignment of non AbstractEventLogger derived type raises TypeError
|
||||
with self.assertRaises(TypeError):
|
||||
get_event_logger_from_cfg_value(logging.getLogger())
|
||||
|
||||
@patch.object(DBEventLogger, "log")
|
||||
def test_log_this_decorator(self, mock_log):
|
||||
logger = DBEventLogger()
|
||||
|
||||
@logger.log_this
|
||||
def test_func():
|
||||
time.sleep(0.05)
|
||||
return 1
|
||||
|
||||
with app.test_request_context():
|
||||
result = test_func()
|
||||
self.assertEqual(result, 1)
|
||||
assert mock_log.call_args[1]["duration_ms"] >= 50
|
||||
|
||||
@patch.object(DBEventLogger, "log")
|
||||
def test_log_manually_decorator(self, mock_log):
|
||||
logger = DBEventLogger()
|
||||
|
||||
@logger.log_manually
|
||||
def test_func(arg1, update_log_payload, karg1=1):
|
||||
time.sleep(0.1)
|
||||
update_log_payload(foo="bar")
|
||||
return arg1 * karg1
|
||||
|
||||
with app.test_request_context():
|
||||
result = test_func(1, karg1=2) # pylint: disable=no-value-for-parameter
|
||||
self.assertEqual(result, 2)
|
||||
# should contain only manual payload
|
||||
self.assertEqual(mock_log.call_args[1]["records"], [{"foo": "bar"}])
|
||||
assert mock_log.call_args[1]["duration_ms"] >= 100
|
||||
|
||||
Reference in New Issue
Block a user