Vaibhav Srivastav commited on
Commit
bbbf923
1 Parent(s): 8a068ad

hotfix for greedy search

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -27,7 +27,7 @@ def fix_transcription_casing(input_sentence):
27
  sentences = nltk.sent_tokenize(input_sentence)
28
  return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
29
 
30
- def predict_and_decode(input_file):
31
  speech = load_and_fix_data(input_file)
32
 
33
  input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
@@ -50,12 +50,12 @@ def predict_and_greedy_decode(input_file):
50
  predicted_ids = torch.argmax(logits, dim=-1)
51
  pred = processor.batch_decode(predicted_ids)
52
 
53
- transcribed_text = fix_transcription_casing(pred.lower())
54
 
55
  return transcribed_text
56
 
57
  def return_all_predictions(input_file):
58
- return predict_and_decode(input_file), predict_and_greedy_decode(input_file)
59
 
60
 
61
  gr.Interface(return_all_predictions,
 
27
  sentences = nltk.sent_tokenize(input_sentence)
28
  return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
29
 
30
+ def predict_and_ctc_decode(input_file):
31
  speech = load_and_fix_data(input_file)
32
 
33
  input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
 
50
  predicted_ids = torch.argmax(logits, dim=-1)
51
  pred = processor.batch_decode(predicted_ids)
52
 
53
+ transcribed_text = fix_transcription_casing(pred[0].lower())
54
 
55
  return transcribed_text
56
 
57
  def return_all_predictions(input_file):
58
+ return predict_and_ctc_decode(input_file), predict_and_greedy_decode(input_file)
59
 
60
 
61
  gr.Interface(return_all_predictions,