OFA-OCR / fairseq /tests /test_noising.py
JustinLin610's picture
first commit
ee21b96
raw
history blame
19.8 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from typing import Dict, List
import tests.utils as test_utils
import torch
from fairseq import utils
from fairseq.data import (
Dictionary,
LanguagePairDataset,
TransformEosDataset,
data_utils,
noising,
)
class TestDataNoising(unittest.TestCase):
def _get_test_data_with_bpe_cont_marker(self, append_eos=True):
"""
Args:
append_eos: if True, each input sentence in the source tokens tensor
will have an EOS appended to the end.
Returns:
vocabs: BPE vocab with continuation markers as suffixes to denote
non-end of word tokens. This is the standard BPE format used in
fairseq's preprocessing.
x: input tensor containing numberized source tokens, with EOS at the
end if append_eos is true
src_lengths: and source lengths.
"""
vocab = Dictionary()
vocab.add_symbol("he@@")
vocab.add_symbol("llo")
vocab.add_symbol("how")
vocab.add_symbol("are")
vocab.add_symbol("y@@")
vocab.add_symbol("ou")
vocab.add_symbol("n@@")
vocab.add_symbol("ew")
vocab.add_symbol("or@@")
vocab.add_symbol("k")
src_tokens = [
["he@@", "llo", "n@@", "ew", "y@@", "or@@", "k"],
["how", "are", "y@@", "ou"],
]
x, src_lengths = x, src_lengths = self._convert_src_tokens_to_tensor(
vocab=vocab, src_tokens=src_tokens, append_eos=append_eos
)
return vocab, x, src_lengths
def _get_test_data_with_bpe_end_marker(self, append_eos=True):
"""
Args:
append_eos: if True, each input sentence in the source tokens tensor
will have an EOS appended to the end.
Returns:
vocabs: BPE vocab with end-of-word markers as suffixes to denote
tokens at the end of a word. This is an alternative to fairseq's
standard preprocessing framework and is not generally supported
within fairseq.
x: input tensor containing numberized source tokens, with EOS at the
end if append_eos is true
src_lengths: and source lengths.
"""
vocab = Dictionary()
vocab.add_symbol("he")
vocab.add_symbol("llo_EOW")
vocab.add_symbol("how_EOW")
vocab.add_symbol("are_EOW")
vocab.add_symbol("y")
vocab.add_symbol("ou_EOW")
vocab.add_symbol("n")
vocab.add_symbol("ew_EOW")
vocab.add_symbol("or")
vocab.add_symbol("k_EOW")
src_tokens = [
["he", "llo_EOW", "n", "ew_EOW", "y", "or", "k_EOW"],
["how_EOW", "are_EOW", "y", "ou_EOW"],
]
x, src_lengths = x, src_lengths = self._convert_src_tokens_to_tensor(
vocab=vocab, src_tokens=src_tokens, append_eos=append_eos
)
return vocab, x, src_lengths
def _get_test_data_with_word_vocab(self, append_eos=True):
"""
Args:
append_eos: if True, each input sentence in the source tokens tensor
will have an EOS appended to the end.
Returns:
vocabs: word vocab
x: input tensor containing numberized source tokens, with EOS at the
end if append_eos is true
src_lengths: and source lengths.
"""
vocab = Dictionary()
vocab.add_symbol("hello")
vocab.add_symbol("how")
vocab.add_symbol("are")
vocab.add_symbol("you")
vocab.add_symbol("new")
vocab.add_symbol("york")
src_tokens = [
["hello", "new", "york", "you"],
["how", "are", "you", "new", "york"],
]
x, src_lengths = self._convert_src_tokens_to_tensor(
vocab=vocab, src_tokens=src_tokens, append_eos=append_eos
)
return vocab, x, src_lengths
def _convert_src_tokens_to_tensor(
self, vocab: Dictionary, src_tokens: List[List[str]], append_eos: bool
):
src_len = [len(x) for x in src_tokens]
# If we have to append EOS, we include EOS in counting src length
if append_eos:
src_len = [length + 1 for length in src_len]
x = torch.LongTensor(len(src_tokens), max(src_len)).fill_(vocab.pad())
for i in range(len(src_tokens)):
for j in range(len(src_tokens[i])):
x[i][j] = vocab.index(src_tokens[i][j])
if append_eos:
x[i][j + 1] = vocab.eos()
x = x.transpose(1, 0)
return x, torch.LongTensor(src_len)
def assert_eos_at_end(self, x, x_len, eos):
"""Asserts last token of every sentence in x is EOS """
for i in range(len(x_len)):
self.assertEqual(
x[x_len[i] - 1][i],
eos,
(
"Expected eos (token id {eos}) at the end of sentence {i} "
"but got {other} instead"
).format(i=i, eos=eos, other=x[i][-1]),
)
def assert_word_dropout_correct(self, x, x_noised, x_len, l_noised):
# Expect only the first word (2 bpe tokens) of the first example
# was dropped out
self.assertEqual(x_len[0] - 2, l_noised[0])
for i in range(l_noised[0]):
self.assertEqual(x_noised[i][0], x[i + 2][0])
def test_word_dropout_with_eos(self):
vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=True)
with data_utils.numpy_seed(1234):
noising_gen = noising.WordDropout(vocab)
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2)
self.assert_word_dropout_correct(
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
)
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def assert_word_blanking_correct(self, x, x_noised, x_len, l_noised, unk):
# Expect only the first word (2 bpe tokens) of the first example
# was blanked out
self.assertEqual(x_len[0], l_noised[0])
for i in range(l_noised[0]):
if i < 2:
self.assertEqual(x_noised[i][0], unk)
else:
self.assertEqual(x_noised[i][0], x[i][0])
def test_word_blank_with_eos(self):
vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=True)
with data_utils.numpy_seed(1234):
noising_gen = noising.WordDropout(vocab)
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2, vocab.unk())
self.assert_word_blanking_correct(
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised, unk=vocab.unk()
)
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def generate_unchanged_shuffle_map(self, length):
return {i: i for i in range(length)}
def assert_word_shuffle_matches_expected(
self,
x,
x_len,
max_shuffle_distance: int,
vocab: Dictionary,
expected_shufle_maps: List[Dict[int, int]],
expect_eos_at_end: bool,
bpe_end_marker=None,
):
"""
This verifies that with a given x, x_len, max_shuffle_distance, and
vocab, we get the expected shuffle result.
Args:
x: Tensor of shape (T x B) = (sequence_length, batch_size)
x_len: Tensor of length B = batch_size
max_shuffle_distance: arg to pass to noising
expected_shuffle_maps: List[mapping] where mapping is a
Dict[old_index, new_index], mapping x's elements from their
old positions in x to their new positions in x.
expect_eos_at_end: if True, check the output to make sure there is
an EOS at the end.
bpe_end_marker: str denoting the BPE end token. If this is not None, we
set the BPE cont token to None in the noising classes.
"""
bpe_cont_marker = None
if bpe_end_marker is None:
bpe_cont_marker = "@@"
with data_utils.numpy_seed(1234):
word_shuffle = noising.WordShuffle(
vocab, bpe_cont_marker=bpe_cont_marker, bpe_end_marker=bpe_end_marker
)
x_noised, l_noised = word_shuffle.noising(
x, x_len, max_shuffle_distance=max_shuffle_distance
)
# For every example, we have a different expected shuffle map. We check
# that each example is shuffled as expected according to each
# corresponding shuffle map.
for i in range(len(expected_shufle_maps)):
shuffle_map = expected_shufle_maps[i]
for k, v in shuffle_map.items():
self.assertEqual(x[k][i], x_noised[v][i])
# Shuffling should not affect the length of each example
for pre_shuffle_length, post_shuffle_length in zip(x_len, l_noised):
self.assertEqual(pre_shuffle_length, post_shuffle_length)
if expect_eos_at_end:
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def test_word_shuffle_with_eos(self):
vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=True)
# Assert word shuffle with max shuffle distance 0 causes input to be
# unchanged
self.assert_word_shuffle_matches_expected(
x=x,
x_len=x_len,
max_shuffle_distance=0,
vocab=vocab,
expected_shufle_maps=[
self.generate_unchanged_shuffle_map(example_len)
for example_len in x_len
],
expect_eos_at_end=True,
)
# Assert word shuffle with max shuffle distance 3 matches our expected
# shuffle order
self.assert_word_shuffle_matches_expected(
x=x,
x_len=x_len,
vocab=vocab,
max_shuffle_distance=3,
expected_shufle_maps=[
self.generate_unchanged_shuffle_map(x_len[0]),
{0: 0, 1: 3, 2: 1, 3: 2},
],
expect_eos_at_end=True,
)
def test_word_shuffle_with_eos_nonbpe(self):
"""The purpose of this is to test shuffling logic with word vocabs"""
vocab, x, x_len = self._get_test_data_with_word_vocab(append_eos=True)
# Assert word shuffle with max shuffle distance 0 causes input to be
# unchanged
self.assert_word_shuffle_matches_expected(
x=x,
x_len=x_len,
max_shuffle_distance=0,
vocab=vocab,
expected_shufle_maps=[
self.generate_unchanged_shuffle_map(example_len)
for example_len in x_len
],
expect_eos_at_end=True,
)
# Assert word shuffle with max shuffle distance 3 matches our expected
# shuffle order
self.assert_word_shuffle_matches_expected(
x=x,
x_len=x_len,
vocab=vocab,
max_shuffle_distance=3,
expected_shufle_maps=[
{0: 0, 1: 1, 2: 3, 3: 2},
{0: 0, 1: 2, 2: 1, 3: 3, 4: 4},
],
expect_eos_at_end=True,
)
def test_word_shuffle_without_eos(self):
"""Same result as word shuffle with eos except no EOS at end"""
vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=False)
# Assert word shuffle with max shuffle distance 0 causes input to be
# unchanged
self.assert_word_shuffle_matches_expected(
x=x,
x_len=x_len,
max_shuffle_distance=0,
vocab=vocab,
expected_shufle_maps=[
self.generate_unchanged_shuffle_map(example_len)
for example_len in x_len
],
expect_eos_at_end=False,
)
# Assert word shuffle with max shuffle distance 3 matches our expected
# shuffle order
self.assert_word_shuffle_matches_expected(
x=x,
x_len=x_len,
vocab=vocab,
max_shuffle_distance=3,
expected_shufle_maps=[
self.generate_unchanged_shuffle_map(x_len[0]),
{0: 0, 1: 3, 2: 1, 3: 2},
],
expect_eos_at_end=False,
)
def test_word_shuffle_without_eos_with_bpe_end_marker(self):
"""Same result as word shuffle without eos except using BPE end token"""
vocab, x, x_len = self._get_test_data_with_bpe_end_marker(append_eos=False)
# Assert word shuffle with max shuffle distance 0 causes input to be
# unchanged
self.assert_word_shuffle_matches_expected(
x=x,
x_len=x_len,
max_shuffle_distance=0,
vocab=vocab,
expected_shufle_maps=[
self.generate_unchanged_shuffle_map(example_len)
for example_len in x_len
],
expect_eos_at_end=False,
bpe_end_marker="_EOW",
)
# Assert word shuffle with max shuffle distance 3 matches our expected
# shuffle order
self.assert_word_shuffle_matches_expected(
x=x,
x_len=x_len,
vocab=vocab,
max_shuffle_distance=3,
expected_shufle_maps=[
self.generate_unchanged_shuffle_map(x_len[0]),
{0: 0, 1: 3, 2: 1, 3: 2},
],
expect_eos_at_end=False,
bpe_end_marker="_EOW",
)
def assert_no_eos_at_end(self, x, x_len, eos):
"""Asserts that the last token of each sentence in x is not EOS """
for i in range(len(x_len)):
self.assertNotEqual(
x[x_len[i] - 1][i],
eos,
"Expected no eos (token id {eos}) at the end of sentence {i}.".format(
eos=eos, i=i
),
)
def test_word_dropout_without_eos(self):
"""Same result as word dropout with eos except no EOS at end"""
vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=False)
with data_utils.numpy_seed(1234):
noising_gen = noising.WordDropout(vocab)
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2)
self.assert_word_dropout_correct(
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
)
self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def test_word_blank_without_eos(self):
"""Same result as word blank with eos except no EOS at end"""
vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=False)
with data_utils.numpy_seed(1234):
noising_gen = noising.WordDropout(vocab)
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2, vocab.unk())
self.assert_word_blanking_correct(
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised, unk=vocab.unk()
)
self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def _get_noising_dataset_batch(
self,
src_tokens_no_pad,
src_dict,
append_eos_to_tgt=False,
):
"""
Constructs a NoisingDataset and the corresponding
``LanguagePairDataset(NoisingDataset(src), src)``. If
*append_eos_to_tgt* is True, wrap the source dataset in
:class:`TransformEosDataset` to append EOS to the clean source when
using it as the target.
"""
src_dataset = test_utils.TestDataset(data=src_tokens_no_pad)
noising_dataset = noising.NoisingDataset(
src_dataset=src_dataset,
src_dict=src_dict,
seed=1234,
max_word_shuffle_distance=3,
word_dropout_prob=0.2,
word_blanking_prob=0.2,
noising_class=noising.UnsupervisedMTNoising,
)
tgt = src_dataset
language_pair_dataset = LanguagePairDataset(
src=noising_dataset, tgt=tgt, src_sizes=None, src_dict=src_dict
)
language_pair_dataset = TransformEosDataset(
language_pair_dataset,
src_dict.eos(),
append_eos_to_tgt=append_eos_to_tgt,
)
dataloader = torch.utils.data.DataLoader(
dataset=language_pair_dataset,
batch_size=2,
collate_fn=language_pair_dataset.collater,
)
denoising_batch_result = next(iter(dataloader))
return denoising_batch_result
def test_noising_dataset_with_eos(self):
src_dict, src_tokens, _ = self._get_test_data_with_bpe_cont_marker(
append_eos=True
)
# Format data for src_dataset
src_tokens = torch.t(src_tokens)
src_tokens_no_pad = []
for src_sentence in src_tokens:
src_tokens_no_pad.append(
utils.strip_pad(tensor=src_sentence, pad=src_dict.pad())
)
denoising_batch_result = self._get_noising_dataset_batch(
src_tokens_no_pad=src_tokens_no_pad, src_dict=src_dict
)
eos, pad = src_dict.eos(), src_dict.pad()
# Generated noisy source as source
expected_src = torch.LongTensor(
[[4, 5, 10, 11, 8, 12, 13, eos], [pad, pad, pad, 6, 8, 9, 7, eos]]
)
# Original clean source as target (right-padded)
expected_tgt = torch.LongTensor(
[[4, 5, 10, 11, 8, 12, 13, eos], [6, 7, 8, 9, eos, pad, pad, pad]]
)
generated_src = denoising_batch_result["net_input"]["src_tokens"]
tgt_tokens = denoising_batch_result["target"]
self.assertTensorEqual(expected_src, generated_src)
self.assertTensorEqual(expected_tgt, tgt_tokens)
def test_noising_dataset_without_eos(self):
"""
Similar to test noising dataset with eos except that we have to set
*append_eos_to_tgt* to ``True``.
"""
src_dict, src_tokens, _ = self._get_test_data_with_bpe_cont_marker(
append_eos=False
)
# Format data for src_dataset
src_tokens = torch.t(src_tokens)
src_tokens_no_pad = []
for src_sentence in src_tokens:
src_tokens_no_pad.append(
utils.strip_pad(tensor=src_sentence, pad=src_dict.pad())
)
denoising_batch_result = self._get_noising_dataset_batch(
src_tokens_no_pad=src_tokens_no_pad,
src_dict=src_dict,
append_eos_to_tgt=True,
)
eos, pad = src_dict.eos(), src_dict.pad()
# Generated noisy source as source
expected_src = torch.LongTensor(
[[4, 5, 10, 11, 8, 12, 13], [pad, pad, pad, 6, 8, 9, 7]]
)
# Original clean source as target (right-padded)
expected_tgt = torch.LongTensor(
[[4, 5, 10, 11, 8, 12, 13, eos], [6, 7, 8, 9, eos, pad, pad, pad]]
)
generated_src = denoising_batch_result["net_input"]["src_tokens"]
tgt_tokens = denoising_batch_result["target"]
self.assertTensorEqual(expected_src, generated_src)
self.assertTensorEqual(expected_tgt, tgt_tokens)
def assertTensorEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertEqual(t1.ne(t2).long().sum(), 0)
if __name__ == "__main__":
unittest.main()