SelfDriving / exp_replay.py
VarunKumarGupta2003's picture
Upload 7 files
b89a51c verified
raw
history blame contribute delete
No virus
3.05 kB
import numpy as np
class ExperienceReplay:
def __init__(self,
num_frame_stack=4,
capacity=int(1e5),
pic_size=(96, 96)
):
self.num_frame_stack = num_frame_stack
self.capacity = capacity
self.pic_size = pic_size
self.counter = 0
self.frame_window = None
self.init_caches()
self.expecting_new_episode = True
def add_experience(self, frame, action, done, reward):
assert self.frame_window is not None, "start episode first"
self.counter += 1
frame_idx = self.counter % self.max_frame_cache
exp_idx = (self.counter - 1) % self.capacity
self.prev_states[exp_idx] = self.frame_window
self.frame_window = np.append(self.frame_window[1:], frame_idx)
self.next_states[exp_idx] = self.frame_window
self.actions[exp_idx] = action
self.is_done[exp_idx] = done
self.frames[frame_idx] = frame
self.rewards[exp_idx] = reward
if done:
self.expecting_new_episode = True
def start_new_episode(self, frame):
# it should be okay not to increment counter here
# because episode ending frames are not used
assert self.expecting_new_episode, "previous episode didn't end yet"
frame_idx = self.counter % self.max_frame_cache
self.frame_window = np.repeat(frame_idx, self.num_frame_stack)
self.frames[frame_idx] = frame
self.expecting_new_episode = False
def sample_mini_batch(self, n):
count = min(self.capacity, self.counter)
batchidx = np.random.randint(count, size=n)
prev_frames = self.frames[self.prev_states[batchidx]]
next_frames = self.frames[self.next_states[batchidx]]
prev_frames = np.moveaxis(prev_frames, 1, -1)
next_frames = np.moveaxis(next_frames, 1, -1)
return {
"reward": self.rewards[batchidx],
"prev_state": prev_frames,
"next_state": next_frames,
"actions": self.actions[batchidx],
"done_mask": self.is_done[batchidx]
}
def current_state(self):
# assert not self.expecting_new_episode, "start new episode first"'
assert self.frame_window is not None, "do something first"
sf = self.frames[self.frame_window]
sf = np.moveaxis(sf, 0, -1)
return sf
def init_caches(self):
self.rewards = np.zeros(self.capacity, dtype="float32")
self.prev_states = -np.ones((self.capacity, self.num_frame_stack),
dtype="int32")
self.next_states = -np.ones((self.capacity, self.num_frame_stack),
dtype="int32")
self.is_done = -np.ones(self.capacity, "int32")
self.actions = -np.ones(self.capacity, dtype="int32")
self.max_frame_cache = self.capacity + 2 * self.num_frame_stack + 1
self.frames = -np.ones((self.max_frame_cache,) + self.pic_size, dtype="float32")