Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from dataclasses import dataclass, field | |
import torch | |
from fairseq import utils | |
from fairseq.data import LanguagePairDataset | |
from fairseq.dataclass import ChoiceEnum | |
from fairseq.tasks import register_task | |
from fairseq.tasks.translation import ( | |
TranslationConfig, | |
TranslationTask, | |
load_langpair_dataset, | |
) | |
from fairseq.utils import new_arange | |
NOISE_CHOICES = ChoiceEnum(["random_delete", "random_mask", "no_noise", "full_mask"]) | |
class TranslationLevenshteinConfig(TranslationConfig): | |
noise: NOISE_CHOICES = field( | |
default="random_delete", | |
metadata={"help": "type of noise"}, | |
) | |
class TranslationLevenshteinTask(TranslationTask): | |
""" | |
Translation (Sequence Generation) task for Levenshtein Transformer | |
See `"Levenshtein Transformer" <https://arxiv.org/abs/1905.11006>`_. | |
""" | |
cfg: TranslationLevenshteinConfig | |
def load_dataset(self, split, epoch=1, combine=False, **kwargs): | |
"""Load a given dataset split. | |
Args: | |
split (str): name of the split (e.g., train, valid, test) | |
""" | |
paths = utils.split_paths(self.cfg.data) | |
assert len(paths) > 0 | |
data_path = paths[(epoch - 1) % len(paths)] | |
# infer langcode | |
src, tgt = self.cfg.source_lang, self.cfg.target_lang | |
self.datasets[split] = load_langpair_dataset( | |
data_path, | |
split, | |
src, | |
self.src_dict, | |
tgt, | |
self.tgt_dict, | |
combine=combine, | |
dataset_impl=self.cfg.dataset_impl, | |
upsample_primary=self.cfg.upsample_primary, | |
left_pad_source=self.cfg.left_pad_source, | |
left_pad_target=self.cfg.left_pad_target, | |
max_source_positions=self.cfg.max_source_positions, | |
max_target_positions=self.cfg.max_target_positions, | |
prepend_bos=True, | |
) | |
def inject_noise(self, target_tokens): | |
def _random_delete(target_tokens): | |
pad = self.tgt_dict.pad() | |
bos = self.tgt_dict.bos() | |
eos = self.tgt_dict.eos() | |
max_len = target_tokens.size(1) | |
target_mask = target_tokens.eq(pad) | |
target_score = target_tokens.clone().float().uniform_() | |
target_score.masked_fill_( | |
target_tokens.eq(bos) | target_tokens.eq(eos), 0.0 | |
) | |
target_score.masked_fill_(target_mask, 1) | |
target_score, target_rank = target_score.sort(1) | |
target_length = target_mask.size(1) - target_mask.float().sum( | |
1, keepdim=True | |
) | |
# do not delete <bos> and <eos> (we assign 0 score for them) | |
target_cutoff = ( | |
2 | |
+ ( | |
(target_length - 2) | |
* target_score.new_zeros(target_score.size(0), 1).uniform_() | |
).long() | |
) | |
target_cutoff = target_score.sort(1)[1] >= target_cutoff | |
prev_target_tokens = ( | |
target_tokens.gather(1, target_rank) | |
.masked_fill_(target_cutoff, pad) | |
.gather(1, target_rank.masked_fill_(target_cutoff, max_len).sort(1)[1]) | |
) | |
prev_target_tokens = prev_target_tokens[ | |
:, : prev_target_tokens.ne(pad).sum(1).max() | |
] | |
return prev_target_tokens | |
def _random_mask(target_tokens): | |
pad = self.tgt_dict.pad() | |
bos = self.tgt_dict.bos() | |
eos = self.tgt_dict.eos() | |
unk = self.tgt_dict.unk() | |
target_masks = ( | |
target_tokens.ne(pad) & target_tokens.ne(bos) & target_tokens.ne(eos) | |
) | |
target_score = target_tokens.clone().float().uniform_() | |
target_score.masked_fill_(~target_masks, 2.0) | |
target_length = target_masks.sum(1).float() | |
target_length = target_length * target_length.clone().uniform_() | |
target_length = target_length + 1 # make sure to mask at least one token. | |
_, target_rank = target_score.sort(1) | |
target_cutoff = new_arange(target_rank) < target_length[:, None].long() | |
prev_target_tokens = target_tokens.masked_fill( | |
target_cutoff.scatter(1, target_rank, target_cutoff), unk | |
) | |
return prev_target_tokens | |
def _full_mask(target_tokens): | |
pad = self.tgt_dict.pad() | |
bos = self.tgt_dict.bos() | |
eos = self.tgt_dict.eos() | |
unk = self.tgt_dict.unk() | |
target_mask = ( | |
target_tokens.eq(bos) | target_tokens.eq(eos) | target_tokens.eq(pad) | |
) | |
return target_tokens.masked_fill(~target_mask, unk) | |
if self.cfg.noise == "random_delete": | |
return _random_delete(target_tokens) | |
elif self.cfg.noise == "random_mask": | |
return _random_mask(target_tokens) | |
elif self.cfg.noise == "full_mask": | |
return _full_mask(target_tokens) | |
elif self.cfg.noise == "no_noise": | |
return target_tokens | |
else: | |
raise NotImplementedError | |
def build_generator(self, models, args, **unused): | |
# add models input to match the API for SequenceGenerator | |
from fairseq.iterative_refinement_generator import IterativeRefinementGenerator | |
return IterativeRefinementGenerator( | |
self.target_dictionary, | |
eos_penalty=getattr(args, "iter_decode_eos_penalty", 0.0), | |
max_iter=getattr(args, "iter_decode_max_iter", 10), | |
beam_size=getattr(args, "iter_decode_with_beam", 1), | |
reranking=getattr(args, "iter_decode_with_external_reranker", False), | |
decoding_format=getattr(args, "decoding_format", None), | |
adaptive=not getattr(args, "iter_decode_force_max_iter", False), | |
retain_history=getattr(args, "retain_iter_history", False), | |
) | |
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): | |
if constraints is not None: | |
# Though see Susanto et al. (ACL 2020): https://www.aclweb.org/anthology/2020.acl-main.325/ | |
raise NotImplementedError( | |
"Constrained decoding with the translation_lev task is not supported" | |
) | |
return LanguagePairDataset( | |
src_tokens, src_lengths, self.source_dictionary, append_bos=True | |
) | |
def train_step( | |
self, sample, model, criterion, optimizer, update_num, ignore_grad=False | |
): | |
model.train() | |
sample["prev_target"] = self.inject_noise(sample["target"]) | |
loss, sample_size, logging_output = criterion(model, sample) | |
if ignore_grad: | |
loss *= 0 | |
optimizer.backward(loss) | |
return loss, sample_size, logging_output | |
def valid_step(self, sample, model, criterion): | |
model.eval() | |
with torch.no_grad(): | |
sample["prev_target"] = self.inject_noise(sample["target"]) | |
loss, sample_size, logging_output = criterion(model, sample) | |
return loss, sample_size, logging_output | |