luganda-asr / app.py
cahya's picture
add KenLM
0d80de1
import soundfile as sf
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from pyctcdecode import build_ctcdecoder
import gradio as gr
import librosa
import os
from multiprocessing import Pool
class KenLM:
def __init__(self, tokenizer, model_name, num_workers=8, beam_width=128):
self.num_workers = num_workers
self.beam_width = beam_width
vocab_dict = tokenizer.get_vocab()
self.vocabulary = [x[0] for x in sorted(vocab_dict.items(), key=lambda x: x[1], reverse=False)]
# Workaround for wrong number of vocabularies:
self.vocabulary = self.vocabulary[:-2]
self.decoder = build_ctcdecoder(self.vocabulary, model_name)
@staticmethod
def lm_postprocess(text):
return ' '.join([x if len(x) > 1 else "" for x in text.split()]).strip()
def decode(self, logits):
probs = logits.cpu().numpy()
# probs = logits.numpy()
with Pool(self.num_workers) as pool:
text = self.decoder.decode_batch(pool, probs)
text = [KenLM.lm_postprocess(x) for x in text]
return text
def convert(inputfile, outfile):
target_sr = 16000
data, sample_rate = librosa.load(inputfile)
data = librosa.resample(data, orig_sr=sample_rate, target_sr=target_sr)
sf.write(outfile, data, target_sr)
api_token = os.getenv("API_TOKEN")
model_name = "indonesian-nlp/wav2vec2-luganda"
processor = Wav2Vec2Processor.from_pretrained(model_name, use_auth_token=api_token)
model = Wav2Vec2ForCTC.from_pretrained(model_name, use_auth_token=api_token)
kenlm = KenLM(processor.tokenizer, "5gram.bin")
def parse_transcription(wav_file):
filename = wav_file.name.split('.')[0]
convert(wav_file.name, filename + "16k.wav")
speech, _ = sf.read(filename + "16k.wav")
input_values = processor(speech, sampling_rate=16_000, return_tensors="pt").input_values
with torch.no_grad():
logits = model(input_values).logits
transcription = kenlm.decode(logits)[0]
return transcription
output = gr.outputs.Textbox(label="The transcript")
input_ = gr.inputs.Audio(source="microphone", type="file")
gr.Interface(parse_transcription, inputs=input_, outputs=[output],
analytics_enabled=False,
title="Automatic Speech Recognition for Luganda",
description="Speech Recognition Live Demo for Luganda",
article="This demo was built for the "
"<a href='https://zindi.africa/competitions/mozilla-luganda-automatic-speech-recognition' target='_blank'>Mozilla Luganda Automatic Speech Recognition Competition</a>. "
"It uses the <a href='https://huggingface.co/indonesian-nlp/wav2vec2-luganda' target='_blank'>indonesian-nlp/wav2vec2-luganda</a> model "
"which was fine-tuned on Luganda Common Voice speech datasets.",
enable_queue=True).launch(inline=False, server_name="0.0.0.0", show_tips=False, enable_queue=True)