Spaces:
Running
Running
# Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""A circular buffer where each element is a list of tensors. | |
Each element of the buffer is a list of tensors. An example use case is a replay | |
buffer in reinforcement learning, where each element is a list of tensors | |
representing the state, action, reward etc. | |
New elements are added sequentially, and once the buffer is full, we | |
start overwriting them in a circular fashion. Reading does not remove any | |
elements, only adding new elements does. | |
""" | |
import collections | |
import numpy as np | |
import tensorflow as tf | |
import gin.tf | |
class CircularBuffer(object): | |
"""A circular buffer where each element is a list of tensors.""" | |
def __init__(self, buffer_size=1000, scope='replay_buffer'): | |
"""Circular buffer of list of tensors. | |
Args: | |
buffer_size: (integer) maximum number of tensor lists the buffer can hold. | |
scope: (string) variable scope for creating the variables. | |
""" | |
self._buffer_size = np.int64(buffer_size) | |
self._scope = scope | |
self._tensors = collections.OrderedDict() | |
with tf.variable_scope(self._scope): | |
self._num_adds = tf.Variable(0, dtype=tf.int64, name='num_adds') | |
self._num_adds_cs = tf.CriticalSection(name='num_adds') | |
def buffer_size(self): | |
return self._buffer_size | |
def scope(self): | |
return self._scope | |
def num_adds(self): | |
return self._num_adds | |
def _create_variables(self, tensors): | |
with tf.variable_scope(self._scope): | |
for name in tensors.keys(): | |
tensor = tensors[name] | |
self._tensors[name] = tf.get_variable( | |
name='BufferVariable_' + name, | |
shape=[self._buffer_size] + tensor.get_shape().as_list(), | |
dtype=tensor.dtype, | |
trainable=False) | |
def _validate(self, tensors): | |
"""Validate shapes of tensors.""" | |
if len(tensors) != len(self._tensors): | |
raise ValueError('Expected tensors to have %d elements. Received %d ' | |
'instead.' % (len(self._tensors), len(tensors))) | |
if self._tensors.keys() != tensors.keys(): | |
raise ValueError('The keys of tensors should be the always the same.' | |
'Received %s instead %s.' % | |
(tensors.keys(), self._tensors.keys())) | |
for name, tensor in tensors.items(): | |
if tensor.get_shape().as_list() != self._tensors[ | |
name].get_shape().as_list()[1:]: | |
raise ValueError('Tensor %s has incorrect shape.' % name) | |
if not tensor.dtype.is_compatible_with(self._tensors[name].dtype): | |
raise ValueError( | |
'Tensor %s has incorrect data type. Expected %s, received %s' % | |
(name, self._tensors[name].read_value().dtype, tensor.dtype)) | |
def add(self, tensors): | |
"""Adds an element (list/tuple/dict of tensors) to the buffer. | |
Args: | |
tensors: (list/tuple/dict of tensors) to be added to the buffer. | |
Returns: | |
An add operation that adds the input `tensors` to the buffer. Similar to | |
an enqueue_op. | |
Raises: | |
ValueError: If the shapes and data types of input `tensors' are not the | |
same across calls to the add function. | |
""" | |
return self.maybe_add(tensors, True) | |
def maybe_add(self, tensors, condition): | |
"""Adds an element (tensors) to the buffer based on the condition.. | |
Args: | |
tensors: (list/tuple of tensors) to be added to the buffer. | |
condition: A boolean Tensor controlling whether the tensors would be added | |
to the buffer or not. | |
Returns: | |
An add operation that adds the input `tensors` to the buffer. Similar to | |
an maybe_enqueue_op. | |
Raises: | |
ValueError: If the shapes and data types of input `tensors' are not the | |
same across calls to the add function. | |
""" | |
if not isinstance(tensors, dict): | |
names = [str(i) for i in range(len(tensors))] | |
tensors = collections.OrderedDict(zip(names, tensors)) | |
if not isinstance(tensors, collections.OrderedDict): | |
tensors = collections.OrderedDict( | |
sorted(tensors.items(), key=lambda t: t[0])) | |
if not self._tensors: | |
self._create_variables(tensors) | |
else: | |
self._validate(tensors) | |
#@tf.critical_section(self._position_mutex) | |
def _increment_num_adds(): | |
# Adding 0 to the num_adds variable is a trick to read the value of the | |
# variable and return a read-only tensor. Doing this in a critical | |
# section allows us to capture a snapshot of the variable that will | |
# not be affected by other threads updating num_adds. | |
return self._num_adds.assign_add(1) + 0 | |
def _add(): | |
num_adds_inc = self._num_adds_cs.execute(_increment_num_adds) | |
current_pos = tf.mod(num_adds_inc - 1, self._buffer_size) | |
update_ops = [] | |
for name in self._tensors.keys(): | |
update_ops.append( | |
tf.scatter_update(self._tensors[name], current_pos, tensors[name])) | |
return tf.group(*update_ops) | |
return tf.contrib.framework.smart_cond(condition, _add, tf.no_op) | |
def get_random_batch(self, batch_size, keys=None, num_steps=1): | |
"""Samples a batch of tensors from the buffer with replacement. | |
Args: | |
batch_size: (integer) number of elements to sample. | |
keys: List of keys of tensors to retrieve. If None retrieve all. | |
num_steps: (integer) length of trajectories to return. If > 1 will return | |
a list of lists, where each internal list represents a trajectory of | |
length num_steps. | |
Returns: | |
A list of tensors, where each element in the list is a batch sampled from | |
one of the tensors in the buffer. | |
Raises: | |
ValueError: If get_random_batch is called before calling the add function. | |
tf.errors.InvalidArgumentError: If this operation is executed before any | |
items are added to the buffer. | |
""" | |
if not self._tensors: | |
raise ValueError('The add function must be called before get_random_batch.') | |
if keys is None: | |
keys = self._tensors.keys() | |
latest_start_index = self.get_num_adds() - num_steps + 1 | |
empty_buffer_assert = tf.Assert( | |
tf.greater(latest_start_index, 0), | |
['Not enough elements have been added to the buffer.']) | |
with tf.control_dependencies([empty_buffer_assert]): | |
max_index = tf.minimum(self._buffer_size, latest_start_index) | |
indices = tf.random_uniform( | |
[batch_size], | |
minval=0, | |
maxval=max_index, | |
dtype=tf.int64) | |
if num_steps == 1: | |
return self.gather(indices, keys) | |
else: | |
return self.gather_nstep(num_steps, indices, keys) | |
def gather(self, indices, keys=None): | |
"""Returns elements at the specified indices from the buffer. | |
Args: | |
indices: (list of integers or rank 1 int Tensor) indices in the buffer to | |
retrieve elements from. | |
keys: List of keys of tensors to retrieve. If None retrieve all. | |
Returns: | |
A list of tensors, where each element in the list is obtained by indexing | |
one of the tensors in the buffer. | |
Raises: | |
ValueError: If gather is called before calling the add function. | |
tf.errors.InvalidArgumentError: If indices are bigger than the number of | |
items in the buffer. | |
""" | |
if not self._tensors: | |
raise ValueError('The add function must be called before calling gather.') | |
if keys is None: | |
keys = self._tensors.keys() | |
with tf.name_scope('Gather'): | |
index_bound_assert = tf.Assert( | |
tf.less( | |
tf.to_int64(tf.reduce_max(indices)), | |
tf.minimum(self.get_num_adds(), self._buffer_size)), | |
['Index out of bounds.']) | |
with tf.control_dependencies([index_bound_assert]): | |
indices = tf.convert_to_tensor(indices) | |
batch = [] | |
for key in keys: | |
batch.append(tf.gather(self._tensors[key], indices, name=key)) | |
return batch | |
def gather_nstep(self, num_steps, indices, keys=None): | |
"""Returns elements at the specified indices from the buffer. | |
Args: | |
num_steps: (integer) length of trajectories to return. | |
indices: (list of rank num_steps int Tensor) indices in the buffer to | |
retrieve elements from for multiple trajectories. Each Tensor in the | |
list represents the indices for a trajectory. | |
keys: List of keys of tensors to retrieve. If None retrieve all. | |
Returns: | |
A list of list-of-tensors, where each element in the list is obtained by | |
indexing one of the tensors in the buffer. | |
Raises: | |
ValueError: If gather is called before calling the add function. | |
tf.errors.InvalidArgumentError: If indices are bigger than the number of | |
items in the buffer. | |
""" | |
if not self._tensors: | |
raise ValueError('The add function must be called before calling gather.') | |
if keys is None: | |
keys = self._tensors.keys() | |
with tf.name_scope('Gather'): | |
index_bound_assert = tf.Assert( | |
tf.less_equal( | |
tf.to_int64(tf.reduce_max(indices) + num_steps), | |
self.get_num_adds()), | |
['Trajectory indices go out of bounds.']) | |
with tf.control_dependencies([index_bound_assert]): | |
indices = tf.map_fn( | |
lambda x: tf.mod(tf.range(x, x + num_steps), self._buffer_size), | |
indices, | |
dtype=tf.int64) | |
batch = [] | |
for key in keys: | |
def SampleTrajectories(trajectory_indices, key=key, | |
num_steps=num_steps): | |
trajectory_indices.set_shape([num_steps]) | |
return tf.gather(self._tensors[key], trajectory_indices, name=key) | |
batch.append(tf.map_fn(SampleTrajectories, indices, | |
dtype=self._tensors[key].dtype)) | |
return batch | |
def get_position(self): | |
"""Returns the position at which the last element was added. | |
Returns: | |
An int tensor representing the index at which the last element was added | |
to the buffer or -1 if no elements were added. | |
""" | |
return tf.cond(self.get_num_adds() < 1, | |
lambda: self.get_num_adds() - 1, | |
lambda: tf.mod(self.get_num_adds() - 1, self._buffer_size)) | |
def get_num_adds(self): | |
"""Returns the number of additions to the buffer. | |
Returns: | |
An int tensor representing the number of elements that were added. | |
""" | |
def num_adds(): | |
return self._num_adds.value() | |
return self._num_adds_cs.execute(num_adds) | |
def get_num_tensors(self): | |
"""Returns the number of tensors (slots) in the buffer.""" | |
return len(self._tensors) | |