File size: 5,376 Bytes
97b6013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Helper functions used for training AutoAugment models."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf


def setup_loss(logits, labels):
  """Returns the cross entropy for the given `logits` and `labels`."""
  predictions = tf.nn.softmax(logits)
  cost = tf.losses.softmax_cross_entropy(onehot_labels=labels,
                                         logits=logits)
  return predictions, cost


def decay_weights(cost, weight_decay_rate):
  """Calculates the loss for l2 weight decay and adds it to `cost`."""
  costs = []
  for var in tf.trainable_variables():
    costs.append(tf.nn.l2_loss(var))
  cost += tf.multiply(weight_decay_rate, tf.add_n(costs))
  return cost


def eval_child_model(session, model, data_loader, mode):
  """Evaluates `model` on held out data depending on `mode`.

  Args:
    session: TensorFlow session the model will be run with.
    model: TensorFlow model that will be evaluated.
    data_loader: DataSet object that contains data that `model` will
      evaluate.
    mode: Will `model` either evaluate validation or test data.

  Returns:
    Accuracy of `model` when evaluated on the specified dataset.

  Raises:
    ValueError: if invalid dataset `mode` is specified.
  """
  if mode == 'val':
    images = data_loader.val_images
    labels = data_loader.val_labels
  elif mode == 'test':
    images = data_loader.test_images
    labels = data_loader.test_labels
  else:
    raise ValueError('Not valid eval mode')
  assert len(images) == len(labels)
  tf.logging.info('model.batch_size is {}'.format(model.batch_size))
  assert len(images) % model.batch_size == 0
  eval_batches = int(len(images) / model.batch_size)
  for i in range(eval_batches):
    eval_images = images[i * model.batch_size:(i + 1) * model.batch_size]
    eval_labels = labels[i * model.batch_size:(i + 1) * model.batch_size]
    _ = session.run(
        model.eval_op,
        feed_dict={
            model.images: eval_images,
            model.labels: eval_labels,
        })
  return session.run(model.accuracy)


def cosine_lr(learning_rate, epoch, iteration, batches_per_epoch, total_epochs):
  """Cosine Learning rate.

  Args:
    learning_rate: Initial learning rate.
    epoch: Current epoch we are one. This is one based.
    iteration: Current batch in this epoch.
    batches_per_epoch: Batches per epoch.
    total_epochs: Total epochs you are training for.

  Returns:
    The learning rate to be used for this current batch.
  """
  t_total = total_epochs * batches_per_epoch
  t_cur = float(epoch * batches_per_epoch + iteration)
  return 0.5 * learning_rate * (1 + np.cos(np.pi * t_cur / t_total))


def get_lr(curr_epoch, hparams, iteration=None):
  """Returns the learning rate during training based on the current epoch."""
  assert iteration is not None
  batches_per_epoch = int(hparams.train_size / hparams.batch_size)
  lr = cosine_lr(hparams.lr, curr_epoch, iteration, batches_per_epoch,
                 hparams.num_epochs)
  return lr


def run_epoch_training(session, model, data_loader, curr_epoch):
  """Runs one epoch of training for the model passed in.

  Args:
    session: TensorFlow session the model will be run with.
    model: TensorFlow model that will be evaluated.
    data_loader: DataSet object that contains data that `model` will
      evaluate.
    curr_epoch: How many of epochs of training have been done so far.

  Returns:
    The accuracy of 'model' on the training set
  """
  steps_per_epoch = int(model.hparams.train_size / model.hparams.batch_size)
  tf.logging.info('steps per epoch: {}'.format(steps_per_epoch))
  curr_step = session.run(model.global_step)
  assert curr_step % steps_per_epoch == 0

  # Get the current learning rate for the model based on the current epoch
  curr_lr = get_lr(curr_epoch, model.hparams, iteration=0)
  tf.logging.info('lr of {} for epoch {}'.format(curr_lr, curr_epoch))

  for step in xrange(steps_per_epoch):
    curr_lr = get_lr(curr_epoch, model.hparams, iteration=(step + 1))
    # Update the lr rate variable to the current LR.
    model.lr_rate_ph.load(curr_lr, session=session)
    if step % 20 == 0:
      tf.logging.info('Training {}/{}'.format(step, steps_per_epoch))

    train_images, train_labels = data_loader.next_batch()
    _, step, _ = session.run(
        [model.train_op, model.global_step, model.eval_op],
        feed_dict={
            model.images: train_images,
            model.labels: train_labels,
        })

  train_accuracy = session.run(model.accuracy)
  tf.logging.info('Train accuracy: {}'.format(train_accuracy))
  return train_accuracy