vishred18's picture
Upload 364 files
d5ee97c
raw
history blame contribute delete
No virus
3.19 kB
"""Gradient Accummlate for training TF2 custom training loop.
Copy from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py.
"""
import re
import tensorflow as tf
class GradientAccumulator(object):
"""Gradient accumulation utility.
When used with a distribution strategy, the accumulator should be called in a
replica context. Gradients will be accumulated locally on each replica and
without synchronization. Users should then call ``.gradients``, scale the
gradients if required, and pass the result to ``apply_gradients``.
"""
# We use the ON_READ synchronization policy so that no synchronization is
# performed on assignment. To get the value, we call .value() which returns the
# value on the current replica without synchronization.
def __init__(self):
"""Initializes the accumulator."""
self._gradients = []
self._accum_steps = None
@property
def step(self):
"""Number of accumulated steps."""
if self._accum_steps is None:
self._accum_steps = tf.Variable(
tf.constant(0, dtype=tf.int64),
trainable=False,
synchronization=tf.VariableSynchronization.ON_READ,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
)
return self._accum_steps.value()
@property
def gradients(self):
"""The accumulated gradients on the current replica."""
if not self._gradients:
raise ValueError(
"The accumulator should be called first to initialize the gradients"
)
return list(
gradient.value() if gradient is not None else gradient
for gradient in self._gradients
)
def __call__(self, gradients):
"""Accumulates :obj:`gradients` on the current replica."""
if not self._gradients:
_ = self.step # Create the step variable.
self._gradients.extend(
[
tf.Variable(
tf.zeros_like(gradient),
trainable=False,
synchronization=tf.VariableSynchronization.ON_READ,
)
if gradient is not None
else gradient
for gradient in gradients
]
)
if len(gradients) != len(self._gradients):
raise ValueError(
"Expected %s gradients, but got %d"
% (len(self._gradients), len(gradients))
)
for accum_gradient, gradient in zip(self._gradients, gradients):
if accum_gradient is not None and gradient is not None:
accum_gradient.assign_add(gradient, read_value=False)
self._accum_steps.assign_add(1)
def reset(self):
"""Resets the accumulated gradients on the current replica."""
if not self._gradients:
return
self._accum_steps.assign(0)
for gradient in self._gradients:
if gradient is not None:
gradient.assign(tf.zeros_like(gradient), read_value=False)