NCTC / models /research /maskgan /model_utils /model_optimization.py
NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
7.5 kB
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model optimization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
def create_dis_pretrain_op(hparams, dis_loss, global_step):
"""Create a train op for pretraining."""
with tf.name_scope('pretrain_generator'):
optimizer = tf.train.AdamOptimizer(hparams.dis_pretrain_learning_rate)
dis_vars = [
v for v in tf.trainable_variables() if v.op.name.startswith('dis')
]
if FLAGS.dis_update_share_embedding and FLAGS.dis_share_embedding:
shared_embedding = [
v for v in tf.trainable_variables()
if v.op.name == 'gen/decoder/rnn/embedding'
][0]
dis_vars.append(shared_embedding)
dis_grads = tf.gradients(dis_loss, dis_vars)
dis_grads_clipped, _ = tf.clip_by_global_norm(dis_grads,
FLAGS.grad_clipping)
dis_pretrain_op = optimizer.apply_gradients(
zip(dis_grads_clipped, dis_vars), global_step=global_step)
return dis_pretrain_op
def create_gen_pretrain_op(hparams, cross_entropy_loss, global_step):
"""Create a train op for pretraining."""
with tf.name_scope('pretrain_generator'):
optimizer = tf.train.AdamOptimizer(hparams.gen_pretrain_learning_rate)
gen_vars = [
v for v in tf.trainable_variables() if v.op.name.startswith('gen')
]
gen_grads = tf.gradients(cross_entropy_loss, gen_vars)
gen_grads_clipped, _ = tf.clip_by_global_norm(gen_grads,
FLAGS.grad_clipping)
gen_pretrain_op = optimizer.apply_gradients(
zip(gen_grads_clipped, gen_vars), global_step=global_step)
return gen_pretrain_op
def create_gen_train_op(hparams, learning_rate, gen_loss, global_step, mode):
"""Create Generator train op."""
del hparams
with tf.name_scope('train_generator'):
if FLAGS.generator_optimizer == 'sgd':
gen_optimizer = tf.train.GradientDescentOptimizer(learning_rate)
elif FLAGS.generator_optimizer == 'adam':
gen_optimizer = tf.train.AdamOptimizer(learning_rate)
else:
raise NotImplementedError
gen_vars = [
v for v in tf.trainable_variables() if v.op.name.startswith('gen')
]
print('Optimizing Generator vars.')
for v in gen_vars:
print(v)
if mode == 'MINIMIZE':
gen_grads = tf.gradients(gen_loss, gen_vars)
elif mode == 'MAXIMIZE':
gen_grads = tf.gradients(-gen_loss, gen_vars)
else:
raise ValueError("Must be one of 'MINIMIZE' or 'MAXIMIZE'")
gen_grads_clipped, _ = tf.clip_by_global_norm(gen_grads,
FLAGS.grad_clipping)
gen_train_op = gen_optimizer.apply_gradients(
zip(gen_grads_clipped, gen_vars), global_step=global_step)
return gen_train_op, gen_grads_clipped, gen_vars
def create_reinforce_gen_train_op(hparams, learning_rate, final_gen_reward,
averages_op, global_step):
"""Create the Generator train_op when using REINFORCE.
Args:
hparams: MaskGAN hyperparameters.
learning_rate: tf.Variable scalar learning rate.
final_gen_objective: Scalar final REINFORCE objective for the sequence.
averages_op: ExponentialMovingAverage apply average op to
maintain the baseline.
global_step: global_step tf.Variable.
Returns:
gen_train_op: Generator training op.
"""
del hparams
with tf.name_scope('train_generator'):
if FLAGS.generator_optimizer == 'sgd':
gen_optimizer = tf.train.GradientDescentOptimizer(learning_rate)
elif FLAGS.generator_optimizer == 'adam':
gen_optimizer = tf.train.AdamOptimizer(learning_rate)
else:
raise NotImplementedError
gen_vars = [
v for v in tf.trainable_variables() if v.op.name.startswith('gen')
]
print('\nOptimizing Generator vars:')
for v in gen_vars:
print(v)
# Maximize reward.
gen_grads = tf.gradients(-final_gen_reward, gen_vars)
gen_grads_clipped, _ = tf.clip_by_global_norm(gen_grads,
FLAGS.grad_clipping)
maximize_op = gen_optimizer.apply_gradients(
zip(gen_grads_clipped, gen_vars), global_step=global_step)
# Group maintain averages op.
if averages_op:
gen_train_op = tf.group(maximize_op, averages_op)
else:
gen_train_op = maximize_op
return [gen_train_op, gen_grads, gen_vars]
def create_dis_train_op(hparams, dis_loss, global_step):
"""Create Discriminator train op."""
with tf.name_scope('train_discriminator'):
dis_optimizer = tf.train.AdamOptimizer(hparams.dis_learning_rate)
dis_vars = [
v for v in tf.trainable_variables() if v.op.name.startswith('dis')
]
if FLAGS.dis_update_share_embedding and FLAGS.dis_share_embedding:
shared_embedding = [
v for v in tf.trainable_variables()
if v.op.name == 'gen/decoder/rnn/embedding'
][0]
dis_vars.append(shared_embedding)
print('\nOptimizing Discriminator vars:')
for v in dis_vars:
print(v)
dis_grads = tf.gradients(dis_loss, dis_vars)
dis_grads_clipped, _ = tf.clip_by_global_norm(dis_grads,
FLAGS.grad_clipping)
dis_train_op = dis_optimizer.apply_gradients(
zip(dis_grads_clipped, dis_vars), global_step=global_step)
return dis_train_op, dis_grads_clipped, dis_vars
def create_critic_train_op(hparams, critic_loss, global_step):
"""Create Discriminator train op."""
with tf.name_scope('train_critic'):
critic_optimizer = tf.train.AdamOptimizer(hparams.critic_learning_rate)
output_vars = [
v for v in tf.trainable_variables() if v.op.name.startswith('critic')
]
if FLAGS.critic_update_dis_vars:
if FLAGS.discriminator_model == 'bidirectional_vd':
critic_vars = [
v for v in tf.trainable_variables()
if v.op.name.startswith('dis/rnn')
]
elif FLAGS.discriminator_model == 'seq2seq_vd':
critic_vars = [
v for v in tf.trainable_variables()
if v.op.name.startswith('dis/decoder/rnn/multi_rnn_cell')
]
critic_vars.extend(output_vars)
else:
critic_vars = output_vars
print('\nOptimizing Critic vars:')
for v in critic_vars:
print(v)
critic_grads = tf.gradients(critic_loss, critic_vars)
critic_grads_clipped, _ = tf.clip_by_global_norm(critic_grads,
FLAGS.grad_clipping)
critic_train_op = critic_optimizer.apply_gradients(
zip(critic_grads_clipped, critic_vars), global_step=global_step)
return critic_train_op, critic_grads_clipped, critic_vars