smajumdar commited on
Commit
d40d29c
1 Parent(s): 39cf8cc

Add support for YT transcription

Browse files
app.py CHANGED
@@ -1,7 +1,21 @@
 
 
 
 
 
 
 
1
  import gradio as gr
2
- import torch
3
 
4
  import nemo.collections.asr as nemo_asr
 
 
 
 
 
 
 
5
 
6
  SAMPLE_RATE = 16000
7
  TITLE = "NeMo ASR Inference on Hugging Face"
@@ -32,7 +46,7 @@ ARTICLE = """
32
  SUPPORTED_LANGUAGES = set([])
33
  SUPPORTED_MODEL_NAMES = set([])
34
 
35
- # HF models
36
  hf_filter = nemo_asr.models.ASRModel.get_hf_model_filter()
37
  hf_filter.task = "automatic-speech-recognition"
38
 
@@ -44,6 +58,8 @@ for info in hf_infos:
44
 
45
  SUPPORTED_MODEL_NAMES = sorted(list(SUPPORTED_MODEL_NAMES))
46
 
 
 
47
  model_dict = {model_name: gr.Interface.load(f'models/{model_name}') for model_name in SUPPORTED_MODEL_NAMES}
48
 
49
  SUPPORTED_LANG_MODEL_DICT = {}
@@ -63,8 +79,253 @@ for lang in SUPPORTED_LANG_MODEL_DICT.keys():
63
  SUPPORTED_LANG_MODEL_DICT[lang] = model_ids
64
 
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  def transcribe(microphone, audio_file, model_name):
67
- model = model_dict[model_name]
68
 
69
  warn_output = ""
70
  if (microphone is not None) and (audio_file is not None):
@@ -84,7 +345,7 @@ def transcribe(microphone, audio_file, model_name):
84
 
85
  try:
86
  # Use HF API for transcription
87
- transcriptions = model(audio_data)
88
 
89
  except Exception as e:
90
  transcriptions = ""
@@ -98,21 +359,38 @@ def transcribe(microphone, audio_file, model_name):
98
  return warn_output + transcriptions
99
 
100
 
101
- demo = gr.Blocks(title=TITLE, css=CSS)
 
 
 
 
 
 
102
 
103
- with demo:
104
- header = gr.Markdown(MARKDOWN)
105
 
106
- with gr.Row() as row:
107
- file_upload = gr.components.Audio(source="upload", type='filepath', label='Upload File')
108
- microphone = gr.components.Audio(source="microphone", type='filepath', label='Microphone')
 
 
 
 
 
 
 
109
 
 
 
 
 
 
 
110
  lang_selector = gr.components.Dropdown(
111
  choices=sorted(list(SUPPORTED_LANGUAGES)), value="en", type="value", label="Languages", interactive=True,
112
  )
113
  models_in_lang = gr.components.Dropdown(
114
  choices=sorted(list(SUPPORTED_LANG_MODEL_DICT["en"])),
115
- value=DEFAULT_EN_MODEL,
116
  label="Models",
117
  interactive=True,
118
  )
@@ -122,17 +400,47 @@ with demo:
122
  default = models_names[0]
123
 
124
  if lang == 'en':
125
- default = DEFAULT_EN_MODEL
126
  return models_in_lang.update(choices=models_names, value=default)
127
 
128
  lang_selector.change(update_models_with_lang, inputs=[lang_selector], outputs=[models_in_lang])
129
 
130
- transcript = gr.components.Label(label='Transcript')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- run = gr.components.Button('Transcribe')
133
- run.click(transcribe, inputs=[microphone, file_upload, models_in_lang], outputs=[transcript])
134
 
135
  gr.components.HTML(ARTICLE)
136
 
137
  demo.queue(concurrency_count=1)
138
- demo.launch()
 
1
+ import os
2
+ import json
3
+ import uuid
4
+ import tempfile
5
+ import subprocess
6
+ import re
7
+
8
  import gradio as gr
9
+ import pytube as pt
10
 
11
  import nemo.collections.asr as nemo_asr
12
+ import speech_to_text_buffered_infer_ctc as buffered_ctc
13
+ import speech_to_text_buffered_infer_rnnt as buffered_rnnt
14
+
15
+ # Set NeMo cache dir as /tmp
16
+ from nemo import constants
17
+ os.environ[constants.NEMO_ENV_CACHE_DIR] = "/tmp/nemo"
18
+
19
 
20
  SAMPLE_RATE = 16000
21
  TITLE = "NeMo ASR Inference on Hugging Face"
 
46
  SUPPORTED_LANGUAGES = set([])
47
  SUPPORTED_MODEL_NAMES = set([])
48
 
49
+ # HF models, grouped by language identifier
50
  hf_filter = nemo_asr.models.ASRModel.get_hf_model_filter()
51
  hf_filter.task = "automatic-speech-recognition"
52
 
 
58
 
59
  SUPPORTED_MODEL_NAMES = sorted(list(SUPPORTED_MODEL_NAMES))
60
 
61
+ SUPPORTED_MODEL_NAMES = list(filter(lambda x: 'en' in x and 'conformer_transducer_large' in x, SUPPORTED_MODEL_NAMES))
62
+
63
  model_dict = {model_name: gr.Interface.load(f'models/{model_name}') for model_name in SUPPORTED_MODEL_NAMES}
64
 
65
  SUPPORTED_LANG_MODEL_DICT = {}
 
79
  SUPPORTED_LANG_MODEL_DICT[lang] = model_ids
80
 
81
 
82
+ def parse_duration(audio_file):
83
+ """
84
+ FFMPEG to calculate durations. Libraries can do it too, but filetypes cause different libraries to behave differently.
85
+ """
86
+ process = subprocess.Popen(['ffmpeg', '-i', audio_file], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
87
+ stdout, stderr = process.communicate()
88
+ matches = re.search(
89
+ r"Duration:\s{1}(?P<hours>\d+?):(?P<minutes>\d+?):(?P<seconds>\d+\.\d+?),", stdout.decode(), re.DOTALL
90
+ ).groupdict()
91
+
92
+ duration = 0.0
93
+ duration += float(matches['hours']) * 60.0 * 60.0
94
+ duration += float(matches['minutes']) * 60.0
95
+ duration += float(matches['seconds']) * 1.0
96
+ return duration
97
+
98
+
99
+ def resolve_model_type(model_name: str) -> str:
100
+ """
101
+ Map model name to a class type, without loading the model. Has some hardcoded assumptions in
102
+ semantics of model naming.
103
+ """
104
+ # Loss specific maps
105
+ if 'hybrid' in model_name or 'hybrid_ctc' in model_name or 'hybrid_transducer' in model_name:
106
+ return 'hybrid'
107
+ elif 'transducer' in model_name or 'rnnt' in model_id:
108
+ return 'transducer'
109
+ elif 'ctc' in model_name:
110
+ return 'ctc'
111
+
112
+ # Model specific maps
113
+ elif 'jasper' in model_name:
114
+ return 'ctc'
115
+ elif 'quartznet' in model_name:
116
+ return 'ctc'
117
+ elif 'citrinet' in model_name:
118
+ return 'ctc'
119
+ elif 'contextnet' in model_name:
120
+ return 'ctc'
121
+ else:
122
+ # Unknown model type
123
+ return None
124
+
125
+
126
+ def resolve_model_stride(model_name) -> int:
127
+ """
128
+ Model specific pre-calc of stride levels.
129
+ Dont laod model to get such info.
130
+ """
131
+ if 'jasper' in model_name:
132
+ return 2
133
+ if 'quartznet' in model_name:
134
+ return 2
135
+ if 'conformer' in model_name:
136
+ return 4
137
+ if 'squeezeformer' in model_name:
138
+ return 4
139
+ if 'citrinet' in model_name:
140
+ return 8
141
+ if 'contextnet' in model_name:
142
+ return 8
143
+
144
+ return -1
145
+
146
+
147
+ def convert_audio(audio_filepath):
148
+ """
149
+ Transcode all mp3 files to monochannel 16 kHz wav files.
150
+ """
151
+ filedir = os.path.split(audio_filepath)[0]
152
+ filename, ext = os.path.splitext(audio_filepath)
153
+
154
+ if ext == 'wav':
155
+ return audio_filepath
156
+
157
+ out_filename = os.path.join(filedir, filename + '.wav')
158
+ process = subprocess.Popen(
159
+ ['ffmpeg', '-i', audio_filepath, '-ac', '1', '-ar', str(SAMPLE_RATE), out_filename],
160
+ stdout=subprocess.PIPE,
161
+ stderr=subprocess.STDOUT,
162
+ )
163
+ stdout, stderr = process.communicate()
164
+
165
+ if os.path.exists(out_filename):
166
+ return out_filename
167
+ else:
168
+ return None
169
+
170
+
171
+ def extract_result_from_manifest(filepath, model_name) -> (bool, str):
172
+ """
173
+ Parse the written manifest which is result of the buffered inference process.
174
+ """
175
+ data = []
176
+ with open(filepath, 'r', encoding='utf-8') as f:
177
+ for line in f:
178
+ try:
179
+ line = json.loads(line)
180
+ data.append(line['pred_text'])
181
+ except Exception as e:
182
+ pass
183
+
184
+ if len(data) > 0:
185
+ return True, data[0]
186
+ else:
187
+ return False, f"Could not perform inference on model with name : {model_name}"
188
+
189
+
190
+ def infer_audio(model_name: str, audio_file: str) -> str:
191
+ """
192
+ Main method that switches from HF inference for small audio files to Buffered CTC/RNNT mode for long audio files.
193
+
194
+ Args:
195
+ model_name: Str name of the model (potentially with / to denote HF models)
196
+ audio_file: Path to an audio file (mp3 or wav)
197
+
198
+ Returns:
199
+ str which is the transcription if successful.
200
+ """
201
+ # Parse the duration of the audio file
202
+ duration = parse_duration(audio_file)
203
+
204
+ if duration > 60.0: # Longer than one minute; use buffered mode
205
+ # Process audio to be of wav type (possible youtube audio)
206
+ audio_file = convert_audio(audio_file)
207
+
208
+ # If audio file transcoding failed, let user know
209
+ if audio_file is None:
210
+ return "Failed to convert audio file to wav."
211
+
212
+ # Extract audio dir from resolved audio filepath
213
+ audio_dir = os.path.split(audio_file)[0]
214
+
215
+ # Next calculate the stride of each model
216
+ model_stride = resolve_model_stride(model_name)
217
+
218
+ if model_stride < 0:
219
+ return f"Failed to compute the model stride for model with name : {model_name}"
220
+
221
+ # Process model type (CTC/RNNT/Hybrid)
222
+ model_type = resolve_model_type(model_name)
223
+
224
+ if model_type is None:
225
+
226
+ # Model type could not be infered.
227
+ # Try all feasible options
228
+ RESULT = None
229
+
230
+ try:
231
+ ctc_config = buffered_ctc.TranscriptionConfig(
232
+ pretrained_name=model_name,
233
+ audio_dir=audio_dir,
234
+ output_filename="output.json",
235
+ audio_type="wav",
236
+ overwrite_transcripts=True,
237
+ model_stride=model_stride,
238
+ chunk_len_in_secs=20.0,
239
+ total_buffer_in_secs=30.0,
240
+ )
241
+
242
+ buffered_ctc.main(ctc_config)
243
+ result = extract_result_from_manifest('output.json', model_name)
244
+ if result[0]:
245
+ RESULT = result[1]
246
+
247
+ except Exception as e:
248
+ pass
249
+
250
+ try:
251
+ rnnt_config = buffered_rnnt.TranscriptionConfig(
252
+ pretrained_name=model_name,
253
+ audio_dir=audio_dir,
254
+ output_filename="output.json",
255
+ audio_type="wav",
256
+ overwrite_transcripts=True,
257
+ model_stride=model_stride,
258
+ chunk_len_in_secs=20.0,
259
+ total_buffer_in_secs=30.0,
260
+ )
261
+
262
+ buffered_rnnt.main(rnnt_config)
263
+ result = extract_result_from_manifest('output.json', model_name)[-1]
264
+
265
+ if result[0]:
266
+ RESULT = result[1]
267
+ except Exception as e:
268
+ pass
269
+
270
+ if RESULT is None:
271
+ return f"Could not parse model type; failed to perform inference with model {model_name}!"
272
+
273
+ elif model_type == 'ctc':
274
+
275
+ # CTC Buffered Inference
276
+ ctc_config = buffered_ctc.TranscriptionConfig(
277
+ pretrained_name=model_name,
278
+ audio_dir=audio_dir,
279
+ output_filename="output.json",
280
+ audio_type="wav",
281
+ overwrite_transcripts=True,
282
+ model_stride=model_stride,
283
+ chunk_len_in_secs=20.0,
284
+ total_buffer_in_secs=30.0,
285
+ )
286
+
287
+ buffered_ctc.main(ctc_config)
288
+ return extract_result_from_manifest('output.json', model_name)[-1]
289
+
290
+ elif model_type == 'transducer':
291
+
292
+ # RNNT Buffered Inference
293
+ rnnt_config = buffered_rnnt.TranscriptionConfig(
294
+ pretrained_name=model_name,
295
+ audio_dir=audio_dir,
296
+ output_filename="output.json",
297
+ audio_type="wav",
298
+ overwrite_transcripts=True,
299
+ model_stride=model_stride,
300
+ chunk_len_in_secs=20.0,
301
+ total_buffer_in_secs=30.0,
302
+ )
303
+
304
+ buffered_rnnt.main(rnnt_config)
305
+ return extract_result_from_manifest('output.json', model_name)[-1]
306
+
307
+ else:
308
+ return f"Could not parse model type; failed to perform inference with model {model_name}!"
309
+
310
+ else:
311
+ if model_name in model_dict:
312
+ model = model_dict[model_name]
313
+ else:
314
+ model = None
315
+
316
+ if model is not None:
317
+ # Use HF API for transcription
318
+ transcriptions = model(audio_file)
319
+ return transcriptions
320
+ else:
321
+ error = (
322
+ f"Could not find model {model_name} in list of available models : "
323
+ f"{list([k for k in model_dict.keys()])}"
324
+ )
325
+ return error
326
+
327
+
328
  def transcribe(microphone, audio_file, model_name):
 
329
 
330
  warn_output = ""
331
  if (microphone is not None) and (audio_file is not None):
 
345
 
346
  try:
347
  # Use HF API for transcription
348
+ transcriptions = infer_audio(model_name, audio_data)
349
 
350
  except Exception as e:
351
  transcriptions = ""
 
359
  return warn_output + transcriptions
360
 
361
 
362
+ def _return_yt_html_embed(yt_url):
363
+ video_id = yt_url.split("?v=")[-1]
364
+ HTML_str = (
365
+ f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
366
+ " </center>"
367
+ )
368
+ return HTML_str
369
 
 
 
370
 
371
+ def yt_transcribe(yt_url, model_name):
372
+ yt = pt.YouTube(yt_url)
373
+ html_embed_str = _return_yt_html_embed(yt_url)
374
+
375
+ with tempfile.TemporaryDirectory() as tempdir:
376
+ file_uuid = str(uuid.uuid4().hex)
377
+ file_uuid = f"{tempdir}/{file_uuid}.mp3"
378
+
379
+ stream = yt.streams.filter(only_audio=True)[0]
380
+ stream.download(filename=file_uuid)
381
 
382
+ text = infer_audio(model_name, file_uuid)
383
+
384
+ return html_embed_str, text
385
+
386
+
387
+ def create_lang_selector_component(default_en_model=DEFAULT_EN_MODEL):
388
  lang_selector = gr.components.Dropdown(
389
  choices=sorted(list(SUPPORTED_LANGUAGES)), value="en", type="value", label="Languages", interactive=True,
390
  )
391
  models_in_lang = gr.components.Dropdown(
392
  choices=sorted(list(SUPPORTED_LANG_MODEL_DICT["en"])),
393
+ value=default_en_model,
394
  label="Models",
395
  interactive=True,
396
  )
 
400
  default = models_names[0]
401
 
402
  if lang == 'en':
403
+ default = default_en_model
404
  return models_in_lang.update(choices=models_names, value=default)
405
 
406
  lang_selector.change(update_models_with_lang, inputs=[lang_selector], outputs=[models_in_lang])
407
 
408
+ return lang_selector, models_in_lang
409
+
410
+
411
+ demo = gr.Blocks(title=TITLE, css=CSS)
412
+
413
+ with demo:
414
+ header = gr.Markdown(MARKDOWN)
415
+
416
+ with gr.Tab("Transcribe Audio"):
417
+ with gr.Row() as row:
418
+ file_upload = gr.components.Audio(source="upload", type='filepath', label='Upload File')
419
+ microphone = gr.components.Audio(source="microphone", type='filepath', label='Microphone')
420
+
421
+ lang_selector, models_in_lang = create_lang_selector_component()
422
+
423
+ transcript = gr.components.Label(label='Transcript')
424
+
425
+ run = gr.components.Button('Transcribe')
426
+ run.click(transcribe, inputs=[microphone, file_upload, models_in_lang], outputs=[transcript])
427
+
428
+ with gr.Tab("Transcribe Youtube"):
429
+ yt_url = gr.components.Textbox(
430
+ lines=1, label="Youtube URL", placeholder="Paste the URL to a YouTube video here"
431
+ )
432
+
433
+ lang_selector_yt, models_in_lang_yt = create_lang_selector_component(
434
+ default_en_model='nvidia/stt_en_conformer_transducer_large'
435
+ )
436
+
437
+ embedded_video = gr.components.HTML()
438
+ transcript = gr.components.Label(label='Transcript')
439
 
440
+ run = gr.components.Button('Transcribe YouTube')
441
+ run.click(yt_transcribe, inputs=[yt_url, models_in_lang_yt], outputs=[embedded_video, transcript])
442
 
443
  gr.components.HTML(ARTICLE)
444
 
445
  demo.queue(concurrency_count=1)
446
+ demo.launch(enable_queue=True)
requirements.txt CHANGED
@@ -1 +1,2 @@
1
- nemo_toolkit[asr]
 
 
1
+ git+https://github.com/NVIDIA/NeMo.git@{BRANCH}#egg=nemo_toolkit[all]
2
+ pytube
speech_to_text_buffered_infer_ctc.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ This script serves three goals:
17
+ (1) Demonstrate how to use NeMo Models outside of PytorchLightning
18
+ (2) Shows example of batch ASR inference
19
+ (3) Serves as CI test for pre-trained checkpoint
20
+
21
+ python speech_to_text_buffered_infer_ctc.py \
22
+ model_path=null \
23
+ pretrained_name=null \
24
+ audio_dir="<remove or path to folder of audio files>" \
25
+ dataset_manifest="<remove or path to manifest>" \
26
+ output_filename="<remove or specify output filename>" \
27
+ total_buffer_in_secs=4.0 \
28
+ chunk_len_in_secs=1.6 \
29
+ model_stride=4 \
30
+ batch_size=32
31
+
32
+ # NOTE:
33
+ You can use `DEBUG=1 python speech_to_text_buffered_infer_ctc.py ...` to print out the
34
+ predictions of the model, and ground-truth text if presents in manifest.
35
+ """
36
+ import contextlib
37
+ import copy
38
+ import glob
39
+ import math
40
+ import os
41
+ from dataclasses import dataclass, is_dataclass
42
+ from typing import Optional
43
+
44
+ import torch
45
+ from omegaconf import OmegaConf
46
+
47
+ from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchASR
48
+ from nemo.collections.asr.parts.utils.transcribe_utils import (
49
+ compute_output_filename,
50
+ get_buffered_pred_feat,
51
+ setup_model,
52
+ write_transcription,
53
+ )
54
+ from nemo.core.config import hydra_runner
55
+ from nemo.utils import logging
56
+
57
+ can_gpu = torch.cuda.is_available()
58
+
59
+
60
+ @dataclass
61
+ class TranscriptionConfig:
62
+ # Required configs
63
+ model_path: Optional[str] = None # Path to a .nemo file
64
+ pretrained_name: Optional[str] = None # Name of a pretrained model
65
+ audio_dir: Optional[str] = None # Path to a directory which contains audio files
66
+ dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest
67
+
68
+ # General configs
69
+ output_filename: Optional[str] = None
70
+ batch_size: int = 32
71
+ num_workers: int = 0
72
+ append_pred: bool = False # Sets mode of work, if True it will add new field transcriptions.
73
+ pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one.
74
+
75
+ # Chunked configs
76
+ chunk_len_in_secs: float = 1.6 # Chunk length in seconds
77
+ total_buffer_in_secs: float = 4.0 # Length of buffer (chunk + left and right padding) in seconds
78
+ model_stride: int = 8 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models",
79
+
80
+ # Set `cuda` to int to define CUDA device. If 'None', will look for CUDA
81
+ # device anyway, and do inference on CPU only if CUDA device is not found.
82
+ # If `cuda` is a negative number, inference will be on CPU only.
83
+ cuda: Optional[int] = None
84
+ amp: bool = False
85
+ audio_type: str = "wav"
86
+
87
+ # Recompute model transcription, even if the output folder exists with scores.
88
+ overwrite_transcripts: bool = True
89
+
90
+
91
+ @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
92
+ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
93
+ logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
94
+ torch.set_grad_enabled(False)
95
+
96
+ if is_dataclass(cfg):
97
+ cfg = OmegaConf.structured(cfg)
98
+
99
+ if cfg.model_path is None and cfg.pretrained_name is None:
100
+ raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!")
101
+ if cfg.audio_dir is None and cfg.dataset_manifest is None:
102
+ raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!")
103
+
104
+ filepaths = None
105
+ manifest = cfg.dataset_manifest
106
+ if cfg.audio_dir is not None:
107
+ filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"**/*.{cfg.audio_type}"), recursive=True))
108
+ manifest = None # ignore dataset_manifest if audio_dir and dataset_manifest both presents
109
+
110
+ # setup GPU
111
+ if cfg.cuda is None:
112
+ if torch.cuda.is_available():
113
+ device = [0] # use 0th CUDA device
114
+ accelerator = 'gpu'
115
+ else:
116
+ device = 1
117
+ accelerator = 'cpu'
118
+ else:
119
+ device = [cfg.cuda]
120
+ accelerator = 'gpu'
121
+ map_location = torch.device('cuda:{}'.format(device[0]) if accelerator == 'gpu' else 'cpu')
122
+ logging.info(f"Inference will be done on device : {device}")
123
+
124
+ asr_model, model_name = setup_model(cfg, map_location)
125
+
126
+ model_cfg = copy.deepcopy(asr_model._cfg)
127
+ OmegaConf.set_struct(model_cfg.preprocessor, False)
128
+ # some changes for streaming scenario
129
+ model_cfg.preprocessor.dither = 0.0
130
+ model_cfg.preprocessor.pad_to = 0
131
+
132
+ if model_cfg.preprocessor.normalize != "per_feature":
133
+ logging.error("Only EncDecCTCModelBPE models trained with per_feature normalization are supported currently")
134
+
135
+ # Disable config overwriting
136
+ OmegaConf.set_struct(model_cfg.preprocessor, True)
137
+
138
+ # setup AMP (optional)
139
+ if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
140
+ logging.info("AMP enabled!\n")
141
+ autocast = torch.cuda.amp.autocast
142
+ else:
143
+
144
+ @contextlib.contextmanager
145
+ def autocast():
146
+ yield
147
+
148
+ # Compute output filename
149
+ cfg = compute_output_filename(cfg, model_name)
150
+
151
+ # if transcripts should not be overwritten, and already exists, skip re-transcription step and return
152
+ if not cfg.overwrite_transcripts and os.path.exists(cfg.output_filename):
153
+ logging.info(
154
+ f"Previous transcripts found at {cfg.output_filename}, and flag `overwrite_transcripts`"
155
+ f"is {cfg.overwrite_transcripts}. Returning without re-transcribing text."
156
+ )
157
+ return cfg
158
+
159
+ asr_model.eval()
160
+ asr_model = asr_model.to(asr_model.device)
161
+
162
+ feature_stride = model_cfg.preprocessor['window_stride']
163
+ model_stride_in_secs = feature_stride * cfg.model_stride
164
+ total_buffer = cfg.total_buffer_in_secs
165
+ chunk_len = float(cfg.chunk_len_in_secs)
166
+
167
+ tokens_per_chunk = math.ceil(chunk_len / model_stride_in_secs)
168
+ mid_delay = math.ceil((chunk_len + (total_buffer - chunk_len) / 2) / model_stride_in_secs)
169
+ logging.info(f"tokens_per_chunk is {tokens_per_chunk}, mid_delay is {mid_delay}")
170
+
171
+ frame_asr = FrameBatchASR(
172
+ asr_model=asr_model, frame_len=chunk_len, total_buffer=cfg.total_buffer_in_secs, batch_size=cfg.batch_size,
173
+ )
174
+
175
+ hyps = get_buffered_pred_feat(
176
+ frame_asr,
177
+ chunk_len,
178
+ tokens_per_chunk,
179
+ mid_delay,
180
+ model_cfg.preprocessor,
181
+ model_stride_in_secs,
182
+ asr_model.device,
183
+ manifest,
184
+ filepaths,
185
+ )
186
+ output_filename = write_transcription(hyps, cfg, model_name, filepaths=filepaths, compute_langs=False)
187
+ logging.info(f"Finished writing predictions to {output_filename}!")
188
+
189
+ return cfg
190
+
191
+
192
+ if __name__ == '__main__':
193
+ main() # noqa pylint: disable=no-value-for-parameter
speech_to_text_buffered_infer_rnnt.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Script to perform buffered inference using RNNT models.
17
+
18
+ Buffered inference is the primary form of audio transcription when the audio segment is longer than 20-30 seconds.
19
+ This is especially useful for models such as Conformers, which have quadratic time and memory scaling with
20
+ audio duration.
21
+
22
+ The difference between streaming and buffered inference is the chunk size (or the latency of inference).
23
+ Buffered inference will use large chunk sizes (5-10 seconds) + some additional buffer for context.
24
+ Streaming inference will use small chunk sizes (0.1 to 0.25 seconds) + some additional buffer for context.
25
+
26
+ # Middle Token merge algorithm
27
+
28
+ python speech_to_text_buffered_infer_rnnt.py \
29
+ model_path=null \
30
+ pretrained_name=null \
31
+ audio_dir="<remove or path to folder of audio files>" \
32
+ dataset_manifest="<remove or path to manifest>" \
33
+ output_filename="<remove or specify output filename>" \
34
+ total_buffer_in_secs=4.0 \
35
+ chunk_len_in_secs=1.6 \
36
+ model_stride=4 \
37
+ batch_size=32
38
+
39
+ # Longer Common Subsequence (LCS) Merge algorithm
40
+
41
+ python speech_to_text_buffered_infer_rnnt.py \
42
+ model_path=null \
43
+ pretrained_name=null \
44
+ audio_dir="<remove or path to folder of audio files>" \
45
+ dataset_manifest="<remove or path to manifest>" \
46
+ output_filename="<remove or specify output filename>" \
47
+ total_buffer_in_secs=4.0 \
48
+ chunk_len_in_secs=1.6 \
49
+ model_stride=4 \
50
+ batch_size=32 \
51
+ merge_algo="lcs" \
52
+ lcs_alignment_dir=<OPTIONAL: Some path to store the LCS alignments>
53
+
54
+ # NOTE:
55
+ You can use `DEBUG=1 python speech_to_text_buffered_infer_ctc.py ...` to print out the
56
+ predictions of the model, and ground-truth text if presents in manifest.
57
+ """
58
+ import copy
59
+ import glob
60
+ import math
61
+ import os
62
+ from dataclasses import dataclass, is_dataclass
63
+ from typing import Optional
64
+
65
+ import torch
66
+ from omegaconf import OmegaConf, open_dict
67
+
68
+ from nemo.collections.asr.parts.utils.streaming_utils import (
69
+ BatchedFrameASRRNNT,
70
+ LongestCommonSubsequenceBatchedFrameASRRNNT,
71
+ )
72
+ from nemo.collections.asr.parts.utils.transcribe_utils import (
73
+ compute_output_filename,
74
+ get_buffered_pred_feat_rnnt,
75
+ setup_model,
76
+ write_transcription,
77
+ )
78
+ from nemo.core.config import hydra_runner
79
+ from nemo.utils import logging
80
+
81
+ can_gpu = torch.cuda.is_available()
82
+
83
+
84
+ @dataclass
85
+ class TranscriptionConfig:
86
+ # Required configs
87
+ model_path: Optional[str] = None # Path to a .nemo file
88
+ pretrained_name: Optional[str] = None # Name of a pretrained model
89
+ audio_dir: Optional[str] = None # Path to a directory which contains audio files
90
+ dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest
91
+
92
+ # General configs
93
+ output_filename: Optional[str] = None
94
+ batch_size: int = 32
95
+ num_workers: int = 0
96
+ append_pred: bool = False # Sets mode of work, if True it will add new field transcriptions.
97
+ pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one.
98
+
99
+ # Chunked configs
100
+ chunk_len_in_secs: float = 1.6 # Chunk length in seconds
101
+ total_buffer_in_secs: float = 4.0 # Length of buffer (chunk + left and right padding) in seconds
102
+ model_stride: int = 8 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models",
103
+
104
+ # Set `cuda` to int to define CUDA device. If 'None', will look for CUDA
105
+ # device anyway, and do inference on CPU only if CUDA device is not found.
106
+ # If `cuda` is a negative number, inference will be on CPU only.
107
+ cuda: Optional[int] = None
108
+ audio_type: str = "wav"
109
+
110
+ # Recompute model transcription, even if the output folder exists with scores.
111
+ overwrite_transcripts: bool = True
112
+
113
+ # Decoding configs
114
+ max_steps_per_timestep: int = 5 #'Maximum number of tokens decoded per acoustic timestep'
115
+ stateful_decoding: bool = False # Whether to perform stateful decoding
116
+
117
+ # Merge algorithm for transducers
118
+ merge_algo: Optional[str] = 'middle' # choices=['middle', 'lcs'], choice of algorithm to apply during inference.
119
+ lcs_alignment_dir: Optional[str] = None # Path to a directory to store LCS algo alignments
120
+
121
+
122
+ @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
123
+ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
124
+ logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
125
+ torch.set_grad_enabled(False)
126
+
127
+ if is_dataclass(cfg):
128
+ cfg = OmegaConf.structured(cfg)
129
+
130
+ if cfg.model_path is None and cfg.pretrained_name is None:
131
+ raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!")
132
+ if cfg.audio_dir is None and cfg.dataset_manifest is None:
133
+ raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!")
134
+
135
+ filepaths = None
136
+ manifest = cfg.dataset_manifest
137
+ if cfg.audio_dir is not None:
138
+ filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"**/*.{cfg.audio_type}"), recursive=True))
139
+ manifest = None # ignore dataset_manifest if audio_dir and dataset_manifest both presents
140
+
141
+ # setup GPU
142
+ if cfg.cuda is None:
143
+ if torch.cuda.is_available():
144
+ device = [0] # use 0th CUDA device
145
+ accelerator = 'gpu'
146
+ else:
147
+ device = 1
148
+ accelerator = 'cpu'
149
+ else:
150
+ device = [cfg.cuda]
151
+ accelerator = 'gpu'
152
+ map_location = torch.device('cuda:{}'.format(device[0]) if accelerator == 'gpu' else 'cpu')
153
+ logging.info(f"Inference will be done on device : {device}")
154
+
155
+ asr_model, model_name = setup_model(cfg, map_location)
156
+
157
+ model_cfg = copy.deepcopy(asr_model._cfg)
158
+ OmegaConf.set_struct(model_cfg.preprocessor, False)
159
+ # some changes for streaming scenario
160
+ model_cfg.preprocessor.dither = 0.0
161
+ model_cfg.preprocessor.pad_to = 0
162
+
163
+ if model_cfg.preprocessor.normalize != "per_feature":
164
+ logging.error("Only EncDecRNNTBPEModel models trained with per_feature normalization are supported currently")
165
+
166
+ # Disable config overwriting
167
+ OmegaConf.set_struct(model_cfg.preprocessor, True)
168
+
169
+ # Compute output filename
170
+ cfg = compute_output_filename(cfg, model_name)
171
+
172
+ # if transcripts should not be overwritten, and already exists, skip re-transcription step and return
173
+ if not cfg.overwrite_transcripts and os.path.exists(cfg.output_filename):
174
+ logging.info(
175
+ f"Previous transcripts found at {cfg.output_filename}, and flag `overwrite_transcripts`"
176
+ f"is {cfg.overwrite_transcripts}. Returning without re-transcribing text."
177
+ )
178
+ return cfg
179
+
180
+ asr_model.freeze()
181
+ asr_model = asr_model.to(asr_model.device)
182
+
183
+ # Change Decoding Config
184
+ decoding_cfg = asr_model.cfg.decoding
185
+ with open_dict(decoding_cfg):
186
+ if cfg.stateful_decoding:
187
+ decoding_cfg.strategy = "greedy"
188
+ else:
189
+ decoding_cfg.strategy = "greedy_batch"
190
+ decoding_cfg.preserve_alignments = True # required to compute the middle token for transducers.
191
+ decoding_cfg.fused_batch_size = -1 # temporarily stop fused batch during inference.
192
+
193
+ asr_model.change_decoding_strategy(decoding_cfg)
194
+
195
+ feature_stride = model_cfg.preprocessor['window_stride']
196
+ model_stride_in_secs = feature_stride * cfg.model_stride
197
+ total_buffer = cfg.total_buffer_in_secs
198
+ chunk_len = float(cfg.chunk_len_in_secs)
199
+
200
+ tokens_per_chunk = math.ceil(chunk_len / model_stride_in_secs)
201
+ mid_delay = math.ceil((chunk_len + (total_buffer - chunk_len) / 2) / model_stride_in_secs)
202
+ logging.info(f"tokens_per_chunk is {tokens_per_chunk}, mid_delay is {mid_delay}")
203
+
204
+ if cfg.merge_algo == 'middle':
205
+ frame_asr = BatchedFrameASRRNNT(
206
+ asr_model=asr_model,
207
+ frame_len=chunk_len,
208
+ total_buffer=cfg.total_buffer_in_secs,
209
+ batch_size=cfg.batch_size,
210
+ max_steps_per_timestep=cfg.max_steps_per_timestep,
211
+ stateful_decoding=cfg.stateful_decoding,
212
+ )
213
+
214
+ elif cfg.merge_algo == 'lcs':
215
+ frame_asr = LongestCommonSubsequenceBatchedFrameASRRNNT(
216
+ asr_model=asr_model,
217
+ frame_len=chunk_len,
218
+ total_buffer=cfg.total_buffer_in_secs,
219
+ batch_size=cfg.batch_size,
220
+ max_steps_per_timestep=cfg.max_steps_per_timestep,
221
+ stateful_decoding=cfg.stateful_decoding,
222
+ alignment_basepath=cfg.lcs_alignment_dir,
223
+ )
224
+ # Set the LCS algorithm delay.
225
+ frame_asr.lcs_delay = math.floor(((total_buffer - chunk_len)) / model_stride_in_secs)
226
+
227
+ else:
228
+ raise ValueError("Invalid choice of merge algorithm for transducer buffered inference.")
229
+
230
+ hyps = get_buffered_pred_feat_rnnt(
231
+ asr=frame_asr,
232
+ tokens_per_chunk=tokens_per_chunk,
233
+ delay=mid_delay,
234
+ model_stride_in_secs=model_stride_in_secs,
235
+ batch_size=cfg.batch_size,
236
+ manifest=manifest,
237
+ filepaths=filepaths,
238
+ )
239
+
240
+ output_filename = write_transcription(hyps, cfg, model_name, filepaths=filepaths, compute_langs=False)
241
+ logging.info(f"Finished writing predictions to {output_filename}!")
242
+
243
+ return cfg
244
+
245
+
246
+ if __name__ == '__main__':
247
+ main() # noqa pylint: disable=no-value-for-parameter