|
from typing import TYPE_CHECKING, Callable |
|
from easydict import EasyDict |
|
from ditk import logging |
|
import torch |
|
from ding.framework import task |
|
if TYPE_CHECKING: |
|
from ding.framework import OnlineRLContext |
|
from ding.reward_model import BaseRewardModel, HerRewardModel |
|
from ding.data import Buffer |
|
|
|
|
|
def reward_estimator(cfg: EasyDict, reward_model: "BaseRewardModel") -> Callable: |
|
""" |
|
Overview: |
|
Estimate the reward of `train_data` using `reward_model`. |
|
Arguments: |
|
- cfg (:obj:`EasyDict`): Config. |
|
- reward_model (:obj:`BaseRewardModel`): Reward model. |
|
""" |
|
if task.router.is_active and not task.has_role(task.role.LEARNER): |
|
return task.void() |
|
|
|
def _enhance(ctx: "OnlineRLContext"): |
|
""" |
|
Input of ctx: |
|
- train_data (:obj:`List`): The list of data used for estimation. |
|
""" |
|
reward_model.estimate(ctx.train_data) |
|
|
|
return _enhance |
|
|
|
|
|
def her_data_enhancer(cfg: EasyDict, buffer_: "Buffer", her_reward_model: "HerRewardModel") -> Callable: |
|
""" |
|
Overview: |
|
Fetch a batch of data/episode from `buffer_`, \ |
|
then use `her_reward_model` to get HER processed episodes from original episodes. |
|
Arguments: |
|
- cfg (:obj:`EasyDict`): Config which should contain the following keys \ |
|
if her_reward_model.episode_size is None: `cfg.policy.learn.batch_size`. |
|
- buffer\_ (:obj:`Buffer`): Buffer to sample data from. |
|
- her_reward_model (:obj:`HerRewardModel`): Hindsight Experience Replay (HER) model \ |
|
which is used to process episodes. |
|
""" |
|
if task.router.is_active and not task.has_role(task.role.LEARNER): |
|
return task.void() |
|
|
|
def _fetch_and_enhance(ctx: "OnlineRLContext"): |
|
""" |
|
Output of ctx: |
|
- train_data (:obj:`List[treetensor.torch.Tensor]`): The HER processed episodes. |
|
""" |
|
if her_reward_model.episode_size is None: |
|
size = cfg.policy.learn.batch_size |
|
else: |
|
size = her_reward_model.episode_size |
|
try: |
|
buffered_episode = buffer_.sample(size) |
|
train_episode = [d.data for d in buffered_episode] |
|
except (ValueError, AssertionError): |
|
|
|
logging.warning( |
|
"Replay buffer's data is not enough to support training, so skip this training for waiting more data." |
|
) |
|
ctx.train_data = None |
|
return |
|
|
|
her_episode = sum([her_reward_model.estimate(e) for e in train_episode], []) |
|
ctx.train_data = sum(her_episode, []) |
|
|
|
return _fetch_and_enhance |
|
|
|
|
|
def nstep_reward_enhancer(cfg: EasyDict) -> Callable: |
|
|
|
if task.router.is_active and (not task.has_role(task.role.LEARNER) and not task.has_role(task.role.COLLECTOR)): |
|
return task.void() |
|
|
|
def _enhance(ctx: "OnlineRLContext"): |
|
nstep = cfg.policy.nstep |
|
gamma = cfg.policy.discount_factor |
|
L = len(ctx.trajectories) |
|
reward_template = ctx.trajectories[0].reward |
|
nstep_rewards = [] |
|
value_gamma = [] |
|
for i in range(L): |
|
valid = min(nstep, L - i) |
|
for j in range(1, valid): |
|
if ctx.trajectories[j + i].done: |
|
valid = j |
|
break |
|
value_gamma.append(torch.FloatTensor([gamma ** valid])) |
|
nstep_reward = [ctx.trajectories[j].reward for j in range(i, i + valid)] |
|
if nstep > valid: |
|
nstep_reward.extend([torch.zeros_like(reward_template) for j in range(nstep - valid)]) |
|
nstep_reward = torch.cat(nstep_reward) |
|
nstep_rewards.append(nstep_reward) |
|
for i in range(L): |
|
ctx.trajectories[i].reward = nstep_rewards[i] |
|
ctx.trajectories[i].value_gamma = value_gamma[i] |
|
|
|
return _enhance |
|
|
|
|
|
|
|
|
|
|
|
|