ppo-CarRacing-v0 / shared /trajectory.py
sgoodfriend's picture
PPO playing CarRacing-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c
0ca6846
raw
history blame
634 Bytes
import numpy as np
import torch
from dataclasses import dataclass
from typing import List
@dataclass
class Trajectory:
obs: List[np.ndarray]
act: List[np.ndarray]
rew: List[float]
v: List[float]
terminated: bool
def __init__(self) -> None:
self.obs = []
self.act = []
self.rew = []
self.v = []
self.terminated = False
def add(self, obs: np.ndarray, act: np.ndarray, rew: float, v: float):
self.obs.append(obs)
self.act.append(act)
self.rew.append(rew)
self.v.append(v)
def __len__(self) -> int:
return len(self.obs)