3loi's picture
Update app.py
a4a1b9b verified
raw history blame
No virus
2.01 kB
from transformers import pipeline
from transformers import AutoModelForAudioClassification
import gradio as gr
import librosa
import torch
import numpy as np
description_text = "Multi-label (arousal, dominance, valence) Odyssey 2024 Emotion Recognition competition baseline model.<br> \
The model is trained on MSP-Podcast. \
For more details visit: [HuggingFace model page](https://huggingface.co/3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes), \
[paper/soon]() and [GitHub](https://github.com/MSP-UTD/MSP-Podcast_Challenge/tree/main). <br> <br>\
Upload an audio file and hit the 'Submit' button to predict the emotion"
def classify_audio(audio_file):
model = AutoModelForAudioClassification.from_pretrained("3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes", trust_remote_code=True)
mean, std = model.config.mean, model.config.std
model_sr = model.config.sampling_rate
id2label = model.config.id2label
sr, raw_wav = audio_file
y = raw_wav.astype(np.float32, order='C') / np.iinfo(raw_wav.dtype).max
output = ''
if sr != 16000:
y = librosa.resample(y, orig_sr=sr, target_sr=model_sr)
output += "{} sampling rate is uncompatible, converted to {} as the model was trained on {} sampling rate\n".format(sr, model_sr, model_sr)
norm_wav = (y - mean) / (std+0.000001)
mask = torch.ones(1, len(norm_wav))
wavs = torch.tensor(norm_wav).unsqueeze(0)
pred = model(wavs, mask).detach().numpy()
for att_i, att_val in enumerate(pred[0]):
output += "{}: \t{:0.4f}\n".format(id2label[att_i], att_val)
return output
def main():
iface = gr.Interface(fn=classify_audio, inputs=gr.Audio(sources=["upload", "microphone"], label="Audio file"),
outputs=gr.Text(), title="Speech Emotion Recognition App",
description=description_text)
iface.launch()
if __name__ == '__main__':
main()