File size: 10,874 Bytes
506da10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
# coding=utf-8
# Copyright 2021 The Deeplab2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This file contains code to create a Trainer for training and validation."""

from typing import Dict, Any, Text
import orbit
import tensorflow as tf

from deeplab2 import common
from deeplab2 import config_pb2
from deeplab2.model import utils
from deeplab2.trainer import runner_utils


class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
  """Applies a warmup schedule on a given learning rate decay schedule."""

  def __init__(self,
               initial_learning_rate,
               decay_schedule_fn,
               warmup_steps,
               name=None):
    super(WarmUp, self).__init__()
    self.initial_learning_rate = initial_learning_rate
    self.warmup_steps = warmup_steps
    self.decay_schedule_fn = decay_schedule_fn
    self.name = name

  def __call__(self, step):
    with tf.name_scope(self.name or 'WarmUp') as name:
      # Implements linear warmup. i.e., if global_step < warmup_steps, the
      # learning rate will be `global_step/num_warmup_steps * init_lr`.
      global_step_float = tf.cast(step, tf.float32)
      warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
      warmup_percent_done = global_step_float / warmup_steps_float
      warmup_learning_rate = self.initial_learning_rate * warmup_percent_done
      return tf.cond(
          global_step_float < warmup_steps_float,
          lambda: warmup_learning_rate,
          lambda: self.decay_schedule_fn(step),
          name=name)

  def get_config(self):
    return {
        'initial_learning_rate': self.initial_learning_rate,
        'decay_schedule_fn': self.decay_schedule_fn,
        'warmup_steps': self.warmup_steps,
        'name': self.name
    }


def _create_optimizer(
    solver_config: config_pb2.SolverOptions,
    learning_rate_multiplier: float = 1.0) -> tf.keras.optimizers.Optimizer:
  """Creates an Optimizer based on the configuration.

  Args:
    solver_config: A trainer_pb2.SolverOptions configuration.
    learning_rate_multiplier: A float, the learning rate multiplier applied on
      top of the base learning rate. Default to 1.0.

  Returns:
    A tf.keras.optimizer.Optimizer.

  Raises:
    ValueError: An error occurs when the desired optimizer or learning rate
      scheduler is not supported.
  """
  learning_rate = (solver_config.base_learning_rate * learning_rate_multiplier)
  if solver_config.learning_policy == 'poly':
    lr_scheduler = tf.keras.optimizers.schedules.PolynomialDecay(
        initial_learning_rate=learning_rate,
        decay_steps=solver_config.training_number_of_steps,
        end_learning_rate=solver_config.poly_end_learning_rate,
        power=solver_config.poly_learning_power,
        cycle=False)
  elif solver_config.learning_policy == 'cosine':
    lr_scheduler = tf.keras.experimental.CosineDecay(
        initial_learning_rate=learning_rate,
        decay_steps=solver_config.training_number_of_steps,
        alpha=0.0)
  else:
    raise ValueError('Learning rate policy %s is not supported.' %
                     solver_config.learning_policy)

  if solver_config.warmup_steps:
    lr_scheduler = WarmUp(
        initial_learning_rate=learning_rate,
        decay_schedule_fn=lr_scheduler,
        warmup_steps=solver_config.warmup_steps,
        name='linear_warmup')

  if solver_config.optimizer == 'adam':
    return tf.keras.optimizers.Adam(learning_rate=lr_scheduler)
  elif solver_config.optimizer == 'sgd':
    # We use momentum = 0.9, the most frequently used case.
    return tf.keras.optimizers.SGD(learning_rate=lr_scheduler,
                                   momentum=0.9)

  raise ValueError('Optimizer %s is not supported.' % solver_config.optimizer)


class Trainer(orbit.StandardTrainer):
  """Implements a Trainer for training DeepLab models."""

  def __init__(self, config: config_pb2.ExperimentOptions,
               model: tf.keras.Model, loss: tf.keras.losses.Loss,
               global_step: tf.Variable):
    """Initializes the trainer.

    Args:
      config: A config_pb2.ExperimentOptions configuration.
      model: A tf.keras.Model.
      loss: A tf.keras.losses.Loss.
      global_step: A tf.Variable that records the global training step.
    """
    self._strategy = tf.distribute.get_strategy()

    support_panoptic = (common.TASK_PANOPTIC_SEGMENTATION in
                        utils.get_supported_tasks(config))
    train_dataset = runner_utils.create_dataset(
        config.train_dataset_options,
        is_training=True,
        only_semantic_annotations=not support_panoptic)
    train_dataset = orbit.utils.make_distributed_dataset(
        self.strategy, train_dataset)
    super(Trainer, self).__init__(train_dataset)

    self._config = config
    self._model = model
    self._loss = loss

    solver_options = config.trainer_options.solver_options
    self._optimizer = _create_optimizer(solver_options)
    self._backbone_optimizer = None
    if solver_options.HasField('backbone_learning_rate_multiplier'):
      self._backbone_optimizer = _create_optimizer(
          solver_options, learning_rate_multiplier=(
              solver_options.backbone_learning_rate_multiplier))

    self._global_step = global_step
    self._use_gradient_clipping = solver_options.use_gradient_clipping
    self._clip_gradient_norm = solver_options.clip_gradient_norm

    self._train_loss_metric_dict = runner_utils.create_loss_metric_dict(
        loss.get_loss_names(), prefix='train_')

  def train_loop_begin(self):
    """Called once at the beginning of the training loop.

    This method is called before dataset iterators creation.
    """
    for metric in self._train_loss_metric_dict.values():
      metric.reset_states()

  def _apply_gradients_to_optimizers(self, gradients_and_variables):
    """Applies gradients to their optimizers.

    This function divides all trainable variables (and their gradients) into
    two groups. One group contains backbone variables that have been pretrained,
    e.g., on ImageNet classification. The other group contains all other
    variables that are added specifically for the dense prediction task, e.g.,
    panoptic segmentation. Then, we apply two optimizers, optionally with two
    learning rates, to the variables and gradients.

    Args:
      gradients_and_variables: A list of tuple of (gradient, variable) tensors.
    """
    if self._backbone_optimizer is None:
      self._optimizer.apply_gradients(gradients_and_variables)
    else:
      optimizer_inputs = []
      backbone_optimizer_inputs = []

      encoder = self._model.checkpoint_items['encoder']
      encoder_variable_names = [x.name for x in encoder.trainable_variables]
      encoder_name = self._config.model_options.backbone.name

      for gradient, variable in gradients_and_variables:
        if runner_utils.check_if_variable_in_backbone(variable, encoder_name,
                                                      encoder_variable_names):
          backbone_optimizer_inputs.append((gradient, variable))
        else:
          optimizer_inputs.append((gradient, variable))
      self._optimizer.apply_gradients(optimizer_inputs)
      self._backbone_optimizer.apply_gradients(backbone_optimizer_inputs)

  def train_step(self, iterator):
    """Implements one step of training.

    Runs one step of evaluation with respect to the chosen strategy. In case of
    a distributed strategy, the replica results are gathered and returned.

    Note that all operations within `_train_step` are tf.function compatible, as
    they will be traced with tf.function. Any other/numpy operations are put in
    `train_loop_begin` or `train_loop_end` functions.

    Args:
      iterator: A tf.nest-compatible structure of tf.data Iterator or
        DistributedIterator.
    """

    def step_fn(inputs):
      self._train_step(inputs)
      self._global_step.assign_add(1)

    self._strategy.run(step_fn, args=(next(iterator),))

  def _train_step(self, inputs: Dict[Text, Any]):
    """Performs a forward and backward pass.

    Args:
      inputs: A dictionary to be consumed by the model.
    """
    with tf.GradientTape() as tape:
      outputs = self._model(inputs[common.IMAGE], training=True)
      # Get the average per-batch loss and scale it down by the number of
      # replicas. This ensures that we don't end up multiplying our loss by the
      # number of workers - gradients are summed, not averaged, across replicas
      # during the apply_gradients call.
      loss_dict = self._loss(inputs, outputs)
      # Average over the batch.
      average_loss_dict = {
          key: tf.reduce_mean(value) for key, value in loss_dict.items()}
      total_loss = average_loss_dict[common.TOTAL_LOSS]
      scaled_loss = total_loss / self.strategy.num_replicas_in_sync

    training_vars = self._model.trainable_variables
    gradients = tape.gradient(scaled_loss, training_vars)

    # Apply gradient clipping.
    if self._clip_gradient_norm > 0.0 and self._use_gradient_clipping:
      gradients, _ = tf.clip_by_global_norm(gradients, self._clip_gradient_norm)

    self._apply_gradients_to_optimizers(list(zip(gradients, training_vars)))

    for name, value in average_loss_dict.items():
      self._train_loss_metric_dict[name].update_state(value)

  def train_loop_end(self) -> Dict[Text, tf.Tensor]:
    """Called at the end of the training loop.

    The value returned from this function will be returned as-is from the
    train() method.

    Returns:
      A dictionary of `Tensors`, which will be written to logs and as
      TensorBoard summaries.
    """
    train_logs = {}
    for loss_metric in self._train_loss_metric_dict.values():
      train_logs['losses/' + loss_metric.name] = loss_metric.result()

    if callable(self._optimizer.learning_rate):
      train_logs['learning_rate'] = self._optimizer.learning_rate(
          self._global_step)
    else:
      train_logs['learning_rate'] = self._optimizer.learning_rate
    return train_logs

  @property
  def optimizer(self):
    return self._optimizer

  @property
  def backbone_optimizer(self):
    return self._backbone_optimizer

  @property
  def strategy(self):
    return self._strategy

  @property
  def global_step(self):
    return self._global_step

  @property
  def model(self):
    return self._model