File size: 9,436 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 |
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Waseda University (Yosuke Higuchi)
# Apache 2.0 (
Mask CTC based non-autoregressive speech recognition model (pytorch).
See for the detail.
from itertools import groupby
import logging
import math
from distutils.util import strtobool
import numpy
import torch
from espnet.nets.pytorch_backend.conformer.encoder import Encoder
from espnet.nets.pytorch_backend.conformer.argument import (
add_arguments_conformer_common, # noqa: H301
from espnet.nets.pytorch_backend.e2e_asr import CTC_LOSS_THRESHOLD
from espnet.nets.pytorch_backend.e2e_asr_transformer import E2E as E2ETransformer
from espnet.nets.pytorch_backend.maskctc.add_mask_token import mask_uniform
from espnet.nets.pytorch_backend.maskctc.mask import square_mask
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
from espnet.nets.pytorch_backend.nets_utils import th_accuracy
class E2E(E2ETransformer):
"""E2E module.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
def add_arguments(parser):
"""Add arguments."""
return parser
def add_maskctc_arguments(parser):
"""Add arguments for maskctc model."""
group = parser.add_argument_group("maskctc specific setting")
group = add_arguments_conformer_common(group)
return parser
def __init__(self, idim, odim, args, ignore_id=-1):
"""Construct an E2E object.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
odim += 1 # for the mask token
super().__init__(idim, odim, args, ignore_id)
assert 0.0 <= self.mtlalpha < 1.0, "mtlalpha should be [0.0, 1.0)"
self.mask_token = odim - 1
self.sos = odim - 2
self.eos = odim - 2
self.odim = odim
if args.maskctc_use_conformer_encoder:
if args.transformer_attn_dropout_rate is None:
args.transformer_attn_dropout_rate = args.conformer_dropout_rate
self.encoder = Encoder(
def forward(self, xs_pad, ilens, ys_pad):
"""E2E forward.
:param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of source sequences (B)
:param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
:return: ctc loss value
:rtype: torch.Tensor
:return: attention loss value
:rtype: torch.Tensor
:return: accuracy in attention decoder
:rtype: float
# 1. forward encoder
xs_pad = xs_pad[:, : max(ilens)] # for data parallel
src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2)
hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
self.hs_pad = hs_pad
# 2. forward decoder
ys_in_pad, ys_out_pad = mask_uniform(
ys_pad, self.mask_token, self.eos, self.ignore_id
ys_mask = square_mask(ys_in_pad, self.eos)
pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
self.pred_pad = pred_pad
# 3. compute attention loss
loss_att = self.criterion(pred_pad, ys_out_pad)
self.acc = th_accuracy(
pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id
# 4. compute ctc loss
loss_ctc, cer_ctc = None, None
if self.mtlalpha > 0:
batch_size = xs_pad.size(0)
hs_len = hs_mask.view(batch_size, -1).sum(1)
loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad)
if self.error_calculator is not None:
ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
# for visualization
if not
# 5. compute cer/wer
if or self.error_calculator is None or self.decoder is None:
cer, wer = None, None
ys_hat = pred_pad.argmax(dim=-1)
cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
alpha = self.mtlalpha
if alpha == 0:
self.loss = loss_att
loss_att_data = float(loss_att)
loss_ctc_data = None
self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
loss_att_data = float(loss_att)
loss_ctc_data = float(loss_ctc)
loss_data = float(self.loss)
if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data
logging.warning("loss (=%f) is not correct", loss_data)
return self.loss
def recognize(self, x, recog_args, char_list=None, rnnlm=None):
"""Recognize input speech.
:param ndnarray x: input acoustic feature (B, T, D) or (T, D)
:param Namespace recog_args: argment Namespace contraining options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: decoding result
:rtype: list
def num2str(char_list, mask_token, mask_char="_"):
def f(yl):
cl = [char_list[y] if y != mask_token else mask_char for y in yl]
return "".join(cl).replace("<space>", " ")
return f
n2s = num2str(char_list, self.mask_token)
h = self.encode(x).unsqueeze(0)
# greedy ctc outputs
ctc_probs, ctc_ids = torch.exp(self.ctc.log_softmax(h)).max(dim=-1)
y_hat = torch.stack([x[0] for x in groupby(ctc_ids[0])])
y_idx = torch.nonzero(y_hat != 0).squeeze(-1)
# calculate token-level ctc probabilities by taking
# the maximum probability of consecutive frames with
# the same ctc symbols
probs_hat = []
cnt = 0
for i, y in enumerate(y_hat.tolist()):
while cnt < ctc_ids.shape[1] and y == ctc_ids[0][cnt]:
if probs_hat[i] < ctc_probs[0][cnt]:
probs_hat[i] = ctc_probs[0][cnt].item()
cnt += 1
probs_hat = torch.from_numpy(numpy.array(probs_hat))
# mask ctc outputs based on ctc probabilities
p_thres = recog_args.maskctc_probability_threshold
mask_idx = torch.nonzero(probs_hat[y_idx] < p_thres).squeeze(-1)
confident_idx = torch.nonzero(probs_hat[y_idx] >= p_thres).squeeze(-1)
mask_num = len(mask_idx)
y_in = torch.zeros(1, len(y_idx), dtype=torch.long) + self.mask_token
y_in[0][confident_idx] = y_hat[y_idx][confident_idx]"ctc:{}".format(n2s(y_in[0].tolist())))
# iterative decoding
if not mask_num == 0:
K = recog_args.maskctc_n_iterations
num_iter = K if mask_num >= K and K > 0 else mask_num
for t in range(num_iter - 1):
pred, _ = self.decoder(y_in, None, h, None)
pred_score, pred_id = pred[0][mask_idx].max(dim=-1)
cand = torch.topk(pred_score, mask_num // num_iter, -1)[1]
y_in[0][mask_idx[cand]] = pred_id[cand]
mask_idx = torch.nonzero(y_in[0] == self.mask_token).squeeze(-1)"msk:{}".format(n2s(y_in[0].tolist())))
# predict leftover masks (|masks| < mask_num // num_iter)
pred, pred_mask = self.decoder(y_in, None, h, None)
y_in[0][mask_idx] = pred[0][mask_idx].argmax(dim=-1)"msk:{}".format(n2s(y_in[0].tolist())))
ret = y_in.tolist()[0]
hyp = {"score": 0.0, "yseq": [self.sos] + ret + [self.eos]}
return [hyp]