File size: 3,979 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
from typing import TYPE_CHECKING, Callable, List, Tuple, Union, Dict, Optional
from easydict import EasyDict
from collections import deque
from ding.framework import task
from ding.data import Buffer
from .functional import trainer, offpolicy_data_fetcher, reward_estimator, her_data_enhancer
if TYPE_CHECKING:
from ding.framework import Context, OnlineRLContext
from ding.policy import Policy
from ding.reward_model import BaseRewardModel
class OffPolicyLearner:
"""
Overview:
The class of the off-policy learner, including data fetching and model training. Use \
the `__call__` method to execute the whole learning process.
"""
def __new__(cls, *args, **kwargs):
if task.router.is_active and not task.has_role(task.role.LEARNER):
return task.void()
return super(OffPolicyLearner, cls).__new__(cls)
def __init__(
self,
cfg: EasyDict,
policy: 'Policy',
buffer_: Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]],
reward_model: Optional['BaseRewardModel'] = None,
log_freq: int = 100,
) -> None:
"""
Arguments:
- cfg (:obj:`EasyDict`): Config.
- policy (:obj:`Policy`): The policy to be trained.
- buffer (:obj:`Buffer`): The replay buffer to store the data for training.
- reward_model (:obj:`BaseRewardModel`): Additional reward estimator likes RND, ICM, etc. \
default to None.
- log_freq (:obj:`int`): The frequency (iteration) of showing log.
"""
self.cfg = cfg
self._fetcher = task.wrap(offpolicy_data_fetcher(cfg, buffer_))
self._trainer = task.wrap(trainer(cfg, policy, log_freq=log_freq))
if reward_model is not None:
self._reward_estimator = task.wrap(reward_estimator(cfg, reward_model))
else:
self._reward_estimator = None
def __call__(self, ctx: "OnlineRLContext") -> None:
"""
Output of ctx:
- train_output (:obj:`Deque`): The training output in deque.
"""
train_output_queue = []
for _ in range(self.cfg.policy.learn.update_per_collect):
self._fetcher(ctx)
if ctx.train_data is None:
break
if self._reward_estimator:
self._reward_estimator(ctx)
self._trainer(ctx)
train_output_queue.append(ctx.train_output)
ctx.train_output = train_output_queue
class HERLearner:
"""
Overview:
The class of the learner with the Hindsight Experience Replay (HER). \
Use the `__call__` method to execute the data featching and training \
process.
"""
def __init__(
self,
cfg: EasyDict,
policy,
buffer_: Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]],
her_reward_model,
) -> None:
"""
Arguments:
- cfg (:obj:`EasyDict`): Config.
- policy (:obj:`Policy`): The policy to be trained.
- buffer\_ (:obj:`Buffer`): The replay buffer to store the data for training.
- her_reward_model (:obj:`HerRewardModel`): HER reward model.
"""
self.cfg = cfg
self._fetcher = task.wrap(her_data_enhancer(cfg, buffer_, her_reward_model))
self._trainer = task.wrap(trainer(cfg, policy))
def __call__(self, ctx: "OnlineRLContext") -> None:
"""
Output of ctx:
- train_output (:obj:`Deque`): The deque of training output.
"""
train_output_queue = []
for _ in range(self.cfg.policy.learn.update_per_collect):
self._fetcher(ctx)
if ctx.train_data is None:
break
self._trainer(ctx)
train_output_queue.append(ctx.train_output)
ctx.train_output = train_output_queue
|