File size: 2,006 Bytes
35545b0
 
 
 
 
 
 
 
5642657
 
a4a1b9b
5c180c7
ad02703
8dda170
 
35545b0
 
bd1d60b
 
 
f9b0e0d
8dda170
ff9ea0c
8dda170
76fe58b
 
 
357c84b
35545b0
76fe58b
 
35545b0
 
d3aaef0
76fe58b
88ba5b7
d3aaef0
 
c3d029a
d3aaef0
35545b0
 
 
 
8dda170
 
5642657
35545b0
 
 
 
 
 
8dda170
d3aaef0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from transformers import pipeline
from transformers import AutoModelForAudioClassification
import gradio as gr
import librosa
import torch
import numpy as np


description_text = "Multi-label (arousal, dominance, valence) Odyssey 2024 Emotion Recognition competition baseline model.<br> \
        The model is trained on MSP-Podcast. \
        For more details visit: [HuggingFace model page](https://huggingface.co/3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes), \
                            [paper/soon]() and [GitHub](https://github.com/MSP-UTD/MSP-Podcast_Challenge/tree/main). <br> <br>\
        Upload an audio file and hit the 'Submit' button to predict the emotion"


def classify_audio(audio_file):
    model = AutoModelForAudioClassification.from_pretrained("3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes", trust_remote_code=True)
    mean, std = model.config.mean, model.config.std
    model_sr  = model.config.sampling_rate
    id2label  = model.config.id2label
    
    sr, raw_wav = audio_file
    y = raw_wav.astype(np.float32, order='C') / np.iinfo(raw_wav.dtype).max

    output = ''
    if sr != 16000:
        y = librosa.resample(y, orig_sr=sr, target_sr=model_sr)
        output += "{} sampling rate is uncompatible, converted to {} as the model was trained on {} sampling rate\n".format(sr, model_sr, model_sr)

    
    norm_wav = (y - mean) / (std+0.000001)
    mask = torch.ones(1, len(norm_wav))
    wavs = torch.tensor(norm_wav).unsqueeze(0)

    
    pred = model(wavs, mask).detach().numpy()
    for att_i, att_val in enumerate(pred[0]):
        output += "{}: \t{:0.4f}\n".format(id2label[att_i], att_val)
    
    return output


def main():
    
    iface = gr.Interface(fn=classify_audio, inputs=gr.Audio(sources=["upload", "microphone"], label="Audio file"), 
                         outputs=gr.Text(), title="Speech Emotion Recognition App",
                         description=description_text)
    
    iface.launch()


if __name__ == '__main__':
    main()