Compare commits

...

1 Commits

Author SHA1 Message Date
Beto Dealmeida
b4299cafd6 WIP 2024-05-17 13:28:07 -04:00
5 changed files with 440 additions and 3 deletions

View File

@@ -64,9 +64,11 @@ setup(
"postgres.psycopg2 = sqlalchemy.dialects.postgresql:dialect",
"postgres = sqlalchemy.dialects.postgresql:dialect",
"superset = superset.extensions.metadb:SupersetAPSWDialect",
"dbt = superset.extensions.dbt:DbtMetricFlowDialect",
],
"shillelagh.adapter": [
"superset=superset.extensions.metadb:SupersetShillelaghAdapter"
"superset = superset.extensions.metadb:SupersetShillelaghAdapter",
"presetdbtmetricflowapi = superset.extensions.dbt:PresetDbtMetricFlowAPI",
],
},
download_url="https://www.apache.org/dist/superset/" + version_string,

View File

@@ -82,6 +82,7 @@ if TYPE_CHECKING:
from superset.connectors.sqla.models import TableColumn
from superset.databases.schemas import TableMetadataResponse
from superset.models.core import Database
from superset.models.helpers import ExploreMixin
from superset.models.sql_lab import Query
@@ -131,7 +132,9 @@ builtin_time_grains: dict[str | None, str] = {
}
class TimestampExpression(ColumnClause): # pylint: disable=abstract-method, too-many-ancestors
class TimestampExpression(
ColumnClause
): # pylint: disable=abstract-method, too-many-ancestors
def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None:
"""Sqlalchemy class that can be used to render native column elements respecting
engine-specific quoting rules as part of a string-based expression.
@@ -182,6 +185,15 @@ class MetricType(TypedDict, total=False):
extra: str | None
class ValidColumnsType(TypedDict):
"""
Type for valid columns returned by `get_valid_columns`.
"""
columns: set[str]
metrics: set[str]
class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""Abstract class for database engine specific configurations
@@ -419,6 +431,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# Driver-specific exception that should be mapped to OAuth2RedirectError
oauth2_exception = OAuth2RedirectError
# This attribute is used for semantic layers, where only certain combinations of
# metrics and dimensions are valid for given datasource. For traditional databases
# this should be set to false.
supports_dynamic_columns = False
@classmethod
def is_oauth2_enabled(cls) -> bool:
return (
@@ -1573,6 +1590,32 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
}
]
@classmethod
def get_valid_columns(
cls,
database: Database,
inspector: Inspector,
datasource: ExploreMixin,
columns: set[str],
metrics: set[str],
) -> ValidColumnsType:
"""
Given a selection of columns/metrics from a datasource, return related columns.
This is a method used for semantic layers, where tables can have columns and
metrics that cannot be computed together. When the user selects a given metric
it allows the UI to filter the remaining metrics and dimensions so that only
valid combinations are possible.
The method should only be called when ``supports_dynamic_columns`` is set to
true. The default method in the base class ignores the selected columns and
metrics, and simply returns everything, for reference.
"""
return {
"columns": {column.column_name for column in datasource.columns},
"metrics": {metric.metric_name for metric in datasource.metrics},
}
@classmethod
def where_latest_partition( # pylint: disable=unused-argument
cls,

View File

@@ -0,0 +1,161 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
An interface to dbt's semantic layer.
"""
from __future__ import annotations
from typing import Any, TypedDict, TYPE_CHECKING
from shillelagh.backends.apsw.dialects.base import get_adapter_for_table_name
from superset.constants import TimeGrain
from superset.db_engine_specs.base import ValidColumnsType
from superset.db_engine_specs.shillelagh import ShillelaghEngineSpec
from superset.extensions.dbt import TABLE_NAME
from superset.models.helpers import ExploreMixin
from superset.superset_typing import ResultSetColumnType
if TYPE_CHECKING:
from superset.models.core import Database
from superset.sql_parse import Table
from sqlalchemy.engine.reflection import Inspector
SELECT_STAR_MESSAGE = (
'The dbt semantic layer does not support data preview, since the "metrics" table is '
"a virtual table that is not materialized. An administrator should configure the "
'database in Apache Superset so that the "Disable SQL Lab data preview queries" '
'option under "Advanced""SQL Lab" is enabled.'
)
class MetricType(TypedDict, total=False):
"""
Type for metrics return by `get_metrics`.
"""
metric_name: str
expression: str
verbose_name: str | None
metric_type: str | None
description: str | None
d3format: str | None
warning_text: str | None
extra: str | None
class DbtMetricFlowEngineSpec(ShillelaghEngineSpec):
"""
Engine for the the dbt semantic layer.
"""
engine = "dbt"
engine_name = "dbt Semantic Layer"
sqlalchemy_uri_placeholder = "dbt:///<environment_id>?service_token=<service_token>"
supports_dynamic_columns = True
_time_grain_expressions = {
TimeGrain.DAY: "{col}__day",
TimeGrain.WEEK: "{col}__week",
TimeGrain.MONTH: "{col}__month",
TimeGrain.QUARTER: "{col}__quarter",
TimeGrain.YEAR: "{col}__year",
}
@classmethod
def select_star(cls, *args: Any, **kwargs: Any) -> str:
"""
Return a ``SELECT *`` query.
"""
message = SELECT_STAR_MESSAGE.replace("'", "''")
return f"SELECT '{message}' AS warning"
@classmethod
def get_columns(
cls,
inspector: Inspector,
table: Table,
options: dict[str, Any] | None = None,
) -> list[ResultSetColumnType]:
"""
Get columns.
This method enriches the method from the SQLAlchemy dialect to include the
dimension descriptions.
"""
connection = inspector.engine.connect()
adapter = get_adapter_for_table_name(connection, table.table)
return [
{
"name": column["name"],
"column_name": column["name"],
"type": column["type"],
"nullable": column["nullable"],
"default": column["default"],
"comment": adapter.dimensions.get(column["name"], ""),
}
for column in inspector.get_columns(table.table, table.schema)
]
@classmethod
def get_metrics(
cls,
database: Database,
inspector: Inspector,
table: Table,
) -> list[MetricType]:
"""
Get all metrics.
"""
connection = inspector.engine.connect()
adapter = get_adapter_for_table_name(connection, table.table)
return [
{
"metric_name": metric,
"expression": metric,
"description": description,
}
for metric, description in adapter.metrics.items()
]
@classmethod
def get_valid_columns(
cls,
database: Database,
inspector: Inspector,
datasource: ExploreMixin,
columns: set[str],
metrics: set[str],
) -> ValidColumnsType:
"""
Get valid columns.
Given a datasource, and sets of selected metrics and dimensions, return the
sets of valid metrics and dimensions that can further be selected.
"""
connection = inspector.engine.connect()
adapter = get_adapter_for_table_name(connection, TABLE_NAME)
return {
"metrics": adapter._get_metrics_for_dimensions(columns),
"dimensions": adapter._get_dimensions_for_metrics(metrics),
}

231
superset/extensions/dbt.py Normal file
View File

@@ -0,0 +1,231 @@
from __future__ import annotations
from datetime import timedelta
from typing import Any
import sqlalchemy.types
from shillelagh.adapters.api.dbt_metricflow import DbtMetricFlowAPI
from shillelagh.backends.apsw.dialects.base import (
APSWDialect,
get_adapter_for_table_name,
)
from shillelagh.fields import Field
from sqlalchemy.engine.base import Connection
from sqlalchemy.engine.url import URL
from sqlalchemy.sql.visitors import VisitableType
from superset.extensions import cache_manager
from superset.utils.cache import memoized_func
TABLE_NAME = "metrics"
def get_sqla_type(field: Field) -> VisitableType:
"""
Convert from Shillelagh to SQLAlchemy types.
"""
type_map = {
"BOOLEAN": sqlalchemy.types.BOOLEAN,
"INTEGER": sqlalchemy.types.INT,
"DECIMAL": sqlalchemy.types.DECIMAL,
"TIMESTAMP": sqlalchemy.types.TIMESTAMP,
"DATE": sqlalchemy.types.DATE,
"TIME": sqlalchemy.types.TIME,
"TEXT": sqlalchemy.types.TEXT,
}
return type_map.get(field.type, sqlalchemy.types.TEXT)
class PresetDbtMetricFlowAPI(DbtMetricFlowAPI):
"""
Custom API adapter for dbt Metric Flow API.
In the original adapter, the SQL queries a base dbt API URL, eg:
SELECT * FROM "https://semantic-layer.cloud.getdbt.com/";
SELECT * FROM "https://ab123.us1.dbt.com/"; -- custom user URL
For this adapter, we want a leaner URI, mimicking a table:
SELECT * FROM metrics;
In order to do this, we override the ``supports`` method to only accept
``$TABLE_NAME`` instead of the URL, which is then passed to the adapter when it is
instantiated.
One problem with this change is that the adapter needs the base URL in order to
determine the GraphQL endpoint. To solve this we pass the original URL via a new
argument ``url``, and override the ``_get_endpoint`` method to use it instead of the
table name.
"""
@staticmethod
def supports(uri: str, fast: bool = True, **kwargs: Any) -> bool:
return uri == TABLE_NAME
def __init__(
self,
table: str,
service_token: str,
environment_id: int,
url: str,
) -> None:
self.url = url
super().__init__(table, service_token, environment_id)
def _get_endpoint(self, url: str) -> str:
"""
Compute the GraphQL endpoint.
Instead of using ``url`` (which points to ``TABLE_NAME`` in this adapter), we
should call the method using the actual dbt API base URL.
"""
return super()._get_endpoint(self.url)
def _build_column_from_dimension(self, name: str) -> Field:
"""
Build a Shillelagh column from a dbt dimension.
This method is terrible slow, since it needs to do a full data request for each
dimension in order to determine their types. To improve UX we cache the results
for one day.
"""
return self._cached_build_column_from_dimension(
name,
cache_timeout=int(timedelta(days=1).total_seconds()),
)
@memoized_func(key="dbt:dimension:{name}", cache=cache_manager.data_cache)
def _cached_build_column_from_dimension(
self,
name: str,
*args: Any,
**kwargs: Any,
) -> Field:
"""
Cached version of ``_build_column_from_dimension``.
"""
return super()._build_column_from_dimension(name)
class DbtMetricFlowDialect(APSWDialect):
"""
A dbt Metric Flow dialect.
URL should look like:
dbt:///<environment_id>?service_token=<service_token>
Or when using a custom URL:
dbt://ab123.us1.dbt.com/<environment_id>?service_token=<service_token>
"""
name = "dbt"
supports_statement_cache = True
def create_connect_args(self, url: URL) -> tuple[tuple[()], dict[str, Any]]:
baseurl = (
f"https://{url.host}/"
if url.host
else "https://semantic-layer.cloud.getdbt.com/"
)
return (
(),
{
"path": ":memory:",
"adapters": ["presetdbtmetricflowapi"],
"adapter_kwargs": {
"presetdbtmetricflowapi": {
"service_token": url.query["service_token"],
"environment_id": int(url.database),
"url": baseurl,
},
},
"safe": True,
"isolation_level": self.isolation_level,
},
)
def get_table_names(
self,
connection: Connection,
schema: str | None = None,
sqlite_include_internal: bool = False,
**kwargs: Any,
) -> list[str]:
return [TABLE_NAME]
def has_table(
self,
connection: Connection,
table_name: str,
schema: str | None = None,
**kwargs: Any,
) -> bool:
return table_name == TABLE_NAME
def get_columns(
self,
connection: Connection,
table_name: str,
schema: str | None = None,
**kwargs: Any,
) -> list[tuple[str, str]]:
adapter = get_adapter_for_table_name(connection, table_name)
columns = {
adapter.grains[dimension][0]
if dimension in adapter.grains
else dimension: adapter.columns[dimension]
for dimension in adapter.dimensions
}
return [
{
"name": name,
"type": get_sqla_type(field),
"nullable": True,
"default": None,
}
for name, field in columns.items()
]
def get_schema_names(
self,
connection: Connection,
**kwargs: Any,
) -> list[str]:
return ["main"]
def get_pk_constraint(
self,
connection: Connection,
table_name: str,
schema: str | None = None,
**kwargs: Any,
) -> dict[str, Any]:
return {"constrained_columns": [], "name": None}
def get_foreign_keys(
self,
connection: Connection,
table_name: str,
schema: str | None = None,
**kwargs: Any,
) -> list[dict[str, Any]]:
return []
get_check_constraints = get_foreign_keys
get_indexes = get_foreign_keys
get_unique_constraints = get_foreign_keys
def get_table_comment(self, connection, table_name, schema=None, **kwargs):
return {
"text": "A virtual table that gives access to all dbt metrics and dimensions."
}

View File

@@ -111,7 +111,7 @@ class SupersetAPSWDialect(APSWDialect):
"superset": {
"prefix": None,
"allowed_dbs": self.allowed_dbs,
}
},
},
"safe": True,
"isolation_level": self.isolation_level,