Vaibhav Srivastav commited on
Commit
379fa33
1 Parent(s): 3b8d409

adding greedy decoding

Browse files
Files changed (1) hide show
  1. app.py +20 -4
app.py CHANGED
@@ -22,8 +22,7 @@ def load_and_fix_data(input_file):
22
  if sample_rate !=16000:
23
  speech = librosa.resample(speech, sample_rate,16000)
24
  return speech
25
-
26
-
27
  def fix_transcription_casing(input_sentence):
28
  sentences = nltk.sent_tokenize(input_sentence)
29
  return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
@@ -41,10 +40,27 @@ def predict_and_decode(input_file):
41
  transcribed_text = fix_transcription_casing(pred.lower())
42
 
43
  return transcribed_text
44
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  gr.Interface(predict_and_decode,
46
  inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Record/ Drop audio"),
47
- outputs = gr.outputs.Textbox(label="Output Text"),
48
  title="ASR using Wav2Vec 2.0 & pyctcdecode",
49
  description = "Extending HF ASR models with pyctcdecode decoder",
50
  layout = "horizontal",
 
22
  if sample_rate !=16000:
23
  speech = librosa.resample(speech, sample_rate,16000)
24
  return speech
25
+
 
26
  def fix_transcription_casing(input_sentence):
27
  sentences = nltk.sent_tokenize(input_sentence)
28
  return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
 
40
  transcribed_text = fix_transcription_casing(pred.lower())
41
 
42
  return transcribed_text
43
+
44
+ def predict_and_greedy_decode(input_file):
45
+ speech = load_and_fix_data(input_file)
46
+
47
+ input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
48
+ logits = model(input_values).logits
49
+
50
+ predicted_ids = torch.argmax(logits, dim=-1)
51
+ pred = processor.batch_decode(predicted_ids)
52
+
53
+ transcribed_text = fix_transcription_casing(pred.lower())
54
+
55
+ return transcribed_text
56
+
57
+ def return_all_predictions(input_file):
58
+ return predict_and_decode(input_file), predict_and_greedy_decode(input_file)
59
+
60
+
61
  gr.Interface(predict_and_decode,
62
  inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Record/ Drop audio"),
63
+ outputs = [gr.outputs.Textbox(label="Beam CTC Decoding"), gr.outputs.Textbox(label="Greedy Decoding")],
64
  title="ASR using Wav2Vec 2.0 & pyctcdecode",
65
  description = "Extending HF ASR models with pyctcdecode decoder",
66
  layout = "horizontal",