mirror of
https://github.com/apache/superset.git
synced 2026-04-30 13:34:20 +00:00
Compare commits
5 Commits
semantic-l
...
snowflake-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
57ca29baf4 | ||
|
|
01c587361f | ||
|
|
fc64ac918a | ||
|
|
cd019bab3e | ||
|
|
a330fe6f7e |
@@ -164,7 +164,9 @@ class DatasourceKind(StrEnum):
|
||||
PHYSICAL = "physical"
|
||||
|
||||
|
||||
class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=too-many-public-methods
|
||||
class BaseDatasource(
|
||||
AuditMixinNullable, ImportExportMixin
|
||||
): # pylint: disable=too-many-public-methods
|
||||
"""A common interface to objects that are queryable
|
||||
(tables and datasources)"""
|
||||
|
||||
@@ -1778,7 +1780,9 @@ class SqlaTable(
|
||||
def default_query(qry: Query) -> Query:
|
||||
return qry.filter_by(is_sqllab_view=False)
|
||||
|
||||
def has_extra_cache_key_calls(self, query_obj: QueryObjectDict) -> bool: # noqa: C901
|
||||
def has_extra_cache_key_calls(
|
||||
self, query_obj: QueryObjectDict
|
||||
) -> bool: # noqa: C901
|
||||
"""
|
||||
Detects the presence of calls to `ExtraCache` methods in items in query_obj that
|
||||
can be templated. If any are present, the query must be evaluated to extract
|
||||
|
||||
@@ -75,12 +75,11 @@ class DatasetDAO(BaseDAO[SqlaTable]):
|
||||
database: Database,
|
||||
table: Table,
|
||||
) -> bool:
|
||||
try:
|
||||
database.get_table(table)
|
||||
return True
|
||||
except SQLAlchemyError as ex: # pragma: no cover
|
||||
logger.warning("Got an error %s validating table: %s", str(ex), table)
|
||||
return False
|
||||
with database.get_inspector(
|
||||
catalog=table.catalog,
|
||||
schema=table.schema,
|
||||
) as inspector:
|
||||
return database.db_engine_spec.has_table(database, inspector, table)
|
||||
|
||||
@staticmethod
|
||||
def validate_uniqueness(
|
||||
|
||||
@@ -30,6 +30,7 @@ from typing import (
|
||||
cast,
|
||||
ContextManager,
|
||||
NamedTuple,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
TypedDict,
|
||||
Union,
|
||||
@@ -62,6 +63,10 @@ from superset.constants import QUERY_CANCEL_KEY, TimeGrain as TimeGrainConstants
|
||||
from superset.databases.utils import get_table_metadata, make_url_safe
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError
|
||||
from superset.extensions.semantic_layer import (
|
||||
get_sqla_type_from_dimension_type,
|
||||
SemanticLayer,
|
||||
)
|
||||
from superset.sql.parse import (
|
||||
BaseSQLStatement,
|
||||
LimitMethod,
|
||||
@@ -85,7 +90,7 @@ from superset.utils.network import is_hostname_valid, is_port_open
|
||||
from superset.utils.oauth2 import encode_oauth2_state
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.sqla.models import TableColumn
|
||||
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||
from superset.databases.schemas import TableMetadataResponse
|
||||
from superset.models.core import Database
|
||||
from superset.models.sql_lab import Query
|
||||
@@ -106,6 +111,15 @@ logger = logging.getLogger()
|
||||
GenericDBException = Exception
|
||||
|
||||
|
||||
class ValidColumnsType(TypedDict):
|
||||
"""
|
||||
Type for valid columns returned by `get_valid_metrics_and_dimensions`.
|
||||
"""
|
||||
|
||||
dimensions: set[str]
|
||||
metrics: set[str]
|
||||
|
||||
|
||||
def convert_inspector_columns(cols: list[SQLAColumnType]) -> list[ResultSetColumnType]:
|
||||
result_set_columns: list[ResultSetColumnType] = []
|
||||
for col in cols:
|
||||
@@ -143,7 +157,9 @@ builtin_time_grains: dict[str | None, str] = {
|
||||
}
|
||||
|
||||
|
||||
class TimestampExpression(ColumnClause): # pylint: disable=abstract-method, too-many-ancestors
|
||||
class TimestampExpression(
|
||||
ColumnClause
|
||||
): # pylint: disable=abstract-method, too-many-ancestors
|
||||
def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None:
|
||||
"""Sqlalchemy class that can be used to render native column elements respecting
|
||||
engine-specific quoting rules as part of a string-based expression.
|
||||
@@ -214,6 +230,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
"engine+driver://user:password@host:port/dbname[?key=value&key=value...]"
|
||||
)
|
||||
|
||||
# databases can optionally specify a semantic layer
|
||||
semantic_layer: Type[SemanticLayer] | None = None
|
||||
|
||||
disable_ssh_tunneling = False
|
||||
|
||||
_date_trunc_functions: dict[str, str] = {}
|
||||
@@ -388,9 +407,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
max_column_name_length: int | None = None
|
||||
try_remove_schema_from_table_name = True # pylint: disable=invalid-name
|
||||
run_multiple_statements_as_one = False
|
||||
custom_errors: dict[
|
||||
Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]
|
||||
] = {}
|
||||
custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = (
|
||||
{}
|
||||
)
|
||||
|
||||
# List of JSON path to fields in `encrypted_extra` that should be masked when the
|
||||
# database is edited. By default everything is masked.
|
||||
@@ -1461,8 +1480,32 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
|
||||
if schema and cls.try_remove_schema_from_table_name:
|
||||
tables = {re.sub(f"^{schema}\\.", "", table) for table in tables}
|
||||
|
||||
# add semantic views as tables too
|
||||
if cls.semantic_layer:
|
||||
semantic_layer = cls.semantic_layer(inspector.engine)
|
||||
tables.update(
|
||||
semantic_view.name
|
||||
for semantic_view in semantic_layer.get_semantic_views()
|
||||
)
|
||||
|
||||
return tables
|
||||
|
||||
@classmethod
|
||||
def has_table(
|
||||
cls,
|
||||
database: Database,
|
||||
inspector: Inspector,
|
||||
table: Table,
|
||||
) -> bool:
|
||||
if cls.semantic_layer:
|
||||
semantic_layer = cls.semantic_layer(inspector.engine)
|
||||
semantic_views = semantic_layer.get_semantic_views()
|
||||
if table.table in {semantic_view.name for semantic_view in semantic_views}:
|
||||
return True
|
||||
|
||||
return inspector.has_table(table.table, table.schema)
|
||||
|
||||
@classmethod
|
||||
def get_view_names( # pylint: disable=unused-argument
|
||||
cls,
|
||||
@@ -1536,6 +1579,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
@classmethod
|
||||
def get_columns( # pylint: disable=unused-argument
|
||||
cls,
|
||||
database: Database,
|
||||
inspector: Inspector,
|
||||
table: Table,
|
||||
options: dict[str, Any] | None = None,
|
||||
@@ -1543,7 +1587,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
"""
|
||||
Get all columns from a given schema and table.
|
||||
|
||||
The inspector will be bound to a catalog, if one was specified.
|
||||
The inspector will be bound to a catalog, if one was specified. If the database
|
||||
supports semantic layers the method will check if the table is a semantic view,
|
||||
and return columns (metrics and dimensions) from it instead.
|
||||
|
||||
:param inspector: SqlAlchemy Inspector instance
|
||||
:param table: Table instance
|
||||
@@ -1551,6 +1597,26 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
some databases
|
||||
:return: All columns in table
|
||||
"""
|
||||
if cls.semantic_layer:
|
||||
semantic_layer = cls.semantic_layer(inspector.engine)
|
||||
semantic_views = {
|
||||
semantic_view.name: semantic_view
|
||||
for semantic_view in semantic_layer.get_semantic_views()
|
||||
}
|
||||
if semantic_view := semantic_views.get(table.table):
|
||||
dialect = database.get_dialect()
|
||||
return [
|
||||
{
|
||||
"name": dimension.name,
|
||||
"column_name": dimension.name,
|
||||
"type": cls.column_datatype_to_string(
|
||||
get_sqla_type_from_dimension_type(dimension.type),
|
||||
dialect,
|
||||
),
|
||||
}
|
||||
for dimension in semantic_layer.get_dimensions(semantic_view)
|
||||
]
|
||||
|
||||
return convert_inspector_columns(
|
||||
cast(
|
||||
list[SQLAColumnType],
|
||||
@@ -1568,6 +1634,22 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
"""
|
||||
Get all metrics from a given schema and table.
|
||||
"""
|
||||
if cls.semantic_layer:
|
||||
semantic_layer = cls.semantic_layer(inspector.engine)
|
||||
semantic_views = {
|
||||
semantic_view.name: semantic_view
|
||||
for semantic_view in semantic_layer.get_semantic_views()
|
||||
}
|
||||
if semantic_view := semantic_views.get(table.table):
|
||||
return [
|
||||
{
|
||||
"metric_name": metric.name,
|
||||
"verbose_name": metric.name,
|
||||
"expression": metric.sql,
|
||||
}
|
||||
for metric in semantic_layer.get_metrics(semantic_view)
|
||||
]
|
||||
|
||||
return [
|
||||
{
|
||||
"metric_name": "count",
|
||||
@@ -1577,6 +1659,62 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
}
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_valid_metrics_and_dimensions(
|
||||
cls,
|
||||
database: Database,
|
||||
table: SqlaTable,
|
||||
dimensions: set[str],
|
||||
metrics: set[str],
|
||||
) -> ValidColumnsType:
|
||||
"""
|
||||
Get valid metrics and dimensions.
|
||||
|
||||
Given a datasource, and sets of selected metrics and dimensions, return the
|
||||
sets of valid metrics and dimensions that can further be selected.
|
||||
"""
|
||||
if cls.semantic_layer:
|
||||
with database.get_sqla_engine() as engine:
|
||||
semantic_layer = cls.semantic_layer(engine)
|
||||
semantic_views = {
|
||||
semantic_view.name: semantic_view
|
||||
for semantic_view in semantic_layer.get_semantic_views()
|
||||
}
|
||||
if semantic_view := semantic_views.get(table.table):
|
||||
selected_metrics = {
|
||||
metric
|
||||
for metric in semantic_layer.get_metrics(semantic_view)
|
||||
if metric.name in metrics
|
||||
}
|
||||
selected_dimensions = {
|
||||
dimension
|
||||
for dimension in semantic_layer.get_dimensions(semantic_view)
|
||||
if dimension.name in dimensions
|
||||
}
|
||||
return {
|
||||
"metrics": {
|
||||
metric.name
|
||||
for metric in semantic_layer.get_valid_metrics(
|
||||
semantic_view,
|
||||
selected_metrics,
|
||||
selected_dimensions,
|
||||
)
|
||||
},
|
||||
"dimensions": {
|
||||
dimension.name
|
||||
for dimension in semantic_layer.get_valid_dimensions(
|
||||
semantic_view,
|
||||
selected_metrics,
|
||||
selected_dimensions,
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
return {
|
||||
"dimensions": {column.column_name for column in table.columns},
|
||||
"metrics": {metric.metric_name for metric in table.metrics},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def where_latest_partition( # pylint: disable=unused-argument
|
||||
cls,
|
||||
@@ -1843,6 +1981,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
:param kwargs: kwargs to be passed to cursor.execute()
|
||||
:return:
|
||||
"""
|
||||
if cls.semantic_layer:
|
||||
with cls.get_engine(database, schema="tpcds_sf10tcl") as engine:
|
||||
semantic_layer = cls.semantic_layer(engine)
|
||||
query = semantic_layer.get_query_from_standard_sql(query).sql
|
||||
|
||||
if cls.arraysize:
|
||||
cursor.arraysize = cls.arraysize
|
||||
try:
|
||||
|
||||
@@ -16,11 +16,13 @@
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from re import Pattern
|
||||
from typing import Any, Optional, TYPE_CHECKING, TypedDict
|
||||
from typing import Any, Iterator, Optional, TYPE_CHECKING, TypedDict
|
||||
from urllib import parse
|
||||
|
||||
from apispec import APISpec
|
||||
@@ -30,20 +32,48 @@ from cryptography.hazmat.primitives import serialization
|
||||
from flask import current_app
|
||||
from flask_babel import gettext as __
|
||||
from marshmallow import fields, Schema
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy import text, types
|
||||
from sqlalchemy.engine.interfaces import Dialect
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlglot import exp, parse_one
|
||||
|
||||
from superset.constants import TimeGrain
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType
|
||||
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.extensions.semantic_layer import (
|
||||
BINARY,
|
||||
BOOLEAN,
|
||||
Column as SemanticColumn,
|
||||
DATE,
|
||||
DATETIME,
|
||||
DECIMAL,
|
||||
Dimension as SemanticDimension,
|
||||
Filter as SemanticFilter,
|
||||
INTEGER,
|
||||
Metric as SemanticMetric,
|
||||
NoSort,
|
||||
NUMBER,
|
||||
OBJECT,
|
||||
Query as SemanticQuery,
|
||||
SemanticView,
|
||||
Sort as SemanticSort,
|
||||
SortDirectionEnum,
|
||||
STRING,
|
||||
Table as SemanticTable,
|
||||
TIME,
|
||||
Type as SemanticType,
|
||||
)
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.sql.parse import Table
|
||||
from superset.utils import json
|
||||
from superset.utils.core import get_user_agent, QuerySource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.engine.base import Engine
|
||||
|
||||
from superset.models.core import Database
|
||||
|
||||
# Regular expressions to catch custom errors
|
||||
@@ -77,6 +107,303 @@ class SnowflakeParametersType(TypedDict):
|
||||
warehouse: str
|
||||
|
||||
|
||||
class SnowflakeSemanticLayer:
|
||||
def __init__(self, engine: Engine) -> None:
|
||||
self.engine = engine
|
||||
|
||||
def execute(
|
||||
self,
|
||||
sql: str,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
with self.engine.connect() as connection:
|
||||
for row in connection.execute(text(sql), kwargs).mappings():
|
||||
yield dict(row)
|
||||
|
||||
def get_semantic_views(self) -> set[SemanticView]:
|
||||
sql = """
|
||||
SHOW SEMANTIC VIEWS
|
||||
->> SELECT "name" FROM $1;
|
||||
"""
|
||||
return {SemanticView(row["name"]) for row in self.execute(sql)}
|
||||
|
||||
def get_type(self, snowflake_type: str | None) -> type[SemanticType]:
|
||||
if snowflake_type is None:
|
||||
return STRING
|
||||
|
||||
type_map = {
|
||||
STRING: {r"VARCHAR\(\d+\)$", "STRING$", "TEXT$", r"CHAR\(\d+\)$"},
|
||||
INTEGER: {r"NUMBER\(38,\s?0\)$", "INT$", "INTEGER$", "BIGINT$"},
|
||||
DECIMAL: {r"NUMBER\(10,\s?2\)$"},
|
||||
NUMBER: {r"NUMBER\(\d+,\s?\d+\)$", "FLOAT$", "DOUBLE$"},
|
||||
BOOLEAN: {"BOOLEAN$"},
|
||||
DATE: {"DATE$"},
|
||||
DATETIME: {"TIMESTAMP_TZ$", "TIMESTAMP__NTZ$"},
|
||||
TIME: {"TIME$"},
|
||||
OBJECT: {"OBJECT$"},
|
||||
BINARY: {r"BINARY\(\d+\)$", r"VARBINARY\(\d+\)$"},
|
||||
}
|
||||
for semantic_type, patterns in type_map.items():
|
||||
if any(
|
||||
re.match(pattern, snowflake_type, re.IGNORECASE) for pattern in patterns
|
||||
):
|
||||
return semantic_type
|
||||
|
||||
return STRING
|
||||
|
||||
@classmethod
|
||||
def quote_table(cls, table: Table, dialect: Dialect) -> str:
|
||||
"""
|
||||
Fully quote a table name, including the schema and catalog.
|
||||
"""
|
||||
quoters = {
|
||||
"catalog": dialect.identifier_preparer.quote_schema,
|
||||
"schema": dialect.identifier_preparer.quote_schema,
|
||||
"table": dialect.identifier_preparer.quote,
|
||||
}
|
||||
|
||||
return ".".join(
|
||||
function(getattr(table, key))
|
||||
for key, function in quoters.items()
|
||||
if getattr(table, key)
|
||||
)
|
||||
|
||||
def get_metrics(self, semantic_view: SemanticView) -> set[SemanticMetric]:
|
||||
quoted_semantic_view_name = self.quote_table(
|
||||
Table(semantic_view.name),
|
||||
self.engine.dialect,
|
||||
)
|
||||
sql = f"""
|
||||
DESC SEMANTIC VIEW {quoted_semantic_view_name}
|
||||
->> SELECT "object_name", "property", "property_value"
|
||||
FROM $1
|
||||
WHERE
|
||||
"object_kind" = 'METRIC' AND
|
||||
"property" IN ('DATA_TYPE', 'TABLE');
|
||||
""" # noqa: S608 (semantic_view.name is quoted)
|
||||
rows = self.execute(sql)
|
||||
|
||||
metrics: set[SemanticMetric] = set()
|
||||
for name, group in itertools.groupby(rows, key=lambda x: x["object_name"]):
|
||||
attributes = defaultdict(set)
|
||||
for row in group:
|
||||
attributes[row["property"]].add(row["property_value"])
|
||||
|
||||
table = next(iter(attributes["TABLE"]))
|
||||
metric_name = table + "." + name
|
||||
type_ = self.get_type(next(iter(attributes["DATA_TYPE"])))
|
||||
sql = self.engine.dialect.identifier_preparer.quote(metric_name)
|
||||
tables = frozenset(attributes["TABLE"])
|
||||
join_columns = frozenset()
|
||||
|
||||
metrics.add(SemanticMetric(metric_name, type_, sql, tables, join_columns))
|
||||
|
||||
return metrics
|
||||
|
||||
def get_dimensions(self, semantic_view: SemanticView) -> set[SemanticDimension]:
|
||||
quoted_semantic_view_name = self.quote_table(
|
||||
Table(semantic_view.name),
|
||||
self.engine.dialect,
|
||||
)
|
||||
sql = f"""
|
||||
DESC SEMANTIC VIEW {quoted_semantic_view_name}
|
||||
->> SELECT "object_name", "property", "property_value"
|
||||
FROM $1
|
||||
WHERE
|
||||
"object_kind" = 'DIMENSION' AND
|
||||
"property" IN ('DATA_TYPE', 'TABLE');
|
||||
""" # noqa: S608 (semantic_view.name is quoted)
|
||||
rows = self.execute(sql)
|
||||
|
||||
dimensions: set[SemanticDimension] = set()
|
||||
for name, group in itertools.groupby(rows, key=lambda x: x["object_name"]):
|
||||
attributes = defaultdict(set)
|
||||
for row in group:
|
||||
attributes[row["property"]].add(row["property_value"])
|
||||
|
||||
table = next(iter(attributes["TABLE"]))
|
||||
dimension_name = table + "." + name
|
||||
column = SemanticColumn(SemanticTable(table), name)
|
||||
type_ = self.get_type(next(iter(attributes["DATA_TYPE"])))
|
||||
|
||||
dimensions.add(SemanticDimension(column, dimension_name, type_))
|
||||
|
||||
return dimensions
|
||||
|
||||
def get_valid_metrics(
|
||||
self,
|
||||
semantic_view: SemanticView,
|
||||
metrics: set[SemanticMetric],
|
||||
dimensions: set[SemanticDimension],
|
||||
) -> set[SemanticMetric]:
|
||||
# all metrics and dimensions are valid inside a given semantic view
|
||||
return self.get_metrics(semantic_view)
|
||||
|
||||
def get_valid_dimensions(
|
||||
self,
|
||||
semantic_view: SemanticView,
|
||||
metrics: set[SemanticMetric],
|
||||
dimensions: set[SemanticDimension],
|
||||
) -> set[SemanticDimension]:
|
||||
# all metrics and dimensions are valid inside a given semantic view
|
||||
return self.get_dimensions(semantic_view)
|
||||
|
||||
def get_query(
|
||||
self,
|
||||
semantic_view: SemanticView,
|
||||
metrics: set[SemanticMetric],
|
||||
dimensions: set[SemanticDimension],
|
||||
filters: set[SemanticFilter],
|
||||
sort: SemanticSort = NoSort,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> SemanticQuery:
|
||||
ast = self.build_query(
|
||||
semantic_view,
|
||||
metrics,
|
||||
dimensions,
|
||||
filters,
|
||||
sort,
|
||||
limit,
|
||||
offset,
|
||||
)
|
||||
return SemanticQuery(sql=ast.sql(dialect="snowflake", pretty=True))
|
||||
|
||||
def build_query(
|
||||
self,
|
||||
semantic_view: SemanticView,
|
||||
metrics: set[SemanticMetric],
|
||||
dimensions: set[SemanticDimension],
|
||||
filters: set[SemanticFilter],
|
||||
sort: SemanticSort = NoSort,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> exp.Select:
|
||||
semantic_view = exp.SemanticView(
|
||||
this=exp.Table(this=exp.Identifier(this=semantic_view.name, quoted=True)),
|
||||
dimensions=[
|
||||
exp.Column(
|
||||
this=exp.Identifier(this=dimension.column.name, quoted=True),
|
||||
table=exp.Identifier(
|
||||
this=dimension.column.relation.name,
|
||||
quoted=True,
|
||||
),
|
||||
)
|
||||
for dimension in dimensions
|
||||
],
|
||||
metrics=[
|
||||
exp.Column(
|
||||
this=exp.Identifier(this=column, quoted=True),
|
||||
table=exp.Identifier(this=table, quoted=True),
|
||||
)
|
||||
for table, column in (
|
||||
metric.name.split(".", 1)
|
||||
for metric in metrics
|
||||
if "." in metric.name
|
||||
)
|
||||
],
|
||||
)
|
||||
query = exp.Select(
|
||||
expressions=[exp.Star()],
|
||||
**{"from": exp.From(this=exp.Table(this=semantic_view))},
|
||||
)
|
||||
|
||||
if sort.items:
|
||||
order = [
|
||||
exp.Ordered(
|
||||
this=exp.Column(this=exp.Identifier(this=item.field.name)),
|
||||
desc=item.direction == SortDirectionEnum.DESC,
|
||||
nulls_first=item.nulls_first,
|
||||
)
|
||||
for item in sort.items
|
||||
]
|
||||
query.args["order"] = exp.Order(expressions=order)
|
||||
|
||||
if offset:
|
||||
query = query.offset(offset)
|
||||
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
return query
|
||||
|
||||
def get_query_from_standard_sql(self, sql: str) -> SemanticQuery:
|
||||
"""
|
||||
Convert the Explore query into a proper query.
|
||||
|
||||
Explore will produce a pseudo-SQL query that references metrics and dimensions
|
||||
as if they were columns in a table. This method replaces the table name with a
|
||||
call to `SEMANTIC_VIEW`, and removes the `GROUP BY` clause, since all the
|
||||
aggregations happen inside the `SEMANTIC_VIEW` call.
|
||||
"""
|
||||
ast = parse_one(sql, "snowflake")
|
||||
table = ast.find(exp.Table)
|
||||
if not table:
|
||||
return SemanticQuery(sql=sql)
|
||||
|
||||
semantic_views = self.get_semantic_views()
|
||||
if table.name not in {semantic_view.name for semantic_view in semantic_views}:
|
||||
return SemanticQuery(sql=sql)
|
||||
|
||||
# collect all metric and dimensions
|
||||
semantic_view = SemanticView(table.name)
|
||||
all_metrics = self.get_metrics(semantic_view)
|
||||
all_dimensions = self.get_dimensions(semantic_view)
|
||||
|
||||
# collect metrics and dimensions used in the query
|
||||
columns = {column.name for column in ast.find_all(exp.Column)}
|
||||
metrics = [metric for metric in all_metrics if metric.name in columns]
|
||||
dimensions = [
|
||||
dimension for dimension in all_dimensions if dimension.name in columns
|
||||
]
|
||||
|
||||
# now replace table with a call to `SEMANTIC_VIEW`
|
||||
udtf = exp.Table(
|
||||
this=exp.SemanticView(
|
||||
this=exp.Table(
|
||||
this=exp.Identifier(this=semantic_view.name, quoted=True)
|
||||
),
|
||||
metrics=[
|
||||
exp.Column(
|
||||
this=exp.Identifier(this=column, quoted=True),
|
||||
table=exp.Identifier(this=table, quoted=True),
|
||||
)
|
||||
for table, column in (
|
||||
metric.name.split(".", 1)
|
||||
for metric in metrics
|
||||
if "." in metric.name
|
||||
)
|
||||
],
|
||||
dimensions=[
|
||||
exp.Column(
|
||||
this=exp.Identifier(this=column, quoted=True),
|
||||
table=exp.Identifier(this=table, quoted=True),
|
||||
)
|
||||
for table, column in (
|
||||
dimension.name.split(".", 1)
|
||||
for dimension in dimensions
|
||||
if "." in dimension.name
|
||||
)
|
||||
],
|
||||
),
|
||||
alias=exp.TableAlias(
|
||||
this=exp.Identifier(this="table_alias", quoted=False),
|
||||
columns=[
|
||||
exp.Identifier(this=column.name, quoted=True)
|
||||
for column in metrics + dimensions
|
||||
],
|
||||
),
|
||||
)
|
||||
table.replace(udtf)
|
||||
|
||||
# remove group by, since aggregations are done inside the `SEMANTIC_VIEW` call
|
||||
del ast.args["group"]
|
||||
|
||||
print("BETO")
|
||||
print(ast.sql(dialect="snowflake", pretty=True))
|
||||
return SemanticQuery(sql=ast.sql(dialect="snowflake", pretty=True))
|
||||
|
||||
|
||||
class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
||||
engine = "snowflake"
|
||||
engine_name = "Snowflake"
|
||||
@@ -90,6 +417,8 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
||||
default_driver = "snowflake"
|
||||
sqlalchemy_uri_placeholder = "snowflake://"
|
||||
|
||||
semantic_layer = SnowflakeSemanticLayer
|
||||
|
||||
supports_dynamic_schema = True
|
||||
supports_catalog = supports_dynamic_catalog = supports_cross_catalog_queries = True
|
||||
|
||||
|
||||
340
superset/extensions/semantic_layer.py
Normal file
340
superset/extensions/semantic_layer.py
Normal file
@@ -0,0 +1,340 @@
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from functools import total_ordering
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from sqlalchemy import types as sqltypes
|
||||
from sqlalchemy.engine.base import Engine
|
||||
|
||||
|
||||
class Type:
|
||||
"""
|
||||
Base class for types.
|
||||
"""
|
||||
|
||||
|
||||
class INTEGER(Type):
|
||||
"""
|
||||
Represents an integer type.
|
||||
"""
|
||||
|
||||
|
||||
class NUMBER(Type):
|
||||
"""
|
||||
Represents a number type.
|
||||
"""
|
||||
|
||||
|
||||
class DECIMAL(Type):
|
||||
"""
|
||||
Represents a decimal type.
|
||||
"""
|
||||
|
||||
|
||||
class STRING(Type):
|
||||
"""
|
||||
Represents a string type.
|
||||
"""
|
||||
|
||||
|
||||
class BOOLEAN(Type):
|
||||
"""
|
||||
Represents a boolean type.
|
||||
"""
|
||||
|
||||
|
||||
class DATE(Type):
|
||||
"""
|
||||
Represents a date type.
|
||||
"""
|
||||
|
||||
|
||||
class TIME(Type):
|
||||
"""
|
||||
Represents a time type.
|
||||
"""
|
||||
|
||||
|
||||
class DATETIME(DATE, TIME):
|
||||
"""
|
||||
Represents a datetime type.
|
||||
"""
|
||||
|
||||
|
||||
class INTERVAL(Type):
|
||||
"""
|
||||
Represents an interval type.
|
||||
"""
|
||||
|
||||
|
||||
class OBJECT(Type):
|
||||
"""
|
||||
Represents an object type.
|
||||
"""
|
||||
|
||||
|
||||
class BINARY(Type):
|
||||
"""
|
||||
Represents a binary type.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SemanticView:
|
||||
name: str
|
||||
description: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Relation:
|
||||
name: str
|
||||
schema: str | None = None
|
||||
catalog: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Table:
|
||||
name: str
|
||||
schema: str | None = None
|
||||
catalog: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class View:
|
||||
name: str
|
||||
sql: str
|
||||
schema: str | None = None
|
||||
catalog: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Virtual:
|
||||
name: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Metric:
|
||||
name: str
|
||||
type: type[Type]
|
||||
sql: str
|
||||
tables: frozenset[Table]
|
||||
join_columns: frozenset[str]
|
||||
|
||||
|
||||
@total_ordering
|
||||
class ComparableEnum(enum.Enum):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, enum.Enum):
|
||||
return self.value == other.value
|
||||
return NotImplemented
|
||||
|
||||
def __lt__(self, other: object) -> bool:
|
||||
if isinstance(other, enum.Enum):
|
||||
return self.value < other.value
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.__class__, self.name))
|
||||
|
||||
|
||||
class TimeGrain(ComparableEnum):
|
||||
second = timedelta(seconds=1)
|
||||
minute = timedelta(minutes=1)
|
||||
hour = timedelta(hours=1)
|
||||
|
||||
|
||||
class DateGrain(ComparableEnum):
|
||||
day = timedelta(days=1)
|
||||
week = timedelta(weeks=1)
|
||||
month = timedelta(days=30)
|
||||
quarter = timedelta(days=90)
|
||||
year = timedelta(days=365)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Column:
|
||||
relation: Table | View | Virtual
|
||||
name: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Dimension:
|
||||
column: Column
|
||||
name: str
|
||||
type: type[Type]
|
||||
grain: TimeGrain | DateGrain | None = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
metadata = f"[{self.grain.name}]" if self.grain else ""
|
||||
return f"{self.type.__name__} {self.name} {metadata}".strip()
|
||||
|
||||
|
||||
class FilterTypeEnum(enum.Enum):
|
||||
WHERE = enum.auto()
|
||||
HAVING = enum.auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Filter:
|
||||
type: FilterTypeEnum
|
||||
expression: str
|
||||
|
||||
|
||||
class SortDirectionEnum(enum.Enum):
|
||||
ASC = enum.auto()
|
||||
DESC = enum.auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SortField:
|
||||
field: Metric | Dimension
|
||||
direction: SortDirectionEnum
|
||||
nulls_first: bool = True
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Sort:
|
||||
items: list[SortField]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Query:
|
||||
sql: str
|
||||
|
||||
|
||||
NoSort = Sort(items=[])
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SemanticLayer(Protocol):
|
||||
"""
|
||||
A generic protocol for semantic layers.
|
||||
"""
|
||||
|
||||
def __init__(self, engine: Engine) -> None: ...
|
||||
|
||||
def get_semantic_views(self) -> set[SemanticView]:
|
||||
"""
|
||||
Return a set of the semantic views.
|
||||
|
||||
A semantic view is an organizational group of metrics and dimensions. It's not a
|
||||
logical grouping, since metrics and dimensions from a given semantic view might
|
||||
not be compatible. An implementation might expose a single semantic view for
|
||||
exploration of available metric and dimesnions, and smaller curated semantic
|
||||
views that are domain specific.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_metrics(self, semantic_view: SemanticView) -> set[Metric]:
|
||||
"""
|
||||
Return a set of metrics from a given semantic views.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_dimensions(self, semantic_view: SemanticView) -> set[Dimension]:
|
||||
"""
|
||||
Return a set of dimensions from a given semantic views.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_valid_metrics(
|
||||
self,
|
||||
semantic_view: SemanticView,
|
||||
metrics: set[Metric],
|
||||
dimensions: set[Dimension],
|
||||
) -> set[Metric]:
|
||||
"""
|
||||
Return compatible metrics for the given metrics and dimensions.
|
||||
|
||||
For metrics to be valid they must be compatible with all the provided
|
||||
dimensions.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_valid_dimensions(
|
||||
self,
|
||||
semantic_view: SemanticView,
|
||||
metrics: set[Metric],
|
||||
dimensions: set[Dimension],
|
||||
) -> set[Dimension]:
|
||||
"""
|
||||
Return compatible dimensions for the given metrics.
|
||||
|
||||
For dimensions to be valid they must be compatible with all the provided
|
||||
metrics.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_query(
|
||||
self,
|
||||
semantic_view: SemanticView,
|
||||
metrics: set[Metric],
|
||||
dimensions: set[Dimension],
|
||||
# populations: set[Population],
|
||||
filters: set[Filter],
|
||||
sort: Sort = NoSort,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> Query:
|
||||
"""
|
||||
Build a SQL query from the given metrics, dimensions, filters, and sort order.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_query_from_standard_sql(
|
||||
self,
|
||||
semantic_view: SemanticView,
|
||||
sql: str,
|
||||
) -> Query:
|
||||
"""
|
||||
Build a SQL query from a pseudo-query referencing metrics and dimensions.
|
||||
|
||||
For example, given `metric1` having the expression `COUNT(*)`, this query:
|
||||
|
||||
SELECT metric1, dim1
|
||||
FROM semantic_layer
|
||||
GROUP BY dim1
|
||||
|
||||
Becomes:
|
||||
|
||||
SELECT metric1, dim1
|
||||
FROM (
|
||||
SELECT COUNT(*) AS metric1, dim1
|
||||
FROM fact_table
|
||||
JOIN dim_table
|
||||
ON fact_table.dim_id = dim_table.id
|
||||
GROUP BY dim1
|
||||
) AS semantic_view
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
TYPE_MAPPING: dict[Type, type[sqltypes.TypeEngine]] = {
|
||||
# Numeric types
|
||||
INTEGER: sqltypes.Integer,
|
||||
NUMBER: sqltypes.Numeric,
|
||||
DECIMAL: sqltypes.DECIMAL,
|
||||
# String types
|
||||
STRING: sqltypes.String,
|
||||
# Boolean type
|
||||
BOOLEAN: sqltypes.Boolean,
|
||||
# Date/time types
|
||||
DATE: sqltypes.Date,
|
||||
TIME: sqltypes.Time,
|
||||
DATETIME: sqltypes.DateTime,
|
||||
INTERVAL: sqltypes.Interval,
|
||||
# Complex types
|
||||
OBJECT: sqltypes.JSON,
|
||||
BINARY: sqltypes.LargeBinary,
|
||||
}
|
||||
|
||||
|
||||
def get_sqla_type_from_dimension_type(
|
||||
dimension_type: Type,
|
||||
) -> sqltypes.TypeEngine:
|
||||
"""
|
||||
Get the SQLAlchemy type corresponding to the given dimension type.
|
||||
"""
|
||||
return TYPE_MAPPING.get(dimension_type, sqltypes.String)()
|
||||
@@ -126,7 +126,9 @@ class ConfigurationMethod(StrEnum):
|
||||
DYNAMIC_FORM = "dynamic_form"
|
||||
|
||||
|
||||
class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable=too-many-public-methods
|
||||
class Database(
|
||||
Model, AuditMixinNullable, ImportExportMixin
|
||||
): # pylint: disable=too-many-public-methods
|
||||
"""An ORM object that stores Database related information"""
|
||||
|
||||
__tablename__ = "dbs"
|
||||
@@ -400,9 +402,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
||||
return (
|
||||
username
|
||||
if (username := get_username())
|
||||
else object_url.username
|
||||
if self.impersonate_user
|
||||
else None
|
||||
else object_url.username if self.impersonate_user else None
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
@@ -987,7 +987,10 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
||||
schema=table.schema,
|
||||
) as inspector:
|
||||
return self.db_engine_spec.get_columns(
|
||||
inspector, table, self.schema_options
|
||||
self,
|
||||
inspector,
|
||||
table,
|
||||
self.schema_options,
|
||||
)
|
||||
|
||||
def get_metrics(
|
||||
@@ -1076,9 +1079,11 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
||||
return self.perm
|
||||
|
||||
def has_table(self, table: Table) -> bool:
|
||||
with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine:
|
||||
# do not pass "" as an empty schema; force null
|
||||
return engine.has_table(table.table, table.schema or None)
|
||||
with self.get_inspector(
|
||||
catalog=table.catalog,
|
||||
schema=table.schema,
|
||||
) as inspector:
|
||||
return self.db_engine_spec.has_table(self, inspector, table)
|
||||
|
||||
def has_view(self, table: Table) -> bool:
|
||||
with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine:
|
||||
|
||||
Reference in New Issue
Block a user