Slava917 commited on
Commit
df48e6f
1 Parent(s): 688bae9

Create predict.py

Browse files
Files changed (1) hide show
  1. predict.py +29 -0
predict.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+
4
+ #fixes second prediction bug
5
+ torch._C._jit_override_can_fuse_on_cpu(False)
6
+ torch._C._jit_override_can_fuse_on_gpu(False)
7
+ torch._C._jit_set_texpr_fuser_enabled(False)
8
+ torch._C._jit_set_nvfuser_enabled(False)
9
+
10
+ loader = torch.jit.load("audio_loader.pt")
11
+ model = torch.jit.load('QuartzNet_thunderspeech_3.pt')
12
+
13
+ vocab = model.text_transform.vocab.itos
14
+ vocab[-1] = ''
15
+
16
+ def convert_probs(probs):
17
+ ids = probs.argmax(1)[0]
18
+ s = []
19
+ if vocab[ids[0]]: s.append(vocab[ids[0]])
20
+ for i in range(1,len(ids)):
21
+ if ids[i-1] != ids[i]:
22
+ new = vocab[ids[i]]
23
+ if new: s.append(new)
24
+ return '.'.join(s)
25
+
26
+ def predict(path):
27
+ audio = loader(path)
28
+ probs = model(audio, torch.tensor(audio.shape[0] * [audio.shape[-1]], device=audio.device))[0]
29
+ return convert_probs(probs)