canlinzhang
commited on
Commit
•
620ae71
1
Parent(s):
c134797
Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
This model is fine tuned on the IEMOCAP_speaker_indpt_Ses05F_Ses05M.pickle dataset, which use Ses05F as validation speaker and Ses05M as test speaker. So it is a speaker independent model.
|
2 |
+
|
3 |
+
The initial pre-trained model is facebook/wav2vec2-base. The fine tune dataset only contains 4 common emotions of IEMOCAP (happy, angry, sad, neutral), **without frustration**. No audio augmentation is applied. Fine-tune dataset audios are also not padded or trimed to fixed length. The length setting is done when fine tuning the transoformer using max_length = 8 sec in the feature extractor.
|
4 |
+
|
5 |
+
After **10** epoches of training, the validation accuracy is around **67%**.
|
6 |
+
|
7 |
+
In order to impliment this model: run the following code in a python script:
|
8 |
+
|
9 |
+
```
|
10 |
+
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
|
11 |
+
import librosa
|
12 |
+
import torch
|
13 |
+
|
14 |
+
target_sampling_rate = 16000
|
15 |
+
model_name = 'canlinzhang/Sorenson_fine_tune_wav2vec2-on_IEMOCAP_no_aug_no_fru_2'
|
16 |
+
my_token = my_token
|
17 |
+
audio_path = your_audio_path
|
18 |
+
|
19 |
+
#build id and label dicts
|
20 |
+
id2label = {0:'neu', 1:'ang', 2:'sad', 3:'hap'}
|
21 |
+
label2id = {'neu':0, 'ang':1, 'sad':2, 'hap':3}
|
22 |
+
|
23 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
|
24 |
+
|
25 |
+
model = AutoModelForAudioClassification.from_pretrained(model_name, use_auth_token = my_token)
|
26 |
+
|
27 |
+
y_ini, sr_ini = librosa.load(audio_path, sr=target_sampling_rate)
|
28 |
+
|
29 |
+
inputs = feature_extractor(y_ini, sampling_rate=target_sampling_rate, return_tensors="pt")
|
30 |
+
|
31 |
+
logits = model(**inputs).logits
|
32 |
+
|
33 |
+
predicted_class_ids = torch.argmax(logits).item()
|
34 |
+
|
35 |
+
pred_class = id2label[predicted_class_ids]
|
36 |
+
|
37 |
+
print(pred_class)
|
38 |
+
```
|