#!/usr/bin/env python3 # import models/encoder/decoder to be tested from examples.speech_recognition.models.vggtransformer import ( TransformerDecoder, VGGTransformerEncoder, VGGTransformerModel, vggtransformer_1, vggtransformer_2, vggtransformer_base, ) # import base test class from .asr_test_base import ( DEFAULT_TEST_VOCAB_SIZE, TestFairseqDecoderBase, TestFairseqEncoderBase, TestFairseqEncoderDecoderModelBase, get_dummy_dictionary, get_dummy_encoder_output, get_dummy_input, ) class VGGTransformerModelTest_mid(TestFairseqEncoderDecoderModelBase): def setUp(self): def override_config(args): """ vggtrasformer_1 use 14 layers of transformer, for testing purpose, it is too expensive. For fast turn-around test, reduce the number of layers to 3. """ args.transformer_enc_config = ( "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 3" ) super().setUp() extra_args_setter = [vggtransformer_1, override_config] self.setUpModel(VGGTransformerModel, extra_args_setter) self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE)) class VGGTransformerModelTest_big(TestFairseqEncoderDecoderModelBase): def setUp(self): def override_config(args): """ vggtrasformer_2 use 16 layers of transformer, for testing purpose, it is too expensive. For fast turn-around test, reduce the number of layers to 3. """ args.transformer_enc_config = ( "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 3" ) super().setUp() extra_args_setter = [vggtransformer_2, override_config] self.setUpModel(VGGTransformerModel, extra_args_setter) self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE)) class VGGTransformerModelTest_base(TestFairseqEncoderDecoderModelBase): def setUp(self): def override_config(args): """ vggtrasformer_base use 12 layers of transformer, for testing purpose, it is too expensive. For fast turn-around test, reduce the number of layers to 3. """ args.transformer_enc_config = ( "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 3" ) super().setUp() extra_args_setter = [vggtransformer_base, override_config] self.setUpModel(VGGTransformerModel, extra_args_setter) self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE)) class VGGTransformerEncoderTest(TestFairseqEncoderBase): def setUp(self): super().setUp() self.setUpInput(get_dummy_input(T=50, D=80, B=5)) def test_forward(self): print("1. test standard vggtransformer") self.setUpEncoder(VGGTransformerEncoder(input_feat_per_channel=80)) super().test_forward() print("2. test vggtransformer with limited right context") self.setUpEncoder( VGGTransformerEncoder( input_feat_per_channel=80, transformer_context=(-1, 5) ) ) super().test_forward() print("3. test vggtransformer with limited left context") self.setUpEncoder( VGGTransformerEncoder( input_feat_per_channel=80, transformer_context=(5, -1) ) ) super().test_forward() print("4. test vggtransformer with limited right context and sampling") self.setUpEncoder( VGGTransformerEncoder( input_feat_per_channel=80, transformer_context=(-1, 12), transformer_sampling=(2, 2), ) ) super().test_forward() print("5. test vggtransformer with windowed context and sampling") self.setUpEncoder( VGGTransformerEncoder( input_feat_per_channel=80, transformer_context=(12, 12), transformer_sampling=(2, 2), ) ) class TransformerDecoderTest(TestFairseqDecoderBase): def setUp(self): super().setUp() dict = get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE) decoder = TransformerDecoder(dict) dummy_encoder_output = get_dummy_encoder_output(encoder_out_shape=(50, 5, 256)) self.setUpDecoder(decoder) self.setUpInput(dummy_encoder_output) self.setUpPrevOutputTokens()