import gradio as gr import torch import soundfile as sf import os import numpy as np import os import soundfile as sf import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForSequenceClassification from collections import Counter device = torch.device("cpu") processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device) model_path = "dysarthria_classifier12.pth" # model_path = '/home/user/app/dysarthria_classifier12.pth' model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # if os.path.exists(model_path): # print(f"Loading saved model {model_path}") # model.load_state_dict(torch.load(model_path)) title = "Upload an mp3 file for Psuedobulbar Palsy (PP) detection! (Thai Language)" description = """ The model was trained on Thai audio recordings with the following sentences so please use these sentences: \n ชาวไร่ตัดต้นสนทำท่อนซุง\n ปูม้าวิ่งไปมาบนใบไม้ (เน้นใช้ริมฝีปาก)\n อีกาคอยคาบงูคาบไก่ (เน้นใช้เพดานปาก)\n เพียงแค่ฝนตกลงที่หน้าต่างในบางครา\n “อาาาาาาาาาาา”\n “อีีีีีีีีี”\n “อาาาา” (ดังขึ้นเรื่อยๆ)\n “อาา อาาา อาาาาา”\n """ #

via GIPHY

def predict(file_upload,microphone): max_length = 100000 file_path =file_upload warn_output = "" if (microphone is not None) and (file_upload is not None): warn_output = ( "WARNING: You've uploaded an audio file and used the microphone. " "The recorded file from the microphone will be used and the uploaded audio will be discarded.\n\n" ) elif (microphone is None) and (file_upload is None): return "ERROR: You have to either use the microphone or upload an audio file" if(file_upload is not None): file_path = file_upload if(microphone is not None): file_path = microphone model.eval() with torch.no_grad(): wav_data, _ = sf.read(file_path) inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True) input_values = inputs.input_values.squeeze(0) if max_length - input_values.shape[-1] > 0: input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1) else: input_values = input_values[:max_length] input_values = input_values.unsqueeze(0).to(device) inputs = {"input_values": input_values} logits = model(**inputs).logits logits = logits.squeeze() predicted_class_id = torch.argmax(logits, dim=-1).item() return warn_output + "You probably have PP" if predicted_class_id == 1 else warn_output + "You probably don't have PP" gr.Interface( fn=predict, inputs=[ gr.inputs.Audio(source="upload", type="filepath", optional=True), gr.inputs.Audio(source="microphone", type="filepath", optional=True), ], outputs="text", title=title, description=description, ).launch() # iface = gr.Interface(fn=predict, inputs="file", outputs="text") # iface.launch()