NCTC / models /research /maskgan /model_utils /model_construction.py
NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
7.97 kB
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model construction."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
from models import bidirectional
from models import bidirectional_vd
from models import bidirectional_zaremba
from models import cnn
from models import critic_vd
from models import feedforward
from models import rnn
from models import rnn_nas
from models import rnn_vd
from models import rnn_zaremba
from models import seq2seq
from models import seq2seq_nas
from models import seq2seq_vd
from models import seq2seq_zaremba
FLAGS = tf.app.flags.FLAGS
# TODO(adai): IMDB labels placeholder to model.
def create_generator(hparams,
inputs,
targets,
present,
is_training,
is_validating,
reuse=None):
"""Create the Generator model specified by the FLAGS and hparams.
Args;
hparams: Hyperparameters for the MaskGAN.
inputs: tf.int32 Tensor of the sequence input of shape [batch_size,
sequence_length].
present: tf.bool Tensor indicating the presence or absence of the token
of shape [batch_size, sequence_length].
is_training: Whether the model is training.
is_validating: Whether the model is being run in validation mode for
calculating the perplexity.
reuse (Optional): Whether to reuse the model.
Returns:
Tuple of the (sequence, logits, log_probs) of the Generator. Sequence
and logits have shape [batch_size, sequence_length, vocab_size]. The
log_probs will have shape [batch_size, sequence_length]. Log_probs
corresponds to the log probability of selecting the words.
"""
if FLAGS.generator_model == 'rnn':
(sequence, logits, log_probs, initial_state, final_state) = rnn.generator(
hparams,
inputs,
targets,
present,
is_training=is_training,
is_validating=is_validating,
reuse=reuse)
elif FLAGS.generator_model == 'rnn_zaremba':
(sequence, logits, log_probs, initial_state,
final_state) = rnn_zaremba.generator(
hparams,
inputs,
targets,
present,
is_training=is_training,
is_validating=is_validating,
reuse=reuse)
elif FLAGS.generator_model == 'seq2seq':
(sequence, logits, log_probs, initial_state,
final_state) = seq2seq.generator(
hparams,
inputs,
targets,
present,
is_training=is_training,
is_validating=is_validating,
reuse=reuse)
elif FLAGS.generator_model == 'seq2seq_zaremba':
(sequence, logits, log_probs, initial_state,
final_state) = seq2seq_zaremba.generator(
hparams,
inputs,
targets,
present,
is_training=is_training,
is_validating=is_validating,
reuse=reuse)
elif FLAGS.generator_model == 'rnn_nas':
(sequence, logits, log_probs, initial_state,
final_state) = rnn_nas.generator(
hparams,
inputs,
targets,
present,
is_training=is_training,
is_validating=is_validating,
reuse=reuse)
elif FLAGS.generator_model == 'seq2seq_nas':
(sequence, logits, log_probs, initial_state,
final_state) = seq2seq_nas.generator(
hparams,
inputs,
targets,
present,
is_training=is_training,
is_validating=is_validating,
reuse=reuse)
elif FLAGS.generator_model == 'seq2seq_vd':
(sequence, logits, log_probs, initial_state, final_state,
encoder_states) = seq2seq_vd.generator(
hparams,
inputs,
targets,
present,
is_training=is_training,
is_validating=is_validating,
reuse=reuse)
else:
raise NotImplementedError
return (sequence, logits, log_probs, initial_state, final_state,
encoder_states)
def create_discriminator(hparams,
sequence,
is_training,
reuse=None,
initial_state=None,
inputs=None,
present=None):
"""Create the Discriminator model specified by the FLAGS and hparams.
Args:
hparams: Hyperparameters for the MaskGAN.
sequence: tf.int32 Tensor sequence of shape [batch_size, sequence_length]
is_training: Whether the model is training.
reuse (Optional): Whether to reuse the model.
Returns:
predictions: tf.float32 Tensor of predictions of shape [batch_size,
sequence_length]
"""
if FLAGS.discriminator_model == 'cnn':
predictions = cnn.discriminator(
hparams, sequence, is_training=is_training, reuse=reuse)
elif FLAGS.discriminator_model == 'fnn':
predictions = feedforward.discriminator(
hparams, sequence, is_training=is_training, reuse=reuse)
elif FLAGS.discriminator_model == 'rnn':
predictions = rnn.discriminator(
hparams, sequence, is_training=is_training, reuse=reuse)
elif FLAGS.discriminator_model == 'bidirectional':
predictions = bidirectional.discriminator(
hparams, sequence, is_training=is_training, reuse=reuse)
elif FLAGS.discriminator_model == 'bidirectional_zaremba':
predictions = bidirectional_zaremba.discriminator(
hparams, sequence, is_training=is_training, reuse=reuse)
elif FLAGS.discriminator_model == 'seq2seq_vd':
predictions = seq2seq_vd.discriminator(
hparams,
inputs,
present,
sequence,
is_training=is_training,
reuse=reuse)
elif FLAGS.discriminator_model == 'rnn_zaremba':
predictions = rnn_zaremba.discriminator(
hparams, sequence, is_training=is_training, reuse=reuse)
elif FLAGS.discriminator_model == 'rnn_nas':
predictions = rnn_nas.discriminator(
hparams, sequence, is_training=is_training, reuse=reuse)
elif FLAGS.discriminator_model == 'rnn_vd':
predictions = rnn_vd.discriminator(
hparams,
sequence,
is_training=is_training,
reuse=reuse,
initial_state=initial_state)
elif FLAGS.discriminator_model == 'bidirectional_vd':
predictions = bidirectional_vd.discriminator(
hparams,
sequence,
is_training=is_training,
reuse=reuse,
initial_state=initial_state)
else:
raise NotImplementedError
return predictions
def create_critic(hparams, sequence, is_training, reuse=None):
"""Create the Critic model specified by the FLAGS and hparams.
Args:
hparams: Hyperparameters for the MaskGAN.
sequence: tf.int32 Tensor sequence of shape [batch_size, sequence_length]
is_training: Whether the model is training.
reuse (Optional): Whether to reuse the model.
Returns:
values: tf.float32 Tensor of predictions of shape [batch_size,
sequence_length]
"""
if FLAGS.baseline_method == 'critic':
if FLAGS.discriminator_model == 'seq2seq_vd':
values = critic_vd.critic_seq2seq_vd_derivative(
hparams, sequence, is_training, reuse=reuse)
else:
raise NotImplementedError
else:
raise NotImplementedError
return values