File size: 634 Bytes
9839b09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import numpy as np
import torch
from dataclasses import dataclass
from typing import List
@dataclass
class Trajectory:
obs: List[np.ndarray]
act: List[np.ndarray]
rew: List[float]
v: List[float]
terminated: bool
def __init__(self) -> None:
self.obs = []
self.act = []
self.rew = []
self.v = []
self.terminated = False
def add(self, obs: np.ndarray, act: np.ndarray, rew: float, v: float):
self.obs.append(obs)
self.act.append(act)
self.rew.append(rew)
self.v.append(v)
def __len__(self) -> int:
return len(self.obs)
|