fix(charting): correctly categorize numeric columns with NULL values (#34213)

This commit is contained in:
LisaHusband
2025-07-24 20:46:58 +08:00
committed by GitHub
parent 1e5a4e9bdc
commit 7a1c056374
2 changed files with 194 additions and 2 deletions

View File

@@ -48,7 +48,16 @@ from enum import Enum, IntEnum
from io import BytesIO
from timeit import default_timer
from types import TracebackType
from typing import Any, Callable, cast, NamedTuple, TYPE_CHECKING, TypedDict, TypeVar
from typing import (
Any,
Callable,
cast,
NamedTuple,
Optional,
TYPE_CHECKING,
TypedDict,
TypeVar,
)
from urllib.parse import unquote_plus
from zipfile import ZipFile
@@ -120,6 +129,32 @@ InputType = TypeVar("InputType") # pylint: disable=invalid-name
ADHOC_FILTERS_REGEX = re.compile("^adhoc_filters")
TYPE_MAPPING = {
re.compile(r"INT", re.IGNORECASE): "integer",
re.compile(r"CHAR|TEXT|VARCHAR", re.IGNORECASE): "string",
re.compile(r"DECIMAL|NUMERIC|FLOAT|DOUBLE", re.IGNORECASE): "floating",
re.compile(r"BOOL", re.IGNORECASE): "boolean",
re.compile(r"DATE|TIME", re.IGNORECASE): "datetime64",
}
METRIC_MAP_TYPE = {
"SUM": "floating",
"AVG": "floating",
"COUNT": "floating",
"COUNT_DISTINCT": "floating",
"MIN": "numeric",
"MAX": "numeric",
"FIRST": "string",
"LAST": "string",
"GROUP_CONCAT": "string",
"ARRAY_AGG": "string",
"STRING_AGG": "string",
"MEDIAN": "floating",
"PERCENTILE": "floating",
"VARIANCE": "floating",
"STDDEV": "floating",
}
class AdhocMetricExpressionType(StrEnum):
SIMPLE = "SIMPLE"
@@ -1503,6 +1538,67 @@ def get_column_names_from_metrics(metrics: list[Metric]) -> list[str]:
return [col for col in map(get_column_name_from_metric, metrics) if col]
def map_sql_type_to_inferred_type(sql_type: Optional[str]) -> str:
"""
Map a SQL type to a type string recognized by pandas' `infer_objects` method.
If the SQL type is not recognized, the function will return "string" as the
default type.
:param sql_type: SQL type to map
:return: string type recognized by pandas
"""
if not sql_type:
return "string" # If no SQL type is provided, return "string" as default
# Use regular expressions to check the SQL type. The first match is returned.
for pattern, inferred_type in TYPE_MAPPING.items():
if pattern.search(sql_type):
return inferred_type
return "string" # If no match is found, return "string" as default
def get_metric_type_from_column(column: Any, datasource: BaseDatasource | Query) -> str:
"""
Determine the metric type from a given column in a datasource.
This function checks if the specified column is a metric in the provided
datasource. If it is, it extracts the SQL expression associated with the
metric and attempts to identify the aggregation operation used within
the expression (e.g., SUM, COUNT, etc.). It then maps the operation to
a corresponding GenericDataType.
:param column: The column name or identifier to check.
:param datasource: The datasource containing metrics to search within.
:return: The inferred metric type as a string, or an empty string if the
column is not a metric or no valid operation is found.
"""
from superset.connectors.sqla.models import SqlMetric
metric: SqlMetric = next(
(metric for metric in datasource.metrics if metric.metric_name == column),
SqlMetric(metric_name=""),
)
if metric.metric_name == "":
return ""
expression: str = metric.expression
match = re.match(
r"(SUM|AVG|COUNT|COUNT_DISTINCT|MIN|MAX|FIRST|LAST)\((.*)\)", expression
)
if match:
operation = match.group(1)
return METRIC_MAP_TYPE.get(operation, "")
logger.warning("Unexpected metric expression type: %s", expression)
return ""
def extract_dataframe_dtypes(
df: pd.DataFrame,
datasource: BaseDatasource | Query | None = None,
@@ -1533,7 +1629,17 @@ def extract_dataframe_dtypes(
for column in df.columns:
column_object = columns_by_name.get(column)
series = df[column]
inferred_type = infer_dtype(series)
inferred_type: str = ""
if series.isna().all():
sql_type: Optional[str] = ""
if datasource and hasattr(datasource, "columns_types"):
if column in datasource.columns_types:
sql_type = datasource.columns_types.get(column)
inferred_type = map_sql_type_to_inferred_type(sql_type)
else:
inferred_type = get_metric_type_from_column(column, datasource)
else:
inferred_type = infer_dtype(series)
if isinstance(column_object, dict):
generic_type = (
GenericDataType.TEMPORAL

View File

@@ -0,0 +1,86 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from unittest.mock import MagicMock, patch
import pytest
from superset.connectors.sqla.models import SqlMetric
from superset.utils.core import (
get_metric_type_from_column,
map_sql_type_to_inferred_type,
)
def test_column_not_in_datasource():
datasource = MagicMock(metrics=[])
column = "non_existent_column"
assert (get_metric_type_from_column(column, datasource)) == ""
def test_column_with_valid_operation():
metric = SqlMetric(metric_name="my_column", expression="SUM(my_column)")
datasource = MagicMock(metrics=[metric])
column = "my_column"
assert (get_metric_type_from_column(column, datasource)) == "floating"
def test_column_with_invalid_operation():
metric = SqlMetric(metric_name="my_column", expression="INVALID(my_column)")
datasource = MagicMock(metrics=[metric])
column = "my_column"
with patch("superset.utils.core.logger.warning") as mock_warning:
assert (get_metric_type_from_column(column, datasource)) == ""
mock_warning.assert_called_once()
def test_empty_datasource():
datasource = MagicMock(metrics=[])
column = "my_column"
assert (get_metric_type_from_column(column, datasource)) == ""
def test_column_is_none():
datasource = MagicMock(metrics=[])
column = None
assert (get_metric_type_from_column(column, datasource)) == ""
def test_datasource_is_none():
datasource = None
column = "my_column"
with pytest.raises(AttributeError):
get_metric_type_from_column(column, datasource)
def test_none_input():
assert (map_sql_type_to_inferred_type(None)) == "string"
def test_empty_string_input():
assert (map_sql_type_to_inferred_type("")) == "string"
def test_recognized_sql_type():
assert (map_sql_type_to_inferred_type("INT")) == "integer"
def test_unrecognized_sql_type():
assert (map_sql_type_to_inferred_type("unknown_type")) == "string"
def test_sql_type_with_special_chars():
assert (map_sql_type_to_inferred_type("varchar(255)")) == "string"