Cletrason commited on
Commit
5304b54
1 Parent(s): 8245979

Create optimization _tf.py

Browse files
Files changed (1) hide show
  1. optimization _tf.py +371 -0
optimization _tf.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 The TensorFlow Authors, The Hugging Face Team. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Functions and classes related to optimization (weight updates)."""
16
+
17
+
18
+ import re
19
+ from typing import Callable, List, Optional, Union
20
+
21
+ import tensorflow as tf
22
+
23
+
24
+ try:
25
+ from tensorflow.keras.optimizers.legacy import Adam
26
+ except ImportError:
27
+ from tensorflow.keras.optimizers import Adam
28
+
29
+
30
+ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
31
+ """
32
+ Applies a warmup schedule on a given learning rate decay schedule.
33
+
34
+ Args:
35
+ initial_learning_rate (`float`):
36
+ The initial learning rate for the schedule after the warmup (so this will be the learning rate at the end
37
+ of the warmup).
38
+ decay_schedule_fn (`Callable`):
39
+ The schedule function to apply after the warmup for the rest of training.
40
+ warmup_steps (`int`):
41
+ The number of steps for the warmup part of training.
42
+ power (`float`, *optional*, defaults to 1):
43
+ The power to use for the polynomial warmup (defaults is a linear warmup).
44
+ name (`str`, *optional*):
45
+ Optional name prefix for the returned tensors during the schedule.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ initial_learning_rate: float,
51
+ decay_schedule_fn: Callable,
52
+ warmup_steps: int,
53
+ power: float = 1.0,
54
+ name: str = None,
55
+ ):
56
+ super().__init__()
57
+ self.initial_learning_rate = initial_learning_rate
58
+ self.warmup_steps = warmup_steps
59
+ self.power = power
60
+ self.decay_schedule_fn = decay_schedule_fn
61
+ self.name = name
62
+
63
+ def __call__(self, step):
64
+ with tf.name_scope(self.name or "WarmUp") as name:
65
+ # Implements polynomial warmup. i.e., if global_step < warmup_steps, the
66
+ # learning rate will be `global_step/num_warmup_steps * init_lr`.
67
+ global_step_float = tf.cast(step, tf.float32)
68
+ warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
69
+ warmup_percent_done = global_step_float / warmup_steps_float
70
+ warmup_learning_rate = self.initial_learning_rate * tf.math.pow(warmup_percent_done, self.power)
71
+ return tf.cond(
72
+ global_step_float < warmup_steps_float,
73
+ lambda: warmup_learning_rate,
74
+ lambda: self.decay_schedule_fn(step - self.warmup_steps),
75
+ name=name,
76
+ )
77
+
78
+ def get_config(self):
79
+ return {
80
+ "initial_learning_rate": self.initial_learning_rate,
81
+ "decay_schedule_fn": self.decay_schedule_fn,
82
+ "warmup_steps": self.warmup_steps,
83
+ "power": self.power,
84
+ "name": self.name,
85
+ }
86
+
87
+
88
+ def create_optimizer(
89
+ init_lr: float,
90
+ num_train_steps: int,
91
+ num_warmup_steps: int,
92
+ min_lr_ratio: float = 0.0,
93
+ adam_beta1: float = 0.9,
94
+ adam_beta2: float = 0.999,
95
+ adam_epsilon: float = 1e-8,
96
+ adam_clipnorm: Optional[float] = None,
97
+ adam_global_clipnorm: Optional[float] = None,
98
+ weight_decay_rate: float = 0.0,
99
+ power: float = 1.0,
100
+ include_in_weight_decay: Optional[List[str]] = None,
101
+ ):
102
+ """
103
+ Creates an optimizer with a learning rate schedule using a warmup phase followed by a linear decay.
104
+
105
+ Args:
106
+ init_lr (`float`):
107
+ The desired learning rate at the end of the warmup phase.
108
+ num_train_steps (`int`):
109
+ The total number of training steps.
110
+ num_warmup_steps (`int`):
111
+ The number of warmup steps.
112
+ min_lr_ratio (`float`, *optional*, defaults to 0):
113
+ The final learning rate at the end of the linear decay will be `init_lr * min_lr_ratio`.
114
+ adam_beta1 (`float`, *optional*, defaults to 0.9):
115
+ The beta1 to use in Adam.
116
+ adam_beta2 (`float`, *optional*, defaults to 0.999):
117
+ The beta2 to use in Adam.
118
+ adam_epsilon (`float`, *optional*, defaults to 1e-8):
119
+ The epsilon to use in Adam.
120
+ adam_clipnorm: (`float`, *optional*, defaults to `None`):
121
+ If not `None`, clip the gradient norm for each weight tensor to this value.
122
+ adam_global_clipnorm: (`float`, *optional*, defaults to `None`)
123
+ If not `None`, clip gradient norm to this value. When using this argument, the norm is computed over all
124
+ weight tensors, as if they were concatenated into a single vector.
125
+ weight_decay_rate (`float`, *optional*, defaults to 0):
126
+ The weight decay to use.
127
+ power (`float`, *optional*, defaults to 1.0):
128
+ The power to use for PolynomialDecay.
129
+ include_in_weight_decay (`List[str]`, *optional*):
130
+ List of the parameter names (or re patterns) to apply weight decay to. If none is passed, weight decay is
131
+ applied to all parameters except bias and layer norm parameters.
132
+ """
133
+ # Implements linear decay of the learning rate.
134
+ lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
135
+ initial_learning_rate=init_lr,
136
+ decay_steps=num_train_steps - num_warmup_steps,
137
+ end_learning_rate=init_lr * min_lr_ratio,
138
+ power=power,
139
+ )
140
+ if num_warmup_steps:
141
+ lr_schedule = WarmUp(
142
+ initial_learning_rate=init_lr,
143
+ decay_schedule_fn=lr_schedule,
144
+ warmup_steps=num_warmup_steps,
145
+ )
146
+ if weight_decay_rate > 0.0:
147
+ optimizer = AdamWeightDecay(
148
+ learning_rate=lr_schedule,
149
+ weight_decay_rate=weight_decay_rate,
150
+ beta_1=adam_beta1,
151
+ beta_2=adam_beta2,
152
+ epsilon=adam_epsilon,
153
+ clipnorm=adam_clipnorm,
154
+ global_clipnorm=adam_global_clipnorm,
155
+ exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
156
+ include_in_weight_decay=include_in_weight_decay,
157
+ )
158
+ else:
159
+ optimizer = tf.keras.optimizers.Adam(
160
+ learning_rate=lr_schedule,
161
+ beta_1=adam_beta1,
162
+ beta_2=adam_beta2,
163
+ epsilon=adam_epsilon,
164
+ clipnorm=adam_clipnorm,
165
+ global_clipnorm=adam_global_clipnorm,
166
+ )
167
+ # We return the optimizer and the LR scheduler in order to better track the
168
+ # evolution of the LR independently of the optimizer.
169
+ return optimizer, lr_schedule
170
+
171
+
172
+ class AdamWeightDecay(Adam):
173
+ """
174
+ Adam enables L2 weight decay and clip_by_global_norm on gradients. Just adding the square of the weights to the
175
+ loss function is *not* the correct way of using L2 regularization/weight decay with Adam, since that will interact
176
+ with the m and v parameters in strange ways as shown in [Decoupled Weight Decay
177
+ Regularization](https://arxiv.org/abs/1711.05101).
178
+
179
+ Instead we want to decay the weights in a manner that doesn't interact with the m/v parameters. This is equivalent
180
+ to adding the square of the weights to the loss with plain (non-momentum) SGD.
181
+
182
+ Args:
183
+ learning_rate (`Union[float, tf.keras.optimizers.schedules.LearningRateSchedule]`, *optional*, defaults to 1e-3):
184
+ The learning rate to use or a schedule.
185
+ beta_1 (`float`, *optional*, defaults to 0.9):
186
+ The beta1 parameter in Adam, which is the exponential decay rate for the 1st momentum estimates.
187
+ beta_2 (`float`, *optional*, defaults to 0.999):
188
+ The beta2 parameter in Adam, which is the exponential decay rate for the 2nd momentum estimates.
189
+ epsilon (`float`, *optional*, defaults to 1e-7):
190
+ The epsilon parameter in Adam, which is a small constant for numerical stability.
191
+ amsgrad (`bool`, *optional*, default to `False`):
192
+ Whether to apply AMSGrad variant of this algorithm or not, see [On the Convergence of Adam and
193
+ Beyond](https://arxiv.org/abs/1904.09237).
194
+ weight_decay_rate (`float`, *optional*, defaults to 0):
195
+ The weight decay to apply.
196
+ include_in_weight_decay (`List[str]`, *optional*):
197
+ List of the parameter names (or re patterns) to apply weight decay to. If none is passed, weight decay is
198
+ applied to all parameters by default (unless they are in `exclude_from_weight_decay`).
199
+ exclude_from_weight_decay (`List[str]`, *optional*):
200
+ List of the parameter names (or re patterns) to exclude from applying weight decay to. If a
201
+ `include_in_weight_decay` is passed, the names in it will supersede this list.
202
+ name (`str`, *optional*, defaults to 'AdamWeightDecay'):
203
+ Optional name for the operations created when applying gradients.
204
+ kwargs:
205
+ Keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
206
+ norm; `clipvalue` is clip gradients by value, `decay` is included for backward compatibility to allow time
207
+ inverse decay of learning rate. `lr` is included for backward compatibility, recommended to use
208
+ `learning_rate` instead.
209
+ """
210
+
211
+ def __init__(
212
+ self,
213
+ learning_rate: Union[float, tf.keras.optimizers.schedules.LearningRateSchedule] = 0.001,
214
+ beta_1: float = 0.9,
215
+ beta_2: float = 0.999,
216
+ epsilon: float = 1e-7,
217
+ amsgrad: bool = False,
218
+ weight_decay_rate: float = 0.0,
219
+ include_in_weight_decay: Optional[List[str]] = None,
220
+ exclude_from_weight_decay: Optional[List[str]] = None,
221
+ name: str = "AdamWeightDecay",
222
+ **kwargs,
223
+ ):
224
+ super().__init__(learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs)
225
+ self.weight_decay_rate = weight_decay_rate
226
+ self._include_in_weight_decay = include_in_weight_decay
227
+ self._exclude_from_weight_decay = exclude_from_weight_decay
228
+
229
+ @classmethod
230
+ def from_config(cls, config):
231
+ """Creates an optimizer from its config with WarmUp custom object."""
232
+ custom_objects = {"WarmUp": WarmUp}
233
+ return super(AdamWeightDecay, cls).from_config(config, custom_objects=custom_objects)
234
+
235
+ def _prepare_local(self, var_device, var_dtype, apply_state):
236
+ super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype, apply_state)
237
+ apply_state[(var_device, var_dtype)]["weight_decay_rate"] = tf.constant(
238
+ self.weight_decay_rate, name="adam_weight_decay_rate"
239
+ )
240
+
241
+ def _decay_weights_op(self, var, learning_rate, apply_state):
242
+ do_decay = self._do_use_weight_decay(var.name)
243
+ if do_decay:
244
+ return var.assign_sub(
245
+ learning_rate * var * apply_state[(var.device, var.dtype.base_dtype)]["weight_decay_rate"],
246
+ use_locking=self._use_locking,
247
+ )
248
+ return tf.no_op()
249
+
250
+ def apply_gradients(self, grads_and_vars, name=None, **kwargs):
251
+ grads, tvars = list(zip(*grads_and_vars))
252
+ return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars), name=name, **kwargs)
253
+
254
+ def _get_lr(self, var_device, var_dtype, apply_state):
255
+ """Retrieves the learning rate with the given state."""
256
+ if apply_state is None:
257
+ return self._decayed_lr_t[var_dtype], {}
258
+
259
+ apply_state = apply_state or {}
260
+ coefficients = apply_state.get((var_device, var_dtype))
261
+ if coefficients is None:
262
+ coefficients = self._fallback_apply_state(var_device, var_dtype)
263
+ apply_state[(var_device, var_dtype)] = coefficients
264
+
265
+ return coefficients["lr_t"], {"apply_state": apply_state}
266
+
267
+ def _resource_apply_dense(self, grad, var, apply_state=None):
268
+ lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
269
+ decay = self._decay_weights_op(var, lr_t, apply_state)
270
+ with tf.control_dependencies([decay]):
271
+ return super(AdamWeightDecay, self)._resource_apply_dense(grad, var, **kwargs)
272
+
273
+ def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
274
+ lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
275
+ decay = self._decay_weights_op(var, lr_t, apply_state)
276
+ with tf.control_dependencies([decay]):
277
+ return super(AdamWeightDecay, self)._resource_apply_sparse(grad, var, indices, **kwargs)
278
+
279
+ def get_config(self):
280
+ config = super().get_config()
281
+ config.update({"weight_decay_rate": self.weight_decay_rate})
282
+ return config
283
+
284
+ def _do_use_weight_decay(self, param_name):
285
+ """Whether to use L2 weight decay for `param_name`."""
286
+ if self.weight_decay_rate == 0:
287
+ return False
288
+
289
+ if self._include_in_weight_decay:
290
+ for r in self._include_in_weight_decay:
291
+ if re.search(r, param_name) is not None:
292
+ return True
293
+
294
+ if self._exclude_from_weight_decay:
295
+ for r in self._exclude_from_weight_decay:
296
+ if re.search(r, param_name) is not None:
297
+ return False
298
+ return True
299
+
300
+
301
+ # Extracted from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py
302
+ class GradientAccumulator(object):
303
+ """
304
+ Gradient accumulation utility. When used with a distribution strategy, the accumulator should be called in a
305
+ replica context. Gradients will be accumulated locally on each replica and without synchronization. Users should
306
+ then call `.gradients`, scale the gradients if required, and pass the result to `apply_gradients`.
307
+ """
308
+
309
+ # We use the ON_READ synchronization policy so that no synchronization is
310
+ # performed on assignment. To get the value, we call .value() which returns the
311
+ # value on the current replica without synchronization.
312
+
313
+ def __init__(self):
314
+ """Initializes the accumulator."""
315
+ self._gradients = []
316
+ self._accum_steps = None
317
+
318
+ @property
319
+ def step(self):
320
+ """Number of accumulated steps."""
321
+ if self._accum_steps is None:
322
+ self._accum_steps = tf.Variable(
323
+ tf.constant(0, dtype=tf.int64),
324
+ trainable=False,
325
+ synchronization=tf.VariableSynchronization.ON_READ,
326
+ aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
327
+ )
328
+
329
+ return self._accum_steps.value()
330
+
331
+ @property
332
+ def gradients(self):
333
+ """The accumulated gradients on the current replica."""
334
+ if not self._gradients:
335
+ raise ValueError("The accumulator should be called first to initialize the gradients")
336
+ return [gradient.value() if gradient is not None else gradient for gradient in self._gradients]
337
+
338
+ def __call__(self, gradients):
339
+ """Accumulates `gradients` on the current replica."""
340
+ if not self._gradients:
341
+ _ = self.step # Create the step variable.
342
+ self._gradients.extend(
343
+ [
344
+ tf.Variable(
345
+ tf.zeros_like(gradient),
346
+ trainable=False,
347
+ synchronization=tf.VariableSynchronization.ON_READ,
348
+ aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
349
+ )
350
+ if gradient is not None
351
+ else gradient
352
+ for gradient in gradients
353
+ ]
354
+ )
355
+ if len(gradients) != len(self._gradients):
356
+ raise ValueError(f"Expected {len(self._gradients)} gradients, but got {len(gradients)}")
357
+
358
+ for accum_gradient, gradient in zip(self._gradients, gradients):
359
+ if accum_gradient is not None and gradient is not None:
360
+ accum_gradient.assign_add(gradient)
361
+
362
+ self._accum_steps.assign_add(1)
363
+
364
+ def reset(self):
365
+ """Resets the accumulated gradients on the current replica."""
366
+ if not self._gradients:
367
+ return
368
+ self._accum_steps.assign(0)
369
+ for gradient in self._gradients:
370
+ if gradient is not None:
371
+ gradient.assign(tf.zeros_like(gradient))