# Copyright 2018 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """AutoAugment Train/Eval module. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import contextlib import os import time import custom_ops as ops import data_utils import helper_utils import numpy as np from shake_drop import build_shake_drop_model from shake_shake import build_shake_shake_model import tensorflow as tf from wrn import build_wrn_model tf.flags.DEFINE_string('model_name', 'wrn', 'wrn, shake_shake_32, shake_shake_96, shake_shake_112, ' 'pyramid_net') tf.flags.DEFINE_string('checkpoint_dir', '/tmp/training', 'Training Directory.') tf.flags.DEFINE_string('data_path', '/tmp/data', 'Directory where dataset is located.') tf.flags.DEFINE_string('dataset', 'cifar10', 'Dataset to train with. Either cifar10 or cifar100') tf.flags.DEFINE_integer('use_cpu', 1, '1 if use CPU, else GPU.') FLAGS = tf.flags.FLAGS arg_scope = tf.contrib.framework.arg_scope def setup_arg_scopes(is_training): """Sets up the argscopes that will be used when building an image model. Args: is_training: Is the model training or not. Returns: Arg scopes to be put around the model being constructed. """ batch_norm_decay = 0.9 batch_norm_epsilon = 1e-5 batch_norm_params = { # Decay for the moving averages. 'decay': batch_norm_decay, # epsilon to prevent 0s in variance. 'epsilon': batch_norm_epsilon, 'scale': True, # collection containing the moving mean and moving variance. 'is_training': is_training, } scopes = [] scopes.append(arg_scope([ops.batch_norm], **batch_norm_params)) return scopes def build_model(inputs, num_classes, is_training, hparams): """Constructs the vision model being trained/evaled. Args: inputs: input features/images being fed to the image model build built. num_classes: number of output classes being predicted. is_training: is the model training or not. hparams: additional hyperparameters associated with the image model. Returns: The logits of the image model. """ scopes = setup_arg_scopes(is_training) with contextlib.nested(*scopes): if hparams.model_name == 'pyramid_net': logits = build_shake_drop_model( inputs, num_classes, is_training) elif hparams.model_name == 'wrn': logits = build_wrn_model( inputs, num_classes, hparams.wrn_size) elif hparams.model_name == 'shake_shake': logits = build_shake_shake_model( inputs, num_classes, hparams, is_training) return logits class CifarModel(object): """Builds an image model for Cifar10/Cifar100.""" def __init__(self, hparams): self.hparams = hparams def build(self, mode): """Construct the cifar model.""" assert mode in ['train', 'eval'] self.mode = mode self._setup_misc(mode) self._setup_images_and_labels() self._build_graph(self.images, self.labels, mode) self.init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) def _setup_misc(self, mode): """Sets up miscellaneous in the cifar model constructor.""" self.lr_rate_ph = tf.Variable(0.0, name='lrn_rate', trainable=False) self.reuse = None if (mode == 'train') else True self.batch_size = self.hparams.batch_size if mode == 'eval': self.batch_size = 25 def _setup_images_and_labels(self): """Sets up image and label placeholders for the cifar model.""" if FLAGS.dataset == 'cifar10': self.num_classes = 10 else: self.num_classes = 100 self.images = tf.placeholder(tf.float32, [self.batch_size, 32, 32, 3]) self.labels = tf.placeholder(tf.float32, [self.batch_size, self.num_classes]) def assign_epoch(self, session, epoch_value): session.run(self._epoch_update, feed_dict={self._new_epoch: epoch_value}) def _build_graph(self, images, labels, mode): """Constructs the TF graph for the cifar model. Args: images: A 4-D image Tensor labels: A 2-D labels Tensor. mode: string indicating training mode ( e.g., 'train', 'valid', 'test'). """ is_training = 'train' in mode if is_training: self.global_step = tf.train.get_or_create_global_step() logits = build_model( images, self.num_classes, is_training, self.hparams) self.predictions, self.cost = helper_utils.setup_loss( logits, labels) self.accuracy, self.eval_op = tf.metrics.accuracy( tf.argmax(labels, 1), tf.argmax(self.predictions, 1)) self._calc_num_trainable_params() # Adds L2 weight decay to the cost self.cost = helper_utils.decay_weights(self.cost, self.hparams.weight_decay_rate) if is_training: self._build_train_op() # Setup checkpointing for this child model # Keep 2 or more checkpoints around during training. with tf.device('/cpu:0'): self.saver = tf.train.Saver(max_to_keep=2) self.init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) def _calc_num_trainable_params(self): self.num_trainable_params = np.sum([ np.prod(var.get_shape().as_list()) for var in tf.trainable_variables() ]) tf.logging.info('number of trainable params: {}'.format( self.num_trainable_params)) def _build_train_op(self): """Builds the train op for the cifar model.""" hparams = self.hparams tvars = tf.trainable_variables() grads = tf.gradients(self.cost, tvars) if hparams.gradient_clipping_by_global_norm > 0.0: grads, norm = tf.clip_by_global_norm( grads, hparams.gradient_clipping_by_global_norm) tf.summary.scalar('grad_norm', norm) # Setup the initial learning rate initial_lr = self.lr_rate_ph optimizer = tf.train.MomentumOptimizer( initial_lr, 0.9, use_nesterov=True) self.optimizer = optimizer apply_op = optimizer.apply_gradients( zip(grads, tvars), global_step=self.global_step, name='train_step') train_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies([apply_op]): self.train_op = tf.group(*train_ops) class CifarModelTrainer(object): """Trains an instance of the CifarModel class.""" def __init__(self, hparams): self._session = None self.hparams = hparams self.model_dir = os.path.join(FLAGS.checkpoint_dir, 'model') self.log_dir = os.path.join(FLAGS.checkpoint_dir, 'log') # Set the random seed to be sure the same validation set # is used for each model np.random.seed(0) self.data_loader = data_utils.DataSet(hparams) np.random.seed() # Put the random seed back to random self.data_loader.reset() def save_model(self, step=None): """Dumps model into the backup_dir. Args: step: If provided, creates a checkpoint with the given step number, instead of overwriting the existing checkpoints. """ model_save_name = os.path.join(self.model_dir, 'model.ckpt') if not tf.gfile.IsDirectory(self.model_dir): tf.gfile.MakeDirs(self.model_dir) self.saver.save(self.session, model_save_name, global_step=step) tf.logging.info('Saved child model') def extract_model_spec(self): """Loads a checkpoint with the architecture structure stored in the name.""" checkpoint_path = tf.train.latest_checkpoint(self.model_dir) if checkpoint_path is not None: self.saver.restore(self.session, checkpoint_path) tf.logging.info('Loaded child model checkpoint from %s', checkpoint_path) else: self.save_model(step=0) def eval_child_model(self, model, data_loader, mode): """Evaluate the child model. Args: model: image model that will be evaluated. data_loader: dataset object to extract eval data from. mode: will the model be evalled on train, val or test. Returns: Accuracy of the model on the specified dataset. """ tf.logging.info('Evaluating child model in mode %s', mode) while True: try: with self._new_session(model): accuracy = helper_utils.eval_child_model( self.session, model, data_loader, mode) tf.logging.info('Eval child model accuracy: {}'.format(accuracy)) # If epoch trained without raising the below errors, break # from loop. break except (tf.errors.AbortedError, tf.errors.UnavailableError) as e: tf.logging.info('Retryable error caught: %s. Retrying.', e) return accuracy @contextlib.contextmanager def _new_session(self, m): """Creates a new session for model m.""" # Create a new session for this model, initialize # variables, and save / restore from # checkpoint. self._session = tf.Session( '', config=tf.ConfigProto( allow_soft_placement=True, log_device_placement=False)) self.session.run(m.init) # Load in a previous checkpoint, or save this one self.extract_model_spec() try: yield finally: tf.Session.reset('') self._session = None def _build_models(self): """Builds the image models for train and eval.""" # Determine if we should build the train and eval model. When using # distributed training we only want to build one or the other and not both. with tf.variable_scope('model', use_resource=False): m = CifarModel(self.hparams) m.build('train') self._num_trainable_params = m.num_trainable_params self._saver = m.saver with tf.variable_scope('model', reuse=True, use_resource=False): meval = CifarModel(self.hparams) meval.build('eval') return m, meval def _calc_starting_epoch(self, m): """Calculates the starting epoch for model m based on global step.""" hparams = self.hparams batch_size = hparams.batch_size steps_per_epoch = int(hparams.train_size / batch_size) with self._new_session(m): curr_step = self.session.run(m.global_step) total_steps = steps_per_epoch * hparams.num_epochs epochs_left = (total_steps - curr_step) // steps_per_epoch starting_epoch = hparams.num_epochs - epochs_left return starting_epoch def _run_training_loop(self, m, curr_epoch): """Trains the cifar model `m` for one epoch.""" start_time = time.time() while True: try: with self._new_session(m): train_accuracy = helper_utils.run_epoch_training( self.session, m, self.data_loader, curr_epoch) tf.logging.info('Saving model after epoch') self.save_model(step=curr_epoch) break except (tf.errors.AbortedError, tf.errors.UnavailableError) as e: tf.logging.info('Retryable error caught: %s. Retrying.', e) tf.logging.info('Finished epoch: {}'.format(curr_epoch)) tf.logging.info('Epoch time(min): {}'.format( (time.time() - start_time) / 60.0)) return train_accuracy def _compute_final_accuracies(self, meval): """Run once training is finished to compute final val/test accuracies.""" valid_accuracy = self.eval_child_model(meval, self.data_loader, 'val') if self.hparams.eval_test: test_accuracy = self.eval_child_model(meval, self.data_loader, 'test') else: test_accuracy = 0 tf.logging.info('Test Accuracy: {}'.format(test_accuracy)) return valid_accuracy, test_accuracy def run_model(self): """Trains and evalutes the image model.""" hparams = self.hparams # Build the child graph with tf.Graph().as_default(), tf.device( '/cpu:0' if FLAGS.use_cpu else '/gpu:0'): m, meval = self._build_models() # Figure out what epoch we are on starting_epoch = self._calc_starting_epoch(m) # Run the validation error right at the beginning valid_accuracy = self.eval_child_model( meval, self.data_loader, 'val') tf.logging.info('Before Training Epoch: {} Val Acc: {}'.format( starting_epoch, valid_accuracy)) training_accuracy = None for curr_epoch in xrange(starting_epoch, hparams.num_epochs): # Run one training epoch training_accuracy = self._run_training_loop(m, curr_epoch) valid_accuracy = self.eval_child_model( meval, self.data_loader, 'val') tf.logging.info('Epoch: {} Valid Acc: {}'.format( curr_epoch, valid_accuracy)) valid_accuracy, test_accuracy = self._compute_final_accuracies( meval) tf.logging.info( 'Train Acc: {} Valid Acc: {} Test Acc: {}'.format( training_accuracy, valid_accuracy, test_accuracy)) @property def saver(self): return self._saver @property def session(self): return self._session @property def num_trainable_params(self): return self._num_trainable_params def main(_): if FLAGS.dataset not in ['cifar10', 'cifar100']: raise ValueError('Invalid dataset: %s' % FLAGS.dataset) hparams = tf.contrib.training.HParams( train_size=50000, validation_size=0, eval_test=1, dataset=FLAGS.dataset, data_path=FLAGS.data_path, batch_size=128, gradient_clipping_by_global_norm=5.0) if FLAGS.model_name == 'wrn': hparams.add_hparam('model_name', 'wrn') hparams.add_hparam('num_epochs', 200) hparams.add_hparam('wrn_size', 160) hparams.add_hparam('lr', 0.1) hparams.add_hparam('weight_decay_rate', 5e-4) elif FLAGS.model_name == 'shake_shake_32': hparams.add_hparam('model_name', 'shake_shake') hparams.add_hparam('num_epochs', 1800) hparams.add_hparam('shake_shake_widen_factor', 2) hparams.add_hparam('lr', 0.01) hparams.add_hparam('weight_decay_rate', 0.001) elif FLAGS.model_name == 'shake_shake_96': hparams.add_hparam('model_name', 'shake_shake') hparams.add_hparam('num_epochs', 1800) hparams.add_hparam('shake_shake_widen_factor', 6) hparams.add_hparam('lr', 0.01) hparams.add_hparam('weight_decay_rate', 0.001) elif FLAGS.model_name == 'shake_shake_112': hparams.add_hparam('model_name', 'shake_shake') hparams.add_hparam('num_epochs', 1800) hparams.add_hparam('shake_shake_widen_factor', 7) hparams.add_hparam('lr', 0.01) hparams.add_hparam('weight_decay_rate', 0.001) elif FLAGS.model_name == 'pyramid_net': hparams.add_hparam('model_name', 'pyramid_net') hparams.add_hparam('num_epochs', 1800) hparams.add_hparam('lr', 0.05) hparams.add_hparam('weight_decay_rate', 5e-5) hparams.batch_size = 64 else: raise ValueError('Not Valid Model Name: %s' % FLAGS.model_name) cifar_trainer = CifarModelTrainer(hparams) cifar_trainer.run_model() if __name__ == '__main__': tf.logging.set_verbosity(tf.logging.INFO) tf.app.run()