Vaibhav Srivastav commited on
Commit
b8af00e
1 Parent(s): 851eb15

adding decoding w lm

Browse files
Files changed (2) hide show
  1. 4gram_small.arpa.gz +3 -0
  2. app.py +24 -2
4gram_small.arpa.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4c4fe64751abecdeb7040fe6ed7f2440c2d3f36ed35c43e3510f7cf95578f2a
3
+ size 18358716
app.py CHANGED
@@ -42,6 +42,28 @@ def predict_and_ctc_decode(input_file, model_name):
42
 
43
  return transcribed_text
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def predict_and_greedy_decode(input_file, model_name):
46
  processor, model = return_processor_and_model(model_name)
47
  speech = load_and_fix_data(input_file)
@@ -57,12 +79,12 @@ def predict_and_greedy_decode(input_file, model_name):
57
  return transcribed_text
58
 
59
  def return_all_predictions(input_file, model_name):
60
- return predict_and_ctc_decode(input_file, model_name), predict_and_greedy_decode(input_file, model_name)
61
 
62
 
63
  gr.Interface(return_all_predictions,
64
  inputs = [gr.inputs.Audio(source="microphone", type="filepath", label="Record/ Drop audio"), gr.inputs.Dropdown(["facebook/wav2vec2-base-960h", "facebook/hubert-large-ls960-ft"], label="Model Name")],
65
- outputs = [gr.outputs.Textbox(label="Beam CTC decoding"), gr.outputs.Textbox(label="Greedy decoding")],
66
  title="ASR using Wav2Vec2/ Hubert & pyctcdecode",
67
  description = "Comparing greedy decoder with beam search CTC decoder, record/ drop your audio!",
68
  layout = "horizontal",
42
 
43
  return transcribed_text
44
 
45
+ def predict_and_ctc_lm_decode(input_file, model_name):
46
+ processor, model = return_processor_and_model(model_name)
47
+ speech = load_and_fix_data(input_file)
48
+
49
+ input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
50
+ logits = model(input_values).logits.cpu().detach().numpy()[0]
51
+
52
+ vocab_list = list(processor.tokenizer.get_vocab().keys())
53
+ vocab_dict = processor.tokenizer.get_vocab()
54
+ sorted_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}
55
+
56
+ decoder = build_ctcdecoder(
57
+ list(sorted_dict.keys()),
58
+ "4gram_small.arpa.gz",
59
+ )
60
+
61
+ pred = decoder.decode(logits)
62
+
63
+ transcribed_text = fix_transcription_casing(pred.lower())
64
+
65
+ return transcribed_text
66
+
67
  def predict_and_greedy_decode(input_file, model_name):
68
  processor, model = return_processor_and_model(model_name)
69
  speech = load_and_fix_data(input_file)
79
  return transcribed_text
80
 
81
  def return_all_predictions(input_file, model_name):
82
+ return predict_and_ctc_decode(input_file, model_name), predict_and_ctc_lm_decode(input_file, model_name), predict_and_greedy_decode(input_file, model_name)
83
 
84
 
85
  gr.Interface(return_all_predictions,
86
  inputs = [gr.inputs.Audio(source="microphone", type="filepath", label="Record/ Drop audio"), gr.inputs.Dropdown(["facebook/wav2vec2-base-960h", "facebook/hubert-large-ls960-ft"], label="Model Name")],
87
+ outputs = [gr.outputs.Textbox(label="Beam CTC decoding"), gr.outputs.Textbox(label="Beam CTC decoding w/ LM"), gr.outputs.Textbox(label="Greedy decoding")],
88
  title="ASR using Wav2Vec2/ Hubert & pyctcdecode",
89
  description = "Comparing greedy decoder with beam search CTC decoder, record/ drop your audio!",
90
  layout = "horizontal",