Yehor commited on
Commit
2e42bb8
1 Parent(s): a208abb

Show full text

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -8,7 +8,7 @@ model_name = "Yehor/wav2vec2-xls-r-1b-uk-with-lm"
8
  tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_name)
9
  processor = Wav2Vec2ProcessorWithLM.from_pretrained(model_name)
10
  model = Wav2Vec2ForCTC.from_pretrained(model_name)
11
- model.to("cpu")
12
 
13
 
14
  # define function to read in sound file
@@ -45,6 +45,7 @@ def inference(audio):
45
  stride_length_s=(4, 2),
46
  return_tensors="pt",
47
  ).input_values
 
48
 
49
  with torch.no_grad():
50
  logits = model(input_values).logits
@@ -55,6 +56,7 @@ def inference(audio):
55
  time_offset = 320 / sample_rate
56
 
57
  total_prediction = []
 
58
  for item in prediction.word_offsets:
59
  r = item
60
 
@@ -62,8 +64,11 @@ def inference(audio):
62
  e = round(r['end_offset'] * time_offset, 2)
63
 
64
  total_prediction.append(f"{s} - {e}: {r['word']}")
 
 
65
  print(prediction[0])
66
- return "\n".join(total_prediction)
 
67
 
68
 
69
  inputs = gr.inputs.Audio(label="Input Audio", type="file")
 
8
  tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_name)
9
  processor = Wav2Vec2ProcessorWithLM.from_pretrained(model_name)
10
  model = Wav2Vec2ForCTC.from_pretrained(model_name)
11
+ model.to("cuda")
12
 
13
 
14
  # define function to read in sound file
 
45
  stride_length_s=(4, 2),
46
  return_tensors="pt",
47
  ).input_values
48
+ input_values = input_values.cuda()
49
 
50
  with torch.no_grad():
51
  logits = model(input_values).logits
 
56
  time_offset = 320 / sample_rate
57
 
58
  total_prediction = []
59
+ words = []
60
  for item in prediction.word_offsets:
61
  r = item
62
 
 
64
  e = round(r['end_offset'] * time_offset, 2)
65
 
66
  total_prediction.append(f"{s} - {e}: {r['word']}")
67
+ words.append(r['word'])
68
+
69
  print(prediction[0])
70
+
71
+ return "\n".join(total_prediction) + "\n\n" + ' '.join(words)
72
 
73
 
74
  inputs = gr.inputs.Audio(label="Input Audio", type="file")