# -------------------------------------------------------- # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) # Github source: https://github.com/mbzuai-nlp/ArTST # Based on speecht5, fairseq and espnet code bases # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet # -------------------------------------------------------- import math from typing import Dict, List, Optional import sys import torch import torch.nn as nn from fairseq import search, utils from fairseq.data import data_utils from fairseq.models import FairseqIncrementalDecoder from torch import Tensor from fairseq.ngram_repeat_block import NGramRepeatBlock from espnet.nets.ctc_prefix_score import CTCPrefixScore import numpy CTC_SCORING_RATIO = 7.0 device = "cuda" if torch.cuda.is_available() else "cpu" class SequenceGenerator(nn.Module): def __init__( self, models, tgt_dict, beam_size=1, max_len_a=0, max_len_b=200, max_len=0, min_len=1, normalize_scores=True, len_penalty=1.0, unk_penalty=0.0, temperature=1.0, match_source_len=False, no_repeat_ngram_size=0, search_strategy=None, eos=None, symbols_to_strip_from_output=None, lm_model=None, lm_weight=1.0, ctc_weight=0.0, ): """Generates translations of a given source sentence. Args: models (List[~fairseq.models.FairseqModel]): ensemble of models, currently support fairseq.models.TransformerModel for scripting beam_size (int, optional): beam width (default: 1) max_len_a/b (int, optional): generate sequences of maximum length ax + b, where x is the source length max_len (int, optional): the maximum length of the generated output (not including end-of-sentence) min_len (int, optional): the minimum length of the generated output (not including end-of-sentence) normalize_scores (bool, optional): normalize scores by the length of the output (default: True) len_penalty (float, optional): length penalty, where <1.0 favors shorter, >1.0 favors longer sentences (default: 1.0) unk_penalty (float, optional): unknown word penalty, where <0 produces more unks, >0 produces fewer (default: 0.0) temperature (float, optional): temperature, where values >1.0 produce more uniform samples and values <1.0 produce sharper samples (default: 1.0) match_source_len (bool, optional): outputs should match the source length (default: False) """ super().__init__() if isinstance(models, EnsembleModel): self.model = models else: self.model = EnsembleModel(models) self.tgt_dict = tgt_dict self.pad = tgt_dict.pad() self.unk = tgt_dict.unk() self.eos = tgt_dict.eos() if eos is None else eos self.blank = self.tgt_dict.index("") self.mask = self.tgt_dict.index("") self.mask_idxs = [] if self.tgt_dict.index("0") != self.unk: count = 0 while self.tgt_dict.index("" + str(count)) != self.unk: self.mask_idxs.append(self.tgt_dict.index("" + str(count))) count += 1 self.mask_idxs = torch.tensor(self.mask_idxs) self.symbols_to_strip_from_output = ( symbols_to_strip_from_output.union({self.eos}) if symbols_to_strip_from_output is not None else {self.eos} ) self.vocab_size = len(tgt_dict) self.beam_size = beam_size # the max beam size is the dictionary size - 1, since we never select pad self.beam_size = min(beam_size, self.vocab_size - 1) self.max_len_a = max_len_a self.max_len_b = max_len_b self.min_len = min_len self.max_len = max_len or self.model.max_decoder_positions() self.normalize_scores = normalize_scores self.len_penalty = len_penalty self.unk_penalty = unk_penalty self.temperature = temperature self.match_source_len = match_source_len if no_repeat_ngram_size > 0: self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size) else: self.repeat_ngram_blocker = None assert temperature > 0, "--temperature must be greater than 0" self.search = ( search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy ) # We only need to set src_lengths in LengthConstrainedBeamSearch. # As a module attribute, setting it would break in multithread # settings when the model is shared. self.should_set_src_lengths = ( hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths ) self.model.eval() self.lm_model = lm_model self.lm_weight = lm_weight self.ctc_weight = ctc_weight if self.lm_model is not None: self.lm_model.eval() def cuda(self): self.model.cuda() return self @torch.no_grad() def forward( self, sample: Dict[str, Dict[str, Tensor]], prefix_tokens: Optional[Tensor] = None, bos_token: Optional[int] = None, ): """Generate a batch of translations. Args: sample (dict): batch prefix_tokens (torch.LongTensor, optional): force decoder to begin with these tokens bos_token (int, optional): beginning of sentence token (default: self.eos) """ return self._generate(sample, prefix_tokens, bos_token=bos_token) # TODO(myleott): unused, deprecate after pytorch-translate migration def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None): """Iterate over a batched dataset and yield individual translations. Args: cuda (bool, optional): use GPU for generation timer (StopwatchMeter, optional): time generations """ for sample in data_itr: s = utils.move_to_cuda(sample) if cuda else sample if "net_input" not in s: continue input = s["net_input"] # model.forward normally channels prev_output_tokens into the decoder # separately, but SequenceGenerator directly calls model.encoder encoder_input = { k: v for k, v in input.items() if k != "prev_output_tokens" } if timer is not None: timer.start() with torch.no_grad(): hypos = self.generate(encoder_input) if timer is not None: timer.stop(sum(len(h[0]["tokens"]) for h in hypos)) for i, id in enumerate(s["id"].data): # remove padding src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad) ref = ( utils.strip_pad(s["target"].data[i, :], self.pad) if s["target"] is not None else None ) yield id, src, ref, hypos[i] @torch.no_grad() def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs): """Generate translations. Match the api of other fairseq generators. Args: models (List[~fairseq.models.FairseqModel]): ensemble of models sample (dict): batch prefix_tokens (torch.LongTensor, optional): force decoder to begin with these tokens constraints (torch.LongTensor, optional): force decoder to include the list of constraints bos_token (int, optional): beginning of sentence token (default: self.eos) """ return self._generate(sample, **kwargs) def _generate( self, sample: Dict[str, Dict[str, Tensor]], prefix_tokens: Optional[Tensor] = None, constraints: Optional[Tensor] = None, bos_token: Optional[int] = None, ): incremental_states = torch.jit.annotate( List[Dict[str, Dict[str, Optional[Tensor]]]], [ torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) for i in range(self.model.models_size) ], ) net_input = sample["net_input"] if "src_tokens" in net_input: src_tokens = net_input["src_tokens"] # length of the source text being the character length except EndOfSentence and pad src_lengths = ( (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) ) elif "source" in net_input: src_tokens = net_input["source"] src_lengths = ( net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1) if net_input["padding_mask"] is not None else torch.tensor(src_tokens.size(-1)).to(src_tokens) ) elif "features" in net_input: src_tokens = net_input["features"] src_lengths = ( net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1) if net_input["padding_mask"] is not None else torch.tensor(src_tokens.size(-1)).to(src_tokens) ) else: raise Exception("expected src_tokens or source in net input. input keys: " + str(net_input.keys())) # bsz: total number of sentences in beam # Note that src_tokens may have more than 2 dimensions (i.e. audio features) bsz, src_len = src_tokens.size()[:2] beam_size = self.beam_size if constraints is not None and not self.search.supports_constraints: raise NotImplementedError( "Target-side constraints were provided, but search method doesn't support them" ) # Initialize constraints, when active self.search.init_constraints(constraints, beam_size) max_len: int = -1 if self.match_source_len: max_len = src_lengths.max().item() else: max_len = min( int(self.max_len_a * src_len + self.max_len_b), self.max_len - 1, ) assert ( self.min_len <= max_len ), "min_len cannot be larger than max_len, please adjust these!" # compute the encoder output for each beam encoder_outs = self.model.forward_encoder(net_input) # Get CTC lprobs and prep ctc_scorer if self.ctc_weight > 0: ctc_lprobs = self.model.models[0].get_normalized_probs_for_ctc( encoder_outs[0], log_probs=True ).contiguous().transpose(0, 1) # (B, T, C) from the encoder hyp = {} ctc_prefix_score = CTCPrefixScore(ctc_lprobs[0].detach().cpu().numpy(), self.blank, self.eos, numpy) hyp["ctc_state_prev"] = ctc_prefix_score.initial_state() hyp["ctc_score_prev"] = 0.0 ctc_beam = min(ctc_lprobs.shape[-1] - self.mask_idxs.size(-1), int(beam_size * CTC_SCORING_RATIO)) ctc_hyps = {str(self.eos): hyp} # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) new_order = new_order.to(src_tokens.device).long() encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order) # ensure encoder_outs is a List. assert encoder_outs is not None # initialize buffers scores = ( torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float() ) # +1 for eos; pad is never chosen for scoring tokens = ( torch.zeros(bsz * beam_size, max_len + 2) .to(src_tokens) .long() .fill_(self.pad) ) # +2 for eos and pad tokens[:, 0] = self.eos if bos_token is None else bos_token attn: Optional[Tensor] = None # A list that indicates candidates that should be ignored. # For example, suppose we're sampling and have already finalized 2/5 # samples. Then cands_to_ignore would mark 2 positions as being ignored, # so that we only finalize the remaining 3 samples. cands_to_ignore = ( torch.zeros(bsz, beam_size).to(src_tokens).eq(-1) ) # forward and backward-compatible False mask # list of completed sentences finalized = torch.jit.annotate( List[List[Dict[str, Tensor]]], [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)], ) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step # a boolean array indicating if the sentence at the index is finished or not finished = [False for i in range(bsz)] num_remaining_sent = bsz # number of sentences remaining # number of candidate hypos per step cand_size = 2 * beam_size # 2 x beam size in case half are EOS # offset arrays for converting between different indexing schemes bbsz_offsets = ( (torch.arange(0, bsz) * beam_size) .unsqueeze(1) .type_as(tokens) .to(src_tokens.device) ) cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_tokens.device) reorder_state: Optional[Tensor] = None ctc_state = None batch_idxs: Optional[Tensor] = None original_batch_idxs: Optional[Tensor] = None if "id" in sample and isinstance(sample["id"], Tensor): original_batch_idxs = sample["id"] else: original_batch_idxs = torch.arange(0, bsz).type_as(tokens) for step in range(max_len + 1): # one extra step for EOS marker # reorder decoder internal states based on the prev choice of beams if reorder_state is not None: if batch_idxs is not None: # update beam indices to take into account removed sentences corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as( batch_idxs ) reorder_state.view(-1, beam_size).add_( corr.unsqueeze(-1) * beam_size ) original_batch_idxs = original_batch_idxs[batch_idxs] self.model.reorder_incremental_state(incremental_states, reorder_state) encoder_outs = self.model.reorder_encoder_out( encoder_outs, reorder_state ) lprobs, avg_attn_scores = self.model.forward_decoder( tokens[:, : step + 1], encoder_outs, incremental_states, self.temperature, ) if self.ctc_weight > 0 and step != 0: # lprobs[:, self.blank] = -math.inf # never select blank ctc_lprobs = lprobs.clone() ctc_lprobs[:, self.blank] = -math.inf # never select blank if self.mask != self.unk: ctc_lprobs[:, self.mask] = -math.inf # never select mask if self.mask_idxs.size(0) != 0: ctc_lprobs[:, self.mask_idxs] = -math.inf # never select mask local_best_scores, local_best_ids = torch.topk(ctc_lprobs, ctc_beam, dim=-1) for b in range(tokens.size(0)): hyp_key = " ".join(str(x) for x in tokens[b, : step + 1].tolist()) ctc_scores, ctc_states = ctc_prefix_score( tokens[b, : step + 1].cpu(), local_best_ids[b].cpu(), ctc_hyps[hyp_key]["ctc_state_prev"] ) lprobs[b] = lprobs[b] lprobs[b, local_best_ids[b]] = (1 - self.ctc_weight) * (lprobs[b, local_best_ids[b]]) + self.ctc_weight * torch.from_numpy( ctc_scores - ctc_hyps[hyp_key]["ctc_score_prev"] ).to(device=device) for j in range(len(local_best_ids[b])): ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())] = {} ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_score_prev"] = ctc_scores[j] ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_state_prev"] = ctc_states[j] # local_ctc_scores, ctc_state = ctc_scorer( # tokens[:, : step + 1], ctc_state, part_ids # ) # lprobs += local_ctc_scores * self.ctc_weight elif self.ctc_weight > 0 and step == 0: ctc_lprobs = lprobs.clone() ctc_lprobs[:, self.blank] = -math.inf # never select blank if self.mask != self.unk: ctc_lprobs[:, self.mask] = -math.inf # never select mask if self.mask_idxs.size(0) != 0: ctc_lprobs[:, self.mask_idxs] = -math.inf # never select mask local_best_scores, local_best_ids = torch.topk(ctc_lprobs, ctc_beam, dim=-1) for b in range(tokens.size(0)): hyp_key = " ".join(str(x) for x in tokens[b, : step + 1].tolist()) ctc_scores, ctc_states = ctc_prefix_score( tokens[b, : step + 1].cpu(), local_best_ids[b].cpu(), ctc_hyps[hyp_key]["ctc_state_prev"] ) lprobs[b] = lprobs[b] lprobs[b, local_best_ids[b]] = (1 - self.ctc_weight) * (lprobs[b, local_best_ids[b]]) + self.ctc_weight * torch.from_numpy( ctc_scores - ctc_hyps[hyp_key]["ctc_score_prev"] ).to(device=device) for j in range(len(local_best_ids[b])): if b == 0: ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())] = {} ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_score_prev"] = ctc_scores[j] ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_state_prev"] = ctc_states[j] if self.lm_model is not None: lm_out = self.lm_model(tokens[:, : step + 1]) probs = self.lm_model.get_normalized_probs( lm_out, log_probs=True, sample=None ) probs = probs[:, -1, :] * self.lm_weight lprobs[:, :probs.size(1)] += probs # handle prefix tokens (possibly with different lengths) if ( prefix_tokens is not None and step < prefix_tokens.size(1) and step < max_len ): lprobs, tokens, scores = self._prefix_tokens( step, lprobs, scores, tokens, prefix_tokens, beam_size ) elif step < self.min_len: # minimum length constraint (does not apply if using prefix_tokens) lprobs[:, self.eos] = -math.inf lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs) lprobs[:, self.pad] = -math.inf # never select pad lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty lprobs[:, self.blank] = -math.inf # never select blank if self.mask != self.unk: lprobs[:, self.mask] = -math.inf # never select mask if self.mask_idxs.size(0) != 0: lprobs[:, self.mask_idxs] = -math.inf # never select mask # handle max length constraint if step >= max_len: lprobs[:, : self.eos] = -math.inf lprobs[:, self.eos + 1 :] = -math.inf # Record attention scores, only support avg_attn_scores is a Tensor if avg_attn_scores is not None: if attn is None: attn = torch.empty( bsz * beam_size, avg_attn_scores.size(1), max_len + 2 ).to(scores) attn[:, :, step + 1].copy_(avg_attn_scores) scores = scores.type_as(lprobs) eos_bbsz_idx = torch.empty(0).to( tokens ) # indices of hypothesis ending with eos (finished sentences) eos_scores = torch.empty(0).to( scores ) # scores of hypothesis ending with eos (finished sentences) if self.should_set_src_lengths: self.search.set_src_lengths(src_lengths) if self.repeat_ngram_blocker is not None: lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step) # Shape: (batch, cand_size) cand_scores, cand_indices, cand_beams = self.search.step( step, lprobs.view(bsz, -1, self.vocab_size), scores.view(bsz, beam_size, -1)[:, :, :step], tokens[:, : step + 1], original_batch_idxs, ) # cand_bbsz_idx contains beam indices for the top candidate # hypotheses, with a range of values: [0, bsz*beam_size), # and dimensions: [bsz, cand_size] cand_bbsz_idx = cand_beams.add(bbsz_offsets) # finalize hypotheses that end in eos # Shape of eos_mask: (batch size, beam size) eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf) eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask) # only consider eos when it's among the top beam_size indices # Now we know what beam item(s) to finish # Shape: 1d list of absolute-numbered eos_bbsz_idx = torch.masked_select( cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size] ) finalized_sents: List[int] = [] if eos_bbsz_idx.numel() > 0: eos_scores = torch.masked_select( cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size] ) finalized_sents = self.finalize_hypos( step, eos_bbsz_idx, eos_scores, tokens, scores, finalized, finished, beam_size, attn, src_lengths, max_len, ) num_remaining_sent -= len(finalized_sents) assert num_remaining_sent >= 0 if num_remaining_sent == 0: break if self.search.stop_on_max_len and step >= max_len: break assert step < max_len, f"{step} < {max_len}" # Remove finalized sentences (ones for which {beam_size} # finished hypotheses have been generated) from the batch. if len(finalized_sents) > 0: new_bsz = bsz - len(finalized_sents) # construct batch_idxs which holds indices of batches to keep for the next pass batch_mask = torch.ones( bsz, dtype=torch.bool, device=cand_indices.device ) batch_mask[finalized_sents] = False # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it batch_idxs = torch.arange( bsz, device=cand_indices.device ).masked_select(batch_mask) # Choose the subset of the hypothesized constraints that will continue self.search.prune_sentences(batch_idxs) eos_mask = eos_mask[batch_idxs] cand_beams = cand_beams[batch_idxs] bbsz_offsets.resize_(new_bsz, 1) cand_bbsz_idx = cand_beams.add(bbsz_offsets) cand_scores = cand_scores[batch_idxs] cand_indices = cand_indices[batch_idxs] if prefix_tokens is not None: prefix_tokens = prefix_tokens[batch_idxs] src_lengths = src_lengths[batch_idxs] cands_to_ignore = cands_to_ignore[batch_idxs] scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) if attn is not None: attn = attn.view(bsz, -1)[batch_idxs].view( new_bsz * beam_size, attn.size(1), -1 ) bsz = new_bsz else: batch_idxs = None # Set active_mask so that values > cand_size indicate eos hypos # and values < cand_size indicate candidate active hypos. # After, the min values per row are the top candidate active hypos # Rewrite the operator since the element wise or is not supported in torchscript. eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size])) active_mask = torch.add( eos_mask.type_as(cand_offsets) * cand_size, cand_offsets[: eos_mask.size(1)], ) # get the top beam_size active hypotheses, which are just # the hypos with the smallest values in active_mask. # {active_hypos} indicates which {beam_size} hypotheses # from the list of {2 * beam_size} candidates were # selected. Shapes: (batch size, beam size) new_cands_to_ignore, active_hypos = torch.topk( active_mask, k=beam_size, dim=1, largest=False ) # update cands_to_ignore to ignore any finalized hypos. cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size] # Make sure there is at least one active item for each sentence in the batch. assert (~cands_to_ignore).any(dim=1).all() # update cands_to_ignore to ignore any finalized hypos # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam # can be selected more than once). active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos) active_scores = torch.gather(cand_scores, dim=1, index=active_hypos) active_bbsz_idx = active_bbsz_idx.view(-1) active_scores = active_scores.view(-1) # copy tokens and scores for active hypotheses # Set the tokens for each beam (can select the same row more than once) tokens[:, : step + 1] = torch.index_select( tokens[:, : step + 1], dim=0, index=active_bbsz_idx ) # Select the next token for each of them tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather( cand_indices, dim=1, index=active_hypos ) if step > 0: scores[:, :step] = torch.index_select( scores[:, :step], dim=0, index=active_bbsz_idx ) scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather( cand_scores, dim=1, index=active_hypos ) # Update constraints based on which candidates were selected for the next beam self.search.update_constraints(active_hypos) # copy attention for active hypotheses if attn is not None: attn[:, :, : step + 2] = torch.index_select( attn[:, :, : step + 2], dim=0, index=active_bbsz_idx ) # reorder incremental state in decoder reorder_state = active_bbsz_idx # if self.ctc_weight > 0: # accum_best_id = torch.gather(cand_indices, dim=1, index=active_hypos) # ctc_state = ctc_scorer.index_select_state( # ctc_state, accum_best_id # ) # sort by score descending for sent in range(len(finalized)): scores = torch.tensor( [float(elem["score"].item()) for elem in finalized[sent]] ) _, sorted_scores_indices = torch.sort(scores, descending=True) finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices] finalized[sent] = torch.jit.annotate( List[Dict[str, Tensor]], finalized[sent] ) return finalized def _prefix_tokens( self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int ): """Handle prefix tokens""" prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1) prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) prefix_mask = prefix_toks.ne(self.pad) lprobs[prefix_mask] = torch.min(prefix_lprobs) - 1 lprobs[prefix_mask] = lprobs[prefix_mask].scatter( -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask] ) # if prefix includes eos, then we should make sure tokens and # scores are the same across all beams eos_mask = prefix_toks.eq(self.eos) if eos_mask.any(): # validate that the first beam matches the prefix first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[ :, 0, 1 : step + 1 ] eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] assert (first_beam == target_prefix).all() # copy tokens, scores and lprobs from the first beam to all beams tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size) scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size) lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size) return lprobs, tokens, scores def replicate_first_beam(self, tensor, mask, beam_size: int): tensor = tensor.view(-1, beam_size, tensor.size(-1)) tensor[mask] = tensor[mask][:, :1, :] return tensor.view(-1, tensor.size(-1)) def finalize_hypos( self, step: int, bbsz_idx, eos_scores, tokens, scores, finalized: List[List[Dict[str, Tensor]]], finished: List[bool], beam_size: int, attn: Optional[Tensor], src_lengths, max_len: int, ): """Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly. A sentence is finalized when {beam_size} finished items have been collected for it. Returns number of sentences (not beam items) being finalized. These will be removed from the batch and not processed further. Args: bbsz_idx (Tensor): """ assert bbsz_idx.numel() == eos_scores.numel() # clone relevant token and attention tensors. # tokens is (batch * beam, max_len). So the index_select # gets the newly EOS rows, then selects cols 1..{step + 2} tokens_clone = tokens.index_select(0, bbsz_idx)[ :, 1 : step + 2 ] # skip the first index, which is EOS tokens_clone[:, step] = self.eos attn_clone = ( attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2] if attn is not None else None ) # compute scores per token position pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1] pos_scores[:, step] = eos_scores # convert from cumulative to per-position scores pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] # normalize sentence-level scores if self.normalize_scores: eos_scores /= (step + 1) ** self.len_penalty # cum_unfin records which sentences in the batch are finished. # It helps match indexing between (a) the original sentences # in the batch and (b) the current, possibly-reduced set of # sentences. cum_unfin: List[int] = [] prev = 0 for f in finished: if f: prev += 1 else: cum_unfin.append(prev) cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx) unfin_idx = bbsz_idx // beam_size sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx) # Create a set of "{sent}{unfin_idx}", where # "unfin_idx" is the index in the current (possibly reduced) # list of sentences, and "sent" is the index in the original, # unreduced batch # For every finished beam item # sentence index in the current (possibly reduced) batch seen = (sent << 32) + unfin_idx unique_seen: List[int] = torch.unique(seen).tolist() if self.match_source_len: condition = step > torch.index_select(src_lengths, 0, unfin_idx) eos_scores = torch.where(condition, torch.tensor(-math.inf), eos_scores) sent_list: List[int] = sent.tolist() for i in range(bbsz_idx.size()[0]): # An input sentence (among those in a batch) is finished when # beam_size hypotheses have been collected for it if len(finalized[sent_list[i]]) < beam_size: if attn_clone is not None: # remove padding tokens from attn scores hypo_attn = attn_clone[i] else: hypo_attn = torch.empty(0) finalized[sent_list[i]].append( { "tokens": tokens_clone[i], "score": eos_scores[i], "attention": hypo_attn, # src_len x tgt_len "alignment": torch.empty(0), "positional_scores": pos_scores[i], } ) newly_finished: List[int] = [] for unique_s in unique_seen: # check termination conditions for this sentence unique_sent: int = unique_s >> 32 unique_unfin_idx: int = unique_s - (unique_sent << 32) if not finished[unique_sent] and self.is_finished( step, unique_unfin_idx, max_len, len(finalized[unique_sent]), beam_size ): finished[unique_sent] = True newly_finished.append(unique_unfin_idx) return newly_finished def is_finished( self, step: int, unfin_idx: int, max_len: int, finalized_sent_len: int, beam_size: int, ): """ Check whether decoding for a sentence is finished, which occurs when the list of finalized sentences has reached the beam size, or when we reach the maximum length. """ assert finalized_sent_len <= beam_size if finalized_sent_len == beam_size or step == max_len: return True return False class EnsembleModel(nn.Module): """A wrapper around an ensemble of models.""" def __init__(self, models): super().__init__() self.models_size = len(models) # method '__len__' is not supported in ModuleList for torch script self.single_model = models[0] self.models = nn.ModuleList(models) self.has_incremental: bool = False if all( hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder) for m in models ): self.has_incremental = True def forward(self): pass def has_encoder(self): return hasattr(self.single_model, "encoder") def is_t5_structure(self): t5_structure = hasattr(self.single_model, "text_encoder_prenet") and hasattr(self.single_model, "speech_encoder_prenet") or \ hasattr(self.single_model, "encoder_prenet") and hasattr(self.single_model, "encoder_prenet") return t5_structure def has_incremental_states(self): return self.has_incremental def max_decoder_positions(self): return min([m.max_decoder_positions() for m in self.models if hasattr(m, "max_decoder_positions")] + [sys.maxsize]) @torch.jit.export def forward_encoder(self, net_input: Dict[str, Tensor]): if not self.has_encoder(): return None elif self.is_t5_structure(): return [model.forward_encoder_torchscript(net_input) for model in self.models] else: return [model.encoder.forward_torchscript(net_input) for model in self.models] @torch.jit.export def forward_decoder( self, tokens, encoder_outs: List[Dict[str, List[Tensor]]], incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], temperature: float = 1.0, ): log_probs = [] avg_attn: Optional[Tensor] = None encoder_out: Optional[Dict[str, List[Tensor]]] = None for i, model in enumerate(self.models): if self.has_encoder(): encoder_out = encoder_outs[i] # decode each model if self.has_incremental_states(): if self.is_t5_structure: decoder_out = model.forward_decoder( tokens, encoder_out=encoder_out, incremental_state=incremental_states[i] ) else: decoder_out = model.decoder.forward( tokens, encoder_out=encoder_out, incremental_state=incremental_states[i], ) else: if hasattr(model, "decoder"): decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out) else: decoder_out = model.forward(tokens) attn: Optional[Tensor] = None decoder_len = len(decoder_out) if decoder_len > 1 and decoder_out[1] is not None: if isinstance(decoder_out[1], Tensor): attn = decoder_out[1] else: attn_holder = decoder_out[1]["attn"] if isinstance(attn_holder, Tensor): attn = attn_holder elif attn_holder is not None: attn = attn_holder[0] if attn is not None: attn = attn[:, -1, :] decoder_out_tuple = ( decoder_out[0][:, -1:, :].div_(temperature), None if decoder_len <= 1 else decoder_out[1], ) probs = model.get_normalized_probs( decoder_out_tuple, log_probs=True, sample=None ) probs = probs[:, -1, :] if self.models_size == 1: return probs, attn log_probs.append(probs) if attn is not None: if avg_attn is None: avg_attn = attn else: avg_attn.add_(attn) avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log( self.models_size ) if avg_attn is not None: avg_attn.div_(self.models_size) return avg_probs, avg_attn @torch.jit.export def reorder_encoder_out( self, encoder_outs: Optional[List[Dict[str, List[Tensor]]]], new_order ): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ new_outs: List[Dict[str, List[Tensor]]] = [] if not self.has_encoder(): return new_outs for i, model in enumerate(self.models): assert encoder_outs is not None new_outs.append( model.encoder.reorder_encoder_out(encoder_outs[i], new_order) ) return new_outs @torch.jit.export def reorder_incremental_state( self, incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], new_order, ): if not self.has_incremental_states(): return for i, model in enumerate(self.models): model.decoder.reorder_incremental_state_scripting( incremental_states[i], new_order ) class SequenceGeneratorWithAlignment(SequenceGenerator): def __init__( self, models, tgt_dict, left_pad_target=False, print_alignment="hard", **kwargs ): """Generates translations of a given source sentence. Produces alignments following "Jointly Learning to Align and Translate with Transformer Models" (Garg et al., EMNLP 2019). Args: left_pad_target (bool, optional): Whether or not the hypothesis should be left padded or not when they are teacher forced for generating alignments. """ super().__init__(EnsembleModelWithAlignment(models), tgt_dict, **kwargs) self.left_pad_target = left_pad_target if print_alignment == "hard": self.extract_alignment = utils.extract_hard_alignment elif print_alignment == "soft": self.extract_alignment = utils.extract_soft_alignment @torch.no_grad() def generate(self, models, sample, **kwargs): finalized = super()._generate(sample, **kwargs) src_tokens = sample["net_input"]["src_tokens"] bsz = src_tokens.shape[0] beam_size = self.beam_size ( src_tokens, src_lengths, prev_output_tokens, tgt_tokens, ) = self._prepare_batch_for_alignment(sample, finalized) if any(getattr(m, "full_context_alignment", False) for m in self.model.models): attn = self.model.forward_align(src_tokens, src_lengths, prev_output_tokens) else: attn = [ finalized[i // beam_size][i % beam_size]["attention"].transpose(1, 0) for i in range(bsz * beam_size) ] if src_tokens.device != "cpu": src_tokens = src_tokens.to("cpu") tgt_tokens = tgt_tokens.to("cpu") attn = [i.to("cpu") for i in attn] # Process the attn matrix to extract hard alignments. for i in range(bsz * beam_size): alignment = self.extract_alignment( attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos ) finalized[i // beam_size][i % beam_size]["alignment"] = alignment return finalized def _prepare_batch_for_alignment(self, sample, hypothesis): src_tokens = sample["net_input"]["src_tokens"] bsz = src_tokens.shape[0] src_tokens = ( src_tokens[:, None, :] .expand(-1, self.beam_size, -1) .contiguous() .view(bsz * self.beam_size, -1) ) src_lengths = sample["net_input"]["src_lengths"] src_lengths = ( src_lengths[:, None] .expand(-1, self.beam_size) .contiguous() .view(bsz * self.beam_size) ) prev_output_tokens = data_utils.collate_tokens( [beam["tokens"] for example in hypothesis for beam in example], self.pad, self.eos, self.left_pad_target, move_eos_to_beginning=True, ) tgt_tokens = data_utils.collate_tokens( [beam["tokens"] for example in hypothesis for beam in example], self.pad, self.eos, self.left_pad_target, move_eos_to_beginning=False, ) return src_tokens, src_lengths, prev_output_tokens, tgt_tokens class EnsembleModelWithAlignment(EnsembleModel): """A wrapper around an ensemble of models.""" def __init__(self, models): super().__init__(models) def forward_align(self, src_tokens, src_lengths, prev_output_tokens): avg_attn = None for model in self.models: decoder_out = model(src_tokens, src_lengths, prev_output_tokens) attn = decoder_out[1]["attn"][0] if avg_attn is None: avg_attn = attn else: avg_attn.add_(attn) if len(self.models) > 1: avg_attn.div_(len(self.models)) return avg_attn