|
import time |
|
import torch |
|
import librosa |
|
import numpy as np |
|
import gradio as gr |
|
import gradio as gr |
|
from .generate_graph import create_behaviour_gantt_plot |
|
from transformers import Wav2Vec2Processor |
|
|
|
|
|
SAMPLING_RATE = 16_000 |
|
|
|
class AudioProcessor: |
|
def __init__( |
|
self, |
|
emotion_model, |
|
segmentation_model, |
|
device, |
|
behaviour_model=None, |
|
): |
|
self.emotion_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") |
|
self.emotion_model = emotion_model |
|
self.behaviour_model = behaviour_model |
|
self.device = device |
|
self.audio_emotion_labels = { |
|
0: "Neutralità", |
|
1: "Rabbia", |
|
2: "Paura", |
|
3: "Gioia", |
|
4: "Sorpresa", |
|
5: "Tristezza", |
|
6: "Disgusto", |
|
} |
|
self.emotion_translation = { |
|
"neutrality": "Neutralità", |
|
"anger": "Rabbia", |
|
"fear": "Paura", |
|
"joy": "Gioia", |
|
"surprise": "Sorpresa", |
|
"sadness": "Tristezza", |
|
"disgust": "Disgusto" |
|
} |
|
self.behaviour_labels = { |
|
0: "frustrated", |
|
1: "delighted", |
|
2: "dysregulated", |
|
} |
|
self.behaviour_translation = { |
|
"frustrated": "frustazione", |
|
"delighted": "incantato", |
|
"dysregulated": "disregolazione", |
|
} |
|
self.segmentation_model = segmentation_model |
|
|
|
self._set_emotion_model() |
|
if self.behaviour_model: |
|
self._set_behaviour_model() |
|
|
|
self.behaviour_confidence = 0.6 |
|
|
|
self.chart_generator = None |
|
|
|
def _set_emotion_model(self): |
|
self.emotion_model.to(self.device) |
|
self.emotion_model.eval() |
|
|
|
def _set_behaviour_model(self): |
|
self.behaviour_model.to(self.device) |
|
self.behaviour_model.eval() |
|
|
|
def _prepare_transcribed_text(self, chunks): |
|
formated_timestamps = [] |
|
predictions = [] |
|
|
|
for chunk in chunks: |
|
start = chunk[0] / SAMPLING_RATE |
|
end = chunk[1] / SAMPLING_RATE |
|
formated_start = time.strftime('%H:%M:%S', time.gmtime(start)) |
|
formated_end = time.strftime('%H:%M:%S', time.gmtime(end)) |
|
formated_timestamps.append(f"**({formated_start} - {formated_end})**") |
|
|
|
predictions.append(f"**[{chunk[2]}]**") |
|
|
|
transcribed_texts = [chunk[3] for chunk in chunks] |
|
transcribed_text = "<br/>".join( |
|
[ |
|
f"{formated_timestamps[i]}: {transcribed_texts[i]} {predictions[i]}" for i in range(len(transcribed_texts)) |
|
] |
|
) |
|
|
|
print(f"Transcribed text:\n{transcribed_text}") |
|
|
|
|
|
return transcribed_text |
|
|
|
def __call__(self, audio_path: str): |
|
""" |
|
Predicts the emotion label for a given audio input. |
|
|
|
Args: |
|
audio (filepath): The audio input path to be processed. |
|
|
|
Returns: |
|
str: The predicted emotion label. |
|
|
|
""" |
|
try: |
|
input_frames, _ = librosa.load( |
|
audio_path, |
|
sr=SAMPLING_RATE |
|
) |
|
except Exception as e: |
|
gr.Error(f"Error loading audio file: {e}.") |
|
|
|
print("Segmenting audio...") |
|
out = self.segmentation_model( |
|
inputs={ |
|
"raw": input_frames, |
|
"sampling_rate": SAMPLING_RATE, |
|
}, |
|
chunk_length_s=30, |
|
stride_length_s=5, |
|
return_timestamps=True, |
|
) |
|
|
|
emotion_chunks = [] |
|
behaviour_chunks = [] |
|
timestamps = [] |
|
predicted_labels = [] |
|
all_probabilities = [] |
|
|
|
print("Analizing chunks...") |
|
for chunk in out["chunks"]: |
|
|
|
start = int(chunk["timestamp"][0] * SAMPLING_RATE) |
|
end = int(chunk["timestamp"][1] * SAMPLING_RATE if chunk["timestamp"][1] else len(input_frames)) |
|
|
|
audio = input_frames[start:end] |
|
|
|
inputs = self.emotion_processor(audio, chunk["text"], return_tensors="pt", sampling_rate=SAMPLING_RATE) |
|
|
|
print(f"Inputs: {inputs}") |
|
|
|
if "input_values" in inputs: |
|
inputs["input_features"] = inputs.pop("input_values") |
|
|
|
inputs['input_features'] = inputs['input_features'].to(self.device) |
|
inputs['input_ids'] = inputs['input_ids'].to(self.device) |
|
inputs['text_attention_mask'] = inputs['text_attention_mask'].to(self.device) |
|
|
|
print("Predicting emotion for chunk...") |
|
logits = self.emotion_model(**inputs).logits |
|
|
|
logits = logits.detach().cpu() |
|
softmax = torch.nn.Softmax(dim=1) |
|
probabilities = softmax(logits).squeeze(0) |
|
|
|
prediction = probabilities.argmax().item() |
|
|
|
predicted_label = self.emotion_processor.config.id2label[prediction] |
|
label_translation = self.emotion_translation[predicted_label] |
|
|
|
emotion_chunks.append( |
|
( |
|
start, |
|
end, |
|
label_translation, |
|
chunk["text"], |
|
np.round(probabilities[prediction].item(), 2) |
|
) |
|
) |
|
|
|
timestamps.append((start, end)) |
|
predicted_labels.append(label_translation) |
|
all_probabilities.append(probabilities[prediction].item()) |
|
|
|
|
|
inputs = self.emotion_processor(audio, return_tensors="pt", sampling_rate=SAMPLING_RATE) |
|
if "input_values" in inputs: |
|
inputs["input_features"] = inputs.pop("input_values") |
|
|
|
inputs = inputs.input_features.to(self.device) |
|
print("Predicting behaviour for chunk...") |
|
logits = self.behaviour_model(inputs).logits |
|
probabilities = torch.nn.functional.softmax(logits.detach().cpu(), dim=-1).squeeze() |
|
behaviour_chunks.append( |
|
( |
|
start, |
|
end, |
|
chunk["text"], |
|
np.round(probabilities[2].item(), 2), |
|
label_translation, |
|
) |
|
) |
|
behaviour_gantt = create_behaviour_gantt_plot(behaviour_chunks) |
|
|
|
|
|
|
|
return ( |
|
behaviour_gantt, |
|
|
|
) |