| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import pytest |
| | import torch |
| | import torch.nn as nn |
| |
|
| | from icefall.checkpoint import average_checkpoints, load_checkpoint, save_checkpoint |
| |
|
| |
|
| | @pytest.fixture |
| | def checkpoints1(tmp_path): |
| | f = tmp_path / "f.pt" |
| | m = nn.Module() |
| | m.p1 = nn.Parameter(torch.tensor([10.0, 20.0]), requires_grad=False) |
| | m.register_buffer("p2", torch.tensor([10, 100])) |
| |
|
| | params = {"a": 10, "b": 20} |
| | save_checkpoint(f, m, params=params) |
| | return f |
| |
|
| |
|
| | @pytest.fixture |
| | def checkpoints2(tmp_path): |
| | f = tmp_path / "f2.pt" |
| | m = nn.Module() |
| | m.p1 = nn.Parameter(torch.Tensor([50, 30.0])) |
| | m.register_buffer("p2", torch.tensor([1, 3])) |
| | params = {"a": 100, "b": 200} |
| |
|
| | save_checkpoint(f, m, params=params) |
| | return f |
| |
|
| |
|
| | def test_load_checkpoints(checkpoints1): |
| | m = nn.Module() |
| | m.p1 = nn.Parameter(torch.Tensor([0, 0.0])) |
| | m.p2 = nn.Parameter(torch.Tensor([0, 0])) |
| | params = load_checkpoint(checkpoints1, m) |
| | assert torch.allclose(m.p1, torch.Tensor([10.0, 20])) |
| | assert params["a"] == 10 |
| | assert params["b"] == 20 |
| |
|
| |
|
| | def test_average_checkpoints(checkpoints1, checkpoints2): |
| | state_dict = average_checkpoints([checkpoints1, checkpoints2]) |
| | assert torch.allclose(state_dict["p1"], torch.Tensor([30, 25.0])) |
| | assert torch.allclose(state_dict["p2"], torch.tensor([5, 51])) |
| |
|