fix: Add aliases and groupby list to chart schemas (#38740)

This commit is contained in:
Kamil Gabryjelski
2026-03-19 16:15:58 +01:00
committed by GitHub
parent 972e15e601
commit 14b1b456e1
7 changed files with 262 additions and 120 deletions

View File

@@ -158,16 +158,31 @@ class TestXYChartConfig:
)
assert len(config.y) == 2
def test_group_by_duplicate_with_x_rejected(self) -> None:
"""Test that group_by conflicts with x are rejected."""
def test_group_by_duplicate_label_with_x_rejected(self) -> None:
"""Test that group_by with a custom label conflicting with x is rejected."""
with pytest.raises(ValidationError, match="Duplicate column/metric labels"):
XYChartConfig(
chart_type="xy",
x=ColumnRef(name="region"),
y=[ColumnRef(name="sales", aggregate="SUM")],
group_by=ColumnRef(name="category", label="region"),
group_by=[ColumnRef(name="category", label="region")],
)
def test_group_by_same_column_as_x_allowed(self) -> None:
"""Test that group_by with the same column name as x is allowed.
The mapping layer filters these out, so validation should not reject them.
"""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[ColumnRef(name="sales", aggregate="SUM")],
kind="line",
group_by=[ColumnRef(name="date")],
)
assert config.group_by is not None
assert config.group_by[0].name == "date"
def test_realistic_chart_configurations(self) -> None:
"""Test realistic chart configurations."""
# This should work - COUNT(product_line) != product_line
@@ -220,19 +235,34 @@ class TestXYChartConfig:
)
assert config.kind == "area"
def test_unknown_fields_rejected(self) -> None:
"""Test that unknown fields like 'series' are rejected."""
with pytest.raises(ValidationError, match="Extra inputs are not permitted"):
XYChartConfig(
chart_type="xy",
x=ColumnRef(name="territory"),
y=[ColumnRef(name="sales", aggregate="SUM")],
kind="bar",
series=ColumnRef(name="year"),
)
def test_unknown_fields_ignored(self) -> None:
"""Test that unknown fields are silently ignored (extra='ignore')."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="territory"),
y=[ColumnRef(name="sales", aggregate="SUM")],
kind="bar",
unknown_field="bad",
)
assert config.kind == "bar"
assert not hasattr(config, "unknown_field")
def test_series_alias_accepted(self) -> None:
"""Test that 'series' is accepted as alias for 'group_by'."""
config = XYChartConfig.model_validate(
{
"chart_type": "xy",
"x": {"name": "territory"},
"y": [{"name": "sales", "aggregate": "SUM"}],
"kind": "bar",
"series": {"name": "year"},
}
)
assert config.group_by is not None
assert config.group_by[0].name == "year"
def test_group_by_accepted(self) -> None:
"""Test that group_by is the correct field for series grouping."""
"""Test that group_by accepts a single ColumnRef (auto-wrapped in list)."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="territory"),
@@ -241,7 +271,20 @@ class TestXYChartConfig:
group_by=ColumnRef(name="year"),
)
assert config.group_by is not None
assert config.group_by.name == "year"
assert len(config.group_by) == 1
assert config.group_by[0].name == "year"
def test_group_by_multiple(self) -> None:
"""Test that group_by accepts a list of ColumnRefs."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="territory"),
y=[ColumnRef(name="sales", aggregate="SUM")],
kind="bar",
group_by=[ColumnRef(name="year"), ColumnRef(name="region")],
)
assert config.group_by is not None
assert len(config.group_by) == 2
def test_orientation_horizontal_accepted(self) -> None:
"""Test that orientation='horizontal' is accepted for bar charts."""
@@ -374,11 +417,12 @@ class TestRowLimit:
class TestTableChartConfigExtraFields:
"""Test TableChartConfig rejects unknown fields."""
def test_unknown_fields_rejected(self) -> None:
"""Test that unknown fields are rejected."""
with pytest.raises(ValidationError, match="Extra inputs are not permitted"):
TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="product")],
foo="bar",
)
def test_unknown_fields_ignored(self) -> None:
"""Test that unknown fields are silently ignored (extra='ignore')."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="product")],
foo="bar",
)
assert len(config.columns) == 1
assert not hasattr(config, "foo")

View File

@@ -103,14 +103,15 @@ class TestPieChartConfigSchema:
assert config.filters is not None
assert len(config.filters) == 1
def test_pie_config_rejects_extra_fields(self) -> None:
with pytest.raises(ValidationError):
PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="product"),
metric=ColumnRef(name="revenue", aggregate="SUM"),
unknown_field="bad",
)
def test_pie_config_ignores_extra_fields(self) -> None:
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="product"),
metric=ColumnRef(name="revenue", aggregate="SUM"),
unknown_field="bad",
)
assert config.dimension.name == "product"
assert not hasattr(config, "unknown_field")
def test_pie_config_missing_dimension(self) -> None:
with pytest.raises(ValidationError):
@@ -323,14 +324,15 @@ class TestPivotTableChartConfigSchema:
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
)
def test_pivot_table_rejects_extra_fields(self) -> None:
with pytest.raises(ValidationError):
PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="product")],
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
unknown_field="bad",
)
def test_pivot_table_ignores_extra_fields(self) -> None:
config = PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="product")],
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
unknown_field="bad",
)
assert config.rows[0].name == "product"
assert not hasattr(config, "unknown_field")
def test_pivot_table_valid_aggregate_functions(self) -> None:
for agg in ["Sum", "Average", "Median", "Count", "Minimum", "Maximum"]:
@@ -459,10 +461,10 @@ class TestMixedTimeseriesChartConfigSchema:
time_grain="P1M",
y=[ColumnRef(name="revenue", aggregate="SUM")],
primary_kind="area",
group_by=ColumnRef(name="region"),
group_by=[ColumnRef(name="region")],
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
secondary_kind="scatter",
group_by_secondary=ColumnRef(name="channel"),
group_by_secondary=[ColumnRef(name="channel")],
show_legend=False,
x_axis=AxisConfig(title="Date"),
y_axis=AxisConfig(title="Revenue", format="$,.2f"),
@@ -473,9 +475,9 @@ class TestMixedTimeseriesChartConfigSchema:
assert config.secondary_kind == "scatter"
assert config.time_grain == "P1M"
assert config.group_by is not None
assert config.group_by.name == "region"
assert config.group_by[0].name == "region"
assert config.group_by_secondary is not None
assert config.group_by_secondary.name == "channel"
assert config.group_by_secondary[0].name == "channel"
def test_mixed_timeseries_missing_y(self) -> None:
with pytest.raises(ValidationError):
@@ -502,15 +504,16 @@ class TestMixedTimeseriesChartConfigSchema:
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
)
def test_mixed_timeseries_rejects_extra_fields(self) -> None:
with pytest.raises(ValidationError):
MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
unknown_field="bad",
)
def test_mixed_timeseries_ignores_extra_fields(self) -> None:
config = MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
unknown_field="bad",
)
assert config.x.name == "date"
assert not hasattr(config, "unknown_field")
def test_mixed_timeseries_default_row_limit(self) -> None:
config = MixedTimeseriesChartConfig(
@@ -606,9 +609,9 @@ class TestMapMixedTimeseriesConfig:
chart_type="mixed_timeseries",
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
group_by=ColumnRef(name="region"),
group_by=[ColumnRef(name="region")],
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
group_by_secondary=ColumnRef(name="channel"),
group_by_secondary=[ColumnRef(name="channel")],
)
result = map_mixed_timeseries_config(config, dataset_id=1)
@@ -623,9 +626,9 @@ class TestMapMixedTimeseriesConfig:
chart_type="mixed_timeseries",
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
group_by=ColumnRef(name="date"), # same as x
group_by=[ColumnRef(name="date")], # same as x
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
group_by_secondary=ColumnRef(name="date"), # same as x
group_by_secondary=[ColumnRef(name="date")], # same as x
)
result = map_mixed_timeseries_config(config, dataset_id=1)

View File

@@ -163,12 +163,12 @@ class TestNormalizeXYConfig:
"x": {"name": "OrderDate"},
"y": [{"name": "Sales", "aggregate": "SUM"}],
"kind": "line",
"group_by": {"name": "productline"},
"group_by": [{"name": "productline"}],
}
DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context)
assert config_dict["group_by"]["name"] == "ProductLine"
assert config_dict["group_by"][0]["name"] == "ProductLine"
class TestNormalizeTableConfig:
@@ -397,7 +397,7 @@ class TestNormalizeUppercaseDataset:
assert normalized.x.name == "ds"
assert normalized.y[0].name == "DISTANCE"
assert normalized.group_by is not None
assert normalized.group_by.name == "AIRLINE"
assert normalized.group_by[0].name == "AIRLINE"
assert normalized.filters is not None
assert normalized.filters[0].column == "AIRLINE"
@@ -531,7 +531,7 @@ class TestNormalizeEdgeCases:
assert normalized.y[0].name == "Sales"
assert normalized.y[1].name == "quantity_ordered"
assert normalized.group_by is not None
assert normalized.group_by.name == "ProductLine"
assert normalized.group_by[0].name == "ProductLine"
assert normalized.filters is not None
assert normalized.filters[0].column == "ProductLine"
assert normalized.filters[1].column == "OrderDate"
@@ -678,4 +678,4 @@ class TestNormalizeXAxisFilterConsistency:
assert normalized.group_by is not None
assert normalized.filters is not None
assert normalized.group_by.name == normalized.filters[0].column == "AIRLINE"
assert normalized.group_by[0].name == normalized.filters[0].column == "AIRLINE"