|
|
|
|
|
|
|
"""RNN sequence-to-sequence speech recognition model (chainer).""" |
|
|
|
import logging |
|
import math |
|
|
|
import chainer |
|
from chainer import reporter |
|
import numpy as np |
|
|
|
from espnet.nets.chainer_backend.asr_interface import ChainerASRInterface |
|
from espnet.nets.chainer_backend.ctc import ctc_for |
|
from espnet.nets.chainer_backend.rnn.attentions import att_for |
|
from espnet.nets.chainer_backend.rnn.decoders import decoder_for |
|
from espnet.nets.chainer_backend.rnn.encoders import encoder_for |
|
from espnet.nets.e2e_asr_common import label_smoothing_dist |
|
from espnet.nets.pytorch_backend.e2e_asr import E2E as E2E_pytorch |
|
from espnet.nets.pytorch_backend.nets_utils import get_subsample |
|
|
|
CTC_LOSS_THRESHOLD = 10000 |
|
|
|
|
|
class E2E(ChainerASRInterface): |
|
"""E2E module for chainer backend. |
|
|
|
Args: |
|
idim (int): Dimension of the inputs. |
|
odim (int): Dimension of the outputs. |
|
args (parser.args): Training config. |
|
flag_return (bool): If True, train() would return |
|
additional metrics in addition to the training |
|
loss. |
|
|
|
""" |
|
|
|
@staticmethod |
|
def add_arguments(parser): |
|
"""Add arguments.""" |
|
return E2E_pytorch.add_arguments(parser) |
|
|
|
def get_total_subsampling_factor(self): |
|
"""Get total subsampling factor.""" |
|
return self.enc.conv_subsampling_factor * int(np.prod(self.subsample)) |
|
|
|
def __init__(self, idim, odim, args, flag_return=True): |
|
"""Construct an E2E object. |
|
|
|
:param int idim: dimension of inputs |
|
:param int odim: dimension of outputs |
|
:param Namespace args: argument Namespace containing options |
|
""" |
|
chainer.Chain.__init__(self) |
|
self.mtlalpha = args.mtlalpha |
|
assert 0 <= self.mtlalpha <= 1, "mtlalpha must be [0,1]" |
|
self.etype = args.etype |
|
self.verbose = args.verbose |
|
self.char_list = args.char_list |
|
self.outdir = args.outdir |
|
|
|
|
|
|
|
self.sos = odim - 1 |
|
self.eos = odim - 1 |
|
|
|
|
|
self.subsample = get_subsample(args, mode="asr", arch="rnn") |
|
|
|
|
|
if args.lsm_type: |
|
logging.info("Use label smoothing with " + args.lsm_type) |
|
labeldist = label_smoothing_dist( |
|
odim, args.lsm_type, transcript=args.train_json |
|
) |
|
else: |
|
labeldist = None |
|
|
|
with self.init_scope(): |
|
|
|
self.enc = encoder_for(args, idim, self.subsample) |
|
|
|
self.ctc = ctc_for(args, odim) |
|
|
|
self.att = att_for(args) |
|
|
|
self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) |
|
|
|
self.acc = None |
|
self.loss = None |
|
self.flag_return = flag_return |
|
|
|
def forward(self, xs, ilens, ys): |
|
"""E2E forward propagation. |
|
|
|
Args: |
|
xs (chainer.Variable): Batch of padded charactor ids. (B, Tmax) |
|
ilens (chainer.Variable): Batch of length of each input batch. (B,) |
|
ys (chainer.Variable): Batch of padded target features. (B, Lmax, odim) |
|
|
|
Returns: |
|
float: Loss that calculated by attention and ctc loss. |
|
float (optional): Ctc loss. |
|
float (optional): Attention loss. |
|
float (optional): Accuracy. |
|
|
|
""" |
|
|
|
hs, ilens = self.enc(xs, ilens) |
|
|
|
|
|
if self.mtlalpha == 0: |
|
loss_ctc = None |
|
else: |
|
loss_ctc = self.ctc(hs, ys) |
|
|
|
|
|
if self.mtlalpha == 1: |
|
loss_att = None |
|
acc = None |
|
else: |
|
loss_att, acc = self.dec(hs, ys) |
|
|
|
self.acc = acc |
|
alpha = self.mtlalpha |
|
if alpha == 0: |
|
self.loss = loss_att |
|
elif alpha == 1: |
|
self.loss = loss_ctc |
|
else: |
|
self.loss = alpha * loss_ctc + (1 - alpha) * loss_att |
|
|
|
if self.loss.data < CTC_LOSS_THRESHOLD and not math.isnan(self.loss.data): |
|
reporter.report({"loss_ctc": loss_ctc}, self) |
|
reporter.report({"loss_att": loss_att}, self) |
|
reporter.report({"acc": acc}, self) |
|
|
|
logging.info("mtl loss:" + str(self.loss.data)) |
|
reporter.report({"loss": self.loss}, self) |
|
else: |
|
logging.warning("loss (=%f) is not correct", self.loss.data) |
|
if self.flag_return: |
|
return self.loss, loss_ctc, loss_att, acc |
|
else: |
|
return self.loss |
|
|
|
def recognize(self, x, recog_args, char_list, rnnlm=None): |
|
"""E2E greedy/beam search. |
|
|
|
Args: |
|
x (chainer.Variable): Input tensor for recognition. |
|
recog_args (parser.args): Arguments of config file. |
|
char_list (List[str]): List of Charactors. |
|
rnnlm (Module): RNNLM module defined at `espnet.lm.chainer_backend.lm`. |
|
|
|
Returns: |
|
List[Dict[str, Any]]: Result of recognition. |
|
|
|
""" |
|
|
|
x = x[:: self.subsample[0], :] |
|
ilen = self.xp.array(x.shape[0], dtype=np.int32) |
|
h = chainer.Variable(self.xp.array(x, dtype=np.float32)) |
|
|
|
with chainer.no_backprop_mode(), chainer.using_config("train", False): |
|
|
|
|
|
h, _ = self.enc([h], [ilen]) |
|
|
|
|
|
if recog_args.ctc_weight > 0.0: |
|
lpz = self.ctc.log_softmax(h).data[0] |
|
else: |
|
lpz = None |
|
|
|
|
|
|
|
y = self.dec.recognize_beam(h[0], lpz, recog_args, char_list, rnnlm) |
|
|
|
return y |
|
|
|
def calculate_all_attentions(self, xs, ilens, ys): |
|
"""E2E attention calculation. |
|
|
|
Args: |
|
xs (List): List of padded input sequences. [(T1, idim), (T2, idim), ...] |
|
ilens (np.ndarray): Batch of lengths of input sequences. (B) |
|
ys (List): List of character id sequence tensor. [(L1), (L2), (L3), ...] |
|
|
|
Returns: |
|
float np.ndarray: Attention weights. (B, Lmax, Tmax) |
|
|
|
""" |
|
hs, ilens = self.enc(xs, ilens) |
|
att_ws = self.dec.calculate_all_attentions(hs, ys) |
|
|
|
return att_ws |
|
|
|
@staticmethod |
|
def custom_converter(subsampling_factor=0): |
|
"""Get customconverter of the model.""" |
|
from espnet.nets.chainer_backend.rnn.training import CustomConverter |
|
|
|
return CustomConverter(subsampling_factor=subsampling_factor) |
|
|
|
@staticmethod |
|
def custom_updater(iters, optimizer, converter, device=-1, accum_grad=1): |
|
"""Get custom_updater of the model.""" |
|
from espnet.nets.chainer_backend.rnn.training import CustomUpdater |
|
|
|
return CustomUpdater( |
|
iters, optimizer, converter=converter, device=device, accum_grad=accum_grad |
|
) |
|
|
|
@staticmethod |
|
def custom_parallel_updater(iters, optimizer, converter, devices, accum_grad=1): |
|
"""Get custom_parallel_updater of the model.""" |
|
from espnet.nets.chainer_backend.rnn.training import CustomParallelUpdater |
|
|
|
return CustomParallelUpdater( |
|
iters, |
|
optimizer, |
|
converter=converter, |
|
devices=devices, |
|
accum_grad=accum_grad, |
|
) |
|
|