import numpy as np class ExperienceReplay: def __init__(self, num_frame_stack=4, capacity=int(1e5), pic_size=(96, 96) ): self.num_frame_stack = num_frame_stack self.capacity = capacity self.pic_size = pic_size self.counter = 0 self.frame_window = None self.init_caches() self.expecting_new_episode = True def add_experience(self, frame, action, done, reward): assert self.frame_window is not None, "start episode first" self.counter += 1 frame_idx = self.counter % self.max_frame_cache exp_idx = (self.counter - 1) % self.capacity self.prev_states[exp_idx] = self.frame_window self.frame_window = np.append(self.frame_window[1:], frame_idx) self.next_states[exp_idx] = self.frame_window self.actions[exp_idx] = action self.is_done[exp_idx] = done self.frames[frame_idx] = frame self.rewards[exp_idx] = reward if done: self.expecting_new_episode = True def start_new_episode(self, frame): # it should be okay not to increment counter here # because episode ending frames are not used assert self.expecting_new_episode, "previous episode didn't end yet" frame_idx = self.counter % self.max_frame_cache self.frame_window = np.repeat(frame_idx, self.num_frame_stack) self.frames[frame_idx] = frame self.expecting_new_episode = False def sample_mini_batch(self, n): count = min(self.capacity, self.counter) batchidx = np.random.randint(count, size=n) prev_frames = self.frames[self.prev_states[batchidx]] next_frames = self.frames[self.next_states[batchidx]] prev_frames = np.moveaxis(prev_frames, 1, -1) next_frames = np.moveaxis(next_frames, 1, -1) return { "reward": self.rewards[batchidx], "prev_state": prev_frames, "next_state": next_frames, "actions": self.actions[batchidx], "done_mask": self.is_done[batchidx] } def current_state(self): # assert not self.expecting_new_episode, "start new episode first"' assert self.frame_window is not None, "do something first" sf = self.frames[self.frame_window] sf = np.moveaxis(sf, 0, -1) return sf def init_caches(self): self.rewards = np.zeros(self.capacity, dtype="float32") self.prev_states = -np.ones((self.capacity, self.num_frame_stack), dtype="int32") self.next_states = -np.ones((self.capacity, self.num_frame_stack), dtype="int32") self.is_done = -np.ones(self.capacity, "int32") self.actions = -np.ones(self.capacity, dtype="int32") self.max_frame_cache = self.capacity + 2 * self.num_frame_stack + 1 self.frames = -np.ones((self.max_frame_cache,) + self.pic_size, dtype="float32")