File size: 1,895 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import pytest
import torch
from ding.rl_utils import vtrace_data, vtrace_error_discrete_action, vtrace_error_continuous_action


@pytest.mark.unittest
def test_vtrace_discrete_action():
    T, B, N = 4, 8, 16
    value = torch.randn(T + 1, B).requires_grad_(True)
    reward = torch.rand(T, B)
    target_output = torch.randn(T, B, N).requires_grad_(True)
    behaviour_output = torch.randn(T, B, N)
    action = torch.randint(0, N, size=(T, B))
    data = vtrace_data(target_output, behaviour_output, action, value, reward, None)
    loss = vtrace_error_discrete_action(data, rho_clip_ratio=1.1)
    assert all([l.shape == tuple() for l in loss])
    assert target_output.grad is None
    assert value.grad is None
    loss = sum(loss)
    loss.backward()
    assert isinstance(target_output, torch.Tensor)
    assert isinstance(value, torch.Tensor)


@pytest.mark.unittest
def test_vtrace_continuous_action():
    T, B, N = 4, 8, 16
    value = torch.randn(T + 1, B).requires_grad_(True)
    reward = torch.rand(T, B)
    target_output = {}
    target_output['mu'] = torch.randn(T, B, N).requires_grad_(True)
    target_output['sigma'] = torch.exp(torch.randn(T, B, N).requires_grad_(True))
    behaviour_output = {}
    behaviour_output['mu'] = torch.randn(T, B, N)
    behaviour_output['sigma'] = torch.exp(torch.randn(T, B, N))
    action = torch.randn((T, B, N))
    data = vtrace_data(target_output, behaviour_output, action, value, reward, None)
    loss = vtrace_error_continuous_action(data, rho_clip_ratio=1.1)
    assert all([l.shape == tuple() for l in loss])
    assert target_output['mu'].grad is None
    assert target_output['sigma'].grad is None
    assert value.grad is None
    loss = sum(loss)
    loss.backward()
    assert isinstance(target_output['mu'], torch.Tensor)
    assert isinstance(target_output['sigma'], torch.Tensor)
    assert isinstance(value, torch.Tensor)