File size: 6,633 Bytes
c6e7238 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
import mesh_tensorflow as mtf
import tensorflow.compat.v1 as tf
def clip_by_global_norm(grads, clip_norm):
"""Clip the grads by global norm."""
global_norm = mtf.sqrt(mtf.add_n([mtf.reduce_sum(mtf.square(t)) for t in grads if t is not None]))
multiplier = clip_norm / mtf.maximum(global_norm, clip_norm)
clipped_grads = [None if t is None else t * multiplier for t in grads]
return clipped_grads, global_norm
def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None):
"""Creates and returns an optimizer training op."""
global_step = tf.train.get_or_create_global_step()
learning_rate = tf.constant(value=params["lr"], shape=[], dtype=variable_dtype.slice_dtype)
clip_value = mtf.constant(mesh, params["gradient_clipping"], dtype=variable_dtype.slice_dtype)
if inp_var_grads is None:
var_grads = mtf.gradients([loss], [v.outputs[0] for v in mesh.graph.trainable_variables])
else:
var_grads = inp_var_grads
# Cast to full precision
var_grads_fp = [mtf.cast(v, variable_dtype.slice_dtype) for v in var_grads]
# decrease LR to final lr (lr*0.1) by this step - defaults to train_steps
end_step = params.get("lr_decay_end", params["train_steps"])
if params["lr_decay"] == "linear":
learning_rate = tf.train.polynomial_decay(
learning_rate,
global_step,
end_step,
end_learning_rate=params["lr"]*0.1, # Decrease to 10% of initial LR according to GPT-3 paper
power=1.0,
cycle=False)
elif params["lr_decay"] == "cosine":
learning_rate = tf.train.cosine_decay(
learning_rate,
global_step,
end_step,
alpha=0.1 # Alpha is min lr value as a fraction of init lr.
)
if params["warmup_steps"] > 0:
global_steps_int = tf.cast(global_step, tf.int32)
warmup_steps_int = tf.constant(params["warmup_steps"], dtype=tf.int32)
dtype = variable_dtype.slice_dtype
global_steps_float = tf.cast(global_steps_int, dtype)
warmup_steps_float = tf.cast(warmup_steps_int, dtype)
warmup_percent_done = global_steps_float / warmup_steps_float
warmup_learning_rate = learning_rate * warmup_percent_done
is_warmup = tf.cast(global_steps_int < warmup_steps_int, dtype)
learning_rate = ((1.0 - is_warmup) * learning_rate +
is_warmup * warmup_learning_rate)
learning_rate = mtf.import_fully_replicated(mesh, learning_rate, mtf.Shape([]), name="learning_rate")
mtf.scalar_summary("lr", learning_rate)
if params["opt_name"].lower() == "adam":
optimizer = AdamWeightDecayOptimizer(
learning_rate=learning_rate,
weight_decay_rate=params["weight_decay"],
beta_1=params["beta1"],
beta_2=params["beta2"],
epsilon=params["epsilon"],
exclude_from_weight_decay=["norm", "bias"],
variable_dtype=variable_dtype
)
else:
optimizer = mtf.optimize.AdafactorOptimizer(
learning_rate=params["lr"],
decay_rate=params["weight_decay"],
beta1=params["beta1"],
epsilon1=params["ada_epsilon1"],
epsilon2=params["ada_epsilon2"]
)
if params["gradient_clipping"] is not None:
(var_grads_fp, _) = clip_by_global_norm(var_grads_fp, clip_norm=clip_value)
update_ops = optimizer.apply_grads(var_grads_fp, mesh.graph.trainable_variables)
return learning_rate, update_ops, var_grads_fp
class AdamWeightDecayOptimizer(mtf.optimize.Optimizer):
"""A basic Adam optimizer that includes "correct" L2 weight decay."""
def __init__(self,
learning_rate,
weight_decay_rate=0.0,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-6,
exclude_from_weight_decay=None,
variable_dtype=None):
"""Constructs a AdamWeightDecayOptimizer."""
self.learning_rate = learning_rate
self.weight_decay_rate = weight_decay_rate
self.beta_1 = beta_1
self.beta_2 = beta_2
self.epsilon = epsilon
self.exclude_from_weight_decay = exclude_from_weight_decay
self.variable_dtype = variable_dtype
def apply_grad(self, grad, var):
"""See base class."""
if grad is None:
tf.logging.warning("Gradient is None for variable %s" % var.name)
return []
grad = mtf.to_float(grad)
assignments = []
m = mtf.get_variable(
var.mesh, var.name + "/adam_m", var.shape,
initializer=tf.zeros_initializer(),
# master_dtype=self.variable_dtype.master_dtype,
# slice_dtype=self.variable_dtype.slice_dtype,
# activation_dtype=self.variable_dtype.activation_dtype,
trainable=False)
v = mtf.get_variable(
var.mesh, var.name + "/adam_v", var.shape,
initializer=tf.zeros_initializer(),
# master_dtype=self.variable_dtype.master_dtype,
# slice_dtype=self.variable_dtype.slice_dtype,
# activation_dtype=self.variable_dtype.activation_dtype,
trainable=False)
# Standard Adam update.
next_m = self.beta_1 * m + (1.0 - self.beta_1) * grad
next_v = self.beta_2 * v + (1.0 - self.beta_2) * mtf.square(grad)
update = next_m / (mtf.sqrt(next_v) + self.epsilon)
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
if self._do_use_weight_decay(var.name):
update += mtf.to_float(var.value) * self.weight_decay_rate
update_with_lr = self.learning_rate * update
var_update = mtf.assign_sub(var, update_with_lr)
assignments.extend(
[var_update,
mtf.assign(m, next_m),
mtf.assign(v, next_v)])
return assignments
def _do_use_weight_decay(self, param_name):
"""Whether to use L2 weight decay for `param_name`."""
if not self.weight_decay_rate:
return False
if self.exclude_from_weight_decay:
for r in self.exclude_from_weight_decay:
if re.search(r, param_name) is not None:
return False
return True |