# Copyright 2017 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Model optimization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports import tensorflow as tf FLAGS = tf.app.flags.FLAGS def create_dis_pretrain_op(hparams, dis_loss, global_step): """Create a train op for pretraining.""" with tf.name_scope('pretrain_generator'): optimizer = tf.train.AdamOptimizer(hparams.dis_pretrain_learning_rate) dis_vars = [ v for v in tf.trainable_variables() if v.op.name.startswith('dis') ] if FLAGS.dis_update_share_embedding and FLAGS.dis_share_embedding: shared_embedding = [ v for v in tf.trainable_variables() if v.op.name == 'gen/decoder/rnn/embedding' ][0] dis_vars.append(shared_embedding) dis_grads = tf.gradients(dis_loss, dis_vars) dis_grads_clipped, _ = tf.clip_by_global_norm(dis_grads, FLAGS.grad_clipping) dis_pretrain_op = optimizer.apply_gradients( zip(dis_grads_clipped, dis_vars), global_step=global_step) return dis_pretrain_op def create_gen_pretrain_op(hparams, cross_entropy_loss, global_step): """Create a train op for pretraining.""" with tf.name_scope('pretrain_generator'): optimizer = tf.train.AdamOptimizer(hparams.gen_pretrain_learning_rate) gen_vars = [ v for v in tf.trainable_variables() if v.op.name.startswith('gen') ] gen_grads = tf.gradients(cross_entropy_loss, gen_vars) gen_grads_clipped, _ = tf.clip_by_global_norm(gen_grads, FLAGS.grad_clipping) gen_pretrain_op = optimizer.apply_gradients( zip(gen_grads_clipped, gen_vars), global_step=global_step) return gen_pretrain_op def create_gen_train_op(hparams, learning_rate, gen_loss, global_step, mode): """Create Generator train op.""" del hparams with tf.name_scope('train_generator'): if FLAGS.generator_optimizer == 'sgd': gen_optimizer = tf.train.GradientDescentOptimizer(learning_rate) elif FLAGS.generator_optimizer == 'adam': gen_optimizer = tf.train.AdamOptimizer(learning_rate) else: raise NotImplementedError gen_vars = [ v for v in tf.trainable_variables() if v.op.name.startswith('gen') ] print('Optimizing Generator vars.') for v in gen_vars: print(v) if mode == 'MINIMIZE': gen_grads = tf.gradients(gen_loss, gen_vars) elif mode == 'MAXIMIZE': gen_grads = tf.gradients(-gen_loss, gen_vars) else: raise ValueError("Must be one of 'MINIMIZE' or 'MAXIMIZE'") gen_grads_clipped, _ = tf.clip_by_global_norm(gen_grads, FLAGS.grad_clipping) gen_train_op = gen_optimizer.apply_gradients( zip(gen_grads_clipped, gen_vars), global_step=global_step) return gen_train_op, gen_grads_clipped, gen_vars def create_reinforce_gen_train_op(hparams, learning_rate, final_gen_reward, averages_op, global_step): """Create the Generator train_op when using REINFORCE. Args: hparams: MaskGAN hyperparameters. learning_rate: tf.Variable scalar learning rate. final_gen_objective: Scalar final REINFORCE objective for the sequence. averages_op: ExponentialMovingAverage apply average op to maintain the baseline. global_step: global_step tf.Variable. Returns: gen_train_op: Generator training op. """ del hparams with tf.name_scope('train_generator'): if FLAGS.generator_optimizer == 'sgd': gen_optimizer = tf.train.GradientDescentOptimizer(learning_rate) elif FLAGS.generator_optimizer == 'adam': gen_optimizer = tf.train.AdamOptimizer(learning_rate) else: raise NotImplementedError gen_vars = [ v for v in tf.trainable_variables() if v.op.name.startswith('gen') ] print('\nOptimizing Generator vars:') for v in gen_vars: print(v) # Maximize reward. gen_grads = tf.gradients(-final_gen_reward, gen_vars) gen_grads_clipped, _ = tf.clip_by_global_norm(gen_grads, FLAGS.grad_clipping) maximize_op = gen_optimizer.apply_gradients( zip(gen_grads_clipped, gen_vars), global_step=global_step) # Group maintain averages op. if averages_op: gen_train_op = tf.group(maximize_op, averages_op) else: gen_train_op = maximize_op return [gen_train_op, gen_grads, gen_vars] def create_dis_train_op(hparams, dis_loss, global_step): """Create Discriminator train op.""" with tf.name_scope('train_discriminator'): dis_optimizer = tf.train.AdamOptimizer(hparams.dis_learning_rate) dis_vars = [ v for v in tf.trainable_variables() if v.op.name.startswith('dis') ] if FLAGS.dis_update_share_embedding and FLAGS.dis_share_embedding: shared_embedding = [ v for v in tf.trainable_variables() if v.op.name == 'gen/decoder/rnn/embedding' ][0] dis_vars.append(shared_embedding) print('\nOptimizing Discriminator vars:') for v in dis_vars: print(v) dis_grads = tf.gradients(dis_loss, dis_vars) dis_grads_clipped, _ = tf.clip_by_global_norm(dis_grads, FLAGS.grad_clipping) dis_train_op = dis_optimizer.apply_gradients( zip(dis_grads_clipped, dis_vars), global_step=global_step) return dis_train_op, dis_grads_clipped, dis_vars def create_critic_train_op(hparams, critic_loss, global_step): """Create Discriminator train op.""" with tf.name_scope('train_critic'): critic_optimizer = tf.train.AdamOptimizer(hparams.critic_learning_rate) output_vars = [ v for v in tf.trainable_variables() if v.op.name.startswith('critic') ] if FLAGS.critic_update_dis_vars: if FLAGS.discriminator_model == 'bidirectional_vd': critic_vars = [ v for v in tf.trainable_variables() if v.op.name.startswith('dis/rnn') ] elif FLAGS.discriminator_model == 'seq2seq_vd': critic_vars = [ v for v in tf.trainable_variables() if v.op.name.startswith('dis/decoder/rnn/multi_rnn_cell') ] critic_vars.extend(output_vars) else: critic_vars = output_vars print('\nOptimizing Critic vars:') for v in critic_vars: print(v) critic_grads = tf.gradients(critic_loss, critic_vars) critic_grads_clipped, _ = tf.clip_by_global_norm(critic_grads, FLAGS.grad_clipping) critic_train_op = critic_optimizer.apply_gradients( zip(critic_grads_clipped, critic_vars), global_step=global_step) return critic_train_op, critic_grads_clipped, critic_vars