|
|
|
|
|
|
|
from examples.speech_recognition.models.vggtransformer import ( |
|
TransformerDecoder, |
|
VGGTransformerEncoder, |
|
VGGTransformerModel, |
|
vggtransformer_1, |
|
vggtransformer_2, |
|
vggtransformer_base, |
|
) |
|
|
|
|
|
from .asr_test_base import ( |
|
DEFAULT_TEST_VOCAB_SIZE, |
|
TestFairseqDecoderBase, |
|
TestFairseqEncoderBase, |
|
TestFairseqEncoderDecoderModelBase, |
|
get_dummy_dictionary, |
|
get_dummy_encoder_output, |
|
get_dummy_input, |
|
) |
|
|
|
|
|
class VGGTransformerModelTest_mid(TestFairseqEncoderDecoderModelBase): |
|
def setUp(self): |
|
def override_config(args): |
|
""" |
|
vggtrasformer_1 use 14 layers of transformer, |
|
for testing purpose, it is too expensive. For fast turn-around |
|
test, reduce the number of layers to 3. |
|
""" |
|
args.transformer_enc_config = ( |
|
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 3" |
|
) |
|
|
|
super().setUp() |
|
extra_args_setter = [vggtransformer_1, override_config] |
|
|
|
self.setUpModel(VGGTransformerModel, extra_args_setter) |
|
self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE)) |
|
|
|
|
|
class VGGTransformerModelTest_big(TestFairseqEncoderDecoderModelBase): |
|
def setUp(self): |
|
def override_config(args): |
|
""" |
|
vggtrasformer_2 use 16 layers of transformer, |
|
for testing purpose, it is too expensive. For fast turn-around |
|
test, reduce the number of layers to 3. |
|
""" |
|
args.transformer_enc_config = ( |
|
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 3" |
|
) |
|
|
|
super().setUp() |
|
extra_args_setter = [vggtransformer_2, override_config] |
|
|
|
self.setUpModel(VGGTransformerModel, extra_args_setter) |
|
self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE)) |
|
|
|
|
|
class VGGTransformerModelTest_base(TestFairseqEncoderDecoderModelBase): |
|
def setUp(self): |
|
def override_config(args): |
|
""" |
|
vggtrasformer_base use 12 layers of transformer, |
|
for testing purpose, it is too expensive. For fast turn-around |
|
test, reduce the number of layers to 3. |
|
""" |
|
args.transformer_enc_config = ( |
|
"((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 3" |
|
) |
|
|
|
super().setUp() |
|
extra_args_setter = [vggtransformer_base, override_config] |
|
|
|
self.setUpModel(VGGTransformerModel, extra_args_setter) |
|
self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE)) |
|
|
|
|
|
class VGGTransformerEncoderTest(TestFairseqEncoderBase): |
|
def setUp(self): |
|
super().setUp() |
|
|
|
self.setUpInput(get_dummy_input(T=50, D=80, B=5)) |
|
|
|
def test_forward(self): |
|
print("1. test standard vggtransformer") |
|
self.setUpEncoder(VGGTransformerEncoder(input_feat_per_channel=80)) |
|
super().test_forward() |
|
print("2. test vggtransformer with limited right context") |
|
self.setUpEncoder( |
|
VGGTransformerEncoder( |
|
input_feat_per_channel=80, transformer_context=(-1, 5) |
|
) |
|
) |
|
super().test_forward() |
|
print("3. test vggtransformer with limited left context") |
|
self.setUpEncoder( |
|
VGGTransformerEncoder( |
|
input_feat_per_channel=80, transformer_context=(5, -1) |
|
) |
|
) |
|
super().test_forward() |
|
print("4. test vggtransformer with limited right context and sampling") |
|
self.setUpEncoder( |
|
VGGTransformerEncoder( |
|
input_feat_per_channel=80, |
|
transformer_context=(-1, 12), |
|
transformer_sampling=(2, 2), |
|
) |
|
) |
|
super().test_forward() |
|
print("5. test vggtransformer with windowed context and sampling") |
|
self.setUpEncoder( |
|
VGGTransformerEncoder( |
|
input_feat_per_channel=80, |
|
transformer_context=(12, 12), |
|
transformer_sampling=(2, 2), |
|
) |
|
) |
|
|
|
|
|
class TransformerDecoderTest(TestFairseqDecoderBase): |
|
def setUp(self): |
|
super().setUp() |
|
|
|
dict = get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE) |
|
decoder = TransformerDecoder(dict) |
|
dummy_encoder_output = get_dummy_encoder_output(encoder_out_shape=(50, 5, 256)) |
|
|
|
self.setUpDecoder(decoder) |
|
self.setUpInput(dummy_encoder_output) |
|
self.setUpPrevOutputTokens() |
|
|