Dpngtm commited on
Commit
87f6c9c
β€’
1 Parent(s): ebb590a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -36
app.py CHANGED
@@ -1,70 +1,142 @@
1
  import gradio as gr
2
  import torch
 
3
  from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
4
  import torchaudio
 
5
 
6
- # Define emotion labels (use the same order as during training)
7
  emotion_labels = ["angry", "calm", "disgust", "fearful", "happy", "neutral", "sad", "surprised"]
8
 
9
  # Load model and processor
10
- model_name = "Dpngtm/wav2vec2-emotion-recognition" # Replace with your model's Hugging Face Hub path
11
  model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
12
  processor = Wav2Vec2Processor.from_pretrained(model_name, num_labels=len(emotion_labels))
13
 
14
- # Define device (use GPU if available)
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  model.to(device)
 
17
 
18
- # Preprocessing and inference function
19
  def recognize_emotion(audio):
20
  """
21
- Predicts the emotion from an audio file using the fine-tuned Wav2Vec2 model.
22
-
23
- Args:
24
- audio (str or file-like object): Path or file-like object for the audio file to predict emotion for.
25
-
26
- Returns:
27
- str: Predicted emotion label for the given audio file.
28
  """
29
  try:
30
- # Determine if input is a file path or file-like object
 
 
 
31
  audio_path = audio if isinstance(audio, str) else audio.name
32
- print(f'Received audio file:', audio_path)
33
-
34
 
35
- # Load and resample audio to 16kHz if necessary
36
  speech_array, sampling_rate = torchaudio.load(audio_path)
37
- print(f'Loaded audio with sampling rate:', sampling_rate)
38
 
 
 
 
 
 
 
 
 
 
39
  if sampling_rate != 16000:
40
  resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
41
- speech_array = resampler(speech_array).squeeze().numpy()
42
- else:
43
- speech_array = speech_array.squeeze().numpy()
 
 
 
 
 
 
 
 
44
 
45
- # Process input for the model
46
- inputs = processor(speech_array, sampling_rate=16000, return_tensors='pt', padding=True)
 
 
 
 
 
47
  input_values = inputs.input_values.to(device)
48
 
49
- # Make predictions
50
  with torch.no_grad():
51
- logits = model(input_values).logits
52
- predicted_label = torch.argmax(logits, dim=1).item()
53
-
54
- # Map prediction to emotion label
55
- emotion = emotion_labels[predicted_label]
56
- return emotion
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  except Exception as e:
58
- return f'Error during prediction: {str(e)}'
 
 
 
59
 
60
- # Gradio interface with both microphone and file upload options
61
  interface = gr.Interface(
62
  fn=recognize_emotion,
63
- inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"),
64
- outputs="text",
65
- title="Emotion Recognition with Wav2Vec2",
66
- description="Upload an audio file or record audio, and the model will predict the emotion."
67
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  # Launch the app
70
- interface.launch()
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
+ import torch.nn.functional as F
4
  from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
5
  import torchaudio
6
+ import numpy as np
7
 
8
+ # Define emotion labels
9
  emotion_labels = ["angry", "calm", "disgust", "fearful", "happy", "neutral", "sad", "surprised"]
10
 
11
  # Load model and processor
12
+ model_name = "Dpngtm/wav2vec2-emotion-recognition"
13
  model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
14
  processor = Wav2Vec2Processor.from_pretrained(model_name, num_labels=len(emotion_labels))
15
 
16
+ # Define device
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  model.to(device)
19
+ model.eval() # Set model to evaluation mode
20
 
 
21
  def recognize_emotion(audio):
22
  """
23
+ Predicts the emotion and confidence scores from an audio file.
24
+ Max duration: 60 seconds
 
 
 
 
 
25
  """
26
  try:
27
+ if audio is None:
28
+ return {emotion: 0.0 for emotion in emotion_labels}
29
+
30
+ # Handle audio input
31
  audio_path = audio if isinstance(audio, str) else audio.name
 
 
32
 
33
+ # Load and resample audio
34
  speech_array, sampling_rate = torchaudio.load(audio_path)
 
35
 
36
+ # Check audio duration
37
+ duration = speech_array.shape[1] / sampling_rate
38
+ if duration > 60: # 60 seconds (1 minute) limit
39
+ return {
40
+ "Error": "Audio too long (max 1 minute)",
41
+ **{emotion: 0.0 for emotion in emotion_labels}
42
+ }
43
+
44
+ # Resample if needed
45
  if sampling_rate != 16000:
46
  resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
47
+ speech_array = resampler(speech_array)
48
+
49
+ # Convert to mono if stereo
50
+ if speech_array.shape[0] > 1:
51
+ speech_array = torch.mean(speech_array, dim=0, keepdim=True)
52
+
53
+ # Normalize audio
54
+ speech_array = speech_array / torch.max(torch.abs(speech_array))
55
+
56
+ # Convert to numpy and squeeze
57
+ speech_array = speech_array.squeeze().numpy()
58
 
59
+ # Process input
60
+ inputs = processor(
61
+ speech_array,
62
+ sampling_rate=16000,
63
+ return_tensors='pt',
64
+ padding=True
65
+ )
66
  input_values = inputs.input_values.to(device)
67
 
68
+ # Get predictions
69
  with torch.no_grad():
70
+ outputs = model(input_values)
71
+ logits = outputs.logits
72
+
73
+ # Get probabilities using softmax
74
+ probs = F.softmax(logits, dim=-1)[0].cpu().numpy()
75
+
76
+ # Get confidence scores for all emotions
77
+ confidence_scores = {
78
+ emotion: round(float(prob) * 100, 2) # Convert to percentage with 2 decimal places
79
+ for emotion, prob in zip(emotion_labels, probs)
80
+ }
81
+
82
+ # Sort confidence scores by value
83
+ sorted_scores = dict(sorted(
84
+ confidence_scores.items(),
85
+ key=lambda x: x[1],
86
+ reverse=True
87
+ ))
88
+
89
+ return sorted_scores
90
+
91
  except Exception as e:
92
+ return {
93
+ "Error": str(e),
94
+ **{emotion: 0.0 for emotion in emotion_labels}
95
+ }
96
 
97
+ # Create Gradio interface
98
  interface = gr.Interface(
99
  fn=recognize_emotion,
100
+ inputs=gr.Audio(
101
+ sources=["microphone", "upload"],
102
+ type="filepath",
103
+ label="Upload audio or record from microphone",
104
+ max_length=60 # Set max length to 60 seconds in Gradio interface
105
+ ),
106
+ outputs=gr.Label(
107
+ num_top_classes=len(emotion_labels),
108
+ label="Emotion Predictions"
109
+ ),
110
+ title="Speech Emotion Recognition",
111
+ description="""
112
+ ## Speech Emotion Recognition using Wav2Vec2
113
+
114
+ This model recognizes emotions from speech audio in the following categories:
115
+ - Angry 😠
116
+ - Calm 😌
117
+ - Disgust 🀒
118
+ - Fearful 😨
119
+ - Happy 😊
120
+ - Neutral 😐
121
+ - Sad 😒
122
+ - Surprised 😲
123
+
124
+ ### Instructions:
125
+ 1. Upload an audio file or record through the microphone
126
+ 2. Wait for processing
127
+ 3. View predicted emotions with confidence scores
128
+
129
+ ### Notes:
130
+ - Maximum audio length: 1 minute
131
+ - Best results with clear speech and minimal background noise
132
+ - Confidence scores are shown as percentages
133
+ """,
134
+
135
 
136
  # Launch the app
137
+ interface.launch(
138
+ share=True,
139
+ debug=True,
140
+ server_name="0.0.0.0",
141
+ server_port=7860
142
+ )