NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
29 kB
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A UVF agent.
"""
import tensorflow as tf
import gin.tf
from agents import ddpg_agent
# pylint: disable=unused-import
import cond_fn
from utils import utils as uvf_utils
from context import gin_imports
# pylint: enable=unused-import
slim = tf.contrib.slim
@gin.configurable
class UvfAgentCore(object):
"""Defines basic functions for UVF agent. Must be inherited with an RL agent.
Used as lower-level agent.
"""
def __init__(self,
observation_spec,
action_spec,
tf_env,
tf_context,
step_cond_fn=cond_fn.env_transition,
reset_episode_cond_fn=cond_fn.env_restart,
reset_env_cond_fn=cond_fn.false_fn,
metrics=None,
**base_agent_kwargs):
"""Constructs a UVF agent.
Args:
observation_spec: A TensorSpec defining the observations.
action_spec: A BoundedTensorSpec defining the actions.
tf_env: A Tensorflow environment object.
tf_context: A Context class.
step_cond_fn: A function indicating whether to increment the num of steps.
reset_episode_cond_fn: A function indicating whether to restart the
episode, resampling the context.
reset_env_cond_fn: A function indicating whether to perform a manual reset
of the environment.
metrics: A list of functions that evaluate metrics of the agent.
**base_agent_kwargs: A dictionary of parameters for base RL Agent.
Raises:
ValueError: If 'dqda_clipping' is < 0.
"""
self._step_cond_fn = step_cond_fn
self._reset_episode_cond_fn = reset_episode_cond_fn
self._reset_env_cond_fn = reset_env_cond_fn
self.metrics = metrics
# expose tf_context methods
self.tf_context = tf_context(tf_env=tf_env)
self.set_replay = self.tf_context.set_replay
self.sample_contexts = self.tf_context.sample_contexts
self.compute_rewards = self.tf_context.compute_rewards
self.gamma_index = self.tf_context.gamma_index
self.context_specs = self.tf_context.context_specs
self.context_as_action_specs = self.tf_context.context_as_action_specs
self.init_context_vars = self.tf_context.create_vars
self.env_observation_spec = observation_spec[0]
merged_observation_spec = (uvf_utils.merge_specs(
(self.env_observation_spec,) + self.context_specs),)
self._context_vars = dict()
self._action_vars = dict()
self.BASE_AGENT_CLASS.__init__(
self,
observation_spec=merged_observation_spec,
action_spec=action_spec,
**base_agent_kwargs
)
def set_meta_agent(self, agent=None):
self._meta_agent = agent
@property
def meta_agent(self):
return self._meta_agent
def actor_loss(self, states, actions, rewards, discounts,
next_states):
"""Returns the next action for the state.
Args:
state: A [num_state_dims] tensor representing a state.
context: A list of [num_context_dims] tensor representing a context.
Returns:
A [num_action_dims] tensor representing the action.
"""
return self.BASE_AGENT_CLASS.actor_loss(self, states)
def action(self, state, context=None):
"""Returns the next action for the state.
Args:
state: A [num_state_dims] tensor representing a state.
context: A list of [num_context_dims] tensor representing a context.
Returns:
A [num_action_dims] tensor representing the action.
"""
merged_state = self.merged_state(state, context)
return self.BASE_AGENT_CLASS.action(self, merged_state)
def actions(self, state, context=None):
"""Returns the next action for the state.
Args:
state: A [-1, num_state_dims] tensor representing a state.
context: A list of [-1, num_context_dims] tensor representing a context.
Returns:
A [-1, num_action_dims] tensor representing the action.
"""
merged_states = self.merged_states(state, context)
return self.BASE_AGENT_CLASS.actor_net(self, merged_states)
def log_probs(self, states, actions, state_reprs, contexts=None):
assert contexts is not None
batch_dims = [tf.shape(states)[0], tf.shape(states)[1]]
contexts = self.tf_context.context_multi_transition_fn(
contexts, states=tf.to_float(state_reprs))
flat_states = tf.reshape(states,
[batch_dims[0] * batch_dims[1], states.shape[-1]])
flat_contexts = [tf.reshape(tf.cast(context, states.dtype),
[batch_dims[0] * batch_dims[1], context.shape[-1]])
for context in contexts]
flat_pred_actions = self.actions(flat_states, flat_contexts)
pred_actions = tf.reshape(flat_pred_actions,
batch_dims + [flat_pred_actions.shape[-1]])
error = tf.square(actions - pred_actions)
spec_range = (self._action_spec.maximum - self._action_spec.minimum) / 2
normalized_error = tf.cast(error, tf.float64) / tf.constant(spec_range) ** 2
return -normalized_error
@gin.configurable('uvf_add_noise_fn')
def add_noise_fn(self, action_fn, stddev=1.0, debug=False,
clip=True, global_step=None):
"""Returns the action_fn with additive Gaussian noise.
Args:
action_fn: A callable(`state`, `context`) which returns a
[num_action_dims] tensor representing a action.
stddev: stddev for the Ornstein-Uhlenbeck noise.
debug: Print debug messages.
Returns:
A [num_action_dims] action tensor.
"""
if global_step is not None:
stddev *= tf.maximum( # Decay exploration during training.
tf.train.exponential_decay(1.0, global_step, 1e6, 0.8), 0.5)
def noisy_action_fn(state, context=None):
"""Noisy action fn."""
action = action_fn(state, context)
if debug:
action = uvf_utils.tf_print(
action, [action],
message='[add_noise_fn] pre-noise action',
first_n=100)
noise_dist = tf.distributions.Normal(tf.zeros_like(action),
tf.ones_like(action) * stddev)
noise = noise_dist.sample()
action += noise
if debug:
action = uvf_utils.tf_print(
action, [action],
message='[add_noise_fn] post-noise action',
first_n=100)
if clip:
action = uvf_utils.clip_to_spec(action, self._action_spec)
return action
return noisy_action_fn
def merged_state(self, state, context=None):
"""Returns the merged state from the environment state and contexts.
Args:
state: A [num_state_dims] tensor representing a state.
context: A list of [num_context_dims] tensor representing a context.
If None, use the internal context.
Returns:
A [num_merged_state_dims] tensor representing the merged state.
"""
if context is None:
context = list(self.context_vars)
state = tf.concat([state,] + context, axis=-1)
self._validate_states(self._batch_state(state))
return state
def merged_states(self, states, contexts=None):
"""Returns the batch merged state from the batch env state and contexts.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
contexts: A list of [batch_size, num_context_dims] tensor
representing a batch of contexts. If None,
use the internal context.
Returns:
A [batch_size, num_merged_state_dims] tensor representing the batch
of merged states.
"""
if contexts is None:
contexts = [tf.tile(tf.expand_dims(context, axis=0),
(tf.shape(states)[0], 1)) for
context in self.context_vars]
states = tf.concat([states,] + contexts, axis=-1)
self._validate_states(states)
return states
def unmerged_states(self, merged_states):
"""Returns the batch state and contexts from the batch merged state.
Args:
merged_states: A [batch_size, num_merged_state_dims] tensor
representing a batch of merged states.
Returns:
A [batch_size, num_state_dims] tensor and a list of
[batch_size, num_context_dims] tensors representing the batch state
and contexts respectively.
"""
self._validate_states(merged_states)
num_state_dims = self.env_observation_spec.shape.as_list()[0]
num_context_dims_list = [c.shape.as_list()[0] for c in self.context_specs]
states = merged_states[:, :num_state_dims]
contexts = []
i = num_state_dims
for num_context_dims in num_context_dims_list:
contexts.append(merged_states[:, i: i+num_context_dims])
i += num_context_dims
return states, contexts
def sample_random_actions(self, batch_size=1):
"""Return random actions.
Args:
batch_size: Batch size.
Returns:
A [batch_size, num_action_dims] tensor representing the batch of actions.
"""
actions = tf.concat(
[
tf.random_uniform(
shape=(batch_size, 1),
minval=self._action_spec.minimum[i],
maxval=self._action_spec.maximum[i])
for i in range(self._action_spec.shape[0].value)
],
axis=1)
return actions
def clip_actions(self, actions):
"""Clip actions to spec.
Args:
actions: A [batch_size, num_action_dims] tensor representing
the batch of actions.
Returns:
A [batch_size, num_action_dims] tensor representing the batch
of clipped actions.
"""
actions = tf.concat(
[
tf.clip_by_value(
actions[:, i:i+1],
self._action_spec.minimum[i],
self._action_spec.maximum[i])
for i in range(self._action_spec.shape[0].value)
],
axis=1)
return actions
def mix_contexts(self, contexts, insert_contexts, indices):
"""Mix two contexts based on indices.
Args:
contexts: A list of [batch_size, num_context_dims] tensor representing
the batch of contexts.
insert_contexts: A list of [batch_size, num_context_dims] tensor
representing the batch of contexts to be inserted.
indices: A list of a list of integers denoting indices to replace.
Returns:
A list of resulting contexts.
"""
if indices is None: indices = [[]] * len(contexts)
assert len(contexts) == len(indices)
assert all([spec.shape.ndims == 1 for spec in self.context_specs])
mix_contexts = []
for contexts_, insert_contexts_, indices_, spec in zip(
contexts, insert_contexts, indices, self.context_specs):
mix_contexts.append(
tf.concat(
[
insert_contexts_[:, i:i + 1] if i in indices_ else
contexts_[:, i:i + 1] for i in range(spec.shape.as_list()[0])
],
axis=1))
return mix_contexts
def begin_episode_ops(self, mode, action_fn=None, state=None):
"""Returns ops that reset agent at beginning of episodes.
Args:
mode: a string representing the mode=[train, explore, eval].
Returns:
A list of ops.
"""
all_ops = []
for _, action_var in sorted(self._action_vars.items()):
sample_action = self.sample_random_actions(1)[0]
all_ops.append(tf.assign(action_var, sample_action))
all_ops += self.tf_context.reset(mode=mode, agent=self._meta_agent,
action_fn=action_fn, state=state)
return all_ops
def cond_begin_episode_op(self, cond, input_vars, mode, meta_action_fn):
"""Returns op that resets agent at beginning of episodes.
A new episode is begun if the cond op evalues to `False`.
Args:
cond: a Boolean tensor variable.
input_vars: A list of tensor variables.
mode: a string representing the mode=[train, explore, eval].
Returns:
Conditional begin op.
"""
(state, action, reward, next_state,
state_repr, next_state_repr) = input_vars
def continue_fn():
"""Continue op fn."""
items = [state, action, reward, next_state,
state_repr, next_state_repr] + list(self.context_vars)
batch_items = [tf.expand_dims(item, 0) for item in items]
(states, actions, rewards, next_states,
state_reprs, next_state_reprs) = batch_items[:6]
context_reward = self.compute_rewards(
mode, state_reprs, actions, rewards, next_state_reprs,
batch_items[6:])[0][0]
context_reward = tf.cast(context_reward, dtype=reward.dtype)
if self.meta_agent is not None:
meta_action = tf.concat(self.context_vars, -1)
items = [state, meta_action, reward, next_state,
state_repr, next_state_repr] + list(self.meta_agent.context_vars)
batch_items = [tf.expand_dims(item, 0) for item in items]
(states, meta_actions, rewards, next_states,
state_reprs, next_state_reprs) = batch_items[:6]
meta_reward = self.meta_agent.compute_rewards(
mode, states, meta_actions, rewards,
next_states, batch_items[6:])[0][0]
meta_reward = tf.cast(meta_reward, dtype=reward.dtype)
else:
meta_reward = tf.constant(0, dtype=reward.dtype)
with tf.control_dependencies([context_reward, meta_reward]):
step_ops = self.tf_context.step(mode=mode, agent=self._meta_agent,
state=state,
next_state=next_state,
state_repr=state_repr,
next_state_repr=next_state_repr,
action_fn=meta_action_fn)
with tf.control_dependencies(step_ops):
context_reward, meta_reward = map(tf.identity, [context_reward, meta_reward])
return context_reward, meta_reward
def begin_episode_fn():
"""Begin op fn."""
begin_ops = self.begin_episode_ops(mode=mode, action_fn=meta_action_fn, state=state)
with tf.control_dependencies(begin_ops):
return tf.zeros_like(reward), tf.zeros_like(reward)
with tf.control_dependencies(input_vars):
cond_begin_episode_op = tf.cond(cond, continue_fn, begin_episode_fn)
return cond_begin_episode_op
def get_env_base_wrapper(self, env_base, **begin_kwargs):
"""Create a wrapper around env_base, with agent-specific begin/end_episode.
Args:
env_base: A python environment base.
**begin_kwargs: Keyword args for begin_episode_ops.
Returns:
An object with begin_episode() and end_episode().
"""
begin_ops = self.begin_episode_ops(**begin_kwargs)
return uvf_utils.get_contextual_env_base(env_base, begin_ops)
def init_action_vars(self, name, i=None):
"""Create and return a tensorflow Variable holding an action.
Args:
name: Name of the variables.
i: Integer id.
Returns:
A [num_action_dims] tensor.
"""
if i is not None:
name += '_%d' % i
assert name not in self._action_vars, ('Conflict! %s is already '
'initialized.') % name
self._action_vars[name] = tf.Variable(
self.sample_random_actions(1)[0], name='%s_action' % (name))
self._validate_actions(tf.expand_dims(self._action_vars[name], 0))
return self._action_vars[name]
@gin.configurable('uvf_critic_function')
def critic_function(self, critic_vals, states, critic_fn=None):
"""Computes q values based on outputs from the critic net.
Args:
critic_vals: A tf.float32 [batch_size, ...] tensor representing outputs
from the critic net.
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
critic_fn: A callable that process outputs from critic_net and
outputs a [batch_size] tensor representing q values.
Returns:
A tf.float32 [batch_size] tensor representing q values.
"""
if critic_fn is not None:
env_states, contexts = self.unmerged_states(states)
critic_vals = critic_fn(critic_vals, env_states, contexts)
critic_vals.shape.assert_has_rank(1)
return critic_vals
def get_action_vars(self, key):
return self._action_vars[key]
def get_context_vars(self, key):
return self.tf_context.context_vars[key]
def step_cond_fn(self, *args):
return self._step_cond_fn(self, *args)
def reset_episode_cond_fn(self, *args):
return self._reset_episode_cond_fn(self, *args)
def reset_env_cond_fn(self, *args):
return self._reset_env_cond_fn(self, *args)
@property
def context_vars(self):
return self.tf_context.vars
@gin.configurable
class MetaAgentCore(UvfAgentCore):
"""Defines basic functions for UVF Meta-agent. Must be inherited with an RL agent.
Used as higher-level agent.
"""
def __init__(self,
observation_spec,
action_spec,
tf_env,
tf_context,
sub_context,
step_cond_fn=cond_fn.env_transition,
reset_episode_cond_fn=cond_fn.env_restart,
reset_env_cond_fn=cond_fn.false_fn,
metrics=None,
actions_reg=0.,
k=2,
**base_agent_kwargs):
"""Constructs a Meta agent.
Args:
observation_spec: A TensorSpec defining the observations.
action_spec: A BoundedTensorSpec defining the actions.
tf_env: A Tensorflow environment object.
tf_context: A Context class.
step_cond_fn: A function indicating whether to increment the num of steps.
reset_episode_cond_fn: A function indicating whether to restart the
episode, resampling the context.
reset_env_cond_fn: A function indicating whether to perform a manual reset
of the environment.
metrics: A list of functions that evaluate metrics of the agent.
**base_agent_kwargs: A dictionary of parameters for base RL Agent.
Raises:
ValueError: If 'dqda_clipping' is < 0.
"""
self._step_cond_fn = step_cond_fn
self._reset_episode_cond_fn = reset_episode_cond_fn
self._reset_env_cond_fn = reset_env_cond_fn
self.metrics = metrics
self._actions_reg = actions_reg
self._k = k
# expose tf_context methods
self.tf_context = tf_context(tf_env=tf_env)
self.sub_context = sub_context(tf_env=tf_env)
self.set_replay = self.tf_context.set_replay
self.sample_contexts = self.tf_context.sample_contexts
self.compute_rewards = self.tf_context.compute_rewards
self.gamma_index = self.tf_context.gamma_index
self.context_specs = self.tf_context.context_specs
self.context_as_action_specs = self.tf_context.context_as_action_specs
self.sub_context_as_action_specs = self.sub_context.context_as_action_specs
self.init_context_vars = self.tf_context.create_vars
self.env_observation_spec = observation_spec[0]
merged_observation_spec = (uvf_utils.merge_specs(
(self.env_observation_spec,) + self.context_specs),)
self._context_vars = dict()
self._action_vars = dict()
assert len(self.context_as_action_specs) == 1
self.BASE_AGENT_CLASS.__init__(
self,
observation_spec=merged_observation_spec,
action_spec=self.sub_context_as_action_specs,
**base_agent_kwargs
)
@gin.configurable('meta_add_noise_fn')
def add_noise_fn(self, action_fn, stddev=1.0, debug=False,
global_step=None):
noisy_action_fn = super(MetaAgentCore, self).add_noise_fn(
action_fn, stddev,
clip=True, global_step=global_step)
return noisy_action_fn
def actor_loss(self, states, actions, rewards, discounts,
next_states):
"""Returns the next action for the state.
Args:
state: A [num_state_dims] tensor representing a state.
context: A list of [num_context_dims] tensor representing a context.
Returns:
A [num_action_dims] tensor representing the action.
"""
actions = self.actor_net(states, stop_gradients=False)
regularizer = self._actions_reg * tf.reduce_mean(
tf.reduce_sum(tf.abs(actions[:, self._k:]), -1), 0)
loss = self.BASE_AGENT_CLASS.actor_loss(self, states)
return regularizer + loss
@gin.configurable
class UvfAgent(UvfAgentCore, ddpg_agent.TD3Agent):
"""A DDPG agent with UVF.
"""
BASE_AGENT_CLASS = ddpg_agent.TD3Agent
ACTION_TYPE = 'continuous'
def __init__(self, *args, **kwargs):
UvfAgentCore.__init__(self, *args, **kwargs)
@gin.configurable
class MetaAgent(MetaAgentCore, ddpg_agent.TD3Agent):
"""A DDPG meta-agent.
"""
BASE_AGENT_CLASS = ddpg_agent.TD3Agent
ACTION_TYPE = 'continuous'
def __init__(self, *args, **kwargs):
MetaAgentCore.__init__(self, *args, **kwargs)
@gin.configurable()
def state_preprocess_net(
states,
num_output_dims=2,
states_hidden_layers=(100,),
normalizer_fn=None,
activation_fn=tf.nn.relu,
zero_time=True,
images=False):
"""Creates a simple feed forward net for embedding states.
"""
with slim.arg_scope(
[slim.fully_connected],
activation_fn=activation_fn,
normalizer_fn=normalizer_fn,
weights_initializer=slim.variance_scaling_initializer(
factor=1.0/3.0, mode='FAN_IN', uniform=True)):
states_shape = tf.shape(states)
states_dtype = states.dtype
states = tf.to_float(states)
if images: # Zero-out x-y
states *= tf.constant([0.] * 2 + [1.] * (states.shape[-1] - 2), dtype=states.dtype)
if zero_time:
states *= tf.constant([1.] * (states.shape[-1] - 1) + [0.], dtype=states.dtype)
orig_states = states
embed = states
if states_hidden_layers:
embed = slim.stack(embed, slim.fully_connected, states_hidden_layers,
scope='states')
with slim.arg_scope([slim.fully_connected],
weights_regularizer=None,
weights_initializer=tf.random_uniform_initializer(
minval=-0.003, maxval=0.003)):
embed = slim.fully_connected(embed, num_output_dims,
activation_fn=None,
normalizer_fn=None,
scope='value')
output = embed
output = tf.cast(output, states_dtype)
return output
@gin.configurable()
def action_embed_net(
actions,
states=None,
num_output_dims=2,
hidden_layers=(400, 300),
normalizer_fn=None,
activation_fn=tf.nn.relu,
zero_time=True,
images=False):
"""Creates a simple feed forward net for embedding actions.
"""
with slim.arg_scope(
[slim.fully_connected],
activation_fn=activation_fn,
normalizer_fn=normalizer_fn,
weights_initializer=slim.variance_scaling_initializer(
factor=1.0/3.0, mode='FAN_IN', uniform=True)):
actions = tf.to_float(actions)
if states is not None:
if images: # Zero-out x-y
states *= tf.constant([0.] * 2 + [1.] * (states.shape[-1] - 2), dtype=states.dtype)
if zero_time:
states *= tf.constant([1.] * (states.shape[-1] - 1) + [0.], dtype=states.dtype)
actions = tf.concat([actions, tf.to_float(states)], -1)
embed = actions
if hidden_layers:
embed = slim.stack(embed, slim.fully_connected, hidden_layers,
scope='hidden')
with slim.arg_scope([slim.fully_connected],
weights_regularizer=None,
weights_initializer=tf.random_uniform_initializer(
minval=-0.003, maxval=0.003)):
embed = slim.fully_connected(embed, num_output_dims,
activation_fn=None,
normalizer_fn=None,
scope='value')
if num_output_dims == 1:
return embed[:, 0, ...]
else:
return embed
def huber(x, kappa=0.1):
return (0.5 * tf.square(x) * tf.to_float(tf.abs(x) <= kappa) +
kappa * (tf.abs(x) - 0.5 * kappa) * tf.to_float(tf.abs(x) > kappa)
) / kappa
@gin.configurable()
class StatePreprocess(object):
STATE_PREPROCESS_NET_SCOPE = 'state_process_net'
ACTION_EMBED_NET_SCOPE = 'action_embed_net'
def __init__(self, trainable=False,
state_preprocess_net=lambda states: states,
action_embed_net=lambda actions, *args, **kwargs: actions,
ndims=None):
self.trainable = trainable
self._scope = tf.get_variable_scope().name
self._ndims = ndims
self._state_preprocess_net = tf.make_template(
self.STATE_PREPROCESS_NET_SCOPE, state_preprocess_net,
create_scope_now_=True)
self._action_embed_net = tf.make_template(
self.ACTION_EMBED_NET_SCOPE, action_embed_net,
create_scope_now_=True)
def __call__(self, states):
batched = states.get_shape().ndims != 1
if not batched:
states = tf.expand_dims(states, 0)
embedded = self._state_preprocess_net(states)
if self._ndims is not None:
embedded = embedded[..., :self._ndims]
if not batched:
return embedded[0]
return embedded
def loss(self, states, next_states, low_actions, low_states):
batch_size = tf.shape(states)[0]
d = int(low_states.shape[1])
# Sample indices into meta-transition to train on.
probs = 0.99 ** tf.range(d, dtype=tf.float32)
probs *= tf.constant([1.0] * (d - 1) + [1.0 / (1 - 0.99)],
dtype=tf.float32)
probs /= tf.reduce_sum(probs)
index_dist = tf.distributions.Categorical(probs=probs, dtype=tf.int64)
indices = index_dist.sample(batch_size)
batch_size = tf.cast(batch_size, tf.int64)
next_indices = tf.concat(
[tf.range(batch_size, dtype=tf.int64)[:, None],
(1 + indices[:, None]) % d], -1)
new_next_states = tf.where(indices < d - 1,
tf.gather_nd(low_states, next_indices),
next_states)
next_states = new_next_states
embed1 = tf.to_float(self._state_preprocess_net(states))
embed2 = tf.to_float(self._state_preprocess_net(next_states))
action_embed = self._action_embed_net(
tf.layers.flatten(low_actions), states=states)
tau = 2.0
fn = lambda z: tau * tf.reduce_sum(huber(z), -1)
all_embed = tf.get_variable('all_embed', [1024, int(embed1.shape[-1])],
initializer=tf.zeros_initializer())
upd = all_embed.assign(tf.concat([all_embed[batch_size:], embed2], 0))
with tf.control_dependencies([upd]):
close = 1 * tf.reduce_mean(fn(embed1 + action_embed - embed2))
prior_log_probs = tf.reduce_logsumexp(
-fn((embed1 + action_embed)[:, None, :] - all_embed[None, :, :]),
axis=-1) - tf.log(tf.to_float(all_embed.shape[0]))
far = tf.reduce_mean(tf.exp(-fn((embed1 + action_embed)[1:] - embed2[:-1])
- tf.stop_gradient(prior_log_probs[1:])))
repr_log_probs = tf.stop_gradient(
-fn(embed1 + action_embed - embed2) - prior_log_probs) / tau
return close + far, repr_log_probs, indices
def get_trainable_vars(self):
return (
slim.get_trainable_variables(
uvf_utils.join_scope(self._scope, self.STATE_PREPROCESS_NET_SCOPE)) +
slim.get_trainable_variables(
uvf_utils.join_scope(self._scope, self.ACTION_EMBED_NET_SCOPE)))
@gin.configurable()
class InverseDynamics(object):
INVERSE_DYNAMICS_NET_SCOPE = 'inverse_dynamics'
def __init__(self, spec):
self._spec = spec
def sample(self, states, next_states, num_samples, orig_goals, sc=0.5):
goal_dim = orig_goals.shape[-1]
spec_range = (self._spec.maximum - self._spec.minimum) / 2 * tf.ones([goal_dim])
loc = tf.cast(next_states - states, tf.float32)[:, :goal_dim]
scale = sc * tf.tile(tf.reshape(spec_range, [1, goal_dim]),
[tf.shape(states)[0], 1])
dist = tf.distributions.Normal(loc, scale)
if num_samples == 1:
return dist.sample()
samples = tf.concat([dist.sample(num_samples - 2),
tf.expand_dims(loc, 0),
tf.expand_dims(orig_goals, 0)], 0)
return uvf_utils.clip_to_spec(samples, self._spec)