OFA-Image_Caption / fairseq /fairseq /optim /dynamic_loss_scaler.py
JustinLin610
update
8437114
raw history blame
No virus
2.64 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
class DynamicLossScaler(object):
def __init__(
self,
init_scale=2.0 ** 15,
scale_factor=2.0,
scale_window=2000,
tolerance=0.0,
threshold=None,
min_loss_scale=1e-4,
):
self.loss_scale = init_scale
self.scale_factor = scale_factor
self.scale_window = scale_window
self.tolerance = tolerance
self.threshold = threshold
self._iter = 0
self._last_overflow_iter = -1
self._last_rescale_iter = -1
self._overflows_since_rescale = 0
self.min_loss_scale = min_loss_scale
def scale(self, outputs):
return self.loss_scale * outputs
def update(self):
if (self._iter - self._last_overflow_iter) % self.scale_window == 0:
self.loss_scale *= self.scale_factor
self._last_rescale_iter = self._iter
self._iter += 1
def _decrease_loss_scale(self):
self.loss_scale /= self.scale_factor
if self.threshold is not None:
self.loss_scale = max(self.loss_scale, self.threshold)
def check_overflow(self, grad_norm):
# detect inf and nan
if grad_norm == float("inf") or grad_norm != grad_norm:
# overflow has occured
prev_scale = self.loss_scale
iter_since_rescale = self._iter - self._last_rescale_iter
self._last_overflow_iter = self._iter
self._overflows_since_rescale += 1
pct_overflow = self._overflows_since_rescale / float(iter_since_rescale)
if pct_overflow >= self.tolerance:
self._decrease_loss_scale()
self._last_rescale_iter = self._iter
self._overflows_since_rescale = 0
if self.loss_scale <= self.min_loss_scale:
# Use FloatingPointError as an uncommon error that parent
# functions can safely catch to stop training.
self.loss_scale = prev_scale
raise FloatingPointError(
(
"Minimum loss scale reached ({}). Your loss is probably exploding. "
"Try lowering the learning rate, using gradient clipping or "
"increasing the batch size."
).format(self.min_loss_scale)
)
self._iter += 1
raise OverflowError("setting loss scale to: " + str(self.loss_scale))