Spaces:
Running
Running
# Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""TensorFlow utility functions. | |
""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
from copy import deepcopy | |
import tensorflow as tf | |
from tf_agents import specs | |
from tf_agents.utils import common | |
_tf_print_counts = dict() | |
_tf_print_running_sums = dict() | |
_tf_print_running_counts = dict() | |
_tf_print_ids = 0 | |
def get_contextual_env_base(env_base, begin_ops=None, end_ops=None): | |
"""Wrap env_base with additional tf ops.""" | |
# pylint: disable=protected-access | |
def init(self_, env_base): | |
self_._env_base = env_base | |
attribute_list = ["_render_mode", "_gym_env"] | |
for attribute in attribute_list: | |
if hasattr(env_base, attribute): | |
setattr(self_, attribute, getattr(env_base, attribute)) | |
if hasattr(env_base, "physics"): | |
self_._physics = env_base.physics | |
elif hasattr(env_base, "gym"): | |
class Physics(object): | |
def render(self, *args, **kwargs): | |
return env_base.gym.render("rgb_array") | |
physics = Physics() | |
self_._physics = physics | |
self_.physics = physics | |
def set_sess(self_, sess): | |
self_._sess = sess | |
if hasattr(self_._env_base, "set_sess"): | |
self_._env_base.set_sess(sess) | |
def begin_episode(self_): | |
self_._env_base.reset() | |
if begin_ops is not None: | |
self_._sess.run(begin_ops) | |
def end_episode(self_): | |
self_._env_base.reset() | |
if end_ops is not None: | |
self_._sess.run(end_ops) | |
return type("ContextualEnvBase", (env_base.__class__,), dict( | |
__init__=init, | |
set_sess=set_sess, | |
begin_episode=begin_episode, | |
end_episode=end_episode, | |
))(env_base) | |
# pylint: enable=protected-access | |
def merge_specs(specs_): | |
"""Merge TensorSpecs. | |
Args: | |
specs_: List of TensorSpecs to be merged. | |
Returns: | |
a TensorSpec: a merged TensorSpec. | |
""" | |
shape = specs_[0].shape | |
dtype = specs_[0].dtype | |
name = specs_[0].name | |
for spec in specs_[1:]: | |
assert shape[1:] == spec.shape[1:], "incompatible shapes: %s, %s" % ( | |
shape, spec.shape) | |
assert dtype == spec.dtype, "incompatible dtypes: %s, %s" % ( | |
dtype, spec.dtype) | |
shape = merge_shapes((shape, spec.shape), axis=0) | |
return specs.TensorSpec( | |
shape=shape, | |
dtype=dtype, | |
name=name, | |
) | |
def merge_shapes(shapes, axis=0): | |
"""Merge TensorShapes. | |
Args: | |
shapes: List of TensorShapes to be merged. | |
axis: optional, the axis to merge shaped. | |
Returns: | |
a TensorShape: a merged TensorShape. | |
""" | |
assert len(shapes) > 1 | |
dims = deepcopy(shapes[0].dims) | |
for shape in shapes[1:]: | |
assert shapes[0].ndims == shape.ndims | |
dims[axis] += shape.dims[axis] | |
return tf.TensorShape(dims=dims) | |
def get_all_vars(ignore_scopes=None): | |
"""Get all tf variables in scope. | |
Args: | |
ignore_scopes: A list of scope names to ignore. | |
Returns: | |
A list of all tf variables in scope. | |
""" | |
all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) | |
all_vars = [var for var in all_vars if ignore_scopes is None or not | |
any(var.name.startswith(scope) for scope in ignore_scopes)] | |
return all_vars | |
def clip(tensor, range_=None): | |
"""Return a tf op which clips tensor according to range_. | |
Args: | |
tensor: A Tensor to be clipped. | |
range_: None, or a tuple representing (minval, maxval) | |
Returns: | |
A clipped Tensor. | |
""" | |
if range_ is None: | |
return tf.identity(tensor) | |
elif isinstance(range_, (tuple, list)): | |
assert len(range_) == 2 | |
return tf.clip_by_value(tensor, range_[0], range_[1]) | |
else: raise NotImplementedError("Unacceptable range input: %r" % range_) | |
def clip_to_bounds(value, minimum, maximum): | |
"""Clips value to be between minimum and maximum. | |
Args: | |
value: (tensor) value to be clipped. | |
minimum: (numpy float array) minimum value to clip to. | |
maximum: (numpy float array) maximum value to clip to. | |
Returns: | |
clipped_value: (tensor) `value` clipped to between `minimum` and `maximum`. | |
""" | |
value = tf.minimum(value, maximum) | |
return tf.maximum(value, minimum) | |
clip_to_spec = common.clip_to_spec | |
def _clip_to_spec(value, spec): | |
"""Clips value to a given bounded tensor spec. | |
Args: | |
value: (tensor) value to be clipped. | |
spec: (BoundedTensorSpec) spec containing min. and max. values for clipping. | |
Returns: | |
clipped_value: (tensor) `value` clipped to be compatible with `spec`. | |
""" | |
return clip_to_bounds(value, spec.minimum, spec.maximum) | |
join_scope = common.join_scope | |
def _join_scope(parent_scope, child_scope): | |
"""Joins a parent and child scope using `/`, checking for empty/none. | |
Args: | |
parent_scope: (string) parent/prefix scope. | |
child_scope: (string) child/suffix scope. | |
Returns: | |
joined scope: (string) parent and child scopes joined by /. | |
""" | |
if not parent_scope: | |
return child_scope | |
if not child_scope: | |
return parent_scope | |
return '/'.join([parent_scope, child_scope]) | |
def assign_vars(vars_, values): | |
"""Returns the update ops for assigning a list of vars. | |
Args: | |
vars_: A list of variables. | |
values: A list of tensors representing new values. | |
Returns: | |
A list of update ops for the variables. | |
""" | |
return [var.assign(value) for var, value in zip(vars_, values)] | |
def identity_vars(vars_): | |
"""Return the identity ops for a list of tensors. | |
Args: | |
vars_: A list of tensors. | |
Returns: | |
A list of identity ops. | |
""" | |
return [tf.identity(var) for var in vars_] | |
def tile(var, batch_size=1): | |
"""Return tiled tensor. | |
Args: | |
var: A tensor representing the state. | |
batch_size: Batch size. | |
Returns: | |
A tensor with shape [batch_size,] + var.shape. | |
""" | |
batch_var = tf.tile( | |
tf.expand_dims(var, 0), | |
(batch_size,) + (1,) * var.get_shape().ndims) | |
return batch_var | |
def batch_list(vars_list): | |
"""Batch a list of variables. | |
Args: | |
vars_list: A list of tensor variables. | |
Returns: | |
A list of tensor variables with additional first dimension. | |
""" | |
return [tf.expand_dims(var, 0) for var in vars_list] | |
def tf_print(op, | |
tensors, | |
message="", | |
first_n=-1, | |
name=None, | |
sub_messages=None, | |
print_freq=-1, | |
include_count=True): | |
"""tf.Print, but to stdout.""" | |
# TODO(shanegu): `name` is deprecated. Remove from the rest of codes. | |
global _tf_print_ids | |
_tf_print_ids += 1 | |
name = _tf_print_ids | |
_tf_print_counts[name] = 0 | |
if print_freq > 0: | |
_tf_print_running_sums[name] = [0 for _ in tensors] | |
_tf_print_running_counts[name] = 0 | |
def print_message(*xs): | |
"""print message fn.""" | |
_tf_print_counts[name] += 1 | |
if print_freq > 0: | |
for i, x in enumerate(xs): | |
_tf_print_running_sums[name][i] += x | |
_tf_print_running_counts[name] += 1 | |
if (print_freq <= 0 or _tf_print_running_counts[name] >= print_freq) and ( | |
first_n < 0 or _tf_print_counts[name] <= first_n): | |
for i, x in enumerate(xs): | |
if print_freq > 0: | |
del x | |
x = _tf_print_running_sums[name][i]/_tf_print_running_counts[name] | |
if sub_messages is None: | |
sub_message = str(i) | |
else: | |
sub_message = sub_messages[i] | |
log_message = "%s, %s" % (message, sub_message) | |
if include_count: | |
log_message += ", count=%d" % _tf_print_counts[name] | |
tf.logging.info("[%s]: %s" % (log_message, x)) | |
if print_freq > 0: | |
for i, x in enumerate(xs): | |
_tf_print_running_sums[name][i] = 0 | |
_tf_print_running_counts[name] = 0 | |
return xs[0] | |
print_op = tf.py_func(print_message, tensors, tensors[0].dtype) | |
with tf.control_dependencies([print_op]): | |
op = tf.identity(op) | |
return op | |
periodically = common.periodically | |
def _periodically(body, period, name='periodically'): | |
"""Periodically performs a tensorflow op.""" | |
if period is None or period == 0: | |
return tf.no_op() | |
if period < 0: | |
raise ValueError("period cannot be less than 0.") | |
if period == 1: | |
return body() | |
with tf.variable_scope(None, default_name=name): | |
counter = tf.get_variable( | |
"counter", | |
shape=[], | |
dtype=tf.int64, | |
trainable=False, | |
initializer=tf.constant_initializer(period, dtype=tf.int64)) | |
def _wrapped_body(): | |
with tf.control_dependencies([body()]): | |
return counter.assign(1) | |
update = tf.cond( | |
tf.equal(counter, period), _wrapped_body, | |
lambda: counter.assign_add(1)) | |
return update | |
soft_variables_update = common.soft_variables_update | |