Slava917 commited on
Commit
1554ec8
1 Parent(s): daeb975

Update gradio_interface.py

Browse files
Files changed (1) hide show
  1. gradio_interface.py +32 -5
gradio_interface.py CHANGED
@@ -1,7 +1,34 @@
1
- import gradio as gr
 
2
 
 
 
 
 
 
3
 
4
- gr.Interface(
5
- fn=transcribe,
6
- inputs=gr.inputs.Audio(source="microphone", type="filepath"),
7
- outputs="text").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
 
4
+ #fixes second prediction bug
5
+ torch._C._jit_override_can_fuse_on_cpu(False)
6
+ torch._C._jit_override_can_fuse_on_gpu(False)
7
+ torch._C._jit_set_texpr_fuser_enabled(False)
8
+ torch._C._jit_set_nvfuser_enabled(False)
9
 
10
+ loader = torch.jit.load("audio_loader.pt")
11
+ model = torch.jit.load('QuartzNet_thunderspeech_3.pt')
12
+
13
+ vocab = model.text_transform.vocab.itos
14
+ vocab[-1] = ''
15
+
16
+ def convert_probs(probs):
17
+ ids = probs.argmax(1)[0]
18
+ s = []
19
+ if vocab[ids[0]]: s.append(vocab[ids[0]])
20
+ for i in range(1,len(ids)):
21
+ if ids[i-1] != ids[i]:
22
+ new = vocab[ids[i]]
23
+ if new: s.append(new)
24
+ #return '.'.join(s)
25
+ return s
26
+
27
+ def predict(path):
28
+ audio = loader(path)
29
+ probs = model(audio, torch.tensor(audio.shape[0] * [audio.shape[-1]], device=audio.device))[0]
30
+ return convert_probs(probs)
31
+
32
+ gr.Interface(fn=predict,
33
+ inputs=[gr.inputs.Audio(source='microphone', type='filepath', optional=True)],
34
+ outputs= 'text').launch(debug=Tru