JanDalhuysen smajumdar commited on
Commit
d76d8fd
0 Parent(s):

Duplicate from smajumdar/nemo_multilingual_language_id

Browse files

Co-authored-by: Somshubra Majumdar <smajumdar@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.npy filter=lfs diff=lfs merge=lfs -text
14
+ *.npz filter=lfs diff=lfs merge=lfs -text
15
+ *.onnx filter=lfs diff=lfs merge=lfs -text
16
+ *.ot filter=lfs diff=lfs merge=lfs -text
17
+ *.parquet filter=lfs diff=lfs merge=lfs -text
18
+ *.pb filter=lfs diff=lfs merge=lfs -text
19
+ *.pickle filter=lfs diff=lfs merge=lfs -text
20
+ *.pkl filter=lfs diff=lfs merge=lfs -text
21
+ *.pt filter=lfs diff=lfs merge=lfs -text
22
+ *.pth filter=lfs diff=lfs merge=lfs -text
23
+ *.rar filter=lfs diff=lfs merge=lfs -text
24
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
25
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
26
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
27
+ *.tflite filter=lfs diff=lfs merge=lfs -text
28
+ *.tgz filter=lfs diff=lfs merge=lfs -text
29
+ *.wasm filter=lfs diff=lfs merge=lfs -text
30
+ *.xz filter=lfs diff=lfs merge=lfs -text
31
+ *.zip filter=lfs diff=lfs merge=lfs -text
32
+ *.zst filter=lfs diff=lfs merge=lfs -text
33
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Nemo Multilingual Language Id
3
+ emoji: 🐠
4
+ colorFrom: blue
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 3.17.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: smajumdar/nemo_multilingual_language_id
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import shutil
4
+ import uuid
5
+ import tempfile
6
+ import subprocess
7
+ import re
8
+ import time
9
+ import traceback
10
+
11
+ import gradio as gr
12
+ import pytube as pt
13
+
14
+ import nemo.collections.asr as nemo_asr
15
+ import torch
16
+
17
+ import speech_to_text_buffered_infer_ctc as buffered_ctc
18
+ import speech_to_text_buffered_infer_rnnt as buffered_rnnt
19
+ from nemo.utils import logging
20
+
21
+ # Set NeMo cache dir as /tmp
22
+ from nemo import constants
23
+
24
+ os.environ[constants.NEMO_ENV_CACHE_DIR] = "/tmp/nemo/"
25
+
26
+
27
+ SAMPLE_RATE = 16000 # Default sample rate for ASR
28
+ BUFFERED_INFERENCE_DURATION_THRESHOLD = 60.0 # 60 second and above will require chunked inference.
29
+ CHUNK_LEN_IN_SEC = 20.0 # Chunk size
30
+ BUFFER_LEN_IN_SEC = 30.0 # Total buffer size
31
+
32
+ TITLE = "NeMo ASR Inference on Hugging Face"
33
+ DESCRIPTION = "Demo of all languages supported by NeMo ASR"
34
+ DEFAULT_EN_MODEL = "nvidia/stt_en_conformer_transducer_xlarge"
35
+ DEFAULT_BUFFERED_EN_MODEL = "nvidia/stt_en_conformer_transducer_large"
36
+
37
+ # Pre-download and cache the model in disk space
38
+ logging.setLevel(logging.ERROR)
39
+ tmp_model = nemo_asr.models.ASRModel.from_pretrained(DEFAULT_BUFFERED_EN_MODEL, map_location='cpu')
40
+ del tmp_model
41
+ logging.setLevel(logging.INFO)
42
+
43
+ MARKDOWN = f"""
44
+ # {TITLE}
45
+
46
+ ## {DESCRIPTION}
47
+ """
48
+
49
+ CSS = """
50
+ p.big {
51
+ font-size: 20px;
52
+ }
53
+
54
+ /* From https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition/blob/main/app.py */
55
+
56
+ .result {display:flex;flex-direction:column}
57
+ .result_item {padding:15px;margin-bottom:8px;border-radius:15px;width:100%;font-size:20px;}
58
+ .result_item_success {background-color:mediumaquamarine;color:white;align-self:start}
59
+ .result_item_error {background-color:#ff7070;color:white;align-self:start}
60
+ """
61
+
62
+ ARTICLE = """
63
+ <br><br>
64
+ <p class='big' style='text-align: center'>
65
+ <a href='https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/intro.html' target='_blank'>NeMo ASR</a>
66
+ |
67
+ <a href='https://github.com/NVIDIA/NeMo#nvidia-nemo' target='_blank'>Github Repo</a>
68
+ </p>
69
+ """
70
+
71
+ SUPPORTED_LANGUAGES = set([])
72
+ SUPPORTED_MODEL_NAMES = set([])
73
+
74
+ # HF models, grouped by language identifier
75
+ hf_filter = nemo_asr.models.ASRModel.get_hf_model_filter()
76
+ hf_filter.task = "automatic-speech-recognition"
77
+
78
+ hf_infos = nemo_asr.models.ASRModel.search_huggingface_models(model_filter=hf_filter)
79
+ for info in hf_infos:
80
+ print("Model ID:", info.modelId)
81
+ try:
82
+ lang_id = info.modelId.split("_")[1] # obtains lang id as str
83
+ except Exception:
84
+ print("WARNING: Skipping model id -", info)
85
+ continue
86
+
87
+ SUPPORTED_LANGUAGES.add(lang_id)
88
+ SUPPORTED_MODEL_NAMES.add(info.modelId)
89
+
90
+ SUPPORTED_MODEL_NAMES = sorted(list(SUPPORTED_MODEL_NAMES))
91
+
92
+ # DEBUG FILTER
93
+ # SUPPORTED_MODEL_NAMES = list(filter(lambda x: "en" in x and "conformer_transducer_large" in x, SUPPORTED_MODEL_NAMES))
94
+
95
+ model_dict = {}
96
+ for model_name in SUPPORTED_MODEL_NAMES:
97
+ try:
98
+ iface = gr.Interface.load(f'models/{model_name}')
99
+ model_dict[model_name] = iface
100
+
101
+ # model_dict[model_name] = None
102
+ except:
103
+ pass
104
+
105
+ if DEFAULT_EN_MODEL in model_dict:
106
+ # Preemptively load the default EN model
107
+ if model_dict[DEFAULT_EN_MODEL] is None:
108
+ model_dict[DEFAULT_EN_MODEL] = gr.Interface.load(f'models/{DEFAULT_EN_MODEL}')
109
+
110
+ SUPPORTED_LANG_MODEL_DICT = {}
111
+ for lang in SUPPORTED_LANGUAGES:
112
+ for model_id in SUPPORTED_MODEL_NAMES:
113
+ if ("_" + lang + "_") in model_id:
114
+ # create new lang in dict
115
+ if lang not in SUPPORTED_LANG_MODEL_DICT:
116
+ SUPPORTED_LANG_MODEL_DICT[lang] = [model_id]
117
+ else:
118
+ SUPPORTED_LANG_MODEL_DICT[lang].append(model_id)
119
+
120
+ # Sort model names
121
+ for lang in SUPPORTED_LANG_MODEL_DICT.keys():
122
+ model_ids = SUPPORTED_LANG_MODEL_DICT[lang]
123
+ model_ids = sorted(model_ids)
124
+ SUPPORTED_LANG_MODEL_DICT[lang] = model_ids
125
+
126
+
127
+ def get_device():
128
+ gpu_available = torch.cuda.is_available()
129
+ if gpu_available:
130
+ return torch.cuda.get_device_name()
131
+ else:
132
+ return "CPU"
133
+
134
+
135
+ def parse_duration(audio_file):
136
+ """
137
+ FFMPEG to calculate durations. Libraries can do it too, but filetypes cause different libraries to behave differently.
138
+ """
139
+ process = subprocess.Popen(['ffmpeg', '-i', audio_file], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
140
+ stdout, stderr = process.communicate()
141
+ matches = re.search(
142
+ r"Duration:\s{1}(?P<hours>\d+?):(?P<minutes>\d+?):(?P<seconds>\d+\.\d+?),", stdout.decode(), re.DOTALL
143
+ ).groupdict()
144
+
145
+ duration = 0.0
146
+ duration += float(matches['hours']) * 60.0 * 60.0
147
+ duration += float(matches['minutes']) * 60.0
148
+ duration += float(matches['seconds']) * 1.0
149
+ return duration
150
+
151
+
152
+ def resolve_model_type(model_name: str) -> str:
153
+ """
154
+ Map model name to a class type, without loading the model. Has some hardcoded assumptions in
155
+ semantics of model naming.
156
+ """
157
+ # Loss specific maps
158
+ if 'hybrid' in model_name or 'hybrid_ctc' in model_name or 'hybrid_transducer' in model_name:
159
+ return 'hybrid'
160
+ elif 'transducer' in model_name or 'rnnt' in model_id:
161
+ return 'transducer'
162
+ elif 'ctc' in model_name:
163
+ return 'ctc'
164
+
165
+ # Model specific maps
166
+ if 'jasper' in model_name:
167
+ return 'ctc'
168
+ elif 'quartznet' in model_name:
169
+ return 'ctc'
170
+ elif 'citrinet' in model_name:
171
+ return 'ctc'
172
+ elif 'contextnet' in model_name:
173
+ return 'transducer'
174
+
175
+ return None
176
+
177
+
178
+ def resolve_model_stride(model_name) -> int:
179
+ """
180
+ Model specific pre-calc of stride levels.
181
+ Dont laod model to get such info.
182
+ """
183
+ if 'jasper' in model_name:
184
+ return 2
185
+ if 'quartznet' in model_name:
186
+ return 2
187
+ if 'conformer' in model_name:
188
+ return 4
189
+ if 'squeezeformer' in model_name:
190
+ return 4
191
+ if 'citrinet' in model_name:
192
+ return 8
193
+ if 'contextnet' in model_name:
194
+ return 8
195
+
196
+ return -1
197
+
198
+
199
+ def convert_audio(audio_filepath):
200
+ """
201
+ Transcode all mp3 files to monochannel 16 kHz wav files.
202
+ """
203
+ filedir = os.path.split(audio_filepath)[0]
204
+ filename, ext = os.path.splitext(audio_filepath)
205
+
206
+ if ext == 'wav':
207
+ return audio_filepath
208
+
209
+ out_filename = os.path.join(filedir, filename + '.wav')
210
+
211
+ process = subprocess.Popen(
212
+ ['ffmpeg', '-y', '-i', audio_filepath, '-ac', '1', '-ar', str(SAMPLE_RATE), out_filename],
213
+ stdout=subprocess.PIPE,
214
+ stderr=subprocess.STDOUT,
215
+ close_fds=True,
216
+ )
217
+
218
+ stdout, stderr = process.communicate()
219
+
220
+ if os.path.exists(out_filename):
221
+ return out_filename
222
+ else:
223
+ return None
224
+
225
+
226
+ def extract_result_from_manifest(filepath, model_name) -> (bool, str):
227
+ """
228
+ Parse the written manifest which is result of the buffered inference process.
229
+ """
230
+ data = []
231
+ with open(filepath, 'r', encoding='utf-8') as f:
232
+ for line in f:
233
+ try:
234
+ line = json.loads(line)
235
+ data.append(line['pred_text'])
236
+ except Exception as e:
237
+ pass
238
+
239
+ if len(data) > 0:
240
+ return True, data[0]
241
+ else:
242
+ return False, f"Could not perform inference on model with name : {model_name}"
243
+
244
+
245
+ def build_html_output(s: str, style: str = "result_item_success"):
246
+ return f"""
247
+ <div class='result'>
248
+ <div class='result_item {style}'>
249
+ {s}
250
+ </div>
251
+ </div>
252
+ """
253
+
254
+
255
+ def infer_audio(model_name: str, audio_file: str) -> str:
256
+ """
257
+ Main method that switches from HF inference for small audio files to Buffered CTC/RNNT mode for long audio files.
258
+
259
+ Args:
260
+ model_name: Str name of the model (potentially with / to denote HF models)
261
+ audio_file: Path to an audio file (mp3 or wav)
262
+
263
+ Returns:
264
+ str which is the transcription if successful.
265
+ str which is HTML output of logs.
266
+ """
267
+ # Parse the duration of the audio file
268
+ duration = parse_duration(audio_file)
269
+
270
+ if duration > BUFFERED_INFERENCE_DURATION_THRESHOLD: # Longer than one minute; use buffered mode
271
+ # Process audio to be of wav type (possible youtube audio)
272
+ audio_file = convert_audio(audio_file)
273
+
274
+ # If audio file transcoding failed, let user know
275
+ if audio_file is None:
276
+ return "Error:- Failed to convert audio file to wav."
277
+
278
+ # Extract audio dir from resolved audio filepath
279
+ audio_dir = os.path.split(audio_file)[0]
280
+
281
+ # Next calculate the stride of each model
282
+ model_stride = resolve_model_stride(model_name)
283
+
284
+ if model_stride < 0:
285
+ return f"Error:- Failed to compute the model stride for model with name : {model_name}"
286
+
287
+ # Process model type (CTC/RNNT/Hybrid)
288
+ model_type = resolve_model_type(model_name)
289
+
290
+ if model_type is None:
291
+
292
+ # Model type could not be infered.
293
+ # Try all feasible options
294
+ RESULT = None
295
+
296
+ try:
297
+ ctc_config = buffered_ctc.TranscriptionConfig(
298
+ pretrained_name=model_name,
299
+ audio_dir=audio_dir,
300
+ output_filename="output.json",
301
+ audio_type="wav",
302
+ overwrite_transcripts=True,
303
+ model_stride=model_stride,
304
+ chunk_len_in_secs=20.0,
305
+ total_buffer_in_secs=30.0,
306
+ )
307
+
308
+ buffered_ctc.main(ctc_config)
309
+ result = extract_result_from_manifest('output.json', model_name)
310
+ if result[0]:
311
+ RESULT = result[1]
312
+
313
+ except Exception as e:
314
+ pass
315
+
316
+ try:
317
+ rnnt_config = buffered_rnnt.TranscriptionConfig(
318
+ pretrained_name=model_name,
319
+ audio_dir=audio_dir,
320
+ output_filename="output.json",
321
+ audio_type="wav",
322
+ overwrite_transcripts=True,
323
+ model_stride=model_stride,
324
+ chunk_len_in_secs=20.0,
325
+ total_buffer_in_secs=30.0,
326
+ )
327
+
328
+ buffered_rnnt.main(rnnt_config)
329
+ result = extract_result_from_manifest('output.json', model_name)[-1]
330
+
331
+ if result[0]:
332
+ RESULT = result[1]
333
+ except Exception as e:
334
+ pass
335
+
336
+ if RESULT is None:
337
+ return f"Error:- Could not parse model type; failed to perform inference with model {model_name}!"
338
+
339
+ elif model_type == 'ctc':
340
+
341
+ # CTC Buffered Inference
342
+ ctc_config = buffered_ctc.TranscriptionConfig(
343
+ pretrained_name=model_name,
344
+ audio_dir=audio_dir,
345
+ output_filename="output.json",
346
+ audio_type="wav",
347
+ overwrite_transcripts=True,
348
+ model_stride=model_stride,
349
+ chunk_len_in_secs=20.0,
350
+ total_buffer_in_secs=30.0,
351
+ )
352
+
353
+ buffered_ctc.main(ctc_config)
354
+ return extract_result_from_manifest('output.json', model_name)[-1]
355
+
356
+ elif model_type == 'transducer':
357
+
358
+ # RNNT Buffered Inference
359
+ rnnt_config = buffered_rnnt.TranscriptionConfig(
360
+ pretrained_name=model_name,
361
+ audio_dir=audio_dir,
362
+ output_filename="output.json",
363
+ audio_type="wav",
364
+ overwrite_transcripts=True,
365
+ model_stride=model_stride,
366
+ chunk_len_in_secs=20.0,
367
+ total_buffer_in_secs=30.0,
368
+ )
369
+
370
+ buffered_rnnt.main(rnnt_config)
371
+ return extract_result_from_manifest('output.json', model_name)[-1]
372
+
373
+ else:
374
+ return f"Error:- Could not parse model type; failed to perform inference with model {model_name}!"
375
+
376
+ else:
377
+ # Obtain Gradio Model function from cache of models
378
+ if model_name in model_dict:
379
+ model = model_dict[model_name]
380
+
381
+ if model is None:
382
+ # Load the gradio interface
383
+ # try:
384
+ iface = gr.Interface.load(f'models/{model_name}')
385
+ print(iface)
386
+ # except:
387
+ # iface = None
388
+
389
+ if iface is not None:
390
+ # Update model cache
391
+ model_dict[model_name] = iface
392
+ else:
393
+ model = None
394
+
395
+ if model is not None:
396
+ # Use HF API for transcription
397
+ try:
398
+ transcriptions = model(audio_file)
399
+ return transcriptions
400
+ except Exception as e:
401
+ transcriptions = ""
402
+ error = ""
403
+
404
+ error += (
405
+ f"The model `{model_name}` is currently loading and cannot be used "
406
+ f"for transcription.<br>"
407
+ f"Please try another model or wait a few minutes."
408
+ )
409
+
410
+ return error
411
+
412
+ else:
413
+ error = (
414
+ f"Error:- Could not find model {model_name} in list of available models : "
415
+ f"{list([k for k in model_dict.keys()])}"
416
+ )
417
+ return error
418
+
419
+
420
+ def transcribe(microphone, audio_file, model_name):
421
+
422
+ audio_data = None
423
+ warn_output = ""
424
+ if (microphone is not None) and (audio_file is not None):
425
+ warn_output = (
426
+ "WARNING: You've uploaded an audio file and used the microphone. "
427
+ "The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
428
+ )
429
+ audio_data = microphone
430
+
431
+ elif (microphone is None) and (audio_file is None):
432
+ warn_output = "ERROR: You have to either use the microphone or upload an audio file"
433
+
434
+ elif microphone is not None:
435
+ audio_data = microphone
436
+ else:
437
+ audio_data = audio_file
438
+
439
+ if audio_data is not None:
440
+ audio_duration = parse_duration(audio_data)
441
+ else:
442
+ audio_duration = None
443
+
444
+ time_diff = None
445
+ try:
446
+ with tempfile.TemporaryDirectory() as tempdir:
447
+ filename = os.path.split(audio_data)[-1]
448
+ new_audio_data = os.path.join(tempdir, filename)
449
+ shutil.copy2(audio_data, new_audio_data)
450
+
451
+ if os.path.exists(audio_data):
452
+ os.remove(audio_data)
453
+
454
+ audio_data = new_audio_data
455
+
456
+ # Use HF API for transcription
457
+ start = time.time()
458
+ transcriptions = infer_audio(model_name, audio_data)
459
+ end = time.time()
460
+ time_diff = end - start
461
+
462
+ except Exception as e:
463
+ transcriptions = ""
464
+ warn_output = warn_output
465
+
466
+ if warn_output != "":
467
+ warn_output += "<br><br>"
468
+
469
+ warn_output += (
470
+ f"The model `{model_name}` is currently loading and cannot be used "
471
+ f"for transcription.<br>"
472
+ f"Please try another model or wait a few minutes."
473
+ )
474
+
475
+ # Built HTML output
476
+ if warn_output != "":
477
+ html_output = build_html_output(warn_output, style="result_item_error")
478
+ else:
479
+ if transcriptions.startswith("Error:-"):
480
+ html_output = build_html_output(transcriptions, style="result_item_error")
481
+ else:
482
+ output = f"Successfully transcribed on {get_device()} ! <br>" f"Transcription Time : {time_diff: 0.3f} s"
483
+
484
+ if audio_duration > BUFFERED_INFERENCE_DURATION_THRESHOLD:
485
+ output += f""" <br><br>
486
+ Note: Audio duration was {audio_duration: 0.3f} s, so model had to be downloaded, initialized, and then
487
+ buffered inference was used. <br>
488
+ """
489
+
490
+ html_output = build_html_output(output)
491
+
492
+ return transcriptions, html_output
493
+
494
+
495
+ def _return_yt_html_embed(yt_url):
496
+ """ Obtained from https://huggingface.co/spaces/whisper-event/whisper-demo """
497
+ video_id = yt_url.split("?v=")[-1]
498
+ HTML_str = (
499
+ f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
500
+ " </center>"
501
+ )
502
+ return HTML_str
503
+
504
+
505
+ def yt_transcribe(yt_url: str, model_name: str):
506
+ """ Modified from https://huggingface.co/spaces/whisper-event/whisper-demo """
507
+ if yt_url == "":
508
+ text = ""
509
+ html_embed_str = ""
510
+ html_output = build_html_output(f"""
511
+ Error:- No YouTube URL was provide !
512
+ """, style='result_item_error')
513
+ return text, html_embed_str, html_output
514
+
515
+ yt = pt.YouTube(yt_url)
516
+ html_embed_str = _return_yt_html_embed(yt_url)
517
+
518
+ with tempfile.TemporaryDirectory() as tempdir:
519
+ file_uuid = str(uuid.uuid4().hex)
520
+ file_uuid = f"{tempdir}/{file_uuid}.mp3"
521
+
522
+ # Download YT Audio temporarily
523
+ download_time_start = time.time()
524
+
525
+ stream = yt.streams.filter(only_audio=True)[0]
526
+ stream.download(filename=file_uuid)
527
+
528
+ download_time_end = time.time()
529
+
530
+ # Get audio duration
531
+ audio_duration = parse_duration(file_uuid)
532
+
533
+ # Perform transcription
534
+ infer_time_start = time.time()
535
+
536
+ text = infer_audio(model_name, file_uuid)
537
+
538
+ infer_time_end = time.time()
539
+
540
+ if text.startswith("Error:-"):
541
+ html_output = build_html_output(text, style='result_item_error')
542
+ else:
543
+ html_output = f"""
544
+ Successfully transcribed on {get_device()} ! <br>
545
+ Audio Download Time : {download_time_end - download_time_start: 0.3f} s <br>
546
+ Transcription Time : {infer_time_end - infer_time_start: 0.3f} s <br>
547
+ """
548
+
549
+ if audio_duration > BUFFERED_INFERENCE_DURATION_THRESHOLD:
550
+ html_output += f""" <br>
551
+ Note: Audio duration was {audio_duration: 0.3f} s, so model had to be downloaded, initialized, and then
552
+ buffered inference was used. <br>
553
+ """
554
+
555
+ html_output = build_html_output(html_output)
556
+
557
+ return text, html_embed_str, html_output
558
+
559
+
560
+ def create_lang_selector_component(default_en_model=DEFAULT_EN_MODEL):
561
+ """
562
+ Utility function to select a langauge from a dropdown menu, and simultanously update another dropdown
563
+ containing the corresponding model checkpoints for that language.
564
+
565
+ Args:
566
+ default_en_model: str name of a default english model that should be the set default.
567
+
568
+ Returns:
569
+ Gradio components for lang_selector (Dropdown menu) and models_in_lang (Dropdown menu)
570
+ """
571
+ lang_selector = gr.components.Dropdown(
572
+ choices=sorted(list(SUPPORTED_LANGUAGES)), value="en", type="value", label="Languages", interactive=True,
573
+ )
574
+ models_in_lang = gr.components.Dropdown(
575
+ choices=sorted(list(SUPPORTED_LANG_MODEL_DICT["en"])),
576
+ value=default_en_model,
577
+ label="Models",
578
+ interactive=True,
579
+ )
580
+
581
+ def update_models_with_lang(lang):
582
+ models_names = sorted(list(SUPPORTED_LANG_MODEL_DICT[lang]))
583
+ default = models_names[0]
584
+
585
+ if lang == 'en':
586
+ default = default_en_model
587
+ return models_in_lang.update(choices=models_names, value=default)
588
+
589
+ lang_selector.change(update_models_with_lang, inputs=[lang_selector], outputs=[models_in_lang])
590
+
591
+ return lang_selector, models_in_lang
592
+
593
+
594
+ """
595
+ Define the GUI
596
+ """
597
+ demo = gr.Blocks(title=TITLE, css=CSS)
598
+
599
+ with demo:
600
+ header = gr.Markdown(MARKDOWN)
601
+
602
+ with gr.Tab("Transcribe Audio"):
603
+ with gr.Row() as row:
604
+ file_upload = gr.components.Audio(source="upload", type='filepath', label='Upload File')
605
+ microphone = gr.components.Audio(source="microphone", type='filepath', label='Microphone')
606
+
607
+ lang_selector, models_in_lang = create_lang_selector_component()
608
+
609
+ run = gr.components.Button('Transcribe')
610
+
611
+ transcript = gr.components.Label(label='Transcript')
612
+ audio_html_output = gr.components.HTML()
613
+
614
+ run.click(
615
+ transcribe, inputs=[microphone, file_upload, models_in_lang], outputs=[transcript, audio_html_output]
616
+ )
617
+
618
+ with gr.Tab("Transcribe Youtube"):
619
+ yt_url = gr.components.Textbox(
620
+ lines=1, label="Youtube URL", placeholder="Paste the URL to a YouTube video here"
621
+ )
622
+
623
+ lang_selector_yt, models_in_lang_yt = create_lang_selector_component(
624
+ default_en_model=DEFAULT_BUFFERED_EN_MODEL
625
+ )
626
+
627
+ with gr.Row():
628
+ run = gr.components.Button('Transcribe YouTube')
629
+ embedded_video = gr.components.HTML()
630
+
631
+ transcript = gr.components.Label(label='Transcript')
632
+ yt_html_output = gr.components.HTML()
633
+
634
+ run.click(
635
+ yt_transcribe, inputs=[yt_url, models_in_lang_yt], outputs=[transcript, embedded_video, yt_html_output]
636
+ )
637
+
638
+ gr.components.HTML(ARTICLE)
639
+
640
+ demo.queue(concurrency_count=1)
641
+ demo.launch(enable_queue=True)
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ffmpeg
2
+ libsndfile1
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ nemo_toolkit[all]
2
+ pytube
speech_to_text_buffered_infer_ctc.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ This script serves three goals:
17
+ (1) Demonstrate how to use NeMo Models outside of PytorchLightning
18
+ (2) Shows example of batch ASR inference
19
+ (3) Serves as CI test for pre-trained checkpoint
20
+
21
+ python speech_to_text_buffered_infer_ctc.py \
22
+ model_path=null \
23
+ pretrained_name=null \
24
+ audio_dir="<remove or path to folder of audio files>" \
25
+ dataset_manifest="<remove or path to manifest>" \
26
+ output_filename="<remove or specify output filename>" \
27
+ total_buffer_in_secs=4.0 \
28
+ chunk_len_in_secs=1.6 \
29
+ model_stride=4 \
30
+ batch_size=32
31
+
32
+ # NOTE:
33
+ You can use `DEBUG=1 python speech_to_text_buffered_infer_ctc.py ...` to print out the
34
+ predictions of the model, and ground-truth text if presents in manifest.
35
+ """
36
+ import contextlib
37
+ import copy
38
+ import glob
39
+ import math
40
+ import os
41
+ from dataclasses import dataclass, is_dataclass
42
+ from typing import Optional
43
+
44
+ import torch
45
+ from omegaconf import OmegaConf
46
+
47
+ from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchASR
48
+ from nemo.collections.asr.parts.utils.transcribe_utils import (
49
+ compute_output_filename,
50
+ get_buffered_pred_feat,
51
+ setup_model,
52
+ write_transcription,
53
+ )
54
+ from nemo.core.config import hydra_runner
55
+ from nemo.utils import logging
56
+
57
+ can_gpu = torch.cuda.is_available()
58
+
59
+
60
+ @dataclass
61
+ class TranscriptionConfig:
62
+ # Required configs
63
+ model_path: Optional[str] = None # Path to a .nemo file
64
+ pretrained_name: Optional[str] = None # Name of a pretrained model
65
+ audio_dir: Optional[str] = None # Path to a directory which contains audio files
66
+ dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest
67
+
68
+ # General configs
69
+ output_filename: Optional[str] = None
70
+ batch_size: int = 32
71
+ num_workers: int = 0
72
+ append_pred: bool = False # Sets mode of work, if True it will add new field transcriptions.
73
+ pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one.
74
+
75
+ # Chunked configs
76
+ chunk_len_in_secs: float = 1.6 # Chunk length in seconds
77
+ total_buffer_in_secs: float = 4.0 # Length of buffer (chunk + left and right padding) in seconds
78
+ model_stride: int = 8 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models",
79
+
80
+ # Set `cuda` to int to define CUDA device. If 'None', will look for CUDA
81
+ # device anyway, and do inference on CPU only if CUDA device is not found.
82
+ # If `cuda` is a negative number, inference will be on CPU only.
83
+ cuda: Optional[int] = None
84
+ amp: bool = False
85
+ audio_type: str = "wav"
86
+
87
+ # Recompute model transcription, even if the output folder exists with scores.
88
+ overwrite_transcripts: bool = True
89
+
90
+
91
+ @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
92
+ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
93
+ logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
94
+ torch.set_grad_enabled(False)
95
+
96
+ if is_dataclass(cfg):
97
+ cfg = OmegaConf.structured(cfg)
98
+
99
+ if cfg.model_path is None and cfg.pretrained_name is None:
100
+ raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!")
101
+ if cfg.audio_dir is None and cfg.dataset_manifest is None:
102
+ raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!")
103
+
104
+ filepaths = None
105
+ manifest = cfg.dataset_manifest
106
+ if cfg.audio_dir is not None:
107
+ filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"**/*.{cfg.audio_type}"), recursive=True))
108
+ manifest = None # ignore dataset_manifest if audio_dir and dataset_manifest both presents
109
+
110
+ # setup GPU
111
+ if cfg.cuda is None:
112
+ if torch.cuda.is_available():
113
+ device = [0] # use 0th CUDA device
114
+ accelerator = 'gpu'
115
+ else:
116
+ device = 1
117
+ accelerator = 'cpu'
118
+ else:
119
+ device = [cfg.cuda]
120
+ accelerator = 'gpu'
121
+ map_location = torch.device('cuda:{}'.format(device[0]) if accelerator == 'gpu' else 'cpu')
122
+ logging.info(f"Inference will be done on device : {device}")
123
+
124
+ asr_model, model_name = setup_model(cfg, map_location)
125
+
126
+ model_cfg = copy.deepcopy(asr_model._cfg)
127
+ OmegaConf.set_struct(model_cfg.preprocessor, False)
128
+ # some changes for streaming scenario
129
+ model_cfg.preprocessor.dither = 0.0
130
+ model_cfg.preprocessor.pad_to = 0
131
+
132
+ if model_cfg.preprocessor.normalize != "per_feature":
133
+ logging.error("Only EncDecCTCModelBPE models trained with per_feature normalization are supported currently")
134
+
135
+ # Disable config overwriting
136
+ OmegaConf.set_struct(model_cfg.preprocessor, True)
137
+
138
+ # setup AMP (optional)
139
+ if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
140
+ logging.info("AMP enabled!\n")
141
+ autocast = torch.cuda.amp.autocast
142
+ else:
143
+
144
+ @contextlib.contextmanager
145
+ def autocast():
146
+ yield
147
+
148
+ # Compute output filename
149
+ cfg = compute_output_filename(cfg, model_name)
150
+
151
+ # if transcripts should not be overwritten, and already exists, skip re-transcription step and return
152
+ if not cfg.overwrite_transcripts and os.path.exists(cfg.output_filename):
153
+ logging.info(
154
+ f"Previous transcripts found at {cfg.output_filename}, and flag `overwrite_transcripts`"
155
+ f"is {cfg.overwrite_transcripts}. Returning without re-transcribing text."
156
+ )
157
+ return cfg
158
+
159
+ asr_model.eval()
160
+ asr_model = asr_model.to(asr_model.device)
161
+
162
+ feature_stride = model_cfg.preprocessor['window_stride']
163
+ model_stride_in_secs = feature_stride * cfg.model_stride
164
+ total_buffer = cfg.total_buffer_in_secs
165
+ chunk_len = float(cfg.chunk_len_in_secs)
166
+
167
+ tokens_per_chunk = math.ceil(chunk_len / model_stride_in_secs)
168
+ mid_delay = math.ceil((chunk_len + (total_buffer - chunk_len) / 2) / model_stride_in_secs)
169
+ logging.info(f"tokens_per_chunk is {tokens_per_chunk}, mid_delay is {mid_delay}")
170
+
171
+ frame_asr = FrameBatchASR(
172
+ asr_model=asr_model, frame_len=chunk_len, total_buffer=cfg.total_buffer_in_secs, batch_size=cfg.batch_size,
173
+ )
174
+
175
+ hyps = get_buffered_pred_feat(
176
+ frame_asr,
177
+ chunk_len,
178
+ tokens_per_chunk,
179
+ mid_delay,
180
+ model_cfg.preprocessor,
181
+ model_stride_in_secs,
182
+ asr_model.device,
183
+ manifest,
184
+ filepaths,
185
+ )
186
+ output_filename = write_transcription(hyps, cfg, model_name, filepaths=filepaths, compute_langs=False)
187
+ logging.info(f"Finished writing predictions to {output_filename}!")
188
+
189
+ return cfg
190
+
191
+
192
+ if __name__ == '__main__':
193
+ main() # noqa pylint: disable=no-value-for-parameter
speech_to_text_buffered_infer_rnnt.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Script to perform buffered inference using RNNT models.
17
+
18
+ Buffered inference is the primary form of audio transcription when the audio segment is longer than 20-30 seconds.
19
+ This is especially useful for models such as Conformers, which have quadratic time and memory scaling with
20
+ audio duration.
21
+
22
+ The difference between streaming and buffered inference is the chunk size (or the latency of inference).
23
+ Buffered inference will use large chunk sizes (5-10 seconds) + some additional buffer for context.
24
+ Streaming inference will use small chunk sizes (0.1 to 0.25 seconds) + some additional buffer for context.
25
+
26
+ # Middle Token merge algorithm
27
+
28
+ python speech_to_text_buffered_infer_rnnt.py \
29
+ model_path=null \
30
+ pretrained_name=null \
31
+ audio_dir="<remove or path to folder of audio files>" \
32
+ dataset_manifest="<remove or path to manifest>" \
33
+ output_filename="<remove or specify output filename>" \
34
+ total_buffer_in_secs=4.0 \
35
+ chunk_len_in_secs=1.6 \
36
+ model_stride=4 \
37
+ batch_size=32
38
+
39
+ # Longer Common Subsequence (LCS) Merge algorithm
40
+
41
+ python speech_to_text_buffered_infer_rnnt.py \
42
+ model_path=null \
43
+ pretrained_name=null \
44
+ audio_dir="<remove or path to folder of audio files>" \
45
+ dataset_manifest="<remove or path to manifest>" \
46
+ output_filename="<remove or specify output filename>" \
47
+ total_buffer_in_secs=4.0 \
48
+ chunk_len_in_secs=1.6 \
49
+ model_stride=4 \
50
+ batch_size=32 \
51
+ merge_algo="lcs" \
52
+ lcs_alignment_dir=<OPTIONAL: Some path to store the LCS alignments>
53
+
54
+ # NOTE:
55
+ You can use `DEBUG=1 python speech_to_text_buffered_infer_ctc.py ...` to print out the
56
+ predictions of the model, and ground-truth text if presents in manifest.
57
+ """
58
+ import copy
59
+ import glob
60
+ import math
61
+ import os
62
+ from dataclasses import dataclass, is_dataclass
63
+ from typing import Optional
64
+
65
+ import torch
66
+ from omegaconf import OmegaConf, open_dict
67
+
68
+ from nemo.collections.asr.parts.utils.streaming_utils import (
69
+ BatchedFrameASRRNNT,
70
+ LongestCommonSubsequenceBatchedFrameASRRNNT,
71
+ )
72
+ from nemo.collections.asr.parts.utils.transcribe_utils import (
73
+ compute_output_filename,
74
+ get_buffered_pred_feat_rnnt,
75
+ setup_model,
76
+ write_transcription,
77
+ )
78
+ from nemo.core.config import hydra_runner
79
+ from nemo.utils import logging
80
+
81
+ can_gpu = torch.cuda.is_available()
82
+
83
+
84
+ @dataclass
85
+ class TranscriptionConfig:
86
+ # Required configs
87
+ model_path: Optional[str] = None # Path to a .nemo file
88
+ pretrained_name: Optional[str] = None # Name of a pretrained model
89
+ audio_dir: Optional[str] = None # Path to a directory which contains audio files
90
+ dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest
91
+
92
+ # General configs
93
+ output_filename: Optional[str] = None
94
+ batch_size: int = 32
95
+ num_workers: int = 0
96
+ append_pred: bool = False # Sets mode of work, if True it will add new field transcriptions.
97
+ pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one.
98
+
99
+ # Chunked configs
100
+ chunk_len_in_secs: float = 1.6 # Chunk length in seconds
101
+ total_buffer_in_secs: float = 4.0 # Length of buffer (chunk + left and right padding) in seconds
102
+ model_stride: int = 8 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models",
103
+
104
+ # Set `cuda` to int to define CUDA device. If 'None', will look for CUDA
105
+ # device anyway, and do inference on CPU only if CUDA device is not found.
106
+ # If `cuda` is a negative number, inference will be on CPU only.
107
+ cuda: Optional[int] = None
108
+ audio_type: str = "wav"
109
+
110
+ # Recompute model transcription, even if the output folder exists with scores.
111
+ overwrite_transcripts: bool = True
112
+
113
+ # Decoding configs
114
+ max_steps_per_timestep: int = 5 #'Maximum number of tokens decoded per acoustic timestep'
115
+ stateful_decoding: bool = False # Whether to perform stateful decoding
116
+
117
+ # Merge algorithm for transducers
118
+ merge_algo: Optional[str] = 'middle' # choices=['middle', 'lcs'], choice of algorithm to apply during inference.
119
+ lcs_alignment_dir: Optional[str] = None # Path to a directory to store LCS algo alignments
120
+
121
+
122
+ @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
123
+ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
124
+ logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
125
+ torch.set_grad_enabled(False)
126
+
127
+ if is_dataclass(cfg):
128
+ cfg = OmegaConf.structured(cfg)
129
+
130
+ if cfg.model_path is None and cfg.pretrained_name is None:
131
+ raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!")
132
+ if cfg.audio_dir is None and cfg.dataset_manifest is None:
133
+ raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!")
134
+
135
+ filepaths = None
136
+ manifest = cfg.dataset_manifest
137
+ if cfg.audio_dir is not None:
138
+ filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"**/*.{cfg.audio_type}"), recursive=True))
139
+ manifest = None # ignore dataset_manifest if audio_dir and dataset_manifest both presents
140
+
141
+ # setup GPU
142
+ if cfg.cuda is None:
143
+ if torch.cuda.is_available():
144
+ device = [0] # use 0th CUDA device
145
+ accelerator = 'gpu'
146
+ else:
147
+ device = 1
148
+ accelerator = 'cpu'
149
+ else:
150
+ device = [cfg.cuda]
151
+ accelerator = 'gpu'
152
+ map_location = torch.device('cuda:{}'.format(device[0]) if accelerator == 'gpu' else 'cpu')
153
+ logging.info(f"Inference will be done on device : {device}")
154
+
155
+ asr_model, model_name = setup_model(cfg, map_location)
156
+
157
+ model_cfg = copy.deepcopy(asr_model._cfg)
158
+ OmegaConf.set_struct(model_cfg.preprocessor, False)
159
+ # some changes for streaming scenario
160
+ model_cfg.preprocessor.dither = 0.0
161
+ model_cfg.preprocessor.pad_to = 0
162
+
163
+ if model_cfg.preprocessor.normalize != "per_feature":
164
+ logging.error("Only EncDecRNNTBPEModel models trained with per_feature normalization are supported currently")
165
+
166
+ # Disable config overwriting
167
+ OmegaConf.set_struct(model_cfg.preprocessor, True)
168
+
169
+ # Compute output filename
170
+ cfg = compute_output_filename(cfg, model_name)
171
+
172
+ # if transcripts should not be overwritten, and already exists, skip re-transcription step and return
173
+ if not cfg.overwrite_transcripts and os.path.exists(cfg.output_filename):
174
+ logging.info(
175
+ f"Previous transcripts found at {cfg.output_filename}, and flag `overwrite_transcripts`"
176
+ f"is {cfg.overwrite_transcripts}. Returning without re-transcribing text."
177
+ )
178
+ return cfg
179
+
180
+ asr_model.freeze()
181
+ asr_model = asr_model.to(asr_model.device)
182
+
183
+ # Change Decoding Config
184
+ decoding_cfg = asr_model.cfg.decoding
185
+ with open_dict(decoding_cfg):
186
+ if cfg.stateful_decoding:
187
+ decoding_cfg.strategy = "greedy"
188
+ else:
189
+ decoding_cfg.strategy = "greedy_batch"
190
+ decoding_cfg.preserve_alignments = True # required to compute the middle token for transducers.
191
+ decoding_cfg.fused_batch_size = -1 # temporarily stop fused batch during inference.
192
+
193
+ asr_model.change_decoding_strategy(decoding_cfg)
194
+
195
+ feature_stride = model_cfg.preprocessor['window_stride']
196
+ model_stride_in_secs = feature_stride * cfg.model_stride
197
+ total_buffer = cfg.total_buffer_in_secs
198
+ chunk_len = float(cfg.chunk_len_in_secs)
199
+
200
+ tokens_per_chunk = math.ceil(chunk_len / model_stride_in_secs)
201
+ mid_delay = math.ceil((chunk_len + (total_buffer - chunk_len) / 2) / model_stride_in_secs)
202
+ logging.info(f"tokens_per_chunk is {tokens_per_chunk}, mid_delay is {mid_delay}")
203
+
204
+ if cfg.merge_algo == 'middle':
205
+ frame_asr = BatchedFrameASRRNNT(
206
+ asr_model=asr_model,
207
+ frame_len=chunk_len,
208
+ total_buffer=cfg.total_buffer_in_secs,
209
+ batch_size=cfg.batch_size,
210
+ max_steps_per_timestep=cfg.max_steps_per_timestep,
211
+ stateful_decoding=cfg.stateful_decoding,
212
+ )
213
+
214
+ elif cfg.merge_algo == 'lcs':
215
+ frame_asr = LongestCommonSubsequenceBatchedFrameASRRNNT(
216
+ asr_model=asr_model,
217
+ frame_len=chunk_len,
218
+ total_buffer=cfg.total_buffer_in_secs,
219
+ batch_size=cfg.batch_size,
220
+ max_steps_per_timestep=cfg.max_steps_per_timestep,
221
+ stateful_decoding=cfg.stateful_decoding,
222
+ alignment_basepath=cfg.lcs_alignment_dir,
223
+ )
224
+ # Set the LCS algorithm delay.
225
+ frame_asr.lcs_delay = math.floor(((total_buffer - chunk_len)) / model_stride_in_secs)
226
+
227
+ else:
228
+ raise ValueError("Invalid choice of merge algorithm for transducer buffered inference.")
229
+
230
+ hyps = get_buffered_pred_feat_rnnt(
231
+ asr=frame_asr,
232
+ tokens_per_chunk=tokens_per_chunk,
233
+ delay=mid_delay,
234
+ model_stride_in_secs=model_stride_in_secs,
235
+ batch_size=cfg.batch_size,
236
+ manifest=manifest,
237
+ filepaths=filepaths,
238
+ )
239
+
240
+ output_filename = write_transcription(hyps, cfg, model_name, filepaths=filepaths, compute_langs=False)
241
+ logging.info(f"Finished writing predictions to {output_filename}!")
242
+
243
+ return cfg
244
+
245
+
246
+ if __name__ == '__main__':
247
+ main() # noqa pylint: disable=no-value-for-parameter