File size: 4,142 Bytes
e111c36 87f6c9c e111c36 f24ec85 fc0b2dd f24ec85 fc0b2dd e111c36 87f6c9c e111c36 fc0b2dd e111c36 f24ec85 e111c36 0a3c034 e111c36 fc0b2dd f24ec85 87f6c9c f24ec85 fc0b2dd f24ec85 87f6c9c 0a3c034 87f6c9c f24ec85 87f6c9c f24ec85 fc0b2dd 87f6c9c f24ec85 87f6c9c f24ec85 87f6c9c fc0b2dd f24ec85 0a3c034 fc0b2dd 87f6c9c 7b4cf33 87f6c9c 6d09264 87f6c9c f24ec85 87f6c9c f24ec85 fc0b2dd f24ec85 87f6c9c f24ec85 87f6c9c e111c36 f24ec85 56606dd f24ec85 e111c36 87f6c9c 0a3c034 87f6c9c 56606dd 87f6c9c 56606dd 87f6c9c 56606dd 87f6c9c 56606dd 0a3c034 e111c36 56813b6 d5a4fc1 f24ec85 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
import torchaudio
# Define emotion labels and corresponding icons
emotion_labels = ["angry", "calm", "disgust", "fearful", "happy", "neutral", "sad", "surprised"]
emotion_icons = {
"angry": "π ", "calm": "π", "disgust": "π€’", "fearful": "π¨",
"happy": "π", "neutral": "π", "sad": "π’", "surprised": "π²"
}
# Load model and processor
model_name = "Dpngtm/wav2vec2-emotion-recognition"
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
processor = Wav2Vec2Processor.from_pretrained(model_name, num_labels=len(emotion_labels))
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
def recognize_emotion(audio):
try:
# Handle case where no audio is provided
if audio is None:
return {f"{emotion} {emotion_icons[emotion]}": 0.0 for emotion in emotion_labels}
# Load and preprocess the audio
audio_path = audio if isinstance(audio, str) else audio.name
speech_array, sampling_rate = torchaudio.load(audio_path)
# Limit audio length to 1 minute (60 seconds)
duration = speech_array.shape[1] / sampling_rate
if duration > 60:
return {
"Error": "Audio too long (max 1 minute)",
**{f"{emotion} {emotion_icons[emotion]}": 0.0 for emotion in emotion_labels}
}
# Resample audio if not at 16kHz
if sampling_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
speech_array = resampler(speech_array)
# Convert stereo to mono if necessary
if speech_array.shape[0] > 1:
speech_array = torch.mean(speech_array, dim=0, keepdim=True)
# Normalize audio
speech_array = speech_array / torch.max(torch.abs(speech_array))
speech_array = speech_array.squeeze().numpy()
# Process audio with the model
inputs = processor(speech_array, sampling_rate=16000, return_tensors='pt', padding=True)
input_values = inputs.input_values.to(device)
with torch.no_grad():
outputs = model(input_values)
logits = outputs.logits
probs = F.softmax(logits, dim=-1)[0].cpu().numpy()
# Prepare the confidence scores without converting to percentages
confidence_scores = {
f"{emotion} {emotion_icons[emotion]}": prob
for emotion, prob in zip(emotion_labels, probs)
}
# Sort scores in descending order
sorted_scores = dict(sorted(confidence_scores.items(), key=lambda x: x[1], reverse=True))
return sorted_scores
except Exception as e:
# Return error message along with zeroed-out emotion scores
return {
"Error": str(e),
**{f"{emotion} {emotion_icons[emotion]}": 0.0 for emotion in emotion_labels}
}
# Supported emotions for display
supported_emotions = " | ".join([f"{emotion_icons[emotion]} {emotion}" for emotion in emotion_labels])
# Gradio Interface setup
interface = gr.Interface(
fn=recognize_emotion,
inputs=gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label="Record or Upload Audio"
),
outputs=gr.Label(
num_top_classes=len(emotion_labels),
label="Detected Emotion"
),
title="Speech Emotion Recognition",
description=f"""
### Supported Emotions:
{supported_emotions}
Maximum audio length: 1 minute""",
theme=gr.themes.Soft(
primary_hue="orange",
secondary_hue="blue"
),
css="""
.gradio-container {max-width: 800px}
.label {font-size: 18px}
"""
)
if __name__ == "__main__":
interface.launch(
share=True,
debug=True,
server_name="0.0.0.0",
server_port=7860
)
|