|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Trains LSTM text classification model. |
|
|
|
Model trains with adversarial or virtual adversarial training. |
|
|
|
Computational time: |
|
1.8 hours to train 10000 steps without adversarial or virtual adversarial |
|
training, on 1 layer 1024 hidden units LSTM, 256 embeddings, 400 truncated |
|
BP, 64 minibatch and on single GPU (Pascal Titan X, cuDNNv5). |
|
|
|
4 hours to train 10000 steps with adversarial or virtual adversarial |
|
training, with above condition. |
|
|
|
To initialize embedding and LSTM cell weights from a pretrained model, set |
|
FLAGS.pretrained_model_dir to the pretrained model's checkpoint directory. |
|
""" |
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
|
|
|
|
import tensorflow as tf |
|
|
|
import graphs |
|
import train_utils |
|
|
|
flags = tf.app.flags |
|
FLAGS = flags.FLAGS |
|
|
|
flags.DEFINE_string('pretrained_model_dir', None, |
|
'Directory path to pretrained model to restore from') |
|
|
|
|
|
def main(_): |
|
"""Trains LSTM classification model.""" |
|
tf.logging.set_verbosity(tf.logging.INFO) |
|
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): |
|
model = graphs.get_model() |
|
train_op, loss, global_step = model.classifier_training() |
|
train_utils.run_training( |
|
train_op, |
|
loss, |
|
global_step, |
|
variables_to_restore=model.pretrained_variables, |
|
pretrained_model_dir=FLAGS.pretrained_model_dir) |
|
|
|
|
|
if __name__ == '__main__': |
|
tf.app.run() |
|
|