Spaces:
Running
Running
# Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""A UVF agent. | |
""" | |
import tensorflow as tf | |
import gin.tf | |
from agents import ddpg_agent | |
# pylint: disable=unused-import | |
import cond_fn | |
from utils import utils as uvf_utils | |
from context import gin_imports | |
# pylint: enable=unused-import | |
slim = tf.contrib.slim | |
class UvfAgentCore(object): | |
"""Defines basic functions for UVF agent. Must be inherited with an RL agent. | |
Used as lower-level agent. | |
""" | |
def __init__(self, | |
observation_spec, | |
action_spec, | |
tf_env, | |
tf_context, | |
step_cond_fn=cond_fn.env_transition, | |
reset_episode_cond_fn=cond_fn.env_restart, | |
reset_env_cond_fn=cond_fn.false_fn, | |
metrics=None, | |
**base_agent_kwargs): | |
"""Constructs a UVF agent. | |
Args: | |
observation_spec: A TensorSpec defining the observations. | |
action_spec: A BoundedTensorSpec defining the actions. | |
tf_env: A Tensorflow environment object. | |
tf_context: A Context class. | |
step_cond_fn: A function indicating whether to increment the num of steps. | |
reset_episode_cond_fn: A function indicating whether to restart the | |
episode, resampling the context. | |
reset_env_cond_fn: A function indicating whether to perform a manual reset | |
of the environment. | |
metrics: A list of functions that evaluate metrics of the agent. | |
**base_agent_kwargs: A dictionary of parameters for base RL Agent. | |
Raises: | |
ValueError: If 'dqda_clipping' is < 0. | |
""" | |
self._step_cond_fn = step_cond_fn | |
self._reset_episode_cond_fn = reset_episode_cond_fn | |
self._reset_env_cond_fn = reset_env_cond_fn | |
self.metrics = metrics | |
# expose tf_context methods | |
self.tf_context = tf_context(tf_env=tf_env) | |
self.set_replay = self.tf_context.set_replay | |
self.sample_contexts = self.tf_context.sample_contexts | |
self.compute_rewards = self.tf_context.compute_rewards | |
self.gamma_index = self.tf_context.gamma_index | |
self.context_specs = self.tf_context.context_specs | |
self.context_as_action_specs = self.tf_context.context_as_action_specs | |
self.init_context_vars = self.tf_context.create_vars | |
self.env_observation_spec = observation_spec[0] | |
merged_observation_spec = (uvf_utils.merge_specs( | |
(self.env_observation_spec,) + self.context_specs),) | |
self._context_vars = dict() | |
self._action_vars = dict() | |
self.BASE_AGENT_CLASS.__init__( | |
self, | |
observation_spec=merged_observation_spec, | |
action_spec=action_spec, | |
**base_agent_kwargs | |
) | |
def set_meta_agent(self, agent=None): | |
self._meta_agent = agent | |
def meta_agent(self): | |
return self._meta_agent | |
def actor_loss(self, states, actions, rewards, discounts, | |
next_states): | |
"""Returns the next action for the state. | |
Args: | |
state: A [num_state_dims] tensor representing a state. | |
context: A list of [num_context_dims] tensor representing a context. | |
Returns: | |
A [num_action_dims] tensor representing the action. | |
""" | |
return self.BASE_AGENT_CLASS.actor_loss(self, states) | |
def action(self, state, context=None): | |
"""Returns the next action for the state. | |
Args: | |
state: A [num_state_dims] tensor representing a state. | |
context: A list of [num_context_dims] tensor representing a context. | |
Returns: | |
A [num_action_dims] tensor representing the action. | |
""" | |
merged_state = self.merged_state(state, context) | |
return self.BASE_AGENT_CLASS.action(self, merged_state) | |
def actions(self, state, context=None): | |
"""Returns the next action for the state. | |
Args: | |
state: A [-1, num_state_dims] tensor representing a state. | |
context: A list of [-1, num_context_dims] tensor representing a context. | |
Returns: | |
A [-1, num_action_dims] tensor representing the action. | |
""" | |
merged_states = self.merged_states(state, context) | |
return self.BASE_AGENT_CLASS.actor_net(self, merged_states) | |
def log_probs(self, states, actions, state_reprs, contexts=None): | |
assert contexts is not None | |
batch_dims = [tf.shape(states)[0], tf.shape(states)[1]] | |
contexts = self.tf_context.context_multi_transition_fn( | |
contexts, states=tf.to_float(state_reprs)) | |
flat_states = tf.reshape(states, | |
[batch_dims[0] * batch_dims[1], states.shape[-1]]) | |
flat_contexts = [tf.reshape(tf.cast(context, states.dtype), | |
[batch_dims[0] * batch_dims[1], context.shape[-1]]) | |
for context in contexts] | |
flat_pred_actions = self.actions(flat_states, flat_contexts) | |
pred_actions = tf.reshape(flat_pred_actions, | |
batch_dims + [flat_pred_actions.shape[-1]]) | |
error = tf.square(actions - pred_actions) | |
spec_range = (self._action_spec.maximum - self._action_spec.minimum) / 2 | |
normalized_error = tf.cast(error, tf.float64) / tf.constant(spec_range) ** 2 | |
return -normalized_error | |
def add_noise_fn(self, action_fn, stddev=1.0, debug=False, | |
clip=True, global_step=None): | |
"""Returns the action_fn with additive Gaussian noise. | |
Args: | |
action_fn: A callable(`state`, `context`) which returns a | |
[num_action_dims] tensor representing a action. | |
stddev: stddev for the Ornstein-Uhlenbeck noise. | |
debug: Print debug messages. | |
Returns: | |
A [num_action_dims] action tensor. | |
""" | |
if global_step is not None: | |
stddev *= tf.maximum( # Decay exploration during training. | |
tf.train.exponential_decay(1.0, global_step, 1e6, 0.8), 0.5) | |
def noisy_action_fn(state, context=None): | |
"""Noisy action fn.""" | |
action = action_fn(state, context) | |
if debug: | |
action = uvf_utils.tf_print( | |
action, [action], | |
message='[add_noise_fn] pre-noise action', | |
first_n=100) | |
noise_dist = tf.distributions.Normal(tf.zeros_like(action), | |
tf.ones_like(action) * stddev) | |
noise = noise_dist.sample() | |
action += noise | |
if debug: | |
action = uvf_utils.tf_print( | |
action, [action], | |
message='[add_noise_fn] post-noise action', | |
first_n=100) | |
if clip: | |
action = uvf_utils.clip_to_spec(action, self._action_spec) | |
return action | |
return noisy_action_fn | |
def merged_state(self, state, context=None): | |
"""Returns the merged state from the environment state and contexts. | |
Args: | |
state: A [num_state_dims] tensor representing a state. | |
context: A list of [num_context_dims] tensor representing a context. | |
If None, use the internal context. | |
Returns: | |
A [num_merged_state_dims] tensor representing the merged state. | |
""" | |
if context is None: | |
context = list(self.context_vars) | |
state = tf.concat([state,] + context, axis=-1) | |
self._validate_states(self._batch_state(state)) | |
return state | |
def merged_states(self, states, contexts=None): | |
"""Returns the batch merged state from the batch env state and contexts. | |
Args: | |
states: A [batch_size, num_state_dims] tensor representing a batch | |
of states. | |
contexts: A list of [batch_size, num_context_dims] tensor | |
representing a batch of contexts. If None, | |
use the internal context. | |
Returns: | |
A [batch_size, num_merged_state_dims] tensor representing the batch | |
of merged states. | |
""" | |
if contexts is None: | |
contexts = [tf.tile(tf.expand_dims(context, axis=0), | |
(tf.shape(states)[0], 1)) for | |
context in self.context_vars] | |
states = tf.concat([states,] + contexts, axis=-1) | |
self._validate_states(states) | |
return states | |
def unmerged_states(self, merged_states): | |
"""Returns the batch state and contexts from the batch merged state. | |
Args: | |
merged_states: A [batch_size, num_merged_state_dims] tensor | |
representing a batch of merged states. | |
Returns: | |
A [batch_size, num_state_dims] tensor and a list of | |
[batch_size, num_context_dims] tensors representing the batch state | |
and contexts respectively. | |
""" | |
self._validate_states(merged_states) | |
num_state_dims = self.env_observation_spec.shape.as_list()[0] | |
num_context_dims_list = [c.shape.as_list()[0] for c in self.context_specs] | |
states = merged_states[:, :num_state_dims] | |
contexts = [] | |
i = num_state_dims | |
for num_context_dims in num_context_dims_list: | |
contexts.append(merged_states[:, i: i+num_context_dims]) | |
i += num_context_dims | |
return states, contexts | |
def sample_random_actions(self, batch_size=1): | |
"""Return random actions. | |
Args: | |
batch_size: Batch size. | |
Returns: | |
A [batch_size, num_action_dims] tensor representing the batch of actions. | |
""" | |
actions = tf.concat( | |
[ | |
tf.random_uniform( | |
shape=(batch_size, 1), | |
minval=self._action_spec.minimum[i], | |
maxval=self._action_spec.maximum[i]) | |
for i in range(self._action_spec.shape[0].value) | |
], | |
axis=1) | |
return actions | |
def clip_actions(self, actions): | |
"""Clip actions to spec. | |
Args: | |
actions: A [batch_size, num_action_dims] tensor representing | |
the batch of actions. | |
Returns: | |
A [batch_size, num_action_dims] tensor representing the batch | |
of clipped actions. | |
""" | |
actions = tf.concat( | |
[ | |
tf.clip_by_value( | |
actions[:, i:i+1], | |
self._action_spec.minimum[i], | |
self._action_spec.maximum[i]) | |
for i in range(self._action_spec.shape[0].value) | |
], | |
axis=1) | |
return actions | |
def mix_contexts(self, contexts, insert_contexts, indices): | |
"""Mix two contexts based on indices. | |
Args: | |
contexts: A list of [batch_size, num_context_dims] tensor representing | |
the batch of contexts. | |
insert_contexts: A list of [batch_size, num_context_dims] tensor | |
representing the batch of contexts to be inserted. | |
indices: A list of a list of integers denoting indices to replace. | |
Returns: | |
A list of resulting contexts. | |
""" | |
if indices is None: indices = [[]] * len(contexts) | |
assert len(contexts) == len(indices) | |
assert all([spec.shape.ndims == 1 for spec in self.context_specs]) | |
mix_contexts = [] | |
for contexts_, insert_contexts_, indices_, spec in zip( | |
contexts, insert_contexts, indices, self.context_specs): | |
mix_contexts.append( | |
tf.concat( | |
[ | |
insert_contexts_[:, i:i + 1] if i in indices_ else | |
contexts_[:, i:i + 1] for i in range(spec.shape.as_list()[0]) | |
], | |
axis=1)) | |
return mix_contexts | |
def begin_episode_ops(self, mode, action_fn=None, state=None): | |
"""Returns ops that reset agent at beginning of episodes. | |
Args: | |
mode: a string representing the mode=[train, explore, eval]. | |
Returns: | |
A list of ops. | |
""" | |
all_ops = [] | |
for _, action_var in sorted(self._action_vars.items()): | |
sample_action = self.sample_random_actions(1)[0] | |
all_ops.append(tf.assign(action_var, sample_action)) | |
all_ops += self.tf_context.reset(mode=mode, agent=self._meta_agent, | |
action_fn=action_fn, state=state) | |
return all_ops | |
def cond_begin_episode_op(self, cond, input_vars, mode, meta_action_fn): | |
"""Returns op that resets agent at beginning of episodes. | |
A new episode is begun if the cond op evalues to `False`. | |
Args: | |
cond: a Boolean tensor variable. | |
input_vars: A list of tensor variables. | |
mode: a string representing the mode=[train, explore, eval]. | |
Returns: | |
Conditional begin op. | |
""" | |
(state, action, reward, next_state, | |
state_repr, next_state_repr) = input_vars | |
def continue_fn(): | |
"""Continue op fn.""" | |
items = [state, action, reward, next_state, | |
state_repr, next_state_repr] + list(self.context_vars) | |
batch_items = [tf.expand_dims(item, 0) for item in items] | |
(states, actions, rewards, next_states, | |
state_reprs, next_state_reprs) = batch_items[:6] | |
context_reward = self.compute_rewards( | |
mode, state_reprs, actions, rewards, next_state_reprs, | |
batch_items[6:])[0][0] | |
context_reward = tf.cast(context_reward, dtype=reward.dtype) | |
if self.meta_agent is not None: | |
meta_action = tf.concat(self.context_vars, -1) | |
items = [state, meta_action, reward, next_state, | |
state_repr, next_state_repr] + list(self.meta_agent.context_vars) | |
batch_items = [tf.expand_dims(item, 0) for item in items] | |
(states, meta_actions, rewards, next_states, | |
state_reprs, next_state_reprs) = batch_items[:6] | |
meta_reward = self.meta_agent.compute_rewards( | |
mode, states, meta_actions, rewards, | |
next_states, batch_items[6:])[0][0] | |
meta_reward = tf.cast(meta_reward, dtype=reward.dtype) | |
else: | |
meta_reward = tf.constant(0, dtype=reward.dtype) | |
with tf.control_dependencies([context_reward, meta_reward]): | |
step_ops = self.tf_context.step(mode=mode, agent=self._meta_agent, | |
state=state, | |
next_state=next_state, | |
state_repr=state_repr, | |
next_state_repr=next_state_repr, | |
action_fn=meta_action_fn) | |
with tf.control_dependencies(step_ops): | |
context_reward, meta_reward = map(tf.identity, [context_reward, meta_reward]) | |
return context_reward, meta_reward | |
def begin_episode_fn(): | |
"""Begin op fn.""" | |
begin_ops = self.begin_episode_ops(mode=mode, action_fn=meta_action_fn, state=state) | |
with tf.control_dependencies(begin_ops): | |
return tf.zeros_like(reward), tf.zeros_like(reward) | |
with tf.control_dependencies(input_vars): | |
cond_begin_episode_op = tf.cond(cond, continue_fn, begin_episode_fn) | |
return cond_begin_episode_op | |
def get_env_base_wrapper(self, env_base, **begin_kwargs): | |
"""Create a wrapper around env_base, with agent-specific begin/end_episode. | |
Args: | |
env_base: A python environment base. | |
**begin_kwargs: Keyword args for begin_episode_ops. | |
Returns: | |
An object with begin_episode() and end_episode(). | |
""" | |
begin_ops = self.begin_episode_ops(**begin_kwargs) | |
return uvf_utils.get_contextual_env_base(env_base, begin_ops) | |
def init_action_vars(self, name, i=None): | |
"""Create and return a tensorflow Variable holding an action. | |
Args: | |
name: Name of the variables. | |
i: Integer id. | |
Returns: | |
A [num_action_dims] tensor. | |
""" | |
if i is not None: | |
name += '_%d' % i | |
assert name not in self._action_vars, ('Conflict! %s is already ' | |
'initialized.') % name | |
self._action_vars[name] = tf.Variable( | |
self.sample_random_actions(1)[0], name='%s_action' % (name)) | |
self._validate_actions(tf.expand_dims(self._action_vars[name], 0)) | |
return self._action_vars[name] | |
def critic_function(self, critic_vals, states, critic_fn=None): | |
"""Computes q values based on outputs from the critic net. | |
Args: | |
critic_vals: A tf.float32 [batch_size, ...] tensor representing outputs | |
from the critic net. | |
states: A [batch_size, num_state_dims] tensor representing a batch | |
of states. | |
critic_fn: A callable that process outputs from critic_net and | |
outputs a [batch_size] tensor representing q values. | |
Returns: | |
A tf.float32 [batch_size] tensor representing q values. | |
""" | |
if critic_fn is not None: | |
env_states, contexts = self.unmerged_states(states) | |
critic_vals = critic_fn(critic_vals, env_states, contexts) | |
critic_vals.shape.assert_has_rank(1) | |
return critic_vals | |
def get_action_vars(self, key): | |
return self._action_vars[key] | |
def get_context_vars(self, key): | |
return self.tf_context.context_vars[key] | |
def step_cond_fn(self, *args): | |
return self._step_cond_fn(self, *args) | |
def reset_episode_cond_fn(self, *args): | |
return self._reset_episode_cond_fn(self, *args) | |
def reset_env_cond_fn(self, *args): | |
return self._reset_env_cond_fn(self, *args) | |
def context_vars(self): | |
return self.tf_context.vars | |
class MetaAgentCore(UvfAgentCore): | |
"""Defines basic functions for UVF Meta-agent. Must be inherited with an RL agent. | |
Used as higher-level agent. | |
""" | |
def __init__(self, | |
observation_spec, | |
action_spec, | |
tf_env, | |
tf_context, | |
sub_context, | |
step_cond_fn=cond_fn.env_transition, | |
reset_episode_cond_fn=cond_fn.env_restart, | |
reset_env_cond_fn=cond_fn.false_fn, | |
metrics=None, | |
actions_reg=0., | |
k=2, | |
**base_agent_kwargs): | |
"""Constructs a Meta agent. | |
Args: | |
observation_spec: A TensorSpec defining the observations. | |
action_spec: A BoundedTensorSpec defining the actions. | |
tf_env: A Tensorflow environment object. | |
tf_context: A Context class. | |
step_cond_fn: A function indicating whether to increment the num of steps. | |
reset_episode_cond_fn: A function indicating whether to restart the | |
episode, resampling the context. | |
reset_env_cond_fn: A function indicating whether to perform a manual reset | |
of the environment. | |
metrics: A list of functions that evaluate metrics of the agent. | |
**base_agent_kwargs: A dictionary of parameters for base RL Agent. | |
Raises: | |
ValueError: If 'dqda_clipping' is < 0. | |
""" | |
self._step_cond_fn = step_cond_fn | |
self._reset_episode_cond_fn = reset_episode_cond_fn | |
self._reset_env_cond_fn = reset_env_cond_fn | |
self.metrics = metrics | |
self._actions_reg = actions_reg | |
self._k = k | |
# expose tf_context methods | |
self.tf_context = tf_context(tf_env=tf_env) | |
self.sub_context = sub_context(tf_env=tf_env) | |
self.set_replay = self.tf_context.set_replay | |
self.sample_contexts = self.tf_context.sample_contexts | |
self.compute_rewards = self.tf_context.compute_rewards | |
self.gamma_index = self.tf_context.gamma_index | |
self.context_specs = self.tf_context.context_specs | |
self.context_as_action_specs = self.tf_context.context_as_action_specs | |
self.sub_context_as_action_specs = self.sub_context.context_as_action_specs | |
self.init_context_vars = self.tf_context.create_vars | |
self.env_observation_spec = observation_spec[0] | |
merged_observation_spec = (uvf_utils.merge_specs( | |
(self.env_observation_spec,) + self.context_specs),) | |
self._context_vars = dict() | |
self._action_vars = dict() | |
assert len(self.context_as_action_specs) == 1 | |
self.BASE_AGENT_CLASS.__init__( | |
self, | |
observation_spec=merged_observation_spec, | |
action_spec=self.sub_context_as_action_specs, | |
**base_agent_kwargs | |
) | |
def add_noise_fn(self, action_fn, stddev=1.0, debug=False, | |
global_step=None): | |
noisy_action_fn = super(MetaAgentCore, self).add_noise_fn( | |
action_fn, stddev, | |
clip=True, global_step=global_step) | |
return noisy_action_fn | |
def actor_loss(self, states, actions, rewards, discounts, | |
next_states): | |
"""Returns the next action for the state. | |
Args: | |
state: A [num_state_dims] tensor representing a state. | |
context: A list of [num_context_dims] tensor representing a context. | |
Returns: | |
A [num_action_dims] tensor representing the action. | |
""" | |
actions = self.actor_net(states, stop_gradients=False) | |
regularizer = self._actions_reg * tf.reduce_mean( | |
tf.reduce_sum(tf.abs(actions[:, self._k:]), -1), 0) | |
loss = self.BASE_AGENT_CLASS.actor_loss(self, states) | |
return regularizer + loss | |
class UvfAgent(UvfAgentCore, ddpg_agent.TD3Agent): | |
"""A DDPG agent with UVF. | |
""" | |
BASE_AGENT_CLASS = ddpg_agent.TD3Agent | |
ACTION_TYPE = 'continuous' | |
def __init__(self, *args, **kwargs): | |
UvfAgentCore.__init__(self, *args, **kwargs) | |
class MetaAgent(MetaAgentCore, ddpg_agent.TD3Agent): | |
"""A DDPG meta-agent. | |
""" | |
BASE_AGENT_CLASS = ddpg_agent.TD3Agent | |
ACTION_TYPE = 'continuous' | |
def __init__(self, *args, **kwargs): | |
MetaAgentCore.__init__(self, *args, **kwargs) | |
def state_preprocess_net( | |
states, | |
num_output_dims=2, | |
states_hidden_layers=(100,), | |
normalizer_fn=None, | |
activation_fn=tf.nn.relu, | |
zero_time=True, | |
images=False): | |
"""Creates a simple feed forward net for embedding states. | |
""" | |
with slim.arg_scope( | |
[slim.fully_connected], | |
activation_fn=activation_fn, | |
normalizer_fn=normalizer_fn, | |
weights_initializer=slim.variance_scaling_initializer( | |
factor=1.0/3.0, mode='FAN_IN', uniform=True)): | |
states_shape = tf.shape(states) | |
states_dtype = states.dtype | |
states = tf.to_float(states) | |
if images: # Zero-out x-y | |
states *= tf.constant([0.] * 2 + [1.] * (states.shape[-1] - 2), dtype=states.dtype) | |
if zero_time: | |
states *= tf.constant([1.] * (states.shape[-1] - 1) + [0.], dtype=states.dtype) | |
orig_states = states | |
embed = states | |
if states_hidden_layers: | |
embed = slim.stack(embed, slim.fully_connected, states_hidden_layers, | |
scope='states') | |
with slim.arg_scope([slim.fully_connected], | |
weights_regularizer=None, | |
weights_initializer=tf.random_uniform_initializer( | |
minval=-0.003, maxval=0.003)): | |
embed = slim.fully_connected(embed, num_output_dims, | |
activation_fn=None, | |
normalizer_fn=None, | |
scope='value') | |
output = embed | |
output = tf.cast(output, states_dtype) | |
return output | |
def action_embed_net( | |
actions, | |
states=None, | |
num_output_dims=2, | |
hidden_layers=(400, 300), | |
normalizer_fn=None, | |
activation_fn=tf.nn.relu, | |
zero_time=True, | |
images=False): | |
"""Creates a simple feed forward net for embedding actions. | |
""" | |
with slim.arg_scope( | |
[slim.fully_connected], | |
activation_fn=activation_fn, | |
normalizer_fn=normalizer_fn, | |
weights_initializer=slim.variance_scaling_initializer( | |
factor=1.0/3.0, mode='FAN_IN', uniform=True)): | |
actions = tf.to_float(actions) | |
if states is not None: | |
if images: # Zero-out x-y | |
states *= tf.constant([0.] * 2 + [1.] * (states.shape[-1] - 2), dtype=states.dtype) | |
if zero_time: | |
states *= tf.constant([1.] * (states.shape[-1] - 1) + [0.], dtype=states.dtype) | |
actions = tf.concat([actions, tf.to_float(states)], -1) | |
embed = actions | |
if hidden_layers: | |
embed = slim.stack(embed, slim.fully_connected, hidden_layers, | |
scope='hidden') | |
with slim.arg_scope([slim.fully_connected], | |
weights_regularizer=None, | |
weights_initializer=tf.random_uniform_initializer( | |
minval=-0.003, maxval=0.003)): | |
embed = slim.fully_connected(embed, num_output_dims, | |
activation_fn=None, | |
normalizer_fn=None, | |
scope='value') | |
if num_output_dims == 1: | |
return embed[:, 0, ...] | |
else: | |
return embed | |
def huber(x, kappa=0.1): | |
return (0.5 * tf.square(x) * tf.to_float(tf.abs(x) <= kappa) + | |
kappa * (tf.abs(x) - 0.5 * kappa) * tf.to_float(tf.abs(x) > kappa) | |
) / kappa | |
class StatePreprocess(object): | |
STATE_PREPROCESS_NET_SCOPE = 'state_process_net' | |
ACTION_EMBED_NET_SCOPE = 'action_embed_net' | |
def __init__(self, trainable=False, | |
state_preprocess_net=lambda states: states, | |
action_embed_net=lambda actions, *args, **kwargs: actions, | |
ndims=None): | |
self.trainable = trainable | |
self._scope = tf.get_variable_scope().name | |
self._ndims = ndims | |
self._state_preprocess_net = tf.make_template( | |
self.STATE_PREPROCESS_NET_SCOPE, state_preprocess_net, | |
create_scope_now_=True) | |
self._action_embed_net = tf.make_template( | |
self.ACTION_EMBED_NET_SCOPE, action_embed_net, | |
create_scope_now_=True) | |
def __call__(self, states): | |
batched = states.get_shape().ndims != 1 | |
if not batched: | |
states = tf.expand_dims(states, 0) | |
embedded = self._state_preprocess_net(states) | |
if self._ndims is not None: | |
embedded = embedded[..., :self._ndims] | |
if not batched: | |
return embedded[0] | |
return embedded | |
def loss(self, states, next_states, low_actions, low_states): | |
batch_size = tf.shape(states)[0] | |
d = int(low_states.shape[1]) | |
# Sample indices into meta-transition to train on. | |
probs = 0.99 ** tf.range(d, dtype=tf.float32) | |
probs *= tf.constant([1.0] * (d - 1) + [1.0 / (1 - 0.99)], | |
dtype=tf.float32) | |
probs /= tf.reduce_sum(probs) | |
index_dist = tf.distributions.Categorical(probs=probs, dtype=tf.int64) | |
indices = index_dist.sample(batch_size) | |
batch_size = tf.cast(batch_size, tf.int64) | |
next_indices = tf.concat( | |
[tf.range(batch_size, dtype=tf.int64)[:, None], | |
(1 + indices[:, None]) % d], -1) | |
new_next_states = tf.where(indices < d - 1, | |
tf.gather_nd(low_states, next_indices), | |
next_states) | |
next_states = new_next_states | |
embed1 = tf.to_float(self._state_preprocess_net(states)) | |
embed2 = tf.to_float(self._state_preprocess_net(next_states)) | |
action_embed = self._action_embed_net( | |
tf.layers.flatten(low_actions), states=states) | |
tau = 2.0 | |
fn = lambda z: tau * tf.reduce_sum(huber(z), -1) | |
all_embed = tf.get_variable('all_embed', [1024, int(embed1.shape[-1])], | |
initializer=tf.zeros_initializer()) | |
upd = all_embed.assign(tf.concat([all_embed[batch_size:], embed2], 0)) | |
with tf.control_dependencies([upd]): | |
close = 1 * tf.reduce_mean(fn(embed1 + action_embed - embed2)) | |
prior_log_probs = tf.reduce_logsumexp( | |
-fn((embed1 + action_embed)[:, None, :] - all_embed[None, :, :]), | |
axis=-1) - tf.log(tf.to_float(all_embed.shape[0])) | |
far = tf.reduce_mean(tf.exp(-fn((embed1 + action_embed)[1:] - embed2[:-1]) | |
- tf.stop_gradient(prior_log_probs[1:]))) | |
repr_log_probs = tf.stop_gradient( | |
-fn(embed1 + action_embed - embed2) - prior_log_probs) / tau | |
return close + far, repr_log_probs, indices | |
def get_trainable_vars(self): | |
return ( | |
slim.get_trainable_variables( | |
uvf_utils.join_scope(self._scope, self.STATE_PREPROCESS_NET_SCOPE)) + | |
slim.get_trainable_variables( | |
uvf_utils.join_scope(self._scope, self.ACTION_EMBED_NET_SCOPE))) | |
class InverseDynamics(object): | |
INVERSE_DYNAMICS_NET_SCOPE = 'inverse_dynamics' | |
def __init__(self, spec): | |
self._spec = spec | |
def sample(self, states, next_states, num_samples, orig_goals, sc=0.5): | |
goal_dim = orig_goals.shape[-1] | |
spec_range = (self._spec.maximum - self._spec.minimum) / 2 * tf.ones([goal_dim]) | |
loc = tf.cast(next_states - states, tf.float32)[:, :goal_dim] | |
scale = sc * tf.tile(tf.reshape(spec_range, [1, goal_dim]), | |
[tf.shape(states)[0], 1]) | |
dist = tf.distributions.Normal(loc, scale) | |
if num_samples == 1: | |
return dist.sample() | |
samples = tf.concat([dist.sample(num_samples - 2), | |
tf.expand_dims(loc, 0), | |
tf.expand_dims(orig_goals, 0)], 0) | |
return uvf_utils.clip_to_spec(samples, self._spec) | |