Yurii Paniv commited on
Commit
442ed91
1 Parent(s): 14485b0

Add proper wav2vec in demo

Browse files
Files changed (3) hide show
  1. app.py +18 -2
  2. requirements-torch.txt +2 -0
  3. requirements.txt +3 -1
app.py CHANGED
@@ -8,7 +8,8 @@ import requests
8
  from os.path import exists
9
  from stt import Model
10
  from datetime import datetime
11
-
 
12
 
13
  # download model
14
  version = "v0.4"
@@ -18,6 +19,10 @@ scorer_name = "kenlm.scorer"
18
  model_link = f"{storage_url}/{model_name}"
19
  scorer_link = f"{storage_url}/{scorer_name}"
20
 
 
 
 
 
21
  def download(url, file_name):
22
  if not exists(file_name):
23
  print(f"Downloading {file_name}")
@@ -37,6 +42,16 @@ def deepspeech(audio: np.array, use_scorer=False):
37
 
38
  return result
39
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  def inference(audio: Tuple[int, np.array]):
42
  print("=============================")
@@ -50,7 +65,8 @@ def inference(audio: Tuple[int, np.array]):
50
 
51
  transcripts = []
52
 
53
- transcripts.append("")
 
54
  transcripts.append(deepspeech(audio, use_scorer=True))
55
  print(f"Deepspeech with LM: `{transcripts[-1]}`")
56
  transcripts.append(deepspeech(audio))
 
8
  from os.path import exists
9
  from stt import Model
10
  from datetime import datetime
11
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
12
+ import torch
13
 
14
  # download model
15
  version = "v0.4"
 
19
  model_link = f"{storage_url}/{model_name}"
20
  scorer_link = f"{storage_url}/{scorer_name}"
21
 
22
+ model = Wav2Vec2ForCTC.from_pretrained("robinhad/wav2vec2-xls-r-300m-uk")#.to("cuda")
23
+ processor = Wav2Vec2Processor.from_pretrained("robinhad/wav2vec2-xls-r-300m-uk")
24
+ # TODO: download config.json, pytorch_model.bin, preprocessor_config.json, tokenizer_config.json, vocab.json, added_tokens.json, special_tokens.json
25
+
26
  def download(url, file_name):
27
  if not exists(file_name):
28
  print(f"Downloading {file_name}")
 
42
 
43
  return result
44
 
45
+ def wav2vec2(audio: np.array):
46
+ input_dict = processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
47
+ with torch.no_grad():
48
+ output = model(input_dict.input_values.float())
49
+
50
+ logits = output.logits
51
+
52
+ pred_ids = torch.argmax(logits, dim=-1)[0]
53
+
54
+ return processor.decode(pred_ids)
55
 
56
  def inference(audio: Tuple[int, np.array]):
57
  print("=============================")
 
65
 
66
  transcripts = []
67
 
68
+ transcripts.append(wav2vec2(audio))
69
+ print(f"Wav2Vec2: `{transcripts[-1]}`")
70
  transcripts.append(deepspeech(audio, use_scorer=True))
71
  print(f"Deepspeech with LM: `{transcripts[-1]}`")
72
  transcripts.append(deepspeech(audio))
requirements-torch.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ -i https://download.pytorch.org/whl/cpu
2
+ torch==1.12
requirements.txt CHANGED
@@ -1,2 +1,4 @@
1
  STT==1.3.0
2
- pydub==0.25.1
 
 
 
1
  STT==1.3.0
2
+ transformers==4.21.2
3
+ pydub==0.25.1
4
+ -r requirements-torch.txt