OFA-Image_Caption / fairseq /fairseq /iterative_refinement_generator.py
JustinLin610
update
8437114
raw history blame
No virus
13.2 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import namedtuple
import numpy as np
import torch
from fairseq import utils
DecoderOut = namedtuple(
"IterativeRefinementDecoderOut",
["output_tokens", "output_scores", "attn", "step", "max_step", "history"],
)
class IterativeRefinementGenerator(object):
def __init__(
self,
tgt_dict,
models=None,
eos_penalty=0.0,
max_iter=10,
max_ratio=2,
beam_size=1,
decoding_format=None,
retain_dropout=False,
adaptive=True,
retain_history=False,
reranking=False,
):
"""
Generates translations based on iterative refinement.
Args:
tgt_dict: target dictionary
eos_penalty: if > 0.0, it penalized early-stopping in decoding
max_iter: maximum number of refinement iterations
max_ratio: generate sequences of maximum length ax, where x is the source length
decoding_format: decoding mode in {'unigram', 'ensemble', 'vote', 'dp', 'bs'}
retain_dropout: retaining dropout in the inference
adaptive: decoding with early stop
"""
self.bos = tgt_dict.bos()
self.pad = tgt_dict.pad()
self.unk = tgt_dict.unk()
self.eos = tgt_dict.eos()
self.vocab_size = len(tgt_dict)
self.eos_penalty = eos_penalty
self.max_iter = max_iter
self.max_ratio = max_ratio
self.beam_size = beam_size
self.reranking = reranking
self.decoding_format = decoding_format
self.retain_dropout = retain_dropout
self.retain_history = retain_history
self.adaptive = adaptive
self.models = models
def generate_batched_itr(
self,
data_itr,
maxlen_a=None,
maxlen_b=None,
cuda=False,
timer=None,
prefix_size=0,
):
"""Iterate over a batched dataset and yield individual translations.
Args:
maxlen_a/b: generate sequences of maximum length ax + b,
where x is the source sentence length.
cuda: use GPU for generation
timer: StopwatchMeter for timing generations.
"""
for sample in data_itr:
if "net_input" not in sample:
continue
if timer is not None:
timer.start()
with torch.no_grad():
hypos = self.generate(
self.models,
sample,
prefix_tokens=sample["target"][:, :prefix_size]
if prefix_size > 0
else None,
)
if timer is not None:
timer.stop(sample["ntokens"])
for i, id in enumerate(sample["id"]):
# remove padding
src = utils.strip_pad(sample["net_input"]["src_tokens"][i, :], self.pad)
ref = utils.strip_pad(sample["target"][i, :], self.pad)
yield id, src, ref, hypos[i]
@torch.no_grad()
def generate(self, models, sample, prefix_tokens=None, constraints=None):
if constraints is not None:
raise NotImplementedError(
"Constrained decoding with the IterativeRefinementGenerator is not supported"
)
# TODO: iterative refinement generator does not support ensemble for now.
if not self.retain_dropout:
for model in models:
model.eval()
model, reranker = models[0], None
if self.reranking:
assert len(models) > 1, "Assuming the last checkpoint is the reranker"
assert (
self.beam_size > 1
), "Reranking requires multiple translation for each example"
reranker = models[-1]
models = models[:-1]
if len(models) > 1 and hasattr(model, "enable_ensemble"):
assert model.allow_ensemble, "{} does not support ensembling".format(
model.__class__.__name__
)
model.enable_ensemble(models)
# TODO: better encoder inputs?
src_tokens = sample["net_input"]["src_tokens"]
src_lengths = sample["net_input"]["src_lengths"]
bsz, src_len = src_tokens.size()
# initialize
encoder_out = model.forward_encoder([src_tokens, src_lengths])
prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens)
if self.beam_size > 1:
assert (
model.allow_length_beam
), "{} does not support decoding with length beam.".format(
model.__class__.__name__
)
# regenerate data based on length-beam
length_beam_order = (
utils.new_arange(src_tokens, self.beam_size, bsz).t().reshape(-1)
)
encoder_out = model.encoder.reorder_encoder_out(
encoder_out, length_beam_order
)
prev_decoder_out = model.regenerate_length_beam(
prev_decoder_out, self.beam_size
)
bsz = bsz * self.beam_size
sent_idxs = torch.arange(bsz)
prev_output_tokens = prev_decoder_out.output_tokens.clone()
if self.retain_history:
prev_decoder_out = prev_decoder_out._replace(history=[prev_output_tokens])
finalized = [[] for _ in range(bsz)]
def is_a_loop(x, y, s, a):
b, l_x, l_y = x.size(0), x.size(1), y.size(1)
if l_x > l_y:
y = torch.cat([y, x.new_zeros(b, l_x - l_y).fill_(self.pad)], 1)
s = torch.cat([s, s.new_zeros(b, l_x - l_y)], 1)
if a is not None:
a = torch.cat([a, a.new_zeros(b, l_x - l_y, a.size(2))], 1)
elif l_x < l_y:
x = torch.cat([x, y.new_zeros(b, l_y - l_x).fill_(self.pad)], 1)
return (x == y).all(1), y, s, a
def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
cutoff = prev_out_token.ne(self.pad)
tokens = prev_out_token[cutoff]
if prev_out_score is None:
scores, score = None, None
else:
scores = prev_out_score[cutoff]
score = scores.mean()
if prev_out_attn is None:
hypo_attn, alignment = None, None
else:
hypo_attn = prev_out_attn[cutoff]
alignment = hypo_attn.max(dim=1)[1]
return {
"steps": step,
"tokens": tokens,
"positional_scores": scores,
"score": score,
"hypo_attn": hypo_attn,
"alignment": alignment,
}
for step in range(self.max_iter + 1):
decoder_options = {
"eos_penalty": self.eos_penalty,
"max_ratio": self.max_ratio,
"decoding_format": self.decoding_format,
}
prev_decoder_out = prev_decoder_out._replace(
step=step,
max_step=self.max_iter + 1,
)
decoder_out = model.forward_decoder(
prev_decoder_out, encoder_out, **decoder_options
)
if self.adaptive:
# terminate if there is a loop
terminated, out_tokens, out_scores, out_attn = is_a_loop(
prev_output_tokens,
decoder_out.output_tokens,
decoder_out.output_scores,
decoder_out.attn,
)
decoder_out = decoder_out._replace(
output_tokens=out_tokens,
output_scores=out_scores,
attn=out_attn,
)
else:
terminated = decoder_out.output_tokens.new_zeros(
decoder_out.output_tokens.size(0)
).bool()
if step == self.max_iter: # reach last iteration, terminate
terminated.fill_(1)
# collect finalized sentences
finalized_idxs = sent_idxs[terminated]
finalized_tokens = decoder_out.output_tokens[terminated]
finalized_scores = decoder_out.output_scores[terminated]
finalized_attn = (
None
if (decoder_out.attn is None or decoder_out.attn.size(0) == 0)
else decoder_out.attn[terminated]
)
if self.retain_history:
finalized_history_tokens = [h[terminated] for h in decoder_out.history]
for i in range(finalized_idxs.size(0)):
finalized[finalized_idxs[i]] = [
finalized_hypos(
step,
finalized_tokens[i],
finalized_scores[i],
None if finalized_attn is None else finalized_attn[i],
)
]
if self.retain_history:
finalized[finalized_idxs[i]][0]["history"] = []
for j in range(len(finalized_history_tokens)):
finalized[finalized_idxs[i]][0]["history"].append(
finalized_hypos(
step, finalized_history_tokens[j][i], None, None
)
)
# check if all terminated
if terminated.sum() == terminated.size(0):
break
# for next step
not_terminated = ~terminated
prev_decoder_out = decoder_out._replace(
output_tokens=decoder_out.output_tokens[not_terminated],
output_scores=decoder_out.output_scores[not_terminated],
attn=decoder_out.attn[not_terminated]
if (decoder_out.attn is not None and decoder_out.attn.size(0) > 0)
else None,
history=[h[not_terminated] for h in decoder_out.history]
if decoder_out.history is not None
else None,
)
encoder_out = model.encoder.reorder_encoder_out(
encoder_out, not_terminated.nonzero(as_tuple=False).squeeze()
)
sent_idxs = sent_idxs[not_terminated]
prev_output_tokens = prev_decoder_out.output_tokens.clone()
if self.beam_size > 1:
if reranker is not None:
finalized = self.rerank(
reranker, finalized, [src_tokens, src_lengths], self.beam_size
)
# aggregate information from length beam
finalized = [
finalized[
np.argmax(
[
finalized[self.beam_size * i + j][0]["score"]
for j in range(self.beam_size)
]
)
+ self.beam_size * i
]
for i in range(len(finalized) // self.beam_size)
]
return finalized
def rerank(self, reranker, finalized, encoder_input, beam_size):
def rebuild_batch(finalized):
finalized_tokens = [f[0]["tokens"] for f in finalized]
finalized_maxlen = max(f.size(0) for f in finalized_tokens)
final_output_tokens = (
finalized_tokens[0]
.new_zeros(len(finalized_tokens), finalized_maxlen)
.fill_(self.pad)
)
for i, f in enumerate(finalized_tokens):
final_output_tokens[i, : f.size(0)] = f
return final_output_tokens
final_output_tokens = rebuild_batch(finalized)
final_output_tokens[
:, 0
] = self.eos # autoregressive model assumes starting with EOS
reranker_encoder_out = reranker.encoder(*encoder_input)
length_beam_order = (
utils.new_arange(
final_output_tokens, beam_size, reranker_encoder_out.encoder_out.size(1)
)
.t()
.reshape(-1)
)
reranker_encoder_out = reranker.encoder.reorder_encoder_out(
reranker_encoder_out, length_beam_order
)
reranking_scores = reranker.get_normalized_probs(
reranker.decoder(final_output_tokens[:, :-1], reranker_encoder_out),
True,
None,
)
reranking_scores = reranking_scores.gather(2, final_output_tokens[:, 1:, None])
reranking_masks = final_output_tokens[:, 1:].ne(self.pad)
reranking_scores = (
reranking_scores[:, :, 0].masked_fill_(~reranking_masks, 0).sum(1)
)
reranking_scores = reranking_scores / reranking_masks.sum(1).type_as(
reranking_scores
)
for i in range(len(finalized)):
finalized[i][0]["score"] = reranking_scores[i]
return finalized