Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import unittest | |
from typing import Dict, List | |
import tests.utils as test_utils | |
import torch | |
from fairseq import utils | |
from fairseq.data import ( | |
Dictionary, | |
LanguagePairDataset, | |
TransformEosDataset, | |
data_utils, | |
noising, | |
) | |
class TestDataNoising(unittest.TestCase): | |
def _get_test_data_with_bpe_cont_marker(self, append_eos=True): | |
""" | |
Args: | |
append_eos: if True, each input sentence in the source tokens tensor | |
will have an EOS appended to the end. | |
Returns: | |
vocabs: BPE vocab with continuation markers as suffixes to denote | |
non-end of word tokens. This is the standard BPE format used in | |
fairseq's preprocessing. | |
x: input tensor containing numberized source tokens, with EOS at the | |
end if append_eos is true | |
src_lengths: and source lengths. | |
""" | |
vocab = Dictionary() | |
vocab.add_symbol("he@@") | |
vocab.add_symbol("llo") | |
vocab.add_symbol("how") | |
vocab.add_symbol("are") | |
vocab.add_symbol("y@@") | |
vocab.add_symbol("ou") | |
vocab.add_symbol("n@@") | |
vocab.add_symbol("ew") | |
vocab.add_symbol("or@@") | |
vocab.add_symbol("k") | |
src_tokens = [ | |
["he@@", "llo", "n@@", "ew", "y@@", "or@@", "k"], | |
["how", "are", "y@@", "ou"], | |
] | |
x, src_lengths = x, src_lengths = self._convert_src_tokens_to_tensor( | |
vocab=vocab, src_tokens=src_tokens, append_eos=append_eos | |
) | |
return vocab, x, src_lengths | |
def _get_test_data_with_bpe_end_marker(self, append_eos=True): | |
""" | |
Args: | |
append_eos: if True, each input sentence in the source tokens tensor | |
will have an EOS appended to the end. | |
Returns: | |
vocabs: BPE vocab with end-of-word markers as suffixes to denote | |
tokens at the end of a word. This is an alternative to fairseq's | |
standard preprocessing framework and is not generally supported | |
within fairseq. | |
x: input tensor containing numberized source tokens, with EOS at the | |
end if append_eos is true | |
src_lengths: and source lengths. | |
""" | |
vocab = Dictionary() | |
vocab.add_symbol("he") | |
vocab.add_symbol("llo_EOW") | |
vocab.add_symbol("how_EOW") | |
vocab.add_symbol("are_EOW") | |
vocab.add_symbol("y") | |
vocab.add_symbol("ou_EOW") | |
vocab.add_symbol("n") | |
vocab.add_symbol("ew_EOW") | |
vocab.add_symbol("or") | |
vocab.add_symbol("k_EOW") | |
src_tokens = [ | |
["he", "llo_EOW", "n", "ew_EOW", "y", "or", "k_EOW"], | |
["how_EOW", "are_EOW", "y", "ou_EOW"], | |
] | |
x, src_lengths = x, src_lengths = self._convert_src_tokens_to_tensor( | |
vocab=vocab, src_tokens=src_tokens, append_eos=append_eos | |
) | |
return vocab, x, src_lengths | |
def _get_test_data_with_word_vocab(self, append_eos=True): | |
""" | |
Args: | |
append_eos: if True, each input sentence in the source tokens tensor | |
will have an EOS appended to the end. | |
Returns: | |
vocabs: word vocab | |
x: input tensor containing numberized source tokens, with EOS at the | |
end if append_eos is true | |
src_lengths: and source lengths. | |
""" | |
vocab = Dictionary() | |
vocab.add_symbol("hello") | |
vocab.add_symbol("how") | |
vocab.add_symbol("are") | |
vocab.add_symbol("you") | |
vocab.add_symbol("new") | |
vocab.add_symbol("york") | |
src_tokens = [ | |
["hello", "new", "york", "you"], | |
["how", "are", "you", "new", "york"], | |
] | |
x, src_lengths = self._convert_src_tokens_to_tensor( | |
vocab=vocab, src_tokens=src_tokens, append_eos=append_eos | |
) | |
return vocab, x, src_lengths | |
def _convert_src_tokens_to_tensor( | |
self, vocab: Dictionary, src_tokens: List[List[str]], append_eos: bool | |
): | |
src_len = [len(x) for x in src_tokens] | |
# If we have to append EOS, we include EOS in counting src length | |
if append_eos: | |
src_len = [length + 1 for length in src_len] | |
x = torch.LongTensor(len(src_tokens), max(src_len)).fill_(vocab.pad()) | |
for i in range(len(src_tokens)): | |
for j in range(len(src_tokens[i])): | |
x[i][j] = vocab.index(src_tokens[i][j]) | |
if append_eos: | |
x[i][j + 1] = vocab.eos() | |
x = x.transpose(1, 0) | |
return x, torch.LongTensor(src_len) | |
def assert_eos_at_end(self, x, x_len, eos): | |
"""Asserts last token of every sentence in x is EOS """ | |
for i in range(len(x_len)): | |
self.assertEqual( | |
x[x_len[i] - 1][i], | |
eos, | |
( | |
"Expected eos (token id {eos}) at the end of sentence {i} " | |
"but got {other} instead" | |
).format(i=i, eos=eos, other=x[i][-1]), | |
) | |
def assert_word_dropout_correct(self, x, x_noised, x_len, l_noised): | |
# Expect only the first word (2 bpe tokens) of the first example | |
# was dropped out | |
self.assertEqual(x_len[0] - 2, l_noised[0]) | |
for i in range(l_noised[0]): | |
self.assertEqual(x_noised[i][0], x[i + 2][0]) | |
def test_word_dropout_with_eos(self): | |
vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=True) | |
with data_utils.numpy_seed(1234): | |
noising_gen = noising.WordDropout(vocab) | |
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2) | |
self.assert_word_dropout_correct( | |
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised | |
) | |
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos()) | |
def assert_word_blanking_correct(self, x, x_noised, x_len, l_noised, unk): | |
# Expect only the first word (2 bpe tokens) of the first example | |
# was blanked out | |
self.assertEqual(x_len[0], l_noised[0]) | |
for i in range(l_noised[0]): | |
if i < 2: | |
self.assertEqual(x_noised[i][0], unk) | |
else: | |
self.assertEqual(x_noised[i][0], x[i][0]) | |
def test_word_blank_with_eos(self): | |
vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=True) | |
with data_utils.numpy_seed(1234): | |
noising_gen = noising.WordDropout(vocab) | |
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2, vocab.unk()) | |
self.assert_word_blanking_correct( | |
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised, unk=vocab.unk() | |
) | |
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos()) | |
def generate_unchanged_shuffle_map(self, length): | |
return {i: i for i in range(length)} | |
def assert_word_shuffle_matches_expected( | |
self, | |
x, | |
x_len, | |
max_shuffle_distance: int, | |
vocab: Dictionary, | |
expected_shufle_maps: List[Dict[int, int]], | |
expect_eos_at_end: bool, | |
bpe_end_marker=None, | |
): | |
""" | |
This verifies that with a given x, x_len, max_shuffle_distance, and | |
vocab, we get the expected shuffle result. | |
Args: | |
x: Tensor of shape (T x B) = (sequence_length, batch_size) | |
x_len: Tensor of length B = batch_size | |
max_shuffle_distance: arg to pass to noising | |
expected_shuffle_maps: List[mapping] where mapping is a | |
Dict[old_index, new_index], mapping x's elements from their | |
old positions in x to their new positions in x. | |
expect_eos_at_end: if True, check the output to make sure there is | |
an EOS at the end. | |
bpe_end_marker: str denoting the BPE end token. If this is not None, we | |
set the BPE cont token to None in the noising classes. | |
""" | |
bpe_cont_marker = None | |
if bpe_end_marker is None: | |
bpe_cont_marker = "@@" | |
with data_utils.numpy_seed(1234): | |
word_shuffle = noising.WordShuffle( | |
vocab, bpe_cont_marker=bpe_cont_marker, bpe_end_marker=bpe_end_marker | |
) | |
x_noised, l_noised = word_shuffle.noising( | |
x, x_len, max_shuffle_distance=max_shuffle_distance | |
) | |
# For every example, we have a different expected shuffle map. We check | |
# that each example is shuffled as expected according to each | |
# corresponding shuffle map. | |
for i in range(len(expected_shufle_maps)): | |
shuffle_map = expected_shufle_maps[i] | |
for k, v in shuffle_map.items(): | |
self.assertEqual(x[k][i], x_noised[v][i]) | |
# Shuffling should not affect the length of each example | |
for pre_shuffle_length, post_shuffle_length in zip(x_len, l_noised): | |
self.assertEqual(pre_shuffle_length, post_shuffle_length) | |
if expect_eos_at_end: | |
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos()) | |
def test_word_shuffle_with_eos(self): | |
vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=True) | |
# Assert word shuffle with max shuffle distance 0 causes input to be | |
# unchanged | |
self.assert_word_shuffle_matches_expected( | |
x=x, | |
x_len=x_len, | |
max_shuffle_distance=0, | |
vocab=vocab, | |
expected_shufle_maps=[ | |
self.generate_unchanged_shuffle_map(example_len) | |
for example_len in x_len | |
], | |
expect_eos_at_end=True, | |
) | |
# Assert word shuffle with max shuffle distance 3 matches our expected | |
# shuffle order | |
self.assert_word_shuffle_matches_expected( | |
x=x, | |
x_len=x_len, | |
vocab=vocab, | |
max_shuffle_distance=3, | |
expected_shufle_maps=[ | |
self.generate_unchanged_shuffle_map(x_len[0]), | |
{0: 0, 1: 3, 2: 1, 3: 2}, | |
], | |
expect_eos_at_end=True, | |
) | |
def test_word_shuffle_with_eos_nonbpe(self): | |
"""The purpose of this is to test shuffling logic with word vocabs""" | |
vocab, x, x_len = self._get_test_data_with_word_vocab(append_eos=True) | |
# Assert word shuffle with max shuffle distance 0 causes input to be | |
# unchanged | |
self.assert_word_shuffle_matches_expected( | |
x=x, | |
x_len=x_len, | |
max_shuffle_distance=0, | |
vocab=vocab, | |
expected_shufle_maps=[ | |
self.generate_unchanged_shuffle_map(example_len) | |
for example_len in x_len | |
], | |
expect_eos_at_end=True, | |
) | |
# Assert word shuffle with max shuffle distance 3 matches our expected | |
# shuffle order | |
self.assert_word_shuffle_matches_expected( | |
x=x, | |
x_len=x_len, | |
vocab=vocab, | |
max_shuffle_distance=3, | |
expected_shufle_maps=[ | |
{0: 0, 1: 1, 2: 3, 3: 2}, | |
{0: 0, 1: 2, 2: 1, 3: 3, 4: 4}, | |
], | |
expect_eos_at_end=True, | |
) | |
def test_word_shuffle_without_eos(self): | |
"""Same result as word shuffle with eos except no EOS at end""" | |
vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=False) | |
# Assert word shuffle with max shuffle distance 0 causes input to be | |
# unchanged | |
self.assert_word_shuffle_matches_expected( | |
x=x, | |
x_len=x_len, | |
max_shuffle_distance=0, | |
vocab=vocab, | |
expected_shufle_maps=[ | |
self.generate_unchanged_shuffle_map(example_len) | |
for example_len in x_len | |
], | |
expect_eos_at_end=False, | |
) | |
# Assert word shuffle with max shuffle distance 3 matches our expected | |
# shuffle order | |
self.assert_word_shuffle_matches_expected( | |
x=x, | |
x_len=x_len, | |
vocab=vocab, | |
max_shuffle_distance=3, | |
expected_shufle_maps=[ | |
self.generate_unchanged_shuffle_map(x_len[0]), | |
{0: 0, 1: 3, 2: 1, 3: 2}, | |
], | |
expect_eos_at_end=False, | |
) | |
def test_word_shuffle_without_eos_with_bpe_end_marker(self): | |
"""Same result as word shuffle without eos except using BPE end token""" | |
vocab, x, x_len = self._get_test_data_with_bpe_end_marker(append_eos=False) | |
# Assert word shuffle with max shuffle distance 0 causes input to be | |
# unchanged | |
self.assert_word_shuffle_matches_expected( | |
x=x, | |
x_len=x_len, | |
max_shuffle_distance=0, | |
vocab=vocab, | |
expected_shufle_maps=[ | |
self.generate_unchanged_shuffle_map(example_len) | |
for example_len in x_len | |
], | |
expect_eos_at_end=False, | |
bpe_end_marker="_EOW", | |
) | |
# Assert word shuffle with max shuffle distance 3 matches our expected | |
# shuffle order | |
self.assert_word_shuffle_matches_expected( | |
x=x, | |
x_len=x_len, | |
vocab=vocab, | |
max_shuffle_distance=3, | |
expected_shufle_maps=[ | |
self.generate_unchanged_shuffle_map(x_len[0]), | |
{0: 0, 1: 3, 2: 1, 3: 2}, | |
], | |
expect_eos_at_end=False, | |
bpe_end_marker="_EOW", | |
) | |
def assert_no_eos_at_end(self, x, x_len, eos): | |
"""Asserts that the last token of each sentence in x is not EOS """ | |
for i in range(len(x_len)): | |
self.assertNotEqual( | |
x[x_len[i] - 1][i], | |
eos, | |
"Expected no eos (token id {eos}) at the end of sentence {i}.".format( | |
eos=eos, i=i | |
), | |
) | |
def test_word_dropout_without_eos(self): | |
"""Same result as word dropout with eos except no EOS at end""" | |
vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=False) | |
with data_utils.numpy_seed(1234): | |
noising_gen = noising.WordDropout(vocab) | |
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2) | |
self.assert_word_dropout_correct( | |
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised | |
) | |
self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos()) | |
def test_word_blank_without_eos(self): | |
"""Same result as word blank with eos except no EOS at end""" | |
vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=False) | |
with data_utils.numpy_seed(1234): | |
noising_gen = noising.WordDropout(vocab) | |
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2, vocab.unk()) | |
self.assert_word_blanking_correct( | |
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised, unk=vocab.unk() | |
) | |
self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos()) | |
def _get_noising_dataset_batch( | |
self, | |
src_tokens_no_pad, | |
src_dict, | |
append_eos_to_tgt=False, | |
): | |
""" | |
Constructs a NoisingDataset and the corresponding | |
``LanguagePairDataset(NoisingDataset(src), src)``. If | |
*append_eos_to_tgt* is True, wrap the source dataset in | |
:class:`TransformEosDataset` to append EOS to the clean source when | |
using it as the target. | |
""" | |
src_dataset = test_utils.TestDataset(data=src_tokens_no_pad) | |
noising_dataset = noising.NoisingDataset( | |
src_dataset=src_dataset, | |
src_dict=src_dict, | |
seed=1234, | |
max_word_shuffle_distance=3, | |
word_dropout_prob=0.2, | |
word_blanking_prob=0.2, | |
noising_class=noising.UnsupervisedMTNoising, | |
) | |
tgt = src_dataset | |
language_pair_dataset = LanguagePairDataset( | |
src=noising_dataset, tgt=tgt, src_sizes=None, src_dict=src_dict | |
) | |
language_pair_dataset = TransformEosDataset( | |
language_pair_dataset, | |
src_dict.eos(), | |
append_eos_to_tgt=append_eos_to_tgt, | |
) | |
dataloader = torch.utils.data.DataLoader( | |
dataset=language_pair_dataset, | |
batch_size=2, | |
collate_fn=language_pair_dataset.collater, | |
) | |
denoising_batch_result = next(iter(dataloader)) | |
return denoising_batch_result | |
def test_noising_dataset_with_eos(self): | |
src_dict, src_tokens, _ = self._get_test_data_with_bpe_cont_marker( | |
append_eos=True | |
) | |
# Format data for src_dataset | |
src_tokens = torch.t(src_tokens) | |
src_tokens_no_pad = [] | |
for src_sentence in src_tokens: | |
src_tokens_no_pad.append( | |
utils.strip_pad(tensor=src_sentence, pad=src_dict.pad()) | |
) | |
denoising_batch_result = self._get_noising_dataset_batch( | |
src_tokens_no_pad=src_tokens_no_pad, src_dict=src_dict | |
) | |
eos, pad = src_dict.eos(), src_dict.pad() | |
# Generated noisy source as source | |
expected_src = torch.LongTensor( | |
[[4, 5, 10, 11, 8, 12, 13, eos], [pad, pad, pad, 6, 8, 9, 7, eos]] | |
) | |
# Original clean source as target (right-padded) | |
expected_tgt = torch.LongTensor( | |
[[4, 5, 10, 11, 8, 12, 13, eos], [6, 7, 8, 9, eos, pad, pad, pad]] | |
) | |
generated_src = denoising_batch_result["net_input"]["src_tokens"] | |
tgt_tokens = denoising_batch_result["target"] | |
self.assertTensorEqual(expected_src, generated_src) | |
self.assertTensorEqual(expected_tgt, tgt_tokens) | |
def test_noising_dataset_without_eos(self): | |
""" | |
Similar to test noising dataset with eos except that we have to set | |
*append_eos_to_tgt* to ``True``. | |
""" | |
src_dict, src_tokens, _ = self._get_test_data_with_bpe_cont_marker( | |
append_eos=False | |
) | |
# Format data for src_dataset | |
src_tokens = torch.t(src_tokens) | |
src_tokens_no_pad = [] | |
for src_sentence in src_tokens: | |
src_tokens_no_pad.append( | |
utils.strip_pad(tensor=src_sentence, pad=src_dict.pad()) | |
) | |
denoising_batch_result = self._get_noising_dataset_batch( | |
src_tokens_no_pad=src_tokens_no_pad, | |
src_dict=src_dict, | |
append_eos_to_tgt=True, | |
) | |
eos, pad = src_dict.eos(), src_dict.pad() | |
# Generated noisy source as source | |
expected_src = torch.LongTensor( | |
[[4, 5, 10, 11, 8, 12, 13], [pad, pad, pad, 6, 8, 9, 7]] | |
) | |
# Original clean source as target (right-padded) | |
expected_tgt = torch.LongTensor( | |
[[4, 5, 10, 11, 8, 12, 13, eos], [6, 7, 8, 9, eos, pad, pad, pad]] | |
) | |
generated_src = denoising_batch_result["net_input"]["src_tokens"] | |
tgt_tokens = denoising_batch_result["target"] | |
self.assertTensorEqual(expected_src, generated_src) | |
self.assertTensorEqual(expected_tgt, tgt_tokens) | |
def assertTensorEqual(self, t1, t2): | |
self.assertEqual(t1.size(), t2.size(), "size mismatch") | |
self.assertEqual(t1.ne(t2).long().sum(), 0) | |
if __name__ == "__main__": | |
unittest.main() | |