|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | """Transformer speech recognition model (pytorch).""" | 
					
						
						|  |  | 
					
						
						|  | from argparse import Namespace | 
					
						
						|  | import logging | 
					
						
						|  | import math | 
					
						
						|  | import numpy | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  |  | 
					
						
						|  | from espnet.nets.e2e_asr_common import end_detect | 
					
						
						|  | from espnet.nets.e2e_asr_common import ErrorCalculator as ASRErrorCalculator | 
					
						
						|  | from espnet.nets.e2e_mt_common import ErrorCalculator as MTErrorCalculator | 
					
						
						|  | from espnet.nets.pytorch_backend.ctc import CTC | 
					
						
						|  | from espnet.nets.pytorch_backend.e2e_asr import CTC_LOSS_THRESHOLD | 
					
						
						|  | from espnet.nets.pytorch_backend.e2e_st import Reporter | 
					
						
						|  | from espnet.nets.pytorch_backend.nets_utils import get_subsample | 
					
						
						|  | from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask | 
					
						
						|  | from espnet.nets.pytorch_backend.nets_utils import pad_list | 
					
						
						|  | from espnet.nets.pytorch_backend.nets_utils import th_accuracy | 
					
						
						|  | from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos | 
					
						
						|  | from espnet.nets.pytorch_backend.transformer.argument import ( | 
					
						
						|  | add_arguments_transformer_common, | 
					
						
						|  | ) | 
					
						
						|  | from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention | 
					
						
						|  | from espnet.nets.pytorch_backend.transformer.decoder import Decoder | 
					
						
						|  | from espnet.nets.pytorch_backend.transformer.encoder import Encoder | 
					
						
						|  | from espnet.nets.pytorch_backend.transformer.initializer import initialize | 
					
						
						|  | from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import ( | 
					
						
						|  | LabelSmoothingLoss, | 
					
						
						|  | ) | 
					
						
						|  | from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask | 
					
						
						|  | from espnet.nets.pytorch_backend.transformer.mask import target_mask | 
					
						
						|  | from espnet.nets.pytorch_backend.transformer.plot import PlotAttentionReport | 
					
						
						|  | from espnet.nets.st_interface import STInterface | 
					
						
						|  | from espnet.utils.fill_missing_args import fill_missing_args | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class E2E(STInterface, torch.nn.Module): | 
					
						
						|  | """E2E module. | 
					
						
						|  |  | 
					
						
						|  | :param int idim: dimension of inputs | 
					
						
						|  | :param int odim: dimension of outputs | 
					
						
						|  | :param Namespace args: argument Namespace containing options | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def add_arguments(parser): | 
					
						
						|  | """Add arguments.""" | 
					
						
						|  | group = parser.add_argument_group("transformer model setting") | 
					
						
						|  | group = add_arguments_transformer_common(group) | 
					
						
						|  | return parser | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def attention_plot_class(self): | 
					
						
						|  | """Return PlotAttentionReport.""" | 
					
						
						|  | return PlotAttentionReport | 
					
						
						|  |  | 
					
						
						|  | def get_total_subsampling_factor(self): | 
					
						
						|  | """Get total subsampling factor.""" | 
					
						
						|  | return self.encoder.conv_subsampling_factor * int(numpy.prod(self.subsample)) | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, idim, odim, args, ignore_id=-1): | 
					
						
						|  | """Construct an E2E object. | 
					
						
						|  |  | 
					
						
						|  | :param int idim: dimension of inputs | 
					
						
						|  | :param int odim: dimension of outputs | 
					
						
						|  | :param Namespace args: argument Namespace containing options | 
					
						
						|  | """ | 
					
						
						|  | torch.nn.Module.__init__(self) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | args = fill_missing_args(args, self.add_arguments) | 
					
						
						|  |  | 
					
						
						|  | if args.transformer_attn_dropout_rate is None: | 
					
						
						|  | args.transformer_attn_dropout_rate = args.dropout_rate | 
					
						
						|  | self.encoder = Encoder( | 
					
						
						|  | idim=idim, | 
					
						
						|  | selfattention_layer_type=args.transformer_encoder_selfattn_layer_type, | 
					
						
						|  | attention_dim=args.adim, | 
					
						
						|  | attention_heads=args.aheads, | 
					
						
						|  | conv_wshare=args.wshare, | 
					
						
						|  | conv_kernel_length=args.ldconv_encoder_kernel_length, | 
					
						
						|  | conv_usebias=args.ldconv_usebias, | 
					
						
						|  | linear_units=args.eunits, | 
					
						
						|  | num_blocks=args.elayers, | 
					
						
						|  | input_layer=args.transformer_input_layer, | 
					
						
						|  | dropout_rate=args.dropout_rate, | 
					
						
						|  | positional_dropout_rate=args.dropout_rate, | 
					
						
						|  | attention_dropout_rate=args.transformer_attn_dropout_rate, | 
					
						
						|  | ) | 
					
						
						|  | self.decoder = Decoder( | 
					
						
						|  | odim=odim, | 
					
						
						|  | selfattention_layer_type=args.transformer_decoder_selfattn_layer_type, | 
					
						
						|  | attention_dim=args.adim, | 
					
						
						|  | attention_heads=args.aheads, | 
					
						
						|  | conv_wshare=args.wshare, | 
					
						
						|  | conv_kernel_length=args.ldconv_decoder_kernel_length, | 
					
						
						|  | conv_usebias=args.ldconv_usebias, | 
					
						
						|  | linear_units=args.dunits, | 
					
						
						|  | num_blocks=args.dlayers, | 
					
						
						|  | dropout_rate=args.dropout_rate, | 
					
						
						|  | positional_dropout_rate=args.dropout_rate, | 
					
						
						|  | self_attention_dropout_rate=args.transformer_attn_dropout_rate, | 
					
						
						|  | src_attention_dropout_rate=args.transformer_attn_dropout_rate, | 
					
						
						|  | ) | 
					
						
						|  | self.pad = 0 | 
					
						
						|  | self.sos = odim - 1 | 
					
						
						|  | self.eos = odim - 1 | 
					
						
						|  | self.odim = odim | 
					
						
						|  | self.ignore_id = ignore_id | 
					
						
						|  | self.subsample = get_subsample(args, mode="st", arch="transformer") | 
					
						
						|  | self.reporter = Reporter() | 
					
						
						|  |  | 
					
						
						|  | self.criterion = LabelSmoothingLoss( | 
					
						
						|  | self.odim, | 
					
						
						|  | self.ignore_id, | 
					
						
						|  | args.lsm_weight, | 
					
						
						|  | args.transformer_length_normalized_loss, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.mtlalpha = args.mtlalpha | 
					
						
						|  | self.asr_weight = args.asr_weight | 
					
						
						|  | if self.asr_weight > 0 and args.mtlalpha < 1: | 
					
						
						|  | self.decoder_asr = Decoder( | 
					
						
						|  | odim=odim, | 
					
						
						|  | attention_dim=args.adim, | 
					
						
						|  | attention_heads=args.aheads, | 
					
						
						|  | linear_units=args.dunits, | 
					
						
						|  | num_blocks=args.dlayers, | 
					
						
						|  | dropout_rate=args.dropout_rate, | 
					
						
						|  | positional_dropout_rate=args.dropout_rate, | 
					
						
						|  | self_attention_dropout_rate=args.transformer_attn_dropout_rate, | 
					
						
						|  | src_attention_dropout_rate=args.transformer_attn_dropout_rate, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.mt_weight = args.mt_weight | 
					
						
						|  | if self.mt_weight > 0: | 
					
						
						|  | self.encoder_mt = Encoder( | 
					
						
						|  | idim=odim, | 
					
						
						|  | attention_dim=args.adim, | 
					
						
						|  | attention_heads=args.aheads, | 
					
						
						|  | linear_units=args.dunits, | 
					
						
						|  | num_blocks=args.dlayers, | 
					
						
						|  | input_layer="embed", | 
					
						
						|  | dropout_rate=args.dropout_rate, | 
					
						
						|  | positional_dropout_rate=args.dropout_rate, | 
					
						
						|  | attention_dropout_rate=args.transformer_attn_dropout_rate, | 
					
						
						|  | padding_idx=0, | 
					
						
						|  | ) | 
					
						
						|  | self.reset_parameters(args) | 
					
						
						|  | self.adim = args.adim | 
					
						
						|  | if self.asr_weight > 0 and args.mtlalpha > 0.0: | 
					
						
						|  | self.ctc = CTC( | 
					
						
						|  | odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | self.ctc = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.error_calculator = MTErrorCalculator( | 
					
						
						|  | args.char_list, args.sym_space, args.sym_blank, args.report_bleu | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.error_calculator_asr = ASRErrorCalculator( | 
					
						
						|  | args.char_list, | 
					
						
						|  | args.sym_space, | 
					
						
						|  | args.sym_blank, | 
					
						
						|  | args.report_cer, | 
					
						
						|  | args.report_wer, | 
					
						
						|  | ) | 
					
						
						|  | self.rnnlm = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.multilingual = getattr(args, "multilingual", False) | 
					
						
						|  | self.replace_sos = getattr(args, "replace_sos", False) | 
					
						
						|  |  | 
					
						
						|  | def reset_parameters(self, args): | 
					
						
						|  | """Initialize parameters.""" | 
					
						
						|  | initialize(self, args.transformer_init) | 
					
						
						|  | if self.mt_weight > 0: | 
					
						
						|  | torch.nn.init.normal_( | 
					
						
						|  | self.encoder_mt.embed[0].weight, mean=0, std=args.adim ** -0.5 | 
					
						
						|  | ) | 
					
						
						|  | torch.nn.init.constant_(self.encoder_mt.embed[0].weight[self.pad], 0) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, xs_pad, ilens, ys_pad, ys_pad_src): | 
					
						
						|  | """E2E forward. | 
					
						
						|  |  | 
					
						
						|  | :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) | 
					
						
						|  | :param torch.Tensor ilens: batch of lengths of source sequences (B) | 
					
						
						|  | :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) | 
					
						
						|  | :param torch.Tensor ys_pad_src: batch of padded target sequences (B, Lmax) | 
					
						
						|  | :return: ctc loss value | 
					
						
						|  | :rtype: torch.Tensor | 
					
						
						|  | :return: attention loss value | 
					
						
						|  | :rtype: torch.Tensor | 
					
						
						|  | :return: accuracy in attention decoder | 
					
						
						|  | :rtype: float | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | tgt_lang_ids = None | 
					
						
						|  | if self.multilingual: | 
					
						
						|  | tgt_lang_ids = ys_pad[:, 0:1] | 
					
						
						|  | ys_pad = ys_pad[:, 1:] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | xs_pad = xs_pad[:, : max(ilens)] | 
					
						
						|  | src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) | 
					
						
						|  | hs_pad, hs_mask = self.encoder(xs_pad, src_mask) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) | 
					
						
						|  |  | 
					
						
						|  | if self.replace_sos: | 
					
						
						|  | ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1) | 
					
						
						|  | ys_mask = target_mask(ys_in_pad, self.ignore_id) | 
					
						
						|  | pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | loss_att = self.criterion(pred_pad, ys_out_pad) | 
					
						
						|  |  | 
					
						
						|  | self.acc = th_accuracy( | 
					
						
						|  | pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.training: | 
					
						
						|  | self.bleu = None | 
					
						
						|  | else: | 
					
						
						|  | ys_hat = pred_pad.argmax(dim=-1) | 
					
						
						|  | self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | loss_asr_att, acc_asr, loss_asr_ctc, cer_ctc, cer, wer = self.forward_asr( | 
					
						
						|  | hs_pad, hs_mask, ys_pad_src | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | loss_mt, acc_mt = 0.0, None | 
					
						
						|  | if self.mt_weight > 0: | 
					
						
						|  | loss_mt, acc_mt = self.forward_mt( | 
					
						
						|  | ys_pad_src, ys_in_pad, ys_out_pad, ys_mask | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | asr_ctc_weight = self.mtlalpha | 
					
						
						|  | self.loss = ( | 
					
						
						|  | (1 - self.asr_weight - self.mt_weight) * loss_att | 
					
						
						|  | + self.asr_weight | 
					
						
						|  | * (asr_ctc_weight * loss_asr_ctc + (1 - asr_ctc_weight) * loss_asr_att) | 
					
						
						|  | + self.mt_weight * loss_mt | 
					
						
						|  | ) | 
					
						
						|  | loss_asr_data = float( | 
					
						
						|  | asr_ctc_weight * loss_asr_ctc + (1 - asr_ctc_weight) * loss_asr_att | 
					
						
						|  | ) | 
					
						
						|  | loss_mt_data = None if self.mt_weight == 0 else float(loss_mt) | 
					
						
						|  | loss_st_data = float(loss_att) | 
					
						
						|  |  | 
					
						
						|  | loss_data = float(self.loss) | 
					
						
						|  | if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): | 
					
						
						|  | self.reporter.report( | 
					
						
						|  | loss_asr_data, | 
					
						
						|  | loss_mt_data, | 
					
						
						|  | loss_st_data, | 
					
						
						|  | acc_asr, | 
					
						
						|  | acc_mt, | 
					
						
						|  | self.acc, | 
					
						
						|  | cer_ctc, | 
					
						
						|  | cer, | 
					
						
						|  | wer, | 
					
						
						|  | self.bleu, | 
					
						
						|  | loss_data, | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | logging.warning("loss (=%f) is not correct", loss_data) | 
					
						
						|  | return self.loss | 
					
						
						|  |  | 
					
						
						|  | def forward_asr(self, hs_pad, hs_mask, ys_pad): | 
					
						
						|  | """Forward pass in the auxiliary ASR task. | 
					
						
						|  |  | 
					
						
						|  | :param torch.Tensor hs_pad: batch of padded source sequences (B, Tmax, idim) | 
					
						
						|  | :param torch.Tensor hs_mask: batch of input token mask (B, Lmax) | 
					
						
						|  | :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) | 
					
						
						|  | :return: ASR attention loss value | 
					
						
						|  | :rtype: torch.Tensor | 
					
						
						|  | :return: accuracy in ASR attention decoder | 
					
						
						|  | :rtype: float | 
					
						
						|  | :return: ASR CTC loss value | 
					
						
						|  | :rtype: torch.Tensor | 
					
						
						|  | :return: character error rate from CTC prediction | 
					
						
						|  | :rtype: float | 
					
						
						|  | :return: character error rate from attetion decoder prediction | 
					
						
						|  | :rtype: float | 
					
						
						|  | :return: word error rate from attetion decoder prediction | 
					
						
						|  | :rtype: float | 
					
						
						|  | """ | 
					
						
						|  | loss_att, loss_ctc = 0.0, 0.0 | 
					
						
						|  | acc = None | 
					
						
						|  | cer, wer = None, None | 
					
						
						|  | cer_ctc = None | 
					
						
						|  | if self.asr_weight == 0: | 
					
						
						|  | return loss_att, acc, loss_ctc, cer_ctc, cer, wer | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.mtlalpha < 1: | 
					
						
						|  | ys_in_pad_asr, ys_out_pad_asr = add_sos_eos( | 
					
						
						|  | ys_pad, self.sos, self.eos, self.ignore_id | 
					
						
						|  | ) | 
					
						
						|  | ys_mask_asr = target_mask(ys_in_pad_asr, self.ignore_id) | 
					
						
						|  | pred_pad, _ = self.decoder_asr(ys_in_pad_asr, ys_mask_asr, hs_pad, hs_mask) | 
					
						
						|  | loss_att = self.criterion(pred_pad, ys_out_pad_asr) | 
					
						
						|  |  | 
					
						
						|  | acc = th_accuracy( | 
					
						
						|  | pred_pad.view(-1, self.odim), | 
					
						
						|  | ys_out_pad_asr, | 
					
						
						|  | ignore_label=self.ignore_id, | 
					
						
						|  | ) | 
					
						
						|  | if not self.training: | 
					
						
						|  | ys_hat_asr = pred_pad.argmax(dim=-1) | 
					
						
						|  | cer, wer = self.error_calculator_asr(ys_hat_asr.cpu(), ys_pad.cpu()) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.mtlalpha > 0: | 
					
						
						|  | batch_size = hs_pad.size(0) | 
					
						
						|  | hs_len = hs_mask.view(batch_size, -1).sum(1) | 
					
						
						|  | loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) | 
					
						
						|  | if not self.training: | 
					
						
						|  | ys_hat_ctc = self.ctc.argmax( | 
					
						
						|  | hs_pad.view(batch_size, -1, self.adim) | 
					
						
						|  | ).data | 
					
						
						|  | cer_ctc = self.error_calculator_asr( | 
					
						
						|  | ys_hat_ctc.cpu(), ys_pad.cpu(), is_ctc=True | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.ctc.softmax(hs_pad) | 
					
						
						|  | return loss_att, acc, loss_ctc, cer_ctc, cer, wer | 
					
						
						|  |  | 
					
						
						|  | def forward_mt(self, xs_pad, ys_in_pad, ys_out_pad, ys_mask): | 
					
						
						|  | """Forward pass in the auxiliary MT task. | 
					
						
						|  |  | 
					
						
						|  | :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) | 
					
						
						|  | :param torch.Tensor ys_in_pad: batch of padded target sequences (B, Lmax) | 
					
						
						|  | :param torch.Tensor ys_out_pad: batch of padded target sequences (B, Lmax) | 
					
						
						|  | :param torch.Tensor ys_mask: batch of input token mask (B, Lmax) | 
					
						
						|  | :return: MT loss value | 
					
						
						|  | :rtype: torch.Tensor | 
					
						
						|  | :return: accuracy in MT decoder | 
					
						
						|  | :rtype: float | 
					
						
						|  | """ | 
					
						
						|  | loss, acc = 0.0, None | 
					
						
						|  | if self.mt_weight == 0: | 
					
						
						|  | return loss, acc | 
					
						
						|  |  | 
					
						
						|  | ilens = torch.sum(xs_pad != self.ignore_id, dim=1).cpu().numpy() | 
					
						
						|  |  | 
					
						
						|  | xs = [x[x != self.ignore_id] for x in xs_pad] | 
					
						
						|  | xs_zero_pad = pad_list(xs, self.pad) | 
					
						
						|  | xs_zero_pad = xs_zero_pad[:, : max(ilens)] | 
					
						
						|  | src_mask = ( | 
					
						
						|  | make_non_pad_mask(ilens.tolist()).to(xs_zero_pad.device).unsqueeze(-2) | 
					
						
						|  | ) | 
					
						
						|  | hs_pad, hs_mask = self.encoder_mt(xs_zero_pad, src_mask) | 
					
						
						|  | pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) | 
					
						
						|  | loss = self.criterion(pred_pad, ys_out_pad) | 
					
						
						|  | acc = th_accuracy( | 
					
						
						|  | pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id | 
					
						
						|  | ) | 
					
						
						|  | return loss, acc | 
					
						
						|  |  | 
					
						
						|  | def scorers(self): | 
					
						
						|  | """Scorers.""" | 
					
						
						|  | return dict(decoder=self.decoder) | 
					
						
						|  |  | 
					
						
						|  | def encode(self, x): | 
					
						
						|  | """Encode source acoustic features. | 
					
						
						|  |  | 
					
						
						|  | :param ndarray x: source acoustic feature (T, D) | 
					
						
						|  | :return: encoder outputs | 
					
						
						|  | :rtype: torch.Tensor | 
					
						
						|  | """ | 
					
						
						|  | self.eval() | 
					
						
						|  | x = torch.as_tensor(x).unsqueeze(0) | 
					
						
						|  | enc_output, _ = self.encoder(x, None) | 
					
						
						|  | return enc_output.squeeze(0) | 
					
						
						|  |  | 
					
						
						|  | def translate( | 
					
						
						|  | self, | 
					
						
						|  | x, | 
					
						
						|  | trans_args, | 
					
						
						|  | char_list=None, | 
					
						
						|  | ): | 
					
						
						|  | """Translate input speech. | 
					
						
						|  |  | 
					
						
						|  | :param ndnarray x: input acoustic feature (B, T, D) or (T, D) | 
					
						
						|  | :param Namespace trans_args: argment Namespace contraining options | 
					
						
						|  | :param list char_list: list of characters | 
					
						
						|  | :return: N-best decoding results | 
					
						
						|  | :rtype: list | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | if getattr(trans_args, "tgt_lang", False): | 
					
						
						|  | if self.replace_sos: | 
					
						
						|  | y = char_list.index(trans_args.tgt_lang) | 
					
						
						|  | else: | 
					
						
						|  | y = self.sos | 
					
						
						|  | logging.info("<sos> index: " + str(y)) | 
					
						
						|  | logging.info("<sos> mark: " + char_list[y]) | 
					
						
						|  | logging.info("input lengths: " + str(x.shape[0])) | 
					
						
						|  |  | 
					
						
						|  | enc_output = self.encode(x).unsqueeze(0) | 
					
						
						|  |  | 
					
						
						|  | h = enc_output | 
					
						
						|  |  | 
					
						
						|  | logging.info("encoder output lengths: " + str(h.size(1))) | 
					
						
						|  |  | 
					
						
						|  | beam = trans_args.beam_size | 
					
						
						|  | penalty = trans_args.penalty | 
					
						
						|  |  | 
					
						
						|  | if trans_args.maxlenratio == 0: | 
					
						
						|  | maxlen = h.size(1) | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | maxlen = max(1, int(trans_args.maxlenratio * h.size(1))) | 
					
						
						|  | minlen = int(trans_args.minlenratio * h.size(1)) | 
					
						
						|  | logging.info("max output length: " + str(maxlen)) | 
					
						
						|  | logging.info("min output length: " + str(minlen)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | hyp = {"score": 0.0, "yseq": [y]} | 
					
						
						|  | hyps = [hyp] | 
					
						
						|  | ended_hyps = [] | 
					
						
						|  |  | 
					
						
						|  | for i in range(maxlen): | 
					
						
						|  | logging.debug("position " + str(i)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ys = h.new_zeros((len(hyps), i + 1), dtype=torch.int64) | 
					
						
						|  | for j, hyp in enumerate(hyps): | 
					
						
						|  | ys[j, :] = torch.tensor(hyp["yseq"]) | 
					
						
						|  | ys_mask = subsequent_mask(i + 1).unsqueeze(0).to(h.device) | 
					
						
						|  |  | 
					
						
						|  | local_scores = self.decoder.forward_one_step( | 
					
						
						|  | ys, ys_mask, h.repeat([len(hyps), 1, 1]) | 
					
						
						|  | )[0] | 
					
						
						|  |  | 
					
						
						|  | hyps_best_kept = [] | 
					
						
						|  | for j, hyp in enumerate(hyps): | 
					
						
						|  | local_best_scores, local_best_ids = torch.topk( | 
					
						
						|  | local_scores[j : j + 1], beam, dim=1 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | for j in range(beam): | 
					
						
						|  | new_hyp = {} | 
					
						
						|  | new_hyp["score"] = hyp["score"] + float(local_best_scores[0, j]) | 
					
						
						|  | new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) | 
					
						
						|  | new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"] | 
					
						
						|  | new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j]) | 
					
						
						|  |  | 
					
						
						|  | hyps_best_kept.append(new_hyp) | 
					
						
						|  |  | 
					
						
						|  | hyps_best_kept = sorted( | 
					
						
						|  | hyps_best_kept, key=lambda x: x["score"], reverse=True | 
					
						
						|  | )[:beam] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | hyps = hyps_best_kept | 
					
						
						|  | logging.debug("number of pruned hypothes: " + str(len(hyps))) | 
					
						
						|  | if char_list is not None: | 
					
						
						|  | logging.debug( | 
					
						
						|  | "best hypo: " | 
					
						
						|  | + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if i == maxlen - 1: | 
					
						
						|  | logging.info("adding <eos> in the last postion in the loop") | 
					
						
						|  | for hyp in hyps: | 
					
						
						|  | hyp["yseq"].append(self.eos) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | remained_hyps = [] | 
					
						
						|  | for hyp in hyps: | 
					
						
						|  | if hyp["yseq"][-1] == self.eos: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if len(hyp["yseq"]) > minlen: | 
					
						
						|  | hyp["score"] += (i + 1) * penalty | 
					
						
						|  | ended_hyps.append(hyp) | 
					
						
						|  | else: | 
					
						
						|  | remained_hyps.append(hyp) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if end_detect(ended_hyps, i) and trans_args.maxlenratio == 0.0: | 
					
						
						|  | logging.info("end detected at %d", i) | 
					
						
						|  | break | 
					
						
						|  |  | 
					
						
						|  | hyps = remained_hyps | 
					
						
						|  | if len(hyps) > 0: | 
					
						
						|  | logging.debug("remeined hypothes: " + str(len(hyps))) | 
					
						
						|  | else: | 
					
						
						|  | logging.info("no hypothesis. Finish decoding.") | 
					
						
						|  | break | 
					
						
						|  |  | 
					
						
						|  | if char_list is not None: | 
					
						
						|  | for hyp in hyps: | 
					
						
						|  | logging.debug( | 
					
						
						|  | "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]]) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | logging.debug("number of ended hypothes: " + str(len(ended_hyps))) | 
					
						
						|  |  | 
					
						
						|  | nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[ | 
					
						
						|  | : min(len(ended_hyps), trans_args.nbest) | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if len(nbest_hyps) == 0: | 
					
						
						|  | logging.warning( | 
					
						
						|  | "there is no N-best results, perform translation " | 
					
						
						|  | "again with smaller minlenratio." | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | trans_args = Namespace(**vars(trans_args)) | 
					
						
						|  | trans_args.minlenratio = max(0.0, trans_args.minlenratio - 0.1) | 
					
						
						|  | return self.translate(x, trans_args, char_list) | 
					
						
						|  |  | 
					
						
						|  | logging.info("total log probability: " + str(nbest_hyps[0]["score"])) | 
					
						
						|  | logging.info( | 
					
						
						|  | "normalized log probability: " | 
					
						
						|  | + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"])) | 
					
						
						|  | ) | 
					
						
						|  | return nbest_hyps | 
					
						
						|  |  | 
					
						
						|  | def calculate_all_attentions(self, xs_pad, ilens, ys_pad, ys_pad_src): | 
					
						
						|  | """E2E attention calculation. | 
					
						
						|  |  | 
					
						
						|  | :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) | 
					
						
						|  | :param torch.Tensor ilens: batch of lengths of input sequences (B) | 
					
						
						|  | :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) | 
					
						
						|  | :param torch.Tensor ys_pad_src: | 
					
						
						|  | batch of padded token id sequence tensor (B, Lmax) | 
					
						
						|  | :return: attention weights (B, H, Lmax, Tmax) | 
					
						
						|  | :rtype: float ndarray | 
					
						
						|  | """ | 
					
						
						|  | self.eval() | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | self.forward(xs_pad, ilens, ys_pad, ys_pad_src) | 
					
						
						|  | ret = dict() | 
					
						
						|  | for name, m in self.named_modules(): | 
					
						
						|  | if ( | 
					
						
						|  | isinstance(m, MultiHeadedAttention) and m.attn is not None | 
					
						
						|  | ): | 
					
						
						|  | ret[name] = m.attn.cpu().numpy() | 
					
						
						|  | self.train() | 
					
						
						|  | return ret | 
					
						
						|  |  | 
					
						
						|  | def calculate_all_ctc_probs(self, xs_pad, ilens, ys_pad, ys_pad_src): | 
					
						
						|  | """E2E CTC probability calculation. | 
					
						
						|  |  | 
					
						
						|  | :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax) | 
					
						
						|  | :param torch.Tensor ilens: batch of lengths of input sequences (B) | 
					
						
						|  | :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) | 
					
						
						|  | :param torch.Tensor ys_pad_src: | 
					
						
						|  | batch of padded token id sequence tensor (B, Lmax) | 
					
						
						|  | :return: CTC probability (B, Tmax, vocab) | 
					
						
						|  | :rtype: float ndarray | 
					
						
						|  | """ | 
					
						
						|  | ret = None | 
					
						
						|  | if self.asr_weight == 0 or self.mtlalpha == 0: | 
					
						
						|  | return ret | 
					
						
						|  |  | 
					
						
						|  | self.eval() | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | self.forward(xs_pad, ilens, ys_pad, ys_pad_src) | 
					
						
						|  | ret = None | 
					
						
						|  | for name, m in self.named_modules(): | 
					
						
						|  | if isinstance(m, CTC) and m.probs is not None: | 
					
						
						|  | ret = m.probs.cpu().numpy() | 
					
						
						|  | self.train() | 
					
						
						|  | return ret | 
					
						
						|  |  |