mirror of
https://github.com/apache/superset.git
synced 2026-05-04 07:24:18 +00:00
Compare commits
5 Commits
docs/testi
...
snowflake-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
57ca29baf4 | ||
|
|
01c587361f | ||
|
|
fc64ac918a | ||
|
|
cd019bab3e | ||
|
|
a330fe6f7e |
@@ -164,7 +164,9 @@ class DatasourceKind(StrEnum):
|
|||||||
PHYSICAL = "physical"
|
PHYSICAL = "physical"
|
||||||
|
|
||||||
|
|
||||||
class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=too-many-public-methods
|
class BaseDatasource(
|
||||||
|
AuditMixinNullable, ImportExportMixin
|
||||||
|
): # pylint: disable=too-many-public-methods
|
||||||
"""A common interface to objects that are queryable
|
"""A common interface to objects that are queryable
|
||||||
(tables and datasources)"""
|
(tables and datasources)"""
|
||||||
|
|
||||||
@@ -1778,7 +1780,9 @@ class SqlaTable(
|
|||||||
def default_query(qry: Query) -> Query:
|
def default_query(qry: Query) -> Query:
|
||||||
return qry.filter_by(is_sqllab_view=False)
|
return qry.filter_by(is_sqllab_view=False)
|
||||||
|
|
||||||
def has_extra_cache_key_calls(self, query_obj: QueryObjectDict) -> bool: # noqa: C901
|
def has_extra_cache_key_calls(
|
||||||
|
self, query_obj: QueryObjectDict
|
||||||
|
) -> bool: # noqa: C901
|
||||||
"""
|
"""
|
||||||
Detects the presence of calls to `ExtraCache` methods in items in query_obj that
|
Detects the presence of calls to `ExtraCache` methods in items in query_obj that
|
||||||
can be templated. If any are present, the query must be evaluated to extract
|
can be templated. If any are present, the query must be evaluated to extract
|
||||||
|
|||||||
@@ -75,12 +75,11 @@ class DatasetDAO(BaseDAO[SqlaTable]):
|
|||||||
database: Database,
|
database: Database,
|
||||||
table: Table,
|
table: Table,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
try:
|
with database.get_inspector(
|
||||||
database.get_table(table)
|
catalog=table.catalog,
|
||||||
return True
|
schema=table.schema,
|
||||||
except SQLAlchemyError as ex: # pragma: no cover
|
) as inspector:
|
||||||
logger.warning("Got an error %s validating table: %s", str(ex), table)
|
return database.db_engine_spec.has_table(database, inspector, table)
|
||||||
return False
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def validate_uniqueness(
|
def validate_uniqueness(
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from typing import (
|
|||||||
cast,
|
cast,
|
||||||
ContextManager,
|
ContextManager,
|
||||||
NamedTuple,
|
NamedTuple,
|
||||||
|
Type,
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
Union,
|
Union,
|
||||||
@@ -62,6 +63,10 @@ from superset.constants import QUERY_CANCEL_KEY, TimeGrain as TimeGrainConstants
|
|||||||
from superset.databases.utils import get_table_metadata, make_url_safe
|
from superset.databases.utils import get_table_metadata, make_url_safe
|
||||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError
|
from superset.exceptions import OAuth2Error, OAuth2RedirectError
|
||||||
|
from superset.extensions.semantic_layer import (
|
||||||
|
get_sqla_type_from_dimension_type,
|
||||||
|
SemanticLayer,
|
||||||
|
)
|
||||||
from superset.sql.parse import (
|
from superset.sql.parse import (
|
||||||
BaseSQLStatement,
|
BaseSQLStatement,
|
||||||
LimitMethod,
|
LimitMethod,
|
||||||
@@ -85,7 +90,7 @@ from superset.utils.network import is_hostname_valid, is_port_open
|
|||||||
from superset.utils.oauth2 import encode_oauth2_state
|
from superset.utils.oauth2 import encode_oauth2_state
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from superset.connectors.sqla.models import TableColumn
|
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||||
from superset.databases.schemas import TableMetadataResponse
|
from superset.databases.schemas import TableMetadataResponse
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
from superset.models.sql_lab import Query
|
from superset.models.sql_lab import Query
|
||||||
@@ -106,6 +111,15 @@ logger = logging.getLogger()
|
|||||||
GenericDBException = Exception
|
GenericDBException = Exception
|
||||||
|
|
||||||
|
|
||||||
|
class ValidColumnsType(TypedDict):
|
||||||
|
"""
|
||||||
|
Type for valid columns returned by `get_valid_metrics_and_dimensions`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dimensions: set[str]
|
||||||
|
metrics: set[str]
|
||||||
|
|
||||||
|
|
||||||
def convert_inspector_columns(cols: list[SQLAColumnType]) -> list[ResultSetColumnType]:
|
def convert_inspector_columns(cols: list[SQLAColumnType]) -> list[ResultSetColumnType]:
|
||||||
result_set_columns: list[ResultSetColumnType] = []
|
result_set_columns: list[ResultSetColumnType] = []
|
||||||
for col in cols:
|
for col in cols:
|
||||||
@@ -143,7 +157,9 @@ builtin_time_grains: dict[str | None, str] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class TimestampExpression(ColumnClause): # pylint: disable=abstract-method, too-many-ancestors
|
class TimestampExpression(
|
||||||
|
ColumnClause
|
||||||
|
): # pylint: disable=abstract-method, too-many-ancestors
|
||||||
def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None:
|
def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None:
|
||||||
"""Sqlalchemy class that can be used to render native column elements respecting
|
"""Sqlalchemy class that can be used to render native column elements respecting
|
||||||
engine-specific quoting rules as part of a string-based expression.
|
engine-specific quoting rules as part of a string-based expression.
|
||||||
@@ -214,6 +230,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||||||
"engine+driver://user:password@host:port/dbname[?key=value&key=value...]"
|
"engine+driver://user:password@host:port/dbname[?key=value&key=value...]"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# databases can optionally specify a semantic layer
|
||||||
|
semantic_layer: Type[SemanticLayer] | None = None
|
||||||
|
|
||||||
disable_ssh_tunneling = False
|
disable_ssh_tunneling = False
|
||||||
|
|
||||||
_date_trunc_functions: dict[str, str] = {}
|
_date_trunc_functions: dict[str, str] = {}
|
||||||
@@ -388,9 +407,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||||||
max_column_name_length: int | None = None
|
max_column_name_length: int | None = None
|
||||||
try_remove_schema_from_table_name = True # pylint: disable=invalid-name
|
try_remove_schema_from_table_name = True # pylint: disable=invalid-name
|
||||||
run_multiple_statements_as_one = False
|
run_multiple_statements_as_one = False
|
||||||
custom_errors: dict[
|
custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = (
|
||||||
Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]
|
{}
|
||||||
] = {}
|
)
|
||||||
|
|
||||||
# List of JSON path to fields in `encrypted_extra` that should be masked when the
|
# List of JSON path to fields in `encrypted_extra` that should be masked when the
|
||||||
# database is edited. By default everything is masked.
|
# database is edited. By default everything is masked.
|
||||||
@@ -1461,8 +1480,32 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||||||
|
|
||||||
if schema and cls.try_remove_schema_from_table_name:
|
if schema and cls.try_remove_schema_from_table_name:
|
||||||
tables = {re.sub(f"^{schema}\\.", "", table) for table in tables}
|
tables = {re.sub(f"^{schema}\\.", "", table) for table in tables}
|
||||||
|
|
||||||
|
# add semantic views as tables too
|
||||||
|
if cls.semantic_layer:
|
||||||
|
semantic_layer = cls.semantic_layer(inspector.engine)
|
||||||
|
tables.update(
|
||||||
|
semantic_view.name
|
||||||
|
for semantic_view in semantic_layer.get_semantic_views()
|
||||||
|
)
|
||||||
|
|
||||||
return tables
|
return tables
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def has_table(
|
||||||
|
cls,
|
||||||
|
database: Database,
|
||||||
|
inspector: Inspector,
|
||||||
|
table: Table,
|
||||||
|
) -> bool:
|
||||||
|
if cls.semantic_layer:
|
||||||
|
semantic_layer = cls.semantic_layer(inspector.engine)
|
||||||
|
semantic_views = semantic_layer.get_semantic_views()
|
||||||
|
if table.table in {semantic_view.name for semantic_view in semantic_views}:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return inspector.has_table(table.table, table.schema)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_view_names( # pylint: disable=unused-argument
|
def get_view_names( # pylint: disable=unused-argument
|
||||||
cls,
|
cls,
|
||||||
@@ -1536,6 +1579,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_columns( # pylint: disable=unused-argument
|
def get_columns( # pylint: disable=unused-argument
|
||||||
cls,
|
cls,
|
||||||
|
database: Database,
|
||||||
inspector: Inspector,
|
inspector: Inspector,
|
||||||
table: Table,
|
table: Table,
|
||||||
options: dict[str, Any] | None = None,
|
options: dict[str, Any] | None = None,
|
||||||
@@ -1543,7 +1587,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||||||
"""
|
"""
|
||||||
Get all columns from a given schema and table.
|
Get all columns from a given schema and table.
|
||||||
|
|
||||||
The inspector will be bound to a catalog, if one was specified.
|
The inspector will be bound to a catalog, if one was specified. If the database
|
||||||
|
supports semantic layers the method will check if the table is a semantic view,
|
||||||
|
and return columns (metrics and dimensions) from it instead.
|
||||||
|
|
||||||
:param inspector: SqlAlchemy Inspector instance
|
:param inspector: SqlAlchemy Inspector instance
|
||||||
:param table: Table instance
|
:param table: Table instance
|
||||||
@@ -1551,6 +1597,26 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||||||
some databases
|
some databases
|
||||||
:return: All columns in table
|
:return: All columns in table
|
||||||
"""
|
"""
|
||||||
|
if cls.semantic_layer:
|
||||||
|
semantic_layer = cls.semantic_layer(inspector.engine)
|
||||||
|
semantic_views = {
|
||||||
|
semantic_view.name: semantic_view
|
||||||
|
for semantic_view in semantic_layer.get_semantic_views()
|
||||||
|
}
|
||||||
|
if semantic_view := semantic_views.get(table.table):
|
||||||
|
dialect = database.get_dialect()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": dimension.name,
|
||||||
|
"column_name": dimension.name,
|
||||||
|
"type": cls.column_datatype_to_string(
|
||||||
|
get_sqla_type_from_dimension_type(dimension.type),
|
||||||
|
dialect,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
for dimension in semantic_layer.get_dimensions(semantic_view)
|
||||||
|
]
|
||||||
|
|
||||||
return convert_inspector_columns(
|
return convert_inspector_columns(
|
||||||
cast(
|
cast(
|
||||||
list[SQLAColumnType],
|
list[SQLAColumnType],
|
||||||
@@ -1568,6 +1634,22 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||||||
"""
|
"""
|
||||||
Get all metrics from a given schema and table.
|
Get all metrics from a given schema and table.
|
||||||
"""
|
"""
|
||||||
|
if cls.semantic_layer:
|
||||||
|
semantic_layer = cls.semantic_layer(inspector.engine)
|
||||||
|
semantic_views = {
|
||||||
|
semantic_view.name: semantic_view
|
||||||
|
for semantic_view in semantic_layer.get_semantic_views()
|
||||||
|
}
|
||||||
|
if semantic_view := semantic_views.get(table.table):
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"metric_name": metric.name,
|
||||||
|
"verbose_name": metric.name,
|
||||||
|
"expression": metric.sql,
|
||||||
|
}
|
||||||
|
for metric in semantic_layer.get_metrics(semantic_view)
|
||||||
|
]
|
||||||
|
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"metric_name": "count",
|
"metric_name": "count",
|
||||||
@@ -1577,6 +1659,62 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_valid_metrics_and_dimensions(
|
||||||
|
cls,
|
||||||
|
database: Database,
|
||||||
|
table: SqlaTable,
|
||||||
|
dimensions: set[str],
|
||||||
|
metrics: set[str],
|
||||||
|
) -> ValidColumnsType:
|
||||||
|
"""
|
||||||
|
Get valid metrics and dimensions.
|
||||||
|
|
||||||
|
Given a datasource, and sets of selected metrics and dimensions, return the
|
||||||
|
sets of valid metrics and dimensions that can further be selected.
|
||||||
|
"""
|
||||||
|
if cls.semantic_layer:
|
||||||
|
with database.get_sqla_engine() as engine:
|
||||||
|
semantic_layer = cls.semantic_layer(engine)
|
||||||
|
semantic_views = {
|
||||||
|
semantic_view.name: semantic_view
|
||||||
|
for semantic_view in semantic_layer.get_semantic_views()
|
||||||
|
}
|
||||||
|
if semantic_view := semantic_views.get(table.table):
|
||||||
|
selected_metrics = {
|
||||||
|
metric
|
||||||
|
for metric in semantic_layer.get_metrics(semantic_view)
|
||||||
|
if metric.name in metrics
|
||||||
|
}
|
||||||
|
selected_dimensions = {
|
||||||
|
dimension
|
||||||
|
for dimension in semantic_layer.get_dimensions(semantic_view)
|
||||||
|
if dimension.name in dimensions
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"metrics": {
|
||||||
|
metric.name
|
||||||
|
for metric in semantic_layer.get_valid_metrics(
|
||||||
|
semantic_view,
|
||||||
|
selected_metrics,
|
||||||
|
selected_dimensions,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
"dimensions": {
|
||||||
|
dimension.name
|
||||||
|
for dimension in semantic_layer.get_valid_dimensions(
|
||||||
|
semantic_view,
|
||||||
|
selected_metrics,
|
||||||
|
selected_dimensions,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"dimensions": {column.column_name for column in table.columns},
|
||||||
|
"metrics": {metric.metric_name for metric in table.metrics},
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def where_latest_partition( # pylint: disable=unused-argument
|
def where_latest_partition( # pylint: disable=unused-argument
|
||||||
cls,
|
cls,
|
||||||
@@ -1843,6 +1981,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||||||
:param kwargs: kwargs to be passed to cursor.execute()
|
:param kwargs: kwargs to be passed to cursor.execute()
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
if cls.semantic_layer:
|
||||||
|
with cls.get_engine(database, schema="tpcds_sf10tcl") as engine:
|
||||||
|
semantic_layer = cls.semantic_layer(engine)
|
||||||
|
query = semantic_layer.get_query_from_standard_sql(query).sql
|
||||||
|
|
||||||
if cls.arraysize:
|
if cls.arraysize:
|
||||||
cursor.arraysize = cls.arraysize
|
cursor.arraysize = cls.arraysize
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -16,11 +16,13 @@
|
|||||||
# under the License.
|
# under the License.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from re import Pattern
|
from re import Pattern
|
||||||
from typing import Any, Optional, TYPE_CHECKING, TypedDict
|
from typing import Any, Iterator, Optional, TYPE_CHECKING, TypedDict
|
||||||
from urllib import parse
|
from urllib import parse
|
||||||
|
|
||||||
from apispec import APISpec
|
from apispec import APISpec
|
||||||
@@ -30,20 +32,48 @@ from cryptography.hazmat.primitives import serialization
|
|||||||
from flask import current_app
|
from flask import current_app
|
||||||
from flask_babel import gettext as __
|
from flask_babel import gettext as __
|
||||||
from marshmallow import fields, Schema
|
from marshmallow import fields, Schema
|
||||||
from sqlalchemy import types
|
from sqlalchemy import text, types
|
||||||
|
from sqlalchemy.engine.interfaces import Dialect
|
||||||
from sqlalchemy.engine.reflection import Inspector
|
from sqlalchemy.engine.reflection import Inspector
|
||||||
from sqlalchemy.engine.url import URL
|
from sqlalchemy.engine.url import URL
|
||||||
|
from sqlglot import exp, parse_one
|
||||||
|
|
||||||
from superset.constants import TimeGrain
|
from superset.constants import TimeGrain
|
||||||
from superset.databases.utils import make_url_safe
|
from superset.databases.utils import make_url_safe
|
||||||
from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType
|
from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType
|
||||||
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
|
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
|
||||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||||
|
from superset.extensions.semantic_layer import (
|
||||||
|
BINARY,
|
||||||
|
BOOLEAN,
|
||||||
|
Column as SemanticColumn,
|
||||||
|
DATE,
|
||||||
|
DATETIME,
|
||||||
|
DECIMAL,
|
||||||
|
Dimension as SemanticDimension,
|
||||||
|
Filter as SemanticFilter,
|
||||||
|
INTEGER,
|
||||||
|
Metric as SemanticMetric,
|
||||||
|
NoSort,
|
||||||
|
NUMBER,
|
||||||
|
OBJECT,
|
||||||
|
Query as SemanticQuery,
|
||||||
|
SemanticView,
|
||||||
|
Sort as SemanticSort,
|
||||||
|
SortDirectionEnum,
|
||||||
|
STRING,
|
||||||
|
Table as SemanticTable,
|
||||||
|
TIME,
|
||||||
|
Type as SemanticType,
|
||||||
|
)
|
||||||
from superset.models.sql_lab import Query
|
from superset.models.sql_lab import Query
|
||||||
|
from superset.sql.parse import Table
|
||||||
from superset.utils import json
|
from superset.utils import json
|
||||||
from superset.utils.core import get_user_agent, QuerySource
|
from superset.utils.core import get_user_agent, QuerySource
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from sqlalchemy.engine.base import Engine
|
||||||
|
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
|
|
||||||
# Regular expressions to catch custom errors
|
# Regular expressions to catch custom errors
|
||||||
@@ -77,6 +107,303 @@ class SnowflakeParametersType(TypedDict):
|
|||||||
warehouse: str
|
warehouse: str
|
||||||
|
|
||||||
|
|
||||||
|
class SnowflakeSemanticLayer:
|
||||||
|
def __init__(self, engine: Engine) -> None:
|
||||||
|
self.engine = engine
|
||||||
|
|
||||||
|
def execute(
|
||||||
|
self,
|
||||||
|
sql: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[dict[str, Any]]:
|
||||||
|
with self.engine.connect() as connection:
|
||||||
|
for row in connection.execute(text(sql), kwargs).mappings():
|
||||||
|
yield dict(row)
|
||||||
|
|
||||||
|
def get_semantic_views(self) -> set[SemanticView]:
|
||||||
|
sql = """
|
||||||
|
SHOW SEMANTIC VIEWS
|
||||||
|
->> SELECT "name" FROM $1;
|
||||||
|
"""
|
||||||
|
return {SemanticView(row["name"]) for row in self.execute(sql)}
|
||||||
|
|
||||||
|
def get_type(self, snowflake_type: str | None) -> type[SemanticType]:
|
||||||
|
if snowflake_type is None:
|
||||||
|
return STRING
|
||||||
|
|
||||||
|
type_map = {
|
||||||
|
STRING: {r"VARCHAR\(\d+\)$", "STRING$", "TEXT$", r"CHAR\(\d+\)$"},
|
||||||
|
INTEGER: {r"NUMBER\(38,\s?0\)$", "INT$", "INTEGER$", "BIGINT$"},
|
||||||
|
DECIMAL: {r"NUMBER\(10,\s?2\)$"},
|
||||||
|
NUMBER: {r"NUMBER\(\d+,\s?\d+\)$", "FLOAT$", "DOUBLE$"},
|
||||||
|
BOOLEAN: {"BOOLEAN$"},
|
||||||
|
DATE: {"DATE$"},
|
||||||
|
DATETIME: {"TIMESTAMP_TZ$", "TIMESTAMP__NTZ$"},
|
||||||
|
TIME: {"TIME$"},
|
||||||
|
OBJECT: {"OBJECT$"},
|
||||||
|
BINARY: {r"BINARY\(\d+\)$", r"VARBINARY\(\d+\)$"},
|
||||||
|
}
|
||||||
|
for semantic_type, patterns in type_map.items():
|
||||||
|
if any(
|
||||||
|
re.match(pattern, snowflake_type, re.IGNORECASE) for pattern in patterns
|
||||||
|
):
|
||||||
|
return semantic_type
|
||||||
|
|
||||||
|
return STRING
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def quote_table(cls, table: Table, dialect: Dialect) -> str:
|
||||||
|
"""
|
||||||
|
Fully quote a table name, including the schema and catalog.
|
||||||
|
"""
|
||||||
|
quoters = {
|
||||||
|
"catalog": dialect.identifier_preparer.quote_schema,
|
||||||
|
"schema": dialect.identifier_preparer.quote_schema,
|
||||||
|
"table": dialect.identifier_preparer.quote,
|
||||||
|
}
|
||||||
|
|
||||||
|
return ".".join(
|
||||||
|
function(getattr(table, key))
|
||||||
|
for key, function in quoters.items()
|
||||||
|
if getattr(table, key)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_metrics(self, semantic_view: SemanticView) -> set[SemanticMetric]:
|
||||||
|
quoted_semantic_view_name = self.quote_table(
|
||||||
|
Table(semantic_view.name),
|
||||||
|
self.engine.dialect,
|
||||||
|
)
|
||||||
|
sql = f"""
|
||||||
|
DESC SEMANTIC VIEW {quoted_semantic_view_name}
|
||||||
|
->> SELECT "object_name", "property", "property_value"
|
||||||
|
FROM $1
|
||||||
|
WHERE
|
||||||
|
"object_kind" = 'METRIC' AND
|
||||||
|
"property" IN ('DATA_TYPE', 'TABLE');
|
||||||
|
""" # noqa: S608 (semantic_view.name is quoted)
|
||||||
|
rows = self.execute(sql)
|
||||||
|
|
||||||
|
metrics: set[SemanticMetric] = set()
|
||||||
|
for name, group in itertools.groupby(rows, key=lambda x: x["object_name"]):
|
||||||
|
attributes = defaultdict(set)
|
||||||
|
for row in group:
|
||||||
|
attributes[row["property"]].add(row["property_value"])
|
||||||
|
|
||||||
|
table = next(iter(attributes["TABLE"]))
|
||||||
|
metric_name = table + "." + name
|
||||||
|
type_ = self.get_type(next(iter(attributes["DATA_TYPE"])))
|
||||||
|
sql = self.engine.dialect.identifier_preparer.quote(metric_name)
|
||||||
|
tables = frozenset(attributes["TABLE"])
|
||||||
|
join_columns = frozenset()
|
||||||
|
|
||||||
|
metrics.add(SemanticMetric(metric_name, type_, sql, tables, join_columns))
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def get_dimensions(self, semantic_view: SemanticView) -> set[SemanticDimension]:
|
||||||
|
quoted_semantic_view_name = self.quote_table(
|
||||||
|
Table(semantic_view.name),
|
||||||
|
self.engine.dialect,
|
||||||
|
)
|
||||||
|
sql = f"""
|
||||||
|
DESC SEMANTIC VIEW {quoted_semantic_view_name}
|
||||||
|
->> SELECT "object_name", "property", "property_value"
|
||||||
|
FROM $1
|
||||||
|
WHERE
|
||||||
|
"object_kind" = 'DIMENSION' AND
|
||||||
|
"property" IN ('DATA_TYPE', 'TABLE');
|
||||||
|
""" # noqa: S608 (semantic_view.name is quoted)
|
||||||
|
rows = self.execute(sql)
|
||||||
|
|
||||||
|
dimensions: set[SemanticDimension] = set()
|
||||||
|
for name, group in itertools.groupby(rows, key=lambda x: x["object_name"]):
|
||||||
|
attributes = defaultdict(set)
|
||||||
|
for row in group:
|
||||||
|
attributes[row["property"]].add(row["property_value"])
|
||||||
|
|
||||||
|
table = next(iter(attributes["TABLE"]))
|
||||||
|
dimension_name = table + "." + name
|
||||||
|
column = SemanticColumn(SemanticTable(table), name)
|
||||||
|
type_ = self.get_type(next(iter(attributes["DATA_TYPE"])))
|
||||||
|
|
||||||
|
dimensions.add(SemanticDimension(column, dimension_name, type_))
|
||||||
|
|
||||||
|
return dimensions
|
||||||
|
|
||||||
|
def get_valid_metrics(
|
||||||
|
self,
|
||||||
|
semantic_view: SemanticView,
|
||||||
|
metrics: set[SemanticMetric],
|
||||||
|
dimensions: set[SemanticDimension],
|
||||||
|
) -> set[SemanticMetric]:
|
||||||
|
# all metrics and dimensions are valid inside a given semantic view
|
||||||
|
return self.get_metrics(semantic_view)
|
||||||
|
|
||||||
|
def get_valid_dimensions(
|
||||||
|
self,
|
||||||
|
semantic_view: SemanticView,
|
||||||
|
metrics: set[SemanticMetric],
|
||||||
|
dimensions: set[SemanticDimension],
|
||||||
|
) -> set[SemanticDimension]:
|
||||||
|
# all metrics and dimensions are valid inside a given semantic view
|
||||||
|
return self.get_dimensions(semantic_view)
|
||||||
|
|
||||||
|
def get_query(
|
||||||
|
self,
|
||||||
|
semantic_view: SemanticView,
|
||||||
|
metrics: set[SemanticMetric],
|
||||||
|
dimensions: set[SemanticDimension],
|
||||||
|
filters: set[SemanticFilter],
|
||||||
|
sort: SemanticSort = NoSort,
|
||||||
|
limit: int | None = None,
|
||||||
|
offset: int | None = None,
|
||||||
|
) -> SemanticQuery:
|
||||||
|
ast = self.build_query(
|
||||||
|
semantic_view,
|
||||||
|
metrics,
|
||||||
|
dimensions,
|
||||||
|
filters,
|
||||||
|
sort,
|
||||||
|
limit,
|
||||||
|
offset,
|
||||||
|
)
|
||||||
|
return SemanticQuery(sql=ast.sql(dialect="snowflake", pretty=True))
|
||||||
|
|
||||||
|
def build_query(
|
||||||
|
self,
|
||||||
|
semantic_view: SemanticView,
|
||||||
|
metrics: set[SemanticMetric],
|
||||||
|
dimensions: set[SemanticDimension],
|
||||||
|
filters: set[SemanticFilter],
|
||||||
|
sort: SemanticSort = NoSort,
|
||||||
|
limit: int | None = None,
|
||||||
|
offset: int | None = None,
|
||||||
|
) -> exp.Select:
|
||||||
|
semantic_view = exp.SemanticView(
|
||||||
|
this=exp.Table(this=exp.Identifier(this=semantic_view.name, quoted=True)),
|
||||||
|
dimensions=[
|
||||||
|
exp.Column(
|
||||||
|
this=exp.Identifier(this=dimension.column.name, quoted=True),
|
||||||
|
table=exp.Identifier(
|
||||||
|
this=dimension.column.relation.name,
|
||||||
|
quoted=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for dimension in dimensions
|
||||||
|
],
|
||||||
|
metrics=[
|
||||||
|
exp.Column(
|
||||||
|
this=exp.Identifier(this=column, quoted=True),
|
||||||
|
table=exp.Identifier(this=table, quoted=True),
|
||||||
|
)
|
||||||
|
for table, column in (
|
||||||
|
metric.name.split(".", 1)
|
||||||
|
for metric in metrics
|
||||||
|
if "." in metric.name
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
query = exp.Select(
|
||||||
|
expressions=[exp.Star()],
|
||||||
|
**{"from": exp.From(this=exp.Table(this=semantic_view))},
|
||||||
|
)
|
||||||
|
|
||||||
|
if sort.items:
|
||||||
|
order = [
|
||||||
|
exp.Ordered(
|
||||||
|
this=exp.Column(this=exp.Identifier(this=item.field.name)),
|
||||||
|
desc=item.direction == SortDirectionEnum.DESC,
|
||||||
|
nulls_first=item.nulls_first,
|
||||||
|
)
|
||||||
|
for item in sort.items
|
||||||
|
]
|
||||||
|
query.args["order"] = exp.Order(expressions=order)
|
||||||
|
|
||||||
|
if offset:
|
||||||
|
query = query.offset(offset)
|
||||||
|
|
||||||
|
if limit:
|
||||||
|
query = query.limit(limit)
|
||||||
|
|
||||||
|
return query
|
||||||
|
|
||||||
|
def get_query_from_standard_sql(self, sql: str) -> SemanticQuery:
|
||||||
|
"""
|
||||||
|
Convert the Explore query into a proper query.
|
||||||
|
|
||||||
|
Explore will produce a pseudo-SQL query that references metrics and dimensions
|
||||||
|
as if they were columns in a table. This method replaces the table name with a
|
||||||
|
call to `SEMANTIC_VIEW`, and removes the `GROUP BY` clause, since all the
|
||||||
|
aggregations happen inside the `SEMANTIC_VIEW` call.
|
||||||
|
"""
|
||||||
|
ast = parse_one(sql, "snowflake")
|
||||||
|
table = ast.find(exp.Table)
|
||||||
|
if not table:
|
||||||
|
return SemanticQuery(sql=sql)
|
||||||
|
|
||||||
|
semantic_views = self.get_semantic_views()
|
||||||
|
if table.name not in {semantic_view.name for semantic_view in semantic_views}:
|
||||||
|
return SemanticQuery(sql=sql)
|
||||||
|
|
||||||
|
# collect all metric and dimensions
|
||||||
|
semantic_view = SemanticView(table.name)
|
||||||
|
all_metrics = self.get_metrics(semantic_view)
|
||||||
|
all_dimensions = self.get_dimensions(semantic_view)
|
||||||
|
|
||||||
|
# collect metrics and dimensions used in the query
|
||||||
|
columns = {column.name for column in ast.find_all(exp.Column)}
|
||||||
|
metrics = [metric for metric in all_metrics if metric.name in columns]
|
||||||
|
dimensions = [
|
||||||
|
dimension for dimension in all_dimensions if dimension.name in columns
|
||||||
|
]
|
||||||
|
|
||||||
|
# now replace table with a call to `SEMANTIC_VIEW`
|
||||||
|
udtf = exp.Table(
|
||||||
|
this=exp.SemanticView(
|
||||||
|
this=exp.Table(
|
||||||
|
this=exp.Identifier(this=semantic_view.name, quoted=True)
|
||||||
|
),
|
||||||
|
metrics=[
|
||||||
|
exp.Column(
|
||||||
|
this=exp.Identifier(this=column, quoted=True),
|
||||||
|
table=exp.Identifier(this=table, quoted=True),
|
||||||
|
)
|
||||||
|
for table, column in (
|
||||||
|
metric.name.split(".", 1)
|
||||||
|
for metric in metrics
|
||||||
|
if "." in metric.name
|
||||||
|
)
|
||||||
|
],
|
||||||
|
dimensions=[
|
||||||
|
exp.Column(
|
||||||
|
this=exp.Identifier(this=column, quoted=True),
|
||||||
|
table=exp.Identifier(this=table, quoted=True),
|
||||||
|
)
|
||||||
|
for table, column in (
|
||||||
|
dimension.name.split(".", 1)
|
||||||
|
for dimension in dimensions
|
||||||
|
if "." in dimension.name
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
alias=exp.TableAlias(
|
||||||
|
this=exp.Identifier(this="table_alias", quoted=False),
|
||||||
|
columns=[
|
||||||
|
exp.Identifier(this=column.name, quoted=True)
|
||||||
|
for column in metrics + dimensions
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
table.replace(udtf)
|
||||||
|
|
||||||
|
# remove group by, since aggregations are done inside the `SEMANTIC_VIEW` call
|
||||||
|
del ast.args["group"]
|
||||||
|
|
||||||
|
print("BETO")
|
||||||
|
print(ast.sql(dialect="snowflake", pretty=True))
|
||||||
|
return SemanticQuery(sql=ast.sql(dialect="snowflake", pretty=True))
|
||||||
|
|
||||||
|
|
||||||
class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
||||||
engine = "snowflake"
|
engine = "snowflake"
|
||||||
engine_name = "Snowflake"
|
engine_name = "Snowflake"
|
||||||
@@ -90,6 +417,8 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
|||||||
default_driver = "snowflake"
|
default_driver = "snowflake"
|
||||||
sqlalchemy_uri_placeholder = "snowflake://"
|
sqlalchemy_uri_placeholder = "snowflake://"
|
||||||
|
|
||||||
|
semantic_layer = SnowflakeSemanticLayer
|
||||||
|
|
||||||
supports_dynamic_schema = True
|
supports_dynamic_schema = True
|
||||||
supports_catalog = supports_dynamic_catalog = supports_cross_catalog_queries = True
|
supports_catalog = supports_dynamic_catalog = supports_cross_catalog_queries = True
|
||||||
|
|
||||||
|
|||||||
340
superset/extensions/semantic_layer.py
Normal file
340
superset/extensions/semantic_layer.py
Normal file
@@ -0,0 +1,340 @@
|
|||||||
|
import enum
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import timedelta
|
||||||
|
from functools import total_ordering
|
||||||
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from sqlalchemy import types as sqltypes
|
||||||
|
from sqlalchemy.engine.base import Engine
|
||||||
|
|
||||||
|
|
||||||
|
class Type:
|
||||||
|
"""
|
||||||
|
Base class for types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class INTEGER(Type):
|
||||||
|
"""
|
||||||
|
Represents an integer type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class NUMBER(Type):
|
||||||
|
"""
|
||||||
|
Represents a number type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class DECIMAL(Type):
|
||||||
|
"""
|
||||||
|
Represents a decimal type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class STRING(Type):
|
||||||
|
"""
|
||||||
|
Represents a string type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class BOOLEAN(Type):
|
||||||
|
"""
|
||||||
|
Represents a boolean type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class DATE(Type):
|
||||||
|
"""
|
||||||
|
Represents a date type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class TIME(Type):
|
||||||
|
"""
|
||||||
|
Represents a time type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class DATETIME(DATE, TIME):
|
||||||
|
"""
|
||||||
|
Represents a datetime type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class INTERVAL(Type):
|
||||||
|
"""
|
||||||
|
Represents an interval type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class OBJECT(Type):
|
||||||
|
"""
|
||||||
|
Represents an object type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class BINARY(Type):
|
||||||
|
"""
|
||||||
|
Represents a binary type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SemanticView:
|
||||||
|
name: str
|
||||||
|
description: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Relation:
|
||||||
|
name: str
|
||||||
|
schema: str | None = None
|
||||||
|
catalog: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Table:
|
||||||
|
name: str
|
||||||
|
schema: str | None = None
|
||||||
|
catalog: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class View:
|
||||||
|
name: str
|
||||||
|
sql: str
|
||||||
|
schema: str | None = None
|
||||||
|
catalog: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Virtual:
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Metric:
|
||||||
|
name: str
|
||||||
|
type: type[Type]
|
||||||
|
sql: str
|
||||||
|
tables: frozenset[Table]
|
||||||
|
join_columns: frozenset[str]
|
||||||
|
|
||||||
|
|
||||||
|
@total_ordering
|
||||||
|
class ComparableEnum(enum.Enum):
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
if isinstance(other, enum.Enum):
|
||||||
|
return self.value == other.value
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def __lt__(self, other: object) -> bool:
|
||||||
|
if isinstance(other, enum.Enum):
|
||||||
|
return self.value < other.value
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash((self.__class__, self.name))
|
||||||
|
|
||||||
|
|
||||||
|
class TimeGrain(ComparableEnum):
|
||||||
|
second = timedelta(seconds=1)
|
||||||
|
minute = timedelta(minutes=1)
|
||||||
|
hour = timedelta(hours=1)
|
||||||
|
|
||||||
|
|
||||||
|
class DateGrain(ComparableEnum):
|
||||||
|
day = timedelta(days=1)
|
||||||
|
week = timedelta(weeks=1)
|
||||||
|
month = timedelta(days=30)
|
||||||
|
quarter = timedelta(days=90)
|
||||||
|
year = timedelta(days=365)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Column:
|
||||||
|
relation: Table | View | Virtual
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Dimension:
|
||||||
|
column: Column
|
||||||
|
name: str
|
||||||
|
type: type[Type]
|
||||||
|
grain: TimeGrain | DateGrain | None = None
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
metadata = f"[{self.grain.name}]" if self.grain else ""
|
||||||
|
return f"{self.type.__name__} {self.name} {metadata}".strip()
|
||||||
|
|
||||||
|
|
||||||
|
class FilterTypeEnum(enum.Enum):
|
||||||
|
WHERE = enum.auto()
|
||||||
|
HAVING = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Filter:
|
||||||
|
type: FilterTypeEnum
|
||||||
|
expression: str
|
||||||
|
|
||||||
|
|
||||||
|
class SortDirectionEnum(enum.Enum):
|
||||||
|
ASC = enum.auto()
|
||||||
|
DESC = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SortField:
|
||||||
|
field: Metric | Dimension
|
||||||
|
direction: SortDirectionEnum
|
||||||
|
nulls_first: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Sort:
|
||||||
|
items: list[SortField]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Query:
|
||||||
|
sql: str
|
||||||
|
|
||||||
|
|
||||||
|
NoSort = Sort(items=[])
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class SemanticLayer(Protocol):
|
||||||
|
"""
|
||||||
|
A generic protocol for semantic layers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, engine: Engine) -> None: ...
|
||||||
|
|
||||||
|
def get_semantic_views(self) -> set[SemanticView]:
|
||||||
|
"""
|
||||||
|
Return a set of the semantic views.
|
||||||
|
|
||||||
|
A semantic view is an organizational group of metrics and dimensions. It's not a
|
||||||
|
logical grouping, since metrics and dimensions from a given semantic view might
|
||||||
|
not be compatible. An implementation might expose a single semantic view for
|
||||||
|
exploration of available metric and dimesnions, and smaller curated semantic
|
||||||
|
views that are domain specific.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_metrics(self, semantic_view: SemanticView) -> set[Metric]:
|
||||||
|
"""
|
||||||
|
Return a set of metrics from a given semantic views.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_dimensions(self, semantic_view: SemanticView) -> set[Dimension]:
|
||||||
|
"""
|
||||||
|
Return a set of dimensions from a given semantic views.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_valid_metrics(
|
||||||
|
self,
|
||||||
|
semantic_view: SemanticView,
|
||||||
|
metrics: set[Metric],
|
||||||
|
dimensions: set[Dimension],
|
||||||
|
) -> set[Metric]:
|
||||||
|
"""
|
||||||
|
Return compatible metrics for the given metrics and dimensions.
|
||||||
|
|
||||||
|
For metrics to be valid they must be compatible with all the provided
|
||||||
|
dimensions.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_valid_dimensions(
|
||||||
|
self,
|
||||||
|
semantic_view: SemanticView,
|
||||||
|
metrics: set[Metric],
|
||||||
|
dimensions: set[Dimension],
|
||||||
|
) -> set[Dimension]:
|
||||||
|
"""
|
||||||
|
Return compatible dimensions for the given metrics.
|
||||||
|
|
||||||
|
For dimensions to be valid they must be compatible with all the provided
|
||||||
|
metrics.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_query(
|
||||||
|
self,
|
||||||
|
semantic_view: SemanticView,
|
||||||
|
metrics: set[Metric],
|
||||||
|
dimensions: set[Dimension],
|
||||||
|
# populations: set[Population],
|
||||||
|
filters: set[Filter],
|
||||||
|
sort: Sort = NoSort,
|
||||||
|
limit: int | None = None,
|
||||||
|
offset: int | None = None,
|
||||||
|
) -> Query:
|
||||||
|
"""
|
||||||
|
Build a SQL query from the given metrics, dimensions, filters, and sort order.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_query_from_standard_sql(
|
||||||
|
self,
|
||||||
|
semantic_view: SemanticView,
|
||||||
|
sql: str,
|
||||||
|
) -> Query:
|
||||||
|
"""
|
||||||
|
Build a SQL query from a pseudo-query referencing metrics and dimensions.
|
||||||
|
|
||||||
|
For example, given `metric1` having the expression `COUNT(*)`, this query:
|
||||||
|
|
||||||
|
SELECT metric1, dim1
|
||||||
|
FROM semantic_layer
|
||||||
|
GROUP BY dim1
|
||||||
|
|
||||||
|
Becomes:
|
||||||
|
|
||||||
|
SELECT metric1, dim1
|
||||||
|
FROM (
|
||||||
|
SELECT COUNT(*) AS metric1, dim1
|
||||||
|
FROM fact_table
|
||||||
|
JOIN dim_table
|
||||||
|
ON fact_table.dim_id = dim_table.id
|
||||||
|
GROUP BY dim1
|
||||||
|
) AS semantic_view
|
||||||
|
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
TYPE_MAPPING: dict[Type, type[sqltypes.TypeEngine]] = {
|
||||||
|
# Numeric types
|
||||||
|
INTEGER: sqltypes.Integer,
|
||||||
|
NUMBER: sqltypes.Numeric,
|
||||||
|
DECIMAL: sqltypes.DECIMAL,
|
||||||
|
# String types
|
||||||
|
STRING: sqltypes.String,
|
||||||
|
# Boolean type
|
||||||
|
BOOLEAN: sqltypes.Boolean,
|
||||||
|
# Date/time types
|
||||||
|
DATE: sqltypes.Date,
|
||||||
|
TIME: sqltypes.Time,
|
||||||
|
DATETIME: sqltypes.DateTime,
|
||||||
|
INTERVAL: sqltypes.Interval,
|
||||||
|
# Complex types
|
||||||
|
OBJECT: sqltypes.JSON,
|
||||||
|
BINARY: sqltypes.LargeBinary,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_sqla_type_from_dimension_type(
|
||||||
|
dimension_type: Type,
|
||||||
|
) -> sqltypes.TypeEngine:
|
||||||
|
"""
|
||||||
|
Get the SQLAlchemy type corresponding to the given dimension type.
|
||||||
|
"""
|
||||||
|
return TYPE_MAPPING.get(dimension_type, sqltypes.String)()
|
||||||
@@ -126,7 +126,9 @@ class ConfigurationMethod(StrEnum):
|
|||||||
DYNAMIC_FORM = "dynamic_form"
|
DYNAMIC_FORM = "dynamic_form"
|
||||||
|
|
||||||
|
|
||||||
class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable=too-many-public-methods
|
class Database(
|
||||||
|
Model, AuditMixinNullable, ImportExportMixin
|
||||||
|
): # pylint: disable=too-many-public-methods
|
||||||
"""An ORM object that stores Database related information"""
|
"""An ORM object that stores Database related information"""
|
||||||
|
|
||||||
__tablename__ = "dbs"
|
__tablename__ = "dbs"
|
||||||
@@ -400,9 +402,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
|||||||
return (
|
return (
|
||||||
username
|
username
|
||||||
if (username := get_username())
|
if (username := get_username())
|
||||||
else object_url.username
|
else object_url.username if self.impersonate_user else None
|
||||||
if self.impersonate_user
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@@ -987,7 +987,10 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
|||||||
schema=table.schema,
|
schema=table.schema,
|
||||||
) as inspector:
|
) as inspector:
|
||||||
return self.db_engine_spec.get_columns(
|
return self.db_engine_spec.get_columns(
|
||||||
inspector, table, self.schema_options
|
self,
|
||||||
|
inspector,
|
||||||
|
table,
|
||||||
|
self.schema_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_metrics(
|
def get_metrics(
|
||||||
@@ -1076,9 +1079,11 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
|||||||
return self.perm
|
return self.perm
|
||||||
|
|
||||||
def has_table(self, table: Table) -> bool:
|
def has_table(self, table: Table) -> bool:
|
||||||
with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine:
|
with self.get_inspector(
|
||||||
# do not pass "" as an empty schema; force null
|
catalog=table.catalog,
|
||||||
return engine.has_table(table.table, table.schema or None)
|
schema=table.schema,
|
||||||
|
) as inspector:
|
||||||
|
return self.db_engine_spec.has_table(self, inspector, table)
|
||||||
|
|
||||||
def has_view(self, table: Table) -> bool:
|
def has_view(self, table: Table) -> bool:
|
||||||
with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine:
|
with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine:
|
||||||
|
|||||||
Reference in New Issue
Block a user