|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import contextlib |
|
import os |
|
from dataclasses import dataclass, is_dataclass |
|
from typing import Optional |
|
|
|
import pytorch_lightning as pl |
|
import torch |
|
from omegaconf import OmegaConf |
|
|
|
from nemo.collections.asr.metrics.rnnt_wer import RNNTDecodingConfig |
|
from nemo.collections.asr.metrics.wer import CTCDecodingConfig |
|
from nemo.collections.asr.models.ctc_models import EncDecCTCModel |
|
from nemo.collections.asr.modules.conformer_encoder import ConformerChangeConfig |
|
from nemo.collections.asr.parts.utils.transcribe_utils import ( |
|
compute_output_filename, |
|
prepare_audio_data, |
|
setup_model, |
|
transcribe_partial_audio, |
|
write_transcription, |
|
) |
|
from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer |
|
from nemo.core.config import hydra_runner |
|
from nemo.utils import logging |
|
|
|
""" |
|
Transcribe audio file on a single CPU/GPU. Useful for transcription of moderate amounts of audio data. |
|
|
|
# Arguments |
|
model_path: path to .nemo ASR checkpoint |
|
pretrained_name: name of pretrained ASR model (from NGC registry) |
|
audio_dir: path to directory with audio files |
|
dataset_manifest: path to dataset JSON manifest file (in NeMo format) |
|
|
|
compute_timestamps: Bool to request greedy time stamp information (if the model supports it) |
|
compute_langs: Bool to request language ID information (if the model supports it) |
|
|
|
(Optionally: You can limit the type of timestamp computations using below overrides) |
|
ctc_decoding.ctc_timestamp_type="all" # (default all, can be [all, char, word]) |
|
rnnt_decoding.rnnt_timestamp_type="all" # (default all, can be [all, char, word]) |
|
|
|
(Optionally: You can limit the type of timestamp computations using below overrides) |
|
ctc_decoding.ctc_timestamp_type="all" # (default all, can be [all, char, word]) |
|
rnnt_decoding.rnnt_timestamp_type="all" # (default all, can be [all, char, word]) |
|
|
|
output_filename: Output filename where the transcriptions will be written |
|
batch_size: batch size during inference |
|
|
|
cuda: Optional int to enable or disable execution of model on certain CUDA device. |
|
amp: Bool to decide if Automatic Mixed Precision should be used during inference |
|
audio_type: Str filetype of the audio. Supported = wav, flac, mp3 |
|
|
|
overwrite_transcripts: Bool which when set allows repeated transcriptions to overwrite previous results. |
|
|
|
ctc_decoding: Decoding sub-config for CTC. Refer to documentation for specific values. |
|
rnnt_decoding: Decoding sub-config for RNNT. Refer to documentation for specific values. |
|
|
|
# Usage |
|
ASR model can be specified by either "model_path" or "pretrained_name". |
|
Data for transcription can be defined with either "audio_dir" or "dataset_manifest". |
|
append_pred - optional. Allows you to add more than one prediction to an existing .json |
|
pred_name_postfix - optional. The name you want to be written for the current model |
|
Results are returned in a JSON manifest file. |
|
|
|
python transcribe_speech.py \ |
|
model_path=null \ |
|
pretrained_name=null \ |
|
audio_dir="<remove or path to folder of audio files>" \ |
|
dataset_manifest="<remove or path to manifest>" \ |
|
output_filename="<remove or specify output filename>" \ |
|
batch_size=32 \ |
|
compute_timestamps=False \ |
|
compute_langs=False \ |
|
cuda=0 \ |
|
amp=True \ |
|
append_pred=False \ |
|
pred_name_postfix="<remove or use another model name for output filename>" |
|
""" |
|
|
|
|
|
@dataclass |
|
class ModelChangeConfig: |
|
|
|
|
|
conformer: ConformerChangeConfig = ConformerChangeConfig() |
|
|
|
|
|
@dataclass |
|
class TranscriptionConfig: |
|
|
|
model_path: Optional[str] = None |
|
pretrained_name: Optional[str] = None |
|
audio_dir: Optional[str] = None |
|
dataset_manifest: Optional[str] = None |
|
channel_selector: Optional[int] = None |
|
audio_key: str = 'audio_filepath' |
|
eval_config_yaml: Optional[str] = None |
|
|
|
|
|
output_filename: Optional[str] = None |
|
batch_size: int = 32 |
|
num_workers: int = 0 |
|
append_pred: bool = False |
|
pred_name_postfix: Optional[str] = None |
|
random_seed: Optional[int] = None |
|
|
|
|
|
compute_timestamps: bool = False |
|
|
|
|
|
compute_langs: bool = False |
|
|
|
|
|
|
|
|
|
cuda: Optional[int] = None |
|
amp: bool = False |
|
audio_type: str = "wav" |
|
|
|
|
|
overwrite_transcripts: bool = True |
|
|
|
|
|
ctc_decoding: CTCDecodingConfig = CTCDecodingConfig() |
|
|
|
|
|
rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(fused_batch_size=-1) |
|
|
|
|
|
decoder_type: Optional[str] = None |
|
|
|
|
|
model_change: ModelChangeConfig = ModelChangeConfig() |
|
|
|
|
|
@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) |
|
def main(cfg: TranscriptionConfig) -> TranscriptionConfig: |
|
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') |
|
|
|
if is_dataclass(cfg): |
|
cfg = OmegaConf.structured(cfg) |
|
|
|
if cfg.random_seed: |
|
pl.seed_everything(cfg.random_seed) |
|
|
|
if cfg.model_path is None and cfg.pretrained_name is None: |
|
raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!") |
|
if cfg.audio_dir is None and cfg.dataset_manifest is None: |
|
raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!") |
|
|
|
|
|
augmentor = None |
|
if cfg.eval_config_yaml: |
|
eval_config = OmegaConf.load(cfg.eval_config_yaml) |
|
augmentor = eval_config.test_ds.get("augmentor") |
|
logging.info(f"Will apply on-the-fly augmentation on samples during transcription: {augmentor} ") |
|
|
|
|
|
if cfg.cuda is None: |
|
if torch.cuda.is_available(): |
|
device = [0] |
|
accelerator = 'gpu' |
|
else: |
|
device = 1 |
|
accelerator = 'cpu' |
|
else: |
|
device = [cfg.cuda] |
|
accelerator = 'gpu' |
|
map_location = torch.device('cuda:{}'.format(device[0]) if accelerator == 'gpu' else 'cpu') |
|
logging.info(f"Inference will be done on device : {device}") |
|
|
|
asr_model, model_name = setup_model(cfg, map_location) |
|
|
|
trainer = pl.Trainer(devices=device, accelerator=accelerator) |
|
asr_model.set_trainer(trainer) |
|
asr_model = asr_model.eval() |
|
|
|
|
|
return_hypotheses = True |
|
|
|
|
|
compute_timestamps = cfg.compute_timestamps |
|
compute_langs = cfg.compute_langs |
|
|
|
|
|
if hasattr(asr_model, 'change_decoding_strategy'): |
|
if cfg.decoder_type is not None: |
|
|
|
if cfg.compute_langs and cfg.decoder_type == 'ctc': |
|
raise ValueError("CTC models do not support `compute_langs` at the moment") |
|
|
|
decoding_cfg = cfg.rnnt_decoding if cfg.decoder_type == 'rnnt' else cfg.ctc_decoding |
|
decoding_cfg.compute_timestamps = cfg.compute_timestamps |
|
if 'preserve_alignments' in decoding_cfg: |
|
decoding_cfg.preserve_alignments = cfg.compute_timestamps |
|
if 'compute_langs' in decoding_cfg: |
|
decoding_cfg.compute_langs = cfg.compute_langs |
|
|
|
asr_model.change_decoding_strategy(decoding_cfg, decoder_type=cfg.decoder_type) |
|
|
|
|
|
elif hasattr(asr_model, 'joint'): |
|
cfg.rnnt_decoding.fused_batch_size = -1 |
|
cfg.rnnt_decoding.compute_timestamps = cfg.compute_timestamps |
|
cfg.rnnt_decoding.compute_langs = cfg.compute_langs |
|
|
|
if 'preserve_alignments' in cfg.rnnt_decoding: |
|
cfg.rnnt_decoding.preserve_alignments = cfg.compute_timestamps |
|
|
|
asr_model.change_decoding_strategy(cfg.rnnt_decoding) |
|
else: |
|
if cfg.compute_langs: |
|
raise ValueError("CTC models do not support `compute_langs` at the moment.") |
|
cfg.ctc_decoding.compute_timestamps = cfg.compute_timestamps |
|
|
|
asr_model.change_decoding_strategy(cfg.ctc_decoding) |
|
|
|
|
|
filepaths, partial_audio = prepare_audio_data(cfg) |
|
|
|
|
|
if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): |
|
logging.info("AMP enabled!\n") |
|
autocast = torch.cuda.amp.autocast |
|
else: |
|
|
|
@contextlib.contextmanager |
|
def autocast(): |
|
yield |
|
|
|
|
|
cfg = compute_output_filename(cfg, model_name) |
|
|
|
|
|
if not cfg.overwrite_transcripts and os.path.exists(cfg.output_filename): |
|
logging.info( |
|
f"Previous transcripts found at {cfg.output_filename}, and flag `overwrite_transcripts`" |
|
f"is {cfg.overwrite_transcripts}. Returning without re-transcribing text." |
|
) |
|
return cfg |
|
|
|
|
|
with autocast(): |
|
with torch.no_grad(): |
|
if partial_audio: |
|
if isinstance(asr_model, EncDecCTCModel): |
|
transcriptions = transcribe_partial_audio( |
|
asr_model=asr_model, |
|
path2manifest=cfg.dataset_manifest, |
|
batch_size=cfg.batch_size, |
|
num_workers=cfg.num_workers, |
|
return_hypotheses=return_hypotheses, |
|
channel_selector=cfg.channel_selector, |
|
augmentor=augmentor, |
|
) |
|
else: |
|
logging.warning( |
|
"RNNT models do not support transcribe partial audio for now. Transcribing full audio." |
|
) |
|
transcriptions = asr_model.transcribe( |
|
paths2audio_files=filepaths, |
|
batch_size=cfg.batch_size, |
|
num_workers=cfg.num_workers, |
|
return_hypotheses=return_hypotheses, |
|
channel_selector=cfg.channel_selector, |
|
augmentor=augmentor, |
|
) |
|
else: |
|
transcriptions = asr_model.transcribe( |
|
paths2audio_files=filepaths, |
|
batch_size=cfg.batch_size, |
|
num_workers=cfg.num_workers, |
|
return_hypotheses=return_hypotheses, |
|
channel_selector=cfg.channel_selector, |
|
augmentor=augmentor, |
|
) |
|
|
|
logging.info(f"Finished transcribing {len(filepaths)} files !") |
|
logging.info(f"Writing transcriptions into file: {cfg.output_filename}") |
|
|
|
|
|
if type(transcriptions) == tuple and len(transcriptions) == 2: |
|
transcriptions = transcriptions[0] |
|
|
|
|
|
output_filename = write_transcription( |
|
transcriptions, |
|
cfg, |
|
model_name, |
|
filepaths=filepaths, |
|
compute_langs=compute_langs, |
|
compute_timestamps=compute_timestamps, |
|
) |
|
logging.info(f"Finished writing predictions to {output_filename}!") |
|
|
|
return cfg |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|