mirror of
https://github.com/apache/superset.git
synced 2026-04-19 16:14:52 +00:00
fix: Add aliases and groupby list to chart schemas (#38740)
This commit is contained in:
committed by
GitHub
parent
972e15e601
commit
14b1b456e1
@@ -158,16 +158,31 @@ class TestXYChartConfig:
|
||||
)
|
||||
assert len(config.y) == 2
|
||||
|
||||
def test_group_by_duplicate_with_x_rejected(self) -> None:
|
||||
"""Test that group_by conflicts with x are rejected."""
|
||||
def test_group_by_duplicate_label_with_x_rejected(self) -> None:
|
||||
"""Test that group_by with a custom label conflicting with x is rejected."""
|
||||
with pytest.raises(ValidationError, match="Duplicate column/metric labels"):
|
||||
XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="region"),
|
||||
y=[ColumnRef(name="sales", aggregate="SUM")],
|
||||
group_by=ColumnRef(name="category", label="region"),
|
||||
group_by=[ColumnRef(name="category", label="region")],
|
||||
)
|
||||
|
||||
def test_group_by_same_column_as_x_allowed(self) -> None:
|
||||
"""Test that group_by with the same column name as x is allowed.
|
||||
|
||||
The mapping layer filters these out, so validation should not reject them.
|
||||
"""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="sales", aggregate="SUM")],
|
||||
kind="line",
|
||||
group_by=[ColumnRef(name="date")],
|
||||
)
|
||||
assert config.group_by is not None
|
||||
assert config.group_by[0].name == "date"
|
||||
|
||||
def test_realistic_chart_configurations(self) -> None:
|
||||
"""Test realistic chart configurations."""
|
||||
# This should work - COUNT(product_line) != product_line
|
||||
@@ -220,19 +235,34 @@ class TestXYChartConfig:
|
||||
)
|
||||
assert config.kind == "area"
|
||||
|
||||
def test_unknown_fields_rejected(self) -> None:
|
||||
"""Test that unknown fields like 'series' are rejected."""
|
||||
with pytest.raises(ValidationError, match="Extra inputs are not permitted"):
|
||||
XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="territory"),
|
||||
y=[ColumnRef(name="sales", aggregate="SUM")],
|
||||
kind="bar",
|
||||
series=ColumnRef(name="year"),
|
||||
)
|
||||
def test_unknown_fields_ignored(self) -> None:
|
||||
"""Test that unknown fields are silently ignored (extra='ignore')."""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="territory"),
|
||||
y=[ColumnRef(name="sales", aggregate="SUM")],
|
||||
kind="bar",
|
||||
unknown_field="bad",
|
||||
)
|
||||
assert config.kind == "bar"
|
||||
assert not hasattr(config, "unknown_field")
|
||||
|
||||
def test_series_alias_accepted(self) -> None:
|
||||
"""Test that 'series' is accepted as alias for 'group_by'."""
|
||||
config = XYChartConfig.model_validate(
|
||||
{
|
||||
"chart_type": "xy",
|
||||
"x": {"name": "territory"},
|
||||
"y": [{"name": "sales", "aggregate": "SUM"}],
|
||||
"kind": "bar",
|
||||
"series": {"name": "year"},
|
||||
}
|
||||
)
|
||||
assert config.group_by is not None
|
||||
assert config.group_by[0].name == "year"
|
||||
|
||||
def test_group_by_accepted(self) -> None:
|
||||
"""Test that group_by is the correct field for series grouping."""
|
||||
"""Test that group_by accepts a single ColumnRef (auto-wrapped in list)."""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="territory"),
|
||||
@@ -241,7 +271,20 @@ class TestXYChartConfig:
|
||||
group_by=ColumnRef(name="year"),
|
||||
)
|
||||
assert config.group_by is not None
|
||||
assert config.group_by.name == "year"
|
||||
assert len(config.group_by) == 1
|
||||
assert config.group_by[0].name == "year"
|
||||
|
||||
def test_group_by_multiple(self) -> None:
|
||||
"""Test that group_by accepts a list of ColumnRefs."""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="territory"),
|
||||
y=[ColumnRef(name="sales", aggregate="SUM")],
|
||||
kind="bar",
|
||||
group_by=[ColumnRef(name="year"), ColumnRef(name="region")],
|
||||
)
|
||||
assert config.group_by is not None
|
||||
assert len(config.group_by) == 2
|
||||
|
||||
def test_orientation_horizontal_accepted(self) -> None:
|
||||
"""Test that orientation='horizontal' is accepted for bar charts."""
|
||||
@@ -374,11 +417,12 @@ class TestRowLimit:
|
||||
class TestTableChartConfigExtraFields:
|
||||
"""Test TableChartConfig rejects unknown fields."""
|
||||
|
||||
def test_unknown_fields_rejected(self) -> None:
|
||||
"""Test that unknown fields are rejected."""
|
||||
with pytest.raises(ValidationError, match="Extra inputs are not permitted"):
|
||||
TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="product")],
|
||||
foo="bar",
|
||||
)
|
||||
def test_unknown_fields_ignored(self) -> None:
|
||||
"""Test that unknown fields are silently ignored (extra='ignore')."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="product")],
|
||||
foo="bar",
|
||||
)
|
||||
assert len(config.columns) == 1
|
||||
assert not hasattr(config, "foo")
|
||||
|
||||
@@ -103,14 +103,15 @@ class TestPieChartConfigSchema:
|
||||
assert config.filters is not None
|
||||
assert len(config.filters) == 1
|
||||
|
||||
def test_pie_config_rejects_extra_fields(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
PieChartConfig(
|
||||
chart_type="pie",
|
||||
dimension=ColumnRef(name="product"),
|
||||
metric=ColumnRef(name="revenue", aggregate="SUM"),
|
||||
unknown_field="bad",
|
||||
)
|
||||
def test_pie_config_ignores_extra_fields(self) -> None:
|
||||
config = PieChartConfig(
|
||||
chart_type="pie",
|
||||
dimension=ColumnRef(name="product"),
|
||||
metric=ColumnRef(name="revenue", aggregate="SUM"),
|
||||
unknown_field="bad",
|
||||
)
|
||||
assert config.dimension.name == "product"
|
||||
assert not hasattr(config, "unknown_field")
|
||||
|
||||
def test_pie_config_missing_dimension(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
@@ -323,14 +324,15 @@ class TestPivotTableChartConfigSchema:
|
||||
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
)
|
||||
|
||||
def test_pivot_table_rejects_extra_fields(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
PivotTableChartConfig(
|
||||
chart_type="pivot_table",
|
||||
rows=[ColumnRef(name="product")],
|
||||
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
unknown_field="bad",
|
||||
)
|
||||
def test_pivot_table_ignores_extra_fields(self) -> None:
|
||||
config = PivotTableChartConfig(
|
||||
chart_type="pivot_table",
|
||||
rows=[ColumnRef(name="product")],
|
||||
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
unknown_field="bad",
|
||||
)
|
||||
assert config.rows[0].name == "product"
|
||||
assert not hasattr(config, "unknown_field")
|
||||
|
||||
def test_pivot_table_valid_aggregate_functions(self) -> None:
|
||||
for agg in ["Sum", "Average", "Median", "Count", "Minimum", "Maximum"]:
|
||||
@@ -459,10 +461,10 @@ class TestMixedTimeseriesChartConfigSchema:
|
||||
time_grain="P1M",
|
||||
y=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
primary_kind="area",
|
||||
group_by=ColumnRef(name="region"),
|
||||
group_by=[ColumnRef(name="region")],
|
||||
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
|
||||
secondary_kind="scatter",
|
||||
group_by_secondary=ColumnRef(name="channel"),
|
||||
group_by_secondary=[ColumnRef(name="channel")],
|
||||
show_legend=False,
|
||||
x_axis=AxisConfig(title="Date"),
|
||||
y_axis=AxisConfig(title="Revenue", format="$,.2f"),
|
||||
@@ -473,9 +475,9 @@ class TestMixedTimeseriesChartConfigSchema:
|
||||
assert config.secondary_kind == "scatter"
|
||||
assert config.time_grain == "P1M"
|
||||
assert config.group_by is not None
|
||||
assert config.group_by.name == "region"
|
||||
assert config.group_by[0].name == "region"
|
||||
assert config.group_by_secondary is not None
|
||||
assert config.group_by_secondary.name == "channel"
|
||||
assert config.group_by_secondary[0].name == "channel"
|
||||
|
||||
def test_mixed_timeseries_missing_y(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
@@ -502,15 +504,16 @@ class TestMixedTimeseriesChartConfigSchema:
|
||||
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
|
||||
)
|
||||
|
||||
def test_mixed_timeseries_rejects_extra_fields(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
MixedTimeseriesChartConfig(
|
||||
chart_type="mixed_timeseries",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
|
||||
unknown_field="bad",
|
||||
)
|
||||
def test_mixed_timeseries_ignores_extra_fields(self) -> None:
|
||||
config = MixedTimeseriesChartConfig(
|
||||
chart_type="mixed_timeseries",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
|
||||
unknown_field="bad",
|
||||
)
|
||||
assert config.x.name == "date"
|
||||
assert not hasattr(config, "unknown_field")
|
||||
|
||||
def test_mixed_timeseries_default_row_limit(self) -> None:
|
||||
config = MixedTimeseriesChartConfig(
|
||||
@@ -606,9 +609,9 @@ class TestMapMixedTimeseriesConfig:
|
||||
chart_type="mixed_timeseries",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
group_by=ColumnRef(name="region"),
|
||||
group_by=[ColumnRef(name="region")],
|
||||
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
|
||||
group_by_secondary=ColumnRef(name="channel"),
|
||||
group_by_secondary=[ColumnRef(name="channel")],
|
||||
)
|
||||
result = map_mixed_timeseries_config(config, dataset_id=1)
|
||||
|
||||
@@ -623,9 +626,9 @@ class TestMapMixedTimeseriesConfig:
|
||||
chart_type="mixed_timeseries",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
group_by=ColumnRef(name="date"), # same as x
|
||||
group_by=[ColumnRef(name="date")], # same as x
|
||||
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
|
||||
group_by_secondary=ColumnRef(name="date"), # same as x
|
||||
group_by_secondary=[ColumnRef(name="date")], # same as x
|
||||
)
|
||||
result = map_mixed_timeseries_config(config, dataset_id=1)
|
||||
|
||||
|
||||
@@ -163,12 +163,12 @@ class TestNormalizeXYConfig:
|
||||
"x": {"name": "OrderDate"},
|
||||
"y": [{"name": "Sales", "aggregate": "SUM"}],
|
||||
"kind": "line",
|
||||
"group_by": {"name": "productline"},
|
||||
"group_by": [{"name": "productline"}],
|
||||
}
|
||||
|
||||
DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context)
|
||||
|
||||
assert config_dict["group_by"]["name"] == "ProductLine"
|
||||
assert config_dict["group_by"][0]["name"] == "ProductLine"
|
||||
|
||||
|
||||
class TestNormalizeTableConfig:
|
||||
@@ -397,7 +397,7 @@ class TestNormalizeUppercaseDataset:
|
||||
assert normalized.x.name == "ds"
|
||||
assert normalized.y[0].name == "DISTANCE"
|
||||
assert normalized.group_by is not None
|
||||
assert normalized.group_by.name == "AIRLINE"
|
||||
assert normalized.group_by[0].name == "AIRLINE"
|
||||
assert normalized.filters is not None
|
||||
assert normalized.filters[0].column == "AIRLINE"
|
||||
|
||||
@@ -531,7 +531,7 @@ class TestNormalizeEdgeCases:
|
||||
assert normalized.y[0].name == "Sales"
|
||||
assert normalized.y[1].name == "quantity_ordered"
|
||||
assert normalized.group_by is not None
|
||||
assert normalized.group_by.name == "ProductLine"
|
||||
assert normalized.group_by[0].name == "ProductLine"
|
||||
assert normalized.filters is not None
|
||||
assert normalized.filters[0].column == "ProductLine"
|
||||
assert normalized.filters[1].column == "OrderDate"
|
||||
@@ -678,4 +678,4 @@ class TestNormalizeXAxisFilterConsistency:
|
||||
|
||||
assert normalized.group_by is not None
|
||||
assert normalized.filters is not None
|
||||
assert normalized.group_by.name == normalized.filters[0].column == "AIRLINE"
|
||||
assert normalized.group_by[0].name == normalized.filters[0].column == "AIRLINE"
|
||||
|
||||
Reference in New Issue
Block a user