| import threading |
|
|
| import numpy as np |
|
|
|
|
| class ReplayBuffer: |
| def __init__(self, buffer_shapes, size_in_transitions, T, sample_transitions): |
| """Creates a replay buffer. |
| |
| Args: |
| buffer_shapes (dict of ints): the shape for all buffers that are used in the replay |
| buffer |
| size_in_transitions (int): the size of the buffer, measured in transitions |
| T (int): the time horizon for episodes |
| sample_transitions (function): a function that samples from the replay buffer |
| """ |
| self.buffer_shapes = buffer_shapes |
| self.size = size_in_transitions // T |
| self.T = T |
| self.sample_transitions = sample_transitions |
|
|
| |
| self.buffers = {key: np.empty([self.size, *shape]) |
| for key, shape in buffer_shapes.items()} |
|
|
| |
| self.current_size = 0 |
| self.n_transitions_stored = 0 |
|
|
| self.lock = threading.Lock() |
|
|
| @property |
| def full(self): |
| with self.lock: |
| return self.current_size == self.size |
|
|
| def sample(self, batch_size): |
| """Returns a dict {key: array(batch_size x shapes[key])} |
| """ |
| buffers = {} |
|
|
| with self.lock: |
| assert self.current_size > 0 |
| for key in self.buffers.keys(): |
| buffers[key] = self.buffers[key][:self.current_size] |
|
|
| buffers['o_2'] = buffers['o'][:, 1:, :] |
| buffers['ag_2'] = buffers['ag'][:, 1:, :] |
|
|
| transitions = self.sample_transitions(buffers, batch_size) |
|
|
| for key in (['r', 'o_2', 'ag_2'] + list(self.buffers.keys())): |
| assert key in transitions, "key %s missing from transitions" % key |
|
|
| return transitions |
|
|
| def store_episode(self, episode_batch): |
| """episode_batch: array(batch_size x (T or T+1) x dim_key) |
| """ |
| batch_sizes = [len(episode_batch[key]) for key in episode_batch.keys()] |
| assert np.all(np.array(batch_sizes) == batch_sizes[0]) |
| batch_size = batch_sizes[0] |
|
|
| with self.lock: |
| idxs = self._get_storage_idx(batch_size) |
|
|
| |
| for key in self.buffers.keys(): |
| self.buffers[key][idxs] = episode_batch[key] |
|
|
| self.n_transitions_stored += batch_size * self.T |
|
|
| def get_current_episode_size(self): |
| with self.lock: |
| return self.current_size |
|
|
| def get_current_size(self): |
| with self.lock: |
| return self.current_size * self.T |
|
|
| def get_transitions_stored(self): |
| with self.lock: |
| return self.n_transitions_stored |
|
|
| def clear_buffer(self): |
| with self.lock: |
| self.current_size = 0 |
|
|
| def _get_storage_idx(self, inc=None): |
| inc = inc or 1 |
| assert inc <= self.size, "Batch committed to replay is too large!" |
| |
| if self.current_size+inc <= self.size: |
| idx = np.arange(self.current_size, self.current_size+inc) |
| elif self.current_size < self.size: |
| overflow = inc - (self.size - self.current_size) |
| idx_a = np.arange(self.current_size, self.size) |
| idx_b = np.random.randint(0, self.current_size, overflow) |
| idx = np.concatenate([idx_a, idx_b]) |
| else: |
| idx = np.random.randint(0, self.size, inc) |
|
|
| |
| self.current_size = min(self.size, self.current_size+inc) |
|
|
| if inc == 1: |
| idx = idx[0] |
| return idx |
|
|