|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
from typing import Dict, List |
|
|
|
import tests.utils as test_utils |
|
import torch |
|
from fairseq import utils |
|
from fairseq.data import ( |
|
Dictionary, |
|
LanguagePairDataset, |
|
TransformEosDataset, |
|
data_utils, |
|
noising, |
|
) |
|
|
|
|
|
class TestDataNoising(unittest.TestCase): |
|
def _get_test_data_with_bpe_cont_marker(self, append_eos=True): |
|
""" |
|
Args: |
|
append_eos: if True, each input sentence in the source tokens tensor |
|
will have an EOS appended to the end. |
|
|
|
Returns: |
|
vocabs: BPE vocab with continuation markers as suffixes to denote |
|
non-end of word tokens. This is the standard BPE format used in |
|
fairseq's preprocessing. |
|
x: input tensor containing numberized source tokens, with EOS at the |
|
end if append_eos is true |
|
src_lengths: and source lengths. |
|
""" |
|
vocab = Dictionary() |
|
vocab.add_symbol("he@@") |
|
vocab.add_symbol("llo") |
|
vocab.add_symbol("how") |
|
vocab.add_symbol("are") |
|
vocab.add_symbol("y@@") |
|
vocab.add_symbol("ou") |
|
vocab.add_symbol("n@@") |
|
vocab.add_symbol("ew") |
|
vocab.add_symbol("or@@") |
|
vocab.add_symbol("k") |
|
|
|
src_tokens = [ |
|
["he@@", "llo", "n@@", "ew", "y@@", "or@@", "k"], |
|
["how", "are", "y@@", "ou"], |
|
] |
|
x, src_lengths = x, src_lengths = self._convert_src_tokens_to_tensor( |
|
vocab=vocab, src_tokens=src_tokens, append_eos=append_eos |
|
) |
|
return vocab, x, src_lengths |
|
|
|
def _get_test_data_with_bpe_end_marker(self, append_eos=True): |
|
""" |
|
Args: |
|
append_eos: if True, each input sentence in the source tokens tensor |
|
will have an EOS appended to the end. |
|
|
|
Returns: |
|
vocabs: BPE vocab with end-of-word markers as suffixes to denote |
|
tokens at the end of a word. This is an alternative to fairseq's |
|
standard preprocessing framework and is not generally supported |
|
within fairseq. |
|
x: input tensor containing numberized source tokens, with EOS at the |
|
end if append_eos is true |
|
src_lengths: and source lengths. |
|
""" |
|
vocab = Dictionary() |
|
vocab.add_symbol("he") |
|
vocab.add_symbol("llo_EOW") |
|
vocab.add_symbol("how_EOW") |
|
vocab.add_symbol("are_EOW") |
|
vocab.add_symbol("y") |
|
vocab.add_symbol("ou_EOW") |
|
vocab.add_symbol("n") |
|
vocab.add_symbol("ew_EOW") |
|
vocab.add_symbol("or") |
|
vocab.add_symbol("k_EOW") |
|
|
|
src_tokens = [ |
|
["he", "llo_EOW", "n", "ew_EOW", "y", "or", "k_EOW"], |
|
["how_EOW", "are_EOW", "y", "ou_EOW"], |
|
] |
|
x, src_lengths = x, src_lengths = self._convert_src_tokens_to_tensor( |
|
vocab=vocab, src_tokens=src_tokens, append_eos=append_eos |
|
) |
|
return vocab, x, src_lengths |
|
|
|
def _get_test_data_with_word_vocab(self, append_eos=True): |
|
""" |
|
Args: |
|
append_eos: if True, each input sentence in the source tokens tensor |
|
will have an EOS appended to the end. |
|
|
|
Returns: |
|
vocabs: word vocab |
|
x: input tensor containing numberized source tokens, with EOS at the |
|
end if append_eos is true |
|
src_lengths: and source lengths. |
|
""" |
|
vocab = Dictionary() |
|
|
|
vocab.add_symbol("hello") |
|
vocab.add_symbol("how") |
|
vocab.add_symbol("are") |
|
vocab.add_symbol("you") |
|
vocab.add_symbol("new") |
|
vocab.add_symbol("york") |
|
src_tokens = [ |
|
["hello", "new", "york", "you"], |
|
["how", "are", "you", "new", "york"], |
|
] |
|
x, src_lengths = self._convert_src_tokens_to_tensor( |
|
vocab=vocab, src_tokens=src_tokens, append_eos=append_eos |
|
) |
|
return vocab, x, src_lengths |
|
|
|
def _convert_src_tokens_to_tensor( |
|
self, vocab: Dictionary, src_tokens: List[List[str]], append_eos: bool |
|
): |
|
src_len = [len(x) for x in src_tokens] |
|
|
|
if append_eos: |
|
src_len = [length + 1 for length in src_len] |
|
|
|
x = torch.LongTensor(len(src_tokens), max(src_len)).fill_(vocab.pad()) |
|
for i in range(len(src_tokens)): |
|
for j in range(len(src_tokens[i])): |
|
x[i][j] = vocab.index(src_tokens[i][j]) |
|
if append_eos: |
|
x[i][j + 1] = vocab.eos() |
|
|
|
x = x.transpose(1, 0) |
|
return x, torch.LongTensor(src_len) |
|
|
|
def assert_eos_at_end(self, x, x_len, eos): |
|
"""Asserts last token of every sentence in x is EOS """ |
|
for i in range(len(x_len)): |
|
self.assertEqual( |
|
x[x_len[i] - 1][i], |
|
eos, |
|
( |
|
"Expected eos (token id {eos}) at the end of sentence {i} " |
|
"but got {other} instead" |
|
).format(i=i, eos=eos, other=x[i][-1]), |
|
) |
|
|
|
def assert_word_dropout_correct(self, x, x_noised, x_len, l_noised): |
|
|
|
|
|
self.assertEqual(x_len[0] - 2, l_noised[0]) |
|
for i in range(l_noised[0]): |
|
self.assertEqual(x_noised[i][0], x[i + 2][0]) |
|
|
|
def test_word_dropout_with_eos(self): |
|
vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=True) |
|
|
|
with data_utils.numpy_seed(1234): |
|
noising_gen = noising.WordDropout(vocab) |
|
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2) |
|
self.assert_word_dropout_correct( |
|
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised |
|
) |
|
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos()) |
|
|
|
def assert_word_blanking_correct(self, x, x_noised, x_len, l_noised, unk): |
|
|
|
|
|
self.assertEqual(x_len[0], l_noised[0]) |
|
for i in range(l_noised[0]): |
|
if i < 2: |
|
self.assertEqual(x_noised[i][0], unk) |
|
else: |
|
self.assertEqual(x_noised[i][0], x[i][0]) |
|
|
|
def test_word_blank_with_eos(self): |
|
vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=True) |
|
|
|
with data_utils.numpy_seed(1234): |
|
noising_gen = noising.WordDropout(vocab) |
|
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2, vocab.unk()) |
|
self.assert_word_blanking_correct( |
|
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised, unk=vocab.unk() |
|
) |
|
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos()) |
|
|
|
def generate_unchanged_shuffle_map(self, length): |
|
return {i: i for i in range(length)} |
|
|
|
def assert_word_shuffle_matches_expected( |
|
self, |
|
x, |
|
x_len, |
|
max_shuffle_distance: int, |
|
vocab: Dictionary, |
|
expected_shufle_maps: List[Dict[int, int]], |
|
expect_eos_at_end: bool, |
|
bpe_end_marker=None, |
|
): |
|
""" |
|
This verifies that with a given x, x_len, max_shuffle_distance, and |
|
vocab, we get the expected shuffle result. |
|
|
|
Args: |
|
x: Tensor of shape (T x B) = (sequence_length, batch_size) |
|
x_len: Tensor of length B = batch_size |
|
max_shuffle_distance: arg to pass to noising |
|
expected_shuffle_maps: List[mapping] where mapping is a |
|
Dict[old_index, new_index], mapping x's elements from their |
|
old positions in x to their new positions in x. |
|
expect_eos_at_end: if True, check the output to make sure there is |
|
an EOS at the end. |
|
bpe_end_marker: str denoting the BPE end token. If this is not None, we |
|
set the BPE cont token to None in the noising classes. |
|
""" |
|
bpe_cont_marker = None |
|
if bpe_end_marker is None: |
|
bpe_cont_marker = "@@" |
|
|
|
with data_utils.numpy_seed(1234): |
|
word_shuffle = noising.WordShuffle( |
|
vocab, bpe_cont_marker=bpe_cont_marker, bpe_end_marker=bpe_end_marker |
|
) |
|
x_noised, l_noised = word_shuffle.noising( |
|
x, x_len, max_shuffle_distance=max_shuffle_distance |
|
) |
|
|
|
|
|
|
|
|
|
for i in range(len(expected_shufle_maps)): |
|
shuffle_map = expected_shufle_maps[i] |
|
for k, v in shuffle_map.items(): |
|
self.assertEqual(x[k][i], x_noised[v][i]) |
|
|
|
|
|
for pre_shuffle_length, post_shuffle_length in zip(x_len, l_noised): |
|
self.assertEqual(pre_shuffle_length, post_shuffle_length) |
|
if expect_eos_at_end: |
|
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos()) |
|
|
|
def test_word_shuffle_with_eos(self): |
|
vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=True) |
|
|
|
|
|
|
|
self.assert_word_shuffle_matches_expected( |
|
x=x, |
|
x_len=x_len, |
|
max_shuffle_distance=0, |
|
vocab=vocab, |
|
expected_shufle_maps=[ |
|
self.generate_unchanged_shuffle_map(example_len) |
|
for example_len in x_len |
|
], |
|
expect_eos_at_end=True, |
|
) |
|
|
|
|
|
|
|
self.assert_word_shuffle_matches_expected( |
|
x=x, |
|
x_len=x_len, |
|
vocab=vocab, |
|
max_shuffle_distance=3, |
|
expected_shufle_maps=[ |
|
self.generate_unchanged_shuffle_map(x_len[0]), |
|
{0: 0, 1: 3, 2: 1, 3: 2}, |
|
], |
|
expect_eos_at_end=True, |
|
) |
|
|
|
def test_word_shuffle_with_eos_nonbpe(self): |
|
"""The purpose of this is to test shuffling logic with word vocabs""" |
|
vocab, x, x_len = self._get_test_data_with_word_vocab(append_eos=True) |
|
|
|
|
|
|
|
self.assert_word_shuffle_matches_expected( |
|
x=x, |
|
x_len=x_len, |
|
max_shuffle_distance=0, |
|
vocab=vocab, |
|
expected_shufle_maps=[ |
|
self.generate_unchanged_shuffle_map(example_len) |
|
for example_len in x_len |
|
], |
|
expect_eos_at_end=True, |
|
) |
|
|
|
|
|
|
|
self.assert_word_shuffle_matches_expected( |
|
x=x, |
|
x_len=x_len, |
|
vocab=vocab, |
|
max_shuffle_distance=3, |
|
expected_shufle_maps=[ |
|
{0: 0, 1: 1, 2: 3, 3: 2}, |
|
{0: 0, 1: 2, 2: 1, 3: 3, 4: 4}, |
|
], |
|
expect_eos_at_end=True, |
|
) |
|
|
|
def test_word_shuffle_without_eos(self): |
|
"""Same result as word shuffle with eos except no EOS at end""" |
|
vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=False) |
|
|
|
|
|
|
|
self.assert_word_shuffle_matches_expected( |
|
x=x, |
|
x_len=x_len, |
|
max_shuffle_distance=0, |
|
vocab=vocab, |
|
expected_shufle_maps=[ |
|
self.generate_unchanged_shuffle_map(example_len) |
|
for example_len in x_len |
|
], |
|
expect_eos_at_end=False, |
|
) |
|
|
|
|
|
|
|
self.assert_word_shuffle_matches_expected( |
|
x=x, |
|
x_len=x_len, |
|
vocab=vocab, |
|
max_shuffle_distance=3, |
|
expected_shufle_maps=[ |
|
self.generate_unchanged_shuffle_map(x_len[0]), |
|
{0: 0, 1: 3, 2: 1, 3: 2}, |
|
], |
|
expect_eos_at_end=False, |
|
) |
|
|
|
def test_word_shuffle_without_eos_with_bpe_end_marker(self): |
|
"""Same result as word shuffle without eos except using BPE end token""" |
|
vocab, x, x_len = self._get_test_data_with_bpe_end_marker(append_eos=False) |
|
|
|
|
|
|
|
self.assert_word_shuffle_matches_expected( |
|
x=x, |
|
x_len=x_len, |
|
max_shuffle_distance=0, |
|
vocab=vocab, |
|
expected_shufle_maps=[ |
|
self.generate_unchanged_shuffle_map(example_len) |
|
for example_len in x_len |
|
], |
|
expect_eos_at_end=False, |
|
bpe_end_marker="_EOW", |
|
) |
|
|
|
|
|
|
|
self.assert_word_shuffle_matches_expected( |
|
x=x, |
|
x_len=x_len, |
|
vocab=vocab, |
|
max_shuffle_distance=3, |
|
expected_shufle_maps=[ |
|
self.generate_unchanged_shuffle_map(x_len[0]), |
|
{0: 0, 1: 3, 2: 1, 3: 2}, |
|
], |
|
expect_eos_at_end=False, |
|
bpe_end_marker="_EOW", |
|
) |
|
|
|
def assert_no_eos_at_end(self, x, x_len, eos): |
|
"""Asserts that the last token of each sentence in x is not EOS """ |
|
for i in range(len(x_len)): |
|
self.assertNotEqual( |
|
x[x_len[i] - 1][i], |
|
eos, |
|
"Expected no eos (token id {eos}) at the end of sentence {i}.".format( |
|
eos=eos, i=i |
|
), |
|
) |
|
|
|
def test_word_dropout_without_eos(self): |
|
"""Same result as word dropout with eos except no EOS at end""" |
|
vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=False) |
|
|
|
with data_utils.numpy_seed(1234): |
|
noising_gen = noising.WordDropout(vocab) |
|
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2) |
|
self.assert_word_dropout_correct( |
|
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised |
|
) |
|
self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos()) |
|
|
|
def test_word_blank_without_eos(self): |
|
"""Same result as word blank with eos except no EOS at end""" |
|
vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=False) |
|
|
|
with data_utils.numpy_seed(1234): |
|
noising_gen = noising.WordDropout(vocab) |
|
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2, vocab.unk()) |
|
self.assert_word_blanking_correct( |
|
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised, unk=vocab.unk() |
|
) |
|
self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos()) |
|
|
|
def _get_noising_dataset_batch( |
|
self, |
|
src_tokens_no_pad, |
|
src_dict, |
|
append_eos_to_tgt=False, |
|
): |
|
""" |
|
Constructs a NoisingDataset and the corresponding |
|
``LanguagePairDataset(NoisingDataset(src), src)``. If |
|
*append_eos_to_tgt* is True, wrap the source dataset in |
|
:class:`TransformEosDataset` to append EOS to the clean source when |
|
using it as the target. |
|
""" |
|
src_dataset = test_utils.TestDataset(data=src_tokens_no_pad) |
|
|
|
noising_dataset = noising.NoisingDataset( |
|
src_dataset=src_dataset, |
|
src_dict=src_dict, |
|
seed=1234, |
|
max_word_shuffle_distance=3, |
|
word_dropout_prob=0.2, |
|
word_blanking_prob=0.2, |
|
noising_class=noising.UnsupervisedMTNoising, |
|
) |
|
tgt = src_dataset |
|
language_pair_dataset = LanguagePairDataset( |
|
src=noising_dataset, tgt=tgt, src_sizes=None, src_dict=src_dict |
|
) |
|
language_pair_dataset = TransformEosDataset( |
|
language_pair_dataset, |
|
src_dict.eos(), |
|
append_eos_to_tgt=append_eos_to_tgt, |
|
) |
|
|
|
dataloader = torch.utils.data.DataLoader( |
|
dataset=language_pair_dataset, |
|
batch_size=2, |
|
collate_fn=language_pair_dataset.collater, |
|
) |
|
denoising_batch_result = next(iter(dataloader)) |
|
return denoising_batch_result |
|
|
|
def test_noising_dataset_with_eos(self): |
|
src_dict, src_tokens, _ = self._get_test_data_with_bpe_cont_marker( |
|
append_eos=True |
|
) |
|
|
|
|
|
src_tokens = torch.t(src_tokens) |
|
src_tokens_no_pad = [] |
|
for src_sentence in src_tokens: |
|
src_tokens_no_pad.append( |
|
utils.strip_pad(tensor=src_sentence, pad=src_dict.pad()) |
|
) |
|
denoising_batch_result = self._get_noising_dataset_batch( |
|
src_tokens_no_pad=src_tokens_no_pad, src_dict=src_dict |
|
) |
|
|
|
eos, pad = src_dict.eos(), src_dict.pad() |
|
|
|
|
|
expected_src = torch.LongTensor( |
|
[[4, 5, 10, 11, 8, 12, 13, eos], [pad, pad, pad, 6, 8, 9, 7, eos]] |
|
) |
|
|
|
expected_tgt = torch.LongTensor( |
|
[[4, 5, 10, 11, 8, 12, 13, eos], [6, 7, 8, 9, eos, pad, pad, pad]] |
|
) |
|
generated_src = denoising_batch_result["net_input"]["src_tokens"] |
|
tgt_tokens = denoising_batch_result["target"] |
|
|
|
self.assertTensorEqual(expected_src, generated_src) |
|
self.assertTensorEqual(expected_tgt, tgt_tokens) |
|
|
|
def test_noising_dataset_without_eos(self): |
|
""" |
|
Similar to test noising dataset with eos except that we have to set |
|
*append_eos_to_tgt* to ``True``. |
|
""" |
|
|
|
src_dict, src_tokens, _ = self._get_test_data_with_bpe_cont_marker( |
|
append_eos=False |
|
) |
|
|
|
|
|
src_tokens = torch.t(src_tokens) |
|
src_tokens_no_pad = [] |
|
for src_sentence in src_tokens: |
|
src_tokens_no_pad.append( |
|
utils.strip_pad(tensor=src_sentence, pad=src_dict.pad()) |
|
) |
|
denoising_batch_result = self._get_noising_dataset_batch( |
|
src_tokens_no_pad=src_tokens_no_pad, |
|
src_dict=src_dict, |
|
append_eos_to_tgt=True, |
|
) |
|
|
|
eos, pad = src_dict.eos(), src_dict.pad() |
|
|
|
|
|
expected_src = torch.LongTensor( |
|
[[4, 5, 10, 11, 8, 12, 13], [pad, pad, pad, 6, 8, 9, 7]] |
|
) |
|
|
|
expected_tgt = torch.LongTensor( |
|
[[4, 5, 10, 11, 8, 12, 13, eos], [6, 7, 8, 9, eos, pad, pad, pad]] |
|
) |
|
|
|
generated_src = denoising_batch_result["net_input"]["src_tokens"] |
|
tgt_tokens = denoising_batch_result["target"] |
|
|
|
self.assertTensorEqual(expected_src, generated_src) |
|
self.assertTensorEqual(expected_tgt, tgt_tokens) |
|
|
|
def assertTensorEqual(self, t1, t2): |
|
self.assertEqual(t1.size(), t2.size(), "size mismatch") |
|
self.assertEqual(t1.ne(t2).long().sum(), 0) |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|