# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import io import os import string import tempfile import unittest import torch from fairseq import tokenizer from fairseq.data import Dictionary class TestDictionary(unittest.TestCase): def test_finalize(self): txt = [ "A B C D", "B C D", "C D", "D", ] ref_ids1 = list( map( torch.IntTensor, [ [4, 5, 6, 7, 2], [5, 6, 7, 2], [6, 7, 2], [7, 2], ], ) ) ref_ids2 = list( map( torch.IntTensor, [ [7, 6, 5, 4, 2], [6, 5, 4, 2], [5, 4, 2], [4, 2], ], ) ) # build dictionary d = Dictionary() for line in txt: d.encode_line(line, add_if_not_exist=True) def get_ids(dictionary): ids = [] for line in txt: ids.append(dictionary.encode_line(line, add_if_not_exist=False)) return ids def assertMatch(ids, ref_ids): for toks, ref_toks in zip(ids, ref_ids): self.assertEqual(toks.size(), ref_toks.size()) self.assertEqual(0, (toks != ref_toks).sum().item()) ids = get_ids(d) assertMatch(ids, ref_ids1) # check finalized dictionary d.finalize() finalized_ids = get_ids(d) assertMatch(finalized_ids, ref_ids2) # write to disk and reload with tempfile.NamedTemporaryFile(mode="w") as tmp_dict: d.save(tmp_dict.name) d = Dictionary.load(tmp_dict.name) reload_ids = get_ids(d) assertMatch(reload_ids, ref_ids2) assertMatch(finalized_ids, reload_ids) def test_overwrite(self): # for example, Camembert overwrites , and dict_file = io.StringIO( " 999 #fairseq:overwrite\n" " 999 #fairseq:overwrite\n" " 999 #fairseq:overwrite\n" ", 999\n" "▁de 999\n" ) d = Dictionary() d.add_from_file(dict_file) self.assertEqual(d.index(""), 1) self.assertEqual(d.index("foo"), 3) self.assertEqual(d.index(""), 4) self.assertEqual(d.index(""), 5) self.assertEqual(d.index(""), 6) self.assertEqual(d.index(","), 7) self.assertEqual(d.index("▁de"), 8) def test_no_overwrite(self): # for example, Camembert overwrites , and dict_file = io.StringIO( " 999\n" " 999\n" " 999\n" ", 999\n" "▁de 999\n" ) d = Dictionary() with self.assertRaisesRegex(RuntimeError, "Duplicate"): d.add_from_file(dict_file) def test_space(self): # for example, character models treat space as a symbol dict_file = io.StringIO(" 999\n" "a 999\n" "b 999\n") d = Dictionary() d.add_from_file(dict_file) self.assertEqual(d.index(" "), 4) self.assertEqual(d.index("a"), 5) self.assertEqual(d.index("b"), 6) def test_add_file_to_dict(self): counts = {} num_lines = 100 per_line = 10 with tempfile.TemporaryDirectory("test_sampling") as data_dir: filename = os.path.join(data_dir, "dummy.txt") with open(filename, "w", encoding="utf-8") as data: for c in string.ascii_letters: line = f"{c} " * per_line for _ in range(num_lines): data.write(f"{line}\n") counts[c] = per_line * num_lines per_line += 5 dict = Dictionary() Dictionary.add_file_to_dictionary( filename, dict, tokenizer.tokenize_line, 10 ) dict.finalize(threshold=0, nwords=-1, padding_factor=8) for c in string.ascii_letters: count = dict.get_count(dict.index(c)) self.assertEqual( counts[c], count, f"{c} count is {count} but should be {counts[c]}" ) if __name__ == "__main__": unittest.main()