File size: 2,986 Bytes
1d2ed5a
 
 
37c396e
1d2ed5a
0d80de1
1d2ed5a
37c396e
1d2ed5a
 
37c396e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d80de1
1d2ed5a
0d80de1
 
 
 
1d2ed5a
 
 
 
 
 
37c396e
1d2ed5a
 
 
 
 
 
 
37c396e
 
 
1d2ed5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d80de1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import soundfile as sf
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from pyctcdecode import build_ctcdecoder
import gradio as gr
import librosa
import os
from multiprocessing import Pool


class KenLM:
    def __init__(self, tokenizer, model_name, num_workers=8, beam_width=128):
        self.num_workers = num_workers
        self.beam_width = beam_width
        vocab_dict = tokenizer.get_vocab()
        self.vocabulary = [x[0] for x in sorted(vocab_dict.items(), key=lambda x: x[1], reverse=False)]
        # Workaround for wrong number of vocabularies:
        self.vocabulary = self.vocabulary[:-2]
        self.decoder = build_ctcdecoder(self.vocabulary, model_name)

    @staticmethod
    def lm_postprocess(text):
        return ' '.join([x if len(x) > 1 else "" for x in text.split()]).strip()

    def decode(self, logits):
        probs = logits.cpu().numpy()
        # probs = logits.numpy()
        with Pool(self.num_workers) as pool:
            text = self.decoder.decode_batch(pool, probs)
            text = [KenLM.lm_postprocess(x) for x in text]
        return text


def convert(inputfile, outfile):
    target_sr = 16000
    data, sample_rate = librosa.load(inputfile)
    data = librosa.resample(data, orig_sr=sample_rate, target_sr=target_sr)
    sf.write(outfile, data, target_sr)


api_token = os.getenv("API_TOKEN")
model_name = "indonesian-nlp/wav2vec2-luganda"
processor = Wav2Vec2Processor.from_pretrained(model_name, use_auth_token=api_token)
model = Wav2Vec2ForCTC.from_pretrained(model_name, use_auth_token=api_token)
kenlm = KenLM(processor.tokenizer, "5gram.bin")


def parse_transcription(wav_file):
    filename = wav_file.name.split('.')[0]
    convert(wav_file.name, filename + "16k.wav")
    speech, _ = sf.read(filename + "16k.wav")
    input_values = processor(speech, sampling_rate=16_000, return_tensors="pt").input_values
    with torch.no_grad():
        logits = model(input_values).logits
    transcription = kenlm.decode(logits)[0]
    return transcription


output = gr.outputs.Textbox(label="The transcript")

input_ = gr.inputs.Audio(source="microphone", type="file")

gr.Interface(parse_transcription, inputs=input_,  outputs=[output],
             analytics_enabled=False,
             title="Automatic Speech Recognition for Luganda",
             description="Speech Recognition Live Demo for Luganda",
             article="This demo was built for the "
                     "<a href='https://zindi.africa/competitions/mozilla-luganda-automatic-speech-recognition' target='_blank'>Mozilla Luganda Automatic Speech Recognition Competition</a>. "
                     "It uses the <a href='https://huggingface.co/indonesian-nlp/wav2vec2-luganda' target='_blank'>indonesian-nlp/wav2vec2-luganda</a> model "
                     "which was fine-tuned on Luganda Common Voice speech datasets.",
             enable_queue=True).launch(inline=False, server_name="0.0.0.0", show_tips=False, enable_queue=True)