gomoku / LightZero /lzero /model /tests /test_stochastic_muzero_model.py
zjowowen's picture
init space
079c32c
raw
history blame
2.24 kB
import torch
import pytest
from torch import nn
from lzero.model.stochastic_muzero_model import ChanceEncoder
# Initialize a ChanceEncoder instance for testing
@pytest.fixture
def encoder():
return ChanceEncoder((3, 32, 32), 4)
def test_ChanceEncoder(encoder):
# Create a dummy tensor for testing
x_and_last_x = torch.randn(1, 6, 32, 32)
# Forward pass through the encoder
chance_encoding_t, chance_onehot_t = encoder(x_and_last_x)
# Check the output shapes
assert chance_encoding_t.shape == (1, 4)
assert chance_onehot_t.shape == (1, 4)
# Check that chance_onehot_t is indeed one-hot
assert torch.all((chance_onehot_t == 0) | (chance_onehot_t == 1))
assert torch.all(torch.sum(chance_onehot_t, dim=1) == 1)
def test_ChanceEncoder_gradients_chance_encoding(encoder):
# Create a dummy tensor for testing
x_and_last_x = torch.randn(1, 6, 32, 32)
# Forward pass through the encoder
chance_encoding_t, chance_onehot_t = encoder(x_and_last_x)
# Create a dummy target tensor for a simple loss function
target = torch.randn(1, 4)
# Use mean squared error as a simple loss function
loss = nn.MSELoss()(chance_encoding_t, target)
# Backward pass
loss.backward()
# Check if gradients are computed
for param in encoder.parameters():
assert param.grad is not None
# Check if gradients have the correct shape
for param in encoder.parameters():
assert param.grad.shape == param.shape
def test_ChanceEncoder_gradients_chance_onehot_t(encoder):
# Create a dummy tensor for testing
x_and_last_x = torch.randn(1, 6, 32, 32)
# Forward pass through the encoder
chance_encoding_t, chance_onehot_t = encoder(x_and_last_x)
# Create a dummy target tensor for a simple loss function
target = torch.randn(1, 4)
# Use mean squared error as a simple loss function
loss = nn.MSELoss()(chance_onehot_t, target)
# Backward pass
loss.backward()
# Check if gradients are computed
for param in encoder.parameters():
assert param.grad is not None
# Check if gradients have the correct shape
for param in encoder.parameters():
assert param.grad.shape == param.shape