Titouan commited on
Commit
9de803c
·
1 Parent(s): 6c63eca

EncoderDecoderASR update

Browse files
Files changed (1) hide show
  1. hyperparams.yaml +14 -6
hyperparams.yaml CHANGED
@@ -86,15 +86,19 @@ tokenizer: !new:sentencepiece.SentencePieceProcessor
86
  asr_model: !new:torch.nn.ModuleList
87
  - [!ref <CNN>, !ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
88
 
89
- modules:
90
- compute_features: !ref <compute_features>
91
- pre_transformer: !ref <CNN>
92
  transformer: !ref <Transformer>
93
- asr_model: !ref <asr_model>
 
 
 
 
94
  normalize: !ref <normalize>
95
- beam_searcher: !ref <beam_searcher>
 
96
 
97
- beam_searcher: !new:speechbrain.decoders.S2STransformerBeamSearch
98
  modules: [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
99
  bos_index: !ref <bos_index>
100
  eos_index: !ref <eos_index>
@@ -106,6 +110,10 @@ beam_searcher: !new:speechbrain.decoders.S2STransformerBeamSearch
106
  using_eos_threshold: False
107
  length_normalization: True
108
 
 
 
 
 
109
  log_softmax: !new:torch.nn.LogSoftmax
110
  dim: -1
111
 
 
86
  asr_model: !new:torch.nn.ModuleList
87
  - [!ref <CNN>, !ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
88
 
89
+ # Here, we extract the encoder from the Transformer model
90
+ Tencoder: !new:speechbrain.lobes.models.transformer.TransformerASR.EncoderWrapper
 
91
  transformer: !ref <Transformer>
92
+
93
+ # We compose the inference (encoder) pipeline.
94
+ encoder: !new:speechbrain.nnet.containers.LengthsCapableSequential
95
+ input_shape: [null, null, !ref <n_mels>]
96
+ compute_features: !ref <compute_features>
97
  normalize: !ref <normalize>
98
+ cnn: !ref <CNN>
99
+ transformer_encoder: !ref <Tencoder>
100
 
101
+ decoder: !new:speechbrain.decoders.S2STransformerBeamSearch
102
  modules: [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>]
103
  bos_index: !ref <bos_index>
104
  eos_index: !ref <eos_index>
 
110
  using_eos_threshold: False
111
  length_normalization: True
112
 
113
+ modules:
114
+ encoder: !ref <encoder>
115
+ decoder: !ref <decoder>
116
+
117
  log_softmax: !new:torch.nn.LogSoftmax
118
  dim: -1
119