Shiry commited on
Commit
050271c
1 Parent(s): fc9008d

Add application file

Browse files
Files changed (2) hide show
  1. app.py +390 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ # os.system("pip install git+https://github.com/openai/whisper.git")
3
+ import gradio as gr
4
+ import whisper
5
+ import librosa
6
+ import plotly.express as px
7
+ from threading import Thread
8
+ from statistics import mode, mean
9
+ import time
10
+
11
+
12
+ model = whisper.load_model("large", device='cpu')
13
+ print('loaded whisper')
14
+
15
+ vad, vad_utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
16
+ model='silero_vad',
17
+ force_reload=False,
18
+ onnx=False)
19
+ print('loaded silero')
20
+ (get_speech_timestamps,
21
+ save_audio,
22
+ read_audio,
23
+ VADIterator,
24
+ collect_chunks) = vad_utils
25
+ vad_iterator = VADIterator(vad)
26
+
27
+ global x, y, j, audio_vec, transcribe, STOP, languages, not_detected, main_lang, STARTED
28
+ x = []
29
+ y = []
30
+ j = 0
31
+ STOP = False
32
+ audio_vec = torch.tensor([])
33
+ transcribe = ''
34
+ languages = []
35
+ not_detected = True
36
+ main_lang = ''
37
+ STARTED = False
38
+
39
+ css = """
40
+ .gradio-container {
41
+ font-family: 'IBM Plex Sans', sans-serif;
42
+ }
43
+ .gr-button {
44
+ color: white;
45
+ border-color: black;
46
+ background: black;
47
+ }
48
+ input[type='range'] {
49
+ accent-color: black;
50
+ }
51
+ .dark input[type='range'] {
52
+ accent-color: #dfdfdf;
53
+ }
54
+ .container {
55
+ max-width: 730px;
56
+ margin: auto;
57
+ padding-top: 1.5rem;
58
+ }
59
+ .details:hover {
60
+ text-decoration: underline;
61
+ }
62
+ .gr-button {
63
+ white-space: nowrap;
64
+ }
65
+ .gr-button:focus {
66
+ border-color: rgb(147 197 253 / var(--tw-border-opacity));
67
+ outline: none;
68
+ box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
69
+ --tw-border-opacity: 1;
70
+ --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
71
+ --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
72
+ --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
73
+ --tw-ring-opacity: .5;
74
+ }
75
+ .footer {
76
+ margin-bottom: 45px;
77
+ margin-top: 35px;
78
+ text-align: center;
79
+ border-bottom: 1px solid #e5e5e5;
80
+ }
81
+ .footer>p {
82
+ font-size: .8rem;
83
+ display: inline-block;
84
+ padding: 0 10px;
85
+ transform: translateY(10px);
86
+ background: white;
87
+ }
88
+ .dark .footer {
89
+ border-color: #303030;
90
+ }
91
+ .dark .footer>p {
92
+ background: #0b0f19;
93
+ }
94
+ .prompt h4{
95
+ margin: 1.25em 0 .25em 0;
96
+ font-weight: bold;
97
+ font-size: 115%;
98
+ }
99
+ .animate-spin {
100
+ animation: spin 1s linear infinite;
101
+ }
102
+ @keyframes spin {
103
+ from {
104
+ transform: rotate(0deg);
105
+ }
106
+ to {
107
+ transform: rotate(360deg);
108
+ }
109
+ }
110
+ #share-btn-container {
111
+ display: flex; margin-top: 1.5rem !important; padding-left: 0.5rem !important; padding-right: 0.5rem
112
+ !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px
113
+ !important; width: 13rem;
114
+ }
115
+ #share-btn {
116
+ all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif;
117
+ margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;
118
+ }
119
+ #share-btn * {
120
+ all: unset;
121
+ }
122
+ """
123
+
124
+
125
+ # def transcribe_chunk():
126
+ # print('********************************')
127
+ # global audio_vec, transcribe, STOP
128
+ # print('Enter trans chunk')
129
+ # counter = 0
130
+ # i = 0
131
+ # while not STOP:
132
+ # if audio_vec.size()[0] // 32000 > counter and audio_vec.size()[0] > 0:
133
+ # print('audio_vec.size()[0] % 32000', audio_vec.size()[0] % 32000)
134
+ # print('audio size', audio_vec.size()[0])
135
+ # chunk = whisper.pad_or_trim(audio_vec[32000*counter: 32000*(counter + 1)])
136
+ # mel_th = whisper.log_mel_spectrogram(chunk).to(model.device)
137
+ # options = whisper.DecodingOptions(fp16=False)
138
+ # result = whisper.decode(model, mel_th, options)
139
+ # no_speech_prob = result.no_speech_prob
140
+ # if no_speech_prob < 0.4:
141
+ # transcribe += result.text + ' '
142
+ # counter += 1
143
+ def transcribe_chunk(audio, vad_prob):
144
+ global languages
145
+ trnscrb = ''
146
+ audio = whisper.pad_or_trim(audio)
147
+ mel = whisper.log_mel_spectrogram(audio).to(model.device)
148
+ options = whisper.DecodingOptions(fp16= False, task='transcribe')
149
+ result = whisper.decode(model, mel, options)
150
+ no_speech_prob = result.no_speech_prob
151
+ mel = whisper.log_mel_spectrogram(audio).to(model.device)
152
+
153
+ _, probs = model.detect_language(mel)
154
+
155
+ temp_lang = max(probs, key=probs.get)
156
+
157
+ print(result.text, "no_speech_prob: ",no_speech_prob, 1 - vad_prob)
158
+ if no_speech_prob < 0.6:
159
+ trnscrb = result.text + ' '
160
+ languages.append(temp_lang)
161
+ if len(languages) > 3:
162
+ languages.pop(0)
163
+ return trnscrb
164
+
165
+
166
+ def inference(audio):
167
+ global x, y, j, audio_vec, transcribe, languages, not_detected, main_lang, STARTED
168
+ print('enter inference')
169
+ if j == 0:
170
+ thread.start()
171
+ STARTED = True
172
+ wav2 = whisper.load_audio(audio, sr=16000)
173
+ wav = torch.from_numpy(librosa.load(audio, sr=16000)[0])
174
+ audio_vec = torch.cat((audio_vec, wav))
175
+ speech_probs = []
176
+ window_size_samples = 1600
177
+ for i in range(0, len(wav), window_size_samples):
178
+ chunk = wav[i: i + window_size_samples]
179
+ if len(chunk) < window_size_samples:
180
+ break
181
+ speech_prob = vad(chunk, 16000).item()
182
+ speech_probs.append(speech_prob)
183
+ vad_iterator.reset_states()
184
+ sample_per_sec = 16000 / window_size_samples
185
+ x.extend([j + i / sample_per_sec for i in range(len(speech_probs))])
186
+ y.extend(speech_probs)
187
+ j = max(x)
188
+ fig = px.line(x=x, y=y)
189
+
190
+ whisper_audio = whisper.pad_or_trim(wav2)
191
+ mel = whisper.log_mel_spectrogram(whisper_audio).to(model.device)
192
+
193
+ _, probs = model.detect_language(mel)
194
+
195
+
196
+ temp_lang = max(probs, key=probs.get)
197
+ print(temp_lang)
198
+
199
+ languages.append(temp_lang)
200
+ if len(languages) > 5:
201
+ languages.pop(0)
202
+
203
+ curr_lang = mode(languages)
204
+ print(curr_lang, languages)
205
+
206
+ if curr_lang == 'iw':
207
+ return 'he', fig, gr.update(visible=True), transcribe, gr.update(visible=True), gr.update(visible=True)
208
+ return curr_lang, fig, gr.update(visible=True), transcribe, gr.update(visible=True), gr.update(visible=True)
209
+
210
+
211
+ def clear():
212
+ global x, y, j, audio_vec, transcribe, thread, STOP, languages, main_lang, not_detected ,STARTED
213
+ STOP = True
214
+ if STARTED:
215
+ thread.join()
216
+ STARTED = False
217
+ x = []
218
+ y = []
219
+ j = 0
220
+ audio_vec = torch.tensor([])
221
+ transcribe = ''
222
+ STOP = False
223
+ languages = []
224
+ main_lang = ''
225
+ not_detected = True
226
+ thread = Thread(target=transcribe_chunk)
227
+ print('clean:', x, y, j, transcribe, audio_vec)
228
+ return '', gr.update(visible=False), gr.update(visible=False), '', gr.update(visible=False), gr.update(visible=False),
229
+
230
+
231
+ def inference_file(audio):
232
+ time.sleep(0.8)
233
+ global x, y, j, audio_vec, transcribe, languages, not_detected, main_lang
234
+ wav = torch.from_numpy(librosa.load(audio, sr=16000)[0])
235
+ audio_vec = torch.cat((audio_vec, wav))
236
+ speech_probs = []
237
+ window_size_samples = 1600
238
+ for i in range(0, len(wav), window_size_samples):
239
+ chunk = wav[i: i + window_size_samples]
240
+ if len(chunk) < window_size_samples:
241
+ break
242
+ speech_prob = vad(chunk, 16000).item()
243
+ speech_probs.append(speech_prob)
244
+ vad_iterator.reset_states()
245
+ sample_per_sec = 16000 / window_size_samples
246
+ x.extend([j + i / sample_per_sec for i in range(len(speech_probs))])
247
+ y.extend(speech_probs)
248
+ j = max(x)
249
+ fig = px.line(x=x, y=y)
250
+
251
+ mean_speech_probs = mean(speech_probs)
252
+
253
+ if wav.shape[0] > 16000 * 30:
254
+ start = 0
255
+ end = 16000 * 30
256
+ chunk = wav[start:end]
257
+ chunk_idx = 0
258
+ while end < wav.shape[0]:
259
+ transcribe += transcribe_chunk(chunk)
260
+ chunk_idx += 1
261
+ start = chunk_idx * 30 * 16000
262
+ if start >= wav.shape[0]:
263
+ break
264
+ end = (chunk_idx + 1) * 30 * 16000
265
+ if end >= wav.shape[0]:
266
+ end = wav.shape[0] - 1
267
+ chunk = wav[start:end]
268
+ else:
269
+ transcribe += transcribe_chunk(wav, mean_speech_probs)
270
+
271
+ curr_lang = ''
272
+ if len(languages) > 0:
273
+ curr_lang = mode(languages)
274
+ print(curr_lang, languages)
275
+
276
+ if curr_lang == 'iw':
277
+ return 'he', fig, gr.update(visible=True), transcribe, gr.update(visible=True), gr.update(visible=True)
278
+ return curr_lang, fig, gr.update(visible=True), transcribe, gr.update(visible=True), gr.update(visible=True)
279
+
280
+
281
+ block = gr.Blocks(css=css)
282
+
283
+
284
+ def play_sound():
285
+ global audio_vec
286
+ import soundfile as sf
287
+ print(audio_vec)
288
+ sf.write('uploaded.wav', data=audio_vec, samplerate=16000)
289
+ from pygame import mixer
290
+ mixer.init()
291
+ mixer.music.load('uploaded.wav')
292
+ mixer.music.play()
293
+
294
+
295
+ def change_audio(string):
296
+ # if string == 'סטרימינג':
297
+ # return gr.Audio.update(source="microphone",), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
298
+ # else:
299
+ # return gr.Audio.update(source='upload'), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
300
+ if string == 'סטרימינג':
301
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), \
302
+ gr.update(visible=False), gr.update(visible=False)
303
+ elif string == 'הקלטה':
304
+ print('in mesholav')
305
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
306
+ gr.update(visible=True), gr.update(visible=True)
307
+ else:
308
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), \
309
+ gr.update(visible=False), gr.update(visible=False)
310
+
311
+
312
+ with block:
313
+ gr.HTML(
314
+ """
315
+ <div style="text-align: center; max-width: 650px; margin: 0 auto;">
316
+ <div
317
+ style="
318
+ display: inline-flex;
319
+ align-items: center;
320
+ gap: 0.8rem;
321
+ font-size: 1.75rem;
322
+ "
323
+ >
324
+ <h1 style="font-weight: 900; margin-bottom: 7px;">
325
+ Whisper
326
+ </h1>
327
+ </div>
328
+ </div>
329
+ """
330
+ )
331
+ with gr.Group():
332
+ plot = gr.Plot(show_label=False, visible=False)
333
+ with gr.Row(equal_height=True):
334
+ with gr.Box():
335
+ radio = gr.Radio(["סטרימינג", "הקלטה", "קובץ"], label="?איך תרצה לספק את האודיו")
336
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
337
+ audio = gr.Audio(
338
+
339
+ show_label=False,
340
+ source="microphone",
341
+ type="filepath",
342
+ visible=True
343
+
344
+ )
345
+ audio2 = gr.Audio(
346
+
347
+ label="Input Audio",
348
+ show_label=False,
349
+ source="upload",
350
+ type="filepath",
351
+ visible=False
352
+
353
+ )
354
+ audio3 = gr.Audio(
355
+ label="Input Audio",
356
+ show_label=False,
357
+ source="microphone",
358
+ type="filepath",
359
+ visible=False
360
+ )
361
+
362
+ trans_btn = gr.Button("Transcribe", visible=False)
363
+ trans_btn3 = gr.Button("Transcribe", visible=False)
364
+
365
+ text = gr.Textbox(show_label=False, elem_id="result-textarea")
366
+ text2 = gr.Textbox(show_label=False, elem_id="result-textarea")
367
+ with gr.Row():
368
+ clear_btn = gr.Button("Clear", visible=False)
369
+ play_btn = gr.Button('Play audio', visible=False)
370
+
371
+ radio.change(fn=change_audio, inputs=radio, outputs=[audio, trans_btn, audio2, trans_btn3, audio3])
372
+ trans_btn.click(inference_file, audio2, [text, plot, plot, text2, clear_btn, play_btn])
373
+ trans_btn3.click(inference_file, audio3, [text, plot, plot, text2, clear_btn, play_btn])
374
+ audio.stream(inference_file, audio, [text, plot, plot, text2, clear_btn, play_btn])
375
+ play_btn.click(play_sound)
376
+ clear_btn.click(clear, inputs=[], outputs=[text, plot, plot, text2, clear_btn, play_btn])
377
+
378
+ gr.HTML('''
379
+ <div class="footer">
380
+ <p>Model by Moses team - Whisper Demo
381
+ </p>
382
+ </div>
383
+ ''')
384
+ gr.HTML('''
385
+ <img style="text-align: center; max-width: 650px; margin: 0 auto;" src="https://geekflare.com/wp-content/uploads/2022/02/speechrecognitionapi.png", alt="Girl in a jacket" width="500" height="600">
386
+ ''')
387
+
388
+ global thread
389
+ thread = Thread(target=transcribe_chunk)
390
+ block.queue().launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ git+https://github.com/huggingface/transformers
2
+ torch
3
+ git+https://github.com/openai/whisper.git