# Copyright 2018 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """A UVF agent. """ import tensorflow as tf import gin.tf from agents import ddpg_agent # pylint: disable=unused-import import cond_fn from utils import utils as uvf_utils from context import gin_imports # pylint: enable=unused-import slim = tf.contrib.slim @gin.configurable class UvfAgentCore(object): """Defines basic functions for UVF agent. Must be inherited with an RL agent. Used as lower-level agent. """ def __init__(self, observation_spec, action_spec, tf_env, tf_context, step_cond_fn=cond_fn.env_transition, reset_episode_cond_fn=cond_fn.env_restart, reset_env_cond_fn=cond_fn.false_fn, metrics=None, **base_agent_kwargs): """Constructs a UVF agent. Args: observation_spec: A TensorSpec defining the observations. action_spec: A BoundedTensorSpec defining the actions. tf_env: A Tensorflow environment object. tf_context: A Context class. step_cond_fn: A function indicating whether to increment the num of steps. reset_episode_cond_fn: A function indicating whether to restart the episode, resampling the context. reset_env_cond_fn: A function indicating whether to perform a manual reset of the environment. metrics: A list of functions that evaluate metrics of the agent. **base_agent_kwargs: A dictionary of parameters for base RL Agent. Raises: ValueError: If 'dqda_clipping' is < 0. """ self._step_cond_fn = step_cond_fn self._reset_episode_cond_fn = reset_episode_cond_fn self._reset_env_cond_fn = reset_env_cond_fn self.metrics = metrics # expose tf_context methods self.tf_context = tf_context(tf_env=tf_env) self.set_replay = self.tf_context.set_replay self.sample_contexts = self.tf_context.sample_contexts self.compute_rewards = self.tf_context.compute_rewards self.gamma_index = self.tf_context.gamma_index self.context_specs = self.tf_context.context_specs self.context_as_action_specs = self.tf_context.context_as_action_specs self.init_context_vars = self.tf_context.create_vars self.env_observation_spec = observation_spec[0] merged_observation_spec = (uvf_utils.merge_specs( (self.env_observation_spec,) + self.context_specs),) self._context_vars = dict() self._action_vars = dict() self.BASE_AGENT_CLASS.__init__( self, observation_spec=merged_observation_spec, action_spec=action_spec, **base_agent_kwargs ) def set_meta_agent(self, agent=None): self._meta_agent = agent @property def meta_agent(self): return self._meta_agent def actor_loss(self, states, actions, rewards, discounts, next_states): """Returns the next action for the state. Args: state: A [num_state_dims] tensor representing a state. context: A list of [num_context_dims] tensor representing a context. Returns: A [num_action_dims] tensor representing the action. """ return self.BASE_AGENT_CLASS.actor_loss(self, states) def action(self, state, context=None): """Returns the next action for the state. Args: state: A [num_state_dims] tensor representing a state. context: A list of [num_context_dims] tensor representing a context. Returns: A [num_action_dims] tensor representing the action. """ merged_state = self.merged_state(state, context) return self.BASE_AGENT_CLASS.action(self, merged_state) def actions(self, state, context=None): """Returns the next action for the state. Args: state: A [-1, num_state_dims] tensor representing a state. context: A list of [-1, num_context_dims] tensor representing a context. Returns: A [-1, num_action_dims] tensor representing the action. """ merged_states = self.merged_states(state, context) return self.BASE_AGENT_CLASS.actor_net(self, merged_states) def log_probs(self, states, actions, state_reprs, contexts=None): assert contexts is not None batch_dims = [tf.shape(states)[0], tf.shape(states)[1]] contexts = self.tf_context.context_multi_transition_fn( contexts, states=tf.to_float(state_reprs)) flat_states = tf.reshape(states, [batch_dims[0] * batch_dims[1], states.shape[-1]]) flat_contexts = [tf.reshape(tf.cast(context, states.dtype), [batch_dims[0] * batch_dims[1], context.shape[-1]]) for context in contexts] flat_pred_actions = self.actions(flat_states, flat_contexts) pred_actions = tf.reshape(flat_pred_actions, batch_dims + [flat_pred_actions.shape[-1]]) error = tf.square(actions - pred_actions) spec_range = (self._action_spec.maximum - self._action_spec.minimum) / 2 normalized_error = tf.cast(error, tf.float64) / tf.constant(spec_range) ** 2 return -normalized_error @gin.configurable('uvf_add_noise_fn') def add_noise_fn(self, action_fn, stddev=1.0, debug=False, clip=True, global_step=None): """Returns the action_fn with additive Gaussian noise. Args: action_fn: A callable(`state`, `context`) which returns a [num_action_dims] tensor representing a action. stddev: stddev for the Ornstein-Uhlenbeck noise. debug: Print debug messages. Returns: A [num_action_dims] action tensor. """ if global_step is not None: stddev *= tf.maximum( # Decay exploration during training. tf.train.exponential_decay(1.0, global_step, 1e6, 0.8), 0.5) def noisy_action_fn(state, context=None): """Noisy action fn.""" action = action_fn(state, context) if debug: action = uvf_utils.tf_print( action, [action], message='[add_noise_fn] pre-noise action', first_n=100) noise_dist = tf.distributions.Normal(tf.zeros_like(action), tf.ones_like(action) * stddev) noise = noise_dist.sample() action += noise if debug: action = uvf_utils.tf_print( action, [action], message='[add_noise_fn] post-noise action', first_n=100) if clip: action = uvf_utils.clip_to_spec(action, self._action_spec) return action return noisy_action_fn def merged_state(self, state, context=None): """Returns the merged state from the environment state and contexts. Args: state: A [num_state_dims] tensor representing a state. context: A list of [num_context_dims] tensor representing a context. If None, use the internal context. Returns: A [num_merged_state_dims] tensor representing the merged state. """ if context is None: context = list(self.context_vars) state = tf.concat([state,] + context, axis=-1) self._validate_states(self._batch_state(state)) return state def merged_states(self, states, contexts=None): """Returns the batch merged state from the batch env state and contexts. Args: states: A [batch_size, num_state_dims] tensor representing a batch of states. contexts: A list of [batch_size, num_context_dims] tensor representing a batch of contexts. If None, use the internal context. Returns: A [batch_size, num_merged_state_dims] tensor representing the batch of merged states. """ if contexts is None: contexts = [tf.tile(tf.expand_dims(context, axis=0), (tf.shape(states)[0], 1)) for context in self.context_vars] states = tf.concat([states,] + contexts, axis=-1) self._validate_states(states) return states def unmerged_states(self, merged_states): """Returns the batch state and contexts from the batch merged state. Args: merged_states: A [batch_size, num_merged_state_dims] tensor representing a batch of merged states. Returns: A [batch_size, num_state_dims] tensor and a list of [batch_size, num_context_dims] tensors representing the batch state and contexts respectively. """ self._validate_states(merged_states) num_state_dims = self.env_observation_spec.shape.as_list()[0] num_context_dims_list = [c.shape.as_list()[0] for c in self.context_specs] states = merged_states[:, :num_state_dims] contexts = [] i = num_state_dims for num_context_dims in num_context_dims_list: contexts.append(merged_states[:, i: i+num_context_dims]) i += num_context_dims return states, contexts def sample_random_actions(self, batch_size=1): """Return random actions. Args: batch_size: Batch size. Returns: A [batch_size, num_action_dims] tensor representing the batch of actions. """ actions = tf.concat( [ tf.random_uniform( shape=(batch_size, 1), minval=self._action_spec.minimum[i], maxval=self._action_spec.maximum[i]) for i in range(self._action_spec.shape[0].value) ], axis=1) return actions def clip_actions(self, actions): """Clip actions to spec. Args: actions: A [batch_size, num_action_dims] tensor representing the batch of actions. Returns: A [batch_size, num_action_dims] tensor representing the batch of clipped actions. """ actions = tf.concat( [ tf.clip_by_value( actions[:, i:i+1], self._action_spec.minimum[i], self._action_spec.maximum[i]) for i in range(self._action_spec.shape[0].value) ], axis=1) return actions def mix_contexts(self, contexts, insert_contexts, indices): """Mix two contexts based on indices. Args: contexts: A list of [batch_size, num_context_dims] tensor representing the batch of contexts. insert_contexts: A list of [batch_size, num_context_dims] tensor representing the batch of contexts to be inserted. indices: A list of a list of integers denoting indices to replace. Returns: A list of resulting contexts. """ if indices is None: indices = [[]] * len(contexts) assert len(contexts) == len(indices) assert all([spec.shape.ndims == 1 for spec in self.context_specs]) mix_contexts = [] for contexts_, insert_contexts_, indices_, spec in zip( contexts, insert_contexts, indices, self.context_specs): mix_contexts.append( tf.concat( [ insert_contexts_[:, i:i + 1] if i in indices_ else contexts_[:, i:i + 1] for i in range(spec.shape.as_list()[0]) ], axis=1)) return mix_contexts def begin_episode_ops(self, mode, action_fn=None, state=None): """Returns ops that reset agent at beginning of episodes. Args: mode: a string representing the mode=[train, explore, eval]. Returns: A list of ops. """ all_ops = [] for _, action_var in sorted(self._action_vars.items()): sample_action = self.sample_random_actions(1)[0] all_ops.append(tf.assign(action_var, sample_action)) all_ops += self.tf_context.reset(mode=mode, agent=self._meta_agent, action_fn=action_fn, state=state) return all_ops def cond_begin_episode_op(self, cond, input_vars, mode, meta_action_fn): """Returns op that resets agent at beginning of episodes. A new episode is begun if the cond op evalues to `False`. Args: cond: a Boolean tensor variable. input_vars: A list of tensor variables. mode: a string representing the mode=[train, explore, eval]. Returns: Conditional begin op. """ (state, action, reward, next_state, state_repr, next_state_repr) = input_vars def continue_fn(): """Continue op fn.""" items = [state, action, reward, next_state, state_repr, next_state_repr] + list(self.context_vars) batch_items = [tf.expand_dims(item, 0) for item in items] (states, actions, rewards, next_states, state_reprs, next_state_reprs) = batch_items[:6] context_reward = self.compute_rewards( mode, state_reprs, actions, rewards, next_state_reprs, batch_items[6:])[0][0] context_reward = tf.cast(context_reward, dtype=reward.dtype) if self.meta_agent is not None: meta_action = tf.concat(self.context_vars, -1) items = [state, meta_action, reward, next_state, state_repr, next_state_repr] + list(self.meta_agent.context_vars) batch_items = [tf.expand_dims(item, 0) for item in items] (states, meta_actions, rewards, next_states, state_reprs, next_state_reprs) = batch_items[:6] meta_reward = self.meta_agent.compute_rewards( mode, states, meta_actions, rewards, next_states, batch_items[6:])[0][0] meta_reward = tf.cast(meta_reward, dtype=reward.dtype) else: meta_reward = tf.constant(0, dtype=reward.dtype) with tf.control_dependencies([context_reward, meta_reward]): step_ops = self.tf_context.step(mode=mode, agent=self._meta_agent, state=state, next_state=next_state, state_repr=state_repr, next_state_repr=next_state_repr, action_fn=meta_action_fn) with tf.control_dependencies(step_ops): context_reward, meta_reward = map(tf.identity, [context_reward, meta_reward]) return context_reward, meta_reward def begin_episode_fn(): """Begin op fn.""" begin_ops = self.begin_episode_ops(mode=mode, action_fn=meta_action_fn, state=state) with tf.control_dependencies(begin_ops): return tf.zeros_like(reward), tf.zeros_like(reward) with tf.control_dependencies(input_vars): cond_begin_episode_op = tf.cond(cond, continue_fn, begin_episode_fn) return cond_begin_episode_op def get_env_base_wrapper(self, env_base, **begin_kwargs): """Create a wrapper around env_base, with agent-specific begin/end_episode. Args: env_base: A python environment base. **begin_kwargs: Keyword args for begin_episode_ops. Returns: An object with begin_episode() and end_episode(). """ begin_ops = self.begin_episode_ops(**begin_kwargs) return uvf_utils.get_contextual_env_base(env_base, begin_ops) def init_action_vars(self, name, i=None): """Create and return a tensorflow Variable holding an action. Args: name: Name of the variables. i: Integer id. Returns: A [num_action_dims] tensor. """ if i is not None: name += '_%d' % i assert name not in self._action_vars, ('Conflict! %s is already ' 'initialized.') % name self._action_vars[name] = tf.Variable( self.sample_random_actions(1)[0], name='%s_action' % (name)) self._validate_actions(tf.expand_dims(self._action_vars[name], 0)) return self._action_vars[name] @gin.configurable('uvf_critic_function') def critic_function(self, critic_vals, states, critic_fn=None): """Computes q values based on outputs from the critic net. Args: critic_vals: A tf.float32 [batch_size, ...] tensor representing outputs from the critic net. states: A [batch_size, num_state_dims] tensor representing a batch of states. critic_fn: A callable that process outputs from critic_net and outputs a [batch_size] tensor representing q values. Returns: A tf.float32 [batch_size] tensor representing q values. """ if critic_fn is not None: env_states, contexts = self.unmerged_states(states) critic_vals = critic_fn(critic_vals, env_states, contexts) critic_vals.shape.assert_has_rank(1) return critic_vals def get_action_vars(self, key): return self._action_vars[key] def get_context_vars(self, key): return self.tf_context.context_vars[key] def step_cond_fn(self, *args): return self._step_cond_fn(self, *args) def reset_episode_cond_fn(self, *args): return self._reset_episode_cond_fn(self, *args) def reset_env_cond_fn(self, *args): return self._reset_env_cond_fn(self, *args) @property def context_vars(self): return self.tf_context.vars @gin.configurable class MetaAgentCore(UvfAgentCore): """Defines basic functions for UVF Meta-agent. Must be inherited with an RL agent. Used as higher-level agent. """ def __init__(self, observation_spec, action_spec, tf_env, tf_context, sub_context, step_cond_fn=cond_fn.env_transition, reset_episode_cond_fn=cond_fn.env_restart, reset_env_cond_fn=cond_fn.false_fn, metrics=None, actions_reg=0., k=2, **base_agent_kwargs): """Constructs a Meta agent. Args: observation_spec: A TensorSpec defining the observations. action_spec: A BoundedTensorSpec defining the actions. tf_env: A Tensorflow environment object. tf_context: A Context class. step_cond_fn: A function indicating whether to increment the num of steps. reset_episode_cond_fn: A function indicating whether to restart the episode, resampling the context. reset_env_cond_fn: A function indicating whether to perform a manual reset of the environment. metrics: A list of functions that evaluate metrics of the agent. **base_agent_kwargs: A dictionary of parameters for base RL Agent. Raises: ValueError: If 'dqda_clipping' is < 0. """ self._step_cond_fn = step_cond_fn self._reset_episode_cond_fn = reset_episode_cond_fn self._reset_env_cond_fn = reset_env_cond_fn self.metrics = metrics self._actions_reg = actions_reg self._k = k # expose tf_context methods self.tf_context = tf_context(tf_env=tf_env) self.sub_context = sub_context(tf_env=tf_env) self.set_replay = self.tf_context.set_replay self.sample_contexts = self.tf_context.sample_contexts self.compute_rewards = self.tf_context.compute_rewards self.gamma_index = self.tf_context.gamma_index self.context_specs = self.tf_context.context_specs self.context_as_action_specs = self.tf_context.context_as_action_specs self.sub_context_as_action_specs = self.sub_context.context_as_action_specs self.init_context_vars = self.tf_context.create_vars self.env_observation_spec = observation_spec[0] merged_observation_spec = (uvf_utils.merge_specs( (self.env_observation_spec,) + self.context_specs),) self._context_vars = dict() self._action_vars = dict() assert len(self.context_as_action_specs) == 1 self.BASE_AGENT_CLASS.__init__( self, observation_spec=merged_observation_spec, action_spec=self.sub_context_as_action_specs, **base_agent_kwargs ) @gin.configurable('meta_add_noise_fn') def add_noise_fn(self, action_fn, stddev=1.0, debug=False, global_step=None): noisy_action_fn = super(MetaAgentCore, self).add_noise_fn( action_fn, stddev, clip=True, global_step=global_step) return noisy_action_fn def actor_loss(self, states, actions, rewards, discounts, next_states): """Returns the next action for the state. Args: state: A [num_state_dims] tensor representing a state. context: A list of [num_context_dims] tensor representing a context. Returns: A [num_action_dims] tensor representing the action. """ actions = self.actor_net(states, stop_gradients=False) regularizer = self._actions_reg * tf.reduce_mean( tf.reduce_sum(tf.abs(actions[:, self._k:]), -1), 0) loss = self.BASE_AGENT_CLASS.actor_loss(self, states) return regularizer + loss @gin.configurable class UvfAgent(UvfAgentCore, ddpg_agent.TD3Agent): """A DDPG agent with UVF. """ BASE_AGENT_CLASS = ddpg_agent.TD3Agent ACTION_TYPE = 'continuous' def __init__(self, *args, **kwargs): UvfAgentCore.__init__(self, *args, **kwargs) @gin.configurable class MetaAgent(MetaAgentCore, ddpg_agent.TD3Agent): """A DDPG meta-agent. """ BASE_AGENT_CLASS = ddpg_agent.TD3Agent ACTION_TYPE = 'continuous' def __init__(self, *args, **kwargs): MetaAgentCore.__init__(self, *args, **kwargs) @gin.configurable() def state_preprocess_net( states, num_output_dims=2, states_hidden_layers=(100,), normalizer_fn=None, activation_fn=tf.nn.relu, zero_time=True, images=False): """Creates a simple feed forward net for embedding states. """ with slim.arg_scope( [slim.fully_connected], activation_fn=activation_fn, normalizer_fn=normalizer_fn, weights_initializer=slim.variance_scaling_initializer( factor=1.0/3.0, mode='FAN_IN', uniform=True)): states_shape = tf.shape(states) states_dtype = states.dtype states = tf.to_float(states) if images: # Zero-out x-y states *= tf.constant([0.] * 2 + [1.] * (states.shape[-1] - 2), dtype=states.dtype) if zero_time: states *= tf.constant([1.] * (states.shape[-1] - 1) + [0.], dtype=states.dtype) orig_states = states embed = states if states_hidden_layers: embed = slim.stack(embed, slim.fully_connected, states_hidden_layers, scope='states') with slim.arg_scope([slim.fully_connected], weights_regularizer=None, weights_initializer=tf.random_uniform_initializer( minval=-0.003, maxval=0.003)): embed = slim.fully_connected(embed, num_output_dims, activation_fn=None, normalizer_fn=None, scope='value') output = embed output = tf.cast(output, states_dtype) return output @gin.configurable() def action_embed_net( actions, states=None, num_output_dims=2, hidden_layers=(400, 300), normalizer_fn=None, activation_fn=tf.nn.relu, zero_time=True, images=False): """Creates a simple feed forward net for embedding actions. """ with slim.arg_scope( [slim.fully_connected], activation_fn=activation_fn, normalizer_fn=normalizer_fn, weights_initializer=slim.variance_scaling_initializer( factor=1.0/3.0, mode='FAN_IN', uniform=True)): actions = tf.to_float(actions) if states is not None: if images: # Zero-out x-y states *= tf.constant([0.] * 2 + [1.] * (states.shape[-1] - 2), dtype=states.dtype) if zero_time: states *= tf.constant([1.] * (states.shape[-1] - 1) + [0.], dtype=states.dtype) actions = tf.concat([actions, tf.to_float(states)], -1) embed = actions if hidden_layers: embed = slim.stack(embed, slim.fully_connected, hidden_layers, scope='hidden') with slim.arg_scope([slim.fully_connected], weights_regularizer=None, weights_initializer=tf.random_uniform_initializer( minval=-0.003, maxval=0.003)): embed = slim.fully_connected(embed, num_output_dims, activation_fn=None, normalizer_fn=None, scope='value') if num_output_dims == 1: return embed[:, 0, ...] else: return embed def huber(x, kappa=0.1): return (0.5 * tf.square(x) * tf.to_float(tf.abs(x) <= kappa) + kappa * (tf.abs(x) - 0.5 * kappa) * tf.to_float(tf.abs(x) > kappa) ) / kappa @gin.configurable() class StatePreprocess(object): STATE_PREPROCESS_NET_SCOPE = 'state_process_net' ACTION_EMBED_NET_SCOPE = 'action_embed_net' def __init__(self, trainable=False, state_preprocess_net=lambda states: states, action_embed_net=lambda actions, *args, **kwargs: actions, ndims=None): self.trainable = trainable self._scope = tf.get_variable_scope().name self._ndims = ndims self._state_preprocess_net = tf.make_template( self.STATE_PREPROCESS_NET_SCOPE, state_preprocess_net, create_scope_now_=True) self._action_embed_net = tf.make_template( self.ACTION_EMBED_NET_SCOPE, action_embed_net, create_scope_now_=True) def __call__(self, states): batched = states.get_shape().ndims != 1 if not batched: states = tf.expand_dims(states, 0) embedded = self._state_preprocess_net(states) if self._ndims is not None: embedded = embedded[..., :self._ndims] if not batched: return embedded[0] return embedded def loss(self, states, next_states, low_actions, low_states): batch_size = tf.shape(states)[0] d = int(low_states.shape[1]) # Sample indices into meta-transition to train on. probs = 0.99 ** tf.range(d, dtype=tf.float32) probs *= tf.constant([1.0] * (d - 1) + [1.0 / (1 - 0.99)], dtype=tf.float32) probs /= tf.reduce_sum(probs) index_dist = tf.distributions.Categorical(probs=probs, dtype=tf.int64) indices = index_dist.sample(batch_size) batch_size = tf.cast(batch_size, tf.int64) next_indices = tf.concat( [tf.range(batch_size, dtype=tf.int64)[:, None], (1 + indices[:, None]) % d], -1) new_next_states = tf.where(indices < d - 1, tf.gather_nd(low_states, next_indices), next_states) next_states = new_next_states embed1 = tf.to_float(self._state_preprocess_net(states)) embed2 = tf.to_float(self._state_preprocess_net(next_states)) action_embed = self._action_embed_net( tf.layers.flatten(low_actions), states=states) tau = 2.0 fn = lambda z: tau * tf.reduce_sum(huber(z), -1) all_embed = tf.get_variable('all_embed', [1024, int(embed1.shape[-1])], initializer=tf.zeros_initializer()) upd = all_embed.assign(tf.concat([all_embed[batch_size:], embed2], 0)) with tf.control_dependencies([upd]): close = 1 * tf.reduce_mean(fn(embed1 + action_embed - embed2)) prior_log_probs = tf.reduce_logsumexp( -fn((embed1 + action_embed)[:, None, :] - all_embed[None, :, :]), axis=-1) - tf.log(tf.to_float(all_embed.shape[0])) far = tf.reduce_mean(tf.exp(-fn((embed1 + action_embed)[1:] - embed2[:-1]) - tf.stop_gradient(prior_log_probs[1:]))) repr_log_probs = tf.stop_gradient( -fn(embed1 + action_embed - embed2) - prior_log_probs) / tau return close + far, repr_log_probs, indices def get_trainable_vars(self): return ( slim.get_trainable_variables( uvf_utils.join_scope(self._scope, self.STATE_PREPROCESS_NET_SCOPE)) + slim.get_trainable_variables( uvf_utils.join_scope(self._scope, self.ACTION_EMBED_NET_SCOPE))) @gin.configurable() class InverseDynamics(object): INVERSE_DYNAMICS_NET_SCOPE = 'inverse_dynamics' def __init__(self, spec): self._spec = spec def sample(self, states, next_states, num_samples, orig_goals, sc=0.5): goal_dim = orig_goals.shape[-1] spec_range = (self._spec.maximum - self._spec.minimum) / 2 * tf.ones([goal_dim]) loc = tf.cast(next_states - states, tf.float32)[:, :goal_dim] scale = sc * tf.tile(tf.reshape(spec_range, [1, goal_dim]), [tf.shape(states)[0], 1]) dist = tf.distributions.Normal(loc, scale) if num_samples == 1: return dist.sample() samples = tf.concat([dist.sample(num_samples - 2), tf.expand_dims(loc, 0), tf.expand_dims(orig_goals, 0)], 0) return uvf_utils.clip_to_spec(samples, self._spec)