mirror of
https://github.com/apache/superset.git
synced 2026-04-20 00:24:38 +00:00
fix(charting): correctly categorize numeric columns with NULL values (#34213)
This commit is contained in:
@@ -48,7 +48,16 @@ from enum import Enum, IntEnum
|
||||
from io import BytesIO
|
||||
from timeit import default_timer
|
||||
from types import TracebackType
|
||||
from typing import Any, Callable, cast, NamedTuple, TYPE_CHECKING, TypedDict, TypeVar
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
)
|
||||
from urllib.parse import unquote_plus
|
||||
from zipfile import ZipFile
|
||||
|
||||
@@ -120,6 +129,32 @@ InputType = TypeVar("InputType") # pylint: disable=invalid-name
|
||||
|
||||
ADHOC_FILTERS_REGEX = re.compile("^adhoc_filters")
|
||||
|
||||
TYPE_MAPPING = {
|
||||
re.compile(r"INT", re.IGNORECASE): "integer",
|
||||
re.compile(r"CHAR|TEXT|VARCHAR", re.IGNORECASE): "string",
|
||||
re.compile(r"DECIMAL|NUMERIC|FLOAT|DOUBLE", re.IGNORECASE): "floating",
|
||||
re.compile(r"BOOL", re.IGNORECASE): "boolean",
|
||||
re.compile(r"DATE|TIME", re.IGNORECASE): "datetime64",
|
||||
}
|
||||
|
||||
METRIC_MAP_TYPE = {
|
||||
"SUM": "floating",
|
||||
"AVG": "floating",
|
||||
"COUNT": "floating",
|
||||
"COUNT_DISTINCT": "floating",
|
||||
"MIN": "numeric",
|
||||
"MAX": "numeric",
|
||||
"FIRST": "string",
|
||||
"LAST": "string",
|
||||
"GROUP_CONCAT": "string",
|
||||
"ARRAY_AGG": "string",
|
||||
"STRING_AGG": "string",
|
||||
"MEDIAN": "floating",
|
||||
"PERCENTILE": "floating",
|
||||
"VARIANCE": "floating",
|
||||
"STDDEV": "floating",
|
||||
}
|
||||
|
||||
|
||||
class AdhocMetricExpressionType(StrEnum):
|
||||
SIMPLE = "SIMPLE"
|
||||
@@ -1503,6 +1538,67 @@ def get_column_names_from_metrics(metrics: list[Metric]) -> list[str]:
|
||||
return [col for col in map(get_column_name_from_metric, metrics) if col]
|
||||
|
||||
|
||||
def map_sql_type_to_inferred_type(sql_type: Optional[str]) -> str:
|
||||
"""
|
||||
Map a SQL type to a type string recognized by pandas' `infer_objects` method.
|
||||
|
||||
If the SQL type is not recognized, the function will return "string" as the
|
||||
default type.
|
||||
|
||||
:param sql_type: SQL type to map
|
||||
:return: string type recognized by pandas
|
||||
"""
|
||||
if not sql_type:
|
||||
return "string" # If no SQL type is provided, return "string" as default
|
||||
|
||||
# Use regular expressions to check the SQL type. The first match is returned.
|
||||
for pattern, inferred_type in TYPE_MAPPING.items():
|
||||
if pattern.search(sql_type):
|
||||
return inferred_type
|
||||
|
||||
return "string" # If no match is found, return "string" as default
|
||||
|
||||
|
||||
def get_metric_type_from_column(column: Any, datasource: BaseDatasource | Query) -> str:
|
||||
"""
|
||||
Determine the metric type from a given column in a datasource.
|
||||
|
||||
This function checks if the specified column is a metric in the provided
|
||||
datasource. If it is, it extracts the SQL expression associated with the
|
||||
metric and attempts to identify the aggregation operation used within
|
||||
the expression (e.g., SUM, COUNT, etc.). It then maps the operation to
|
||||
a corresponding GenericDataType.
|
||||
|
||||
:param column: The column name or identifier to check.
|
||||
:param datasource: The datasource containing metrics to search within.
|
||||
:return: The inferred metric type as a string, or an empty string if the
|
||||
column is not a metric or no valid operation is found.
|
||||
"""
|
||||
|
||||
from superset.connectors.sqla.models import SqlMetric
|
||||
|
||||
metric: SqlMetric = next(
|
||||
(metric for metric in datasource.metrics if metric.metric_name == column),
|
||||
SqlMetric(metric_name=""),
|
||||
)
|
||||
|
||||
if metric.metric_name == "":
|
||||
return ""
|
||||
|
||||
expression: str = metric.expression
|
||||
|
||||
match = re.match(
|
||||
r"(SUM|AVG|COUNT|COUNT_DISTINCT|MIN|MAX|FIRST|LAST)\((.*)\)", expression
|
||||
)
|
||||
|
||||
if match:
|
||||
operation = match.group(1)
|
||||
return METRIC_MAP_TYPE.get(operation, "")
|
||||
|
||||
logger.warning("Unexpected metric expression type: %s", expression)
|
||||
return ""
|
||||
|
||||
|
||||
def extract_dataframe_dtypes(
|
||||
df: pd.DataFrame,
|
||||
datasource: BaseDatasource | Query | None = None,
|
||||
@@ -1533,7 +1629,17 @@ def extract_dataframe_dtypes(
|
||||
for column in df.columns:
|
||||
column_object = columns_by_name.get(column)
|
||||
series = df[column]
|
||||
inferred_type = infer_dtype(series)
|
||||
inferred_type: str = ""
|
||||
if series.isna().all():
|
||||
sql_type: Optional[str] = ""
|
||||
if datasource and hasattr(datasource, "columns_types"):
|
||||
if column in datasource.columns_types:
|
||||
sql_type = datasource.columns_types.get(column)
|
||||
inferred_type = map_sql_type_to_inferred_type(sql_type)
|
||||
else:
|
||||
inferred_type = get_metric_type_from_column(column, datasource)
|
||||
else:
|
||||
inferred_type = infer_dtype(series)
|
||||
if isinstance(column_object, dict):
|
||||
generic_type = (
|
||||
GenericDataType.TEMPORAL
|
||||
|
||||
86
tests/unit_tests/utils/map_type_tests.py
Normal file
86
tests/unit_tests/utils/map_type_tests.py
Normal file
@@ -0,0 +1,86 @@
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from superset.connectors.sqla.models import SqlMetric
|
||||
from superset.utils.core import (
|
||||
get_metric_type_from_column,
|
||||
map_sql_type_to_inferred_type,
|
||||
)
|
||||
|
||||
|
||||
def test_column_not_in_datasource():
|
||||
datasource = MagicMock(metrics=[])
|
||||
column = "non_existent_column"
|
||||
assert (get_metric_type_from_column(column, datasource)) == ""
|
||||
|
||||
|
||||
def test_column_with_valid_operation():
|
||||
metric = SqlMetric(metric_name="my_column", expression="SUM(my_column)")
|
||||
datasource = MagicMock(metrics=[metric])
|
||||
column = "my_column"
|
||||
assert (get_metric_type_from_column(column, datasource)) == "floating"
|
||||
|
||||
|
||||
def test_column_with_invalid_operation():
|
||||
metric = SqlMetric(metric_name="my_column", expression="INVALID(my_column)")
|
||||
datasource = MagicMock(metrics=[metric])
|
||||
column = "my_column"
|
||||
with patch("superset.utils.core.logger.warning") as mock_warning:
|
||||
assert (get_metric_type_from_column(column, datasource)) == ""
|
||||
mock_warning.assert_called_once()
|
||||
|
||||
|
||||
def test_empty_datasource():
|
||||
datasource = MagicMock(metrics=[])
|
||||
column = "my_column"
|
||||
assert (get_metric_type_from_column(column, datasource)) == ""
|
||||
|
||||
|
||||
def test_column_is_none():
|
||||
datasource = MagicMock(metrics=[])
|
||||
column = None
|
||||
assert (get_metric_type_from_column(column, datasource)) == ""
|
||||
|
||||
|
||||
def test_datasource_is_none():
|
||||
datasource = None
|
||||
column = "my_column"
|
||||
with pytest.raises(AttributeError):
|
||||
get_metric_type_from_column(column, datasource)
|
||||
|
||||
|
||||
def test_none_input():
|
||||
assert (map_sql_type_to_inferred_type(None)) == "string"
|
||||
|
||||
|
||||
def test_empty_string_input():
|
||||
assert (map_sql_type_to_inferred_type("")) == "string"
|
||||
|
||||
|
||||
def test_recognized_sql_type():
|
||||
assert (map_sql_type_to_inferred_type("INT")) == "integer"
|
||||
|
||||
|
||||
def test_unrecognized_sql_type():
|
||||
assert (map_sql_type_to_inferred_type("unknown_type")) == "string"
|
||||
|
||||
|
||||
def test_sql_type_with_special_chars():
|
||||
assert (map_sql_type_to_inferred_type("varchar(255)")) == "string"
|
||||
Reference in New Issue
Block a user