Files
superset2/superset/models/helpers.py
2025-12-04 13:18:34 -05:00

3449 lines
130 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-lines
"""a collection of model-related helper classes and functions"""
from __future__ import annotations
import builtins
import copy
import dataclasses
import logging
import re
import uuid
from collections.abc import Hashable
from datetime import datetime, timedelta
from typing import (
Any,
Callable,
cast,
NamedTuple,
Optional,
TYPE_CHECKING,
TypedDict,
Union,
)
import dateutil.parser
import humanize
import numpy as np
import pandas as pd
import pytz
import sqlalchemy as sa
import yaml
from flask import current_app as app, g
from flask_appbuilder import Model
from flask_appbuilder.models.decorators import renders
from flask_appbuilder.models.mixins import AuditMixin
from flask_appbuilder.security.sqla.models import User
from flask_babel import get_locale, lazy_gettext as _
from jinja2.exceptions import TemplateError
from markupsafe import escape, Markup
from pandas import DateOffset
from sqlalchemy import and_, Column, or_, UniqueConstraint
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import Mapper, validates
from sqlalchemy.sql.elements import ColumnElement, literal_column, TextClause
from sqlalchemy.sql.expression import Label, Select, TextAsFrom
from sqlalchemy.sql.selectable import Alias, TableClause
from sqlalchemy_utils import UUIDType
from superset import db, is_feature_enabled
from superset.advanced_data_type.types import AdvancedDataTypeResponse
from superset.common.db_query_status import QueryStatus
from superset.common.utils import dataframe_utils
from superset.common.utils.time_range_utils import (
get_since_until_from_query_object,
get_since_until_from_time_range,
)
from superset.constants import CacheRegion, EMPTY_STRING, NULL_STRING, TimeGrain
from superset.db_engine_specs.base import TimestampExpression
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
AdvancedDataTypeResponseError,
ColumnNotFoundException,
InvalidPostProcessingError,
QueryClauseValidationException,
QueryObjectValidationError,
SupersetErrorException,
SupersetErrorsException,
SupersetSecurityException,
SupersetSyntaxErrorException,
)
from superset.extensions import feature_flag_manager
from superset.jinja_context import BaseTemplateProcessor
from superset.sql.parse import sanitize_clause, SQLScript, SQLStatement
from superset.superset_typing import (
AdhocMetric,
Column as ColumnTyping,
FilterValue,
FilterValues,
Metric,
OrderBy,
QueryObjectDict,
)
from superset.utils import core as utils, json
from superset.utils.core import (
DateColumn,
DTTM_ALIAS,
FilterOperator,
GenericDataType,
get_base_axis_labels,
get_column_name,
get_metric_names,
get_non_base_axis_columns,
get_user_id,
get_x_axis_label,
is_adhoc_column,
MediumText,
normalize_dttm_col,
QueryObjectFilterClause,
remove_duplicates,
SqlExpressionType,
TIME_COMPARISON,
)
from superset.utils.date_parser import get_past_or_future, normalize_time_delta
from superset.utils.dates import datetime_to_epoch
from superset.utils.rls import apply_rls
class ValidationResultDict(TypedDict):
"""Type for validation result objects returned by validate_expression."""
valid: bool
errors: list[dict[str, Any]]
if TYPE_CHECKING:
from superset.common.query_object import QueryObject
from superset.connectors.sqla.models import SqlMetric, TableColumn
from superset.db_engine_specs import BaseEngineSpec
from superset.models.core import Database
logger = logging.getLogger(__name__)
VIRTUAL_TABLE_ALIAS = "virtual_table"
SERIES_LIMIT_SUBQ_ALIAS = "series_limit"
# Offset join column suffix used for joining offset results
OFFSET_JOIN_COLUMN_SUFFIX = "__offset_join_column_"
# Right suffix used for joining offset results
R_SUFFIX = "__right_suffix"
class CachedTimeOffset(TypedDict):
"""Result type for time offset processing"""
df: pd.DataFrame
queries: list[str]
cache_keys: list[str | None]
# Keys used to filter QueryObjectDict for get_sqla_query parameters
SQLA_QUERY_KEYS = {
"apply_fetch_values_predicate",
"columns",
"extras",
"filter",
"from_dttm",
"granularity",
"groupby",
"inner_from_dttm",
"inner_to_dttm",
"is_rowcount",
"is_timeseries",
"metrics",
"orderby",
"order_desc",
"to_dttm",
"series_columns",
"series_limit",
"series_limit_metric",
"group_others_when_limit_reached",
"row_limit",
"row_offset",
"timeseries_limit",
"timeseries_limit_metric",
"time_shift",
}
def validate_adhoc_subquery(
sql: str,
database: Database,
catalog: str | None,
default_schema: str,
engine: str,
) -> str:
"""
Check if adhoc SQL contains sub-queries or nested sub-queries with table.
If sub-queries are allowed, the adhoc SQL is modified to insert any applicable RLS
predicates to it.
:param sql: adhoc sql expression
:raise SupersetSecurityException if sql contains sub-queries or
nested sub-queries with table
"""
parsed_statement = SQLStatement(sql, engine)
if parsed_statement.has_subquery():
if not is_feature_enabled("ALLOW_ADHOC_SUBQUERY"):
raise SupersetSecurityException(
SupersetError(
error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR,
message=_("Custom SQL fields cannot contain sub-queries."),
level=ErrorLevel.ERROR,
)
)
# enforce RLS rules in any relevant tables
apply_rls(database, catalog, default_schema, parsed_statement)
return parsed_statement.format()
def json_to_dict(json_str: str) -> dict[Any, Any]:
if json_str:
val = re.sub(",[ \t\r\n]+}", "}", json_str)
val = re.sub(",[ \t\r\n]+\\]", "]", val)
return json.loads(val)
return {}
def convert_uuids(obj: Any) -> Any:
"""
Convert UUID objects to str so we can use yaml.safe_dump
"""
if isinstance(obj, uuid.UUID):
return str(obj)
if isinstance(obj, list):
return [convert_uuids(el) for el in obj]
if isinstance(obj, dict):
return {k: convert_uuids(v) for k, v in obj.items()}
return obj
class UUIDMixin: # pylint: disable=too-few-public-methods
uuid = sa.Column(
UUIDType(binary=True), primary_key=False, unique=True, default=uuid.uuid4
)
@property
def short_uuid(self) -> str:
return str(self.uuid)[:8]
class ImportExportMixin(UUIDMixin):
export_parent: Optional[str] = None
# The name of the attribute
# with the SQL Alchemy back reference
export_children: list[str] = []
# List of (str) names of attributes
# with the SQL Alchemy forward references
export_fields: list[str] = []
# The names of the attributes
# that are available for import and export
extra_import_fields: list[str] = []
# Additional fields that should be imported,
# even though they were not exported
__mapper__: Mapper
@classmethod
def _unique_constraints(cls) -> list[set[str]]:
"""Get all (single column and multi column) unique constraints"""
unique = [
{c.name for c in u.columns}
for u in cls.__table_args__ # type: ignore
if isinstance(u, UniqueConstraint)
]
unique.extend({c.name} for c in cls.__table__.columns if c.unique) # type: ignore
return unique
@classmethod
def parent_foreign_key_mappings(cls) -> dict[str, str]:
"""Get a mapping of foreign name to the local name of foreign keys"""
parent_rel = cls.__mapper__.relationships.get(cls.export_parent)
if parent_rel:
return {
local.name: remote.name
for (local, remote) in parent_rel.local_remote_pairs
}
return {}
@classmethod
def export_schema(
cls, recursive: bool = True, include_parent_ref: bool = False
) -> dict[str, Any]:
"""Export schema as a dictionary"""
parent_excludes = set()
if not include_parent_ref:
parent_ref = cls.__mapper__.relationships.get(cls.export_parent)
if parent_ref:
parent_excludes = {column.name for column in parent_ref.local_columns}
def formatter(column: sa.Column) -> str:
return (
f"{str(column.type)} Default ({column.default.arg})"
if column.default
else str(column.type)
)
schema: dict[str, Any] = {
column.name: formatter(column)
for column in cls.__table__.columns # type: ignore
if (column.name in cls.export_fields and column.name not in parent_excludes)
}
if recursive:
for column in cls.export_children:
child_class = cls.__mapper__.relationships[column].argument.class_
schema[column] = [
child_class.export_schema(
recursive=recursive, include_parent_ref=include_parent_ref
)
]
return schema
@classmethod
def import_from_dict( # noqa: C901
# pylint: disable=too-many-arguments,too-many-branches,too-many-locals
cls,
dict_rep: dict[Any, Any],
parent: Optional[Any] = None,
recursive: bool = True,
sync: Optional[list[str]] = None,
allow_reparenting: bool = False,
) -> Any:
"""Import obj from a dictionary"""
if sync is None:
sync = []
parent_refs = cls.parent_foreign_key_mappings()
export_fields = (
set(cls.export_fields)
| set(cls.extra_import_fields)
| set(parent_refs.keys())
| {"uuid"}
)
new_children = {c: dict_rep[c] for c in cls.export_children if c in dict_rep}
unique_constraints = cls._unique_constraints()
filters = [] # Using these filters to check if obj already exists
# Remove fields that should not get imported
for k in list(dict_rep):
if k not in export_fields and k not in parent_refs:
del dict_rep[k]
if not parent:
if cls.export_parent:
for prnt in parent_refs.keys():
if prnt not in dict_rep:
raise RuntimeError(f"{cls.__name__}: Missing field {prnt}")
else:
# Set foreign keys to parent obj
for k, v in parent_refs.items():
dict_rep[k] = getattr(parent, v)
if not allow_reparenting:
# Add filter for parent obj
filters.extend(
[getattr(cls, k) == dict_rep.get(k) for k in parent_refs.keys()]
)
# Add filter for unique constraints
ucs = [
and_(
*[
getattr(cls, k) == dict_rep.get(k)
for k in cs
if dict_rep.get(k) is not None
]
)
for cs in unique_constraints
]
filters.append(or_(*ucs))
# Check if object already exists in DB, break if more than one is found
try:
obj_query = db.session.query(cls).filter(and_(*filters))
obj = obj_query.one_or_none()
except MultipleResultsFound:
logger.error(
"Error importing %s \n %s \n %s",
cls.__name__,
str(obj_query),
yaml.safe_dump(dict_rep),
exc_info=True,
)
raise
if not obj:
is_new_obj = True
# Create new DB object
obj = cls(**dict_rep)
logger.debug("Importing new %s %s", obj.__tablename__, str(obj))
if cls.export_parent and parent:
setattr(obj, cls.export_parent, parent)
db.session.add(obj)
else:
is_new_obj = False
logger.debug("Updating %s %s", obj.__tablename__, str(obj))
# Update columns
for k, v in dict_rep.items():
setattr(obj, k, v)
# Recursively create children
if recursive:
for child in cls.export_children:
argument = cls.__mapper__.relationships[child].argument
child_class = (
argument.class_ if hasattr(argument, "class_") else argument
)
added = []
for c_obj in new_children.get(child, []):
added.append(
child_class.import_from_dict(
dict_rep=c_obj, parent=obj, sync=sync
)
)
# If children should get synced, delete the ones that did not
# get updated.
if child in sync and not is_new_obj:
back_refs = child_class.parent_foreign_key_mappings()
delete_filters = [
getattr(child_class, k) == getattr(obj, back_refs.get(k))
for k in back_refs.keys()
]
to_delete = set(
db.session.query(child_class).filter(and_(*delete_filters))
).difference(set(added))
for o in to_delete:
logger.debug("Deleting %s %s", child, str(obj))
db.session.delete(o)
return obj
def export_to_dict(
self,
recursive: bool = True,
include_parent_ref: bool = False,
include_defaults: bool = False,
export_uuids: bool = False,
) -> dict[Any, Any]:
"""Export obj to dictionary"""
export_fields = set(self.export_fields)
if export_uuids:
export_fields.add("uuid")
if "id" in export_fields:
export_fields.remove("id")
cls = self.__class__
parent_excludes = set()
if recursive and not include_parent_ref:
parent_ref = cls.__mapper__.relationships.get(cls.export_parent)
if parent_ref:
parent_excludes = {c.name for c in parent_ref.local_columns}
dict_rep = {
# Convert c.name to str to handle SQLAlchemy's quoted_name type
# which is not YAML-serializable
str(c.name): getattr(self, c.name)
for c in cls.__table__.columns # type: ignore
if (
c.name in export_fields
and c.name not in parent_excludes
and (
include_defaults
or (
getattr(self, c.name) is not None
and (not c.default or getattr(self, c.name) != c.default.arg)
)
)
)
}
# sort according to export_fields using DSU (decorate, sort, undecorate)
order = {field: i for i, field in enumerate(self.export_fields)}
decorated_keys = [(order.get(k, len(order)), k) for k in dict_rep]
decorated_keys.sort()
dict_rep = {k: dict_rep[k] for _, k in decorated_keys}
if recursive:
for cld in self.export_children:
# sorting to make lists of children stable
dict_rep[cld] = sorted(
[
child.export_to_dict(
recursive=recursive,
include_parent_ref=include_parent_ref,
include_defaults=include_defaults,
)
for child in getattr(self, cld)
],
key=lambda k: sorted(str(k.items())),
)
return convert_uuids(dict_rep)
def override(self, obj: Any) -> None:
"""Overrides the plain fields of the dashboard."""
for field in obj.__class__.export_fields:
setattr(self, field, getattr(obj, field))
def copy(self) -> Any:
"""Creates a copy of the dashboard without relationships."""
new_obj = self.__class__()
new_obj.override(self)
return new_obj
def alter_params(self, **kwargs: Any) -> None:
params = self.params_dict
params.update(kwargs)
self.params = json.dumps(params)
def remove_params(self, param_to_remove: str) -> None:
params = self.params_dict
params.pop(param_to_remove, None)
self.params = json.dumps(params)
def reset_ownership(self) -> None:
"""object will belong to the user the current user"""
# make sure the object doesn't have relations to a user
# it will be filled by appbuilder on save
self.created_by = None
self.changed_by = None
# flask global context might not exist (in cli or tests for example)
self.owners = []
if g and hasattr(g, "user"):
self.owners = [g.user]
@property
def params_dict(self) -> dict[Any, Any]:
return json_to_dict(self.params)
@property
def template_params_dict(self) -> dict[Any, Any]:
return json_to_dict(self.template_params) # type: ignore
def _user(user: User) -> str:
if not user:
return ""
return escape(user)
class AuditMixinNullable(AuditMixin):
"""Altering the AuditMixin to use nullable fields
Allows creating objects programmatically outside of CRUD
"""
created_on = sa.Column(sa.DateTime, default=datetime.now, nullable=True)
changed_on = sa.Column(
sa.DateTime, default=datetime.now, onupdate=datetime.now, nullable=True
)
@declared_attr
def created_by_fk(self) -> sa.Column: # pylint: disable=arguments-renamed
return sa.Column(
sa.Integer,
sa.ForeignKey("ab_user.id"),
default=get_user_id,
nullable=True,
)
@declared_attr
def changed_by_fk(self) -> sa.Column: # pylint: disable=arguments-renamed
return sa.Column(
sa.Integer,
sa.ForeignKey("ab_user.id"),
default=get_user_id,
onupdate=get_user_id,
nullable=True,
)
@property
def created_by_name(self) -> str:
if self.created_by:
return escape(f"{self.created_by}")
return ""
@property
def changed_by_name(self) -> str:
if self.changed_by:
return escape(f"{self.changed_by}")
return ""
@renders("created_by")
def creator(self) -> Union[Markup, str]:
return _user(self.created_by)
@property
def changed_by_(self) -> Union[Markup, str]:
return _user(self.changed_by)
@renders("changed_on")
def changed_on_(self) -> Markup:
return Markup(f'<span class="no-wrap">{self.changed_on}</span>')
@renders("changed_on")
def changed_on_delta_humanized(self) -> str:
return self.changed_on_humanized
@renders("changed_on")
def changed_on_dttm(self) -> float:
return datetime_to_epoch(self.changed_on)
@renders("created_on")
def created_on_delta_humanized(self) -> str:
return self.created_on_humanized
@renders("changed_on")
def changed_on_utc(self) -> str:
# Convert naive datetime to UTC
return self.changed_on.astimezone(pytz.utc).strftime("%Y-%m-%dT%H:%M:%S.%f%z")
def _format_time_humanized(self, timestamp: datetime) -> str:
locale = str(get_locale())
time_diff = datetime.now() - timestamp
# Skip activation for 'en' locale as it's humanize's default locale
if locale == "en":
return humanize.naturaltime(time_diff)
try:
humanize.i18n.activate(locale)
result = humanize.naturaltime(time_diff)
humanize.i18n.deactivate()
return result
except Exception as e:
logger.warning("Locale '%s' is not supported in humanize: %s", locale, e)
return humanize.naturaltime(time_diff)
@property
def changed_on_humanized(self) -> str:
return self._format_time_humanized(self.changed_on)
@property
def created_on_humanized(self) -> str:
return self._format_time_humanized(self.created_on)
@renders("changed_on")
def modified(self) -> Markup:
return Markup(f'<span class="no-wrap">{self.changed_on_humanized}</span>')
class QueryResult: # pylint: disable=too-few-public-methods
"""Object returned by the query interface"""
def __init__( # pylint: disable=too-many-arguments
self,
df: pd.DataFrame,
query: str,
duration: timedelta,
applied_template_filters: Optional[list[str]] = None,
applied_filter_columns: Optional[list[ColumnTyping]] = None,
rejected_filter_columns: Optional[list[ColumnTyping]] = None,
status: str = QueryStatus.SUCCESS,
error_message: Optional[str] = None,
errors: Optional[list[dict[str, Any]]] = None,
from_dttm: Optional[datetime] = None,
to_dttm: Optional[datetime] = None,
) -> None:
self.df = df
self.query = query
self.duration = duration
self.applied_template_filters = applied_template_filters or []
self.applied_filter_columns = applied_filter_columns or []
self.rejected_filter_columns = rejected_filter_columns or []
self.status = status
self.error_message = error_message
self.errors = errors or []
self.from_dttm = from_dttm
self.to_dttm = to_dttm
self.sql_rowcount = len(self.df.index) if not self.df.empty else 0
class ExtraJSONMixin:
"""Mixin to add an `extra` column (JSON) and utility methods"""
extra_json = sa.Column(MediumText(), default="{}")
@property
def extra(self) -> dict[str, Any]:
try:
return json.loads(self.extra_json or "{}") or {}
except (TypeError, json.JSONDecodeError) as exc:
logger.error(
"Unable to load an extra json: %r. Leaving empty.", exc, exc_info=True
)
return {}
@extra.setter
def extra(self, extras: dict[str, Any]) -> None:
self.extra_json = json.dumps(extras)
def set_extra_json_key(self, key: str, value: Any) -> None:
extra = self.extra
extra[key] = value
self.extra_json = json.dumps(extra)
@validates("extra_json")
def ensure_extra_json_is_not_none(
self,
_: str,
value: Optional[dict[str, Any]],
) -> Any:
if value is None:
return "{}"
return value
class CertificationMixin:
"""Mixin to add extra certification fields"""
extra = sa.Column(sa.Text, default="{}")
def get_extra_dict(self) -> dict[str, Any]:
try:
return json.loads(self.extra)
except (TypeError, json.JSONDecodeError):
return {}
@property
def is_certified(self) -> bool:
return bool(self.get_extra_dict().get("certification"))
@property
def certified_by(self) -> Optional[str]:
return self.get_extra_dict().get("certification", {}).get("certified_by")
@property
def certification_details(self) -> Optional[str]:
return self.get_extra_dict().get("certification", {}).get("details")
@property
def warning_markdown(self) -> Optional[str]:
return self.get_extra_dict().get("warning_markdown")
def clone_model(
target: Model,
ignore: Optional[list[str]] = None,
keep_relations: Optional[list[str]] = None,
**kwargs: Any,
) -> Model:
"""
Clone a SQLAlchemy model. By default will only clone naive column attributes.
To include relationship attributes, use `keep_relations`.
"""
ignore = ignore or []
table = target.__table__
primary_keys = table.primary_key.columns.keys()
data = {
attr: getattr(target, attr)
for attr in list(table.columns.keys()) + (keep_relations or [])
if attr not in primary_keys and attr not in ignore
}
data.update(kwargs)
return target.__class__(**data)
# todo(hugh): centralize where this code lives
class QueryStringExtended(NamedTuple):
applied_template_filters: Optional[list[str]]
applied_filter_columns: list[ColumnTyping]
rejected_filter_columns: list[ColumnTyping]
labels_expected: list[str]
prequeries: list[str]
sql: str
class SqlaQuery(NamedTuple):
applied_template_filters: list[str]
applied_filter_columns: list[ColumnTyping]
rejected_filter_columns: list[ColumnTyping]
cte: Optional[str]
extra_cache_keys: list[Any]
labels_expected: list[str]
prequeries: list[str]
sqla_query: Select
class ExploreMixin: # pylint: disable=too-many-public-methods
"""
Allows any flask_appbuilder.Model (Query, Table, etc.)
to be used to power a chart inside /explore
"""
sqla_aggregations = {
"COUNT_DISTINCT": lambda column_name: sa.func.COUNT(sa.distinct(column_name)),
"COUNT": sa.func.COUNT,
"SUM": sa.func.SUM,
"AVG": sa.func.AVG,
"MIN": sa.func.MIN,
"MAX": sa.func.MAX,
}
fetch_values_predicate = None
normalize_columns = False
@property
def type(self) -> str:
raise NotImplementedError()
@property
def db_extra(self) -> Optional[dict[str, Any]]:
raise NotImplementedError()
@property
def database_id(self) -> int:
raise NotImplementedError()
@property
def owners_data(self) -> list[Any]:
raise NotImplementedError()
@property
def metrics(self) -> list[Any]:
return []
@property
def uid(self) -> str:
raise NotImplementedError()
@property
def is_rls_supported(self) -> bool:
raise NotImplementedError()
@property
def cache_timeout(self) -> int | None:
raise NotImplementedError()
@property
def column_names(self) -> list[str]:
raise NotImplementedError()
@property
def offset(self) -> int:
raise NotImplementedError()
@property
def main_dttm_col(self) -> Optional[str]:
raise NotImplementedError()
@property
def always_filter_main_dttm(self) -> Optional[bool]:
return False
@property
def dttm_cols(self) -> list[str]:
raise NotImplementedError()
@property
def db_engine_spec(self) -> builtins.type["BaseEngineSpec"]:
raise NotImplementedError()
@property
def database(self) -> Database:
raise NotImplementedError()
@property
def catalog(self) -> str:
raise NotImplementedError()
@property
def schema(self) -> str:
raise NotImplementedError()
@property
def sql(self) -> str:
raise NotImplementedError()
@property
def columns(self) -> list[Any]:
raise NotImplementedError()
def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> list[Hashable]:
raise NotImplementedError()
def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor:
raise NotImplementedError()
def get_fetch_values_predicate(
self,
template_processor: Optional[ # pylint: disable=unused-argument
BaseTemplateProcessor
] = None,
) -> TextClause:
return self.fetch_values_predicate
def get_sqla_row_level_filters(
self,
template_processor: Optional[BaseTemplateProcessor] = None, # pylint: disable=unused-argument
) -> list[TextClause]:
# TODO: We should refactor this mixin and remove this method
# as it exists in the BaseDatasource and is not applicable
# for datasources of type query
return []
def _process_sql_expression( # pylint: disable=too-many-arguments
self,
expression: Optional[str],
database_id: int,
engine: str,
schema: str,
template_processor: Optional[BaseTemplateProcessor],
) -> Optional[str]:
if template_processor and expression:
expression = template_processor.process_template(expression)
if expression:
expression = validate_adhoc_subquery(
expression,
self.database,
self.catalog,
schema,
engine,
)
try:
expression = sanitize_clause(expression, engine)
except QueryClauseValidationException as ex:
raise QueryObjectValidationError(ex.message) from ex
return expression
def _process_select_expression(
self,
expression: Optional[str],
database_id: int,
engine: str,
schema: str,
template_processor: Optional[BaseTemplateProcessor],
) -> Optional[str]:
"""
Validate and process an adhoc expression used as a column or metric.
This requires prefixing the expression with a dummy SELECT statement, so it can
be properly parsed and validated.
"""
if expression:
expression = f"SELECT {expression}"
if processed := self._process_sql_expression(
expression=expression,
database_id=database_id,
engine=engine,
schema=schema,
template_processor=template_processor,
):
prefix, expression = re.split(
r"SELECT\s+",
processed,
maxsplit=1,
flags=re.IGNORECASE,
)
return expression.strip()
return None
def _process_orderby_expression(
self,
expression: Optional[str],
database_id: int,
engine: str,
schema: str,
template_processor: Optional[BaseTemplateProcessor],
) -> Optional[str]:
"""
Validate and process an ORDER BY clause expression.
This requires prefixing the expression with a dummy SELECT statement, so it can
be properly parsed and validated.
"""
if expression:
expression = f"SELECT 1 ORDER BY {expression}"
if processed := self._process_sql_expression(
expression=expression,
database_id=database_id,
engine=engine,
schema=schema,
template_processor=template_processor,
):
prefix, expression = re.split(
r"ORDER\s+BY",
processed,
maxsplit=1,
flags=re.IGNORECASE,
)
return expression.strip()
return None
def make_sqla_column_compatible(
self, sqla_col: ColumnElement, label: Optional[str] = None
) -> ColumnElement:
"""Takes a sqlalchemy column object and adds label info if supported by engine.
:param sqla_col: sqlalchemy column instance
:param label: alias/label that column is expected to have
:return: either a sql alchemy column or label instance if supported by engine
"""
label_expected = label or sqla_col.name
db_engine_spec = self.db_engine_spec
# add quotes to tables
if db_engine_spec.get_allows_alias_in_select(self.database):
label = db_engine_spec.make_label_compatible(label_expected)
sqla_col = sqla_col.label(label)
sqla_col.key = label_expected
return sqla_col
@staticmethod
def _apply_cte(sql: str, cte: Optional[str]) -> str:
"""
Append a CTE before the SELECT statement if defined
:param sql: SELECT statement
:param cte: CTE statement
:return:
"""
if cte:
sql = f"{cte}\n{sql}"
return sql
def get_query_str_extended(
self,
query_obj: QueryObjectDict,
mutate: bool = True,
) -> QueryStringExtended:
# Filter out keys that aren't parameters to get_sqla_query
filtered_query_obj = {
k: v for k, v in query_obj.items() if k in SQLA_QUERY_KEYS
}
sqlaq = self.get_sqla_query(**cast(Any, filtered_query_obj))
sql = self.database.compile_sqla_query(
sqlaq.sqla_query,
catalog=self.catalog,
schema=self.schema,
is_virtual=bool(self.sql),
)
sql = self._apply_cte(sql, sqlaq.cte)
if mutate:
sql = self.database.mutate_sql_based_on_config(sql)
return QueryStringExtended(
applied_template_filters=sqlaq.applied_template_filters,
applied_filter_columns=sqlaq.applied_filter_columns,
rejected_filter_columns=sqlaq.rejected_filter_columns,
labels_expected=sqlaq.labels_expected,
prequeries=sqlaq.prequeries,
sql=sql,
)
def _normalize_prequery_result_type(
self,
row: pd.Series,
dimension: str,
columns_by_name: dict[str, "TableColumn"],
) -> Union[str, int, float, bool, str]:
"""
Convert a prequery result type to its equivalent Python type.
Some databases like Druid will return timestamps as strings, but do not perform
automatic casting when comparing these strings to a timestamp. For cases like
this we convert the value via the appropriate SQL transform.
:param row: A prequery record
:param dimension: The dimension name
:param columns_by_name: The mapping of columns by name
:return: equivalent primitive python type
"""
value = row[dimension]
if isinstance(value, np.generic):
value = value.item()
column_ = columns_by_name.get(dimension)
db_extra: dict[str, Any] = self.database.get_extra()
if column_ is None:
# Column not found, return value as-is
pass
elif isinstance(column_, dict):
if (
column_.get("type")
and column_.get("is_temporal")
and isinstance(value, str)
):
sql = self.db_engine_spec.convert_dttm(
column_.get("type"), dateutil.parser.parse(value), db_extra=None
)
if sql:
value = self.db_engine_spec.get_text_clause(sql)
else:
if column_.type and column_.is_temporal and isinstance(value, str):
sql = self.db_engine_spec.convert_dttm(
column_.type, dateutil.parser.parse(value), db_extra=db_extra
)
if sql:
value = self.db_engine_spec.get_text_clause(sql)
return value
def make_orderby_compatible(
self, select_exprs: list[ColumnElement], orderby_exprs: list[ColumnElement]
) -> None:
"""
If needed, make sure aliases for selected columns are not used in
`ORDER BY`.
In some databases (e.g. Presto), `ORDER BY` clause is not able to
automatically pick the source column if a `SELECT` clause alias is named
the same as a source column. In this case, we update the SELECT alias to
another name to avoid the conflict.
"""
if self.db_engine_spec.allows_alias_to_source_column:
return
def is_alias_used_in_orderby(col: ColumnElement) -> bool:
if not isinstance(col, Label):
return False
regexp = re.compile(f"\\(.*\\b{re.escape(col.name)}\\b.*\\)", re.IGNORECASE)
return any(regexp.search(str(x)) for x in orderby_exprs)
# Iterate through selected columns, if column alias appears in orderby
# use another `alias`. The final output columns will still use the
# original names, because they are updated by `labels_expected` after
# querying.
for col in select_exprs:
if is_alias_used_in_orderby(col):
col.name = f"{col.name}__"
def query(self, query_obj: QueryObjectDict) -> QueryResult:
"""
Executes the query and returns a dataframe.
This method is the unified entry point for query execution across all
datasource types (Query, SqlaTable, etc.).
"""
qry_start_dttm = datetime.now()
query_str_ext = self.get_query_str_extended(query_obj)
sql = query_str_ext.sql
status = QueryStatus.SUCCESS
errors = None
error_message = None
def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]:
"""
Some engines change the case or generate bespoke column names, either by
default or due to lack of support for aliasing. This function ensures that
the column names in the DataFrame correspond to what is expected by
the viz components.
Sometimes a query may also contain only order by columns that are not used
as metrics or groupby columns, but need to present in the SQL `select`,
filtering by `labels_expected` make sure we only return columns users want.
:param df: Original DataFrame returned by the engine
:return: Mutated DataFrame
"""
labels_expected = query_str_ext.labels_expected
if df is not None and not df.empty:
if len(df.columns) < len(labels_expected):
raise QueryObjectValidationError(
_("Db engine did not return all queried columns")
)
if len(df.columns) > len(labels_expected):
df = df.iloc[:, 0 : len(labels_expected)]
df.columns = labels_expected
return df
try:
df = self.database.get_df(
sql,
self.catalog,
self.schema,
mutator=assign_column_label,
)
except Exception as ex: # pylint: disable=broad-except
# Re-raise SupersetErrorException (includes OAuth2RedirectError)
# to bubble up to API layer
if isinstance(ex, (SupersetErrorException, SupersetErrorsException)):
raise
df = pd.DataFrame()
status = QueryStatus.FAILED
logger.warning(
"Query %s on schema %s failed", sql, self.schema, exc_info=True
)
db_engine_spec = self.db_engine_spec
errors = [
dataclasses.asdict(error)
for error in db_engine_spec.extract_errors(
ex, database_name=self.database.unique_name
)
]
error_message = utils.error_msg_from_exception(ex)
return QueryResult(
applied_template_filters=query_str_ext.applied_template_filters,
applied_filter_columns=query_str_ext.applied_filter_columns,
rejected_filter_columns=query_str_ext.rejected_filter_columns,
status=status,
df=df,
duration=datetime.now() - qry_start_dttm,
query=sql,
errors=errors,
error_message=error_message,
)
def exc_query(self, qry: Any) -> QueryResult:
"""
Deprecated: Use query() instead.
This method is kept for backward compatibility.
"""
return self.query(qry)
def normalize_df(self, df: pd.DataFrame, query_object: QueryObject) -> pd.DataFrame:
"""
Normalize the dataframe by converting datetime columns and ensuring
numerical metrics.
:param df: The dataframe to normalize
:param query_object: The query object with metadata about columns
:return: Normalized dataframe
"""
def _get_timestamp_format(column: str | None) -> str | None:
if not hasattr(self, "get_column"):
return None
column_obj = self.get_column(column)
if (
column_obj
and hasattr(column_obj, "python_date_format")
and (formatter := column_obj.python_date_format)
):
return str(formatter)
return None
# Collect datetime columns
labels = tuple(
label
for label in [
*get_base_axis_labels(query_object.columns),
query_object.granularity,
]
if hasattr(self, "get_column")
and (col := self.get_column(label))
and (col.get("is_dttm") if isinstance(col, dict) else col.is_dttm)
)
dttm_cols = [
DateColumn(
timestamp_format=_get_timestamp_format(label),
offset=self.offset,
time_shift=query_object.time_shift,
col_label=label,
)
for label in labels
if label
]
if DTTM_ALIAS in df:
dttm_cols.append(
DateColumn.get_legacy_time_column(
timestamp_format=_get_timestamp_format(query_object.granularity),
offset=self.offset,
time_shift=query_object.time_shift,
)
)
# Build format map from detected datetime formats stored in dataset columns
format_map: dict[str, str] = {}
if hasattr(self, "columns"):
for col in self.columns:
if hasattr(col, "datetime_format") and col.datetime_format:
format_map[col.column_name] = col.datetime_format
normalize_dttm_col(
df=df,
dttm_cols=tuple(dttm_cols),
format_map=format_map if format_map else None,
)
# Convert metrics to numerical values if enforced
if getattr(self, "enforce_numerical_metrics", True):
dataframe_utils.df_metrics_to_num(df, query_object)
df.replace([np.inf, -np.inf], np.nan, inplace=True)
return df
def get_query_result(self, query_object: QueryObject) -> QueryResult:
"""
Execute query and return results with full processing pipeline.
This method handles:
1. Query execution via self.query()
2. DataFrame normalization
3. Time offset processing (if applicable)
4. Post-processing operations
:param query_object: The query configuration
:return: QueryResult with processed dataframe
"""
# Execute the base query
result = self.query(query_object.to_dict())
query = result.query + ";\n\n" if result.query else ""
# Process the dataframe if not empty
df = result.df
if not df.empty:
# Normalize datetime columns and metrics
df = self.normalize_df(df, query_object)
# Process time offsets if requested
if query_object.time_offsets:
# Process time offsets using the datasource's own method
# Note: caching is disabled here as we don't have query context
time_offsets = self.processing_time_offsets(
df, query_object, cache_key_fn=None, cache_timeout_fn=None
)
df = time_offsets["df"]
queries = time_offsets["queries"]
query += ";\n\n".join(queries)
query += ";\n\n"
# Execute post-processing operations
try:
df = query_object.exec_post_processing(df)
except InvalidPostProcessingError as ex:
raise QueryObjectValidationError(ex.message) from ex
# Update result with processed data
result.df = df
result.query = query
result.from_dttm = query_object.from_dttm
result.to_dttm = query_object.to_dttm
return result
def processing_time_offsets( # pylint: disable=too-many-locals,too-many-statements # noqa: C901
self,
df: pd.DataFrame,
query_object: QueryObject,
cache_key_fn: Callable[[QueryObject, str, Any], str | None] | None = None,
cache_timeout_fn: Callable[[], int] | None = None,
force_cache: bool = False,
) -> CachedTimeOffset:
"""
Process time offsets for time comparison feature.
This method handles both relative time offsets (e.g., "1 week ago") and
absolute date range offsets (e.g., "2015-01-03 : 2015-01-04").
:param df: The main dataframe
:param query_object: The query object with time offset configuration
:param cache_key_fn: Optional function to generate cache keys
:param cache_timeout_fn: Optional function to get cache timeout
:param force_cache: Whether to force cache refresh
:return: CachedTimeOffset with processed dataframe and queries
"""
# Import here to avoid circular dependency
# pylint: disable=import-outside-toplevel
from superset.common.utils.query_cache_manager import QueryCacheManager
# ensure query_object is immutable
query_object_clone = copy.copy(query_object)
queries: list[str] = []
cache_keys: list[str | None] = []
offset_dfs: dict[str, pd.DataFrame] = {}
outer_from_dttm, outer_to_dttm = get_since_until_from_query_object(query_object)
if not outer_from_dttm or not outer_to_dttm:
raise QueryObjectValidationError(
_(
"An enclosed time range (both start and end) must be specified "
"when using a Time Comparison."
)
)
time_grain = self.get_time_grain(query_object)
metric_names = get_metric_names(query_object.metrics)
# use columns that are not metrics as join keys
join_keys = [col for col in df.columns if col not in metric_names]
for offset in query_object.time_offsets:
try:
original_offset = offset
is_date_range_offset = self.is_valid_date_range(offset)
if is_date_range_offset and feature_flag_manager.is_feature_enabled(
"DATE_RANGE_TIMESHIFTS_ENABLED"
):
# DATE RANGE OFFSET LOGIC (like "2015-01-03 : 2015-01-04")
try:
# Parse the specified range
offset_from_dttm, offset_to_dttm = (
get_since_until_from_time_range(time_range=offset)
)
except ValueError as ex:
raise QueryObjectValidationError(str(ex)) from ex
# Use the specified range directly
query_object_clone.from_dttm = offset_from_dttm
query_object_clone.to_dttm = offset_to_dttm
# For date range offsets, we must NOT set inner bounds
# These create additional WHERE clauses that conflict with our
# date range
query_object_clone.inner_from_dttm = None
query_object_clone.inner_to_dttm = None
elif is_date_range_offset:
# Date range timeshift feature is disabled
raise QueryObjectValidationError(
"Date range timeshifts are not enabled. "
"Please contact your administrator to enable the "
"DATE_RANGE_TIMESHIFTS_ENABLED feature flag."
)
else:
# RELATIVE OFFSET LOGIC (like "1 day ago")
if self.is_valid_date(offset) or offset == "inherit":
offset = self.get_offset_custom_or_inherit(
offset,
outer_from_dttm,
outer_to_dttm,
)
query_object_clone.from_dttm = get_past_or_future(
offset,
outer_from_dttm,
)
query_object_clone.to_dttm = get_past_or_future(
offset, outer_to_dttm
)
query_object_clone.inner_from_dttm = query_object_clone.from_dttm
query_object_clone.inner_to_dttm = query_object_clone.to_dttm
x_axis_label = get_x_axis_label(query_object.columns)
query_object_clone.granularity = (
query_object_clone.granularity or x_axis_label
)
except ValueError as ex:
raise QueryObjectValidationError(str(ex)) from ex
query_object_clone.time_offsets = []
query_object_clone.post_processing = []
# Get time offset index
index = (get_base_axis_labels(query_object.columns) or [DTTM_ALIAS])[0]
if is_date_range_offset and feature_flag_manager.is_feature_enabled(
"DATE_RANGE_TIMESHIFTS_ENABLED"
):
# Create a completely new filter list to preserve original filters
query_object_clone.filter = copy.deepcopy(query_object_clone.filter)
# Remove any existing temporal filters that might conflict
query_object_clone.filter = [
flt
for flt in query_object_clone.filter
if not (flt.get("op") == FilterOperator.TEMPORAL_RANGE)
]
# Determine the temporal column with multiple fallback strategies
temporal_col = self._get_temporal_column_for_filter(
query_object, x_axis_label
)
# Always add a temporal filter for date range offsets
if temporal_col:
new_temporal_filter: QueryObjectFilterClause = {
"col": temporal_col,
"op": FilterOperator.TEMPORAL_RANGE,
"val": (
f"{query_object_clone.from_dttm} : "
f"{query_object_clone.to_dttm}"
),
}
query_object_clone.filter.append(new_temporal_filter)
else:
# This should rarely happen with proper fallbacks
raise QueryObjectValidationError(
_(
"Unable to identify temporal column for date range time comparison." # noqa: E501
"Please ensure your dataset has a properly configured time column." # noqa: E501
)
)
else:
# RELATIVE OFFSET: Original logic for non-date-range offsets
# The comparison is not using a temporal column so we need to modify
# the temporal filter so we run the query with the correct time range
if not dataframe_utils.is_datetime_series(df.get(index)):
query_object_clone.filter = copy.deepcopy(query_object_clone.filter)
# Find and update temporal filters
for flt in query_object_clone.filter:
if flt.get(
"op"
) == FilterOperator.TEMPORAL_RANGE and isinstance(
flt.get("val"), str
):
time_range = cast(str, flt.get("val"))
(
new_outer_from_dttm,
new_outer_to_dttm,
) = get_since_until_from_time_range(
time_range=time_range,
time_shift=offset,
)
flt["val"] = f"{new_outer_from_dttm} : {new_outer_to_dttm}"
else:
# If it IS a datetime series, we still need to clear conflicts
query_object_clone.filter = copy.deepcopy(query_object_clone.filter)
# For relative offsets with datetime series, ensure the temporal
# filter matches our range
temporal_col = query_object_clone.granularity or x_axis_label
# Update any existing temporal filters to match our shifted range
for flt in query_object_clone.filter:
if (
flt.get("op") == FilterOperator.TEMPORAL_RANGE
and flt.get("col") == temporal_col
):
flt["val"] = (
f"{query_object_clone.from_dttm} : "
f"{query_object_clone.to_dttm}"
)
# Remove non-temporal x-axis filters (but keep temporal ones)
query_object_clone.filter = [
flt
for flt in query_object_clone.filter
if not (
flt.get("col") == x_axis_label
and flt.get("op") != FilterOperator.TEMPORAL_RANGE
)
]
# Continue with the rest of the method (caching, execution, etc.)
cached_time_offset_key = (
offset if offset == original_offset else f"{offset}_{original_offset}"
)
cache_key = None
if cache_key_fn:
cache_key = cache_key_fn(
query_object_clone,
cached_time_offset_key,
time_grain,
)
cache = QueryCacheManager.get(cache_key, CacheRegion.DATA, force_cache)
if cache.is_loaded:
offset_dfs[offset] = cache.df
queries.append(cache.query)
cache_keys.append(cache_key)
continue
query_object_clone_dct = query_object_clone.to_dict()
# rename metrics: SUM(value) => SUM(value) 1 year ago
metrics_mapping = {
metric: TIME_COMPARISON.join([metric, original_offset])
for metric in metric_names
}
# When the original query has limit or offset we wont apply those
# to the subquery so we prevent data inconsistency due to missing records
# in the dataframes when performing the join
if query_object.row_limit or query_object.row_offset:
query_object_clone_dct["row_limit"] = app.config["ROW_LIMIT"]
query_object_clone_dct["row_offset"] = 0
# Call the unified query method on the datasource
result = self.query(query_object_clone_dct)
queries.append(result.query)
cache_keys.append(None)
offset_metrics_df = result.df
if offset_metrics_df.empty:
offset_metrics_df = pd.DataFrame(
{
col: [np.NaN]
for col in join_keys + list(metrics_mapping.values())
}
)
else:
# 1. normalize df, set dttm column
offset_metrics_df = self.normalize_df(
offset_metrics_df, query_object_clone
)
# 2. rename extra query columns
offset_metrics_df = offset_metrics_df.rename(columns=metrics_mapping)
# cache df and query if caching is enabled
if cache_key and cache_timeout_fn:
value = {
"df": offset_metrics_df,
"query": result.query,
}
cache.set(
key=cache_key,
value=value,
timeout=cache_timeout_fn(),
datasource_uid=self.uid,
region=CacheRegion.DATA,
)
offset_dfs[offset] = offset_metrics_df
if offset_dfs:
df = self.join_offset_dfs(
df,
offset_dfs,
time_grain,
join_keys,
)
return CachedTimeOffset(df=df, queries=queries, cache_keys=cache_keys)
@staticmethod
def get_time_grain(query_object: QueryObject) -> Any | None:
if (
query_object.columns
and len(query_object.columns) > 0
and isinstance(query_object.columns[0], dict)
):
# If the time grain is in the columns it will be the first one
# and it will be of AdhocColumn type
return query_object.columns[0].get("timeGrain")
return query_object.extras.get("time_grain_sqla")
def is_valid_date(self, date_string: str) -> bool:
try:
# Attempt to parse the string as a date in the format YYYY-MM-DD
datetime.strptime(date_string, "%Y-%m-%d")
return True
except ValueError:
# If parsing fails, it's not a valid date in the format YYYY-MM-DD
return False
def is_valid_date_range(self, date_range: str) -> bool:
try:
# Attempt to parse the string as a date range in the format
# YYYY-MM-DD:YYYY-MM-DD
start_date, end_date = date_range.split(":")
datetime.strptime(start_date.strip(), "%Y-%m-%d")
datetime.strptime(end_date.strip(), "%Y-%m-%d")
return True
except ValueError:
# If parsing fails, it's not a valid date range in the format
# YYYY-MM-DD:YYYY-MM-DD
return False
def get_offset_custom_or_inherit(
self,
offset: str,
outer_from_dttm: datetime,
outer_to_dttm: datetime,
) -> str:
"""
Get the time offset for custom or inherit.
:param offset: The offset string.
:param outer_from_dttm: The outer from datetime.
:param outer_to_dttm: The outer to datetime.
:returns: The time offset.
"""
if offset == "inherit":
# return the difference in days between the from and the to dttm formatted as a string with the " days ago" suffix # noqa: E501
return f"{(outer_to_dttm - outer_from_dttm).days} days ago"
if self.is_valid_date(offset):
# return the offset as the difference in days between the outer from dttm and the offset date (which is a YYYY-MM-DD string) formatted as a string with the " days ago" suffix # noqa: E501
offset_date = datetime.strptime(offset, "%Y-%m-%d")
return f"{(outer_from_dttm - offset_date).days} days ago"
return ""
def _get_temporal_column_for_filter( # noqa: C901
self, query_object: QueryObject, x_axis_label: str | None
) -> str | None:
"""
Helper method to reliably determine the temporal column for filtering.
This method tries multiple strategies to find the correct temporal column:
1. Use the column from existing TEMPORAL_RANGE filter
2. Use explicitly set granularity
3. Use x_axis_label if it exists
:param query_object: The query object
:param x_axis_label: The x-axis label from the query
:return: The name of the temporal column, or None if not found
"""
# Strategy 1: Use the column from existing TEMPORAL_RANGE filter
if query_object.filter:
for flt in query_object.filter:
if flt.get("op") == FilterOperator.TEMPORAL_RANGE:
col = flt.get("col")
if isinstance(col, str):
return col
elif isinstance(col, dict) and col.get("sqlExpression"):
return str(col.get("label") or col.get("sqlExpression"))
# Strategy 2: Use explicitly set granularity
if query_object.granularity:
return query_object.granularity
# Strategy 3: Use x_axis_label if it exists
if x_axis_label:
return x_axis_label
return None
def _process_date_range_offset(
self, offset_df: pd.DataFrame, join_keys: list[str]
) -> tuple[pd.DataFrame, list[str]]:
"""Process date range offset data and return modified DataFrame and keys."""
temporal_cols = ["ds", "__timestamp", "dttm"]
non_temporal_join_keys = [key for key in join_keys if key not in temporal_cols]
if non_temporal_join_keys:
return offset_df, non_temporal_join_keys
metric_columns = [col for col in offset_df.columns if col not in temporal_cols]
if metric_columns:
aggregated_values = {}
for col in metric_columns:
if pd.api.types.is_numeric_dtype(offset_df[col]):
aggregated_values[col] = offset_df[col].sum()
else:
aggregated_values[col] = (
offset_df[col].iloc[0] if not offset_df.empty else None
)
offset_df = pd.DataFrame([aggregated_values])
return offset_df, []
def _apply_cleanup_logic(
self,
df: pd.DataFrame,
offset: str,
time_grain: str | None,
join_keys: list[str],
is_date_range_offset: bool,
) -> pd.DataFrame:
"""Apply appropriate cleanup logic based on offset type."""
if time_grain and not is_date_range_offset:
if join_keys:
col = df.pop(join_keys[0])
df.insert(0, col.name, col)
df.drop(
list(df.filter(regex=f"{OFFSET_JOIN_COLUMN_SUFFIX}|{R_SUFFIX}")),
axis=1,
inplace=True,
)
elif is_date_range_offset:
df.drop(
list(df.filter(regex=f"{R_SUFFIX}")),
axis=1,
inplace=True,
)
else:
df.drop(
list(df.filter(regex=f"{R_SUFFIX}")),
axis=1,
inplace=True,
)
return df
def _determine_join_keys(
self,
df: pd.DataFrame,
offset_df: pd.DataFrame,
offset: str,
time_grain: str | None,
join_keys: list[str],
is_date_range_offset: bool,
join_column_producer: Any,
) -> tuple[pd.DataFrame, list[str]]:
"""Determine appropriate join keys and modify DataFrames if needed."""
if time_grain and not is_date_range_offset:
column_name = OFFSET_JOIN_COLUMN_SUFFIX + offset
# Add offset join columns for relative time offsets
self.add_offset_join_column(
df, column_name, time_grain, offset, join_column_producer
)
self.add_offset_join_column(
offset_df, column_name, time_grain, None, join_column_producer
)
return offset_df, [column_name, *join_keys[1:]]
elif is_date_range_offset:
return self._process_date_range_offset(offset_df, join_keys)
else:
return offset_df, join_keys
def _perform_join(
self, df: pd.DataFrame, offset_df: pd.DataFrame, actual_join_keys: list[str]
) -> pd.DataFrame:
"""Perform the appropriate join operation."""
if actual_join_keys:
return dataframe_utils.left_join_df(
left_df=df,
right_df=offset_df,
join_keys=actual_join_keys,
rsuffix=R_SUFFIX,
)
else:
temp_key = "__temp_join_key__"
df[temp_key] = 1
offset_df[temp_key] = 1
result_df = dataframe_utils.left_join_df(
left_df=df,
right_df=offset_df,
join_keys=[temp_key],
rsuffix=R_SUFFIX,
)
# Remove temporary join keys
result_df.drop(columns=[temp_key], inplace=True, errors="ignore")
result_df.drop(
columns=[f"{temp_key}{R_SUFFIX}"], inplace=True, errors="ignore"
)
return result_df
def join_offset_dfs(
self,
df: pd.DataFrame,
offset_dfs: dict[str, pd.DataFrame],
time_grain: str | None,
join_keys: list[str],
) -> pd.DataFrame:
"""
Join offset DataFrames with the main DataFrame.
:param df: The main DataFrame.
:param offset_dfs: A list of offset DataFrames.
:param time_grain: The time grain used to calculate the temporal join key.
:param join_keys: The keys to join on.
"""
join_column_producer = app.config["TIME_GRAIN_JOIN_COLUMN_PRODUCERS"].get(
time_grain
)
if join_column_producer and not time_grain:
raise QueryObjectValidationError(
_("Time Grain must be specified when using Time Shift.")
)
for offset, offset_df in offset_dfs.items():
is_date_range_offset = self.is_valid_date_range(
offset
) and feature_flag_manager.is_feature_enabled(
"DATE_RANGE_TIMESHIFTS_ENABLED"
)
offset_df, actual_join_keys = self._determine_join_keys(
df,
offset_df,
offset,
time_grain,
join_keys,
is_date_range_offset,
join_column_producer,
)
df = self._perform_join(df, offset_df, actual_join_keys)
df = self._apply_cleanup_logic(
df, offset, time_grain, join_keys, is_date_range_offset
)
return df
def add_offset_join_column(
self,
df: pd.DataFrame,
name: str,
time_grain: str,
time_offset: str | None = None,
join_column_producer: Any = None,
) -> None:
"""
Adds an offset join column to the provided DataFrame.
The function modifies the DataFrame in-place.
:param df: pandas DataFrame to which the offset join column will be added.
:param name: The name of the new column to be added.
:param time_grain: The time grain used to calculate the new column.
:param time_offset: The time offset used to calculate the new column.
:param join_column_producer: A function to generate the join column.
"""
if join_column_producer:
df[name] = df.apply(lambda row: join_column_producer(row, 0), axis=1)
else:
df[name] = df.apply(
lambda row: self.generate_join_column(row, 0, time_grain, time_offset),
axis=1,
)
@staticmethod
def generate_join_column(
row: pd.Series,
column_index: int,
time_grain: str,
time_offset: str | None = None,
) -> str:
value = row[column_index]
if hasattr(value, "strftime"):
if time_offset and not ExploreMixin.is_valid_date_range_static(time_offset):
value = value + DateOffset(**normalize_time_delta(time_offset))
if time_grain in (
TimeGrain.WEEK_STARTING_SUNDAY,
TimeGrain.WEEK_ENDING_SATURDAY,
):
return value.strftime("%Y-W%U")
if time_grain in (
TimeGrain.WEEK,
TimeGrain.WEEK_STARTING_MONDAY,
TimeGrain.WEEK_ENDING_SUNDAY,
):
return value.strftime("%Y-W%W")
if time_grain == TimeGrain.MONTH:
return value.strftime("%Y-%m")
if time_grain == TimeGrain.QUARTER:
return value.strftime("%Y-Q") + str(value.quarter)
if time_grain == TimeGrain.YEAR:
return value.strftime("%Y")
return str(value)
@staticmethod
def is_valid_date_range_static(date_range: str) -> bool:
"""Static version of is_valid_date_range for use in static methods"""
try:
# Attempt to parse the string as a date range in the format
# YYYY-MM-DD:YYYY-MM-DD
start_date, end_date = date_range.split(":")
datetime.strptime(start_date.strip(), "%Y-%m-%d")
datetime.strptime(end_date.strip(), "%Y-%m-%d")
return True
except ValueError:
# If parsing fails, it's not a valid date range in the format
# YYYY-MM-DD:YYYY-MM-DD
return False
def get_rendered_sql(
self,
template_processor: Optional[BaseTemplateProcessor] = None,
) -> str:
"""
Render sql with template engine (Jinja).
"""
if not self.sql:
return ""
sql = self.sql.strip("\t\r\n; ")
if template_processor:
try:
sql = template_processor.process_template(sql)
except (TemplateError, SupersetSyntaxErrorException) as ex:
# Extract error message from different exception types
if isinstance(ex, TemplateError):
error_msg = ex.message
else: # SupersetSyntaxErrorException
error_msg = str(ex.errors[0].message if ex.errors else ex)
raise QueryObjectValidationError(
_(
"Error while rendering virtual dataset query: %(msg)s",
msg=error_msg,
)
) from ex
script = SQLScript(sql, engine=self.db_engine_spec.engine)
if len(script.statements) > 1:
raise QueryObjectValidationError(
_("Virtual dataset query cannot consist of multiple statements")
)
if not sql:
raise QueryObjectValidationError(_("Virtual dataset query cannot be empty"))
return sql
def text(self, clause: str) -> TextClause:
return self.db_engine_spec.get_text_clause(clause)
def get_from_clause(
self, template_processor: Optional[BaseTemplateProcessor] = None
) -> tuple[Union[TableClause, Alias], Optional[str]]:
"""
Return where to select the columns and metrics from. Either a physical table
or a virtual table with it's own subquery. If the FROM is referencing a
CTE, the CTE is returned as the second value in the return tuple.
For virtual datasets, RLS filters from underlying tables are applied to
prevent RLS bypass.
"""
from_sql = self.get_rendered_sql(template_processor) + "\n"
parsed_script = SQLScript(from_sql, engine=self.db_engine_spec.engine)
if parsed_script.has_mutation():
raise QueryObjectValidationError(
_("Virtual dataset query must be read-only")
)
# Apply RLS filters to virtual dataset SQL to prevent RLS bypass
# For each table referenced in the virtual dataset, apply its RLS filters
if parsed_script.statements:
default_schema = self.database.get_default_schema(self.catalog)
try:
for statement in parsed_script.statements:
apply_rls(
self.database,
self.catalog,
self.schema or default_schema or "",
statement,
)
# Regenerate the SQL after RLS application
from_sql = parsed_script.format()
except Exception as ex:
# Log the error but don't fail - RLS application is best-effort
logger.warning("Failed to apply RLS to virtual dataset SQL: %s", ex)
cte = self.db_engine_spec.get_cte_query(from_sql)
from_clause = (
sa.table(self.db_engine_spec.cte_alias)
if cte
else TextAsFrom(self.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)
)
return from_clause, cte
def adhoc_metric_to_sqla(
self,
metric: AdhocMetric,
columns_by_name: dict[str, "TableColumn"], # pylint: disable=unused-argument
template_processor: Optional[BaseTemplateProcessor] = None,
processed: bool = False,
) -> ColumnElement:
"""
Turn an adhoc metric into a sqlalchemy column.
:param dict metric: Adhoc metric definition
:param dict columns_by_name: Columns for the current table
:param template_processor: template_processor instance
:param bool processed: Whether the sqlExpression has already been processed
:returns: The metric defined as a sqlalchemy column
:rtype: sqlalchemy.sql.column
"""
expression_type = metric.get("expressionType")
label = utils.get_metric_name(metric)
if expression_type == utils.AdhocMetricExpressionType.SIMPLE:
metric_column = metric.get("column") or {}
column_name = cast(str, metric_column.get("column_name"))
sqla_column = sa.column(column_name)
sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column)
elif expression_type == utils.AdhocMetricExpressionType.SQL:
expression = metric.get("sqlExpression")
if not processed:
expression = self._process_select_expression(
expression=metric["sqlExpression"],
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
sqla_metric = literal_column(expression)
else:
raise QueryObjectValidationError("Adhoc metric expressionType is invalid")
return self.make_sqla_column_compatible(sqla_metric, label)
@property
def template_params_dict(self) -> dict[Any, Any]:
return {}
@staticmethod
def filter_values_handler( # pylint: disable=too-many-arguments # noqa: C901
values: Optional[FilterValues],
operator: str,
target_generic_type: utils.GenericDataType,
target_native_type: Optional[str] = None,
is_list_target: bool = False,
db_engine_spec: Optional[
builtins.type["BaseEngineSpec"]
] = None, # fix(hughhh): Optional[Type[BaseEngineSpec]]
db_extra: Optional[dict[str, Any]] = None,
) -> Optional[FilterValues]:
if values is None:
return None
def handle_single_value(value: Optional[FilterValue]) -> Optional[FilterValue]:
if operator == utils.FilterOperator.TEMPORAL_RANGE:
return value
if (
isinstance(value, (float, int))
and target_generic_type == utils.GenericDataType.TEMPORAL
and target_native_type is not None
and db_engine_spec is not None
):
value = db_engine_spec.convert_dttm(
target_type=target_native_type,
dttm=datetime.utcfromtimestamp(value / 1000),
db_extra=db_extra,
)
value = literal_column(value)
if isinstance(value, str):
value = value.strip("\t\n")
if (
target_generic_type == utils.GenericDataType.NUMERIC
and operator
not in {
utils.FilterOperator.ILIKE,
utils.FilterOperator.LIKE,
}
):
# For backwards compatibility and edge cases
# where a column data type might have changed
return utils.cast_to_num(value)
if value == NULL_STRING:
return None
if value == EMPTY_STRING:
return ""
if target_generic_type == utils.GenericDataType.BOOLEAN:
return utils.cast_to_boolean(value)
return value
if isinstance(values, (list, tuple)):
values = [handle_single_value(v) for v in values] # type: ignore
else:
values = handle_single_value(values)
if is_list_target and not isinstance(values, (tuple, list)):
values = [values] # type: ignore
elif not is_list_target and isinstance(values, (tuple, list)):
values = values[0] if values else None
return values
def get_query_str(self, query_obj: QueryObjectDict) -> str:
query_str_ext = self.get_query_str_extended(query_obj)
all_queries = query_str_ext.prequeries + [query_str_ext.sql]
return ";\n\n".join(all_queries) + ";"
def _get_series_orderby(
self,
series_limit_metric: Metric,
metrics_by_name: dict[str, "SqlMetric"],
columns_by_name: dict[str, "TableColumn"],
template_processor: Optional[BaseTemplateProcessor] = None,
) -> Column:
if utils.is_adhoc_metric(series_limit_metric):
assert isinstance(series_limit_metric, dict)
ob = self.adhoc_metric_to_sqla(series_limit_metric, columns_by_name)
elif (
isinstance(series_limit_metric, str)
and series_limit_metric in metrics_by_name
):
ob = metrics_by_name[series_limit_metric].get_sqla_col(
template_processor=template_processor
)
else:
raise QueryObjectValidationError(
_("Metric '%(metric)s' does not exist", metric=series_limit_metric)
)
return ob
def _reapply_query_filters(
self,
qry: Select,
apply_fetch_values_predicate: bool,
template_processor: Optional[BaseTemplateProcessor],
granularity: str | None,
time_filters: list[ColumnElement],
where_clause_and: list[ColumnElement],
having_clause_and: list[ColumnElement],
) -> Select:
"""
Re-apply WHERE and HAVING clauses to a reconstructed query.
When group_others_when_limit_reached=True, the query is reconstructed
with sa.select(), losing previously applied filters. This method
re-applies those filters to maintain query correctness.
The WHERE clause includes: user filters, RLS filters, extra WHERE
clauses, and time range filters accumulated in where_clause_and
and time_filters.
:param qry: The reconstructed SQLAlchemy Select object
:param apply_fetch_values_predicate: Whether to apply fetch values predicate
:param template_processor: Template processor for dynamic filters
:param granularity: Time granularity (if None, time_filters not applied)
:param time_filters: Time-based filter conditions
:param where_clause_and: Accumulated WHERE clause conditions
:param having_clause_and: Accumulated HAVING clause conditions
:return: The query with filters re-applied
"""
if apply_fetch_values_predicate and self.fetch_values_predicate:
qry = qry.where(
self.get_fetch_values_predicate(template_processor=template_processor)
)
if granularity:
if time_filters or where_clause_and:
qry = qry.where(and_(*(time_filters + where_clause_and)))
else:
all_filters = time_filters + where_clause_and
if all_filters:
qry = qry.where(and_(*all_filters))
if having_clause_and:
qry = qry.having(and_(*having_clause_and))
return qry
def adhoc_column_to_sqla(
self,
col: "AdhocColumn", # type: ignore # noqa: F821
force_type_check: bool = False,
template_processor: Optional[BaseTemplateProcessor] = None,
) -> ColumnElement:
raise NotImplementedError()
def _get_top_groups(
self,
df: pd.DataFrame,
dimensions: list[str],
groupby_exprs: dict[str, Any],
columns_by_name: dict[str, "TableColumn"],
) -> ColumnElement:
groups = []
for _unused, row in df.iterrows():
group = []
for dimension in dimensions:
value = self._normalize_prequery_result_type(
row,
dimension,
columns_by_name,
)
group.append(groupby_exprs[dimension] == value)
groups.append(and_(*group))
return or_(*groups)
def _apply_series_others_grouping(
self,
select_exprs: list[Any],
groupby_all_columns: dict[str, Any],
groupby_series_columns: dict[str, Any],
condition_factory: Callable[[str, Any], Any],
) -> tuple[list[Any], dict[str, Any]]:
"""
Apply "Others" grouping to series columns in both SELECT and GROUP BY clauses.
This method encapsulates the common logic for replacing series columns with
CASE expressions that group remaining series into an "Others" category when
the series limit is reached.
Args:
select_exprs: List of SELECT expressions to modify
groupby_all_columns: Dict of GROUP BY columns to modify
groupby_series_columns: Dict of series columns to apply Others grouping to
condition_factory: Function that takes (col_name, original_expr) and returns
the condition for when to keep original value vs use "Others"
Returns:
Tuple of (modified_select_exprs, modified_groupby_all_columns)
"""
# Modify SELECT expressions
modified_select_exprs = []
for expr in select_exprs:
if hasattr(expr, "name") and expr.name in groupby_series_columns:
# Create condition for this column using the factory function
condition = condition_factory(expr.name, expr)
# Create CASE expression: condition true -> original, else "Others"
case_expr = sa.case([(condition, expr)], else_=sa.literal("Others"))
case_expr = self.make_sqla_column_compatible(case_expr, expr.name)
modified_select_exprs.append(case_expr)
else:
modified_select_exprs.append(expr)
# Modify GROUP BY expressions
modified_groupby_all_columns = {}
for col_name, gby_expr in groupby_all_columns.items():
if col_name in groupby_series_columns:
# Create condition for this column using the factory function
condition = condition_factory(col_name, gby_expr)
# Create CASE expression for groupby
case_expr = sa.case(
[(condition, gby_expr)],
else_=sa.literal("Others"),
)
# Don't apply make_sqla_column_compatible to GROUP BY expressions.
# When make_sqla_column_compatible adds a label to the expression,
# it can cause SQLAlchemy to incorrectly render string literals
# without quotes in the GROUP BY clause (e.g., "ELSE Others"
# instead of "ELSE 'Others'")
modified_groupby_all_columns[col_name] = case_expr
else:
modified_groupby_all_columns[col_name] = gby_expr
return modified_select_exprs, modified_groupby_all_columns
def dttm_sql_literal(self, dttm: datetime, col: "TableColumn") -> str:
"""Convert datetime object to a SQL expression string"""
sql = (
self.db_engine_spec.convert_dttm(col.type, dttm, db_extra=self.db_extra)
if col.type
else None
)
if sql:
return sql
tf = col.python_date_format
# Fallback to the default format (if defined).
if not tf and self.db_extra:
tf = self.db_extra.get("python_date_format_by_column_name", {}).get(
col.column_name
)
if tf:
if tf in {"epoch_ms", "epoch_s"}:
seconds_since_epoch = int(dttm.timestamp())
if tf == "epoch_s":
return str(seconds_since_epoch)
return str(seconds_since_epoch * 1000)
return f"'{dttm.strftime(tf)}'"
return f"""'{dttm.strftime("%Y-%m-%d %H:%M:%S.%f")}'"""
def get_time_filter( # pylint: disable=too-many-arguments
self,
time_col: "TableColumn",
start_dttm: Optional[sa.DateTime],
end_dttm: Optional[sa.DateTime],
time_grain: Optional[str] = None,
label: Optional[str] = "__time",
template_processor: Optional[BaseTemplateProcessor] = None,
) -> ColumnElement:
col = (
time_col.get_timestamp_expression(
time_grain=time_grain,
label=label,
template_processor=template_processor,
)
if time_grain
else self.convert_tbl_column_to_sqla_col(
time_col, label=label, template_processor=template_processor
)
)
l = [] # noqa: E741
if start_dttm:
l.append(
col
>= self.db_engine_spec.get_text_clause(
self.dttm_sql_literal(start_dttm, time_col)
)
)
if end_dttm:
l.append(
col
< self.db_engine_spec.get_text_clause(
self.dttm_sql_literal(end_dttm, time_col)
)
)
return and_(*l)
def values_for_column( # pylint: disable=too-many-locals
self,
column_name: str,
limit: int = 10000,
denormalize_column: bool = False,
) -> list[Any]:
# denormalize column name before querying for values
# unless disabled in the dataset configuration
db_dialect = self.database.get_dialect()
column_name_ = (
self.database.db_engine_spec.denormalize_name(db_dialect, column_name)
if denormalize_column
else column_name
)
cols = {col.column_name: col for col in self.columns}
target_col = cols[column_name_]
tp = self.get_template_processor()
tbl, cte = self.get_from_clause(tp)
qry = (
sa.select(
# The alias (label) here is important because some dialects will
# automatically add a random alias to the projection because of the
# call to DISTINCT; others will uppercase the column names. This
# gives us a deterministic column name in the dataframe.
[target_col.get_sqla_col(template_processor=tp).label("column_values")]
)
.select_from(tbl)
.distinct()
)
if limit:
qry = qry.limit(limit)
if self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate(template_processor=tp))
rls_filters = self.get_sqla_row_level_filters(template_processor=tp)
if rls_filters:
qry = qry.where(and_(*rls_filters))
with self.database.get_sqla_engine() as engine:
sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
sql = self._apply_cte(sql, cte)
# pylint: disable=protected-access
if engine.dialect.identifier_preparer._double_percents:
sql = sql.replace("%%", "%")
sql = self.database.mutate_sql_based_on_config(sql)
with engine.connect() as con:
df = pd.read_sql_query(sql=self.text(sql), con=con)
# replace NaN with None to ensure it can be serialized to JSON
df = df.replace({np.nan: None})
return df["column_values"].to_list()
def validate_expression(
self,
expression: str,
expression_type: SqlExpressionType = SqlExpressionType.WHERE,
) -> ValidationResultDict:
"""
Validate a SQL expression against this datasource.
:param expression: SQL expression to validate
:param expression_type: Type of expression (column, metric, where, having)
:return: Dict with validation result and any errors
"""
from superset.sql_validators.base import SQLValidationAnnotation
try:
# Process template
tp = self.get_template_processor()
processed_expression = self._process_expression_template(expression, tp)
# Build validation query
tbl, cte = self.get_from_clause(tp)
validation_query = self._build_validation_query(
processed_expression, expression_type
)
# Execute validation
return self._execute_validation_query(
validation_query, tbl, cte or "", tp, processed_expression
)
except Exception as ex:
# Convert any exception to validation error format
error_msg = str(getattr(ex, "orig", ex))
return ValidationResultDict(
valid=False,
errors=[
SQLValidationAnnotation(
message=error_msg,
line_number=1,
start_column=0,
end_column=len(expression),
).to_dict()
],
)
def _process_expression_template(
self, expression: str, tp: Optional[BaseTemplateProcessor]
) -> str:
"""Process expression through template processor. Raises on error."""
if not tp:
return expression
if hasattr(tp, "process_template"):
return tp.process_template(expression)
return expression
def _build_validation_query(
self, expression: str, expression_type: SqlExpressionType
) -> Select:
"""Build validation query based on expression type. Raises on error."""
if expression_type == SqlExpressionType.COLUMN:
return sa.select([sa.literal_column(expression).label("test_col")])
elif expression_type == SqlExpressionType.METRIC:
return sa.select([sa.literal_column(expression).label("test_metric")])
elif expression_type == SqlExpressionType.WHERE:
return sa.select([sa.literal(1)]).where(sa.text(expression))
elif expression_type == SqlExpressionType.HAVING:
dummy_col = sa.literal("A").label("dummy")
return (
sa.select([dummy_col])
.group_by(sa.text("dummy"))
.having(sa.text(expression))
)
else:
raise ValueError(f"Unsupported expression type: {expression_type}")
def _execute_validation_query(
self,
validation_query: Select,
tbl: TableClause | Alias,
cte: str,
tp: Optional[BaseTemplateProcessor],
expression: str,
) -> ValidationResultDict:
"""Execute validation query and return result."""
# Add FROM clause and prevent execution
validation_query = validation_query.select_from(tbl).where(sa.literal(False))
# Apply row-level security filters
rls_filters = self.get_sqla_row_level_filters(template_processor=tp)
if rls_filters:
validation_query = validation_query.where(and_(*rls_filters))
with self.database.get_sqla_engine() as engine:
sql = str(
validation_query.compile(engine, compile_kwargs={"literal_binds": True})
)
sql = self._apply_cte(sql, cte)
sql = self.database.mutate_sql_based_on_config(sql)
# Execute to validate without fetching data
with engine.connect() as con:
con.execute(self.text(sql))
return ValidationResultDict(valid=True, errors=[])
def get_timestamp_expression(
self,
column: dict[str, Any],
time_grain: Optional[str],
label: Optional[str] = None,
template_processor: Optional[BaseTemplateProcessor] = None,
) -> Union[TimestampExpression, Label]:
"""
Return a SQLAlchemy Core element representation of self to be used in a query.
:param column: column object
:param time_grain: Optional time grain, e.g. P1Y
:param label: alias/label that column is expected to have
:param template_processor: template processor
:return: A TimeExpression object wrapped in a Label if supported by db
"""
label = label or utils.DTTM_ALIAS
column_spec = self.db_engine_spec.get_column_spec(column.get("type"))
type_ = column_spec.sqla_type if column_spec else sa.DateTime
col = sa.column(column.get("column_name"), type_=type_)
if template_processor:
expression = template_processor.process_template(column["column_name"])
col = sa.literal_column(expression, type_=type_)
time_expr = self.db_engine_spec.get_timestamp_expr(col, None, time_grain)
return self.make_sqla_column_compatible(time_expr, label)
def convert_tbl_column_to_sqla_col(
self,
tbl_column: "TableColumn",
label: Optional[str] = None,
template_processor: Optional[BaseTemplateProcessor] = None,
) -> Column:
label = label or tbl_column.column_name
db_engine_spec = self.db_engine_spec
column_spec = db_engine_spec.get_column_spec(self.type, db_extra=self.db_extra)
type_ = column_spec.sqla_type if column_spec else None
if expression := tbl_column.expression:
if template_processor:
expression = template_processor.process_template(expression)
col = literal_column(expression, type_=type_)
else:
col = sa.column(tbl_column.column_name, type_=type_)
col = self.make_sqla_column_compatible(col, label)
return col
def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements # noqa: C901
self,
apply_fetch_values_predicate: bool = False,
columns: Optional[list[Column]] = None,
extras: Optional[dict[str, Any]] = None,
filter: Optional[ # pylint: disable=redefined-builtin
list[utils.QueryObjectFilterClause]
] = None,
from_dttm: Optional[datetime] = None,
granularity: Optional[str] = None,
groupby: Optional[list[Column]] = None,
inner_from_dttm: Optional[datetime] = None,
inner_to_dttm: Optional[datetime] = None,
is_rowcount: bool = False,
is_timeseries: bool = True,
metrics: Optional[list[Metric]] = None,
orderby: Optional[list[OrderBy]] = None,
order_desc: bool = True,
to_dttm: Optional[datetime] = None,
series_columns: Optional[list[Column]] = None,
series_limit: Optional[int] = None,
series_limit_metric: Optional[Metric] = None,
group_others_when_limit_reached: bool = False,
row_limit: Optional[int] = None,
row_offset: Optional[int] = None,
timeseries_limit: Optional[int] = None,
timeseries_limit_metric: Optional[Metric] = None,
time_shift: Optional[str] = None,
) -> SqlaQuery:
"""Querying any sqla table from this common interface"""
if granularity not in self.dttm_cols and granularity is not None:
granularity = self.main_dttm_col
extras = extras or {}
time_grain = extras.get("time_grain_sqla")
# DB-specifc quoting for identifiers
with self.database.get_sqla_engine() as engine:
quote = engine.dialect.identifier_preparer.quote
template_kwargs = {
"columns": columns,
"from_dttm": from_dttm.isoformat() if from_dttm else None,
"groupby": groupby,
"metrics": metrics,
"row_limit": row_limit,
"row_offset": row_offset,
"time_column": granularity,
"time_grain": time_grain,
"to_dttm": to_dttm.isoformat() if to_dttm else None,
"table_columns": [col.column_name for col in self.columns],
"filter": filter,
}
columns = columns or []
groupby = groupby or []
rejected_adhoc_filters_columns: list[Union[str, ColumnTyping]] = []
applied_adhoc_filters_columns: list[Union[str, ColumnTyping]] = []
db_engine_spec = self.db_engine_spec
series_column_labels = [
db_engine_spec.make_label_compatible(column)
for column in utils.get_column_names(
columns=series_columns or [],
)
]
# deprecated, to be removed in 2.0
if is_timeseries and timeseries_limit:
series_limit = timeseries_limit
series_limit_metric = series_limit_metric or timeseries_limit_metric
template_kwargs.update(self.template_params_dict)
extra_cache_keys: list[Any] = []
template_kwargs["extra_cache_keys"] = extra_cache_keys
removed_filters: list[str] = []
applied_template_filters: list[str] = []
template_kwargs["removed_filters"] = removed_filters
template_kwargs["applied_filters"] = applied_template_filters
template_processor = self.get_template_processor(**template_kwargs)
prequeries: list[str] = []
orderby = orderby or []
need_groupby = bool(metrics is not None or groupby)
metrics = metrics or []
# For backward compatibility
if granularity not in self.dttm_cols and granularity is not None:
granularity = self.main_dttm_col
columns_by_name: dict[str, "TableColumn"] = {
col.column_name: col for col in self.columns
}
quoted_columns_by_name = {quote(k): v for k, v in columns_by_name.items()}
metrics_by_name: dict[str, "SqlMetric"] = {
m.metric_name: m for m in self.metrics
}
if not granularity and is_timeseries:
raise QueryObjectValidationError(
_(
"Datetime column not provided as part table configuration "
"and is required by this type of chart"
)
)
if not metrics and not columns and not groupby:
raise QueryObjectValidationError(_("Empty query?"))
metrics_exprs: list[ColumnElement] = []
for metric in metrics:
if utils.is_adhoc_metric(metric):
assert isinstance(metric, dict)
metrics_exprs.append(
self.adhoc_metric_to_sqla(
metric=metric,
columns_by_name=columns_by_name,
template_processor=template_processor,
)
)
elif isinstance(metric, str) and metric in metrics_by_name:
metrics_exprs.append(
metrics_by_name[metric].get_sqla_col(
template_processor=template_processor
)
)
else:
raise QueryObjectValidationError(
_("Metric '%(metric)s' does not exist", metric=metric)
)
if metrics_exprs:
main_metric_expr = metrics_exprs[0]
else:
main_metric_expr, label = literal_column("COUNT(*)"), "ccount"
main_metric_expr = self.make_sqla_column_compatible(main_metric_expr, label)
# To ensure correct handling of the ORDER BY labeling we need to reference the
# metric instance if defined in the SELECT clause.
# use the key of the ColumnClause for the expected label
metrics_exprs_by_label = {m.key: m for m in metrics_exprs}
metrics_exprs_by_expr = {str(m): m for m in metrics_exprs}
# Since orderby may use adhoc metrics, too; we need to process them first
orderby_exprs: list[ColumnElement] = []
for orig_col, ascending in orderby: # noqa: B007
col: Union[AdhocMetric, ColumnElement] = orig_col
if isinstance(col, dict):
col = cast(AdhocMetric, col)
if col.get("sqlExpression"):
col["sqlExpression"] = self._process_orderby_expression(
expression=col["sqlExpression"],
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
if utils.is_adhoc_metric(col):
# add adhoc sort by column to columns_by_name if not exists
col = self.adhoc_metric_to_sqla(
col,
columns_by_name,
processed=True,
)
# use the existing instance, if possible
col = metrics_exprs_by_expr.get(str(col), col)
need_groupby = True
elif col in metrics_exprs_by_label:
col = metrics_exprs_by_label[col]
need_groupby = True
elif col in metrics_by_name:
col = metrics_by_name[col].get_sqla_col(
template_processor=template_processor
)
need_groupby = True
elif col in columns_by_name:
col = self.convert_tbl_column_to_sqla_col(
columns_by_name[col], template_processor=template_processor
)
if isinstance(col, ColumnElement):
orderby_exprs.append(col)
else:
# Could not convert a column reference to valid ColumnElement
raise QueryObjectValidationError(
_("Unknown column used in orderby: %(col)s", col=orig_col)
)
select_exprs: list[Union[Column, Label]] = []
groupby_all_columns = {}
groupby_series_columns = {}
# filter out the pseudo column __timestamp from columns
columns = [col for col in columns if col != utils.DTTM_ALIAS]
dttm_col = columns_by_name.get(granularity) if granularity else None
if need_groupby:
# dedup columns while preserving order
columns = groupby or columns
for selected in columns:
if isinstance(selected, str):
# if groupby field/expr equals granularity field/expr
if selected == granularity:
table_col = columns_by_name[selected]
outer = table_col.get_timestamp_expression(
time_grain=time_grain,
label=selected,
template_processor=template_processor,
)
# if groupby field equals a selected column
elif selected in columns_by_name:
outer = self.convert_tbl_column_to_sqla_col(
columns_by_name[selected],
template_processor=template_processor,
)
else:
selected = self._process_select_expression(
expression=selected,
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
outer = literal_column(f"({selected})")
outer = self.make_sqla_column_compatible(outer, selected)
else:
outer = self.adhoc_column_to_sqla(
col=selected,
template_processor=template_processor,
)
groupby_all_columns[outer.name] = outer
if (
is_timeseries and not series_column_labels
) or outer.name in series_column_labels:
groupby_series_columns[outer.name] = outer
select_exprs.append(outer)
elif columns:
for selected in columns:
if is_adhoc_column(selected):
_sql = selected["sqlExpression"]
_column_label = selected["label"]
elif isinstance(selected, str):
_sql = quote(selected)
_column_label = selected
selected = self._process_select_expression(
expression=_sql,
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
select_exprs.append(
self.convert_tbl_column_to_sqla_col(
quoted_columns_by_name[selected],
template_processor=template_processor,
label=_column_label,
)
if selected in quoted_columns_by_name
else self.make_sqla_column_compatible(
literal_column(selected), _column_label
)
)
metrics_exprs = []
time_filters = []
# Process FROM clause early to populate removed_filters from virtual dataset
# templates before we decide whether to add time filters
tbl, cte = self.get_from_clause(template_processor)
if granularity:
if granularity not in columns_by_name or not dttm_col:
raise QueryObjectValidationError(
_(
'Time column "%(col)s" does not exist in dataset',
col=granularity,
)
)
if is_timeseries:
timestamp = dttm_col.get_timestamp_expression(
time_grain=time_grain, template_processor=template_processor
)
# always put timestamp as the first column
select_exprs.insert(0, timestamp)
groupby_all_columns[timestamp.name] = timestamp
# Use main dttm column to support index with secondary dttm columns.
if (
self.always_filter_main_dttm
and self.main_dttm_col in self.dttm_cols
and self.main_dttm_col != dttm_col.column_name
and self.main_dttm_col not in removed_filters
):
time_filters.append(
self.get_time_filter(
time_col=columns_by_name[self.main_dttm_col],
start_dttm=from_dttm,
end_dttm=to_dttm,
template_processor=template_processor,
)
)
# Check if time filter should be skipped because it was handled in template.
# Check both the actual column name and __timestamp alias
should_skip_time_filter = (
dttm_col.column_name in removed_filters
or utils.DTTM_ALIAS in removed_filters
)
if not should_skip_time_filter:
time_filter_column = self.get_time_filter(
time_col=dttm_col,
start_dttm=from_dttm,
end_dttm=to_dttm,
template_processor=template_processor,
)
time_filters.append(time_filter_column)
# Always remove duplicates by column name, as sometimes `metrics_exprs`
# can have the same name as a groupby column (e.g. when users use
# raw columns as custom SQL adhoc metric).
select_exprs = remove_duplicates(
select_exprs + metrics_exprs, key=lambda x: x.name
)
# Expected output columns
labels_expected = [c.key for c in select_exprs]
# Order by columns are "hidden" columns, some databases require them
# always be present in SELECT if an aggregation function is used
if not db_engine_spec.allows_hidden_orderby_agg:
select_exprs = remove_duplicates(select_exprs + orderby_exprs)
qry = sa.select(select_exprs)
if groupby_all_columns:
qry = qry.group_by(*groupby_all_columns.values())
where_clause_and: list[ColumnElement] = []
having_clause_and: list[ColumnElement] = []
for flt in filter: # type: ignore
if not all(flt.get(s) for s in ["col", "op"]):
continue
flt_col = flt["col"]
val = flt.get("val")
flt_grain = flt.get("grain")
op = utils.FilterOperator(flt["op"].upper())
col_obj: Optional["TableColumn"] = None
sqla_col: Optional[Column] = None
is_metric_filter = (
False # Track if this is a filter on a metric (needs HAVING clause)
)
if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col:
col_obj = dttm_col
elif is_adhoc_column(flt_col):
try:
sqla_col = self.adhoc_column_to_sqla(flt_col, force_type_check=True)
applied_adhoc_filters_columns.append(flt_col)
except ColumnNotFoundException:
rejected_adhoc_filters_columns.append(flt_col)
continue
else:
# Check if it's a regular column first
col_obj = columns_by_name.get(cast(str, flt_col))
# If not found in columns, check if it's a metric
# This supports filtering on metric columns for any chart type
if (
col_obj is None
and isinstance(flt_col, str)
and flt_col in metrics_by_name
):
# Convert metric to SQLA column expression
sqla_col = metrics_by_name[flt_col].get_sqla_col(
template_processor=template_processor
)
is_metric_filter = True
filter_grain = flt.get("grain")
# Check if this filter should be skipped because it was handled in
# template. Special handling for __timestamp alias: check both the
# alias and the actual column name
filter_col_name = get_column_name(flt_col)
should_skip_filter = filter_col_name in removed_filters
if not should_skip_filter and flt_col == utils.DTTM_ALIAS and col_obj:
# For __timestamp, also check if the actual datetime column was
# removed
should_skip_filter = col_obj.column_name in removed_filters
if should_skip_filter:
# Skip generating SQLA filter when the jinja template handles it.
continue
# Determine which clause list to use: HAVING for metrics, WHERE for columns
# Metric filters use HAVING clause because they involve aggregate functions
target_clause_list = (
having_clause_and if is_metric_filter else where_clause_and
)
if col_obj or sqla_col is not None:
if sqla_col is not None:
pass
elif col_obj and filter_grain:
sqla_col = col_obj.get_timestamp_expression(
time_grain=filter_grain, template_processor=template_processor
)
elif col_obj:
sqla_col = self.convert_tbl_column_to_sqla_col(
tbl_column=col_obj, template_processor=template_processor
)
col_type = col_obj.type if col_obj else None
col_spec = db_engine_spec.get_column_spec(native_type=col_type)
is_list_target = op in (
utils.FilterOperator.IN,
utils.FilterOperator.NOT_IN,
)
col_advanced_data_type = col_obj.advanced_data_type if col_obj else ""
if col_spec and not col_advanced_data_type:
target_generic_type = col_spec.generic_type
else:
target_generic_type = GenericDataType.STRING
eq = self.filter_values_handler(
values=val,
operator=op,
target_generic_type=target_generic_type,
target_native_type=col_type,
is_list_target=is_list_target,
db_engine_spec=db_engine_spec,
)
# Get ADVANCED_DATA_TYPES from config when needed
ADVANCED_DATA_TYPES = app.config.get("ADVANCED_DATA_TYPES", {}) # noqa: N806
if (
col_advanced_data_type != ""
and feature_flag_manager.is_feature_enabled(
"ENABLE_ADVANCED_DATA_TYPES"
)
and col_advanced_data_type in ADVANCED_DATA_TYPES
):
values = eq if is_list_target else [eq] # type: ignore
bus_resp: AdvancedDataTypeResponse = ADVANCED_DATA_TYPES[
col_advanced_data_type
].translate_type(
{
"type": col_advanced_data_type,
"values": values,
}
)
if bus_resp["error_message"]:
raise AdvancedDataTypeResponseError(
_(bus_resp["error_message"])
)
target_clause_list.append(
ADVANCED_DATA_TYPES[col_advanced_data_type].translate_filter(
sqla_col, op, bus_resp["values"]
)
)
elif is_list_target:
assert isinstance(eq, (tuple, list))
if len(eq) == 0:
raise QueryObjectValidationError(
_("Filter value list cannot be empty")
)
if len(eq) > len(
eq_without_none := [x for x in eq if x is not None]
):
is_null_cond = sqla_col.is_(None)
if eq:
cond = or_(is_null_cond, sqla_col.in_(eq_without_none))
else:
cond = is_null_cond
else:
cond = sqla_col.in_(eq)
if op == utils.FilterOperator.NOT_IN:
cond = ~cond
target_clause_list.append(cond)
elif op in {
utils.FilterOperator.IS_NULL,
utils.FilterOperator.IS_NOT_NULL,
}:
target_clause_list.append(
db_engine_spec.handle_null_filter(sqla_col, op)
)
elif op == utils.FilterOperator.IS_TRUE:
target_clause_list.append(
db_engine_spec.handle_boolean_filter(sqla_col, op, True)
)
elif op == utils.FilterOperator.IS_FALSE:
target_clause_list.append(
db_engine_spec.handle_boolean_filter(sqla_col, op, False)
)
else:
if (
op
not in {
utils.FilterOperator.EQUALS,
utils.FilterOperator.NOT_EQUALS,
}
and eq is None
):
raise QueryObjectValidationError(
_(
"Must specify a value for filters "
"with comparison operators"
)
)
if op in {
utils.FilterOperator.EQUALS,
utils.FilterOperator.NOT_EQUALS,
utils.FilterOperator.GREATER_THAN,
utils.FilterOperator.LESS_THAN,
utils.FilterOperator.GREATER_THAN_OR_EQUALS,
utils.FilterOperator.LESS_THAN_OR_EQUALS,
}:
target_clause_list.append(
db_engine_spec.handle_comparison_filter(sqla_col, op, eq)
)
elif op in {
utils.FilterOperator.ILIKE,
utils.FilterOperator.LIKE,
}:
if target_generic_type != GenericDataType.STRING:
sqla_col = sa.cast(sqla_col, sa.String)
if op == utils.FilterOperator.LIKE:
target_clause_list.append(sqla_col.like(eq))
else:
target_clause_list.append(sqla_col.ilike(eq))
elif op in {utils.FilterOperator.NOT_LIKE}:
if target_generic_type != GenericDataType.STRING:
sqla_col = sa.cast(sqla_col, sa.String)
target_clause_list.append(sqla_col.not_like(eq))
elif (
op == utils.FilterOperator.TEMPORAL_RANGE
and isinstance(eq, str)
and col_obj is not None
):
_since, _until = get_since_until_from_time_range(
time_range=eq,
time_shift=time_shift,
extras=extras,
)
target_clause_list.append(
self.get_time_filter(
time_col=col_obj,
start_dttm=_since,
end_dttm=_until,
time_grain=flt_grain,
label=sqla_col.key,
template_processor=template_processor,
)
)
else:
raise QueryObjectValidationError(
_("Invalid filter operation type: %(op)s", op=op)
)
else:
# col_obj is None and sqla_col is None - column not found!
# Silently skip - this can happen for removed columns or invalid filters
pass
where_clause_and += self.get_sqla_row_level_filters(template_processor)
if extras:
where = extras.get("where")
if where:
where = self._process_select_expression(
expression=where,
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
where_clause_and += [self.text(where)]
having = extras.get("having")
if having:
having = self._process_select_expression(
expression=having,
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
having_clause_and += [self.text(having)]
if apply_fetch_values_predicate and self.fetch_values_predicate:
qry = qry.where(
self.get_fetch_values_predicate(template_processor=template_processor)
)
if granularity:
qry = qry.where(and_(*(time_filters + where_clause_and)))
else:
qry = qry.where(and_(*where_clause_and))
qry = qry.having(and_(*having_clause_and))
self.make_orderby_compatible(select_exprs, orderby_exprs)
for col, (_orig_col, ascending) in zip(orderby_exprs, orderby, strict=False): # noqa: B007
if not db_engine_spec.allows_alias_in_orderby and isinstance(col, Label):
# if engine does not allow using SELECT alias in ORDER BY
# revert to the underlying column
col = col.element
if (
db_engine_spec.get_allows_alias_in_select(self.database)
and db_engine_spec.allows_hidden_cc_in_orderby
and col.name in [select_col.name for select_col in select_exprs]
):
col = literal_column(quote(col.name))
direction = sa.asc if ascending else sa.desc
qry = qry.order_by(direction(col))
if row_limit:
qry = qry.limit(row_limit)
if row_offset:
qry = qry.offset(row_offset)
if series_limit and groupby_series_columns:
if db_engine_spec.allows_joins and db_engine_spec.allows_subqueries:
# some sql dialects require for order by expressions
# to also be in the select clause -- others, e.g. vertica,
# require a unique inner alias
inner_main_metric_expr = self.make_sqla_column_compatible(
main_metric_expr, "mme_inner__"
)
inner_groupby_exprs = []
inner_select_exprs = []
for gby_name, gby_obj in groupby_series_columns.items():
inner = self.make_sqla_column_compatible(gby_obj, gby_name + "__")
inner_groupby_exprs.append(inner)
inner_select_exprs.append(inner)
inner_select_exprs += [inner_main_metric_expr]
subq = sa.select(inner_select_exprs).select_from(tbl)
inner_time_filter = []
if dttm_col and not db_engine_spec.time_groupby_inline:
inner_time_filter = [
self.get_time_filter(
time_col=dttm_col,
start_dttm=inner_from_dttm or from_dttm,
end_dttm=inner_to_dttm or to_dttm,
template_processor=template_processor,
)
]
subq = subq.where(and_(*(where_clause_and + inner_time_filter)))
subq = subq.group_by(*inner_groupby_exprs)
ob = inner_main_metric_expr
if series_limit_metric:
ob = self._get_series_orderby(
series_limit_metric=series_limit_metric,
metrics_by_name=metrics_by_name,
columns_by_name=columns_by_name,
template_processor=template_processor,
)
direction = sa.desc if order_desc else sa.asc
subq = subq.order_by(direction(ob))
subq = subq.limit(series_limit)
on_clause = []
for gby_name, gby_obj in groupby_series_columns.items():
# in this case the column name, not the alias, needs to be
# conditionally mutated, as it refers to the column alias in
# the inner query
col_name = db_engine_spec.make_label_compatible(gby_name + "__")
on_clause.append(gby_obj == sa.column(col_name))
# Use LEFT JOIN when grouping others, INNER JOIN otherwise
if group_others_when_limit_reached:
# Create the alias once and reuse it
subq_alias = subq.alias(SERIES_LIMIT_SUBQ_ALIAS)
tbl = tbl.join(
subq_alias,
and_(*on_clause),
isouter=True,
)
# Apply Others grouping using the refactored method
def _create_join_condition(col_name: str, expr: Any) -> Any:
# Get the corresponding column from the subquery
subq_col_name = db_engine_spec.make_label_compatible(
col_name + "__"
)
# Reference the column from the already-created aliased subquery
subq_col = subq_alias.c[subq_col_name]
return subq_col.is_not(None)
select_exprs, groupby_all_columns = (
self._apply_series_others_grouping(
select_exprs,
groupby_all_columns,
groupby_series_columns,
_create_join_condition,
)
)
# Reconstruct query with modified expressions
qry = sa.select(select_exprs)
if groupby_all_columns:
qry = qry.group_by(*groupby_all_columns.values())
# Re-apply WHERE and HAVING clauses lost during query reconstruction
qry = self._reapply_query_filters(
qry,
apply_fetch_values_predicate,
template_processor,
granularity,
time_filters,
where_clause_and,
having_clause_and,
)
else:
tbl = tbl.join(
subq.alias(SERIES_LIMIT_SUBQ_ALIAS), and_(*on_clause)
)
else:
if series_limit_metric:
orderby = [
(
self._get_series_orderby(
series_limit_metric=series_limit_metric,
metrics_by_name=metrics_by_name,
columns_by_name=columns_by_name,
template_processor=template_processor,
),
not order_desc,
)
]
# run prequery to get top groups
prequery_obj: QueryObjectDict = {
"is_timeseries": False,
"row_limit": series_limit,
"metrics": metrics,
"granularity": granularity,
"groupby": groupby,
"from_dttm": inner_from_dttm or from_dttm,
"to_dttm": inner_to_dttm or to_dttm,
"filter": filter or [],
"orderby": orderby,
"extras": extras,
"columns": get_non_base_axis_columns(columns),
"order_desc": True,
}
result = self.query(prequery_obj)
prequeries.append(result.query)
dimensions = [
c
for c in result.df.columns
if c not in metrics and c in groupby_series_columns
]
top_groups = self._get_top_groups(
result.df, dimensions, groupby_series_columns, columns_by_name
)
if group_others_when_limit_reached:
# Apply Others grouping using the refactored method
def _create_top_groups_condition(col_name: str, expr: Any) -> Any:
return top_groups
select_exprs, groupby_all_columns = (
self._apply_series_others_grouping(
select_exprs,
groupby_all_columns,
groupby_series_columns,
_create_top_groups_condition,
)
)
# Reconstruct query with modified expressions
qry = sa.select(select_exprs)
if groupby_all_columns:
qry = qry.group_by(*groupby_all_columns.values())
# Re-apply WHERE and HAVING clauses lost during query reconstruction
qry = self._reapply_query_filters(
qry,
apply_fetch_values_predicate,
template_processor,
granularity,
time_filters,
where_clause_and,
having_clause_and,
)
else:
# Original behavior: filter to only top groups
qry = qry.where(top_groups)
qry = qry.select_from(tbl)
if is_rowcount:
if not db_engine_spec.allows_subqueries:
raise QueryObjectValidationError(
_("Database does not support subqueries")
)
label = "rowcount"
col = self.make_sqla_column_compatible(literal_column("COUNT(*)"), label)
qry = sa.select([col]).select_from(qry.alias("rowcount_qry"))
labels_expected = [label]
filter_columns = [flt.get("col") for flt in filter] if filter else []
rejected_filter_columns = [
col
for col in filter_columns
if col
and not is_adhoc_column(col)
and col not in self.column_names
and col not in applied_template_filters
] + rejected_adhoc_filters_columns
applied_filter_columns = [
col
for col in filter_columns
if col
and not is_adhoc_column(col)
and (col in self.column_names or col in applied_template_filters)
] + applied_adhoc_filters_columns
return SqlaQuery(
applied_template_filters=applied_template_filters,
cte=cte,
applied_filter_columns=applied_filter_columns,
rejected_filter_columns=rejected_filter_columns,
extra_cache_keys=extra_cache_keys,
labels_expected=labels_expected,
sqla_query=qry,
prequeries=prequeries,
)