| | import numpy as np |
| |
|
| | from dataclasses import dataclass, field |
| | from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs |
| | from typing import Generic, List, Optional, Type, TypeVar |
| |
|
| |
|
| | @dataclass |
| | class Trajectory: |
| | obs: List[np.ndarray] = field(default_factory=list) |
| | act: List[np.ndarray] = field(default_factory=list) |
| | next_obs: Optional[np.ndarray] = None |
| | rew: List[float] = field(default_factory=list) |
| | terminated: bool = False |
| | v: List[float] = field(default_factory=list) |
| |
|
| | def add( |
| | self, |
| | obs: np.ndarray, |
| | act: np.ndarray, |
| | next_obs: np.ndarray, |
| | rew: float, |
| | terminated: bool, |
| | v: float, |
| | ): |
| | self.obs.append(obs) |
| | self.act.append(act) |
| | self.next_obs = next_obs if not terminated else None |
| | self.rew.append(rew) |
| | self.terminated = terminated |
| | self.v.append(v) |
| |
|
| | def __len__(self) -> int: |
| | return len(self.obs) |
| |
|
| |
|
| | T = TypeVar("T", bound=Trajectory) |
| |
|
| |
|
| | class TrajectoryAccumulator(Generic[T]): |
| | def __init__(self, num_envs: int, trajectory_class: Type[T] = Trajectory) -> None: |
| | self.num_envs = num_envs |
| | self.trajectory_class = trajectory_class |
| |
|
| | self._trajectories = [] |
| | self._current_trajectories = [trajectory_class() for _ in range(num_envs)] |
| |
|
| | def step( |
| | self, |
| | obs: VecEnvObs, |
| | action: np.ndarray, |
| | next_obs: VecEnvObs, |
| | reward: np.ndarray, |
| | done: np.ndarray, |
| | val: np.ndarray, |
| | *args, |
| | ) -> None: |
| | assert isinstance(obs, np.ndarray) |
| | assert isinstance(next_obs, np.ndarray) |
| | for i, args in enumerate(zip(obs, action, next_obs, reward, done, val, *args)): |
| | trajectory = self._current_trajectories[i] |
| | |
| | |
| | trajectory.add(*args) |
| | if done[i]: |
| | self._trajectories.append(trajectory) |
| | self._current_trajectories[i] = self.trajectory_class() |
| | self.on_done(i, trajectory) |
| |
|
| | @property |
| | def all_trajectories(self) -> List[T]: |
| | return self._trajectories + list( |
| | filter(lambda t: len(t), self._current_trajectories) |
| | ) |
| |
|
| | def n_timesteps(self) -> int: |
| | return sum(len(t) for t in self.all_trajectories) |
| |
|
| | def on_done(self, env_idx: int, trajectory: T) -> None: |
| | pass |
| |
|