mirror of
https://github.com/apache/superset.git
synced 2026-04-18 23:55:00 +00:00
feat: removing dup logic in sqla/models.py and models/helpers.py (#34177)
This commit is contained in:
committed by
GitHub
parent
8a8248b575
commit
fe9eef9198
@@ -20,15 +20,12 @@ from __future__ import annotations
|
||||
import builtins
|
||||
import dataclasses
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Hashable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Callable, cast, Optional, Union
|
||||
|
||||
import dateutil.parser
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import sqlalchemy as sa
|
||||
from flask_appbuilder import Model
|
||||
@@ -78,7 +75,6 @@ from superset.connectors.sqla.utils import (
|
||||
get_physical_table_metadata,
|
||||
get_virtual_table_metadata,
|
||||
)
|
||||
from superset.constants import EMPTY_STRING, NULL_STRING
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression
|
||||
from superset.exceptions import (
|
||||
ColumnNotFoundException,
|
||||
@@ -108,8 +104,6 @@ from superset.sql.parse import Table
|
||||
from superset.superset_typing import (
|
||||
AdhocColumn,
|
||||
AdhocMetric,
|
||||
FilterValue,
|
||||
FilterValues,
|
||||
Metric,
|
||||
QueryObjectDict,
|
||||
ResultSetColumnType,
|
||||
@@ -506,66 +500,6 @@ class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=
|
||||
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def filter_values_handler( # pylint: disable=too-many-arguments # noqa: C901
|
||||
values: FilterValues | None,
|
||||
operator: str,
|
||||
target_generic_type: utils.GenericDataType,
|
||||
target_native_type: str | None = None,
|
||||
is_list_target: bool = False,
|
||||
db_engine_spec: builtins.type[BaseEngineSpec] | None = None,
|
||||
db_extra: dict[str, Any] | None = None,
|
||||
) -> FilterValues | None:
|
||||
if values is None:
|
||||
return None
|
||||
|
||||
def handle_single_value(value: FilterValue | None) -> FilterValue | None:
|
||||
if operator == utils.FilterOperator.TEMPORAL_RANGE:
|
||||
return value
|
||||
if (
|
||||
isinstance(value, (float, int))
|
||||
and target_generic_type == utils.GenericDataType.TEMPORAL
|
||||
and target_native_type is not None
|
||||
and db_engine_spec is not None
|
||||
):
|
||||
value = db_engine_spec.convert_dttm(
|
||||
target_type=target_native_type,
|
||||
dttm=datetime.utcfromtimestamp(value / 1000),
|
||||
db_extra=db_extra,
|
||||
)
|
||||
value = literal_column(value)
|
||||
if isinstance(value, str):
|
||||
value = value.strip("\t\n")
|
||||
|
||||
if (
|
||||
target_generic_type == utils.GenericDataType.NUMERIC
|
||||
and operator
|
||||
not in {
|
||||
utils.FilterOperator.ILIKE,
|
||||
utils.FilterOperator.LIKE,
|
||||
}
|
||||
):
|
||||
# For backwards compatibility and edge cases
|
||||
# where a column data type might have changed
|
||||
return utils.cast_to_num(value)
|
||||
if value == NULL_STRING:
|
||||
return None
|
||||
if value == EMPTY_STRING:
|
||||
return ""
|
||||
if target_generic_type == utils.GenericDataType.BOOLEAN:
|
||||
return utils.cast_to_boolean(value)
|
||||
return value
|
||||
|
||||
if isinstance(values, (list, tuple)):
|
||||
values = [handle_single_value(v) for v in values] # type: ignore
|
||||
else:
|
||||
values = handle_single_value(values)
|
||||
if is_list_target and not isinstance(values, (tuple, list)):
|
||||
values = [values] # type: ignore
|
||||
elif not is_list_target and isinstance(values, (tuple, list)):
|
||||
values = values[0] if values else None
|
||||
return values
|
||||
|
||||
def external_metadata(self) -> list[ResultSetColumnType]:
|
||||
"""Returns column information from the external system"""
|
||||
raise NotImplementedError()
|
||||
@@ -1227,29 +1161,10 @@ class SqlaTable(
|
||||
def db_extra(self) -> dict[str, Any]:
|
||||
return self.database.get_extra()
|
||||
|
||||
@staticmethod
|
||||
def _apply_cte(sql: str, cte: str | None) -> str:
|
||||
"""
|
||||
Append a CTE before the SELECT statement if defined
|
||||
|
||||
:param sql: SELECT statement
|
||||
:param cte: CTE statement
|
||||
:return:
|
||||
"""
|
||||
if cte:
|
||||
sql = f"{cte}\n{sql}"
|
||||
return sql
|
||||
|
||||
@property
|
||||
def db_engine_spec(self) -> __builtins__.type[BaseEngineSpec]:
|
||||
return self.database.db_engine_spec
|
||||
|
||||
@property
|
||||
def changed_by_name(self) -> str:
|
||||
if not self.changed_by:
|
||||
return ""
|
||||
return str(self.changed_by)
|
||||
|
||||
@property
|
||||
def connection(self) -> str:
|
||||
return str(self.database)
|
||||
@@ -1452,11 +1367,6 @@ class SqlaTable(
|
||||
def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor:
|
||||
return get_template_processor(table=self, database=self.database, **kwargs)
|
||||
|
||||
def get_query_str(self, query_obj: QueryObjectDict) -> str:
|
||||
query_str_ext = self.get_query_str_extended(query_obj)
|
||||
all_queries = query_str_ext.prequeries + [query_str_ext.sql]
|
||||
return ";\n\n".join(all_queries) + ";"
|
||||
|
||||
def get_sqla_table(self) -> TableClause:
|
||||
tbl = table(self.table_name)
|
||||
if self.schema:
|
||||
@@ -1588,38 +1498,6 @@ class SqlaTable(
|
||||
)
|
||||
return self.make_sqla_column_compatible(sqla_column, label)
|
||||
|
||||
def make_orderby_compatible(
|
||||
self, select_exprs: list[ColumnElement], orderby_exprs: list[ColumnElement]
|
||||
) -> None:
|
||||
"""
|
||||
If needed, make sure aliases for selected columns are not used in
|
||||
`ORDER BY`.
|
||||
|
||||
In some databases (e.g. Presto), `ORDER BY` clause is not able to
|
||||
automatically pick the source column if a `SELECT` clause alias is named
|
||||
the same as a source column. In this case, we update the SELECT alias to
|
||||
another name to avoid the conflict.
|
||||
"""
|
||||
if self.db_engine_spec.allows_alias_to_source_column:
|
||||
return
|
||||
|
||||
def is_alias_used_in_orderby(col: ColumnElement) -> bool:
|
||||
if not isinstance(col, Label):
|
||||
return False
|
||||
regexp = re.compile(f"\\(.*\\b{re.escape(col.name)}\\b.*\\)", re.IGNORECASE)
|
||||
return any(regexp.search(str(x)) for x in orderby_exprs)
|
||||
|
||||
# Iterate through selected columns, if column alias appears in orderby
|
||||
# use another `alias`. The final output columns will still use the
|
||||
# original names, because they are updated by `labels_expected` after
|
||||
# querying.
|
||||
for col in select_exprs:
|
||||
if is_alias_used_in_orderby(col):
|
||||
col.name = f"{col.name}__"
|
||||
|
||||
def text(self, clause: str) -> TextClause:
|
||||
return self.db_engine_spec.get_text_clause(clause)
|
||||
|
||||
def _get_series_orderby(
|
||||
self,
|
||||
series_limit_metric: Metric,
|
||||
@@ -1643,43 +1521,6 @@ class SqlaTable(
|
||||
)
|
||||
return ob
|
||||
|
||||
def _normalize_prequery_result_type(
|
||||
self,
|
||||
row: pd.Series,
|
||||
dimension: str,
|
||||
columns_by_name: dict[str, TableColumn],
|
||||
) -> str | int | float | bool | Text:
|
||||
"""
|
||||
Convert a prequery result type to its equivalent Python type.
|
||||
|
||||
Some databases like Druid will return timestamps as strings, but do not perform
|
||||
automatic casting when comparing these strings to a timestamp. For cases like
|
||||
this we convert the value via the appropriate SQL transform.
|
||||
|
||||
:param row: A prequery record
|
||||
:param dimension: The dimension name
|
||||
:param columns_by_name: The mapping of columns by name
|
||||
:return: equivalent primitive python type
|
||||
"""
|
||||
|
||||
value = row[dimension]
|
||||
|
||||
if isinstance(value, np.generic):
|
||||
value = value.item()
|
||||
|
||||
column_ = columns_by_name.get(dimension)
|
||||
db_extra: dict[str, Any] = self.database.get_extra()
|
||||
|
||||
if column_ and column_.type and column_.is_temporal and isinstance(value, str):
|
||||
sql = self.db_engine_spec.convert_dttm(
|
||||
column_.type, dateutil.parser.parse(value), db_extra=db_extra
|
||||
)
|
||||
|
||||
if sql:
|
||||
value = self.text(sql)
|
||||
|
||||
return value
|
||||
|
||||
def _get_top_groups(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
@@ -2053,6 +1894,14 @@ class SqlaTable(
|
||||
session = inspect(self).session # pylint: disable=disallowed-name
|
||||
self.database = session.query(Database).filter_by(id=self.database_id).one()
|
||||
|
||||
def get_query_str(self, query_obj: QueryObjectDict) -> str:
|
||||
"""Returns a query as a string using ExploreMixin implementation"""
|
||||
return ExploreMixin.get_query_str(self, query_obj)
|
||||
|
||||
def text(self, clause: str) -> TextClause:
|
||||
"""Returns a text clause using ExploreMixin implementation"""
|
||||
return ExploreMixin.text(self, clause)
|
||||
|
||||
|
||||
sa.event.listen(SqlaTable, "before_update", SqlaTable.before_update)
|
||||
sa.event.listen(SqlaTable, "after_insert", SqlaTable.after_insert)
|
||||
|
||||
@@ -919,10 +919,13 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||
if isinstance(value, np.generic):
|
||||
value = value.item()
|
||||
|
||||
column_ = columns_by_name[dimension]
|
||||
column_ = columns_by_name.get(dimension)
|
||||
db_extra: dict[str, Any] = self.database.get_extra()
|
||||
|
||||
if isinstance(column_, dict):
|
||||
if column_ is None:
|
||||
# Column not found, return value as-is
|
||||
pass
|
||||
elif isinstance(column_, dict):
|
||||
if (
|
||||
column_.get("type")
|
||||
and column_.get("is_temporal")
|
||||
@@ -941,7 +944,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||
)
|
||||
|
||||
if sql:
|
||||
value = self.text(sql)
|
||||
value = self.db_engine_spec.get_text_clause(sql)
|
||||
return value
|
||||
|
||||
def make_orderby_compatible(
|
||||
|
||||
Reference in New Issue
Block a user