Yuekai Zhang commited on
Commit
62e5a8a
·
1 Parent(s): d67a714

update examples

Browse files
Files changed (2) hide show
  1. app_local.py +443 -0
  2. examples.py +11 -11
app_local.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ #
3
+ # Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
4
+ # 2023 Nvidia. (authors: Yuekai Zhang)
5
+ #
6
+ # See LICENSE for clarification regarding multiple authors
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+ # References:
21
+ # https://gradio.app/docs/#dropdown
22
+ # https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition
23
+
24
+ import logging
25
+ import os
26
+ import tempfile
27
+ import time
28
+ from datetime import datetime
29
+
30
+ import gradio as gr
31
+ import numpy as np
32
+ import urllib.request
33
+ import tritonclient
34
+ import tritonclient.grpc as grpcclient
35
+ from tritonclient.utils import np_to_triton_dtype
36
+ import soundfile
37
+
38
+ from examples import examples
39
+
40
+ def convert_to_wav(in_filename: str) -> str:
41
+ """Convert the input audio file to a wave file"""
42
+ out_filename = in_filename + ".wav"
43
+ if '.mp3' in in_filename:
44
+ _ = os.system(f"ffmpeg -y -i '{in_filename}' -acodec pcm_s16le -ac 1 -ar 16000 '{out_filename}'")
45
+ else:
46
+ _ = os.system(f"ffmpeg -hide_banner -y -i '{in_filename}' -ar 16000 '{out_filename}'")
47
+ return out_filename
48
+
49
+
50
+ def build_html_output(s: str, style: str = "result_item_success"):
51
+ return f"""
52
+ <div class='result'>
53
+ <div class='result_item {style}'>
54
+ {s}
55
+ </div>
56
+ </div>
57
+ """
58
+
59
+ def process_url(
60
+ language: str,
61
+ repo_id: str,
62
+ decoding_method: str,
63
+ whisper_prompt_textbox: str,
64
+ url: str,
65
+ server_url_textbox: str,
66
+ ):
67
+ logging.info(f"Processing URL: {url}")
68
+ with tempfile.NamedTemporaryFile() as f:
69
+ try:
70
+ urllib.request.urlretrieve(url, f.name)
71
+
72
+ return process(
73
+ in_filename=f.name,
74
+ language=language,
75
+ repo_id=repo_id,
76
+ decoding_method=decoding_method,
77
+ whisper_prompt_textbox=whisper_prompt_textbox,
78
+ server_url=server_url_textbox,
79
+ )
80
+ except Exception as e:
81
+ logging.info(str(e))
82
+ return "", build_html_output(str(e), "result_item_error")
83
+
84
+ def process_uploaded_file(
85
+ language: str,
86
+ repo_id: str,
87
+ decoding_method: str,
88
+ whisper_prompt_textbox: int,
89
+ in_filename: str,
90
+ server_url_textbox: str,
91
+ ):
92
+ if in_filename is None or in_filename == "":
93
+ return "", build_html_output(
94
+ "Please first upload a file and then click "
95
+ 'the button "submit for recognition"',
96
+ "result_item_error",
97
+ )
98
+
99
+ logging.info(f"Processing uploaded file: {in_filename}")
100
+ try:
101
+ return process(
102
+ in_filename=in_filename,
103
+ language=language,
104
+ repo_id=repo_id,
105
+ decoding_method=decoding_method,
106
+ whisper_prompt_textbox=whisper_prompt_textbox,
107
+ server_url=server_url_textbox,
108
+ )
109
+ except Exception as e:
110
+ logging.info(str(e))
111
+ return "", build_html_output(str(e), "result_item_error")
112
+
113
+
114
+ def process_microphone(
115
+ language: str,
116
+ repo_id: str,
117
+ decoding_method: str,
118
+ whisper_prompt_textbox: str,
119
+ in_filename: str,
120
+ server_url_textbox: str,
121
+ ):
122
+ if in_filename is None or in_filename == "":
123
+ return "", build_html_output(
124
+ "Please first click 'Record from microphone', speak, "
125
+ "click 'Stop recording', and then "
126
+ "click the button 'submit for recognition'",
127
+ "result_item_error",
128
+ )
129
+
130
+ logging.info(f"Processing microphone: {in_filename}")
131
+ try:
132
+ return process(
133
+ in_filename=in_filename,
134
+ language=language,
135
+ repo_id=repo_id,
136
+ decoding_method=decoding_method,
137
+ whisper_prompt_textbox=whisper_prompt_textbox,
138
+ server_url=server_url_textbox,
139
+ )
140
+ except Exception as e:
141
+ logging.info(str(e))
142
+ return "", build_html_output(str(e), "result_item_error")
143
+
144
+ def send_whisper(whisper_prompt, wav_path, model_name, triton_client, protocol_client, padding_duration=10):
145
+ waveform, sample_rate = soundfile.read(wav_path)
146
+ assert sample_rate == 16000, f"Only support 16k sample rate, but got {sample_rate}"
147
+ duration = int(len(waveform) / sample_rate)
148
+
149
+ # padding to nearset 10 seconds
150
+ samples = np.zeros(
151
+ (
152
+ 1,
153
+ padding_duration * sample_rate * ((duration // padding_duration) + 1),
154
+ ),
155
+ dtype=np.float32,
156
+ )
157
+
158
+ samples[0, : len(waveform)] = waveform
159
+
160
+ lengths = np.array([[len(waveform)]], dtype=np.int32)
161
+
162
+ inputs = [
163
+ protocol_client.InferInput(
164
+ "WAV", samples.shape, np_to_triton_dtype(samples.dtype)
165
+ ),
166
+ protocol_client.InferInput(
167
+ "TEXT_PREFIX", [1, 1], "BYTES"
168
+ ),
169
+ ]
170
+ inputs[0].set_data_from_numpy(samples)
171
+
172
+ input_data_numpy = np.array([whisper_prompt], dtype=object)
173
+ input_data_numpy = input_data_numpy.reshape((1, 1))
174
+ inputs[1].set_data_from_numpy(input_data_numpy)
175
+
176
+ outputs = [protocol_client.InferRequestedOutput("TRANSCRIPTS")]
177
+ # generate a random sequence id
178
+ sequence_id = np.random.randint(0, 1000000)
179
+
180
+ response = triton_client.infer(
181
+ model_name, inputs, request_id=str(sequence_id), outputs=outputs
182
+ )
183
+
184
+ decoding_results = response.as_numpy("TRANSCRIPTS")[0]
185
+ if type(decoding_results) == np.ndarray:
186
+ decoding_results = b" ".join(decoding_results).decode("utf-8")
187
+ else:
188
+ # For wenet
189
+ decoding_results = decoding_results.decode("utf-8")
190
+ return decoding_results, duration
191
+
192
+ def process(
193
+ language: str,
194
+ repo_id: str,
195
+ decoding_method: str,
196
+ whisper_prompt_textbox: str,
197
+ in_filename: str,
198
+ server_url: str,
199
+ ):
200
+ logging.info(f"language: {language}")
201
+ logging.info(f"repo_id: {repo_id}")
202
+ logging.info(f"decoding_method: {decoding_method}")
203
+ logging.info(f"whisper_prompt_textbox: {whisper_prompt_textbox}")
204
+ logging.info(f"in_filename: {in_filename}")
205
+
206
+ model_name = "whisper"
207
+ triton_client = grpcclient.InferenceServerClient(url=server_url, verbose=False)
208
+ protocol_client = grpcclient
209
+
210
+ filename = convert_to_wav(in_filename)
211
+
212
+ now = datetime.now()
213
+ date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
214
+ logging.info(f"Started at {date_time}")
215
+
216
+ start = time.time()
217
+
218
+ text, duration = send_whisper(whisper_prompt_textbox, filename, model_name, triton_client, protocol_client)
219
+
220
+ date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
221
+ end = time.time()
222
+
223
+ #metadata = torchaudio.info(filename)
224
+ #duration = metadata.num_frames / sample_rate
225
+ rtf = (end - start) / duration
226
+
227
+ logging.info(f"Finished at {date_time} s. Elapsed: {end - start: .3f} s")
228
+
229
+ info = f"""
230
+ Wave duration : {duration: .3f} s <br/>
231
+ Processing time: {end - start: .3f} s <br/>
232
+ RTF: {end - start: .3f}/{duration: .3f} = {rtf:.3f} <br/>
233
+ """
234
+ if rtf > 1:
235
+ info += (
236
+ "<br/>We are loading the model for the first run. "
237
+ "Please run again to measure the real RTF.<br/>"
238
+ )
239
+
240
+ logging.info(info)
241
+ logging.info(f"\nrepo_id: {repo_id}\nhyp: {text}")
242
+
243
+ return text, build_html_output(info)
244
+
245
+
246
+ title = "# Speech Recognition and Translation with Whisper"
247
+ description = """
248
+ This space shows how to do speech recognition and translation with Nvidia **Triton**.
249
+
250
+ Please visit
251
+ <https://huggingface.co/yuekai/model_repo_whisper_large_v2>
252
+ for triton speech recognition.
253
+
254
+ The service is running on a GPU based on triton server.
255
+
256
+ See more information by visiting the following links:
257
+
258
+ - <https://github.com/triton-inference-server>
259
+ - <https://github.com/yuekaizhang/Triton-ASR-Client/tree/main>
260
+ - <https://github.com/k2-fsa/sherpa/tree/master/triton>
261
+ - <https://github.com/wenet-e2e/wenet/tree/main/runtime/gpu>
262
+ - <https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/triton_gpu>
263
+
264
+ """
265
+
266
+ # css style is copied from
267
+ # https://huggingface.co/spaces/alphacep/asr/blob/main/app.py#L113
268
+ css = """
269
+ .result {display:flex;flex-direction:column}
270
+ .result_item {padding:15px;margin-bottom:8px;border-radius:15px;width:100%}
271
+ .result_item_success {background-color:mediumaquamarine;color:white;align-self:start}
272
+ .result_item_error {background-color:#ff7070;color:white;align-self:start}
273
+ """
274
+
275
+
276
+ # def update_model_dropdown(language: str):
277
+ # if language in language_to_models:
278
+ # choices = language_to_models[language]
279
+ # return gr.Dropdown.update(choices=choices, value=choices[0])
280
+
281
+ # raise ValueError(f"Unsupported language: {language}")
282
+
283
+
284
+ demo = gr.Blocks(css=css)
285
+
286
+
287
+ with demo:
288
+ gr.Markdown(title)
289
+ language_choices = ["Chinese", "English", "Chinese+English", "Korean", "Japanese", "Arabic", "German", "French", "Russian"]
290
+ server_url_textbox = gr.Textbox(
291
+ label='Triton Inference Server URL',
292
+ value='10.19.203.82:8001'
293
+ placeholder='e.g. localhost:8001',
294
+ max_lines=1,
295
+ )
296
+
297
+ whisper_prompt_textbox = gr.Textbox(
298
+ label='Whisper prompt',
299
+ placeholder='Whisper prompt e.g. <|startoftranscript|><zh><en><transcribe>',
300
+ max_lines=1,
301
+ )
302
+ language_radio = gr.Radio(
303
+ label="Language",
304
+ choices=language_choices,
305
+ value=language_choices[0],
306
+ )
307
+ model_dropdown = gr.Dropdown(
308
+ choices=["whisper-large-v2"],
309
+ label="Select a model",
310
+ value="whisper-large-v2",
311
+ )
312
+
313
+ # language_radio.change(
314
+ # update_model_dropdown,
315
+ # inputs=language_radio,
316
+ # outputs=model_dropdown,
317
+ # )
318
+
319
+ decoding_method_radio = gr.Radio(
320
+ label="Decoding method",
321
+ choices=["greedy_search"],
322
+ value="greedy_search",
323
+ )
324
+
325
+ # whisper_prompt_textbox_slider = gr.Slider(
326
+ # minimum=1,
327
+ # value=4,
328
+ # step=1,
329
+ # label="Number of active paths for modified_beam_search",
330
+ # )
331
+
332
+ with gr.Tabs():
333
+ with gr.TabItem("Upload from disk"):
334
+ uploaded_file = gr.Audio(
335
+ source="upload", # Choose between "microphone", "upload"
336
+ type="filepath",
337
+ optional=False,
338
+ label="Upload from disk",
339
+ )
340
+ upload_button = gr.Button("Submit for recognition")
341
+ uploaded_output = gr.Textbox(label="Recognized speech from uploaded file")
342
+ uploaded_html_info = gr.HTML(label="Info")
343
+
344
+ gr.Examples(
345
+ examples=examples,
346
+ inputs=[
347
+ language_radio,
348
+ model_dropdown,
349
+ decoding_method_radio,
350
+ whisper_prompt_textbox,
351
+ uploaded_file,
352
+ ],
353
+ outputs=[uploaded_output, uploaded_html_info],
354
+ fn=process_uploaded_file,
355
+ cache_examples=False,
356
+ )
357
+
358
+ with gr.TabItem("Record from microphone"):
359
+ microphone = gr.Audio(
360
+ source="microphone", # Choose between "microphone", "upload"
361
+ type="filepath",
362
+ optional=False,
363
+ label="Record from microphone",
364
+ )
365
+
366
+ record_button = gr.Button("Submit for recognition")
367
+ recorded_output = gr.Textbox(label="Recognized speech from recordings")
368
+ recorded_html_info = gr.HTML(label="Info")
369
+
370
+ gr.Examples(
371
+ examples=examples,
372
+ inputs=[
373
+ language_radio,
374
+ model_dropdown,
375
+ decoding_method_radio,
376
+ whisper_prompt_textbox,
377
+ microphone,
378
+ ],
379
+ outputs=[recorded_output, recorded_html_info],
380
+ fn=process_microphone,
381
+ cache_examples=False,
382
+ )
383
+
384
+ with gr.TabItem("From URL"):
385
+ url_textbox = gr.Textbox(
386
+ max_lines=1,
387
+ placeholder="URL to an audio file",
388
+ label="URL",
389
+ interactive=True,
390
+ )
391
+
392
+ url_button = gr.Button("Submit for recognition")
393
+ url_output = gr.Textbox(label="Recognized speech from URL")
394
+ url_html_info = gr.HTML(label="Info")
395
+
396
+ upload_button.click(
397
+ process_uploaded_file,
398
+ inputs=[
399
+ language_radio,
400
+ model_dropdown,
401
+ decoding_method_radio,
402
+ whisper_prompt_textbox,
403
+ uploaded_file,
404
+ server_url_textbox,
405
+ ],
406
+ outputs=[uploaded_output, uploaded_html_info],
407
+ )
408
+
409
+ record_button.click(
410
+ process_microphone,
411
+ inputs=[
412
+ language_radio,
413
+ model_dropdown,
414
+ decoding_method_radio,
415
+ whisper_prompt_textbox,
416
+ microphone,
417
+ server_url_textbox,
418
+ ],
419
+ outputs=[recorded_output, recorded_html_info],
420
+ )
421
+
422
+ url_button.click(
423
+ process_url,
424
+ inputs=[
425
+ language_radio,
426
+ model_dropdown,
427
+ decoding_method_radio,
428
+ whisper_prompt_textbox,
429
+ url_textbox,
430
+ server_url_textbox,
431
+ ],
432
+ outputs=[url_output, url_html_info],
433
+ )
434
+
435
+ gr.Markdown(description)
436
+
437
+
438
+ if __name__ == "__main__":
439
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
440
+
441
+ logging.basicConfig(format=formatter, level=logging.INFO)
442
+
443
+ demo.launch(share=True)
examples.py CHANGED
@@ -20,49 +20,49 @@ examples = [
20
  "Chinese+English",
21
  "whisper-large-v2",
22
  "greedy_search",
23
- "<|startoftranscript|><|zh|><|en|><|transcribe|><notimestamps>",
24
  "./test_wavs/tal_csasr/0.wav",
25
  ],
26
  [
27
  "Chinese",
28
  "whisper-large-v2",
29
  "greedy_search",
30
- "<|startofprev|>法律<|startoftranscript|><|zh|><|transcribe|><notimestamps>",
31
  "./test_wavs/mini_zh/mid.wav",
32
  ],
33
  [
34
  "Japanese",
35
  "whisper-large-v2",
36
  "greedy_search",
37
- "<|startoftranscript|><|jp|><|transcribe|><notimestamps>",
38
  "./test_wavs/fleurs/7760285811293653093.wav",
39
  ],
40
  [
41
  "Korean",
42
  "whisper-large-v2",
43
  "greedy_search",
44
- "<|startoftranscript|><|ko|><|translate|><notimestamps>",
45
  "./test_wavs/fleurs/15029788401146217023.wav",
46
  ],
47
  [
48
  "Korean",
49
  "whisper-large-v2",
50
  "greedy_search",
51
- "<|startoftranscript|><|ko|><|transcribe|><notimestamps>",
52
  "./test_wavs/fleurs/15029788401146217023.wav",
53
  ],
54
  [
55
  "Japanese",
56
  "whisper-large-v2",
57
  "greedy_search",
58
- "<|startoftranscript|><|en|><|transcribe|><notimestamps>",
59
  "./test_wavs/fleurs/7760285811293653093.wav",
60
  ],
61
  [
62
  "English",
63
  "whisper-large-v2",
64
  "greedy_search",
65
- "<|startoftranscript|><|en|><|transcribe|><notimestamps>",
66
  "./test_wavs/librispeech/1089-134686-0001.wav",
67
  ],
68
  # [
@@ -76,7 +76,7 @@ examples = [
76
  "Russian",
77
  "whisper-large-v2",
78
  "greedy_search",
79
- "<|startoftranscript|><|ru|><|transcribe|><notimestamps>",
80
  "./test_wavs/russian/russian-i-love-you.wav",
81
  ],
82
  # [
@@ -90,14 +90,14 @@ examples = [
90
  "German",
91
  "whisper-large-v2",
92
  "greedy_search",
93
- "<|startoftranscript|><|de|><|transcribe|><notimestamps>",
94
  "./test_wavs/german/20170517-0900-PLENARY-16-de_20170517.wav",
95
  ],
96
  [
97
  "Arabic",
98
  "whisper-large-v2",
99
  "greedy_search",
100
- "<|startoftranscript|><|ar|><|transcribe|><notimestamps>",
101
  "./test_wavs/arabic/a.wav",
102
  ],
103
  # [
@@ -111,7 +111,7 @@ examples = [
111
  "French",
112
  "whisper-large-v2",
113
  "greedy_search",
114
- "<|startoftranscript|><|fr|><|transcribe|><notimestamps>",
115
  "./test_wavs/french/common_voice_fr_19364697.wav",
116
  ],
117
  # [
 
20
  "Chinese+English",
21
  "whisper-large-v2",
22
  "greedy_search",
23
+ "<|startoftranscript|><|zh|><|en|><|transcribe|><|notimestamps|>",
24
  "./test_wavs/tal_csasr/0.wav",
25
  ],
26
  [
27
  "Chinese",
28
  "whisper-large-v2",
29
  "greedy_search",
30
+ "<|startofprev|>热词:获刑<|startoftranscript|><|zh|><|transcribe|><|notimestamps|>",
31
  "./test_wavs/mini_zh/mid.wav",
32
  ],
33
  [
34
  "Japanese",
35
  "whisper-large-v2",
36
  "greedy_search",
37
+ "<|startoftranscript|><|jp|><|transcribe|><|notimestamps|>",
38
  "./test_wavs/fleurs/7760285811293653093.wav",
39
  ],
40
  [
41
  "Korean",
42
  "whisper-large-v2",
43
  "greedy_search",
44
+ "<|startoftranscript|><|ko|><|translate|><|notimestamps|>",
45
  "./test_wavs/fleurs/15029788401146217023.wav",
46
  ],
47
  [
48
  "Korean",
49
  "whisper-large-v2",
50
  "greedy_search",
51
+ "<|startoftranscript|><|ko|><|transcribe|><|notimestamps|>",
52
  "./test_wavs/fleurs/15029788401146217023.wav",
53
  ],
54
  [
55
  "Japanese",
56
  "whisper-large-v2",
57
  "greedy_search",
58
+ "<|startoftranscript|><|en|><|translate|><|notimestamps|>",
59
  "./test_wavs/fleurs/7760285811293653093.wav",
60
  ],
61
  [
62
  "English",
63
  "whisper-large-v2",
64
  "greedy_search",
65
+ "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
66
  "./test_wavs/librispeech/1089-134686-0001.wav",
67
  ],
68
  # [
 
76
  "Russian",
77
  "whisper-large-v2",
78
  "greedy_search",
79
+ "<|startoftranscript|><|ru|><|transcribe|><|notimestamps|>",
80
  "./test_wavs/russian/russian-i-love-you.wav",
81
  ],
82
  # [
 
90
  "German",
91
  "whisper-large-v2",
92
  "greedy_search",
93
+ "<|startoftranscript|><|de|><|transcribe|><|notimestamps|>",
94
  "./test_wavs/german/20170517-0900-PLENARY-16-de_20170517.wav",
95
  ],
96
  [
97
  "Arabic",
98
  "whisper-large-v2",
99
  "greedy_search",
100
+ "<|startoftranscript|><|ar|><|transcribe|><|notimestamps|>",
101
  "./test_wavs/arabic/a.wav",
102
  ],
103
  # [
 
111
  "French",
112
  "whisper-large-v2",
113
  "greedy_search",
114
+ "<|startoftranscript|><|fr|><|transcribe|><|notimestamps|>",
115
  "./test_wavs/french/common_voice_fr_19364697.wav",
116
  ],
117
  # [