File size: 2,148 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import torch
import torch.nn.functional as F
from collections import namedtuple
from ding.rl_utils.isw import compute_importance_weights
def compute_q_retraces(
q_values: torch.Tensor,
v_pred: torch.Tensor,
rewards: torch.Tensor,
actions: torch.Tensor,
weights: torch.Tensor,
ratio: torch.Tensor,
gamma: float = 0.9
) -> torch.Tensor:
"""
Shapes:
- q_values (:obj:`torch.Tensor`): :math:`(T + 1, B, N)`, where T is unroll_len, B is batch size, N is discrete \
action dim.
- v_pred (:obj:`torch.Tensor`): :math:`(T + 1, B, 1)`
- rewards (:obj:`torch.Tensor`): :math:`(T, B)`
- actions (:obj:`torch.Tensor`): :math:`(T, B)`
- weights (:obj:`torch.Tensor`): :math:`(T, B)`
- ratio (:obj:`torch.Tensor`): :math:`(T, B, N)`
- q_retraces (:obj:`torch.Tensor`): :math:`(T + 1, B, 1)`
Examples:
>>> T=2
>>> B=3
>>> N=4
>>> q_values=torch.randn(T+1, B, N)
>>> v_pred=torch.randn(T+1, B, 1)
>>> rewards=torch.randn(T, B)
>>> actions=torch.randint(0, N, (T, B))
>>> weights=torch.ones(T, B)
>>> ratio=torch.randn(T, B, N)
>>> q_retraces = compute_q_retraces(q_values, v_pred, rewards, actions, weights, ratio)
.. note::
q_retrace operation doesn't need to compute gradient, just executes forward computation.
"""
T = q_values.size()[0] - 1
rewards = rewards.unsqueeze(-1)
actions = actions.unsqueeze(-1)
weights = weights.unsqueeze(-1)
q_retraces = torch.zeros_like(v_pred) # shape (T+1),B,1
tmp_retraces = v_pred[-1] # shape B,1
q_retraces[-1] = v_pred[-1]
q_gather = torch.zeros_like(v_pred)
q_gather[0:-1] = q_values[0:-1].gather(-1, actions) # shape (T+1),B,1
ratio_gather = ratio.gather(-1, actions) # shape T,B,1
for idx in reversed(range(T)):
q_retraces[idx] = rewards[idx] + gamma * weights[idx] * tmp_retraces
tmp_retraces = ratio_gather[idx].clamp(max=1.0) * (q_retraces[idx] - q_gather[idx]) + v_pred[idx]
return q_retraces # shape (T+1),B,1
|