import pytest import torch from ding.rl_utils.value_rescale import value_inv_transform, value_transform, symlog, inv_symlog @pytest.mark.unittest class TestValueRescale: def test_value_transform(self): for _ in range(10): t = torch.rand((2, 3)) assert isinstance(value_transform(t), torch.Tensor) assert value_transform(t).shape == t.shape def test_value_inv_transform(self): for _ in range(10): t = torch.rand((2, 3)) assert isinstance(value_inv_transform(t), torch.Tensor) assert value_inv_transform(t).shape == t.shape def test_trans_inverse(self): for _ in range(10): t = torch.rand((4, 16)) diff = value_inv_transform(value_transform(t)) - t assert pytest.approx(diff.abs().max().item(), abs=2e-5) == 0 assert pytest.approx(diff.abs().max().item(), abs=2e-5) == 0 @pytest.mark.unittest class TestSymlog: def test_symlog(self): for _ in range(10): t = torch.rand((3, 4)) assert isinstance(symlog(t), torch.Tensor) assert symlog(t).shape == t.shape def test_inv_symlog(self): for _ in range(10): t = torch.rand((3, 4)) assert isinstance(inv_symlog(t), torch.Tensor) assert inv_symlog(t).shape == t.shape def test_trans_inverse(self): for _ in range(10): t = torch.rand((4, 16)) diff = inv_symlog(symlog(t)) - t assert pytest.approx(diff.abs().max().item(), abs=2e-5) == 0 assert pytest.approx(diff.abs().max().item(), abs=2e-5) == 0