wavlm-base-emo-fi / hyperparams.yalm
Porjaz's picture
Create hyperparams.yalm
4e28017 verified
raw
history blame
3.27 kB
# Generated 2022-01-19 from:
# /scratch/elec/t405-puhe/p/porjazd1/Metadata_Classification/TCN/asr_topic_speechbrain/mgb_asr/hyperparams.yaml
# yamllint disable
# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 1234
__set_seed: !apply:torch.manual_seed [1234]
skip_training: True
output_folder: output_folder_wavlm_base
label_encoder_file: !ref <output_folder>/label_encoder.txt
train_log: !ref <output_folder>/train_log.txt
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: !ref <output_folder>/train_log.txt
save_folder: !ref <output_folder>/save
wav2vec2_hub: microsoft/wavlm-base-plus-sv
wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
# Feature parameters
sample_rate: 22050
new_sample_rate: 16000
window_size: 25
n_mfcc: 23
# Training params
n_epochs: 28
stopping_factor: 10
dataloader_options:
batch_size: 10
shuffle: false
test_dataloader_options:
batch_size: 1
shuffle: false
lr: 0.0001
lr_wav2vec2: 0.00001
#freeze all wav2vec2
freeze_wav2vec2: False
#set to true to freeze the CONV part of the wav2vec2 model
# We see an improvement of 2% with freezing CNNs
freeze_wav2vec2_conv: True
label_encoder: !new:speechbrain.dataio.encoder.CategoricalEncoder
encoder_dims: 768
n_classes: 5
# Wav2vec2 encoder
wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
source: !ref <wav2vec2_hub>
output_norm: True
freeze: !ref <freeze_wav2vec2>
freeze_feature_extractor: !ref <freeze_wav2vec2_conv>
save_path: !ref <wav2vec2_folder>
output_all_hiddens: True
avg_pool: !new:speechbrain.nnet.pooling.StatisticsPooling
return_std: False
label_lin: !new:speechbrain.nnet.linear.Linear
input_size: !ref <encoder_dims>
n_neurons: !ref <n_classes>
bias: False
log_softmax: !new:speechbrain.nnet.activations.Softmax
apply_log: True
opt_class: !name:torch.optim.Adam
lr: !ref <lr>
wav2vec2_opt_class: !name:torch.optim.Adam
lr: !ref <lr_wav2vec2>
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <n_epochs>
# Functions that compute the statistics to track during the validation step.
accuracy_computer: !name:speechbrain.utils.Accuracy.AccuracyStats
compute_cost: !name:speechbrain.nnet.losses.nll_loss
error_stats: !name:speechbrain.utils.metric_stats.MetricStats
metric: !name:speechbrain.nnet.losses.classification_error
reduction: batch
modules:
wav2vec2: !ref <wav2vec2>
label_lin: !ref <label_lin>
model: !new:torch.nn.ModuleList
- [!ref <label_lin>]
lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr>
improvement_threshold: 0.0025
annealing_factor: 0.9
patient: 0
lr_annealing_wav2vec2: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr_wav2vec2>
improvement_threshold: 0.0025
annealing_factor: 0.9
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder>
recoverables:
model: !ref <model>
wav2vec2: !ref <wav2vec2>
lr_annealing_output: !ref <lr_annealing>
lr_annealing_wav2vec2: !ref <lr_annealing_wav2vec2>
counter: !ref <epoch_counter>