File size: 1,658 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import pytest
import torch
from ding.rl_utils.value_rescale import value_inv_transform, value_transform, symlog, inv_symlog
@pytest.mark.unittest
class TestValueRescale:
def test_value_transform(self):
for _ in range(10):
t = torch.rand((2, 3))
assert isinstance(value_transform(t), torch.Tensor)
assert value_transform(t).shape == t.shape
def test_value_inv_transform(self):
for _ in range(10):
t = torch.rand((2, 3))
assert isinstance(value_inv_transform(t), torch.Tensor)
assert value_inv_transform(t).shape == t.shape
def test_trans_inverse(self):
for _ in range(10):
t = torch.rand((4, 16))
diff = value_inv_transform(value_transform(t)) - t
assert pytest.approx(diff.abs().max().item(), abs=2e-5) == 0
assert pytest.approx(diff.abs().max().item(), abs=2e-5) == 0
@pytest.mark.unittest
class TestSymlog:
def test_symlog(self):
for _ in range(10):
t = torch.rand((3, 4))
assert isinstance(symlog(t), torch.Tensor)
assert symlog(t).shape == t.shape
def test_inv_symlog(self):
for _ in range(10):
t = torch.rand((3, 4))
assert isinstance(inv_symlog(t), torch.Tensor)
assert inv_symlog(t).shape == t.shape
def test_trans_inverse(self):
for _ in range(10):
t = torch.rand((4, 16))
diff = inv_symlog(symlog(t)) - t
assert pytest.approx(diff.abs().max().item(), abs=2e-5) == 0
assert pytest.approx(diff.abs().max().item(), abs=2e-5) == 0
|