import torch import torchaudio from torch import nn from transformers import AutoFeatureExtractor,AutoModelForAudioClassification,pipeline #Preprocessing the data feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base") max_duration = 2.0 # seconds if torch.cuda.is_available(): device = "cuda" else: device = "cpu" softmax = nn.Softmax() label2id, id2label = dict(), dict() labels = ['0','1','2','3','4','5','6','7','8','9'] num_labels = 10 for i, label in enumerate(labels): label2id[label] = str(i) id2label[str(i)] = label def get_pipeline(model_name): if model_name.split('-')[-1].strip()!='ibo': return None return pipeline(task="audio-classification", model=model_name) def load_model(model_checkpoint): #if model_checkpoint.split('-')[-1].strip()!='ibo': #This is for DEBUGGING # return None, None # construct model and assign it to device model = AutoModelForAudioClassification.from_pretrained( model_checkpoint, num_labels=num_labels, label2id=label2id, id2label=id2label, ).to(device) return model language_dict = { "Igbo":'ibo', "Oshiwambo":'kua', "Yoruba":'yor', "Oromo":'gax', "Shona":'sna', "Rundi":'run', "Choose language":'none', "MULTILINGUAL":'all' } AUDIO_CLASSIFICATION_MODELS= {'ibo':load_model('chrisjay/afrospeech-wav2vec-ibo'), 'kua':load_model('chrisjay/afrospeech-wav2vec-kua'), 'sna':load_model('chrisjay/afrospeech-wav2vec-sna'), 'yor':load_model('chrisjay/afrospeech-wav2vec-yor'), 'gax':load_model('chrisjay/afrospeech-wav2vec-gax'), 'run':load_model('chrisjay/afrospeech-wav2vec-run'), 'all':load_model('chrisjay/afrospeech-wav2vec-all-6') } def cut_if_necessary(signal,num_samples): if signal.shape[1] > num_samples: signal = signal[:, :num_samples] return signal def right_pad_if_necessary(signal,num_samples): length_signal = signal.shape[1] if length_signal < num_samples: num_missing_samples = num_samples - length_signal last_dim_padding = (0, num_missing_samples) signal = torch.nn.functional.pad(signal, last_dim_padding) return signal def resample_if_necessary(signal, sr,target_sample_rate,device): if sr != target_sample_rate: resampler = torchaudio.transforms.Resample(sr, target_sample_rate).to(device) signal = resampler(signal) return signal def mix_down_if_necessary(signal): if signal.shape[0] > 1: signal = torch.mean(signal, dim=0, keepdim=True) return signal def preprocess_audio(waveform,sample_rate,feature_extractor): waveform = resample_if_necessary(waveform, sample_rate,16000,device) waveform = mix_down_if_necessary(waveform) waveform = cut_if_necessary(waveform,16000) waveform = right_pad_if_necessary(waveform,16000) transformed = feature_extractor(waveform,sampling_rate=feature_extractor.sampling_rate, max_length=16000, truncation=True) return transformed def make_inference(drop_down,audio): waveform, sample_rate = torchaudio.load(audio) preprocessed_audio = preprocess_audio(waveform,sample_rate,feature_extractor) language_code_chosen = language_dict[drop_down] model = AUDIO_CLASSIFICATION_MODELS[language_code_chosen] model.eval() torch_preprocessed_audio = torch.from_numpy(preprocessed_audio.input_values[0]) # make prediction prediction = softmax(model(torch_preprocessed_audio).logits) sorted_prediction = torch.sort(prediction,descending=True) confidences={} for s,v in zip(sorted_prediction.indices.detach().numpy().tolist()[0],sorted_prediction.values.detach().numpy().tolist()[0]): confidences.update({s:v}) return confidences