NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
12 kB
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import numpy as np
import tensorflow as tf
from model_utils import variable_mapping
FLAGS = tf.app.flags.FLAGS
def generate_mask():
"""Generate the mask to be fed into the model."""
if FLAGS.mask_strategy == 'random':
p = np.random.choice(
[True, False],
size=[FLAGS.batch_size, FLAGS.sequence_length],
p=[FLAGS.is_present_rate, 1. - FLAGS.is_present_rate])
elif FLAGS.mask_strategy == 'contiguous':
masked_length = int((1 - FLAGS.is_present_rate) * FLAGS.sequence_length) - 1
# Determine location to start masking.
start_mask = np.random.randint(
1, FLAGS.sequence_length - masked_length + 1, size=FLAGS.batch_size)
p = np.full([FLAGS.batch_size, FLAGS.sequence_length], True, dtype=bool)
# Create contiguous masked section to be False.
for i, index in enumerate(start_mask):
p[i, index:index + masked_length] = False
else:
raise NotImplementedError
return p
def assign_percent_real(session, percent_real_update, new_rate, current_rate):
"""Run assign operation where the we load the current_rate of percent
real into a Tensorflow variable.
Args:
session: Current tf.Session.
percent_real_update: tf.assign operation.
new_rate: tf.placeholder for the new rate.
current_rate: Percent of tokens that are currently real. Fake tokens
are the ones being imputed by the Generator.
"""
session.run(percent_real_update, feed_dict={new_rate: current_rate})
def assign_learning_rate(session, lr_update, lr_placeholder, new_lr):
"""Run assign operation where the we load the current_rate of percent
real into a Tensorflow variable.
Args:
session: Current tf.Session.
lr_update: tf.assign operation.
lr_placeholder: tf.placeholder for the new learning rate.
new_lr: New learning rate to use.
"""
session.run(lr_update, feed_dict={lr_placeholder: new_lr})
def clip_weights(variables, c_lower, c_upper):
"""Clip a list of weights to be within a certain range.
Args:
variables: List of tf.Variable weights.
c_lower: Lower bound for weights.
c_upper: Upper bound for weights.
"""
clip_ops = []
for var in variables:
clipped_var = tf.clip_by_value(var, c_lower, c_upper)
clip_ops.append(tf.assign(var, clipped_var))
return tf.group(*clip_ops)
def retrieve_init_savers(hparams):
"""Retrieve a dictionary of all the initial savers for the models.
Args:
hparams: MaskGAN hyperparameters.
"""
## Dictionary of init savers.
init_savers = {}
## Load Generator weights from MaskGAN checkpoint.
if FLAGS.maskgan_ckpt:
gen_vars = [
v for v in tf.trainable_variables() if v.op.name.startswith('gen')
]
init_saver = tf.train.Saver(var_list=gen_vars)
init_savers['init_saver'] = init_saver
## Load the Discriminator weights from the MaskGAN checkpoint if
# the weights are compatible.
if FLAGS.discriminator_model == 'seq2seq_vd':
dis_variable_maps = variable_mapping.dis_seq2seq_vd(hparams)
dis_init_saver = tf.train.Saver(var_list=dis_variable_maps)
init_savers['dis_init_saver'] = dis_init_saver
## Load weights from language model checkpoint.
if FLAGS.language_model_ckpt_dir:
if FLAGS.maskgan_ckpt is None:
## Generator Variables/Savers.
if FLAGS.generator_model == 'rnn_nas':
gen_variable_maps = variable_mapping.rnn_nas(hparams, model='gen')
gen_init_saver = tf.train.Saver(var_list=gen_variable_maps)
init_savers['gen_init_saver'] = gen_init_saver
elif FLAGS.generator_model == 'seq2seq_nas':
# Encoder.
gen_encoder_variable_maps = variable_mapping.gen_encoder_seq2seq_nas(
hparams)
gen_encoder_init_saver = tf.train.Saver(
var_list=gen_encoder_variable_maps)
# Decoder.
gen_decoder_variable_maps = variable_mapping.gen_decoder_seq2seq_nas(
hparams)
gen_decoder_init_saver = tf.train.Saver(
var_list=gen_decoder_variable_maps)
init_savers['gen_encoder_init_saver'] = gen_encoder_init_saver
init_savers['gen_decoder_init_saver'] = gen_decoder_init_saver
# seq2seq_vd derived from the same code base as seq2seq_zaremba.
elif (FLAGS.generator_model == 'seq2seq_zaremba' or
FLAGS.generator_model == 'seq2seq_vd'):
# Encoder.
gen_encoder_variable_maps = variable_mapping.gen_encoder_seq2seq(
hparams)
gen_encoder_init_saver = tf.train.Saver(
var_list=gen_encoder_variable_maps)
# Decoder.
gen_decoder_variable_maps = variable_mapping.gen_decoder_seq2seq(
hparams)
gen_decoder_init_saver = tf.train.Saver(
var_list=gen_decoder_variable_maps)
init_savers['gen_encoder_init_saver'] = gen_encoder_init_saver
init_savers['gen_decoder_init_saver'] = gen_decoder_init_saver
else:
raise NotImplementedError
## Discriminator Variables/Savers.
if FLAGS.discriminator_model == 'rnn_nas':
dis_variable_maps = variable_mapping.rnn_nas(hparams, model='dis')
dis_init_saver = tf.train.Saver(var_list=dis_variable_maps)
init_savers['dis_init_saver'] = dis_init_saver
# rnn_vd derived from the same code base as rnn_zaremba.
elif (FLAGS.discriminator_model == 'rnn_zaremba' or
FLAGS.discriminator_model == 'rnn_vd'):
dis_variable_maps = variable_mapping.rnn_zaremba(hparams, model='dis')
dis_init_saver = tf.train.Saver(var_list=dis_variable_maps)
init_savers['dis_init_saver'] = dis_init_saver
elif (FLAGS.discriminator_model == 'bidirectional_zaremba' or
FLAGS.discriminator_model == 'bidirectional_vd'):
dis_fwd_variable_maps = variable_mapping.dis_fwd_bidirectional(hparams)
dis_bwd_variable_maps = variable_mapping.dis_bwd_bidirectional(hparams)
# Savers for the forward/backward Discriminator components.
dis_fwd_init_saver = tf.train.Saver(var_list=dis_fwd_variable_maps)
dis_bwd_init_saver = tf.train.Saver(var_list=dis_bwd_variable_maps)
init_savers['dis_fwd_init_saver'] = dis_fwd_init_saver
init_savers['dis_bwd_init_saver'] = dis_bwd_init_saver
elif FLAGS.discriminator_model == 'cnn':
dis_variable_maps = variable_mapping.cnn()
dis_init_saver = tf.train.Saver(var_list=dis_variable_maps)
init_savers['dis_init_saver'] = dis_init_saver
elif FLAGS.discriminator_model == 'seq2seq_vd':
# Encoder.
dis_encoder_variable_maps = variable_mapping.dis_encoder_seq2seq(hparams)
dis_encoder_init_saver = tf.train.Saver(
var_list=dis_encoder_variable_maps)
# Decoder.
dis_decoder_variable_maps = variable_mapping.dis_decoder_seq2seq(hparams)
dis_decoder_init_saver = tf.train.Saver(
var_list=dis_decoder_variable_maps)
init_savers['dis_encoder_init_saver'] = dis_encoder_init_saver
init_savers['dis_decoder_init_saver'] = dis_decoder_init_saver
return init_savers
def init_fn(init_savers, sess):
"""The init_fn to be passed to the Supervisor.
Args:
init_savers: Dictionary of init_savers. 'init_saver_name': init_saver.
sess: tf.Session.
"""
## Load Generator weights from MaskGAN checkpoint.
if FLAGS.maskgan_ckpt:
print('Restoring Generator from %s.' % FLAGS.maskgan_ckpt)
tf.logging.info('Restoring Generator from %s.' % FLAGS.maskgan_ckpt)
print('Asserting Generator is a seq2seq-variant.')
tf.logging.info('Asserting Generator is a seq2seq-variant.')
assert FLAGS.generator_model.startswith('seq2seq')
init_saver = init_savers['init_saver']
init_saver.restore(sess, FLAGS.maskgan_ckpt)
## Load the Discriminator weights from the MaskGAN checkpoint if
# the weights are compatible.
if FLAGS.discriminator_model == 'seq2seq_vd':
print('Restoring Discriminator from %s.' % FLAGS.maskgan_ckpt)
tf.logging.info('Restoring Discriminator from %s.' % FLAGS.maskgan_ckpt)
dis_init_saver = init_savers['dis_init_saver']
dis_init_saver.restore(sess, FLAGS.maskgan_ckpt)
## Load weights from language model checkpoint.
if FLAGS.language_model_ckpt_dir:
if FLAGS.maskgan_ckpt is None:
## Generator Models.
if FLAGS.generator_model == 'rnn_nas':
load_ckpt = tf.train.latest_checkpoint(FLAGS.language_model_ckpt_dir)
print('Restoring Generator from %s.' % load_ckpt)
tf.logging.info('Restoring Generator from %s.' % load_ckpt)
gen_init_saver = init_savers['gen_init_saver']
gen_init_saver.restore(sess, load_ckpt)
elif FLAGS.generator_model.startswith('seq2seq'):
load_ckpt = tf.train.latest_checkpoint(FLAGS.language_model_ckpt_dir)
print('Restoring Generator from %s.' % load_ckpt)
tf.logging.info('Restoring Generator from %s.' % load_ckpt)
gen_encoder_init_saver = init_savers['gen_encoder_init_saver']
gen_decoder_init_saver = init_savers['gen_decoder_init_saver']
gen_encoder_init_saver.restore(sess, load_ckpt)
gen_decoder_init_saver.restore(sess, load_ckpt)
## Discriminator Models.
if (FLAGS.discriminator_model == 'rnn_nas' or
FLAGS.discriminator_model == 'rnn_zaremba' or
FLAGS.discriminator_model == 'rnn_vd' or
FLAGS.discriminator_model == 'cnn'):
load_ckpt = tf.train.latest_checkpoint(FLAGS.language_model_ckpt_dir)
print('Restoring Discriminator from %s.' % load_ckpt)
tf.logging.info('Restoring Discriminator from %s.' % load_ckpt)
dis_init_saver = init_savers['dis_init_saver']
dis_init_saver.restore(sess, load_ckpt)
elif (FLAGS.discriminator_model == 'bidirectional_zaremba' or
FLAGS.discriminator_model == 'bidirectional_vd'):
assert FLAGS.language_model_ckpt_dir_reversed is not None, (
'Need a reversed directory to fill in the backward components.')
load_fwd_ckpt = tf.train.latest_checkpoint(FLAGS.language_model_ckpt_dir)
load_bwd_ckpt = tf.train.latest_checkpoint(
FLAGS.language_model_ckpt_dir_reversed)
print('Restoring Discriminator from %s and %s.' % (load_fwd_ckpt,
load_bwd_ckpt))
tf.logging.info('Restoring Discriminator from %s and %s.' %
(load_fwd_ckpt, load_bwd_ckpt))
dis_fwd_init_saver = init_savers['dis_fwd_init_saver']
dis_bwd_init_saver = init_savers['dis_bwd_init_saver']
dis_fwd_init_saver.restore(sess, load_fwd_ckpt)
dis_bwd_init_saver.restore(sess, load_bwd_ckpt)
elif FLAGS.discriminator_model == 'seq2seq_vd':
load_ckpt = tf.train.latest_checkpoint(FLAGS.language_model_ckpt_dir)
print('Restoring Discriminator from %s.' % load_ckpt)
tf.logging.info('Restoring Discriminator from %s.' % load_ckpt)
dis_encoder_init_saver = init_savers['dis_encoder_init_saver']
dis_decoder_init_saver = init_savers['dis_decoder_init_saver']
dis_encoder_init_saver.restore(sess, load_ckpt)
dis_decoder_init_saver.restore(sess, load_ckpt)
else:
return