File size: 2,148 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import torch.nn.functional as F
from collections import namedtuple
from ding.rl_utils.isw import compute_importance_weights


def compute_q_retraces(
        q_values: torch.Tensor,
        v_pred: torch.Tensor,
        rewards: torch.Tensor,
        actions: torch.Tensor,
        weights: torch.Tensor,
        ratio: torch.Tensor,
        gamma: float = 0.9
) -> torch.Tensor:
    """
    Shapes:
        - q_values (:obj:`torch.Tensor`): :math:`(T + 1, B, N)`, where T is unroll_len, B is batch size, N is discrete \
            action dim.
        - v_pred (:obj:`torch.Tensor`): :math:`(T + 1, B, 1)`
        - rewards (:obj:`torch.Tensor`): :math:`(T, B)`
        - actions (:obj:`torch.Tensor`): :math:`(T, B)`
        - weights (:obj:`torch.Tensor`): :math:`(T, B)`
        - ratio (:obj:`torch.Tensor`): :math:`(T, B, N)`
        - q_retraces (:obj:`torch.Tensor`): :math:`(T + 1, B, 1)`
    Examples:
        >>> T=2
        >>> B=3
        >>> N=4
        >>> q_values=torch.randn(T+1, B, N)
        >>> v_pred=torch.randn(T+1, B, 1)
        >>> rewards=torch.randn(T, B)
        >>> actions=torch.randint(0, N, (T, B))
        >>> weights=torch.ones(T, B)
        >>> ratio=torch.randn(T, B, N)
        >>> q_retraces = compute_q_retraces(q_values, v_pred, rewards, actions, weights, ratio)

    .. note::
        q_retrace operation doesn't need to compute gradient, just executes forward computation.
    """
    T = q_values.size()[0] - 1
    rewards = rewards.unsqueeze(-1)
    actions = actions.unsqueeze(-1)
    weights = weights.unsqueeze(-1)
    q_retraces = torch.zeros_like(v_pred)  # shape (T+1),B,1
    tmp_retraces = v_pred[-1]  # shape B,1
    q_retraces[-1] = v_pred[-1]

    q_gather = torch.zeros_like(v_pred)
    q_gather[0:-1] = q_values[0:-1].gather(-1, actions)  # shape (T+1),B,1
    ratio_gather = ratio.gather(-1, actions)  # shape T,B,1

    for idx in reversed(range(T)):
        q_retraces[idx] = rewards[idx] + gamma * weights[idx] * tmp_retraces
        tmp_retraces = ratio_gather[idx].clamp(max=1.0) * (q_retraces[idx] - q_gather[idx]) + v_pred[idx]
    return q_retraces  # shape (T+1),B,1