Porjaz commited on
Commit
4e28017
1 Parent(s): 163f43a

Create hyperparams.yalm

Browse files
Files changed (1) hide show
  1. hyperparams.yalm +128 -0
hyperparams.yalm ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated 2022-01-19 from:
2
+ # /scratch/elec/t405-puhe/p/porjazd1/Metadata_Classification/TCN/asr_topic_speechbrain/mgb_asr/hyperparams.yaml
3
+ # yamllint disable
4
+ # Seed needs to be set at top of yaml, before objects with parameters are made
5
+ seed: 1234
6
+ __set_seed: !apply:torch.manual_seed [1234]
7
+
8
+ skip_training: True
9
+
10
+ output_folder: output_folder_wavlm_base
11
+ label_encoder_file: !ref <output_folder>/label_encoder.txt
12
+
13
+ train_log: !ref <output_folder>/train_log.txt
14
+ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
15
+ save_file: !ref <output_folder>/train_log.txt
16
+ save_folder: !ref <output_folder>/save
17
+
18
+ wav2vec2_hub: microsoft/wavlm-base-plus-sv
19
+
20
+ wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
21
+
22
+ # Feature parameters
23
+ sample_rate: 22050
24
+ new_sample_rate: 16000
25
+ window_size: 25
26
+ n_mfcc: 23
27
+
28
+ # Training params
29
+ n_epochs: 28
30
+ stopping_factor: 10
31
+
32
+ dataloader_options:
33
+ batch_size: 10
34
+ shuffle: false
35
+
36
+ test_dataloader_options:
37
+ batch_size: 1
38
+ shuffle: false
39
+
40
+ lr: 0.0001
41
+ lr_wav2vec2: 0.00001
42
+
43
+ #freeze all wav2vec2
44
+ freeze_wav2vec2: False
45
+ #set to true to freeze the CONV part of the wav2vec2 model
46
+ # We see an improvement of 2% with freezing CNNs
47
+ freeze_wav2vec2_conv: True
48
+
49
+ label_encoder: !new:speechbrain.dataio.encoder.CategoricalEncoder
50
+
51
+ encoder_dims: 768
52
+ n_classes: 5
53
+
54
+ # Wav2vec2 encoder
55
+ wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
56
+ source: !ref <wav2vec2_hub>
57
+ output_norm: True
58
+ freeze: !ref <freeze_wav2vec2>
59
+ freeze_feature_extractor: !ref <freeze_wav2vec2_conv>
60
+ save_path: !ref <wav2vec2_folder>
61
+ output_all_hiddens: True
62
+
63
+
64
+ avg_pool: !new:speechbrain.nnet.pooling.StatisticsPooling
65
+ return_std: False
66
+
67
+
68
+ label_lin: !new:speechbrain.nnet.linear.Linear
69
+ input_size: !ref <encoder_dims>
70
+ n_neurons: !ref <n_classes>
71
+ bias: False
72
+
73
+
74
+ log_softmax: !new:speechbrain.nnet.activations.Softmax
75
+ apply_log: True
76
+
77
+
78
+ opt_class: !name:torch.optim.Adam
79
+ lr: !ref <lr>
80
+
81
+
82
+ wav2vec2_opt_class: !name:torch.optim.Adam
83
+ lr: !ref <lr_wav2vec2>
84
+
85
+ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
86
+ limit: !ref <n_epochs>
87
+
88
+ # Functions that compute the statistics to track during the validation step.
89
+ accuracy_computer: !name:speechbrain.utils.Accuracy.AccuracyStats
90
+
91
+
92
+ compute_cost: !name:speechbrain.nnet.losses.nll_loss
93
+
94
+
95
+ error_stats: !name:speechbrain.utils.metric_stats.MetricStats
96
+ metric: !name:speechbrain.nnet.losses.classification_error
97
+ reduction: batch
98
+
99
+ modules:
100
+ wav2vec2: !ref <wav2vec2>
101
+ label_lin: !ref <label_lin>
102
+
103
+
104
+ model: !new:torch.nn.ModuleList
105
+ - [!ref <label_lin>]
106
+
107
+
108
+ lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
109
+ initial_value: !ref <lr>
110
+ improvement_threshold: 0.0025
111
+ annealing_factor: 0.9
112
+ patient: 0
113
+
114
+
115
+ lr_annealing_wav2vec2: !new:speechbrain.nnet.schedulers.NewBobScheduler
116
+ initial_value: !ref <lr_wav2vec2>
117
+ improvement_threshold: 0.0025
118
+ annealing_factor: 0.9
119
+
120
+
121
+ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
122
+ checkpoints_dir: !ref <save_folder>
123
+ recoverables:
124
+ model: !ref <model>
125
+ wav2vec2: !ref <wav2vec2>
126
+ lr_annealing_output: !ref <lr_annealing>
127
+ lr_annealing_wav2vec2: !ref <lr_annealing_wav2vec2>
128
+ counter: !ref <epoch_counter>