yingzhi commited on
Commit
0c764f2
1 Parent(s): ef4387c

initial commit

Browse files
Files changed (10) hide show
  1. README.md +130 -1
  2. config.json +3 -0
  3. custom_interface.py +215 -0
  4. example.wav +0 -0
  5. example_sad.wav +0 -0
  6. hyperparams.yaml +68 -0
  7. input_norm.ckpt +3 -0
  8. label_encoder.txt +6 -0
  9. model.ckpt +3 -0
  10. wav2vec2.ckpt +3 -0
README.md CHANGED
@@ -1,3 +1,132 @@
1
  ---
2
- license: apache-2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language: "en"
3
+ thumbnail:
4
+ tags:
5
+ - audio-classification
6
+ - speechbrain
7
+ - Emotion
8
+ - Diarization
9
+ - wavlm
10
+ - pytorch
11
+ license: "apache-2.0"
12
+ datasets:
13
+ - ZaionEmotionDataset
14
+ - iemocap
15
+ - ravdess
16
+ - jl-corpus
17
+ - esd
18
+ - emov-db
19
+ metrics:
20
+ - EDER
21
  ---
22
+
23
+ <iframe src="https://ghbtns.com/github-btn.html?user=speechbrain&repo=speechbrain&type=star&count=true&size=large&v=2" frameborder="0" scrolling="0" width="170" height="30" title="GitHub"></iframe>
24
+ <br/><br/>
25
+
26
+ # Emotion Diarization with WavLM Large on 5 popular emotional datasets.
27
+
28
+ This repository provides all the necessary tools to perform speech emotion diarization with a fine-tuned wavlm (large) model using SpeechBrain.
29
+
30
+ The model is trained on concatenated audios and tested on [ZaionEmotionDataset](https://zaion.ai/en/resources/zaion-lab-blog/zaion-emotion-dataset/). The metric is Emotion Diarization Error Rate (EDER). For more details please check the [paper link](https://arxiv.org/pdf/2306.12991.pdf).
31
+
32
+
33
+ For a better experience, we encourage you to learn more about [SpeechBrain](https://speechbrain.github.io). The model performance on ZED (test set) is:
34
+
35
+ | Release | EDER(%) |
36
+ |:-------------:|:--------------:|
37
+ | 19-10-21 | 29.7 (Avg: 30.2) |
38
+
39
+
40
+ ## Pipeline description
41
+
42
+ This system is composed of an wavlm model. It is a combination of convolutional and residual blocks. The task aimes to predict the correct emotion composants and their boundaries within an utterance. For now, the model was trained with audios that contain only 1 non-neutral emotion event.
43
+
44
+ The system is trained with recordings sampled at 16kHz (single channel).
45
+ The code will automatically normalize your audio (i.e., resampling + mono channel selection) when calling *diarize_file* if needed.
46
+
47
+
48
+ ## Install SpeechBrain
49
+
50
+ First of all, please install the **development** version of SpeechBrain with the following command:
51
+
52
+ ```
53
+ pip install speechbrain
54
+ ```
55
+
56
+ Please notice that we encourage you to read our tutorials and learn more about
57
+ [SpeechBrain](https://speechbrain.github.io).
58
+
59
+ ### Perform Speech Emotion Diarization
60
+
61
+ An external `py_module_file=custom.py` is used as an external Predictor class into this HF repos. We use `foreign_class` function from `speechbrain.pretrained.interfaces` that allow you to load you custom model.
62
+
63
+ ```python
64
+ from speechbrain.pretrained.interfaces import foreign_class
65
+ classifier = foreign_class(
66
+ source="speechbrain/emotion-diarization-wavlm-large",
67
+ pymodule_file="custom_interface.py",
68
+ classname="Speech_Emotion_Diarization"
69
+ )
70
+ diary = classifier.diarize_file("speechbrain/emotion-diarization-wavlm-large/example.wav")
71
+ print(diary)
72
+ ```
73
+ The output will contain a dictionary of emotion composants and their boundaries.
74
+
75
+ ### Inference on GPU
76
+ To perform inference on the GPU, add `run_opts={"device":"cuda"}` when calling the `from_hparams` method.
77
+
78
+ ### Training
79
+ The model was trained with SpeechBrain (aa018540).
80
+ To train it from scratch follows these steps:
81
+ 1. Clone SpeechBrain:
82
+ ```bash
83
+ git clone https://github.com/speechbrain/speechbrain/
84
+ ```
85
+ 2. Install it:
86
+ ```
87
+ cd speechbrain
88
+ pip install -r requirements.txt
89
+ pip install -e .
90
+ ```
91
+
92
+ 3. Run Training:
93
+ ```
94
+ cd recipes/ZaionEmotionDataset/emotion_diarization
95
+ python train.py hparams/train.yaml --zed_folder /path/to/ZED --emovdb_folder /path/to/EmoV-DB --esd_folder /path/to/ESD --iemocap_folder /path/to/IEMOCAP --jlcorpus_folder /path/to/JL_corpus --ravdess_folder /path/to/RAVDESS
96
+ ```
97
+
98
+ You can find our training results (models, logs, etc) [here](to be added).
99
+
100
+ ### Limitations
101
+ The SpeechBrain team does not provide any warranty on the performance achieved by this model when used on other datasets.
102
+
103
+ # **About Speech Emotion Diarization/Zaion Emotion Dataset**
104
+
105
+ ```bibtex
106
+ @article{wang2023speech,
107
+ title={Speech Emotion Diarization: Which Emotion Appears When?},
108
+ author={Wang, Yingzhi and Ravanelli, Mirco and Nfissi, Alaa and Yacoubi, Alya},
109
+ journal={arXiv preprint arXiv:2306.12991},
110
+ year={2023}
111
+ }
112
+ ```
113
+
114
+ # **Citing SpeechBrain**
115
+ Please, cite SpeechBrain if you use it for your research or business.
116
+
117
+ ```bibtex
118
+ @misc{speechbrain,
119
+ title={{SpeechBrain}: A General-Purpose Speech Toolkit},
120
+ author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
121
+ year={2021},
122
+ eprint={2106.04624},
123
+ archivePrefix={arXiv},
124
+ primaryClass={eess.AS},
125
+ note={arXiv:2106.04624}
126
+ }
127
+ ```
128
+
129
+ # **About SpeechBrain**
130
+ - Website: https://speechbrain.github.io/
131
+ - Code: https://github.com/speechbrain/speechbrain/
132
+ - HuggingFace: https://huggingface.co/speechbrain/
config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "speechbrain_interface": "EncoderWav2vecClassifier"
3
+ }
custom_interface.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from speechbrain.pretrained import Pretrained
3
+
4
+ class Speech_Emotion_Diarization(Pretrained):
5
+ """A ready-to-use SED interface (audio -> emotions and their durations)
6
+
7
+ Arguments
8
+ ---------
9
+ hparams
10
+ Hyperparameters (from HyperPyYAML)
11
+
12
+ Example
13
+ -------
14
+ >>> from speechbrain.pretrained import Speech_Emotion_Diarization
15
+ >>> tmpdir = getfixture("tmpdir")
16
+ >>> sed_model = Speech_Emotion_Diarization.from_hparams(source="speechbrain/emotion-diarization-wavlm-large", savedir=tmpdir,) # doctest: +SKIP
17
+ >>> sed_model.diarize_file("speechbrain/emotion-diarization-wavlm-large/example.wav") # doctest: +SKIP
18
+ """
19
+
20
+ MODULES_NEEDED = ["input_norm", "wav2vec", "output_mlp"]
21
+
22
+ def __init__(self, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+
25
+ def diarize_file(self, path):
26
+ """Get emotion diarization of a spoken utterance.
27
+
28
+ Arguments
29
+ ---------
30
+ path : str
31
+ Path to audio file which to diarize.
32
+
33
+ Returns
34
+ -------
35
+ dict
36
+ The emotions and their boundaries.
37
+ """
38
+ waveform = self.load_audio(path)
39
+ # Fake a batch:
40
+ batch = waveform.unsqueeze(0)
41
+ rel_length = torch.tensor([1.0])
42
+ frame_class = self.diarize_batch(
43
+ batch, rel_length, [path]
44
+ )
45
+ return frame_class
46
+
47
+ def encode_batch(self, wavs, wav_lens):
48
+ """Encodes audios into fine-grained emotional embeddings
49
+
50
+ Arguments
51
+ ---------
52
+ wavs : torch.tensor
53
+ Batch of waveforms [batch, time, channels].
54
+ wav_lens : torch.tensor
55
+ Lengths of the waveforms relative to the longest one in the
56
+ batch, tensor of shape [batch]. The longest one should have
57
+ relative length 1.0 and others len(waveform) / max_length.
58
+ Used for ignoring padding.
59
+
60
+ Returns
61
+ -------
62
+ torch.tensor
63
+ The encoded batch
64
+ """
65
+ if len(wavs.shape) == 1:
66
+ wavs = wavs.unsqueeze(0)
67
+
68
+ # Assign full length if wav_lens is not assigned
69
+ if wav_lens is None:
70
+ wav_lens = torch.ones(wavs.shape[0], device=self.device)
71
+
72
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
73
+
74
+ wavs = self.mods.input_norm(wavs, wav_lens)
75
+ outputs = self.mods.wav2vec2(wavs)
76
+ return outputs
77
+
78
+
79
+ def diarize_batch(self, wavs, wav_lens, batch_id):
80
+ """Get emotion diarization of a batch of waveforms.
81
+
82
+ The waveforms should already be in the model's desired format.
83
+ You can call:
84
+ ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
85
+ to get a correctly converted signal in most cases.
86
+
87
+ Arguments
88
+ ---------
89
+ wavs : torch.tensor
90
+ Batch of waveforms [batch, time, channels].
91
+ wav_lens : torch.tensor
92
+ Lengths of the waveforms relative to the longest one in the
93
+ batch, tensor of shape [batch]. The longest one should have
94
+ relative length 1.0 and others len(waveform) / max_length.
95
+ Used for ignoring padding.
96
+
97
+ Returns
98
+ -------
99
+ torch.tensor
100
+ The frame-wise predictions
101
+ """
102
+ outputs = self.encode_batch(wavs, wav_lens)
103
+ averaged_out = self.hparams.avg_pool(outputs)
104
+ outputs = self.mods.output_mlp(averaged_out)
105
+ outputs = self.hparams.log_softmax(outputs)
106
+ score, index = torch.max(outputs, dim=-1)
107
+ preds = self.hparams.label_encoder.decode_torch(index)
108
+ results = self.preds_to_diarization(preds, batch_id)
109
+ return results
110
+
111
+ def preds_to_diarization(self, prediction, batch_id):
112
+ """Convert frame-wise predictions into a dictionary of
113
+ diarization results.
114
+
115
+ Returns
116
+ -------
117
+ dictionary
118
+ A dictionary with the start/end of each emotion
119
+ """
120
+ results = {}
121
+
122
+ for i in range(len(prediction)):
123
+ pred = prediction[i]
124
+ lol = []
125
+ for j in range(len(pred)):
126
+ start = round(self.hparams.stride * 0.02 * j, 2)
127
+ end = round(start + self.hparams.window_length * 0.02, 2)
128
+ lol.append([batch_id[i], start, end, pred[j]])
129
+
130
+ lol = merge_ssegs_same_emotion_adjacent(lol)
131
+ print(lol)
132
+ results[batch_id[i]] = [{"start": k[1], "end":k[2], "emotion": k[3]} for k in lol]
133
+ return results
134
+
135
+
136
+ def forward(self, wavs, wav_lens):
137
+ """Runs full transcription - note: no gradients through decoding"""
138
+ return self.transcribe_batch(wavs, wav_lens)
139
+
140
+
141
+ def is_overlapped(end1, start2):
142
+ """Returns True if segments are overlapping.
143
+
144
+ Arguments
145
+ ---------
146
+ end1 : float
147
+ End time of the first segment.
148
+ start2 : float
149
+ Start time of the second segment.
150
+
151
+ Returns
152
+ -------
153
+ overlapped : bool
154
+ True of segments overlapped else False.
155
+
156
+ Example
157
+ -------
158
+ >>> from speechbrain.processing import diarization as diar
159
+ >>> diar.is_overlapped(5.5, 3.4)
160
+ True
161
+ >>> diar.is_overlapped(5.5, 6.4)
162
+ False
163
+ """
164
+
165
+ if start2 > end1:
166
+ return False
167
+ else:
168
+ return True
169
+
170
+
171
+ def merge_ssegs_same_emotion_adjacent(lol):
172
+ """Merge adjacent sub-segs if they are the same emotion.
173
+ Arguments
174
+ ---------
175
+ lol : list of list
176
+ Each list contains [utt_id, sseg_start, sseg_end, emo_label].
177
+ Returns
178
+ -------
179
+ new_lol : list of list
180
+ new_lol contains adjacent segments merged from the same emotion ID.
181
+ Example
182
+ -------
183
+ >>> from speechbrain.utils.EDER import merge_ssegs_same_emotion_adjacent
184
+ >>> lol=[['u1', 0.0, 7.0, 'a'],
185
+ ... ['u1', 7.0, 9.0, 'a'],
186
+ ... ['u1', 9.0, 11.0, 'n'],
187
+ ... ['u1', 11.0, 13.0, 'n'],
188
+ ... ['u1', 13.0, 15.0, 'n'],
189
+ ... ['u1', 15.0, 16.0, 'a']]
190
+ >>> merge_ssegs_same_emotion_adjacent(lol)
191
+ [['u1', 0.0, 9.0, 'a'], ['u1', 9.0, 15.0, 'n'], ['u1', 15.0, 16.0, 'a']]
192
+ """
193
+ new_lol = []
194
+
195
+ # Start from the first sub-seg
196
+ sseg = lol[0]
197
+ flag = False
198
+ for i in range(1, len(lol)):
199
+ next_sseg = lol[i]
200
+ # IF sub-segments overlap AND has same emotion THEN merge
201
+ if is_overlapped(sseg[2], next_sseg[1]) and sseg[3] == next_sseg[3]:
202
+ sseg[2] = next_sseg[2] # just update the end time
203
+ # This is important. For the last sseg, if it is the same emotion then merge
204
+ # Make sure we don't append the last segment once more. Hence, set FLAG=True
205
+ if i == len(lol) - 1:
206
+ flag = True
207
+ new_lol.append(sseg)
208
+ else:
209
+ new_lol.append(sseg)
210
+ sseg = next_sseg
211
+ # Add last segment only when it was skipped earlier.
212
+ if flag is False:
213
+ new_lol.append(lol[-1])
214
+
215
+ return new_lol
example.wav ADDED
Binary file (144 kB). View file
 
example_sad.wav ADDED
Binary file (169 kB). View file
 
hyperparams.yaml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ############################################################################
2
+ # Model: WavLM for Emotion Diarization
3
+ # ############################################################################
4
+
5
+
6
+ # Hparams NEEDED
7
+ HPARAMS_NEEDED: ["window_length", "stride", "encoder_dim", "out_n_neurons", "avg_pool", "label_encoder", "softmax"]
8
+ # Modules Needed
9
+ MODULES_NEEDED: ["wav2vec2", "output_mlp"]
10
+
11
+ # Feature parameters
12
+ wav2vec2_hub: "microsoft/wavlm-large"
13
+
14
+ # Pretrain folder (HuggingFace)
15
+ pretrained_path: /home/ywang/zed_pr/sed_hf
16
+
17
+ # parameters
18
+ window_length: 1 # win_len = 0.02 * 1 = 0.02s
19
+ stride: 1 # stride = 0.02 * 1 = 0.02s
20
+ encoder_dim: 1024
21
+ out_n_neurons: 4
22
+
23
+ input_norm: !new:speechbrain.processing.features.InputNormalization
24
+ norm_type: sentence
25
+ std_norm: False
26
+
27
+ wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
28
+ source: !ref <wav2vec2_hub>
29
+ output_norm: True
30
+ freeze: False
31
+ freeze_feature_extractor: True
32
+ save_path: wav2vec2_checkpoint
33
+
34
+ avg_pool: !new:speechbrain.nnet.pooling.Pooling1d
35
+ pool_type: "avg"
36
+ kernel_size: !ref <window_length>
37
+ stride: !ref <stride>
38
+ ceil_mode: True
39
+
40
+ output_mlp: !new:speechbrain.nnet.linear.Linear
41
+ input_size: !ref <encoder_dim>
42
+ n_neurons: !ref <out_n_neurons>
43
+ bias: False
44
+
45
+ model: !new:torch.nn.ModuleList
46
+ - [!ref <output_mlp>]
47
+
48
+ modules:
49
+ input_norm: !ref <input_norm>
50
+ wav2vec2: !ref <wav2vec2>
51
+ output_mlp: !ref <output_mlp>
52
+
53
+ log_softmax: !new:speechbrain.nnet.activations.Softmax
54
+ apply_log: True
55
+
56
+ label_encoder: !new:speechbrain.dataio.encoder.CategoricalEncoder
57
+
58
+ pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
59
+ loadables:
60
+ input_norm: !ref <input_norm>
61
+ wav2vec2: !ref <wav2vec2>
62
+ model: !ref <model>
63
+ label_encoder: !ref <label_encoder>
64
+ paths:
65
+ input_norm: !ref <pretrained_path>/input_norm.ckpt
66
+ wav2vec2: !ref <pretrained_path>/wav2vec2.ckpt
67
+ model: !ref <pretrained_path>/model.ckpt
68
+ label_encoder: !ref <pretrained_path>/label_encoder.txt
input_norm.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eddbd59b97a6456c5a81880065b785f731ca3b959abfa2c965658a591e53d31f
3
+ size 1075
label_encoder.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ 'a' => 0
2
+ 'n' => 1
3
+ 'h' => 2
4
+ 's' => 3
5
+ ================
6
+ 'starting_index' => 0
model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23c5832103c64cb628e8e56ce5fc7061be323e435a294d34060172c10015208d
3
+ size 17189
wav2vec2.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60d8746058853c9ad8976c5630d8584959b74428f6bbe7458fe3d8bdf15d54b3
3
+ size 1262005979