gomoku / LightZero /lzero /mcts /tests /test_image_transform.py
zjowowen's picture
init space
079c32c
raw
history blame
340 Bytes
import pytest
import torch
from lzero.model import ImageTransforms
@pytest.mark.unittest
def test_image_transform():
img = torch.rand((4, 3, 96, 96))
transform = ImageTransforms(['shift', 'intensity'])
processed_img = transform.transform(img)
assert img.shape == (4, 3, 96, 96)
assert not (img == processed_img).all()