#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse import tempfile import unittest import torch from fairseq.data.dictionary import Dictionary from fairseq.models.transformer import TransformerModel from fairseq.modules import multihead_attention, sinusoidal_positional_embedding from fairseq.tasks.fairseq_task import LegacyFairseqTask DEFAULT_TEST_VOCAB_SIZE = 100 class DummyTask(LegacyFairseqTask): def __init__(self, args): super().__init__(args) self.dictionary = get_dummy_dictionary() if getattr(self.args, "ctc", False): self.dictionary.add_symbol("") self.src_dict = self.dictionary self.tgt_dict = self.dictionary @property def source_dictionary(self): return self.src_dict @property def target_dictionary(self): return self.dictionary def get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE): dummy_dict = Dictionary() # add dummy symbol to satisfy vocab size for id, _ in enumerate(range(vocab_size)): dummy_dict.add_symbol("{}".format(id), 1000) return dummy_dict def get_dummy_task_and_parser(): """ Return a dummy task and argument parser, which can be used to create a model/criterion. """ parser = argparse.ArgumentParser( description="test_dummy_s2s_task", argument_default=argparse.SUPPRESS ) DummyTask.add_args(parser) args = parser.parse_args([]) task = DummyTask.setup_task(args) return task, parser def _test_save_and_load(scripted_module): with tempfile.NamedTemporaryFile() as f: scripted_module.save(f.name) torch.jit.load(f.name) class TestExportModels(unittest.TestCase): def test_export_multihead_attention(self): module = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2) scripted = torch.jit.script(module) _test_save_and_load(scripted) def test_incremental_state_multihead_attention(self): module1 = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2) module1 = torch.jit.script(module1) module2 = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2) module2 = torch.jit.script(module2) state = {} state = module1.set_incremental_state(state, "key", {"a": torch.tensor([1])}) state = module2.set_incremental_state(state, "key", {"a": torch.tensor([2])}) v1 = module1.get_incremental_state(state, "key")["a"] v2 = module2.get_incremental_state(state, "key")["a"] self.assertEqual(v1, 1) self.assertEqual(v2, 2) def test_positional_embedding(self): module = sinusoidal_positional_embedding.SinusoidalPositionalEmbedding( embedding_dim=8, padding_idx=1 ) scripted = torch.jit.script(module) _test_save_and_load(scripted) @unittest.skipIf( torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release" ) def test_export_transformer(self): task, parser = get_dummy_task_and_parser() TransformerModel.add_args(parser) args = parser.parse_args([]) model = TransformerModel.build_model(args, task) scripted = torch.jit.script(model) _test_save_and_load(scripted) @unittest.skipIf( torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release" ) def test_export_transformer_no_token_pos_emb(self): task, parser = get_dummy_task_and_parser() TransformerModel.add_args(parser) args = parser.parse_args([]) args.no_token_positional_embeddings = True model = TransformerModel.build_model(args, task) scripted = torch.jit.script(model) _test_save_and_load(scripted) if __name__ == "__main__": unittest.main()