# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from collections import namedtuple import numpy as np import torch from fairseq import utils DecoderOut = namedtuple( "IterativeRefinementDecoderOut", ["output_tokens", "output_scores", "attn", "step", "max_step", "history"], ) class IterativeRefinementGenerator(object): def __init__( self, tgt_dict, models=None, eos_penalty=0.0, max_iter=10, max_ratio=2, beam_size=1, decoding_format=None, retain_dropout=False, adaptive=True, retain_history=False, reranking=False, ): """ Generates translations based on iterative refinement. Args: tgt_dict: target dictionary eos_penalty: if > 0.0, it penalized early-stopping in decoding max_iter: maximum number of refinement iterations max_ratio: generate sequences of maximum length ax, where x is the source length decoding_format: decoding mode in {'unigram', 'ensemble', 'vote', 'dp', 'bs'} retain_dropout: retaining dropout in the inference adaptive: decoding with early stop """ self.bos = tgt_dict.bos() self.pad = tgt_dict.pad() self.unk = tgt_dict.unk() self.eos = tgt_dict.eos() self.vocab_size = len(tgt_dict) self.eos_penalty = eos_penalty self.max_iter = max_iter self.max_ratio = max_ratio self.beam_size = beam_size self.reranking = reranking self.decoding_format = decoding_format self.retain_dropout = retain_dropout self.retain_history = retain_history self.adaptive = adaptive self.models = models def generate_batched_itr( self, data_itr, maxlen_a=None, maxlen_b=None, cuda=False, timer=None, prefix_size=0, ): """Iterate over a batched dataset and yield individual translations. Args: maxlen_a/b: generate sequences of maximum length ax + b, where x is the source sentence length. cuda: use GPU for generation timer: StopwatchMeter for timing generations. """ for sample in data_itr: if "net_input" not in sample: continue if timer is not None: timer.start() with torch.no_grad(): hypos = self.generate( self.models, sample, prefix_tokens=sample["target"][:, :prefix_size] if prefix_size > 0 else None, ) if timer is not None: timer.stop(sample["ntokens"]) for i, id in enumerate(sample["id"]): # remove padding src = utils.strip_pad(sample["net_input"]["src_tokens"][i, :], self.pad) ref = utils.strip_pad(sample["target"][i, :], self.pad) yield id, src, ref, hypos[i] @torch.no_grad() def generate(self, models, sample, prefix_tokens=None, constraints=None): if constraints is not None: raise NotImplementedError( "Constrained decoding with the IterativeRefinementGenerator is not supported" ) # TODO: iterative refinement generator does not support ensemble for now. if not self.retain_dropout: for model in models: model.eval() model, reranker = models[0], None if self.reranking: assert len(models) > 1, "Assuming the last checkpoint is the reranker" assert ( self.beam_size > 1 ), "Reranking requires multiple translation for each example" reranker = models[-1] models = models[:-1] if len(models) > 1 and hasattr(model, "enable_ensemble"): assert model.allow_ensemble, "{} does not support ensembling".format( model.__class__.__name__ ) model.enable_ensemble(models) # TODO: better encoder inputs? src_tokens = sample["net_input"]["src_tokens"] src_lengths = sample["net_input"]["src_lengths"] bsz, src_len = src_tokens.size() # initialize encoder_out = model.forward_encoder([src_tokens, src_lengths]) prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens) if self.beam_size > 1: assert ( model.allow_length_beam ), "{} does not support decoding with length beam.".format( model.__class__.__name__ ) # regenerate data based on length-beam length_beam_order = ( utils.new_arange(src_tokens, self.beam_size, bsz).t().reshape(-1) ) encoder_out = model.encoder.reorder_encoder_out( encoder_out, length_beam_order ) prev_decoder_out = model.regenerate_length_beam( prev_decoder_out, self.beam_size ) bsz = bsz * self.beam_size sent_idxs = torch.arange(bsz) prev_output_tokens = prev_decoder_out.output_tokens.clone() if self.retain_history: prev_decoder_out = prev_decoder_out._replace(history=[prev_output_tokens]) finalized = [[] for _ in range(bsz)] def is_a_loop(x, y, s, a): b, l_x, l_y = x.size(0), x.size(1), y.size(1) if l_x > l_y: y = torch.cat([y, x.new_zeros(b, l_x - l_y).fill_(self.pad)], 1) s = torch.cat([s, s.new_zeros(b, l_x - l_y)], 1) if a is not None: a = torch.cat([a, a.new_zeros(b, l_x - l_y, a.size(2))], 1) elif l_x < l_y: x = torch.cat([x, y.new_zeros(b, l_y - l_x).fill_(self.pad)], 1) return (x == y).all(1), y, s, a def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn): cutoff = prev_out_token.ne(self.pad) tokens = prev_out_token[cutoff] if prev_out_score is None: scores, score = None, None else: scores = prev_out_score[cutoff] score = scores.mean() if prev_out_attn is None: hypo_attn, alignment = None, None else: hypo_attn = prev_out_attn[cutoff] alignment = hypo_attn.max(dim=1)[1] return { "steps": step, "tokens": tokens, "positional_scores": scores, "score": score, "hypo_attn": hypo_attn, "alignment": alignment, } for step in range(self.max_iter + 1): decoder_options = { "eos_penalty": self.eos_penalty, "max_ratio": self.max_ratio, "decoding_format": self.decoding_format, } prev_decoder_out = prev_decoder_out._replace( step=step, max_step=self.max_iter + 1, ) decoder_out = model.forward_decoder( prev_decoder_out, encoder_out, **decoder_options ) if self.adaptive: # terminate if there is a loop terminated, out_tokens, out_scores, out_attn = is_a_loop( prev_output_tokens, decoder_out.output_tokens, decoder_out.output_scores, decoder_out.attn, ) decoder_out = decoder_out._replace( output_tokens=out_tokens, output_scores=out_scores, attn=out_attn, ) else: terminated = decoder_out.output_tokens.new_zeros( decoder_out.output_tokens.size(0) ).bool() if step == self.max_iter: # reach last iteration, terminate terminated.fill_(1) # collect finalized sentences finalized_idxs = sent_idxs[terminated] finalized_tokens = decoder_out.output_tokens[terminated] finalized_scores = decoder_out.output_scores[terminated] finalized_attn = ( None if (decoder_out.attn is None or decoder_out.attn.size(0) == 0) else decoder_out.attn[terminated] ) if self.retain_history: finalized_history_tokens = [h[terminated] for h in decoder_out.history] for i in range(finalized_idxs.size(0)): finalized[finalized_idxs[i]] = [ finalized_hypos( step, finalized_tokens[i], finalized_scores[i], None if finalized_attn is None else finalized_attn[i], ) ] if self.retain_history: finalized[finalized_idxs[i]][0]["history"] = [] for j in range(len(finalized_history_tokens)): finalized[finalized_idxs[i]][0]["history"].append( finalized_hypos( step, finalized_history_tokens[j][i], None, None ) ) # check if all terminated if terminated.sum() == terminated.size(0): break # for next step not_terminated = ~terminated prev_decoder_out = decoder_out._replace( output_tokens=decoder_out.output_tokens[not_terminated], output_scores=decoder_out.output_scores[not_terminated], attn=decoder_out.attn[not_terminated] if (decoder_out.attn is not None and decoder_out.attn.size(0) > 0) else None, history=[h[not_terminated] for h in decoder_out.history] if decoder_out.history is not None else None, ) encoder_out = model.encoder.reorder_encoder_out( encoder_out, not_terminated.nonzero(as_tuple=False).squeeze() ) sent_idxs = sent_idxs[not_terminated] prev_output_tokens = prev_decoder_out.output_tokens.clone() if self.beam_size > 1: if reranker is not None: finalized = self.rerank( reranker, finalized, [src_tokens, src_lengths], self.beam_size ) # aggregate information from length beam finalized = [ finalized[ np.argmax( [ finalized[self.beam_size * i + j][0]["score"] for j in range(self.beam_size) ] ) + self.beam_size * i ] for i in range(len(finalized) // self.beam_size) ] return finalized def rerank(self, reranker, finalized, encoder_input, beam_size): def rebuild_batch(finalized): finalized_tokens = [f[0]["tokens"] for f in finalized] finalized_maxlen = max(f.size(0) for f in finalized_tokens) final_output_tokens = ( finalized_tokens[0] .new_zeros(len(finalized_tokens), finalized_maxlen) .fill_(self.pad) ) for i, f in enumerate(finalized_tokens): final_output_tokens[i, : f.size(0)] = f return final_output_tokens final_output_tokens = rebuild_batch(finalized) final_output_tokens[ :, 0 ] = self.eos # autoregressive model assumes starting with EOS reranker_encoder_out = reranker.encoder(*encoder_input) length_beam_order = ( utils.new_arange( final_output_tokens, beam_size, reranker_encoder_out.encoder_out.size(1) ) .t() .reshape(-1) ) reranker_encoder_out = reranker.encoder.reorder_encoder_out( reranker_encoder_out, length_beam_order ) reranking_scores = reranker.get_normalized_probs( reranker.decoder(final_output_tokens[:, :-1], reranker_encoder_out), True, None, ) reranking_scores = reranking_scores.gather(2, final_output_tokens[:, 1:, None]) reranking_masks = final_output_tokens[:, 1:].ne(self.pad) reranking_scores = ( reranking_scores[:, :, 0].masked_fill_(~reranking_masks, 0).sum(1) ) reranking_scores = reranking_scores / reranking_masks.sum(1).type_as( reranking_scores ) for i in range(len(finalized)): finalized[i][0]["score"] = reranking_scores[i] return finalized