Compare commits

...

5 Commits

Author SHA1 Message Date
Beto Dealmeida
57ca29baf4 Working 2025-07-30 16:08:32 -04:00
Beto Dealmeida
01c587361f WIP 2025-07-28 12:28:13 -04:00
Beto Dealmeida
fc64ac918a WIP 2025-07-28 10:56:22 -04:00
Beto Dealmeida
cd019bab3e WIP 2025-07-23 15:51:07 -04:00
Beto Dealmeida
a330fe6f7e WIP 2025-07-22 16:18:50 -04:00
6 changed files with 844 additions and 24 deletions

View File

@@ -164,7 +164,9 @@ class DatasourceKind(StrEnum):
PHYSICAL = "physical"
class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=too-many-public-methods
class BaseDatasource(
AuditMixinNullable, ImportExportMixin
): # pylint: disable=too-many-public-methods
"""A common interface to objects that are queryable
(tables and datasources)"""
@@ -1778,7 +1780,9 @@ class SqlaTable(
def default_query(qry: Query) -> Query:
return qry.filter_by(is_sqllab_view=False)
def has_extra_cache_key_calls(self, query_obj: QueryObjectDict) -> bool: # noqa: C901
def has_extra_cache_key_calls(
self, query_obj: QueryObjectDict
) -> bool: # noqa: C901
"""
Detects the presence of calls to `ExtraCache` methods in items in query_obj that
can be templated. If any are present, the query must be evaluated to extract

View File

@@ -75,12 +75,11 @@ class DatasetDAO(BaseDAO[SqlaTable]):
database: Database,
table: Table,
) -> bool:
try:
database.get_table(table)
return True
except SQLAlchemyError as ex: # pragma: no cover
logger.warning("Got an error %s validating table: %s", str(ex), table)
return False
with database.get_inspector(
catalog=table.catalog,
schema=table.schema,
) as inspector:
return database.db_engine_spec.has_table(database, inspector, table)
@staticmethod
def validate_uniqueness(

View File

@@ -30,6 +30,7 @@ from typing import (
cast,
ContextManager,
NamedTuple,
Type,
TYPE_CHECKING,
TypedDict,
Union,
@@ -62,6 +63,10 @@ from superset.constants import QUERY_CANCEL_KEY, TimeGrain as TimeGrainConstants
from superset.databases.utils import get_table_metadata, make_url_safe
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.extensions.semantic_layer import (
get_sqla_type_from_dimension_type,
SemanticLayer,
)
from superset.sql.parse import (
BaseSQLStatement,
LimitMethod,
@@ -85,7 +90,7 @@ from superset.utils.network import is_hostname_valid, is_port_open
from superset.utils.oauth2 import encode_oauth2_state
if TYPE_CHECKING:
from superset.connectors.sqla.models import TableColumn
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.databases.schemas import TableMetadataResponse
from superset.models.core import Database
from superset.models.sql_lab import Query
@@ -106,6 +111,15 @@ logger = logging.getLogger()
GenericDBException = Exception
class ValidColumnsType(TypedDict):
"""
Type for valid columns returned by `get_valid_metrics_and_dimensions`.
"""
dimensions: set[str]
metrics: set[str]
def convert_inspector_columns(cols: list[SQLAColumnType]) -> list[ResultSetColumnType]:
result_set_columns: list[ResultSetColumnType] = []
for col in cols:
@@ -143,7 +157,9 @@ builtin_time_grains: dict[str | None, str] = {
}
class TimestampExpression(ColumnClause): # pylint: disable=abstract-method, too-many-ancestors
class TimestampExpression(
ColumnClause
): # pylint: disable=abstract-method, too-many-ancestors
def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None:
"""Sqlalchemy class that can be used to render native column elements respecting
engine-specific quoting rules as part of a string-based expression.
@@ -214,6 +230,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"engine+driver://user:password@host:port/dbname[?key=value&key=value...]"
)
# databases can optionally specify a semantic layer
semantic_layer: Type[SemanticLayer] | None = None
disable_ssh_tunneling = False
_date_trunc_functions: dict[str, str] = {}
@@ -388,9 +407,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
max_column_name_length: int | None = None
try_remove_schema_from_table_name = True # pylint: disable=invalid-name
run_multiple_statements_as_one = False
custom_errors: dict[
Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]
] = {}
custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = (
{}
)
# List of JSON path to fields in `encrypted_extra` that should be masked when the
# database is edited. By default everything is masked.
@@ -1461,8 +1480,32 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
if schema and cls.try_remove_schema_from_table_name:
tables = {re.sub(f"^{schema}\\.", "", table) for table in tables}
# add semantic views as tables too
if cls.semantic_layer:
semantic_layer = cls.semantic_layer(inspector.engine)
tables.update(
semantic_view.name
for semantic_view in semantic_layer.get_semantic_views()
)
return tables
@classmethod
def has_table(
cls,
database: Database,
inspector: Inspector,
table: Table,
) -> bool:
if cls.semantic_layer:
semantic_layer = cls.semantic_layer(inspector.engine)
semantic_views = semantic_layer.get_semantic_views()
if table.table in {semantic_view.name for semantic_view in semantic_views}:
return True
return inspector.has_table(table.table, table.schema)
@classmethod
def get_view_names( # pylint: disable=unused-argument
cls,
@@ -1536,6 +1579,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_columns( # pylint: disable=unused-argument
cls,
database: Database,
inspector: Inspector,
table: Table,
options: dict[str, Any] | None = None,
@@ -1543,7 +1587,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""
Get all columns from a given schema and table.
The inspector will be bound to a catalog, if one was specified.
The inspector will be bound to a catalog, if one was specified. If the database
supports semantic layers the method will check if the table is a semantic view,
and return columns (metrics and dimensions) from it instead.
:param inspector: SqlAlchemy Inspector instance
:param table: Table instance
@@ -1551,6 +1597,26 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
some databases
:return: All columns in table
"""
if cls.semantic_layer:
semantic_layer = cls.semantic_layer(inspector.engine)
semantic_views = {
semantic_view.name: semantic_view
for semantic_view in semantic_layer.get_semantic_views()
}
if semantic_view := semantic_views.get(table.table):
dialect = database.get_dialect()
return [
{
"name": dimension.name,
"column_name": dimension.name,
"type": cls.column_datatype_to_string(
get_sqla_type_from_dimension_type(dimension.type),
dialect,
),
}
for dimension in semantic_layer.get_dimensions(semantic_view)
]
return convert_inspector_columns(
cast(
list[SQLAColumnType],
@@ -1568,6 +1634,22 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""
Get all metrics from a given schema and table.
"""
if cls.semantic_layer:
semantic_layer = cls.semantic_layer(inspector.engine)
semantic_views = {
semantic_view.name: semantic_view
for semantic_view in semantic_layer.get_semantic_views()
}
if semantic_view := semantic_views.get(table.table):
return [
{
"metric_name": metric.name,
"verbose_name": metric.name,
"expression": metric.sql,
}
for metric in semantic_layer.get_metrics(semantic_view)
]
return [
{
"metric_name": "count",
@@ -1577,6 +1659,62 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
}
]
@classmethod
def get_valid_metrics_and_dimensions(
cls,
database: Database,
table: SqlaTable,
dimensions: set[str],
metrics: set[str],
) -> ValidColumnsType:
"""
Get valid metrics and dimensions.
Given a datasource, and sets of selected metrics and dimensions, return the
sets of valid metrics and dimensions that can further be selected.
"""
if cls.semantic_layer:
with database.get_sqla_engine() as engine:
semantic_layer = cls.semantic_layer(engine)
semantic_views = {
semantic_view.name: semantic_view
for semantic_view in semantic_layer.get_semantic_views()
}
if semantic_view := semantic_views.get(table.table):
selected_metrics = {
metric
for metric in semantic_layer.get_metrics(semantic_view)
if metric.name in metrics
}
selected_dimensions = {
dimension
for dimension in semantic_layer.get_dimensions(semantic_view)
if dimension.name in dimensions
}
return {
"metrics": {
metric.name
for metric in semantic_layer.get_valid_metrics(
semantic_view,
selected_metrics,
selected_dimensions,
)
},
"dimensions": {
dimension.name
for dimension in semantic_layer.get_valid_dimensions(
semantic_view,
selected_metrics,
selected_dimensions,
)
},
}
return {
"dimensions": {column.column_name for column in table.columns},
"metrics": {metric.metric_name for metric in table.metrics},
}
@classmethod
def where_latest_partition( # pylint: disable=unused-argument
cls,
@@ -1843,6 +1981,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param kwargs: kwargs to be passed to cursor.execute()
:return:
"""
if cls.semantic_layer:
with cls.get_engine(database, schema="tpcds_sf10tcl") as engine:
semantic_layer = cls.semantic_layer(engine)
query = semantic_layer.get_query_from_standard_sql(query).sql
if cls.arraysize:
cursor.arraysize = cls.arraysize
try:

View File

@@ -16,11 +16,13 @@
# under the License.
from __future__ import annotations
import itertools
import logging
import re
from collections import defaultdict
from datetime import datetime
from re import Pattern
from typing import Any, Optional, TYPE_CHECKING, TypedDict
from typing import Any, Iterator, Optional, TYPE_CHECKING, TypedDict
from urllib import parse
from apispec import APISpec
@@ -30,20 +32,48 @@ from cryptography.hazmat.primitives import serialization
from flask import current_app
from flask_babel import gettext as __
from marshmallow import fields, Schema
from sqlalchemy import types
from sqlalchemy import text, types
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlglot import exp, parse_one
from superset.constants import TimeGrain
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.extensions.semantic_layer import (
BINARY,
BOOLEAN,
Column as SemanticColumn,
DATE,
DATETIME,
DECIMAL,
Dimension as SemanticDimension,
Filter as SemanticFilter,
INTEGER,
Metric as SemanticMetric,
NoSort,
NUMBER,
OBJECT,
Query as SemanticQuery,
SemanticView,
Sort as SemanticSort,
SortDirectionEnum,
STRING,
Table as SemanticTable,
TIME,
Type as SemanticType,
)
from superset.models.sql_lab import Query
from superset.sql.parse import Table
from superset.utils import json
from superset.utils.core import get_user_agent, QuerySource
if TYPE_CHECKING:
from sqlalchemy.engine.base import Engine
from superset.models.core import Database
# Regular expressions to catch custom errors
@@ -77,6 +107,303 @@ class SnowflakeParametersType(TypedDict):
warehouse: str
class SnowflakeSemanticLayer:
def __init__(self, engine: Engine) -> None:
self.engine = engine
def execute(
self,
sql: str,
**kwargs: Any,
) -> Iterator[dict[str, Any]]:
with self.engine.connect() as connection:
for row in connection.execute(text(sql), kwargs).mappings():
yield dict(row)
def get_semantic_views(self) -> set[SemanticView]:
sql = """
SHOW SEMANTIC VIEWS
->> SELECT "name" FROM $1;
"""
return {SemanticView(row["name"]) for row in self.execute(sql)}
def get_type(self, snowflake_type: str | None) -> type[SemanticType]:
if snowflake_type is None:
return STRING
type_map = {
STRING: {r"VARCHAR\(\d+\)$", "STRING$", "TEXT$", r"CHAR\(\d+\)$"},
INTEGER: {r"NUMBER\(38,\s?0\)$", "INT$", "INTEGER$", "BIGINT$"},
DECIMAL: {r"NUMBER\(10,\s?2\)$"},
NUMBER: {r"NUMBER\(\d+,\s?\d+\)$", "FLOAT$", "DOUBLE$"},
BOOLEAN: {"BOOLEAN$"},
DATE: {"DATE$"},
DATETIME: {"TIMESTAMP_TZ$", "TIMESTAMP__NTZ$"},
TIME: {"TIME$"},
OBJECT: {"OBJECT$"},
BINARY: {r"BINARY\(\d+\)$", r"VARBINARY\(\d+\)$"},
}
for semantic_type, patterns in type_map.items():
if any(
re.match(pattern, snowflake_type, re.IGNORECASE) for pattern in patterns
):
return semantic_type
return STRING
@classmethod
def quote_table(cls, table: Table, dialect: Dialect) -> str:
"""
Fully quote a table name, including the schema and catalog.
"""
quoters = {
"catalog": dialect.identifier_preparer.quote_schema,
"schema": dialect.identifier_preparer.quote_schema,
"table": dialect.identifier_preparer.quote,
}
return ".".join(
function(getattr(table, key))
for key, function in quoters.items()
if getattr(table, key)
)
def get_metrics(self, semantic_view: SemanticView) -> set[SemanticMetric]:
quoted_semantic_view_name = self.quote_table(
Table(semantic_view.name),
self.engine.dialect,
)
sql = f"""
DESC SEMANTIC VIEW {quoted_semantic_view_name}
->> SELECT "object_name", "property", "property_value"
FROM $1
WHERE
"object_kind" = 'METRIC' AND
"property" IN ('DATA_TYPE', 'TABLE');
""" # noqa: S608 (semantic_view.name is quoted)
rows = self.execute(sql)
metrics: set[SemanticMetric] = set()
for name, group in itertools.groupby(rows, key=lambda x: x["object_name"]):
attributes = defaultdict(set)
for row in group:
attributes[row["property"]].add(row["property_value"])
table = next(iter(attributes["TABLE"]))
metric_name = table + "." + name
type_ = self.get_type(next(iter(attributes["DATA_TYPE"])))
sql = self.engine.dialect.identifier_preparer.quote(metric_name)
tables = frozenset(attributes["TABLE"])
join_columns = frozenset()
metrics.add(SemanticMetric(metric_name, type_, sql, tables, join_columns))
return metrics
def get_dimensions(self, semantic_view: SemanticView) -> set[SemanticDimension]:
quoted_semantic_view_name = self.quote_table(
Table(semantic_view.name),
self.engine.dialect,
)
sql = f"""
DESC SEMANTIC VIEW {quoted_semantic_view_name}
->> SELECT "object_name", "property", "property_value"
FROM $1
WHERE
"object_kind" = 'DIMENSION' AND
"property" IN ('DATA_TYPE', 'TABLE');
""" # noqa: S608 (semantic_view.name is quoted)
rows = self.execute(sql)
dimensions: set[SemanticDimension] = set()
for name, group in itertools.groupby(rows, key=lambda x: x["object_name"]):
attributes = defaultdict(set)
for row in group:
attributes[row["property"]].add(row["property_value"])
table = next(iter(attributes["TABLE"]))
dimension_name = table + "." + name
column = SemanticColumn(SemanticTable(table), name)
type_ = self.get_type(next(iter(attributes["DATA_TYPE"])))
dimensions.add(SemanticDimension(column, dimension_name, type_))
return dimensions
def get_valid_metrics(
self,
semantic_view: SemanticView,
metrics: set[SemanticMetric],
dimensions: set[SemanticDimension],
) -> set[SemanticMetric]:
# all metrics and dimensions are valid inside a given semantic view
return self.get_metrics(semantic_view)
def get_valid_dimensions(
self,
semantic_view: SemanticView,
metrics: set[SemanticMetric],
dimensions: set[SemanticDimension],
) -> set[SemanticDimension]:
# all metrics and dimensions are valid inside a given semantic view
return self.get_dimensions(semantic_view)
def get_query(
self,
semantic_view: SemanticView,
metrics: set[SemanticMetric],
dimensions: set[SemanticDimension],
filters: set[SemanticFilter],
sort: SemanticSort = NoSort,
limit: int | None = None,
offset: int | None = None,
) -> SemanticQuery:
ast = self.build_query(
semantic_view,
metrics,
dimensions,
filters,
sort,
limit,
offset,
)
return SemanticQuery(sql=ast.sql(dialect="snowflake", pretty=True))
def build_query(
self,
semantic_view: SemanticView,
metrics: set[SemanticMetric],
dimensions: set[SemanticDimension],
filters: set[SemanticFilter],
sort: SemanticSort = NoSort,
limit: int | None = None,
offset: int | None = None,
) -> exp.Select:
semantic_view = exp.SemanticView(
this=exp.Table(this=exp.Identifier(this=semantic_view.name, quoted=True)),
dimensions=[
exp.Column(
this=exp.Identifier(this=dimension.column.name, quoted=True),
table=exp.Identifier(
this=dimension.column.relation.name,
quoted=True,
),
)
for dimension in dimensions
],
metrics=[
exp.Column(
this=exp.Identifier(this=column, quoted=True),
table=exp.Identifier(this=table, quoted=True),
)
for table, column in (
metric.name.split(".", 1)
for metric in metrics
if "." in metric.name
)
],
)
query = exp.Select(
expressions=[exp.Star()],
**{"from": exp.From(this=exp.Table(this=semantic_view))},
)
if sort.items:
order = [
exp.Ordered(
this=exp.Column(this=exp.Identifier(this=item.field.name)),
desc=item.direction == SortDirectionEnum.DESC,
nulls_first=item.nulls_first,
)
for item in sort.items
]
query.args["order"] = exp.Order(expressions=order)
if offset:
query = query.offset(offset)
if limit:
query = query.limit(limit)
return query
def get_query_from_standard_sql(self, sql: str) -> SemanticQuery:
"""
Convert the Explore query into a proper query.
Explore will produce a pseudo-SQL query that references metrics and dimensions
as if they were columns in a table. This method replaces the table name with a
call to `SEMANTIC_VIEW`, and removes the `GROUP BY` clause, since all the
aggregations happen inside the `SEMANTIC_VIEW` call.
"""
ast = parse_one(sql, "snowflake")
table = ast.find(exp.Table)
if not table:
return SemanticQuery(sql=sql)
semantic_views = self.get_semantic_views()
if table.name not in {semantic_view.name for semantic_view in semantic_views}:
return SemanticQuery(sql=sql)
# collect all metric and dimensions
semantic_view = SemanticView(table.name)
all_metrics = self.get_metrics(semantic_view)
all_dimensions = self.get_dimensions(semantic_view)
# collect metrics and dimensions used in the query
columns = {column.name for column in ast.find_all(exp.Column)}
metrics = [metric for metric in all_metrics if metric.name in columns]
dimensions = [
dimension for dimension in all_dimensions if dimension.name in columns
]
# now replace table with a call to `SEMANTIC_VIEW`
udtf = exp.Table(
this=exp.SemanticView(
this=exp.Table(
this=exp.Identifier(this=semantic_view.name, quoted=True)
),
metrics=[
exp.Column(
this=exp.Identifier(this=column, quoted=True),
table=exp.Identifier(this=table, quoted=True),
)
for table, column in (
metric.name.split(".", 1)
for metric in metrics
if "." in metric.name
)
],
dimensions=[
exp.Column(
this=exp.Identifier(this=column, quoted=True),
table=exp.Identifier(this=table, quoted=True),
)
for table, column in (
dimension.name.split(".", 1)
for dimension in dimensions
if "." in dimension.name
)
],
),
alias=exp.TableAlias(
this=exp.Identifier(this="table_alias", quoted=False),
columns=[
exp.Identifier(this=column.name, quoted=True)
for column in metrics + dimensions
],
),
)
table.replace(udtf)
# remove group by, since aggregations are done inside the `SEMANTIC_VIEW` call
del ast.args["group"]
print("BETO")
print(ast.sql(dialect="snowflake", pretty=True))
return SemanticQuery(sql=ast.sql(dialect="snowflake", pretty=True))
class SnowflakeEngineSpec(PostgresBaseEngineSpec):
engine = "snowflake"
engine_name = "Snowflake"
@@ -90,6 +417,8 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
default_driver = "snowflake"
sqlalchemy_uri_placeholder = "snowflake://"
semantic_layer = SnowflakeSemanticLayer
supports_dynamic_schema = True
supports_catalog = supports_dynamic_catalog = supports_cross_catalog_queries = True

View File

@@ -0,0 +1,340 @@
import enum
from dataclasses import dataclass
from datetime import timedelta
from functools import total_ordering
from typing import Protocol, runtime_checkable
from sqlalchemy import types as sqltypes
from sqlalchemy.engine.base import Engine
class Type:
"""
Base class for types.
"""
class INTEGER(Type):
"""
Represents an integer type.
"""
class NUMBER(Type):
"""
Represents a number type.
"""
class DECIMAL(Type):
"""
Represents a decimal type.
"""
class STRING(Type):
"""
Represents a string type.
"""
class BOOLEAN(Type):
"""
Represents a boolean type.
"""
class DATE(Type):
"""
Represents a date type.
"""
class TIME(Type):
"""
Represents a time type.
"""
class DATETIME(DATE, TIME):
"""
Represents a datetime type.
"""
class INTERVAL(Type):
"""
Represents an interval type.
"""
class OBJECT(Type):
"""
Represents an object type.
"""
class BINARY(Type):
"""
Represents a binary type.
"""
@dataclass(frozen=True)
class SemanticView:
name: str
description: str | None = None
@dataclass(frozen=True)
class Relation:
name: str
schema: str | None = None
catalog: str | None = None
@dataclass(frozen=True)
class Table:
name: str
schema: str | None = None
catalog: str | None = None
@dataclass(frozen=True)
class View:
name: str
sql: str
schema: str | None = None
catalog: str | None = None
@dataclass(frozen=True)
class Virtual:
name: str
@dataclass(frozen=True)
class Metric:
name: str
type: type[Type]
sql: str
tables: frozenset[Table]
join_columns: frozenset[str]
@total_ordering
class ComparableEnum(enum.Enum):
def __eq__(self, other: object) -> bool:
if isinstance(other, enum.Enum):
return self.value == other.value
return NotImplemented
def __lt__(self, other: object) -> bool:
if isinstance(other, enum.Enum):
return self.value < other.value
return NotImplemented
def __hash__(self):
return hash((self.__class__, self.name))
class TimeGrain(ComparableEnum):
second = timedelta(seconds=1)
minute = timedelta(minutes=1)
hour = timedelta(hours=1)
class DateGrain(ComparableEnum):
day = timedelta(days=1)
week = timedelta(weeks=1)
month = timedelta(days=30)
quarter = timedelta(days=90)
year = timedelta(days=365)
@dataclass(frozen=True)
class Column:
relation: Table | View | Virtual
name: str
@dataclass(frozen=True)
class Dimension:
column: Column
name: str
type: type[Type]
grain: TimeGrain | DateGrain | None = None
def __repr__(self) -> str:
metadata = f"[{self.grain.name}]" if self.grain else ""
return f"{self.type.__name__} {self.name} {metadata}".strip()
class FilterTypeEnum(enum.Enum):
WHERE = enum.auto()
HAVING = enum.auto()
@dataclass(frozen=True)
class Filter:
type: FilterTypeEnum
expression: str
class SortDirectionEnum(enum.Enum):
ASC = enum.auto()
DESC = enum.auto()
@dataclass(frozen=True)
class SortField:
field: Metric | Dimension
direction: SortDirectionEnum
nulls_first: bool = True
@dataclass(frozen=True)
class Sort:
items: list[SortField]
@dataclass(frozen=True)
class Query:
sql: str
NoSort = Sort(items=[])
@runtime_checkable
class SemanticLayer(Protocol):
"""
A generic protocol for semantic layers.
"""
def __init__(self, engine: Engine) -> None: ...
def get_semantic_views(self) -> set[SemanticView]:
"""
Return a set of the semantic views.
A semantic view is an organizational group of metrics and dimensions. It's not a
logical grouping, since metrics and dimensions from a given semantic view might
not be compatible. An implementation might expose a single semantic view for
exploration of available metric and dimesnions, and smaller curated semantic
views that are domain specific.
"""
...
def get_metrics(self, semantic_view: SemanticView) -> set[Metric]:
"""
Return a set of metrics from a given semantic views.
"""
...
def get_dimensions(self, semantic_view: SemanticView) -> set[Dimension]:
"""
Return a set of dimensions from a given semantic views.
"""
...
def get_valid_metrics(
self,
semantic_view: SemanticView,
metrics: set[Metric],
dimensions: set[Dimension],
) -> set[Metric]:
"""
Return compatible metrics for the given metrics and dimensions.
For metrics to be valid they must be compatible with all the provided
dimensions.
"""
...
def get_valid_dimensions(
self,
semantic_view: SemanticView,
metrics: set[Metric],
dimensions: set[Dimension],
) -> set[Dimension]:
"""
Return compatible dimensions for the given metrics.
For dimensions to be valid they must be compatible with all the provided
metrics.
"""
...
def get_query(
self,
semantic_view: SemanticView,
metrics: set[Metric],
dimensions: set[Dimension],
# populations: set[Population],
filters: set[Filter],
sort: Sort = NoSort,
limit: int | None = None,
offset: int | None = None,
) -> Query:
"""
Build a SQL query from the given metrics, dimensions, filters, and sort order.
"""
...
def get_query_from_standard_sql(
self,
semantic_view: SemanticView,
sql: str,
) -> Query:
"""
Build a SQL query from a pseudo-query referencing metrics and dimensions.
For example, given `metric1` having the expression `COUNT(*)`, this query:
SELECT metric1, dim1
FROM semantic_layer
GROUP BY dim1
Becomes:
SELECT metric1, dim1
FROM (
SELECT COUNT(*) AS metric1, dim1
FROM fact_table
JOIN dim_table
ON fact_table.dim_id = dim_table.id
GROUP BY dim1
) AS semantic_view
"""
...
TYPE_MAPPING: dict[Type, type[sqltypes.TypeEngine]] = {
# Numeric types
INTEGER: sqltypes.Integer,
NUMBER: sqltypes.Numeric,
DECIMAL: sqltypes.DECIMAL,
# String types
STRING: sqltypes.String,
# Boolean type
BOOLEAN: sqltypes.Boolean,
# Date/time types
DATE: sqltypes.Date,
TIME: sqltypes.Time,
DATETIME: sqltypes.DateTime,
INTERVAL: sqltypes.Interval,
# Complex types
OBJECT: sqltypes.JSON,
BINARY: sqltypes.LargeBinary,
}
def get_sqla_type_from_dimension_type(
dimension_type: Type,
) -> sqltypes.TypeEngine:
"""
Get the SQLAlchemy type corresponding to the given dimension type.
"""
return TYPE_MAPPING.get(dimension_type, sqltypes.String)()

View File

@@ -126,7 +126,9 @@ class ConfigurationMethod(StrEnum):
DYNAMIC_FORM = "dynamic_form"
class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable=too-many-public-methods
class Database(
Model, AuditMixinNullable, ImportExportMixin
): # pylint: disable=too-many-public-methods
"""An ORM object that stores Database related information"""
__tablename__ = "dbs"
@@ -400,9 +402,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
return (
username
if (username := get_username())
else object_url.username
if self.impersonate_user
else None
else object_url.username if self.impersonate_user else None
)
@contextmanager
@@ -987,7 +987,10 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
schema=table.schema,
) as inspector:
return self.db_engine_spec.get_columns(
inspector, table, self.schema_options
self,
inspector,
table,
self.schema_options,
)
def get_metrics(
@@ -1076,9 +1079,11 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
return self.perm
def has_table(self, table: Table) -> bool:
with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine:
# do not pass "" as an empty schema; force null
return engine.has_table(table.table, table.schema or None)
with self.get_inspector(
catalog=table.catalog,
schema=table.schema,
) as inspector:
return self.db_engine_spec.has_table(self, inspector, table)
def has_view(self, table: Table) -> bool:
with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine: