Spaces:
Runtime error
Runtime error
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
This script serves three goals: | |
(1) Demonstrate how to use NeMo Models outside of PytorchLightning | |
(2) Shows example of batch ASR inference | |
(3) Serves as CI test for pre-trained checkpoint | |
python speech_to_text_buffered_infer_ctc.py \ | |
model_path=null \ | |
pretrained_name=null \ | |
audio_dir="<remove or path to folder of audio files>" \ | |
dataset_manifest="<remove or path to manifest>" \ | |
output_filename="<remove or specify output filename>" \ | |
total_buffer_in_secs=4.0 \ | |
chunk_len_in_secs=1.6 \ | |
model_stride=4 \ | |
batch_size=32 | |
# NOTE: | |
You can use `DEBUG=1 python speech_to_text_buffered_infer_ctc.py ...` to print out the | |
predictions of the model, and ground-truth text if presents in manifest. | |
""" | |
import contextlib | |
import copy | |
import glob | |
import math | |
import os | |
from dataclasses import dataclass, is_dataclass | |
from typing import Optional | |
import torch | |
from omegaconf import OmegaConf | |
from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchASR | |
from nemo.collections.asr.parts.utils.transcribe_utils import ( | |
compute_output_filename, | |
get_buffered_pred_feat, | |
setup_model, | |
write_transcription, | |
) | |
from nemo.core.config import hydra_runner | |
from nemo.utils import logging | |
can_gpu = torch.cuda.is_available() | |
class TranscriptionConfig: | |
# Required configs | |
model_path: Optional[str] = None # Path to a .nemo file | |
pretrained_name: Optional[str] = None # Name of a pretrained model | |
audio_dir: Optional[str] = None # Path to a directory which contains audio files | |
dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest | |
# General configs | |
output_filename: Optional[str] = None | |
batch_size: int = 32 | |
num_workers: int = 0 | |
append_pred: bool = False # Sets mode of work, if True it will add new field transcriptions. | |
pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one. | |
# Chunked configs | |
chunk_len_in_secs: float = 1.6 # Chunk length in seconds | |
total_buffer_in_secs: float = 4.0 # Length of buffer (chunk + left and right padding) in seconds | |
model_stride: int = 8 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models", | |
# Set `cuda` to int to define CUDA device. If 'None', will look for CUDA | |
# device anyway, and do inference on CPU only if CUDA device is not found. | |
# If `cuda` is a negative number, inference will be on CPU only. | |
cuda: Optional[int] = None | |
amp: bool = False | |
audio_type: str = "wav" | |
# Recompute model transcription, even if the output folder exists with scores. | |
overwrite_transcripts: bool = True | |
def main(cfg: TranscriptionConfig) -> TranscriptionConfig: | |
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') | |
torch.set_grad_enabled(False) | |
if is_dataclass(cfg): | |
cfg = OmegaConf.structured(cfg) | |
if cfg.model_path is None and cfg.pretrained_name is None: | |
raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!") | |
if cfg.audio_dir is None and cfg.dataset_manifest is None: | |
raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!") | |
filepaths = None | |
manifest = cfg.dataset_manifest | |
if cfg.audio_dir is not None: | |
filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"**/*.{cfg.audio_type}"), recursive=True)) | |
manifest = None # ignore dataset_manifest if audio_dir and dataset_manifest both presents | |
# setup GPU | |
if cfg.cuda is None: | |
if torch.cuda.is_available(): | |
device = [0] # use 0th CUDA device | |
accelerator = 'gpu' | |
else: | |
device = 1 | |
accelerator = 'cpu' | |
else: | |
device = [cfg.cuda] | |
accelerator = 'gpu' | |
map_location = torch.device('cuda:{}'.format(device[0]) if accelerator == 'gpu' else 'cpu') | |
logging.info(f"Inference will be done on device : {device}") | |
asr_model, model_name = setup_model(cfg, map_location) | |
model_cfg = copy.deepcopy(asr_model._cfg) | |
OmegaConf.set_struct(model_cfg.preprocessor, False) | |
# some changes for streaming scenario | |
model_cfg.preprocessor.dither = 0.0 | |
model_cfg.preprocessor.pad_to = 0 | |
if model_cfg.preprocessor.normalize != "per_feature": | |
logging.error("Only EncDecCTCModelBPE models trained with per_feature normalization are supported currently") | |
# Disable config overwriting | |
OmegaConf.set_struct(model_cfg.preprocessor, True) | |
# setup AMP (optional) | |
if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): | |
logging.info("AMP enabled!\n") | |
autocast = torch.cuda.amp.autocast | |
else: | |
def autocast(): | |
yield | |
# Compute output filename | |
cfg = compute_output_filename(cfg, model_name) | |
# if transcripts should not be overwritten, and already exists, skip re-transcription step and return | |
if not cfg.overwrite_transcripts and os.path.exists(cfg.output_filename): | |
logging.info( | |
f"Previous transcripts found at {cfg.output_filename}, and flag `overwrite_transcripts`" | |
f"is {cfg.overwrite_transcripts}. Returning without re-transcribing text." | |
) | |
return cfg | |
asr_model.eval() | |
asr_model = asr_model.to(asr_model.device) | |
feature_stride = model_cfg.preprocessor['window_stride'] | |
model_stride_in_secs = feature_stride * cfg.model_stride | |
total_buffer = cfg.total_buffer_in_secs | |
chunk_len = float(cfg.chunk_len_in_secs) | |
tokens_per_chunk = math.ceil(chunk_len / model_stride_in_secs) | |
mid_delay = math.ceil((chunk_len + (total_buffer - chunk_len) / 2) / model_stride_in_secs) | |
logging.info(f"tokens_per_chunk is {tokens_per_chunk}, mid_delay is {mid_delay}") | |
frame_asr = FrameBatchASR( | |
asr_model=asr_model, frame_len=chunk_len, total_buffer=cfg.total_buffer_in_secs, batch_size=cfg.batch_size, | |
) | |
hyps = get_buffered_pred_feat( | |
frame_asr, | |
chunk_len, | |
tokens_per_chunk, | |
mid_delay, | |
model_cfg.preprocessor, | |
model_stride_in_secs, | |
asr_model.device, | |
manifest, | |
filepaths, | |
) | |
output_filename = write_transcription(hyps, cfg, model_name, filepaths=filepaths, compute_langs=False) | |
logging.info(f"Finished writing predictions to {output_filename}!") | |
return cfg | |
if __name__ == '__main__': | |
main() # noqa pylint: disable=no-value-for-parameter | |