File size: 1,995 Bytes
35545b0
 
 
 
 
 
 
 
5642657
 
 
5c180c7
ad02703
8dda170
 
35545b0
 
bd1d60b
 
 
f9b0e0d
8dda170
ff9ea0c
8dda170
76fe58b
 
 
357c84b
35545b0
76fe58b
 
35545b0
 
d3aaef0
76fe58b
88ba5b7
d3aaef0
 
c3d029a
d3aaef0
35545b0
 
 
 
8dda170
 
5642657
35545b0
 
 
 
 
 
8dda170
d3aaef0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from transformers import pipeline
from transformers import AutoModelForAudioClassification
import gradio as gr
import librosa
import torch
import numpy as np


description_text = "Multi-label (arousal, dominance, valence) Odyssey 2024 Emotion Recognition competition baseline model.<br> \
        The model is trained on MSP-Podcast. \
        For more details visit: [HuggingFace](https://huggingface.co/3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes), \
                            [paper/soon]() and [GitHub](https://github.com/MSP-UTD/MSP-Podcast_Challenge/tree/main). <br> <br>\
        Upload an audio file and hit the 'Submit' button to predict the emotion"


def classify_audio(audio_file):
    model = AutoModelForAudioClassification.from_pretrained("3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes", trust_remote_code=True)
    mean, std = model.config.mean, model.config.std
    model_sr  = model.config.sampling_rate
    id2label  = model.config.id2label
    
    sr, raw_wav = audio_file
    y = raw_wav.astype(np.float32, order='C') / np.iinfo(raw_wav.dtype).max

    output = ''
    if sr != 16000:
        y = librosa.resample(y, orig_sr=sr, target_sr=model_sr)
        output += "{} sampling rate is uncompatible, converted to {} as the model was trained on {} sampling rate\n".format(sr, model_sr, model_sr)

    
    norm_wav = (y - mean) / (std+0.000001)
    mask = torch.ones(1, len(norm_wav))
    wavs = torch.tensor(norm_wav).unsqueeze(0)

    
    pred = model(wavs, mask).detach().numpy()
    for att_i, att_val in enumerate(pred[0]):
        output += "{}: \t{:0.4f}\n".format(id2label[att_i], att_val)
    
    return output


def main():
    
    iface = gr.Interface(fn=classify_audio, inputs=gr.Audio(sources=["upload", "microphone"], label="Audio file"), 
                         outputs=gr.Text(), title="Speech Emotion Recognition App",
                         description=description_text)
    
    iface.launch()


if __name__ == '__main__':
    main()