# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse import tempfile import unittest import torch from fairseq.data.dictionary import Dictionary from fairseq.models.lstm import LSTMModel from fairseq.tasks.fairseq_task import LegacyFairseqTask DEFAULT_TEST_VOCAB_SIZE = 100 class DummyTask(LegacyFairseqTask): def __init__(self, args): super().__init__(args) self.dictionary = get_dummy_dictionary() if getattr(self.args, "ctc", False): self.dictionary.add_symbol("") self.src_dict = self.dictionary self.tgt_dict = self.dictionary @property def source_dictionary(self): return self.src_dict @property def target_dictionary(self): return self.dictionary def get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE): dummy_dict = Dictionary() # add dummy symbol to satisfy vocab size for id, _ in enumerate(range(vocab_size)): dummy_dict.add_symbol("{}".format(id), 1000) return dummy_dict def get_dummy_task_and_parser(): """ to build a fariseq model, we need some dummy parse and task. This function is used to create dummy task and parser to faciliate model/criterion test Note: we use FbSpeechRecognitionTask as the dummy task. You may want to use other task by providing another function """ parser = argparse.ArgumentParser( description="test_dummy_s2s_task", argument_default=argparse.SUPPRESS ) DummyTask.add_args(parser) args = parser.parse_args([]) task = DummyTask.setup_task(args) return task, parser class TestJitLSTMModel(unittest.TestCase): def _test_save_and_load(self, scripted_module): with tempfile.NamedTemporaryFile() as f: scripted_module.save(f.name) torch.jit.load(f.name) def assertTensorEqual(self, t1, t2): t1 = t1[~torch.isnan(t1)] # can cause size mismatch errors if there are NaNs t2 = t2[~torch.isnan(t2)] self.assertEqual(t1.size(), t2.size(), "size mismatch") self.assertEqual(t1.ne(t2).long().sum(), 0) def test_jit_and_export_lstm(self): task, parser = get_dummy_task_and_parser() LSTMModel.add_args(parser) args = parser.parse_args([]) args.criterion = "" model = LSTMModel.build_model(args, task) scripted_model = torch.jit.script(model) self._test_save_and_load(scripted_model) def test_assert_jit_vs_nonjit_(self): task, parser = get_dummy_task_and_parser() LSTMModel.add_args(parser) args = parser.parse_args([]) args.criterion = "" model = LSTMModel.build_model(args, task) model.eval() scripted_model = torch.jit.script(model) scripted_model.eval() idx = len(task.source_dictionary) iter = 100 # Inject random input and check output seq_len_tensor = torch.randint(1, 10, (iter,)) num_samples_tensor = torch.randint(1, 10, (iter,)) for i in range(iter): seq_len = seq_len_tensor[i] num_samples = num_samples_tensor[i] src_token = (torch.randint(0, idx, (num_samples, seq_len)),) src_lengths = torch.randint(1, seq_len + 1, (num_samples,)) src_lengths, _ = torch.sort(src_lengths, descending=True) # Force the first sample to have seq_len src_lengths[0] = seq_len prev_output_token = (torch.randint(0, idx, (num_samples, 1)),) result = model(src_token[0], src_lengths, prev_output_token[0], None) scripted_result = scripted_model( src_token[0], src_lengths, prev_output_token[0], None ) self.assertTensorEqual(result[0], scripted_result[0]) self.assertTensorEqual(result[1], scripted_result[1]) if __name__ == "__main__": unittest.main()