Files
superset2/superset/migrations/shared/migrate_viz/query_functions.py
2025-07-24 09:40:49 -07:00

1508 lines
48 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json # noqa: TID251
import math
from enum import Enum
from typing import Any, Dict, List, Optional, Union
class RollingType(Enum):
Mean = "mean"
Sum = "sum"
Std = "std"
Cumsum = "cumsum"
class ComparisonType(Enum):
Values = "values"
Difference = "difference"
Percentage = "percentage"
Ratio = "ratio"
class DatasourceType(Enum):
Table = "table"
Query = "query"
Dataset = "dataset"
SlTable = "sl_table"
SavedQuery = "saved_query"
UNARY_OPERATORS = ["IS NOT NULL", "IS NULL"]
BINARY_OPERATORS = [
"==",
"!=",
">",
"<",
">=",
"<=",
"ILIKE",
"LIKE",
"NOT LIKE",
"REGEX",
"TEMPORAL_RANGE",
]
SET_OPERATORS = ["IN", "NOT IN"]
unary_operator_set = set(UNARY_OPERATORS)
binary_operator_set = set(BINARY_OPERATORS)
set_operator_set = set(SET_OPERATORS)
class DatasourceKey:
def __init__(self, key: str):
id_str, type_str = key.split("__", 1)
self.id = int(id_str)
# Default to Table; if type_str is 'query', then use Query.
self.type = DatasourceType.Table
if type_str == "query":
self.type = DatasourceType.Query
def __str__(self) -> str:
return f"{self.id}__{self.type.value}"
def to_object(self) -> dict[str, Any]:
return {
"id": self.id,
"type": self.type.value,
}
TIME_COMPARISON_SEPARATOR = "__"
DTTM_ALIAS = "__timestamp"
NO_TIME_RANGE = "No filter"
EXTRA_FORM_DATA_OVERRIDE_EXTRA_KEYS = [
"relative_start",
"relative_end",
"time_grain_sqla",
]
EXTRA_FORM_DATA_APPEND_KEYS = [
"adhoc_filters",
"filters",
"interactive_groupby",
"interactive_highlight",
"interactive_drilldown",
"custom_form_data",
]
EXTRA_FORM_DATA_OVERRIDE_REGULAR_MAPPINGS = {
"granularity": "granularity",
"granularity_sqla": "granularity",
"time_column": "time_column",
"time_grain": "time_grain",
"time_range": "time_range",
}
EXTRA_FORM_DATA_OVERRIDE_REGULAR_KEYS = list(
EXTRA_FORM_DATA_OVERRIDE_REGULAR_MAPPINGS.keys()
)
EXTRA_FORM_DATA_OVERRIDE_KEYS = (
EXTRA_FORM_DATA_OVERRIDE_REGULAR_KEYS + EXTRA_FORM_DATA_OVERRIDE_EXTRA_KEYS
)
def ensure_is_array(value: Optional[Union[List[Any], Any]] = None) -> List[Any]:
"""
Ensure a nullable value input is a list. Useful when consolidating
input format from a select control.
"""
if value is None:
return []
return value if isinstance(value, list) else [value]
def is_empty(value: Any) -> bool:
"""
A simple implementation similar to lodash's isEmpty.
Returns True if value is None or an empty collection.
"""
if value is None:
return True
if isinstance(value, (list, dict, str, tuple, set)):
return len(value) == 0
return False
def is_saved_metric(metric: Any) -> bool:
"""Return True if metric is a saved metric (str)."""
return isinstance(metric, str)
def is_adhoc_metric_simple(metric: Any) -> bool:
"""Return True if metric dict is a simple adhoc metric."""
return (
not isinstance(metric, str)
and isinstance(metric, dict)
and metric.get("expressionType") == "SIMPLE"
)
def is_adhoc_metric_sql(metric: Any) -> bool:
"""Return True if metric dict is an SQL adhoc metric."""
return (
not isinstance(metric, str)
and isinstance(metric, dict)
and metric.get("expressionType") == "SQL"
)
def is_query_form_metric(metric: Any) -> bool:
"""Return True if metric is of any query form type."""
return (
is_saved_metric(metric)
or is_adhoc_metric_simple(metric)
or is_adhoc_metric_sql(metric)
)
def get_metric_label(metric: Any | dict[str, Any]) -> Any | dict[str, Any]:
"""
Get the label for a given metric.
Args:
metric (dict): The metric object.
Returns:
dict: The label of the metric.
"""
if is_saved_metric(metric):
return metric
if "label" in metric and metric["label"]:
return metric["label"]
if is_adhoc_metric_simple(metric):
column_name = metric["column"].get("columnName") or metric["column"].get(
"column_name"
)
return f"{metric['aggregate']}({column_name})"
return metric["sqlExpression"]
def extract_extra_metrics(form_data: Dict[str, Any]) -> List[Any]:
"""
Extract extra metrics from the form data.
Args:
form_data (Dict[str, Any]): The query form data.
Returns:
List[Any]: A list of extra metrics.
"""
groupby = form_data.get("groupby", [])
timeseries_limit_metric = form_data.get("timeseries_limit_metric")
x_axis_sort = form_data.get("x_axis_sort")
metrics = form_data.get("metrics", [])
extra_metrics = []
limit_metric = (
ensure_is_array(timeseries_limit_metric)[0] if timeseries_limit_metric else None
)
if (
not groupby
and limit_metric
and get_metric_label(limit_metric) == x_axis_sort
and not any(get_metric_label(metric) == x_axis_sort for metric in metrics)
):
extra_metrics.append(limit_metric)
return extra_metrics
def get_metric_offsets_map(
form_data: dict[str, List[str]], query_object: dict[str, List[str]]
) -> dict[str, Any]:
"""
Return a dictionary mapping metric offset-labels to metric-labels.
Args:
form_data (Dict[str, List[str]]): The form data containing time comparisons.
query_object (Dict[str, List[str]]): The query object containing metrics.
Returns:
Dict[str, str]: A dictionary with offset-labels as keys and metric-labels
as values.
"""
query_metrics = ensure_is_array(query_object.get("metrics", []))
time_offsets = ensure_is_array(form_data.get("time_compare", []))
metric_labels = [get_metric_label(metric) for metric in query_metrics]
metric_offset_map = {}
for metric in metric_labels:
for offset in time_offsets:
key = f"{metric}{TIME_COMPARISON_SEPARATOR}{offset}"
metric_offset_map[key] = metric
return metric_offset_map
def is_time_comparison(form_data: dict[str, Any], query_object: dict[str, Any]) -> bool:
"""
Determine if the query involves a time comparison.
Args:
form_data (dict): The form data containing query parameters.
query_object (dict): The query object.
Returns:
bool: True if it is a time comparison, False otherwise.
"""
comparison_type = form_data.get("comparison_type")
metric_offset_map = get_metric_offsets_map(form_data, query_object)
return (
comparison_type in [ct.value for ct in ComparisonType]
and len(metric_offset_map) > 0
)
def ensure_is_int(value: Any, default_value: Any = None) -> Any | float:
"""
Convert the given value to an integer.
If conversion fails, returns default_value if provided,
otherwise returns NaN (as float('nan')).
"""
try:
val = int(str(value))
except (ValueError, TypeError):
return default_value if default_value is not None else float("nan")
return val
def is_physical_column(column: Any = None) -> bool:
"""Return True if column is a physical column (string)."""
return isinstance(column, str)
def is_adhoc_column(column: Any = None) -> bool:
"""Return True if column is an adhoc column (object with SQL expression)."""
if type(column) is not dict:
return False
return (
"sqlExpression" in column.keys()
and column["sqlExpression"] is not None
and "label" in column.keys()
and column["label"] is not None
and ("sqlExpression" not in column.keys() or column["expressionType"] == "SQL")
)
def is_query_form_column(column: Any) -> bool:
"""Return True if column is either physical or adhoc."""
return is_physical_column(column) or is_adhoc_column(column)
def is_x_axis_set(form_data: dict[str, Any]) -> bool:
"""Return True if the x_axis is specified in form_data."""
return is_query_form_column(form_data.get("x_axis"))
def get_x_axis_column(form_data: dict[str, Any]) -> Optional[Any]:
"""Return x_axis column."""
if not (form_data.get("granularity_sqla") or form_data.get("x_axis")):
return None
if is_x_axis_set(form_data):
return form_data.get("x_axis")
return DTTM_ALIAS
def get_column_label(column: Any) -> Optional[str]:
"""Return the string label for a column."""
if is_physical_column(column):
return column
if column and column.get("label"):
return column.get("label")
return column.get("sqlExpression", None)
def get_x_axis_label(form_data: dict[str, Any]) -> Optional[str]:
"""Return the x_axis label from form_data."""
if col := get_x_axis_column(form_data):
return get_column_label(col)
return None
def time_compare_pivot_operator(
form_data: dict[str, Any], query_object: dict[str, Any]
) -> Optional[dict[str, Any]]:
"""
A post-processing factory function for pivot operations.
Args:
form_data: The form data containing configuration
query_object: The query object with series and columns information
Returns:
Dictionary with pivot operation configuration or None
"""
metric_offset_map = get_metric_offsets_map(form_data, query_object)
x_axis_label = get_x_axis_label(form_data)
columns = (
query_object.get("series_columns")
if query_object.get("series_columns") is not None
else query_object.get("columns")
)
if is_time_comparison(form_data, query_object) and x_axis_label:
# Create aggregates dictionary from metric offset map
metrics = list(metric_offset_map.values()) + list(metric_offset_map.keys())
aggregates = {
metric: {"operator": "mean"} # use 'mean' aggregates to avoid dropping NaN
for metric in metrics
}
return {
"operation": "pivot",
"options": {
"index": [x_axis_label],
"columns": [get_column_label(col) for col in ensure_is_array(columns)],
"drop_missing_columns": not form_data.get("show_empty_columns"),
"aggregates": aggregates,
},
}
return None
def pivot_operator(
form_data: dict[str, Any], query_object: dict[str, Any]
) -> Optional[dict[str, Any]]:
"""
Construct a pivot operator configuration for post-processing.
This function extracts metric labels (including extra metrics) from the query object
and form data, and retrieves the x-axis label. If both an x-axis label and at
least one metric label are present, it builds a pivot configuration that sets
the index as the x-axis label, transforms the columns via get_column_label,
and creates dummy 'mean' aggregates for each metric.
Args:
form_data (dict): The form data containing query parameters.
query_object (dict): The base query object containing metrics
and column information.
Returns:
dict or None: A dict with the pivot operator configuration
if the conditions are met,
otherwise None.
"""
metric_labels = [
*ensure_is_array(query_object.get("metrics", [])),
*extract_extra_metrics(form_data),
]
metric_labels = [get_metric_label(metric) for metric in metric_labels]
x_axis_label = get_x_axis_label(form_data)
columns = (
query_object.get("series_columns")
if query_object.get("series_columns") is not None
else query_object.get("columns")
)
if x_axis_label and metric_labels:
cols_list = [get_column_label(col) for col in ensure_is_array(columns)]
return {
"operation": "pivot",
"options": {
"index": [x_axis_label],
"columns": cols_list,
# Create 'dummy' mean aggregates to assign cell values in pivot table
# using the 'mean' aggregates to avoid dropping NaN values
"aggregates": {
metric: {"operator": "mean"} for metric in metric_labels
},
"drop_missing_columns": not form_data.get("show_empty_columns"),
},
}
return None
def normalize_order_by(query_object: dict[str, Any]) -> dict[str, Any]:
"""
Normalize the orderby clause in the query object.
If the "orderby" key already contains a valid clause (a list whose first element
is a list of two elements, where the first element is truthy and the second a bool),
the original query_object is returned. Otherwise, the function creates a copy of
query_object, removes invalid orderby-related keys, and sets an orderby clause based
on available keys: "series_limit_metric", "legacy_order_by", or the first metric in
the "metrics" list. The sorting order is determined by the negation of "order_desc".
Args:
query_object (dict): The query object containing orderby and related keys.
Returns:
dict: A modified query object with a normalized "orderby" clause.
"""
if (
isinstance(query_object.get("orderby"), list)
and len(query_object.get("orderby", [])) > 0
):
# ensure a valid orderby clause
orderby_clause = query_object["orderby"][0]
if (
isinstance(orderby_clause, list)
and len(orderby_clause) == 2
and orderby_clause[0]
and isinstance(orderby_clause[1], bool)
):
return query_object
# remove invalid orderby keys from a copy
clone_query_object = query_object.copy()
clone_query_object.pop("series_limit_metric", None)
clone_query_object.pop("legacy_order_by", None)
clone_query_object.pop("order_desc", None)
clone_query_object.pop("orderby", None)
is_asc = not query_object.get("order_desc", False)
if query_object.get("series_limit_metric") is not None and query_object.get(
"series_limit_metric"
):
return {
**clone_query_object,
"orderby": [[query_object["series_limit_metric"], is_asc]],
}
# todo: Removed `legacy_order_by` after refactoring
if query_object.get("legacy_order_by") is not None and query_object.get(
"legacy_order_by"
):
return {
**clone_query_object,
"orderby": [[query_object["legacy_order_by"], is_asc]],
}
if (
isinstance(query_object.get("metrics"), list)
and len(query_object.get("metrics", [])) > 0
):
return {**clone_query_object, "orderby": [[query_object["metrics"][0], is_asc]]}
return clone_query_object
def remove_duplicates(items: Any, hash_func: Any = None) -> list[Any]:
"""
Remove duplicate items from a list.
Args:
items: List of items to deduplicate
hash_func: Optional function to generate a hash for comparison
Returns:
List with duplicates removed
"""
if hash_func:
seen = set()
result = []
for x in items:
item_hash = hash_func(x)
if item_hash not in seen:
seen.add(item_hash)
result.append(x)
return result
else:
# Using Python's built-in uniqueness for lists
return list(dict.fromkeys(items)) # Preserves order in Python 3.7+
def extract_fields_from_form_data(
rest_form_data: dict[str, Any],
query_field_aliases: dict[str, Any],
query_mode: Any | str,
) -> tuple[list[Any], list[Any], list[Any]]:
"""
Extract fields from form data based on aliases and query mode.
Args:
rest_form_data (dict): The residual form data.
query_field_aliases (dict): A mapping of key aliases.
query_mode (str): The query mode, e.g. 'aggregate' or 'raw'.
Returns:
tuple: A tuple of three lists: (columns, metrics, orderby)
"""
columns = []
metrics = []
orderby = []
for key, value in rest_form_data.items():
if value is None:
continue
normalized_key = query_field_aliases.get(key, key)
if query_mode == "aggregate" and normalized_key == "columns":
continue
if query_mode == "raw" and normalized_key in ["groupby", "metrics"]:
continue
if normalized_key == "groupby":
normalized_key = "columns"
if normalized_key == "metrics":
metrics.extend(value if isinstance(value, list) else [value])
elif normalized_key == "columns":
columns.extend(value if isinstance(value, list) else [value])
elif normalized_key == "orderby":
orderby.extend(value if isinstance(value, list) else [value])
return columns, metrics, orderby
def extract_query_fields(
form_data: dict[Any, Any], aliases: Any = None
) -> Union[dict[str, Any]]:
"""
Extract query fields from form data.
Args:
form_data: Form data residual
aliases: Query field aliases
Returns:
Dictionary with columns, metrics, and orderby fields
"""
query_field_aliases = {
"metric": "metrics",
"metric_2": "metrics",
"secondary_metric": "metrics",
"x": "metrics",
"y": "metrics",
"size": "metrics",
"all_columns": "columns",
"series": "groupby",
"order_by_cols": "orderby",
}
if aliases:
query_field_aliases.update(aliases)
query_mode = form_data.pop("query_mode", None)
rest_form_data = form_data
columns, metrics, orderby = extract_fields_from_form_data(
rest_form_data, query_field_aliases, query_mode
)
result: dict[str, Any] = {
"columns": remove_duplicates(
[col for col in columns if col != ""], get_column_label
),
"orderby": None,
}
if query_mode != "raw":
result["metrics"] = remove_duplicates(metrics, get_metric_label)
else:
result["metrics"] = None
if orderby:
result["orderby"] = []
for item in orderby:
if isinstance(item, str):
try:
result["orderby"].append(json.loads(item))
except Exception as err:
raise ValueError("Found invalid orderby options") from err
else:
result["orderby"].append(item)
return result
def extract_extras(form_data: dict[str, Any]) -> dict[str, Any]:
"""
Extract extras from the form_data analogous to the TS version.
"""
applied_time_extras: dict[str, Any] = {}
filters: list[Any] = []
extras: dict[str, Any] = {}
extract: dict[str, Any] = {
"filters": filters,
"extras": extras,
"applied_time_extras": applied_time_extras,
}
# Mapping reserved columns to query field names
reserved_columns_to_query_field = {
"__time_range": "time_range",
"__time_col": "granularity_sqla",
"__time_grain": "time_grain_sqla",
"__granularity": "granularity",
}
extra_filters = form_data.get("extra_filters", [])
for filter_item in extra_filters:
col = filter_item.get("col")
# Check if filter col is reserved
if col in reserved_columns_to_query_field:
query_field = reserved_columns_to_query_field[col]
# Assign the filter value to the extract dict
extract[query_field] = filter_item.get("val")
applied_time_extras[col] = filter_item.get("val")
else:
filters.append(filter_item)
# SQL: set extra properties based on TS logic
if "time_grain_sqla" in form_data.keys() or "time_grain_sqla" in extract.keys():
# If time_grain_sqla is set in form_data, use it
# Otherwise, use the value from extract
value = form_data.get("time_grain_sqla") or form_data.get("time_grain_sqla")
extras["time_grain_sqla"] = value
extract["granularity"] = (
extract.get("granularity_sqla")
or form_data.get("granularity")
or form_data.get("granularity_sqla")
)
# Remove temporary keys
extract.pop("granularity_sqla", None)
extract.pop("time_grain_sqla", None)
if extract["granularity"] is None:
extract.pop("granularity", None)
return extract
def is_defined(x: Any) -> bool:
"""
Returns True if x is not None.
This is equivalent to checking that x is neither null nor undefined in TypeScript.
"""
return x is not None
def sanitize_clause(clause: str) -> str:
"""
Sanitize a SQL clause. If the clause contains '--', append a newline.
Then wrap the clause in parentheses.
"""
if clause is None:
return ""
sanitized_clause = clause
if "--" in clause:
sanitized_clause = clause + "\n"
return f"({sanitized_clause})"
def is_unary_operator(operator: Any | str) -> bool:
"""Return True if operator is unary."""
return operator in unary_operator_set
def is_binary_operator(operator: Any | str) -> bool:
"""Return True if operator is binary."""
return operator in binary_operator_set
def is_set_operator(operator: Any | str) -> bool:
"""Return True if operator is a set operator."""
return operator in set_operator_set
def is_unary_adhoc_filter(filter_item: dict[str, Any]) -> bool:
"""Return True if the filter's operator is unary."""
return is_unary_operator(filter_item.get("operator"))
def is_binary_adhoc_filter(filter_item: dict[str, Any]) -> bool:
"""Return True if the filter's operator is binary."""
return is_binary_operator(filter_item.get("operator"))
def convert_filter(filter_item: dict[str, Any]) -> dict[str, Any]:
"""Convert an adhoc filter to a query clause dict."""
subject = filter_item.get("subject")
if is_unary_adhoc_filter(filter_item):
operator = filter_item.get("operator")
return {"col": subject, "op": operator}
if is_binary_adhoc_filter(filter_item):
operator = filter_item.get("operator")
val = filter_item.get("comparator")
result = {"col": subject, "op": operator}
if val is not None:
result["val"] = val
return result
operator = filter_item.get("operator")
val = filter_item.get("comparator")
result = {"col": subject, "op": operator}
if val is not None:
result["val"] = val
return result
def is_simple_adhoc_filter(filter_item: dict[str, Any]) -> bool:
"""Return True if the filter is a simple adhoc filter."""
return filter_item.get("expressionType") == "SIMPLE"
def process_filters(form_data: dict[str, Any]) -> dict[str, Any]:
"""
Process filters from form_data:
- Split adhoc_filters according to clause and expression type.
- Build simple filter and freeform SQL clauses for WHERE/HAVING.
- Place freeform clauses into extras.
"""
adhoc_filters = form_data.get("adhoc_filters", [])
extras = form_data.get("extras", {})
filters_list = form_data.get("filters", [])
# Copy filters_list into simple_where
simple_where = filters_list[:]
freeform_where = []
freeform_having = []
if where := form_data.get("where"):
freeform_where.append(where)
for filter_item in adhoc_filters:
clause = filter_item.get("clause")
if is_simple_adhoc_filter(filter_item):
filter_clause = convert_filter(filter_item)
if clause == "WHERE":
simple_where.append(filter_clause)
else:
sql_expression = filter_item.get("sqlExpression")
if clause == "WHERE":
freeform_where.append(sql_expression)
else:
freeform_having.append(sql_expression)
extras["having"] = " AND ".join([sanitize_clause(s) for s in freeform_having])
extras["where"] = " AND ".join([sanitize_clause(s) for s in freeform_where])
return {
"filters": simple_where,
"extras": extras,
}
def override_extra_form_data(
query_object: dict[str, Any], override_form_data: dict[str, Any]
) -> dict[str, Any]:
"""
Override parts of the query_object with values from override_form_data.
Mimics the behavior of the TypeScript function:
- For keys in EXTRA_FORM_DATA_OVERRIDE_REGULAR_MAPPINGS,
if set in override_form_data, assign the value in query_object
under the mapped target key.
- For keys in EXTRA_FORM_DATA_OVERRIDE_EXTRA_KEYS,
if present in override_form_data, add them to query_object['extras'].
"""
# Create a copy of the query object
overridden_form_data = query_object.copy()
# Ensure extras is a mutable copy of what's in query_object (or an empty dict)
overridden_extras = overridden_form_data.get("extras", {}).copy()
# Process regular mappings
for key, target in EXTRA_FORM_DATA_OVERRIDE_REGULAR_MAPPINGS.items():
value = override_form_data.get(key)
if value is not None:
overridden_form_data[target] = value
# Process extra keys
for key in EXTRA_FORM_DATA_OVERRIDE_EXTRA_KEYS:
if key in override_form_data:
overridden_extras[key] = override_form_data[key]
if overridden_extras:
overridden_form_data["extras"] = overridden_extras
return overridden_form_data
def build_query_object(
form_data: dict[str, Any], query_fields: Any = None
) -> dict[str, Any]:
"""
Build a query object from form data.
Args:
form_data: Dictionary containing form data
query_fields: Optional query field aliases
Returns:
Dictionary representing the query object
"""
# Extract fields from form_data with defaults
annotation_layers = form_data.get("annotation_layers", [])
extra_form_data = form_data.get("extra_form_data", {})
time_range = form_data.get("time_range")
since = form_data.get("since")
until = form_data.get("until")
row_limit = form_data.get("row_limit")
row_offset = form_data.get("row_offset")
order_desc = form_data.get("order_desc")
limit: Any | int = form_data.get("limit")
timeseries_limit_metric = form_data.get("timeseries_limit_metric")
granularity = form_data.get("granularity")
url_params = form_data.get("url_params", {})
custom_params = form_data.get("custom_params", {})
series_columns = form_data.get("series_columns")
series_limit: Any | str = form_data.get("series_limit")
series_limit_metric = form_data.get("series_limit_metric")
# Create residual_form_data by removing extracted fields
residual_form_data = {
k: v
for k, v in form_data.items()
if k
not in [
"annotation_layers",
"extra_form_data",
"time_range",
"since",
"until",
"row_limit",
"row_offset",
"order_desc",
"limit",
"timeseries_limit_metric",
"granularity",
"url_params",
"custom_params",
"series_columns",
"series_limit",
"series_limit_metric",
]
}
# Extract fields from extra_form_data
append_adhoc_filters = (
extra_form_data.get("adhoc_filters", []) if extra_form_data else []
)
append_filters = extra_form_data.get("filters", []) if extra_form_data else []
custom_form_data = (
extra_form_data.get("custom_form_data", {}) if extra_form_data else {}
)
overrides = (
{
k: v
for k, v in extra_form_data.items()
if k not in ["adhoc_filters", "filters", "custom_form_data"]
}
if extra_form_data
else {}
)
# Convert to numeric values
numeric_row_limit: Any = float(row_limit) if row_limit is not None else None
numeric_row_offset: Any = float(row_offset) if row_offset is not None else None
# Extract query fields
extracted_fields = extract_query_fields(residual_form_data, query_fields)
metrics = extracted_fields.get("metrics")
columns = extracted_fields.get("columns")
orderby = extracted_fields.get("orderby")
# Collect and process filters
extras = extract_extras(form_data)
extra_filters = extras.get("filters", [])
filter_form_data = {
"filters": extra_filters + append_filters,
"adhoc_filters": (form_data.get("adhoc_filters") or []) + append_adhoc_filters,
}
extras_and_filters = process_filters({**form_data, **extras, **filter_form_data})
def normalize_series_limit_metric(metric: Any) -> Optional[Any]:
if is_query_form_metric(metric):
return metric
return None
# Build the query object
query_object: dict[Any, Any] = {
**extras,
**extras_and_filters,
"columns": columns,
"metrics": metrics,
"orderby": orderby,
"annotation_layers": annotation_layers,
"series_columns": series_columns,
"row_limit": (
None
if row_limit is None or math.isnan(numeric_row_limit)
else int(numeric_row_limit)
),
"series_limit": (
series_limit
if series_limit is not None
else (int(limit) if is_defined(limit) else 0)
),
"order_desc": True if order_desc is None else order_desc,
"url_params": url_params,
"custom_params": custom_params,
}
row_offset = (
None
if row_offset is None or math.isnan(numeric_row_offset)
else numeric_row_offset
)
temp = normalize_series_limit_metric(series_limit_metric)
series_limit_metric = temp if temp is not None else timeseries_limit_metric
for key, value in [
("time_range", time_range),
("since", since),
("until", until),
("granularity", granularity),
("series_limit_metric", series_limit_metric),
("row_offset", row_offset),
]:
if value is not None:
query_object[key] = value
# Override extra form data
query_object = override_extra_form_data(query_object, overrides)
query_object = {k: v for k, v in query_object.items() if v is not None}
# Return the final query object with custom form data
return {**query_object, "custom_form_data": custom_form_data}
def omit(d: dict[str, Any], keys: list[Any]) -> dict[str, Any]:
"""
Return a copy of dictionary d without the specified keys.
"""
return {k: v for k, v in d.items() if k not in keys}
def normalize_time_column(
form_data: dict[str, Any], query_object: dict[str, Any]
) -> dict[str, Any]:
"""
If x_axis is set in form_data, find its index in query_object's columns and update
that column
with timeGrain and columnType information. The updated query_object omits
the 'is_timeseries' key.
"""
if not is_x_axis_set(form_data):
return query_object
_columns: list[Any] = query_object.get("columns", [])
_extras = query_object.get("extras", {})
# Create a shallow copy of columns
mutated_columns = list(_columns)
x_axis: Any = form_data.get("x_axis")
axis_idx = None
# Find the index of the x_axis in the columns list
for idx, col in enumerate(_columns):
if (
is_physical_column(col) and is_physical_column(x_axis) and col == x_axis
) or (
is_adhoc_column(col)
and is_adhoc_column(x_axis)
and col.get("sqlExpression") == x_axis.get("sqlExpression")
):
axis_idx = idx
break
if axis_idx is not None and axis_idx > -1 and x_axis and isinstance(_columns, list):
if is_adhoc_column(_columns[axis_idx]):
# Update the adhoc column with additional keys.
updated = dict(_columns[axis_idx])
updated["columnType"] = "BASE_AXIS"
if _extras:
if "time_grain_sqla" in _extras.keys():
updated["timeGrain"] = _extras["time_grain_sqla"]
mutated_columns[axis_idx] = updated
else:
# For physical columns, create a new column entry.
mutated_columns[axis_idx] = {
"columnType": "BASE_AXIS",
"sqlExpression": x_axis,
"label": x_axis,
"expressionType": "SQL",
}
if _extras:
if "time_grain_sqla" in _extras.keys():
mutated_columns[axis_idx]["timeGrain"] = _extras["time_grain_sqla"]
# Create a new query object without the 'is_timeseries' key.
new_query_object = omit(query_object, ["is_timeseries"])
new_query_object["columns"] = mutated_columns
return new_query_object
# Fallback: return the original query_object
return query_object
def build_query_context(
form_data: dict[str, Any], options: Any = None
) -> dict[str, Any]:
# Handle options based on type
def default_build_query(x: Any) -> list[Any]:
return [x]
if callable(options):
query_fields = {}
build_query = options
elif options:
query_fields = options.get("query_fields", {})
build_query = options.get("build_query", lambda x: [x])
else:
query_fields = {}
build_query = default_build_query
queries = build_query(build_query_object(form_data, query_fields))
for query in queries:
if isinstance(query.get("post_processing"), list):
query["post_processing"] = [p for p in query["post_processing"] if p]
if is_x_axis_set(form_data):
queries = [normalize_time_column(form_data, query) for query in queries]
return {
"datasource": DatasourceKey(form_data["datasource"]).to_object(),
"force": form_data.get("force", False),
"queries": queries,
"form_data": form_data,
"result_format": form_data.get("result_format", "json"),
"result_type": form_data.get("result_type", "full"),
}
def rolling_window_operator(
form_data: dict[str, Any], query_object: dict[str, Any]
) -> Optional[dict[str, Any]]:
"""
Builds a post-processing configuration for a rolling window.
- If it's a time comparison, compute the columns from the metric offsets map.
- Otherwise, derive the columns from query_object.metrics.
- Then, based on the rolling_type, return a configuration dict.
"""
# Determine the columns to operate on
if is_time_comparison(form_data, query_object):
metrics_map = get_metric_offsets_map(form_data, query_object)
columns = list(metrics_map.values()) + list(metrics_map.keys())
else:
metrics = ensure_is_array(query_object.get("metrics"))
columns = []
for metric in metrics:
if isinstance(metric, str):
columns.append(metric)
elif isinstance(metric, dict):
columns.append(metric.get("label"))
# Build a columns map from the list of columns
columns_map = {col: col for col in columns if col is not None}
# Determine the operation based on rolling_type
rolling_type = form_data.get("rolling_type")
if rolling_type == RollingType.Cumsum.value:
return {
"operation": "cum",
"options": {
"operator": "sum",
"columns": columns_map,
},
}
if rolling_type in [
RollingType.Sum.value,
RollingType.Mean.value,
RollingType.Std.value,
]:
return {
"operation": "rolling",
"options": {
"rolling_type": rolling_type,
"window": ensure_is_int(form_data.get("rolling_periods"), 1),
"min_periods": ensure_is_int(form_data.get("min_periods"), 0),
"columns": columns_map,
},
}
return None
def time_compare_operator(
form_data: Dict[str, Any], query_object: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
Returns a post-processing configuration for time comparison if applicable.
If time comparison is enabled and the comparison type is not 'values',
builds a configuration dict that specifies the operation and options.
"""
comparison_type = form_data.get("comparison_type")
metric_offset_map = get_metric_offsets_map(form_data, query_object)
if (
is_time_comparison(form_data, query_object)
and comparison_type != ComparisonType.Values.value
):
return {
"operation": "compare",
"options": {
"source_columns": list(metric_offset_map.values()),
"compare_columns": list(metric_offset_map.keys()),
"compare_type": comparison_type,
"drop_original_columns": True,
},
}
return None
def resample_operator(
form_data: dict[str, Any], query_object: dict[str, Any]
) -> Any | dict[str, Any]:
"""
Returns a post-processing configuration for resampling if the required
resample_method and resample_rule are provided in form_data.
"""
resample_zero_fill = form_data.get("resample_method") == "zerofill"
resample_method = (
"asfreq" if resample_zero_fill else form_data.get("resample_method")
)
resample_rule = form_data.get("resample_rule")
if resample_method and resample_rule:
return {
"operation": "resample",
"options": {
"method": resample_method,
"rule": resample_rule,
"fill_value": 0 if resample_zero_fill else None,
},
}
return None
def rename_operator(
form_data: dict[str, Any], query_object: dict[str, Any]
) -> Optional[dict[str, Any]]:
"""
Produces a post-processing configuration to rename columns based on the criteria:
1) Only one metric exists.
2) There is at least one dimension (series_columns or columns).
3) An x-axis label exists.
4) If time comparison is enabled and its comparison type is not one of
[difference, ratio, percentage].
5) The form data contains a truthy 'truncate_metric' flag.
Additionally, if time comparison is active and the comparison type is 'values',
the operator renames the metric with the corresponding offset label.
"""
metrics: Any = ensure_is_array(query_object.get("metrics"))
columns = ensure_is_array(
query_object.get("series_columns")
if query_object.get("series_columns") is not None
else query_object.get("columns")
)
truncate_metric = form_data.get("truncate_metric")
x_axis_label = get_x_axis_label(form_data)
# Check conditions for renaming
if (
len(metrics) == 1
and len(columns) > 0
and x_axis_label
and not (
is_time_comparison(form_data, query_object)
and form_data.get("comparison_type")
in {
ComparisonType.Difference.value,
ComparisonType.Ratio.value,
ComparisonType.Percentage.value,
}
)
and truncate_metric is not None
and bool(truncate_metric)
):
rename_pairs: Any = []
if (
is_time_comparison(form_data, query_object)
and form_data.get("comparison_type") == ComparisonType.Values.value
):
metric_offset_map = get_metric_offsets_map(form_data, query_object)
time_offsets = ensure_is_array(form_data.get("time_compare"))
for metric_with_offset in list(metric_offset_map.keys()):
offset_label = next(
(offset for offset in time_offsets if offset in metric_with_offset),
None,
)
rename_pairs.append((metric_with_offset, offset_label))
rename_pairs.append((get_metric_label(metrics[0]), None))
return {
"operation": "rename",
"options": {
"columns": dict(rename_pairs),
"level": 0,
"inplace": True,
},
}
return None
def contribution_operator(
form_data: dict[str, Any], query_object: dict[str, Any], time_shifts: Any
) -> Optional[dict[str, Any]]:
"""
Returns a post-processing configuration for contribution if
form_data.contributionMode is truthy.
"""
if form_data.get("contributionMode"):
return {
"operation": "contribution",
"options": {
"orientation": form_data.get("contributionMode"),
"time_shifts": time_shifts,
},
}
return None
def sort_operator(
form_data: Dict[str, Any], query_object: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
Build a sort post-processing configuration if the conditions are met.
Conditions:
- form_data.x_axis_sort and form_data.x_axis_sort_asc are defined.
- The sort key exists in sortableLabels.
- groupby is empty.
If the sort key matches the x-axis label, sort using the index.
Otherwise, sort by the provided sort key.
"""
# Build the sortable labels list
sortable_labels: list[Any] = [
get_x_axis_label(form_data),
]
sortable_labels += [
get_metric_label(m) for m in ensure_is_array(form_data.get("metrics"))
]
sortable_labels += [get_metric_label(m) for m in extract_extra_metrics(form_data)]
# Filter out any falsy values
sortable_labels = [label for label in sortable_labels if label]
# Check the required conditions.
if (
is_defined(form_data.get("x_axis_sort"))
and is_defined(form_data.get("x_axis_sort_asc"))
and form_data.get("x_axis_sort") in sortable_labels
and is_empty(form_data.get("groupby")) ##
):
if form_data.get("x_axis_sort") == get_x_axis_label(form_data):
return {
"operation": "sort",
"options": {
"is_sort_index": True,
"ascending": form_data.get("x_axis_sort_asc"),
},
}
return {
"operation": "sort",
"options": {
"by": form_data.get("x_axis_sort"),
"ascending": form_data.get("x_axis_sort_asc"),
},
}
return None
def flatten_operator(
form_data: dict[str, Any], query_object: dict[str, Any]
) -> dict[str, Any]:
"""
Returns a post-processing configuration that indicates a flatten operation.
"""
return {"operation": "flatten"}
def prophet_operator(
form_data: dict[str, Any], query_object: dict[str, Any]
) -> Any | dict[str, Any]:
"""
Returns a post-processing configuration for prophet forecasting
if forecast is enabled and an x-axis label is present.
"""
x_axis_label = get_x_axis_label(form_data)
if form_data.get("forecastEnabled") and x_axis_label:
try:
periods = int(form_data.get("forecastPeriods", 0))
except (TypeError, ValueError):
periods = 0
try:
confidence_interval = float(form_data.get("forecastInterval", 0))
except (TypeError, ValueError):
confidence_interval = 0.0
return {
"operation": "prophet",
"options": {
"time_grain": form_data.get("time_grain_sqla"),
"periods": periods,
"confidence_interval": confidence_interval,
"yearly_seasonality": form_data.get("forecastSeasonalityYearly"),
"weekly_seasonality": form_data.get("forecastSeasonalityWeekly"),
"daily_seasonality": form_data.get("forecastSeasonalityDaily"),
"index": x_axis_label,
},
}
return None
def rank_operator(
form_data: dict[str, Any], query_object: dict[str, Any], options: dict[str, Any]
) -> dict[str, Any]:
"""
Returns a post-processing configuration for ranking.
Args:
form_data (dict): The form data for the query.
query_object (dict): The base query object.
options (dict): Options for the rank operator.
Returns:
dict: A configuration dict with the ranking operation.
"""
options_dict = options
if options_dict.get("group_by") is None:
options_dict.pop("group_by", None)
return {
"operation": "rank",
"options": options_dict,
}
def drop_none_values(options: dict[str, Any]) -> dict[str, Any]:
return {k: v for k, v in options.items() if v is not None}
def histogram_operator(
form_data: dict[str, str | Any], query_object: dict[str, Any]
) -> dict[str, Any]:
"""
Build a histogram operator configuration.
This function extracts histogram parameters from the form data and builds an
operator configuration for generating a histogram. It attempts to parse the
'bins' value as an integer (defaulting to 5 if parsing fails), retrieves the
column and groupby details by using get_column_label, and collects additional
options such as cumulative and normalize flags.
Args:
form_data (dict): Dictionary containing histogram parameters
such as 'bins', 'column','cumulative', 'groupby', and 'normalize'.
query_object (dict): Dictionary representing the query object
Returns:
dict: A dictionary with keys "operation" and "options"
that defines the histogram operator.
"""
bins: Any | int = form_data.get("bins")
column = form_data.get("column")
cumulative = form_data.get("cumulative")
groupby = form_data.get("groupby", [])
normalize = form_data.get("normalize")
try:
parsed_bins = int(bins)
except (TypeError, ValueError):
parsed_bins = 5
parsed_column = get_column_label(column)
parsed_groupby = [get_column_label(g) for g in groupby]
options = {
"column": parsed_column,
"groupby": parsed_groupby,
"bins": parsed_bins,
"cumulative": cumulative,
"normalize": normalize,
}
result = {"operation": "histogram", "options": drop_none_values(options)}
return result
def retain_form_data_suffix(
form_data: dict[str, Any], control_suffix: str
) -> dict[str, Any]:
"""
Retain keys from the form data that end with a specified suffix
and remove the suffix from them.
The function creates a new form data dictionary. For keys ending
with the provided
control_suffix, it removes the suffix and assigns the corresponding
value. If a key does
not end with the suffix and is not already set in the new dictionary
(i.e. via a suffixed key), it is retained as-is.
Args:
form_data (dict): The original form data dictionary.
control_suffix (str): The suffix string to look for in keys.
Returns:
dict: A new dictionary containing the retained and modified keys.
"""
new_form_data = {}
entries = sorted(
form_data.items(),
key=lambda kv: 1 if kv[0].endswith(control_suffix) else 0,
reverse=True,
)
for key, value in entries:
if key.endswith(control_suffix):
new_form_data[key[: -len(control_suffix)]] = value
if not key.endswith(control_suffix) and key not in new_form_data.keys():
new_form_data[key] = value
return new_form_data
def remove_form_data_suffix(
form_data: dict[str, Any], control_suffix: str
) -> dict[str, Any]:
"""
Remove keys from the form data that end with a specified suffix.
This function builds a new dictionary containing only those key-value pairs
where the key does NOT end with the given control_suffix.
Args:
form_data (dict): The original form data dictionary.
control_suffix (str): The suffix indicating which keys should be removed.
Returns:
dict: A new dictionary with the keys ending with control_suffix removed.
"""
new_form_data = {}
for key, value in form_data.items():
if not key.endswith(control_suffix):
new_form_data[key] = value
return new_form_data