File size: 4,013 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from typing import TYPE_CHECKING, Callable
from easydict import EasyDict
from ditk import logging
import torch
from ding.framework import task
if TYPE_CHECKING:
    from ding.framework import OnlineRLContext
    from ding.reward_model import BaseRewardModel, HerRewardModel
    from ding.data import Buffer


def reward_estimator(cfg: EasyDict, reward_model: "BaseRewardModel") -> Callable:
    """
    Overview:
        Estimate the reward of `train_data` using `reward_model`.
    Arguments:
        - cfg (:obj:`EasyDict`): Config.
        - reward_model (:obj:`BaseRewardModel`): Reward model.
    """
    if task.router.is_active and not task.has_role(task.role.LEARNER):
        return task.void()

    def _enhance(ctx: "OnlineRLContext"):
        """
        Input of ctx:
            - train_data (:obj:`List`): The list of data used for estimation.
        """
        reward_model.estimate(ctx.train_data)  # inplace modification

    return _enhance


def her_data_enhancer(cfg: EasyDict, buffer_: "Buffer", her_reward_model: "HerRewardModel") -> Callable:
    """
    Overview:
        Fetch a batch of data/episode from `buffer_`, \
        then use `her_reward_model` to get HER processed episodes from original episodes.
    Arguments:
        - cfg (:obj:`EasyDict`): Config which should contain the following keys \
            if her_reward_model.episode_size is None: `cfg.policy.learn.batch_size`.
        - buffer\_ (:obj:`Buffer`): Buffer to sample data from.
        - her_reward_model (:obj:`HerRewardModel`): Hindsight Experience Replay (HER) model \
            which is used to process episodes.
    """
    if task.router.is_active and not task.has_role(task.role.LEARNER):
        return task.void()

    def _fetch_and_enhance(ctx: "OnlineRLContext"):
        """
        Output of ctx:
            - train_data (:obj:`List[treetensor.torch.Tensor]`): The HER processed episodes.
        """
        if her_reward_model.episode_size is None:
            size = cfg.policy.learn.batch_size
        else:
            size = her_reward_model.episode_size
        try:
            buffered_episode = buffer_.sample(size)
            train_episode = [d.data for d in buffered_episode]
        except (ValueError, AssertionError):
            # You can modify data collect config to avoid this warning, e.g. increasing n_sample, n_episode.
            logging.warning(
                "Replay buffer's data is not enough to support training, so skip this training for waiting more data."
            )
            ctx.train_data = None
            return

        her_episode = sum([her_reward_model.estimate(e) for e in train_episode], [])
        ctx.train_data = sum(her_episode, [])

    return _fetch_and_enhance


def nstep_reward_enhancer(cfg: EasyDict) -> Callable:

    if task.router.is_active and (not task.has_role(task.role.LEARNER) and not task.has_role(task.role.COLLECTOR)):
        return task.void()

    def _enhance(ctx: "OnlineRLContext"):
        nstep = cfg.policy.nstep
        gamma = cfg.policy.discount_factor
        L = len(ctx.trajectories)
        reward_template = ctx.trajectories[0].reward
        nstep_rewards = []
        value_gamma = []
        for i in range(L):
            valid = min(nstep, L - i)
            for j in range(1, valid):
                if ctx.trajectories[j + i].done:
                    valid = j
                    break
            value_gamma.append(torch.FloatTensor([gamma ** valid]))
            nstep_reward = [ctx.trajectories[j].reward for j in range(i, i + valid)]
            if nstep > valid:
                nstep_reward.extend([torch.zeros_like(reward_template) for j in range(nstep - valid)])
            nstep_reward = torch.cat(nstep_reward)  # (nstep, )
            nstep_rewards.append(nstep_reward)
        for i in range(L):
            ctx.trajectories[i].reward = nstep_rewards[i]
            ctx.trajectories[i].value_gamma = value_gamma[i]

    return _enhance


# TODO MBPO
# TODO SIL
# TODO TD3 VAE