VPG playing CartPole-v1 from https://github.com/sgoodfriend/rl-algo-impls/tree/e8bc541d8b5e67bb4d3f2075282463fb61f5f2c6
9dc837c
import numpy as np | |
from dataclasses import dataclass, field | |
from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs | |
from typing import Generic, List, Optional, Type, TypeVar | |
class Trajectory: | |
obs: List[np.ndarray] = field(default_factory=list) | |
act: List[np.ndarray] = field(default_factory=list) | |
next_obs: Optional[np.ndarray] = None | |
rew: List[float] = field(default_factory=list) | |
terminated: bool = False | |
v: List[float] = field(default_factory=list) | |
def add( | |
self, | |
obs: np.ndarray, | |
act: np.ndarray, | |
next_obs: np.ndarray, | |
rew: float, | |
terminated: bool, | |
v: float, | |
): | |
self.obs.append(obs) | |
self.act.append(act) | |
self.next_obs = next_obs if not terminated else None | |
self.rew.append(rew) | |
self.terminated = terminated | |
self.v.append(v) | |
def __len__(self) -> int: | |
return len(self.obs) | |
T = TypeVar("T", bound=Trajectory) | |
class TrajectoryAccumulator(Generic[T]): | |
def __init__(self, num_envs: int, trajectory_class: Type[T] = Trajectory) -> None: | |
self.num_envs = num_envs | |
self.trajectory_class = trajectory_class | |
self._trajectories = [] | |
self._current_trajectories = [trajectory_class() for _ in range(num_envs)] | |
def step( | |
self, | |
obs: VecEnvObs, | |
action: np.ndarray, | |
next_obs: VecEnvObs, | |
reward: np.ndarray, | |
done: np.ndarray, | |
val: np.ndarray, | |
*args, | |
) -> None: | |
assert isinstance(obs, np.ndarray) | |
assert isinstance(next_obs, np.ndarray) | |
for i, args in enumerate(zip(obs, action, next_obs, reward, done, val, *args)): | |
trajectory = self._current_trajectories[i] | |
# TODO: Eventually take advantage of terminated/truncated differentiation in | |
# later versions of gym. | |
trajectory.add(*args) | |
if done[i]: | |
self._trajectories.append(trajectory) | |
self._current_trajectories[i] = self.trajectory_class() | |
self.on_done(i, trajectory) | |
def all_trajectories(self) -> List[T]: | |
return self._trajectories + list( | |
filter(lambda t: len(t), self._current_trajectories) | |
) | |
def n_timesteps(self) -> int: | |
return sum(len(t) for t in self.all_trajectories) | |
def on_done(self, env_idx: int, trajectory: T) -> None: | |
pass | |