File size: 4,049 Bytes
794ebc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torchaudio
from torch import nn
from transformers import AutoFeatureExtractor,AutoModelForAudioClassification,pipeline

#Preprocessing the data
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
max_duration = 2.0  # seconds


if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

softmax = nn.Softmax()


label2id, id2label = dict(), dict()
labels = ['0','1','2','3','4','5','6','7','8','9']
num_labels = 10

for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label


def get_pipeline(model_name):
    if model_name.split('-')[-1].strip()!='ibo':
        return None
    return pipeline(task="audio-classification", model=model_name)


def load_model(model_checkpoint):
    #if model_checkpoint.split('-')[-1].strip()!='ibo': #This is for DEBUGGING 
    #    return None, None

    # construct model and assign it to device
    model = AutoModelForAudioClassification.from_pretrained(
    model_checkpoint, 
    num_labels=num_labels,
    label2id=label2id,
    id2label=id2label,
    ).to(device)

    return model

language_dict = {
                "Igbo":'ibo',
                "Oshiwambo":'kua',
                "Yoruba":'yor',
                 "Oromo":'gax',
                 "Shona":'sna',
                 "Rundi":'run',
                 "Choose language":'none',
                 "MULTILINGUAL":'all'
            }

AUDIO_CLASSIFICATION_MODELS= {'ibo':load_model('chrisjay/afrospeech-wav2vec-ibo'),
                              'kua':load_model('chrisjay/afrospeech-wav2vec-kua'),
                              'sna':load_model('chrisjay/afrospeech-wav2vec-sna'),
                              'yor':load_model('chrisjay/afrospeech-wav2vec-yor'),
                              'gax':load_model('chrisjay/afrospeech-wav2vec-gax'),
                              'run':load_model('chrisjay/afrospeech-wav2vec-run'),
                              'all':load_model('chrisjay/afrospeech-wav2vec-all-6')  }


def cut_if_necessary(signal,num_samples):
        if signal.shape[1] > num_samples:
            signal = signal[:, :num_samples]
        return signal

def right_pad_if_necessary(signal,num_samples):
    length_signal = signal.shape[1]
    if length_signal < num_samples:
        num_missing_samples = num_samples - length_signal
        last_dim_padding = (0, num_missing_samples)
        signal = torch.nn.functional.pad(signal, last_dim_padding)
    return signal

def resample_if_necessary(signal, sr,target_sample_rate,device):
    if sr != target_sample_rate:
        resampler = torchaudio.transforms.Resample(sr, target_sample_rate).to(device)
        signal = resampler(signal)
    return signal

def mix_down_if_necessary(signal):
    if signal.shape[0] > 1:
        signal = torch.mean(signal, dim=0, keepdim=True)
    return signal



def preprocess_audio(waveform,sample_rate,feature_extractor):

    waveform = resample_if_necessary(waveform, sample_rate,16000,device)
    waveform = mix_down_if_necessary(waveform)
    waveform = cut_if_necessary(waveform,16000)
    waveform = right_pad_if_necessary(waveform,16000)
    transformed = feature_extractor(waveform,sampling_rate=feature_extractor.sampling_rate, max_length=16000, truncation=True)
    return transformed



def make_inference(drop_down,audio):
    waveform, sample_rate = torchaudio.load(audio)
    preprocessed_audio = preprocess_audio(waveform,sample_rate,feature_extractor)
    language_code_chosen = language_dict[drop_down]
    model = AUDIO_CLASSIFICATION_MODELS[language_code_chosen]
    model.eval()
    torch_preprocessed_audio = torch.from_numpy(preprocessed_audio.input_values[0])
    # make prediction
    prediction = softmax(model(torch_preprocessed_audio).logits)

    sorted_prediction = torch.sort(prediction,descending=True)
    confidences={}
    for s,v in zip(sorted_prediction.indices.detach().numpy().tolist()[0],sorted_prediction.values.detach().numpy().tolist()[0]):
        confidences.update({s:v})
    return confidences