Slava917 commited on
Commit
b150df1
1 Parent(s): 1554ec8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -0
app.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+
4
+ #fixes second prediction bug
5
+ torch._C._jit_override_can_fuse_on_cpu(False)
6
+ torch._C._jit_override_can_fuse_on_gpu(False)
7
+ torch._C._jit_set_texpr_fuser_enabled(False)
8
+ torch._C._jit_set_nvfuser_enabled(False)
9
+
10
+ loader = torch.jit.load("audio_loader.pt")
11
+ model = torch.jit.load('QuartzNet_thunderspeech_3.pt')
12
+
13
+ vocab = model.text_transform.vocab.itos
14
+ vocab[-1] = ''
15
+
16
+ def convert_probs(probs):
17
+ ids = probs.argmax(1)[0]
18
+ s = []
19
+ if vocab[ids[0]]: s.append(vocab[ids[0]])
20
+ for i in range(1,len(ids)):
21
+ if ids[i-1] != ids[i]:
22
+ new = vocab[ids[i]]
23
+ if new: s.append(new)
24
+ #return '.'.join(s)
25
+ return s
26
+
27
+ def predict(path):
28
+ audio = loader(path)
29
+ probs = model(audio, torch.tensor(audio.shape[0] * [audio.shape[-1]], device=audio.device))[0]
30
+ return convert_probs(probs)
31
+
32
+ gr.Interface(fn=predict,
33
+ inputs=[gr.inputs.Audio(source='microphone', type='filepath', optional=True)],
34
+ outputs= 'text').launch(debug=Tru