import time import torch import librosa import numpy as np import gradio as gr import gradio as gr from .generate_graph import create_behaviour_gantt_plot from transformers import Wav2Vec2Processor SAMPLING_RATE = 16_000 class AudioProcessor: def __init__( self, emotion_model, segmentation_model, device, behaviour_model=None, ): self.emotion_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") self.emotion_model = emotion_model self.behaviour_model = behaviour_model self.device = device self.audio_emotion_labels = { 0: "Neutralità", 1: "Rabbia", 2: "Paura", 3: "Gioia", 4: "Sorpresa", 5: "Tristezza", 6: "Disgusto", } self.emotion_translation = { "neutrality": "Neutralità", "anger": "Rabbia", "fear": "Paura", "joy": "Gioia", "surprise": "Sorpresa", "sadness": "Tristezza", "disgust": "Disgusto" } self.behaviour_labels = { 0: "frustrated", 1: "delighted", 2: "dysregulated", } self.behaviour_translation = { "frustrated": "frustazione", "delighted": "incantato", "dysregulated": "disregolazione", } self.segmentation_model = segmentation_model self._set_emotion_model() if self.behaviour_model: self._set_behaviour_model() self.behaviour_confidence = 0.6 self.chart_generator = None def _set_emotion_model(self): self.emotion_model.to(self.device) self.emotion_model.eval() def _set_behaviour_model(self): self.behaviour_model.to(self.device) self.behaviour_model.eval() def _prepare_transcribed_text(self, chunks): formated_timestamps = [] predictions = [] for chunk in chunks: start = chunk[0] / SAMPLING_RATE end = chunk[1] / SAMPLING_RATE formated_start = time.strftime('%H:%M:%S', time.gmtime(start)) formated_end = time.strftime('%H:%M:%S', time.gmtime(end)) formated_timestamps.append(f"**({formated_start} - {formated_end})**") predictions.append(f"**[{chunk[2]}]**") transcribed_texts = [chunk[3] for chunk in chunks] transcribed_text = "
".join( [ f"{formated_timestamps[i]}: {transcribed_texts[i]} {predictions[i]}" for i in range(len(transcribed_texts)) ] ) print(f"Transcribed text:\n{transcribed_text}") return transcribed_text def __call__(self, audio_path: str): """ Predicts the emotion label for a given audio input. Args: audio (filepath): The audio input path to be processed. Returns: str: The predicted emotion label. """ try: input_frames, _ = librosa.load( audio_path, sr=SAMPLING_RATE ) except Exception as e: gr.Error(f"Error loading audio file: {e}.") print("Segmenting audio...") out = self.segmentation_model( inputs={ "raw": input_frames, "sampling_rate": SAMPLING_RATE, }, chunk_length_s=30, stride_length_s=5, return_timestamps=True, ) emotion_chunks = [] behaviour_chunks = [] timestamps = [] predicted_labels = [] all_probabilities = [] print("Analizing chunks...") for chunk in out["chunks"]: # trim audio from timestamps start = int(chunk["timestamp"][0] * SAMPLING_RATE) end = int(chunk["timestamp"][1] * SAMPLING_RATE if chunk["timestamp"][1] else len(input_frames)) audio = input_frames[start:end] inputs = self.emotion_processor(audio, chunk["text"], return_tensors="pt", sampling_rate=SAMPLING_RATE) print(f"Inputs: {inputs}") if "input_values" in inputs: inputs["input_features"] = inputs.pop("input_values") inputs['input_features'] = inputs['input_features'].to(self.device) inputs['input_ids'] = inputs['input_ids'].to(self.device) inputs['text_attention_mask'] = inputs['text_attention_mask'].to(self.device) print("Predicting emotion for chunk...") logits = self.emotion_model(**inputs).logits logits = logits.detach().cpu() softmax = torch.nn.Softmax(dim=1) probabilities = softmax(logits).squeeze(0) prediction = probabilities.argmax().item() predicted_label = self.emotion_processor.config.id2label[prediction] label_translation = self.emotion_translation[predicted_label] emotion_chunks.append( ( start, end, label_translation, chunk["text"], np.round(probabilities[prediction].item(), 2) ) ) timestamps.append((start, end)) predicted_labels.append(label_translation) all_probabilities.append(probabilities[prediction].item()) inputs = self.emotion_processor(audio, return_tensors="pt", sampling_rate=SAMPLING_RATE) if "input_values" in inputs: inputs["input_features"] = inputs.pop("input_values") inputs = inputs.input_features.to(self.device) print("Predicting behaviour for chunk...") logits = self.behaviour_model(inputs).logits probabilities = torch.nn.functional.softmax(logits.detach().cpu(), dim=-1).squeeze() behaviour_chunks.append( ( start, end, chunk["text"], np.round(probabilities[2].item(), 2), label_translation, ) ) behaviour_gantt = create_behaviour_gantt_plot(behaviour_chunks) # transcribed_text = self._prepare_transcribed_text(emotion_chunks) return ( behaviour_gantt, # transcribed_text, )