# Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Utilities for training adversarial text models.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import time # Dependency imports import numpy as np import tensorflow as tf flags = tf.app.flags FLAGS = flags.FLAGS flags.DEFINE_string('master', '', 'Master address.') flags.DEFINE_integer('task', 0, 'Task id of the replica running the training.') flags.DEFINE_integer('ps_tasks', 0, 'Number of parameter servers.') flags.DEFINE_string('train_dir', '/tmp/text_train', 'Directory for logs and checkpoints.') flags.DEFINE_integer('max_steps', 1000000, 'Number of batches to run.') flags.DEFINE_boolean('log_device_placement', False, 'Whether to log device placement.') def run_training(train_op, loss, global_step, variables_to_restore=None, pretrained_model_dir=None): """Sets up and runs training loop.""" tf.gfile.MakeDirs(FLAGS.train_dir) # Create pretrain Saver if pretrained_model_dir: assert variables_to_restore tf.logging.info('Will attempt restore from %s: %s', pretrained_model_dir, variables_to_restore) saver_for_restore = tf.train.Saver(variables_to_restore) # Init ops if FLAGS.sync_replicas: local_init_op = tf.get_collection('local_init_op')[0] ready_for_local_init_op = tf.get_collection('ready_for_local_init_op')[0] else: local_init_op = tf.train.Supervisor.USE_DEFAULT ready_for_local_init_op = tf.train.Supervisor.USE_DEFAULT is_chief = FLAGS.task == 0 sv = tf.train.Supervisor( logdir=FLAGS.train_dir, is_chief=is_chief, save_summaries_secs=30, save_model_secs=30, local_init_op=local_init_op, ready_for_local_init_op=ready_for_local_init_op, global_step=global_step) # Delay starting standard services to allow possible pretrained model restore. with sv.managed_session( master=FLAGS.master, config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement), start_standard_services=False) as sess: # Initialization if is_chief: if pretrained_model_dir: maybe_restore_pretrained_model(sess, saver_for_restore, pretrained_model_dir) if FLAGS.sync_replicas: sess.run(tf.get_collection('chief_init_op')[0]) sv.start_standard_services(sess) sv.start_queue_runners(sess) # Training loop global_step_val = 0 while not sv.should_stop() and global_step_val < FLAGS.max_steps: global_step_val = train_step(sess, train_op, loss, global_step) # Final checkpoint if is_chief and global_step_val >= FLAGS.max_steps: sv.saver.save(sess, sv.save_path, global_step=global_step) def maybe_restore_pretrained_model(sess, saver_for_restore, model_dir): """Restores pretrained model if there is no ckpt model.""" ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir) checkpoint_exists = ckpt and ckpt.model_checkpoint_path if checkpoint_exists: tf.logging.info('Checkpoint exists in FLAGS.train_dir; skipping ' 'pretraining restore') return pretrain_ckpt = tf.train.get_checkpoint_state(model_dir) if not (pretrain_ckpt and pretrain_ckpt.model_checkpoint_path): raise ValueError( 'Asked to restore model from %s but no checkpoint found.' % model_dir) saver_for_restore.restore(sess, pretrain_ckpt.model_checkpoint_path) def train_step(sess, train_op, loss, global_step): """Runs a single training step.""" start_time = time.time() _, loss_val, global_step_val = sess.run([train_op, loss, global_step]) duration = time.time() - start_time # Logging if global_step_val % 10 == 0: examples_per_sec = FLAGS.batch_size / duration sec_per_batch = float(duration) format_str = ('step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') tf.logging.info(format_str % (global_step_val, loss_val, examples_per_sec, sec_per_batch)) if np.isnan(loss_val): raise OverflowError('Loss is nan') return global_step_val