multimodal_emotion_recognition / src /audio_processor.py
JuanJoseMV's picture
load classifier weights
3f5f788
import time
import torch
import librosa
import numpy as np
import gradio as gr
import gradio as gr
from .generate_graph import create_behaviour_gantt_plot
from transformers import Wav2Vec2Processor
SAMPLING_RATE = 16_000
class AudioProcessor:
def __init__(
self,
emotion_model,
segmentation_model,
device,
behaviour_model=None,
):
self.emotion_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
self.emotion_model = emotion_model
self.behaviour_model = behaviour_model
self.device = device
self.audio_emotion_labels = {
0: "Neutralità",
1: "Rabbia",
2: "Paura",
3: "Gioia",
4: "Sorpresa",
5: "Tristezza",
6: "Disgusto",
}
self.emotion_translation = {
"neutrality": "Neutralità",
"anger": "Rabbia",
"fear": "Paura",
"joy": "Gioia",
"surprise": "Sorpresa",
"sadness": "Tristezza",
"disgust": "Disgusto"
}
self.behaviour_labels = {
0: "frustrated",
1: "delighted",
2: "dysregulated",
}
self.behaviour_translation = {
"frustrated": "frustazione",
"delighted": "incantato",
"dysregulated": "disregolazione",
}
self.segmentation_model = segmentation_model
self._set_emotion_model()
if self.behaviour_model:
self._set_behaviour_model()
self.behaviour_confidence = 0.6
self.chart_generator = None
def _set_emotion_model(self):
self.emotion_model.to(self.device)
self.emotion_model.eval()
def _set_behaviour_model(self):
self.behaviour_model.to(self.device)
self.behaviour_model.eval()
def _prepare_transcribed_text(self, chunks):
formated_timestamps = []
predictions = []
for chunk in chunks:
start = chunk[0] / SAMPLING_RATE
end = chunk[1] / SAMPLING_RATE
formated_start = time.strftime('%H:%M:%S', time.gmtime(start))
formated_end = time.strftime('%H:%M:%S', time.gmtime(end))
formated_timestamps.append(f"**({formated_start} - {formated_end})**")
predictions.append(f"**[{chunk[2]}]**")
transcribed_texts = [chunk[3] for chunk in chunks]
transcribed_text = "<br/>".join(
[
f"{formated_timestamps[i]}: {transcribed_texts[i]} {predictions[i]}" for i in range(len(transcribed_texts))
]
)
print(f"Transcribed text:\n{transcribed_text}")
return transcribed_text
def __call__(self, audio_path: str):
"""
Predicts the emotion label for a given audio input.
Args:
audio (filepath): The audio input path to be processed.
Returns:
str: The predicted emotion label.
"""
try:
input_frames, _ = librosa.load(
audio_path,
sr=SAMPLING_RATE
)
except Exception as e:
gr.Error(f"Error loading audio file: {e}.")
print("Segmenting audio...")
out = self.segmentation_model(
inputs={
"raw": input_frames,
"sampling_rate": SAMPLING_RATE,
},
chunk_length_s=30,
stride_length_s=5,
return_timestamps=True,
)
emotion_chunks = []
behaviour_chunks = []
timestamps = []
predicted_labels = []
all_probabilities = []
print("Analizing chunks...")
for chunk in out["chunks"]:
# trim audio from timestamps
start = int(chunk["timestamp"][0] * SAMPLING_RATE)
end = int(chunk["timestamp"][1] * SAMPLING_RATE if chunk["timestamp"][1] else len(input_frames))
audio = input_frames[start:end]
inputs = self.emotion_processor(audio, chunk["text"], return_tensors="pt", sampling_rate=SAMPLING_RATE)
print(f"Inputs: {inputs}")
if "input_values" in inputs:
inputs["input_features"] = inputs.pop("input_values")
inputs['input_features'] = inputs['input_features'].to(self.device)
inputs['input_ids'] = inputs['input_ids'].to(self.device)
inputs['text_attention_mask'] = inputs['text_attention_mask'].to(self.device)
print("Predicting emotion for chunk...")
logits = self.emotion_model(**inputs).logits
logits = logits.detach().cpu()
softmax = torch.nn.Softmax(dim=1)
probabilities = softmax(logits).squeeze(0)
prediction = probabilities.argmax().item()
predicted_label = self.emotion_processor.config.id2label[prediction]
label_translation = self.emotion_translation[predicted_label]
emotion_chunks.append(
(
start,
end,
label_translation,
chunk["text"],
np.round(probabilities[prediction].item(), 2)
)
)
timestamps.append((start, end))
predicted_labels.append(label_translation)
all_probabilities.append(probabilities[prediction].item())
inputs = self.emotion_processor(audio, return_tensors="pt", sampling_rate=SAMPLING_RATE)
if "input_values" in inputs:
inputs["input_features"] = inputs.pop("input_values")
inputs = inputs.input_features.to(self.device)
print("Predicting behaviour for chunk...")
logits = self.behaviour_model(inputs).logits
probabilities = torch.nn.functional.softmax(logits.detach().cpu(), dim=-1).squeeze()
behaviour_chunks.append(
(
start,
end,
chunk["text"],
np.round(probabilities[2].item(), 2),
label_translation,
)
)
behaviour_gantt = create_behaviour_gantt_plot(behaviour_chunks)
# transcribed_text = self._prepare_transcribed_text(emotion_chunks)
return (
behaviour_gantt,
# transcribed_text,
)