import torchaudio import torch from model import M11 import gradio as gr def _cut_if_necessary(signal): if signal.shape[1] > 400000: signal = signal[:, :400000] return signal def _right_pad_if_necessary(signal): signal_length = signal.shape[1] if signal_length < 400000: num_missing_samples = 400000 - signal_length last_dim_padding = (0, num_missing_samples) # will add 0 number of zeros in the left side of array and num_missing_samples number of zeros in the right part signal = torch.nn.functional.pad(signal, last_dim_padding) return signal def preprocess(signal, sr, device): # add a channel dimension for 1d samples if len(signal.shape) == 1: signal = signal.unsqueeze(0) # resampling the audio signal with the training sample rate if sr != 8_000: resampler = torchaudio.transforms.Resample(sr, 8_000).to(device) signal = resampler(signal) # turning the stereo signals into mono if signal.shape[0] > 1: signal = torch.mean(signal, dim=0, keepdim=True) signal = _cut_if_necessary(signal) # truncating longer signals signal = _right_pad_if_necessary(signal) # extending shorter signals return signal def pipeline(audio_file): audio_PATH = audio_file.name audio, sample_rate = torchaudio.load(audio_PATH) processed_audio = preprocess(audio.to(DEVICE), sample_rate, DEVICE) with torch.no_grad(): pred = torch.exp(classifier(processed_audio.unsqueeze(0)).squeeze()) # turning log_softmax into probabilities print({labels[i]: float(pred[i]) for i in range(3)}) print(classifier(processed_audio.unsqueeze(0)).squeeze()) return {labels[i]: float(pred[i]) for i in range(3)} DEVICE = "cuda" if torch.cuda.is_available() else "cpu" model_PATH = "./model.ckpt" labels = ["Threat", "Normal", "Sarcastic"] classifier = M11.load_from_checkpoint(model_PATH).to(DEVICE) classifier.eval() inputs = gr.inputs.Audio(label="Input Audio", type="file") outputs = gr.outputs.Label(num_top_classes=3) title = "Threat Detection From Bengali Voice Calls" description = "Gradio demo for Audio Classification, simply upload your audio, or click one of the examples to load them. Read more at the links below." article = "

Github Repo

" examples = [ ['sample_audio.wav'] ] gr.Interface(pipeline, inputs, outputs, title=title, description=description, article=article, examples=examples).launch()