From 4e9ea4b17ac21181a45b977110ca4604a657d806 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Wed, 23 Jul 2025 22:43:14 -0400 Subject: [PATCH] Fix cache --- superset/db_engine_specs/metricflow.py | 39 ++++++++++---------------- 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/superset/db_engine_specs/metricflow.py b/superset/db_engine_specs/metricflow.py index 87ba25c8d6d..101088ede98 100644 --- a/superset/db_engine_specs/metricflow.py +++ b/superset/db_engine_specs/metricflow.py @@ -20,15 +20,12 @@ An interface to dbt's semantic layer, Metric Flow. from __future__ import annotations -from datetime import timedelta from typing import Any, TYPE_CHECKING, TypedDict from shillelagh.backends.apsw.dialects.base import get_adapter_for_table_name from shillelagh.backends.apsw.dialects.metricflow import TABLE_NAME from sqlalchemy import event -from sqlalchemy.engine import Engine -from sqlalchemy.pool import _ConnectionRecord -from sqlalchemy.pool.base import _ConnectionFairy +from sqlalchemy.engine import Connection, Engine from superset.connectors.sqla.models import SqlaTable from superset.constants import TimeGrain @@ -46,45 +43,39 @@ if TYPE_CHECKING: from superset.superset_typing import ResultSetColumnType -@event.listens_for(Engine, "connect") -def receive_connect( - dbapi_connection: _ConnectionFairy, - connection_record: _ConnectionRecord, -) -> None: +@event.listens_for(Engine, "engine_connect") +def receive_engine_connect(conn: Connection, branch: bool) -> None: """ Called when a new DB connection is created. This hook adds a cache to the `_build_column_from_dimension` method of the Metric Flow adapter, since it's called frequently and can be expensive. """ - engine = connection_record.info.get("engine") - if ( - not engine - or not engine.name == "metricflow" - or getattr(engine.dialect, "_patched", False) - ): + engine = conn.engine + if not engine or not engine.name == "metricflow": return - original_method = engine.dialect._build_column_from_dimension + from shillelagh.adapters.api.dbt_metricflow import DbtMetricFlowAPI + + if getattr(DbtMetricFlowAPI, "_patched", False): + return + + original_method = DbtMetricFlowAPI._build_column_from_dimension @memoized_func( key="metricflow:dimension:{name}", cache=cache_manager.data_cache, ) def cached_build_column_from_dimension( - self, + self: DbtMetricFlowAPI, name: str, *args: Any, **kwargs: Any, ) -> Field: - return original_method( - self, - name, - cache_timeout=int(timedelta(days=1).total_seconds()), - ) + return original_method(self, name) - engine.dialect._build_column_from_dimension = cached_build_column_from_dimension - engine.dialect._patched = True + DbtMetricFlowAPI._build_column_from_dimension = cached_build_column_from_dimension + DbtMetricFlowAPI._patched = True SELECT_STAR_MESSAGE = (