# Copyright 2017 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Model construction.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports import tensorflow as tf from models import bidirectional from models import bidirectional_vd from models import bidirectional_zaremba from models import cnn from models import critic_vd from models import feedforward from models import rnn from models import rnn_nas from models import rnn_vd from models import rnn_zaremba from models import seq2seq from models import seq2seq_nas from models import seq2seq_vd from models import seq2seq_zaremba FLAGS = tf.app.flags.FLAGS # TODO(adai): IMDB labels placeholder to model. def create_generator(hparams, inputs, targets, present, is_training, is_validating, reuse=None): """Create the Generator model specified by the FLAGS and hparams. Args; hparams: Hyperparameters for the MaskGAN. inputs: tf.int32 Tensor of the sequence input of shape [batch_size, sequence_length]. present: tf.bool Tensor indicating the presence or absence of the token of shape [batch_size, sequence_length]. is_training: Whether the model is training. is_validating: Whether the model is being run in validation mode for calculating the perplexity. reuse (Optional): Whether to reuse the model. Returns: Tuple of the (sequence, logits, log_probs) of the Generator. Sequence and logits have shape [batch_size, sequence_length, vocab_size]. The log_probs will have shape [batch_size, sequence_length]. Log_probs corresponds to the log probability of selecting the words. """ if FLAGS.generator_model == 'rnn': (sequence, logits, log_probs, initial_state, final_state) = rnn.generator( hparams, inputs, targets, present, is_training=is_training, is_validating=is_validating, reuse=reuse) elif FLAGS.generator_model == 'rnn_zaremba': (sequence, logits, log_probs, initial_state, final_state) = rnn_zaremba.generator( hparams, inputs, targets, present, is_training=is_training, is_validating=is_validating, reuse=reuse) elif FLAGS.generator_model == 'seq2seq': (sequence, logits, log_probs, initial_state, final_state) = seq2seq.generator( hparams, inputs, targets, present, is_training=is_training, is_validating=is_validating, reuse=reuse) elif FLAGS.generator_model == 'seq2seq_zaremba': (sequence, logits, log_probs, initial_state, final_state) = seq2seq_zaremba.generator( hparams, inputs, targets, present, is_training=is_training, is_validating=is_validating, reuse=reuse) elif FLAGS.generator_model == 'rnn_nas': (sequence, logits, log_probs, initial_state, final_state) = rnn_nas.generator( hparams, inputs, targets, present, is_training=is_training, is_validating=is_validating, reuse=reuse) elif FLAGS.generator_model == 'seq2seq_nas': (sequence, logits, log_probs, initial_state, final_state) = seq2seq_nas.generator( hparams, inputs, targets, present, is_training=is_training, is_validating=is_validating, reuse=reuse) elif FLAGS.generator_model == 'seq2seq_vd': (sequence, logits, log_probs, initial_state, final_state, encoder_states) = seq2seq_vd.generator( hparams, inputs, targets, present, is_training=is_training, is_validating=is_validating, reuse=reuse) else: raise NotImplementedError return (sequence, logits, log_probs, initial_state, final_state, encoder_states) def create_discriminator(hparams, sequence, is_training, reuse=None, initial_state=None, inputs=None, present=None): """Create the Discriminator model specified by the FLAGS and hparams. Args: hparams: Hyperparameters for the MaskGAN. sequence: tf.int32 Tensor sequence of shape [batch_size, sequence_length] is_training: Whether the model is training. reuse (Optional): Whether to reuse the model. Returns: predictions: tf.float32 Tensor of predictions of shape [batch_size, sequence_length] """ if FLAGS.discriminator_model == 'cnn': predictions = cnn.discriminator( hparams, sequence, is_training=is_training, reuse=reuse) elif FLAGS.discriminator_model == 'fnn': predictions = feedforward.discriminator( hparams, sequence, is_training=is_training, reuse=reuse) elif FLAGS.discriminator_model == 'rnn': predictions = rnn.discriminator( hparams, sequence, is_training=is_training, reuse=reuse) elif FLAGS.discriminator_model == 'bidirectional': predictions = bidirectional.discriminator( hparams, sequence, is_training=is_training, reuse=reuse) elif FLAGS.discriminator_model == 'bidirectional_zaremba': predictions = bidirectional_zaremba.discriminator( hparams, sequence, is_training=is_training, reuse=reuse) elif FLAGS.discriminator_model == 'seq2seq_vd': predictions = seq2seq_vd.discriminator( hparams, inputs, present, sequence, is_training=is_training, reuse=reuse) elif FLAGS.discriminator_model == 'rnn_zaremba': predictions = rnn_zaremba.discriminator( hparams, sequence, is_training=is_training, reuse=reuse) elif FLAGS.discriminator_model == 'rnn_nas': predictions = rnn_nas.discriminator( hparams, sequence, is_training=is_training, reuse=reuse) elif FLAGS.discriminator_model == 'rnn_vd': predictions = rnn_vd.discriminator( hparams, sequence, is_training=is_training, reuse=reuse, initial_state=initial_state) elif FLAGS.discriminator_model == 'bidirectional_vd': predictions = bidirectional_vd.discriminator( hparams, sequence, is_training=is_training, reuse=reuse, initial_state=initial_state) else: raise NotImplementedError return predictions def create_critic(hparams, sequence, is_training, reuse=None): """Create the Critic model specified by the FLAGS and hparams. Args: hparams: Hyperparameters for the MaskGAN. sequence: tf.int32 Tensor sequence of shape [batch_size, sequence_length] is_training: Whether the model is training. reuse (Optional): Whether to reuse the model. Returns: values: tf.float32 Tensor of predictions of shape [batch_size, sequence_length] """ if FLAGS.baseline_method == 'critic': if FLAGS.discriminator_model == 'seq2seq_vd': values = critic_vd.critic_seq2seq_vd_derivative( hparams, sequence, is_training, reuse=reuse) else: raise NotImplementedError else: raise NotImplementedError return values