|
from typing import TYPE_CHECKING, Callable, List, Tuple, Any |
|
from functools import reduce |
|
import treetensor.torch as ttorch |
|
import numpy as np |
|
from ditk import logging |
|
from ding.utils import EasyTimer |
|
from ding.envs import BaseEnvManager |
|
from ding.policy import Policy |
|
from ding.torch_utils import to_ndarray, get_shape0 |
|
|
|
if TYPE_CHECKING: |
|
from ding.framework import OnlineRLContext |
|
|
|
|
|
class TransitionList: |
|
|
|
def __init__(self, env_num: int) -> None: |
|
self.env_num = env_num |
|
self._transitions = [[] for _ in range(env_num)] |
|
self._done_idx = [[] for _ in range(env_num)] |
|
|
|
def append(self, env_id: int, transition: Any) -> None: |
|
self._transitions[env_id].append(transition) |
|
if transition.done: |
|
self._done_idx[env_id].append(len(self._transitions[env_id])) |
|
|
|
def to_trajectories(self) -> Tuple[List[Any], List[int]]: |
|
trajectories = sum(self._transitions, []) |
|
lengths = [len(t) for t in self._transitions] |
|
trajectory_end_idx = [reduce(lambda x, y: x + y, lengths[:i + 1]) for i in range(len(lengths))] |
|
trajectory_end_idx = [t - 1 for t in trajectory_end_idx] |
|
return trajectories, trajectory_end_idx |
|
|
|
def to_episodes(self) -> List[List[Any]]: |
|
episodes = [] |
|
for env_id in range(self.env_num): |
|
last_idx = 0 |
|
for done_idx in self._done_idx[env_id]: |
|
episodes.append(self._transitions[env_id][last_idx:done_idx]) |
|
last_idx = done_idx |
|
return episodes |
|
|
|
def clear(self): |
|
for item in self._transitions: |
|
item.clear() |
|
for item in self._done_idx: |
|
item.clear() |
|
|
|
|
|
def inferencer(seed: int, policy: Policy, env: BaseEnvManager) -> Callable: |
|
""" |
|
Overview: |
|
The middleware that executes the inference process. |
|
Arguments: |
|
- seed (:obj:`int`): Random seed. |
|
- policy (:obj:`Policy`): The policy to be inferred. |
|
- env (:obj:`BaseEnvManager`): The env where the inference process is performed. \ |
|
The env.ready_obs (:obj:`tnp.array`) will be used as model input. |
|
""" |
|
|
|
env.seed(seed) |
|
|
|
def _inference(ctx: "OnlineRLContext"): |
|
""" |
|
Output of ctx: |
|
- obs (:obj:`Union[torch.Tensor, Dict[torch.Tensor]]`): The input observations collected \ |
|
from all collector environments. |
|
- action: (:obj:`List[np.ndarray]`): The inferred actions listed by env_id. |
|
- inference_output (:obj:`Dict[int, Dict]`): The dict of which the key is env_id (int), \ |
|
and the value is inference result (Dict). |
|
""" |
|
|
|
if env.closed: |
|
env.launch() |
|
|
|
obs = ttorch.as_tensor(env.ready_obs) |
|
ctx.obs = obs |
|
obs = obs.to(dtype=ttorch.float32) |
|
|
|
|
|
obs = {i: obs[i] for i in range(get_shape0(obs))} |
|
inference_output = policy.forward(obs, **ctx.collect_kwargs) |
|
ctx.action = [to_ndarray(v['action']) for v in inference_output.values()] |
|
ctx.inference_output = inference_output |
|
|
|
return _inference |
|
|
|
|
|
def rolloutor( |
|
policy: Policy, |
|
env: BaseEnvManager, |
|
transitions: TransitionList, |
|
collect_print_freq=100, |
|
) -> Callable: |
|
""" |
|
Overview: |
|
The middleware that executes the transition process in the env. |
|
Arguments: |
|
- policy (:obj:`Policy`): The policy to be used during transition. |
|
- env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \ |
|
its derivatives are supported. |
|
- transitions (:obj:`TransitionList`): The transition information which will be filled \ |
|
in this process, including `obs`, `next_obs`, `action`, `logit`, `value`, `reward` \ |
|
and `done`. |
|
""" |
|
|
|
env_episode_id = [_ for _ in range(env.env_num)] |
|
current_id = env.env_num |
|
timer = EasyTimer() |
|
last_train_iter = 0 |
|
total_envstep_count = 0 |
|
total_episode_count = 0 |
|
total_train_sample_count = 0 |
|
env_info = {env_id: {'time': 0., 'step': 0, 'train_sample': 0} for env_id in range(env.env_num)} |
|
episode_info = [] |
|
|
|
def _rollout(ctx: "OnlineRLContext"): |
|
""" |
|
Input of ctx: |
|
- action: (:obj:`List[np.ndarray]`): The inferred actions from previous inference process. |
|
- obs (:obj:`Dict[Tensor]`): The states fed into the transition dict. |
|
- inference_output (:obj:`Dict[int, Dict]`): The inference results to be fed into the \ |
|
transition dict. |
|
- train_iter (:obj:`int`): The train iteration count to be fed into the transition dict. |
|
- env_step (:obj:`int`): The count of env step, which will increase by 1 for a single \ |
|
transition call. |
|
- env_episode (:obj:`int`): The count of env episode, which will increase by 1 if the \ |
|
trajectory stops. |
|
""" |
|
|
|
nonlocal current_id, env_info, episode_info, timer, \ |
|
total_episode_count, total_envstep_count, total_train_sample_count, last_train_iter |
|
timesteps = env.step(ctx.action) |
|
ctx.env_step += len(timesteps) |
|
timesteps = [t.tensor() for t in timesteps] |
|
|
|
collected_sample = 0 |
|
collected_step = 0 |
|
collected_episode = 0 |
|
interaction_duration = timer.value / len(timesteps) |
|
for i, timestep in enumerate(timesteps): |
|
with timer: |
|
transition = policy.process_transition(ctx.obs[i], ctx.inference_output[i], timestep) |
|
transition = ttorch.as_tensor(transition) |
|
transition.collect_train_iter = ttorch.as_tensor([ctx.train_iter]) |
|
transition.env_data_id = ttorch.as_tensor([env_episode_id[timestep.env_id]]) |
|
transitions.append(timestep.env_id, transition) |
|
|
|
collected_step += 1 |
|
collected_sample += len(transition.obs) |
|
env_info[timestep.env_id.item()]['step'] += 1 |
|
env_info[timestep.env_id.item()]['train_sample'] += len(transition.obs) |
|
|
|
env_info[timestep.env_id.item()]['time'] += timer.value + interaction_duration |
|
if timestep.done: |
|
info = { |
|
'reward': timestep.info['eval_episode_return'], |
|
'time': env_info[timestep.env_id.item()]['time'], |
|
'step': env_info[timestep.env_id.item()]['step'], |
|
'train_sample': env_info[timestep.env_id.item()]['train_sample'], |
|
} |
|
|
|
episode_info.append(info) |
|
policy.reset([timestep.env_id.item()]) |
|
env_episode_id[timestep.env_id.item()] = current_id |
|
collected_episode += 1 |
|
current_id += 1 |
|
ctx.env_episode += 1 |
|
|
|
total_envstep_count += collected_step |
|
total_episode_count += collected_episode |
|
total_train_sample_count += collected_sample |
|
|
|
if (ctx.train_iter - last_train_iter) >= collect_print_freq and len(episode_info) > 0: |
|
output_log(episode_info, total_episode_count, total_envstep_count, total_train_sample_count) |
|
last_train_iter = ctx.train_iter |
|
|
|
return _rollout |
|
|
|
|
|
def output_log(episode_info, total_episode_count, total_envstep_count, total_train_sample_count) -> None: |
|
""" |
|
Overview: |
|
Print the output log information. You can refer to the docs of `Best Practice` to understand \ |
|
the training generated logs and tensorboards. |
|
Arguments: |
|
- train_iter (:obj:`int`): the number of training iteration. |
|
""" |
|
episode_count = len(episode_info) |
|
envstep_count = sum([d['step'] for d in episode_info]) |
|
train_sample_count = sum([d['train_sample'] for d in episode_info]) |
|
duration = sum([d['time'] for d in episode_info]) |
|
episode_return = [d['reward'].item() for d in episode_info] |
|
info = { |
|
'episode_count': episode_count, |
|
'envstep_count': envstep_count, |
|
'train_sample_count': train_sample_count, |
|
'avg_envstep_per_episode': envstep_count / episode_count, |
|
'avg_sample_per_episode': train_sample_count / episode_count, |
|
'avg_envstep_per_sec': envstep_count / duration, |
|
'avg_train_sample_per_sec': train_sample_count / duration, |
|
'avg_episode_per_sec': episode_count / duration, |
|
'reward_mean': np.mean(episode_return), |
|
'reward_std': np.std(episode_return), |
|
'reward_max': np.max(episode_return), |
|
'reward_min': np.min(episode_return), |
|
'total_envstep_count': total_envstep_count, |
|
'total_train_sample_count': total_train_sample_count, |
|
'total_episode_count': total_episode_count, |
|
|
|
} |
|
episode_info.clear() |
|
logging.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) |
|
|