import gradio as gr import torch import torchaudio import torch.multiprocessing import librosa from utils import parse_finetune_args import os import loralib as lora import torch.nn as nn from wavlm_plus import WavLMWrapper args = parse_finetune_args() device = 'cpu' def get_inference(model,audio_data,audio_filepath): model.to(device) if audio_data is not None: waveform = transform(audio_data) #outputs = model(waveform.to(device)) model.out_layer[3]= nn.Softmax(dim=1) with torch.no_grad(): outputs = model(waveform.to(device)) max_value, max_idx = torch.max(outputs, dim=1) actual_emo ="" if max_idx ==0: actual_emo = "Not Angry" else: actual_emo ="Angry" return { "emotion recognized":actual_emo, "Not Angry probability":str(outputs[0][0].item()), "Angry probability": str(outputs[0][1].item()) } else: if audio_filepath[0] != 16000: transform_model = torchaudio.transforms.Resample(audio_filepath[0], 16000) waveform = transform_model(torch.from_numpy(audio_filepath[1]).to(torch.float)) waveform = waveform.reshape(-1, 1) waveform= waveform.T model.out_layer[3]= nn.Softmax(dim=1) with torch.no_grad(): print(waveform.shape) outputs = model(waveform.to(device)) max_value, max_idx = torch.max(outputs, dim=1) actual_emo ="" if max_idx ==0: actual_emo = "Not Angry" else: actual_emo ="Angry" return { "emotion recognized":actual_emo, "Not Angry probability":str(outputs[0][0].item()), "Angry probability": str(outputs[0][1].item()) } def transform(audio_data): if audio_data[0] != 16000: transform_model = torchaudio.transforms.Resample(audio_data[0], 16000) waveform = transform_model(torch.from_numpy(audio_data[1]).to(torch.float)) waveform = waveform.reshape(-1, 1) waveform=waveform.T return waveform def get_data(finetune, audio_data, audio_filepath): print(finetune) args.finetune_method = str(finetune) if finetune =="finetune" or finetune==None: model_finetune = WavLMWrapper(args=args,output_class_num=2) model_finetune.load_state_dict(torch.load('fold_1_finetune_wavlm_notpure_labels_binary.pt',map_location=torch.device('cpu'))) return get_inference(model_finetune,audio_data,audio_filepath) elif finetune=="lora": model_lora = WavLMWrapper(args=args,output_class_num=2) #replace lora pretrained with default model_lora.load_state_dict(torch.load('fold_lora_full1.pt',map_location=torch.device('cpu'))) return get_inference(model_lora,audio_data,audio_filepath) elif finetune =="adapter": model_adapter = WavLMWrapper(args=args,output_class_num=2) model_adapter.load_state_dict(torch.load('fold_2_adapter_default.pt',map_location=torch.device('cpu'))) return get_inference(model_adapter,audio_data,audio_filepath) def transcribe(finetune, audio_data, audio_filepath=None): text = get_data(finetune,audio_data,audio_filepath) return text gr.Interface( fn=transcribe, inputs=[ gr.inputs.Radio(["lora", "finetune","adapter"]), gr.Audio(source="microphone"), gr.Audio(source="upload") ], outputs="text" ).launch()