asr-pyctcdecode / app.py
Vaibhav Srivastav
adding kenlm to requirements
e3fed65
raw
history blame contribute delete
No virus
3.89 kB
import nltk
import librosa
import torch
import kenlm
import gradio as gr
from pyctcdecode import build_ctcdecoder
from transformers import Wav2Vec2Processor, AutoModelForCTC
nltk.download("punkt")
wav2vec2processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
wav2vec2model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-base-960h")
hubertprocessor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
hubertmodel = AutoModelForCTC.from_pretrained("facebook/hubert-large-ls960-ft")
def return_processor_and_model(model_name):
return Wav2Vec2Processor.from_pretrained(model_name), AutoModelForCTC.from_pretrained(model_name)
def load_and_fix_data(input_file):
speech, sample_rate = librosa.load(input_file)
if len(speech.shape) > 1:
speech = speech[:,0] + speech[:,1]
if sample_rate !=16000:
speech = librosa.resample(speech, sample_rate,16000)
return speech
def fix_transcription_casing(input_sentence):
sentences = nltk.sent_tokenize(input_sentence)
return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
def predict_and_ctc_decode(input_file, model_name):
processor, model = return_processor_and_model(model_name)
speech = load_and_fix_data(input_file)
input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
logits = model(input_values).logits.cpu().detach().numpy()[0]
vocab_list = list(processor.tokenizer.get_vocab().keys())
decoder = build_ctcdecoder(vocab_list)
pred = decoder.decode(logits)
transcribed_text = fix_transcription_casing(pred.lower())
return transcribed_text
def predict_and_ctc_lm_decode(input_file, model_name):
processor, model = return_processor_and_model(model_name)
speech = load_and_fix_data(input_file)
input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
logits = model(input_values).logits.cpu().detach().numpy()[0]
vocab_list = list(processor.tokenizer.get_vocab().keys())
vocab_dict = processor.tokenizer.get_vocab()
sorted_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}
decoder = build_ctcdecoder(
list(sorted_dict.keys()),
"4gram_small.arpa.gz",
)
pred = decoder.decode(logits)
transcribed_text = fix_transcription_casing(pred.lower())
return transcribed_text
def predict_and_greedy_decode(input_file, model_name):
processor, model = return_processor_and_model(model_name)
speech = load_and_fix_data(input_file)
input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
pred = processor.batch_decode(predicted_ids)
transcribed_text = fix_transcription_casing(pred[0].lower())
return transcribed_text
def return_all_predictions(input_file, model_name):
return predict_and_ctc_decode(input_file, model_name), predict_and_ctc_lm_decode(input_file, model_name), predict_and_greedy_decode(input_file, model_name)
gr.Interface(return_all_predictions,
inputs = [gr.inputs.Audio(source="microphone", type="filepath", label="Record/ Drop audio"), gr.inputs.Dropdown(["facebook/wav2vec2-base-960h", "facebook/hubert-large-ls960-ft"], label="Model Name")],
outputs = [gr.outputs.Textbox(label="Beam CTC decoding"), gr.outputs.Textbox(label="Beam CTC decoding w/ LM"), gr.outputs.Textbox(label="Greedy decoding")],
title="ASR using Wav2Vec2/ Hubert & pyctcdecode",
description = "Comparing greedy decoder with beam search CTC decoder, record/ drop your audio!",
layout = "horizontal",
examples = [["test1.wav", "facebook/wav2vec2-base-960h"], ["test2.wav", "facebook/hubert-large-ls960-ft"]],
theme="huggingface",
enable_queue=True).launch()