OFA-Image_Caption / fairseq /fairseq /tasks /online_backtranslation.py
raw history blame
28.6 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import json
import logging
import math
import os
from argparse import Namespace
from collections import OrderedDict, defaultdict
from pathlib import Path
from typing import Dict, Sequence, Tuple
from argparse import ArgumentError
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import fairseq
from fairseq import metrics, options, utils
from fairseq.data import (
from fairseq.sequence_generator import SequenceGenerator
from fairseq.tasks import register_task
from fairseq.tasks.translation import TranslationTask, load_langpair_dataset
logger = logging.getLogger(__name__)
class PiecewiseLinearFn:
"""Piecewise linear function. Can be configured with a string."""
def __init__(self, pieces: Sequence[Tuple[int, float]]):
assert pieces == sorted(
), f"PiecewiseLinearFn configuration should be sorted, received: {pieces}"
self.pieces = pieces
def __call__(self, x: int) -> float:
for i, (x_a, y_a) in enumerate(self.pieces[:-1]):
x_b, y_b = self.pieces[i + 1]
if x_a <= x <= x_b:
return y_a + (x - x_a) * (y_b - y_a) / (x_b - x_a)
return self.pieces[-1][1]
def from_string(configuration: str) -> "PiecewiseLinearFn":
Parse the configuration of lambda coefficient (for scheduling).
x = "3" # lambda will be a constant equal to x
x = "0:1,1000:0" # lambda will start from 1 and linearly decrease
# to 0 during the first 1000 iterations
x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000
# iterations, then will linearly increase to 1 until iteration 2000
if isinstance(configuration, float):
return PiecewiseLinearFn([(0, configuration)])
parts = configuration.split(",")
if len(parts) == 1:
v = float(configuration)
return PiecewiseLinearFn([(0, v)])
split = [s.split(":") for s in parts]
pieces = [(int(t), float(v)) for t, v in split]
return PiecewiseLinearFn(pieces)
except Exception:
raise ValueError(
f"Invalid PiecewiseLinearFn configuration: {configuration!r}"
def one() -> "PiecewiseLinearFn":
return PiecewiseLinearFn([(0, 1.0)])
class OnlineBackTranslationTask(TranslationTask):
def add_args(parser):
"""Add task-specific arguments to the parser."""
# fmt: off
# Generic translation args
parser.add_argument('data', help='colon separated path to data directories list, \
will be iterated upon during epochs in round-robin manner; \
however, valid and test data are always in the first directory to \
avoid the need for repeating them in all directories')
parser.add_argument('--mono-langs', metavar='MONO_LANGS',
help='monolingual languages for training')
parser.add_argument('--valid-lang-pairs', default=None, metavar='VALID_LANG_PAIRS',
help='language pairs for validation')
parser.add_argument('--load-alignments', action='store_true',
help='load the binarized alignments')
parser.add_argument('--left-pad-source', default='False', type=str, metavar='BOOL',
help='pad the source on the left')
parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL',
help='pad the target on the left')
parser.add_argument('--upsample-primary', default=1, type=int,
help='amount to upsample primary dataset')
parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence')
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
except ArgumentError:
# this might have already been defined. Once we transition this to hydra it should be fine to add it here.
parser.add_argument('--truncate-source', action='store_true', default=False,
help='truncate source to max-source-positions')
parser.add_argument('--num-batch-buckets', default=0, type=int, metavar='N',
help='if >0, then bucket source and target lengths into N '
'buckets and pad accordingly; this is useful on TPUs '
'to minimize the number of compilations')
# Denoising args
parser.add_argument('--max-word-shuffle-distance', default=3.0, type=float, metavar='N',
help='maximum word shuffle distance for denoising autoencoding data generation')
parser.add_argument('--word-dropout-prob', default=0.1, type=float, metavar='N',
help='word dropout probability for denoising autoencoding data generation')
parser.add_argument('--word-blanking-prob', default=0.2, type=float, metavar='N',
help='word blanking probability for denoising autoencoding data generation')
# Backtranslation args
parser.add_argument('--lambda-bt', default="1.0", type=str, metavar='N',
help='back-translation weight')
parser.add_argument('--lambda-dae', default="1.0", type=str, metavar='N',
help='denoising auto-encoder weight')
# Evaluation args
parser.add_argument('--generate-one-by-one', action='store_true',
help='generate one sentence at a time for backtranslation')
parser.add_argument('--eval-bleu', action='store_true',
help='evaluation with BLEU scores')
parser.add_argument('--eval-bleu-detok', type=str, default="space",
help='detokenize before computing BLEU (e.g., "moses"); '
'required if using --eval-bleu; use "space" to '
'disable detokenization; see fairseq.data.encoders '
'for other options')
parser.add_argument('--eval-bleu-detok-args', type=str, metavar='JSON',
help='args for building the tokenizer, if needed')
parser.add_argument('--eval-tokenized-bleu', action='store_true', default=False,
help='compute tokenized BLEU instead of sacrebleu')
parser.add_argument('--eval-bleu-remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE before computing BLEU')
parser.add_argument('--eval-bleu-args', type=str, metavar='JSON',
help='generation args for BLUE scoring, '
'e.g., \'{"beam": 4, "lenpen": 0.6}\'')
parser.add_argument('--eval-bleu-print-samples', action='store_true',
help='print sample generations during validation')
# fmt: on
def __init__(self, args, common_dict, mono_langs, valid_lang_pairs):
super().__init__(args, common_dict, common_dict)
self.common_dict = common_dict
self.mono_langs = mono_langs
self.valid_lang_pairs = valid_lang_pairs
# Start by showing samples
self._show_samples_ctr = self.SHOW_SAMPLES_INTERVAL
self.lambda_bt = PiecewiseLinearFn.from_string(args.lambda_bt)
self.lambda_dae = PiecewiseLinearFn.from_string(args.lambda_dae)
self.args = args
self.data = utils.split_paths(self.args.data)
if len(self.data) == 1:
shards = list(Path(self.data[0]).glob("shard*"))
if len(shards) > 0:
# keep this as strings, since it can also be a manifold path
old_data = self.data
self.data = [str(shard) for shard in shards]
logging.warning(f"Expanded data directory {old_data} to {self.data}")
def setup_task(cls, args, **kwargs):
"""Setup the task (e.g., load dictionaries).
args (argparse.Namespace): parsed command-line arguments
args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target)
paths = utils.split_paths(args.data)
assert len(paths) > 0
assert args.mono_langs is not None
mono_langs = args.mono_langs.split(",")
valid_lang_pairs = args.valid_lang_pairs.split(",")
# load dictionary
dict_path = os.path.join(paths[0], "dict.txt")
common_dict = cls.load_dictionary(dict_path)
return cls(args, common_dict, mono_langs, valid_lang_pairs)
def load_dataset(self, split, epoch=1, combine=False, **kwargs) -> FairseqDataset:
"""Load a given dataset split.
split (str): name of the split (e.g., train, valid, test)
if split == "train":
data_path = self.data[(epoch - 1) % len(self.data)]
dataset = self.load_train_dataset(data_path)
# valid/test should always be the same.
dataset = self.load_translation_dataset(split, self.data[0])
self.datasets[split] = dataset
return dataset
def load_train_dataset(self, data_path: str) -> FairseqDataset:
"""The training dataset is made of backtranslation dataset and denoising dataset."""
data = []
for lang in self.mono_langs:
train_path = os.path.join(data_path, lang, "train")
# TODO: could we do the BT using denoise sample ?
# this would half the data loading work
data.append((f"{lang}-BT", self.load_bt_dataset(train_path, lang)))
(f"{lang}-DENOISE", self.load_denoise_dataset(train_path, lang))
return RoundRobinZipDatasets(OrderedDict(data))
def _langpair_dataset(
self, src: FairseqDataset, tgt: FairseqDataset
) -> LanguagePairDataset:
return LanguagePairDataset(
# TODO: should we shuffle ? we are already sorting batch by sizes so ?
# shuffle=True,
def _prepend_lang_bos_to_target(
self, dataset: LanguagePairDataset, lang: str
) -> LanguagePairDataset:
bos = _lang_token_index(self.dictionary, lang)
return TransformEosLangPairDataset(
def load_bt_dataset(self, data_path: str, lang: str) -> FairseqDataset:
"""The BT dataset is generated with (tgt, tgt) pairs.
The actual translation to a (generated_src, tgt) pair
is done on the fly during training.
mono_dataset = data_utils.load_indexed_dataset(
data_path, self.common_dict, self.args.dataset_impl
assert mono_dataset is not None, f"No dataset found for {lang}"
mono_dataset_src = PrependTokenDataset(
mono_dataset, _lang_token_index(self.dictionary, lang)
mono_dataset_bt = self._langpair_dataset(mono_dataset_src, mono_dataset)
f"mono_lang = {lang} "
f"lang token index = {_lang_token_index(self.dictionary, lang)} "
f"lang token = {_lang_token(lang)}"
mono_dataset_bt = self._prepend_lang_bos_to_target(mono_dataset_bt, lang)
return mono_dataset_bt
def load_denoise_dataset(self, data_path: str, lang: str) -> FairseqDataset:
"""Classic denoising dataset"""
dataset = data_utils.load_indexed_dataset(
data_path, self.common_dict, self.args.dataset_impl
noisy_dataset = NoisingDataset(
noisy_dataset = PrependTokenDataset(
noisy_dataset, _lang_token_index(self.dictionary, lang)
clean_dataset = data_utils.load_indexed_dataset(
data_path, self.common_dict, self.args.dataset_impl
denoising_dataset = self._langpair_dataset(noisy_dataset, clean_dataset)
denoising_dataset = self._prepend_lang_bos_to_target(denoising_dataset, lang)
return denoising_dataset
def load_translation_dataset(
self, split: str, data_path: str, combine: bool = False
# only judging with one language pair for the moment,
# since ConcatDataset doesn't work as expected
assert len(self.valid_lang_pairs) == 1, "For now..."
valid_lang_pair = self.valid_lang_pairs[0]
src, tgt = valid_lang_pair.split("-")
# use the same function than TranslationTask
src_tgt_dt = load_langpair_dataset(
shuffle=(split != "test"),
prepend_bos_src=_lang_token_index(self.dictionary, src),
src_tgt_eos_dt = self._prepend_lang_bos_to_target(src_tgt_dt, tgt)
src_tgt_eos_dt.args = self.args
return src_tgt_eos_dt
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
raise NotImplementedError
def build_model(self, args):
# torch.autograd.set_detect_anomaly(True)
model = super().build_model(args)
add_secial_tokens_to_dict_and_model(self.common_dict, model, self.mono_langs)
self.sequence_generators = {}
for mono_lang in self.mono_langs:
self.sequence_generators[mono_lang] = SequenceGenerator(
# keep 1 to be able to prepend bos
max_len=model.max_decoder_positions() - 1,
if getattr(args, "eval_bleu", False):
assert getattr(args, "eval_bleu_detok", None) is not None, (
"--eval-bleu-detok is required if using --eval-bleu; "
"try --eval-bleu-detok=moses (or --eval-bleu-detok=space "
"to disable detokenization, e.g., when using sentencepiece)"
detok_args = json.loads(getattr(args, "eval_bleu_detok_args", "{}") or "{}")
self.tokenizer = encoders.build_tokenizer(
tokenizer=getattr(args, "eval_bleu_detok", None), **detok_args
gen_args = json.loads(getattr(args, "eval_bleu_args", "{}") or "{}")
self.bleu_sequence_generator = self.build_generator(
[model], Namespace(**gen_args)
return model
def max_positions(self):
"""Return the max sentence length allowed by the task."""
return (self.args.max_source_positions, self.args.max_target_positions)
def dictionary(self):
"""Return the source :class:`~fairseq.data.Dictionary`."""
return self.common_dict
def display_samples_once_in_a_while(self, smp, mono_lang, other_lang):
self._show_samples_ctr += 1
if self._show_samples_ctr < self.SHOW_SAMPLES_INTERVAL:
self._show_samples_ctr = 0
ln = smp["net_input"]["src_tokens"].shape[0]
f"(r:{self.args.distributed_rank}) : "
f"{other_lang} ---> {mono_lang} "
f"({other_lang} was generated by back-translation.) {ln} samples"
for i in range(min(ln, self.SHOW_SAMPLES_NUMBER)):
src_tokens = smp["net_input"]["src_tokens"][i]
tgt_tokens = smp["target"][i]
src_str = self.dictionary.string(src_tokens, "sentencepiece")
tgt_str = self.dictionary.string(tgt_tokens, "sentencepiece")
f"\n{i}\t\t[{other_lang} generated] {src_str}\n"
f"\t\t[{mono_lang} original ] {tgt_str}\n"
f"\t\t[ src tokens] {src_tokens}\n"
def backtranslate_sample(self, smp, orig_lang, other_lang) -> None:
* WARNING: smp is modified in place.
* At the start of this function, `smp` has the same input and target:
| smp['net_input']['src_tokens'] | smp['target'] |
| (from data) __en__ hello world | __en__ hello world |
* We call generator.generate(smp, bos_token = token("ro")),
and copy the result as input
* At the end, `smp` has the translation to other language.
| smp['net_input']['src_tokens'] | smp['target'] |
| (generated) __ro__ salut lume | __en__ hello world |
bos_token = _lang_token_index(self.dictionary, other_lang)
generated = self.sequence_generators[orig_lang].generate(
models=[], sample=smp, bos_token=bos_token
max_lngth = max([gn[0]["tokens"].size(0) for gn in generated])
net_input = smp["net_input"]
n_src_tokens = torch.empty(
size=(len(generated), max_lngth + 1), dtype=net_input["src_tokens"].dtype
n_src_lengths = torch.empty(
len(generated), dtype=net_input["src_lengths"].dtype
for i, gn in enumerate(generated):
tokens = gn[0]["tokens"]
tokens_size = tokens.size(0)
padding_needed = max_lngth - tokens_size
tokens = torch.cat([tokens.new([bos_token]), tokens])
tokens = F.pad(tokens, (0, padding_needed), value=self.dictionary.pad())
n_src_tokens[i] = tokens
n_src_lengths[i] = tokens_size + 1
device = net_input["src_tokens"].device
# This seems to be important
del net_input["src_tokens"]
del net_input["src_lengths"]
net_input["src_tokens"] = n_src_tokens.to(device)
net_input["src_lengths"] = n_src_lengths.to(device)
def generate(self, smp, model):
orig_lang = (
.replace(" ", "")
.replace("_", "")
bos_token = smp["net_input"]["prev_output_tokens"][0][0]
with torch.no_grad():
generated = self.sequence_generators[orig_lang].generate(
models=[model], sample=smp, bos_token=bos_token
return generated
def get_other_lang(self, lang):
# TODO: allow more complex mapping
if lang != self.mono_langs[0]:
return self.mono_langs[0]
if len(self.mono_langs) == 2:
return self.mono_langs[1]
return self.mono_langs[np.random.randint(1, len(self.mono_langs))]
def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
agg_loss, agg_sample_size = 0.0, 0.0
agg_logging_output: Dict[str, float] = defaultdict(float)
dataset_keys = self.datasets["train"].datasets.keys()
weights = {
"BT": self.lambda_bt(update_num),
"DENOISE": self.lambda_dae(update_num),
log_keys = {"BT": "bt_", "DENOISE": "dae_"}
for dataset_key in dataset_keys:
smp = sample[dataset_key]
mono_lang, task_subtype = dataset_key.split("-")
if weights[task_subtype] == 0:
if task_subtype == "BT":
with torch.autograd.profiler.record_function("backtranslation"):
# TODO: Could we translate to several language at once ?
# this would allow to share encoder_out and maximize GPU usage.
other_lang = self.get_other_lang(mono_lang)
self.backtranslate_sample(smp, mono_lang, other_lang)
self.display_samples_once_in_a_while(smp, mono_lang, other_lang)
# Like in FairseqTask.train_step
with torch.autograd.profiler.record_function("forward"):
loss, sample_size, logging_output = criterion(model, smp)
loss *= weights[task_subtype]
if ignore_grad:
loss *= 0
with torch.autograd.profiler.record_function("backward"):
agg_loss += loss.item()
agg_sample_size += sample_size
for k in logging_output:
agg_logging_output[log_keys[task_subtype] + k] += logging_output[k]
agg_logging_output[k] += logging_output[k]
return agg_loss, agg_sample_size, agg_logging_output
def get_bos_token_from_sample(self, sample):
net_input = sample["net_input"]
source_lang_token_id = torch.unique(net_input["src_tokens"][:, 0]).item()
source_lang_token = self.dictionary[source_lang_token_id].replace("_", "")
target_lang_token_id = _lang_token_index(
self.dictionary, self.get_other_lang(source_lang_token)
return target_lang_token_id
def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
bt_sample_size = sum(x.get("bt_sample_size", 0) for x in logging_outputs)
if bt_sample_size:
bt_loss_sum = sum(x.get("bt_loss", 0) for x in logging_outputs)
bt_loss_sum *= 1 / bt_sample_size / math.log(2)
metrics.log_scalar("bt_loss", bt_loss_sum, bt_sample_size, round=3)
bt_nll_loss_sum = sum(x.get("bt_nll_loss", 0) for x in logging_outputs)
bt_ntokens = sum(x.get("bt_ntokens", 0) for x in logging_outputs)
bt_nll_loss_sum *= 1 / bt_ntokens / math.log(2)
metrics.log_scalar("bt_nll_loss", bt_nll_loss_sum, bt_ntokens, round=3)
"bt_ppl", lambda meters: utils.get_perplexity(meters["bt_nll_loss"].avg)
dae_sample_size = sum(x.get("dae_sample_size", 0) for x in logging_outputs)
if dae_sample_size:
dae_loss_sum = sum(x.get("dae_loss", 0) for x in logging_outputs)
dae_loss_sum *= 1 / dae_sample_size / math.log(2)
metrics.log_scalar("dae_loss", dae_loss_sum, dae_sample_size, round=3)
dae_nll_loss_sum = sum(x.get("dae_nll_loss", 0) for x in logging_outputs)
dae_ntokens = sum(x.get("dae_ntokens", 0) for x in logging_outputs)
dae_nll_loss_sum *= 1 / dae_ntokens / math.log(2)
metrics.log_scalar("dae_nll_loss", dae_nll_loss_sum, dae_ntokens, round=3)
lambda meters: utils.get_perplexity(meters["dae_nll_loss"].avg),
def extend_embedding(
emb: nn.Module, new_vocab_size: int, copy_from_token_id: int
) -> None:
old_emb_data = emb.weight.data
(old_vocab_size, dim) = old_emb_data.shape
assert new_vocab_size >= old_vocab_size
if new_vocab_size > old_vocab_size:
emb.weight.data = torch.zeros((new_vocab_size, dim))
emb.weight.data[:old_vocab_size, :] = old_emb_data
# initialize new embeddings
emb.weight.data[old_vocab_size:, :] = old_emb_data[copy_from_token_id]
if hasattr(emb, "num_embeddings"):
emb.num_embeddings = new_vocab_size
if hasattr(emb, "out_features"):
emb.out_features = new_vocab_size
if getattr(emb, "bias", None) is None:
# Fix the bias.
# Bias shape can be different from the previous vocab size
# if the weight matrix was shared and alread extended but not the bias.
(old_vocab_size,) = emb.bias.shape
assert new_vocab_size >= old_vocab_size
if new_vocab_size > old_vocab_size:
old_bias = emb.bias.data
new_bias = torch.zeros(
(new_vocab_size,), dtype=old_bias.dtype, device=old_bias.device
new_bias[:old_vocab_size] = old_bias
emb.bias.data = new_bias
def add_secial_tokens_to_dict_and_model(
dictionary: "fairseq.data.Dictionary",
model: nn.Module,
mono_langs: Sequence[str],
) -> None:
embs = model.encoder.embed_tokens
vocab_size, embedding_dim = embs.weight.shape
# The model may or may not have a '<mask>' embedding yet
assert (
len(dictionary) <= vocab_size <= len(dictionary) + 1
), f"Dictionary len ({len(dictionary)}) doesn't match embs shape ({embs.weight.shape})"
# TODO: we should reuse the pretrained model dict which already has <mask>
for lang in mono_langs:
lang_token = _lang_token(lang)
f"dictionary: {len(dictionary)} -> {vocab_size} tokens "
f"after adding {len(mono_langs)} lang tokens."
if len(dictionary) <= vocab_size:
extend_embedding(embs, len(dictionary), dictionary.bos())
dec_embs = model.decoder.embed_tokens
extend_embedding(dec_embs, len(dictionary), dictionary.bos())
lm_head = model.decoder.output_projection
extend_embedding(lm_head, len(dictionary), dictionary.bos())
assert lm_head.weight.shape == (len(dictionary), embedding_dim)
def _lang_token(lang: str) -> str:
return f"__{lang}__"
def _lang_token_index(dictionary, lang: str) -> int:
return dictionary.index(_lang_token(lang))
def assert_weights_have_changed(model: nn.Module):
def checksum(model: nn.Module) -> float:
return sum(p.sum().item() for p in model.parameters())
initial_checksum = checksum(model)
yield model
final_checksum = checksum(model)
f"initial_checksum={initial_checksum} -> final_checksum={final_checksum}"
assert initial_checksum != final_checksum, "Model hasn't changed !"