File size: 3,533 Bytes
fe65b7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
seed: 1993
__set_seed: !apply:torch.manual_seed [!ref <seed>]

# Dataset will be downloaded to the `data_original`
data_original: D:/voice-emo/dat/
output_folder: !ref results/train_with_wav2vec2/<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

# URL for the wav2vec2 model, you can change to benchmark different models
# Important: we use wav2vec2 base and not the fine-tuned one with ASR task
# This allows you to have ~4% improvement
wav2vec2_hub: facebook/wav2vec2-base
wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint

# Path where data manifest files will be stored
train_annotation: !ref <output_folder>/train.json
valid_annotation: !ref <output_folder>/valid.json
test_annotation: !ref <output_folder>/test.json
split_ratio: [80, 10, 10]
skip_prep: False

# The train logger writes training statistics to a file, as well as stdout.
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
    save_file: !ref <train_log>

ckpt_interval_minutes: 15  # save checkpoint every N min

####################### Training Parameters ####################################
number_of_epochs: 30
batch_size: 4
lr: 0.0001
lr_wav2vec2: 0.00001

# Freeze all wav2vec2
freeze_wav2vec2: False
# Set to true to freeze the CONV part of the wav2vec2 model
# We see an improvement of 2% with freezing CNNs
freeze_wav2vec2_conv: True

####################### Model Parameters #######################################
encoder_dim: 768

# Number of emotions
out_n_neurons: 7  # (anger, disgust, fear, happy, neutral, sad, suprise )

dataloader_options:
    batch_size: !ref <batch_size>
    shuffle: True
    num_workers: 2  # 2 on Linux but 0 works on Windows
    drop_last: False

# Wav2vec2 encoder
wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
    source: !ref <wav2vec2_hub>
    output_norm: True
    freeze: !ref <freeze_wav2vec2>
    freeze_feature_extractor: !ref <freeze_wav2vec2_conv>
    save_path: !ref <wav2vec2_folder>

avg_pool: !new:speechbrain.nnet.pooling.StatisticsPooling
    return_std: False

output_mlp: !new:speechbrain.nnet.linear.Linear
    input_size: !ref <encoder_dim>
    n_neurons: !ref <out_n_neurons>
    bias: False

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <number_of_epochs>

modules:
    wav2vec2: !ref <wav2vec2>
    output_mlp: !ref <output_mlp>

model: !new:torch.nn.ModuleList
    - [!ref <output_mlp>]

log_softmax: !new:speechbrain.nnet.activations.Softmax
    apply_log: True

compute_cost: !name:speechbrain.nnet.losses.nll_loss

error_stats: !name:speechbrain.utils.metric_stats.MetricStats
    metric: !name:speechbrain.nnet.losses.classification_error
        reduction: batch

opt_class: !name:torch.optim.Adam
    lr: !ref <lr>

wav2vec2_opt_class: !name:torch.optim.Adam
    lr: !ref <lr_wav2vec2>

lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
    initial_value: !ref <lr>
    improvement_threshold: 0.0025
    annealing_factor: 0.9
    patient: 0

lr_annealing_wav2vec2: !new:speechbrain.nnet.schedulers.NewBobScheduler
    initial_value: !ref <lr_wav2vec2>
    improvement_threshold: 0.0025
    annealing_factor: 0.9

checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
        model: !ref <model>
        wav2vec2: !ref <wav2vec2>
        lr_annealing_output: !ref <lr_annealing>
        lr_annealing_wav2vec2: !ref <lr_annealing_wav2vec2>
        counter: !ref <epoch_counter>