3loi's picture
Update app.py
ff9ea0c verified
raw
history blame
No virus
1.58 kB
from transformers import pipeline
from transformers import AutoModelForAudioClassification
import gradio as gr
import librosa
import torch
import numpy as np
mean, std = -8.278621631819787e-05, 0.08485510250851999
id2label = {0: 'arousal', 1: 'dominance', 2: 'valence'}
def classify_audio(audio_file):
model = AutoModelForAudioClassification.from_pretrained("3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes", trust_remote_code=True)
sr, raw_wav = audio_file
#y = raw_wav.astype(np.float32)
#y /= np.max(np.abs(y))
y = raw_wav.astype(np.float32, order='C') / np.iinfo(raw_wav.dtype).max
norm_wav = (y - mean) / (std+0.000001)
mask = torch.ones(1, len(norm_wav))
wavs = torch.tensor(norm_wav).unsqueeze(0)
pred = model(wavs, mask).detach().numpy()
output = ''
if sr != 16000:
output += "{} sampling rate is uncompatible. The model was trained on {} sampleing rate\n".format(sr, 16000)
# for i, audio_pred in enumerate(pred):
# output[i] = {}
for att_i, att_val in enumerate(pred[0]):
output += "{}: \t{:0.4f}\n".format(id2label[att_i], att_val)
return output
def main():
iface = gr.Interface(fn=classify_audio, inputs=gr.Audio(sources=["upload", "microphone"], label="Audio file"),
outputs=gr.Text(), title="Speech Emotion Recognition App",
description="Upload an audio file and hit the 'Submit'\
button")
iface.launch()
if __name__ == '__main__':
main()