Slava917 commited on
Commit
ed47f0e
1 Parent(s): 827d41a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -0
app.py CHANGED
@@ -2,4 +2,51 @@ import pandas as pd
2
  import gradio as gr
3
  import torch
4
  import torchaudio
 
 
 
 
 
 
 
 
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
  import torch
4
  import torchaudio
5
+ import warnings
6
+ from cryptography.utils import CryptographyDeprecationWarning
7
+ with warnings.catch_warnings():
8
+ warnings.filterwarnings('ignore', category=CryptographyDeprecationWarning)
9
+ import paramiko
10
+ torch._C._jit_override_can_fuse_on_cpu(False)
11
+ torch._C._jit_override_can_fuse_on_gpu(False)
12
+ torch._C._jit_set_texpr_fuser_enabled(False)
13
+ torch._C._jit_set_nvfuser_enabled(False)
14
 
15
+ loader = torch.jit.load("audio_loader.pt")
16
+ model = torch.jit.load('QuartzNet_thunderspeech_3.pt')
17
+
18
+ vocab = model.text_transform.vocab.itos
19
+ vocab[-1] = ''
20
+
21
+ def convert_probs(probs):
22
+ ids = probs.argmax(1)[0]
23
+ s = []
24
+ if vocab[ids[0]]: s.append(vocab[ids[0]])
25
+ for i in range(1,len(ids)):
26
+ if ids[i-1] != ids[i]:
27
+ new = vocab[ids[i]]
28
+ if new: s.append(new)
29
+ #return '.'.join(s)
30
+ return s
31
+
32
+
33
+ def predict(path):
34
+ audio = loader(path)
35
+ probs = model(audio, torch.tensor(audio.shape[0] * [audio.shape[-1]], device=audio.device))[0]
36
+ return convert_probs(probs)
37
+
38
+
39
+ from difflib import SequenceMatcher
40
+
41
+ def similar(a, b):
42
+ return SequenceMatcher(None, a, b).ratio()
43
+
44
+ def compare (word_choice, path):
45
+ etalon = df.loc[df['replica'] == word_choice, 'transcription'].values[0]
46
+ user = predict(path)
47
+ similar(user, etalon)
48
+
49
+
50
+ word_choice = gr.inputs.Dropdown(list(df['replica'].unique()), label="Choose a word")
51
+
52
+ gr.Interface(fn=compare, inputs=[gr.inputs.Audio(source='microphone', type='filepath', optional=True), word_choice], outputs= 'text').launch(debug=True)