# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import tempfile import unittest from pathlib import Path from typing import Any, Dict, Sequence import fairseq.data.indexed_dataset as indexed_dataset import fairseq.options import fairseq.tasks.online_backtranslation as obt import torch from tests import utils def mk_sample(tokens: Sequence[int], batch_size: int = 2) -> Dict[str, Any]: batch = torch.stack([torch.tensor(tokens, dtype=torch.long)] * batch_size) sample = { "net_input": { "src_tokens": batch, "prev_output_tokens": batch, "src_lengths": torch.tensor([len(tokens)] * batch_size, dtype=torch.long), }, "target": batch[:, 1:], } return sample def mk_dataset(num_samples: int, max_len: int, output: Path): output.parent.mkdir(exist_ok=True) idx = indexed_dataset.IndexedDatasetBuilder(str(output)) data = torch.randint(5, 100, (num_samples, max_len)) lengths = torch.randint(3, max_len, (num_samples,)) for d, l in zip(data, lengths): d[0] = 0 idx.add_item(d[:l]) idx.finalize(output.with_suffix(".idx")) assert output.exists() assert output.with_suffix(".idx").exists() class OnlineBacktranslationTest(unittest.TestCase): tmp_dir = Path(tempfile.mkdtemp(suffix="OnlineBacktranslationTest")) @classmethod def obt_task( cls, languages: Sequence[str], data: Path = None, language_mapping: str = None ): dict_path = cls.tmp_dir / "dict.txt" if not dict_path.exists(): dictionary = utils.dummy_dictionary(100) dictionary.save(str(dict_path)) if data is not None: (data / "dict.txt").write_text(dict_path.read_text()) else: data = cls.tmp_dir assert len(languages) >= 2 kwargs = { "arch": "transformer", # --max-sentences=1 for better predictability of batches "max_sentences": 1, # Use characteristics dimensions "encoder_layers": 3, "encoder_embed_dim": 12, "encoder_ffn_embed_dim": 14, "encoder_attention_heads": 4, "decoder_layers": 3, "decoder_embed_dim": 12, "decoder_output_dim": 12, "decoder_ffn_embed_dim": 14, "decoder_attention_heads": 4, # Disable dropout so we have comparable tests. "dropout": 0, "attention_dropout": 0, "activation_dropout": 0, "encoder_layerdrop": 0, } args = fairseq.options.get_args( data, task="online_backtranslation", mono_langs=",".join(languages), valid_lang_pairs=f"{languages[0]}-{languages[1]}", tokens_per_sample=256, language_mapping=language_mapping, **kwargs, ) task = obt.OnlineBackTranslationTask.setup_task(args) # we need to build the model to have the correct dictionary model = task.build_model(task.args) return task, model def tmp_path(self, test_case: str) -> Path: return Path(tempfile.mkdtemp(test_case, dir=self.tmp_dir)) def test_lang_tokens(self): task, model = self.obt_task(["en", "ro", "zh"]) assert obt._lang_token("en") in task.dictionary assert obt._lang_token("ro") in task.dictionary assert obt._lang_token("zh") in task.dictionary en_bos = obt._lang_token_index(task.common_dict, "en") assert "en" == task.common_dict[en_bos].strip("_") zh_bos = obt._lang_token_index(task.common_dict, "zh") assert "zh" == task.common_dict[zh_bos].strip("_") zh_sample = mk_sample([zh_bos, 16, 14, 12, 10]) # we expect to receive the bos token for translation assert task.get_bos_token_from_sample(zh_sample) == en_bos def test_backtranslate_sample(self): task, model = self.obt_task(["en", "ro", "zh"]) en_bos = obt._lang_token_index(task.common_dict, "en") zh_bos = obt._lang_token_index(task.common_dict, "zh") sample = mk_sample([zh_bos, 16, 14, 12, 10]) task.backtranslate_sample(sample, "zh", "en") target_zh = list(sample["target"][0]) assert target_zh == [16, 14, 12, 10] # original zh sentence generated_en = sample["net_input"]["src_tokens"][0] assert generated_en[0] == en_bos def test_train_dataset(self): data = self.tmp_path("test_train_dataset") mk_dataset(20, 10, data / "en" / "train.bin") mk_dataset(10, 10, data / "zh" / "train.bin") task, model = self.obt_task(["en", "zh"], data) task.load_dataset("train") en_bos = obt._lang_token_index(task.common_dict, "en") zh_bos = obt._lang_token_index(task.common_dict, "zh") train = task.datasets["train"] train.ordered_indices() train.prefetch([0, 19]) sample_0 = train[0] sample_19 = train[19] self.assertEqual( set(sample_0.keys()), {"en-BT", "en-DENOISE", "zh-BT", "zh-DENOISE"} ) for sample in (sample_0, sample_19): self.assertEqual(sample["en-BT"]["source"][0], en_bos) # bt target isn't ready to look at. self.assertEqual(sample["en-DENOISE"]["source"][0], en_bos) # TODO What could we check on the target side ? for i in range(10): # Zh dataset is shorter, and is wrapped around En dataset. train.prefetch([i, i + 10]) self.assertEqual( list(train[i]["zh-DENOISE"]["source"]), list(train[i + 10]["zh-DENOISE"]["source"]), ) self.assertEqual(train[i]["zh-DENOISE"]["source"][0].item(), zh_bos) # Sorted by increasing len self.assertLess( len(sample_0["en-BT"]["source"]), len(sample_19["en-BT"]["source"]) ) def test_valid_dataset(self): data = self.tmp_path("test_valid_dataset") mk_dataset(10, 21, data / "valid.en-zh.en.bin") mk_dataset(10, 21, data / "valid.en-zh.zh.bin") task, model = self.obt_task(["en", "zh"], data) valid = task.load_dataset("valid") en_bos = obt._lang_token_index(task.common_dict, "en") assert valid is not None valid.prefetch(range(10)) sample_0 = valid[0] sample_9 = valid[9] self.assertEqual(sample_0["id"], 0) self.assertEqual(sample_9["id"], 9) self.assertEqual(sample_0["source"][0], en_bos) self.assertEqual(sample_9["source"][0], en_bos) # TODO: could we test the target side ? def assertFnMatch(self, fn, values): for x, y in values.items(): fn_x = fn(x) self.assertEqual(fn_x, y, f"Fn has wrong value: fn({x}) = {fn_x} != {y}") def test_piecewise_linear_fn(self): self.assertFnMatch( obt.PiecewiseLinearFn.from_string("1.0"), {0: 1, 100: 1, 500: 1, 1000: 1} ) self.assertFnMatch( obt.PiecewiseLinearFn.from_string("0:1,1000:0"), {0: 1, 500: 0.5, 1000: 0, 2000: 0}, ) self.assertFnMatch( obt.PiecewiseLinearFn.from_string("0:0,1000:1"), {0: 0, 500: 0.5, 1000: 1, 2000: 1}, ) self.assertFnMatch( obt.PiecewiseLinearFn.from_string("0:0,1000:1,2000:0"), {0: 0, 500: 0.5, 1000: 1, 1500: 0.5, 2000: 0, 3000: 0}, )