Update app.py
Browse files
app.py
CHANGED
@@ -1,70 +1,142 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
|
|
3 |
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
|
4 |
import torchaudio
|
|
|
5 |
|
6 |
-
# Define emotion labels
|
7 |
emotion_labels = ["angry", "calm", "disgust", "fearful", "happy", "neutral", "sad", "surprised"]
|
8 |
|
9 |
# Load model and processor
|
10 |
-
model_name = "Dpngtm/wav2vec2-emotion-recognition"
|
11 |
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
|
12 |
processor = Wav2Vec2Processor.from_pretrained(model_name, num_labels=len(emotion_labels))
|
13 |
|
14 |
-
# Define device
|
15 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
model.to(device)
|
|
|
17 |
|
18 |
-
# Preprocessing and inference function
|
19 |
def recognize_emotion(audio):
|
20 |
"""
|
21 |
-
Predicts the emotion from an audio file
|
22 |
-
|
23 |
-
Args:
|
24 |
-
audio (str or file-like object): Path or file-like object for the audio file to predict emotion for.
|
25 |
-
|
26 |
-
Returns:
|
27 |
-
str: Predicted emotion label for the given audio file.
|
28 |
"""
|
29 |
try:
|
30 |
-
|
|
|
|
|
|
|
31 |
audio_path = audio if isinstance(audio, str) else audio.name
|
32 |
-
print(f'Received audio file:', audio_path)
|
33 |
-
|
34 |
|
35 |
-
# Load and resample audio
|
36 |
speech_array, sampling_rate = torchaudio.load(audio_path)
|
37 |
-
print(f'Loaded audio with sampling rate:', sampling_rate)
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
if sampling_rate != 16000:
|
40 |
resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
|
41 |
-
speech_array = resampler(speech_array)
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
-
# Process input
|
46 |
-
inputs = processor(
|
|
|
|
|
|
|
|
|
|
|
47 |
input_values = inputs.input_values.to(device)
|
48 |
|
49 |
-
#
|
50 |
with torch.no_grad():
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
except Exception as e:
|
58 |
-
return
|
|
|
|
|
|
|
59 |
|
60 |
-
# Gradio interface
|
61 |
interface = gr.Interface(
|
62 |
fn=recognize_emotion,
|
63 |
-
inputs=gr.Audio(
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
# Launch the app
|
70 |
-
interface.launch(
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
|
5 |
import torchaudio
|
6 |
+
import numpy as np
|
7 |
|
8 |
+
# Define emotion labels
|
9 |
emotion_labels = ["angry", "calm", "disgust", "fearful", "happy", "neutral", "sad", "surprised"]
|
10 |
|
11 |
# Load model and processor
|
12 |
+
model_name = "Dpngtm/wav2vec2-emotion-recognition"
|
13 |
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
|
14 |
processor = Wav2Vec2Processor.from_pretrained(model_name, num_labels=len(emotion_labels))
|
15 |
|
16 |
+
# Define device
|
17 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
model.to(device)
|
19 |
+
model.eval() # Set model to evaluation mode
|
20 |
|
|
|
21 |
def recognize_emotion(audio):
|
22 |
"""
|
23 |
+
Predicts the emotion and confidence scores from an audio file.
|
24 |
+
Max duration: 60 seconds
|
|
|
|
|
|
|
|
|
|
|
25 |
"""
|
26 |
try:
|
27 |
+
if audio is None:
|
28 |
+
return {emotion: 0.0 for emotion in emotion_labels}
|
29 |
+
|
30 |
+
# Handle audio input
|
31 |
audio_path = audio if isinstance(audio, str) else audio.name
|
|
|
|
|
32 |
|
33 |
+
# Load and resample audio
|
34 |
speech_array, sampling_rate = torchaudio.load(audio_path)
|
|
|
35 |
|
36 |
+
# Check audio duration
|
37 |
+
duration = speech_array.shape[1] / sampling_rate
|
38 |
+
if duration > 60: # 60 seconds (1 minute) limit
|
39 |
+
return {
|
40 |
+
"Error": "Audio too long (max 1 minute)",
|
41 |
+
**{emotion: 0.0 for emotion in emotion_labels}
|
42 |
+
}
|
43 |
+
|
44 |
+
# Resample if needed
|
45 |
if sampling_rate != 16000:
|
46 |
resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
|
47 |
+
speech_array = resampler(speech_array)
|
48 |
+
|
49 |
+
# Convert to mono if stereo
|
50 |
+
if speech_array.shape[0] > 1:
|
51 |
+
speech_array = torch.mean(speech_array, dim=0, keepdim=True)
|
52 |
+
|
53 |
+
# Normalize audio
|
54 |
+
speech_array = speech_array / torch.max(torch.abs(speech_array))
|
55 |
+
|
56 |
+
# Convert to numpy and squeeze
|
57 |
+
speech_array = speech_array.squeeze().numpy()
|
58 |
|
59 |
+
# Process input
|
60 |
+
inputs = processor(
|
61 |
+
speech_array,
|
62 |
+
sampling_rate=16000,
|
63 |
+
return_tensors='pt',
|
64 |
+
padding=True
|
65 |
+
)
|
66 |
input_values = inputs.input_values.to(device)
|
67 |
|
68 |
+
# Get predictions
|
69 |
with torch.no_grad():
|
70 |
+
outputs = model(input_values)
|
71 |
+
logits = outputs.logits
|
72 |
+
|
73 |
+
# Get probabilities using softmax
|
74 |
+
probs = F.softmax(logits, dim=-1)[0].cpu().numpy()
|
75 |
+
|
76 |
+
# Get confidence scores for all emotions
|
77 |
+
confidence_scores = {
|
78 |
+
emotion: round(float(prob) * 100, 2) # Convert to percentage with 2 decimal places
|
79 |
+
for emotion, prob in zip(emotion_labels, probs)
|
80 |
+
}
|
81 |
+
|
82 |
+
# Sort confidence scores by value
|
83 |
+
sorted_scores = dict(sorted(
|
84 |
+
confidence_scores.items(),
|
85 |
+
key=lambda x: x[1],
|
86 |
+
reverse=True
|
87 |
+
))
|
88 |
+
|
89 |
+
return sorted_scores
|
90 |
+
|
91 |
except Exception as e:
|
92 |
+
return {
|
93 |
+
"Error": str(e),
|
94 |
+
**{emotion: 0.0 for emotion in emotion_labels}
|
95 |
+
}
|
96 |
|
97 |
+
# Create Gradio interface
|
98 |
interface = gr.Interface(
|
99 |
fn=recognize_emotion,
|
100 |
+
inputs=gr.Audio(
|
101 |
+
sources=["microphone", "upload"],
|
102 |
+
type="filepath",
|
103 |
+
label="Upload audio or record from microphone",
|
104 |
+
max_length=60 # Set max length to 60 seconds in Gradio interface
|
105 |
+
),
|
106 |
+
outputs=gr.Label(
|
107 |
+
num_top_classes=len(emotion_labels),
|
108 |
+
label="Emotion Predictions"
|
109 |
+
),
|
110 |
+
title="Speech Emotion Recognition",
|
111 |
+
description="""
|
112 |
+
## Speech Emotion Recognition using Wav2Vec2
|
113 |
+
|
114 |
+
This model recognizes emotions from speech audio in the following categories:
|
115 |
+
- Angry π
|
116 |
+
- Calm π
|
117 |
+
- Disgust π€’
|
118 |
+
- Fearful π¨
|
119 |
+
- Happy π
|
120 |
+
- Neutral π
|
121 |
+
- Sad π’
|
122 |
+
- Surprised π²
|
123 |
+
|
124 |
+
### Instructions:
|
125 |
+
1. Upload an audio file or record through the microphone
|
126 |
+
2. Wait for processing
|
127 |
+
3. View predicted emotions with confidence scores
|
128 |
+
|
129 |
+
### Notes:
|
130 |
+
- Maximum audio length: 1 minute
|
131 |
+
- Best results with clear speech and minimal background noise
|
132 |
+
- Confidence scores are shown as percentages
|
133 |
+
""",
|
134 |
+
|
135 |
|
136 |
# Launch the app
|
137 |
+
interface.launch(
|
138 |
+
share=True,
|
139 |
+
debug=True,
|
140 |
+
server_name="0.0.0.0",
|
141 |
+
server_port=7860
|
142 |
+
)
|