|
import torch |
|
import pytest |
|
from torch import nn |
|
from lzero.model.stochastic_muzero_model import ChanceEncoder |
|
|
|
|
|
@pytest.fixture |
|
def encoder(): |
|
return ChanceEncoder((3, 32, 32), 4) |
|
|
|
def test_ChanceEncoder(encoder): |
|
|
|
x_and_last_x = torch.randn(1, 6, 32, 32) |
|
|
|
|
|
chance_encoding_t, chance_onehot_t = encoder(x_and_last_x) |
|
|
|
|
|
assert chance_encoding_t.shape == (1, 4) |
|
assert chance_onehot_t.shape == (1, 4) |
|
|
|
|
|
assert torch.all((chance_onehot_t == 0) | (chance_onehot_t == 1)) |
|
assert torch.all(torch.sum(chance_onehot_t, dim=1) == 1) |
|
|
|
def test_ChanceEncoder_gradients_chance_encoding(encoder): |
|
|
|
x_and_last_x = torch.randn(1, 6, 32, 32) |
|
|
|
|
|
chance_encoding_t, chance_onehot_t = encoder(x_and_last_x) |
|
|
|
|
|
target = torch.randn(1, 4) |
|
|
|
|
|
loss = nn.MSELoss()(chance_encoding_t, target) |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
for param in encoder.parameters(): |
|
assert param.grad is not None |
|
|
|
|
|
for param in encoder.parameters(): |
|
assert param.grad.shape == param.shape |
|
|
|
def test_ChanceEncoder_gradients_chance_onehot_t(encoder): |
|
|
|
x_and_last_x = torch.randn(1, 6, 32, 32) |
|
|
|
|
|
chance_encoding_t, chance_onehot_t = encoder(x_and_last_x) |
|
|
|
|
|
target = torch.randn(1, 4) |
|
|
|
|
|
loss = nn.MSELoss()(chance_onehot_t, target) |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
for param in encoder.parameters(): |
|
assert param.grad is not None |
|
|
|
|
|
for param in encoder.parameters(): |
|
assert param.grad.shape == param.shape |
|
|