|
|
|
|
|
|
|
|
|
""" |
|
Mask CTC based non-autoregressive speech recognition model (pytorch). |
|
|
|
See https://arxiv.org/abs/2005.08700 for the detail. |
|
|
|
""" |
|
|
|
from itertools import groupby |
|
import logging |
|
import math |
|
|
|
from distutils.util import strtobool |
|
import numpy |
|
import torch |
|
|
|
from espnet.nets.pytorch_backend.conformer.encoder import Encoder |
|
from espnet.nets.pytorch_backend.conformer.argument import ( |
|
add_arguments_conformer_common, |
|
) |
|
from espnet.nets.pytorch_backend.e2e_asr import CTC_LOSS_THRESHOLD |
|
from espnet.nets.pytorch_backend.e2e_asr_transformer import E2E as E2ETransformer |
|
from espnet.nets.pytorch_backend.maskctc.add_mask_token import mask_uniform |
|
from espnet.nets.pytorch_backend.maskctc.mask import square_mask |
|
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask |
|
from espnet.nets.pytorch_backend.nets_utils import th_accuracy |
|
|
|
|
|
class E2E(E2ETransformer): |
|
"""E2E module. |
|
|
|
:param int idim: dimension of inputs |
|
:param int odim: dimension of outputs |
|
:param Namespace args: argument Namespace containing options |
|
|
|
""" |
|
|
|
@staticmethod |
|
def add_arguments(parser): |
|
"""Add arguments.""" |
|
E2ETransformer.add_arguments(parser) |
|
E2E.add_maskctc_arguments(parser) |
|
|
|
return parser |
|
|
|
@staticmethod |
|
def add_maskctc_arguments(parser): |
|
"""Add arguments for maskctc model.""" |
|
group = parser.add_argument_group("maskctc specific setting") |
|
|
|
group.add_argument( |
|
"--maskctc-use-conformer-encoder", |
|
default=False, |
|
type=strtobool, |
|
) |
|
group = add_arguments_conformer_common(group) |
|
|
|
return parser |
|
|
|
def __init__(self, idim, odim, args, ignore_id=-1): |
|
"""Construct an E2E object. |
|
|
|
:param int idim: dimension of inputs |
|
:param int odim: dimension of outputs |
|
:param Namespace args: argument Namespace containing options |
|
""" |
|
odim += 1 |
|
|
|
super().__init__(idim, odim, args, ignore_id) |
|
assert 0.0 <= self.mtlalpha < 1.0, "mtlalpha should be [0.0, 1.0)" |
|
|
|
self.mask_token = odim - 1 |
|
self.sos = odim - 2 |
|
self.eos = odim - 2 |
|
self.odim = odim |
|
|
|
if args.maskctc_use_conformer_encoder: |
|
if args.transformer_attn_dropout_rate is None: |
|
args.transformer_attn_dropout_rate = args.conformer_dropout_rate |
|
self.encoder = Encoder( |
|
idim=idim, |
|
attention_dim=args.adim, |
|
attention_heads=args.aheads, |
|
linear_units=args.eunits, |
|
num_blocks=args.elayers, |
|
input_layer=args.transformer_input_layer, |
|
dropout_rate=args.dropout_rate, |
|
positional_dropout_rate=args.dropout_rate, |
|
attention_dropout_rate=args.transformer_attn_dropout_rate, |
|
pos_enc_layer_type=args.transformer_encoder_pos_enc_layer_type, |
|
selfattention_layer_type=args.transformer_encoder_selfattn_layer_type, |
|
activation_type=args.transformer_encoder_activation_type, |
|
macaron_style=args.macaron_style, |
|
use_cnn_module=args.use_cnn_module, |
|
cnn_module_kernel=args.cnn_module_kernel, |
|
) |
|
self.reset_parameters(args) |
|
|
|
def forward(self, xs_pad, ilens, ys_pad): |
|
"""E2E forward. |
|
|
|
:param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) |
|
:param torch.Tensor ilens: batch of lengths of source sequences (B) |
|
:param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) |
|
:return: ctc loss value |
|
:rtype: torch.Tensor |
|
:return: attention loss value |
|
:rtype: torch.Tensor |
|
:return: accuracy in attention decoder |
|
:rtype: float |
|
""" |
|
|
|
xs_pad = xs_pad[:, : max(ilens)] |
|
src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) |
|
hs_pad, hs_mask = self.encoder(xs_pad, src_mask) |
|
self.hs_pad = hs_pad |
|
|
|
|
|
ys_in_pad, ys_out_pad = mask_uniform( |
|
ys_pad, self.mask_token, self.eos, self.ignore_id |
|
) |
|
ys_mask = square_mask(ys_in_pad, self.eos) |
|
pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) |
|
self.pred_pad = pred_pad |
|
|
|
|
|
loss_att = self.criterion(pred_pad, ys_out_pad) |
|
self.acc = th_accuracy( |
|
pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id |
|
) |
|
|
|
|
|
loss_ctc, cer_ctc = None, None |
|
if self.mtlalpha > 0: |
|
batch_size = xs_pad.size(0) |
|
hs_len = hs_mask.view(batch_size, -1).sum(1) |
|
loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) |
|
if self.error_calculator is not None: |
|
ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data |
|
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) |
|
|
|
if not self.training: |
|
self.ctc.softmax(hs_pad) |
|
|
|
|
|
if self.training or self.error_calculator is None or self.decoder is None: |
|
cer, wer = None, None |
|
else: |
|
ys_hat = pred_pad.argmax(dim=-1) |
|
cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) |
|
|
|
alpha = self.mtlalpha |
|
if alpha == 0: |
|
self.loss = loss_att |
|
loss_att_data = float(loss_att) |
|
loss_ctc_data = None |
|
else: |
|
self.loss = alpha * loss_ctc + (1 - alpha) * loss_att |
|
loss_att_data = float(loss_att) |
|
loss_ctc_data = float(loss_ctc) |
|
|
|
loss_data = float(self.loss) |
|
if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): |
|
self.reporter.report( |
|
loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data |
|
) |
|
else: |
|
logging.warning("loss (=%f) is not correct", loss_data) |
|
return self.loss |
|
|
|
def recognize(self, x, recog_args, char_list=None, rnnlm=None): |
|
"""Recognize input speech. |
|
|
|
:param ndnarray x: input acoustic feature (B, T, D) or (T, D) |
|
:param Namespace recog_args: argment Namespace contraining options |
|
:param list char_list: list of characters |
|
:param torch.nn.Module rnnlm: language model module |
|
:return: decoding result |
|
:rtype: list |
|
""" |
|
|
|
def num2str(char_list, mask_token, mask_char="_"): |
|
def f(yl): |
|
cl = [char_list[y] if y != mask_token else mask_char for y in yl] |
|
return "".join(cl).replace("<space>", " ") |
|
|
|
return f |
|
|
|
n2s = num2str(char_list, self.mask_token) |
|
|
|
self.eval() |
|
h = self.encode(x).unsqueeze(0) |
|
|
|
|
|
ctc_probs, ctc_ids = torch.exp(self.ctc.log_softmax(h)).max(dim=-1) |
|
y_hat = torch.stack([x[0] for x in groupby(ctc_ids[0])]) |
|
y_idx = torch.nonzero(y_hat != 0).squeeze(-1) |
|
|
|
|
|
|
|
|
|
probs_hat = [] |
|
cnt = 0 |
|
for i, y in enumerate(y_hat.tolist()): |
|
probs_hat.append(-1) |
|
while cnt < ctc_ids.shape[1] and y == ctc_ids[0][cnt]: |
|
if probs_hat[i] < ctc_probs[0][cnt]: |
|
probs_hat[i] = ctc_probs[0][cnt].item() |
|
cnt += 1 |
|
probs_hat = torch.from_numpy(numpy.array(probs_hat)) |
|
|
|
|
|
p_thres = recog_args.maskctc_probability_threshold |
|
mask_idx = torch.nonzero(probs_hat[y_idx] < p_thres).squeeze(-1) |
|
confident_idx = torch.nonzero(probs_hat[y_idx] >= p_thres).squeeze(-1) |
|
mask_num = len(mask_idx) |
|
|
|
y_in = torch.zeros(1, len(y_idx), dtype=torch.long) + self.mask_token |
|
y_in[0][confident_idx] = y_hat[y_idx][confident_idx] |
|
|
|
logging.info("ctc:{}".format(n2s(y_in[0].tolist()))) |
|
|
|
|
|
if not mask_num == 0: |
|
K = recog_args.maskctc_n_iterations |
|
num_iter = K if mask_num >= K and K > 0 else mask_num |
|
|
|
for t in range(num_iter - 1): |
|
pred, _ = self.decoder(y_in, None, h, None) |
|
pred_score, pred_id = pred[0][mask_idx].max(dim=-1) |
|
cand = torch.topk(pred_score, mask_num // num_iter, -1)[1] |
|
y_in[0][mask_idx[cand]] = pred_id[cand] |
|
mask_idx = torch.nonzero(y_in[0] == self.mask_token).squeeze(-1) |
|
|
|
logging.info("msk:{}".format(n2s(y_in[0].tolist()))) |
|
|
|
|
|
pred, pred_mask = self.decoder(y_in, None, h, None) |
|
y_in[0][mask_idx] = pred[0][mask_idx].argmax(dim=-1) |
|
|
|
logging.info("msk:{}".format(n2s(y_in[0].tolist()))) |
|
|
|
ret = y_in.tolist()[0] |
|
hyp = {"score": 0.0, "yseq": [self.sos] + ret + [self.eos]} |
|
|
|
return [hyp] |
|
|