feat: removing dup logic in sqla/models.py and models/helpers.py (#34177)

This commit is contained in:
Maxime Beauchemin
2025-07-15 14:02:57 -07:00
committed by GitHub
parent 8a8248b575
commit fe9eef9198
2 changed files with 14 additions and 162 deletions

View File

@@ -20,15 +20,12 @@ from __future__ import annotations
import builtins
import dataclasses
import logging
import re
from collections import defaultdict
from collections.abc import Hashable
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Any, Callable, cast, Optional, Union
import dateutil.parser
import numpy as np
import pandas as pd
import sqlalchemy as sa
from flask_appbuilder import Model
@@ -78,7 +75,6 @@ from superset.connectors.sqla.utils import (
get_physical_table_metadata,
get_virtual_table_metadata,
)
from superset.constants import EMPTY_STRING, NULL_STRING
from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression
from superset.exceptions import (
ColumnNotFoundException,
@@ -108,8 +104,6 @@ from superset.sql.parse import Table
from superset.superset_typing import (
AdhocColumn,
AdhocMetric,
FilterValue,
FilterValues,
Metric,
QueryObjectDict,
ResultSetColumnType,
@@ -506,66 +500,6 @@ class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=
return data
@staticmethod
def filter_values_handler( # pylint: disable=too-many-arguments # noqa: C901
values: FilterValues | None,
operator: str,
target_generic_type: utils.GenericDataType,
target_native_type: str | None = None,
is_list_target: bool = False,
db_engine_spec: builtins.type[BaseEngineSpec] | None = None,
db_extra: dict[str, Any] | None = None,
) -> FilterValues | None:
if values is None:
return None
def handle_single_value(value: FilterValue | None) -> FilterValue | None:
if operator == utils.FilterOperator.TEMPORAL_RANGE:
return value
if (
isinstance(value, (float, int))
and target_generic_type == utils.GenericDataType.TEMPORAL
and target_native_type is not None
and db_engine_spec is not None
):
value = db_engine_spec.convert_dttm(
target_type=target_native_type,
dttm=datetime.utcfromtimestamp(value / 1000),
db_extra=db_extra,
)
value = literal_column(value)
if isinstance(value, str):
value = value.strip("\t\n")
if (
target_generic_type == utils.GenericDataType.NUMERIC
and operator
not in {
utils.FilterOperator.ILIKE,
utils.FilterOperator.LIKE,
}
):
# For backwards compatibility and edge cases
# where a column data type might have changed
return utils.cast_to_num(value)
if value == NULL_STRING:
return None
if value == EMPTY_STRING:
return ""
if target_generic_type == utils.GenericDataType.BOOLEAN:
return utils.cast_to_boolean(value)
return value
if isinstance(values, (list, tuple)):
values = [handle_single_value(v) for v in values] # type: ignore
else:
values = handle_single_value(values)
if is_list_target and not isinstance(values, (tuple, list)):
values = [values] # type: ignore
elif not is_list_target and isinstance(values, (tuple, list)):
values = values[0] if values else None
return values
def external_metadata(self) -> list[ResultSetColumnType]:
"""Returns column information from the external system"""
raise NotImplementedError()
@@ -1227,29 +1161,10 @@ class SqlaTable(
def db_extra(self) -> dict[str, Any]:
return self.database.get_extra()
@staticmethod
def _apply_cte(sql: str, cte: str | None) -> str:
"""
Append a CTE before the SELECT statement if defined
:param sql: SELECT statement
:param cte: CTE statement
:return:
"""
if cte:
sql = f"{cte}\n{sql}"
return sql
@property
def db_engine_spec(self) -> __builtins__.type[BaseEngineSpec]:
return self.database.db_engine_spec
@property
def changed_by_name(self) -> str:
if not self.changed_by:
return ""
return str(self.changed_by)
@property
def connection(self) -> str:
return str(self.database)
@@ -1452,11 +1367,6 @@ class SqlaTable(
def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor:
return get_template_processor(table=self, database=self.database, **kwargs)
def get_query_str(self, query_obj: QueryObjectDict) -> str:
query_str_ext = self.get_query_str_extended(query_obj)
all_queries = query_str_ext.prequeries + [query_str_ext.sql]
return ";\n\n".join(all_queries) + ";"
def get_sqla_table(self) -> TableClause:
tbl = table(self.table_name)
if self.schema:
@@ -1588,38 +1498,6 @@ class SqlaTable(
)
return self.make_sqla_column_compatible(sqla_column, label)
def make_orderby_compatible(
self, select_exprs: list[ColumnElement], orderby_exprs: list[ColumnElement]
) -> None:
"""
If needed, make sure aliases for selected columns are not used in
`ORDER BY`.
In some databases (e.g. Presto), `ORDER BY` clause is not able to
automatically pick the source column if a `SELECT` clause alias is named
the same as a source column. In this case, we update the SELECT alias to
another name to avoid the conflict.
"""
if self.db_engine_spec.allows_alias_to_source_column:
return
def is_alias_used_in_orderby(col: ColumnElement) -> bool:
if not isinstance(col, Label):
return False
regexp = re.compile(f"\\(.*\\b{re.escape(col.name)}\\b.*\\)", re.IGNORECASE)
return any(regexp.search(str(x)) for x in orderby_exprs)
# Iterate through selected columns, if column alias appears in orderby
# use another `alias`. The final output columns will still use the
# original names, because they are updated by `labels_expected` after
# querying.
for col in select_exprs:
if is_alias_used_in_orderby(col):
col.name = f"{col.name}__"
def text(self, clause: str) -> TextClause:
return self.db_engine_spec.get_text_clause(clause)
def _get_series_orderby(
self,
series_limit_metric: Metric,
@@ -1643,43 +1521,6 @@ class SqlaTable(
)
return ob
def _normalize_prequery_result_type(
self,
row: pd.Series,
dimension: str,
columns_by_name: dict[str, TableColumn],
) -> str | int | float | bool | Text:
"""
Convert a prequery result type to its equivalent Python type.
Some databases like Druid will return timestamps as strings, but do not perform
automatic casting when comparing these strings to a timestamp. For cases like
this we convert the value via the appropriate SQL transform.
:param row: A prequery record
:param dimension: The dimension name
:param columns_by_name: The mapping of columns by name
:return: equivalent primitive python type
"""
value = row[dimension]
if isinstance(value, np.generic):
value = value.item()
column_ = columns_by_name.get(dimension)
db_extra: dict[str, Any] = self.database.get_extra()
if column_ and column_.type and column_.is_temporal and isinstance(value, str):
sql = self.db_engine_spec.convert_dttm(
column_.type, dateutil.parser.parse(value), db_extra=db_extra
)
if sql:
value = self.text(sql)
return value
def _get_top_groups(
self,
df: pd.DataFrame,
@@ -2053,6 +1894,14 @@ class SqlaTable(
session = inspect(self).session # pylint: disable=disallowed-name
self.database = session.query(Database).filter_by(id=self.database_id).one()
def get_query_str(self, query_obj: QueryObjectDict) -> str:
"""Returns a query as a string using ExploreMixin implementation"""
return ExploreMixin.get_query_str(self, query_obj)
def text(self, clause: str) -> TextClause:
"""Returns a text clause using ExploreMixin implementation"""
return ExploreMixin.text(self, clause)
sa.event.listen(SqlaTable, "before_update", SqlaTable.before_update)
sa.event.listen(SqlaTable, "after_insert", SqlaTable.after_insert)

View File

@@ -919,10 +919,13 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
if isinstance(value, np.generic):
value = value.item()
column_ = columns_by_name[dimension]
column_ = columns_by_name.get(dimension)
db_extra: dict[str, Any] = self.database.get_extra()
if isinstance(column_, dict):
if column_ is None:
# Column not found, return value as-is
pass
elif isinstance(column_, dict):
if (
column_.get("type")
and column_.get("is_temporal")
@@ -941,7 +944,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
)
if sql:
value = self.text(sql)
value = self.db_engine_spec.get_text_clause(sql)
return value
def make_orderby_compatible(