OFA / fairseq /tests /test_online_backtranslation.py
root
init
93b9482
raw
history blame
7.65 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import tempfile
import unittest
from pathlib import Path
from typing import Any, Dict, Sequence
import fairseq.data.indexed_dataset as indexed_dataset
import fairseq.options
import fairseq.tasks.online_backtranslation as obt
import torch
from tests import utils
def mk_sample(tokens: Sequence[int], batch_size: int = 2) -> Dict[str, Any]:
batch = torch.stack([torch.tensor(tokens, dtype=torch.long)] * batch_size)
sample = {
"net_input": {
"src_tokens": batch,
"prev_output_tokens": batch,
"src_lengths": torch.tensor([len(tokens)] * batch_size, dtype=torch.long),
},
"target": batch[:, 1:],
}
return sample
def mk_dataset(num_samples: int, max_len: int, output: Path):
output.parent.mkdir(exist_ok=True)
idx = indexed_dataset.IndexedDatasetBuilder(str(output))
data = torch.randint(5, 100, (num_samples, max_len))
lengths = torch.randint(3, max_len, (num_samples,))
for d, l in zip(data, lengths):
d[0] = 0
idx.add_item(d[:l])
idx.finalize(output.with_suffix(".idx"))
assert output.exists()
assert output.with_suffix(".idx").exists()
class OnlineBacktranslationTest(unittest.TestCase):
tmp_dir = Path(tempfile.mkdtemp(suffix="OnlineBacktranslationTest"))
@classmethod
def obt_task(
cls, languages: Sequence[str], data: Path = None, language_mapping: str = None
):
dict_path = cls.tmp_dir / "dict.txt"
if not dict_path.exists():
dictionary = utils.dummy_dictionary(100)
dictionary.save(str(dict_path))
if data is not None:
(data / "dict.txt").write_text(dict_path.read_text())
else:
data = cls.tmp_dir
assert len(languages) >= 2
kwargs = {
"arch": "transformer",
# --max-sentences=1 for better predictability of batches
"max_sentences": 1,
# Use characteristics dimensions
"encoder_layers": 3,
"encoder_embed_dim": 12,
"encoder_ffn_embed_dim": 14,
"encoder_attention_heads": 4,
"decoder_layers": 3,
"decoder_embed_dim": 12,
"decoder_output_dim": 12,
"decoder_ffn_embed_dim": 14,
"decoder_attention_heads": 4,
# Disable dropout so we have comparable tests.
"dropout": 0,
"attention_dropout": 0,
"activation_dropout": 0,
"encoder_layerdrop": 0,
}
args = fairseq.options.get_args(
data,
task="online_backtranslation",
mono_langs=",".join(languages),
valid_lang_pairs=f"{languages[0]}-{languages[1]}",
tokens_per_sample=256,
language_mapping=language_mapping,
**kwargs,
)
task = obt.OnlineBackTranslationTask.setup_task(args)
# we need to build the model to have the correct dictionary
model = task.build_model(task.args)
return task, model
def tmp_path(self, test_case: str) -> Path:
return Path(tempfile.mkdtemp(test_case, dir=self.tmp_dir))
def test_lang_tokens(self):
task, model = self.obt_task(["en", "ro", "zh"])
assert obt._lang_token("en") in task.dictionary
assert obt._lang_token("ro") in task.dictionary
assert obt._lang_token("zh") in task.dictionary
en_bos = obt._lang_token_index(task.common_dict, "en")
assert "en" == task.common_dict[en_bos].strip("_")
zh_bos = obt._lang_token_index(task.common_dict, "zh")
assert "zh" == task.common_dict[zh_bos].strip("_")
zh_sample = mk_sample([zh_bos, 16, 14, 12, 10])
# we expect to receive the bos token for translation
assert task.get_bos_token_from_sample(zh_sample) == en_bos
def test_backtranslate_sample(self):
task, model = self.obt_task(["en", "ro", "zh"])
en_bos = obt._lang_token_index(task.common_dict, "en")
zh_bos = obt._lang_token_index(task.common_dict, "zh")
sample = mk_sample([zh_bos, 16, 14, 12, 10])
task.backtranslate_sample(sample, "zh", "en")
target_zh = list(sample["target"][0])
assert target_zh == [16, 14, 12, 10] # original zh sentence
generated_en = sample["net_input"]["src_tokens"][0]
assert generated_en[0] == en_bos
def test_train_dataset(self):
data = self.tmp_path("test_train_dataset")
mk_dataset(20, 10, data / "en" / "train.bin")
mk_dataset(10, 10, data / "zh" / "train.bin")
task, model = self.obt_task(["en", "zh"], data)
task.load_dataset("train")
en_bos = obt._lang_token_index(task.common_dict, "en")
zh_bos = obt._lang_token_index(task.common_dict, "zh")
train = task.datasets["train"]
train.ordered_indices()
train.prefetch([0, 19])
sample_0 = train[0]
sample_19 = train[19]
self.assertEqual(
set(sample_0.keys()), {"en-BT", "en-DENOISE", "zh-BT", "zh-DENOISE"}
)
for sample in (sample_0, sample_19):
self.assertEqual(sample["en-BT"]["source"][0], en_bos)
# bt target isn't ready to look at.
self.assertEqual(sample["en-DENOISE"]["source"][0], en_bos)
# TODO What could we check on the target side ?
for i in range(10):
# Zh dataset is shorter, and is wrapped around En dataset.
train.prefetch([i, i + 10])
self.assertEqual(
list(train[i]["zh-DENOISE"]["source"]),
list(train[i + 10]["zh-DENOISE"]["source"]),
)
self.assertEqual(train[i]["zh-DENOISE"]["source"][0].item(), zh_bos)
# Sorted by increasing len
self.assertLess(
len(sample_0["en-BT"]["source"]), len(sample_19["en-BT"]["source"])
)
def test_valid_dataset(self):
data = self.tmp_path("test_valid_dataset")
mk_dataset(10, 21, data / "valid.en-zh.en.bin")
mk_dataset(10, 21, data / "valid.en-zh.zh.bin")
task, model = self.obt_task(["en", "zh"], data)
valid = task.load_dataset("valid")
en_bos = obt._lang_token_index(task.common_dict, "en")
assert valid is not None
valid.prefetch(range(10))
sample_0 = valid[0]
sample_9 = valid[9]
self.assertEqual(sample_0["id"], 0)
self.assertEqual(sample_9["id"], 9)
self.assertEqual(sample_0["source"][0], en_bos)
self.assertEqual(sample_9["source"][0], en_bos)
# TODO: could we test the target side ?
def assertFnMatch(self, fn, values):
for x, y in values.items():
fn_x = fn(x)
self.assertEqual(fn_x, y, f"Fn has wrong value: fn({x}) = {fn_x} != {y}")
def test_piecewise_linear_fn(self):
self.assertFnMatch(
obt.PiecewiseLinearFn.from_string("1.0"), {0: 1, 100: 1, 500: 1, 1000: 1}
)
self.assertFnMatch(
obt.PiecewiseLinearFn.from_string("0:1,1000:0"),
{0: 1, 500: 0.5, 1000: 0, 2000: 0},
)
self.assertFnMatch(
obt.PiecewiseLinearFn.from_string("0:0,1000:1"),
{0: 0, 500: 0.5, 1000: 1, 2000: 1},
)
self.assertFnMatch(
obt.PiecewiseLinearFn.from_string("0:0,1000:1,2000:0"),
{0: 0, 500: 0.5, 1000: 1, 1500: 0.5, 2000: 0, 3000: 0},
)