IWSLT-ast-w2v2-mbart / hyperparams.yaml
HaNguyen's picture
minor update
8aa4e50
pretrained_path: HaNguyen/IWSLT-ast-w2v2-mbart
lang: fr #for the BLEU score detokenization
target_lang: fr_XX # for mbart initialization
sample_rate: 16000
# URL for the HuggingFace model we want to load (BASE here)
wav2vec2_hub: LIA-AvignonUniversity/IWSLT2022-tamasheq-only
# wav2vec 2.0 specific parameters
wav2vec2_frozen: False
# Feature parameters (W2V2 etc)
features_dim: 768 # base wav2vec output dimension, for large replace by 1024
#projection for w2v
enc_dnn_layers: 1
enc_dnn_neurons: 1024
# Transformer
embedding_size: 256
d_model: 1024
activation: !name:torch.nn.GELU
# Outputs
blank_index: 1
label_smoothing: 0.1
pad_index: 1 # pad_index defined by mbart model
bos_index: 250008 # fr_XX bos_index defined by mbart model
eos_index: 2
# Decoding parameters
# Be sure that the bos and eos index match with the BPEs ones
min_decode_ratio: 0.0
max_decode_ratio: 0.25
valid_beam_size: 5
test_beam_size: 5
############################## models ################################
#wav2vec model
wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
source: !ref <wav2vec2_hub>
output_norm: True
freeze: !ref <wav2vec2_frozen>
save_path: wav2vec2_checkpoint
#linear projection
enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
input_shape: [null, null, 768]
activation: !ref <activation>
dnn_blocks: 1
dnn_neurons: 1024
#mBART
mbart_path: facebook/mbart-large-50-many-to-many-mmt
mbart_frozen: False
mBART: &id004 !new:speechbrain.lobes.models.huggingface_transformers.mbart.mBART
source: !ref <mbart_path>
freeze: !ref <mbart_frozen>
save_path: mbart_checkpoint
target_lang: !ref <target_lang>
log_softmax: !new:speechbrain.nnet.activations.Softmax
apply_log: True
seq_lin: !new:torch.nn.Identity
modules:
wav2vec2: !ref <wav2vec2>
enc: !ref <enc>
mBART: !ref <mBART>
model: !new:torch.nn.ModuleList
- [!ref <enc>]
valid_search: !new:speechbrain.decoders.S2SHFTextBasedBeamSearcher
modules: [!ref <mBART>, null, null]
vocab_size: 250054
bos_index: 250008
eos_index: 2
min_decode_ratio: 0.0
max_decode_ratio: 0.25
beam_size: 5
using_eos_threshold: True
length_normalization: True
pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
loadables:
model: !ref <model>
wav2vec2: !ref <wav2vec2>
mBART: !ref <mBART>
paths:
wav2vec2: !ref <pretrained_path>/wav2vec2.ckpt
model: !ref <pretrained_path>/model.ckpt
mBART: !ref <pretrained_path>/mBART.ckpt