File size: 3,891 Bytes
57926d1
 
 
e3fed65
a05c869
0448aa2
17a49e1
a05c869
57926d1
a05c869
208ffe2
 
851eb15
 
208ffe2
8d69919
7806ecb
57926d1
3b8d409
57926d1
 
 
 
 
 
379fa33
0448aa2
57926d1
 
 
8d69919
 
3b8d409
 
0448aa2
d32240b
3b8d409
 
0448aa2
 
 
 
3b8d409
0448aa2
379fa33
b8af00e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d69919
 
379fa33
 
 
 
 
 
 
 
bbbf923
379fa33
 
 
decaa84
b8af00e
379fa33
 
8a068ad
607a780
b8af00e
8d69919
7806ecb
0448aa2
be02097
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import nltk
import librosa
import torch
import kenlm
import gradio as gr
from pyctcdecode import build_ctcdecoder
from transformers import Wav2Vec2Processor, AutoModelForCTC

nltk.download("punkt")

wav2vec2processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") 
wav2vec2model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-base-960h")
hubertprocessor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft") 
hubertmodel = AutoModelForCTC.from_pretrained("facebook/hubert-large-ls960-ft")

def return_processor_and_model(model_name):
    return Wav2Vec2Processor.from_pretrained(model_name), AutoModelForCTC.from_pretrained(model_name)

def load_and_fix_data(input_file):  
  speech, sample_rate = librosa.load(input_file)
  if len(speech.shape) > 1: 
      speech = speech[:,0] + speech[:,1]
  if sample_rate !=16000:
    speech = librosa.resample(speech, sample_rate,16000)
  return speech

def fix_transcription_casing(input_sentence):
  sentences = nltk.sent_tokenize(input_sentence)
  return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
  
def predict_and_ctc_decode(input_file, model_name):
  processor, model = return_processor_and_model(model_name)
  speech = load_and_fix_data(input_file)

  input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
  logits = model(input_values).logits.cpu().detach().numpy()[0]
  
  vocab_list = list(processor.tokenizer.get_vocab().keys())  
  decoder = build_ctcdecoder(vocab_list)
  pred = decoder.decode(logits)

  transcribed_text = fix_transcription_casing(pred.lower())

  return transcribed_text

def predict_and_ctc_lm_decode(input_file, model_name):
  processor, model = return_processor_and_model(model_name)
  speech = load_and_fix_data(input_file)

  input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
  logits = model(input_values).logits.cpu().detach().numpy()[0]
  
  vocab_list = list(processor.tokenizer.get_vocab().keys())  
  vocab_dict = processor.tokenizer.get_vocab()
  sorted_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}

  decoder = build_ctcdecoder(
    list(sorted_dict.keys()),
    "4gram_small.arpa.gz",
    )

  pred = decoder.decode(logits)

  transcribed_text = fix_transcription_casing(pred.lower())

  return transcribed_text

def predict_and_greedy_decode(input_file, model_name):
  processor, model = return_processor_and_model(model_name)
  speech = load_and_fix_data(input_file)

  input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
  logits = model(input_values).logits

  predicted_ids = torch.argmax(logits, dim=-1)
  pred = processor.batch_decode(predicted_ids)

  transcribed_text = fix_transcription_casing(pred[0].lower())

  return transcribed_text

def return_all_predictions(input_file, model_name):
  return predict_and_ctc_decode(input_file, model_name), predict_and_ctc_lm_decode(input_file, model_name), predict_and_greedy_decode(input_file, model_name)


gr.Interface(return_all_predictions,
             inputs = [gr.inputs.Audio(source="microphone", type="filepath", label="Record/ Drop audio"), gr.inputs.Dropdown(["facebook/wav2vec2-base-960h", "facebook/hubert-large-ls960-ft"], label="Model Name")],
             outputs = [gr.outputs.Textbox(label="Beam CTC decoding"), gr.outputs.Textbox(label="Beam CTC decoding w/ LM"), gr.outputs.Textbox(label="Greedy decoding")],
             title="ASR using Wav2Vec2/ Hubert & pyctcdecode",
             description = "Comparing greedy decoder with beam search CTC decoder, record/ drop your audio!",
             layout = "horizontal",
             examples = [["test1.wav", "facebook/wav2vec2-base-960h"], ["test2.wav", "facebook/hubert-large-ls960-ft"]], 
             theme="huggingface",
             enable_queue=True).launch()