# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This script serves three goals: (1) Demonstrate how to use NeMo Models outside of PytorchLightning (2) Shows example of batch ASR inference (3) Serves as CI test for pre-trained checkpoint python speech_to_text_buffered_infer_ctc.py \ model_path=null \ pretrained_name=null \ audio_dir="" \ dataset_manifest="" \ output_filename="" \ total_buffer_in_secs=4.0 \ chunk_len_in_secs=1.6 \ model_stride=4 \ batch_size=32 # NOTE: You can use `DEBUG=1 python speech_to_text_buffered_infer_ctc.py ...` to print out the predictions of the model, and ground-truth text if presents in manifest. """ import contextlib import copy import glob import math import os from dataclasses import dataclass, is_dataclass from typing import Optional import torch from omegaconf import OmegaConf from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchASR from nemo.collections.asr.parts.utils.transcribe_utils import ( compute_output_filename, get_buffered_pred_feat, setup_model, write_transcription, ) from nemo.core.config import hydra_runner from nemo.utils import logging can_gpu = torch.cuda.is_available() @dataclass class TranscriptionConfig: # Required configs model_path: Optional[str] = None # Path to a .nemo file pretrained_name: Optional[str] = None # Name of a pretrained model audio_dir: Optional[str] = None # Path to a directory which contains audio files dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest # General configs output_filename: Optional[str] = None batch_size: int = 32 num_workers: int = 0 append_pred: bool = False # Sets mode of work, if True it will add new field transcriptions. pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one. # Chunked configs chunk_len_in_secs: float = 1.6 # Chunk length in seconds total_buffer_in_secs: float = 4.0 # Length of buffer (chunk + left and right padding) in seconds model_stride: int = 8 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models", # Set `cuda` to int to define CUDA device. If 'None', will look for CUDA # device anyway, and do inference on CPU only if CUDA device is not found. # If `cuda` is a negative number, inference will be on CPU only. cuda: Optional[int] = None amp: bool = False audio_type: str = "wav" # Recompute model transcription, even if the output folder exists with scores. overwrite_transcripts: bool = True @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) def main(cfg: TranscriptionConfig) -> TranscriptionConfig: logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') torch.set_grad_enabled(False) if is_dataclass(cfg): cfg = OmegaConf.structured(cfg) if cfg.model_path is None and cfg.pretrained_name is None: raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!") if cfg.audio_dir is None and cfg.dataset_manifest is None: raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!") filepaths = None manifest = cfg.dataset_manifest if cfg.audio_dir is not None: filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"**/*.{cfg.audio_type}"), recursive=True)) manifest = None # ignore dataset_manifest if audio_dir and dataset_manifest both presents # setup GPU if cfg.cuda is None: if torch.cuda.is_available(): device = [0] # use 0th CUDA device accelerator = 'gpu' else: device = 1 accelerator = 'cpu' else: device = [cfg.cuda] accelerator = 'gpu' map_location = torch.device('cuda:{}'.format(device[0]) if accelerator == 'gpu' else 'cpu') logging.info(f"Inference will be done on device : {device}") asr_model, model_name = setup_model(cfg, map_location) model_cfg = copy.deepcopy(asr_model._cfg) OmegaConf.set_struct(model_cfg.preprocessor, False) # some changes for streaming scenario model_cfg.preprocessor.dither = 0.0 model_cfg.preprocessor.pad_to = 0 if model_cfg.preprocessor.normalize != "per_feature": logging.error("Only EncDecCTCModelBPE models trained with per_feature normalization are supported currently") # Disable config overwriting OmegaConf.set_struct(model_cfg.preprocessor, True) # setup AMP (optional) if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): logging.info("AMP enabled!\n") autocast = torch.cuda.amp.autocast else: @contextlib.contextmanager def autocast(): yield # Compute output filename cfg = compute_output_filename(cfg, model_name) # if transcripts should not be overwritten, and already exists, skip re-transcription step and return if not cfg.overwrite_transcripts and os.path.exists(cfg.output_filename): logging.info( f"Previous transcripts found at {cfg.output_filename}, and flag `overwrite_transcripts`" f"is {cfg.overwrite_transcripts}. Returning without re-transcribing text." ) return cfg asr_model.eval() asr_model = asr_model.to(asr_model.device) feature_stride = model_cfg.preprocessor['window_stride'] model_stride_in_secs = feature_stride * cfg.model_stride total_buffer = cfg.total_buffer_in_secs chunk_len = float(cfg.chunk_len_in_secs) tokens_per_chunk = math.ceil(chunk_len / model_stride_in_secs) mid_delay = math.ceil((chunk_len + (total_buffer - chunk_len) / 2) / model_stride_in_secs) logging.info(f"tokens_per_chunk is {tokens_per_chunk}, mid_delay is {mid_delay}") frame_asr = FrameBatchASR( asr_model=asr_model, frame_len=chunk_len, total_buffer=cfg.total_buffer_in_secs, batch_size=cfg.batch_size, ) hyps = get_buffered_pred_feat( frame_asr, chunk_len, tokens_per_chunk, mid_delay, model_cfg.preprocessor, model_stride_in_secs, asr_model.device, manifest, filepaths, ) output_filename = write_transcription(hyps, cfg, model_name, filepaths=filepaths, compute_langs=False) logging.info(f"Finished writing predictions to {output_filename}!") return cfg if __name__ == '__main__': main() # noqa pylint: disable=no-value-for-parameter