tobiasc's picture
Initial commit
ad16788
raw
history blame contribute delete
No virus
11.9 kB
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Class Declaration of Transformer's Training Subprocess."""
import collections
import logging
import math
import six
from chainer import cuda
from chainer import functions as F
from chainer import training
from chainer.training import extension
from chainer.training.updaters.multiprocess_parallel_updater import gather_grads
from chainer.training.updaters.multiprocess_parallel_updater import gather_params
from chainer.training.updaters.multiprocess_parallel_updater import scatter_grads
import numpy as np
# copied from https://github.com/chainer/chainer/blob/master/chainer/optimizer.py
def sum_sqnorm(arr):
"""Calculate the norm of the array.
Args:
arr (numpy.ndarray)
Returns:
Float: Sum of the norm calculated from the given array.
"""
sq_sum = collections.defaultdict(float)
for x in arr:
with cuda.get_device_from_array(x) as dev:
if x is not None:
x = x.ravel()
s = x.dot(x)
sq_sum[int(dev)] += s
return sum([float(i) for i in six.itervalues(sq_sum)])
class CustomUpdater(training.StandardUpdater):
"""Custom updater for chainer.
Args:
train_iter (iterator | dict[str, iterator]): Dataset iterator for the
training dataset. It can also be a dictionary that maps strings to
iterators. If this is just an iterator, then the iterator is
registered by the name ``'main'``.
optimizer (optimizer | dict[str, optimizer]): Optimizer to update
parameters. It can also be a dictionary that maps strings to
optimizers. If this is just an optimizer, then the optimizer is
registered by the name ``'main'``.
converter (espnet.asr.chainer_backend.asr.CustomConverter): Converter
function to build input arrays. Each batch extracted by the main
iterator and the ``device`` option are passed to this function.
:func:`chainer.dataset.concat_examples` is used by default.
device (int or dict): The destination device info to send variables. In the
case of cpu or single gpu, `device=-1 or 0`, respectively.
In the case of multi-gpu, `device={"main":0, "sub_1": 1, ...}`.
accum_grad (int):The number of gradient accumulation. if set to 2, the network
parameters will be updated once in twice,
i.e. actual batchsize will be doubled.
"""
def __init__(self, train_iter, optimizer, converter, device, accum_grad=1):
"""Initialize Custom Updater."""
super(CustomUpdater, self).__init__(
train_iter, optimizer, converter=converter, device=device
)
self.accum_grad = accum_grad
self.forward_count = 0
self.start = True
self.device = device
logging.debug("using custom converter for transformer")
# The core part of the update routine can be customized by overriding.
def update_core(self):
"""Process main update routine for Custom Updater."""
train_iter = self.get_iterator("main")
optimizer = self.get_optimizer("main")
# Get batch and convert into variables
batch = train_iter.next()
x = self.converter(batch, self.device)
if self.start:
optimizer.target.cleargrads()
self.start = False
# Compute the loss at this time step and accumulate it
loss = optimizer.target(*x) / self.accum_grad
loss.backward() # Backprop
self.forward_count += 1
if self.forward_count != self.accum_grad:
return
self.forward_count = 0
# compute the gradient norm to check if it is normal or not
grad_norm = np.sqrt(
sum_sqnorm([p.grad for p in optimizer.target.params(False)])
)
logging.info("grad norm={}".format(grad_norm))
if math.isnan(grad_norm):
logging.warning("grad norm is nan. Do not update model.")
else:
optimizer.update()
optimizer.target.cleargrads() # Clear the parameter gradients
def update(self):
"""Update step for Custom Updater."""
self.update_core()
if self.forward_count == 0:
self.iteration += 1
class CustomParallelUpdater(training.updaters.MultiprocessParallelUpdater):
"""Custom Parallel Updater for chainer.
Defines the main update routine.
Args:
train_iter (iterator | dict[str, iterator]): Dataset iterator for the
training dataset. It can also be a dictionary that maps strings to
iterators. If this is just an iterator, then the iterator is
registered by the name ``'main'``.
optimizer (optimizer | dict[str, optimizer]): Optimizer to update
parameters. It can also be a dictionary that maps strings to
optimizers. If this is just an optimizer, then the optimizer is
registered by the name ``'main'``.
converter (espnet.asr.chainer_backend.asr.CustomConverter): Converter
function to build input arrays. Each batch extracted by the main
iterator and the ``device`` option are passed to this function.
:func:`chainer.dataset.concat_examples` is used by default.
device (torch.device): Device to which the training data is sent. Negative value
indicates the host memory (CPU).
accum_grad (int):The number of gradient accumulation. if set to 2, the network
parameters will be updated once in twice,
i.e. actual batchsize will be doubled.
"""
def __init__(self, train_iters, optimizer, converter, devices, accum_grad=1):
"""Initialize custom parallel updater."""
from cupy.cuda import nccl
super(CustomParallelUpdater, self).__init__(
train_iters, optimizer, converter=converter, devices=devices
)
self.accum_grad = accum_grad
self.forward_count = 0
self.nccl = nccl
logging.debug("using custom parallel updater for transformer")
# The core part of the update routine can be customized by overriding.
def update_core(self):
"""Process main update routine for Custom Parallel Updater."""
self.setup_workers()
self._send_message(("update", None))
with cuda.Device(self._devices[0]):
# For reducing memory
optimizer = self.get_optimizer("main")
batch = self.get_iterator("main").next()
x = self.converter(batch, self._devices[0])
loss = self._master(*x) / self.accum_grad
loss.backward()
# NCCL: reduce grads
null_stream = cuda.Stream.null
if self.comm is not None:
gg = gather_grads(self._master)
self.comm.reduce(
gg.data.ptr,
gg.data.ptr,
gg.size,
self.nccl.NCCL_FLOAT,
self.nccl.NCCL_SUM,
0,
null_stream.ptr,
)
scatter_grads(self._master, gg)
del gg
# update parameters
self.forward_count += 1
if self.forward_count != self.accum_grad:
return
self.forward_count = 0
# check gradient value
grad_norm = np.sqrt(
sum_sqnorm([p.grad for p in optimizer.target.params(False)])
)
logging.info("grad norm={}".format(grad_norm))
# update
if math.isnan(grad_norm):
logging.warning("grad norm is nan. Do not update model.")
else:
optimizer.update()
self._master.cleargrads()
if self.comm is not None:
gp = gather_params(self._master)
self.comm.bcast(
gp.data.ptr, gp.size, self.nccl.NCCL_FLOAT, 0, null_stream.ptr
)
def update(self):
"""Update step for Custom Parallel Updater."""
self.update_core()
if self.forward_count == 0:
self.iteration += 1
class VaswaniRule(extension.Extension):
"""Trainer extension to shift an optimizer attribute magically by Vaswani.
Args:
attr (str): Name of the attribute to shift.
rate (float): Rate of the exponential shift. This value is multiplied
to the attribute at each call.
init (float): Initial value of the attribute. If it is ``None``, the
extension extracts the attribute at the first call and uses it as
the initial value.
target (float): Target value of the attribute. If the attribute reaches
this value, the shift stops.
optimizer (~chainer.Optimizer): Target optimizer to adjust the
attribute. If it is ``None``, the main optimizer of the updater is
used.
"""
def __init__(
self,
attr,
d,
warmup_steps=4000,
init=None,
target=None,
optimizer=None,
scale=1.0,
):
"""Initialize Vaswani rule extension."""
self._attr = attr
self._d_inv05 = d ** (-0.5) * scale
self._warmup_steps_inv15 = warmup_steps ** (-1.5)
self._init = init
self._target = target
self._optimizer = optimizer
self._t = 0
self._last_value = None
def initialize(self, trainer):
"""Initialize Optimizer values."""
optimizer = self._get_optimizer(trainer)
# ensure that _init is set
if self._init is None:
self._init = self._d_inv05 * (1.0 * self._warmup_steps_inv15)
if self._last_value is not None: # resuming from a snapshot
self._update_value(optimizer, self._last_value)
else:
self._update_value(optimizer, self._init)
def __call__(self, trainer):
"""Forward extension."""
self._t += 1
optimizer = self._get_optimizer(trainer)
value = self._d_inv05 * min(
self._t ** (-0.5), self._t * self._warmup_steps_inv15
)
self._update_value(optimizer, value)
def serialize(self, serializer):
"""Serialize extension."""
self._t = serializer("_t", self._t)
self._last_value = serializer("_last_value", self._last_value)
def _get_optimizer(self, trainer):
"""Obtain optimizer from trainer."""
return self._optimizer or trainer.updater.get_optimizer("main")
def _update_value(self, optimizer, value):
"""Update requested variable values."""
setattr(optimizer, self._attr, value)
self._last_value = value
class CustomConverter(object):
"""Custom Converter.
Args:
subsampling_factor (int): The subsampling factor.
"""
def __init__(self):
"""Initialize subsampling."""
pass
def __call__(self, batch, device):
"""Perform subsampling.
Args:
batch (list): Batch that will be sabsampled.
device (chainer.backend.Device): CPU or GPU device.
Returns:
chainer.Variable: xp.array that are padded and subsampled from batch.
xp.array: xp.array of the length of the mini-batches.
chainer.Variable: xp.array that are padded and subsampled from batch.
"""
# For transformer, data is processed in CPU.
# batch should be located in list
assert len(batch) == 1
xs, ys = batch[0]
xs = F.pad_sequence(xs, padding=-1).data
# get batch of lengths of input sequences
ilens = np.array([x.shape[0] for x in xs], dtype=np.int32)
return xs, ilens, ys