| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | import logging |
| | import os |
| | import sys |
| | import time |
| | from pathlib import Path |
| |
|
| | import numpy as np |
| | import scipy.io.wavfile as wav |
| | import torch |
| | from joblib import Parallel, delayed |
| | from tqdm import tqdm |
| | from utils import get_segments |
| |
|
| | import nemo.collections.asr as nemo_asr |
| | from nemo.collections.asr.models.ctc_models import EncDecCTCModel |
| | from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel |
| |
|
| | parser = argparse.ArgumentParser(description="CTC Segmentation") |
| | parser.add_argument("--output_dir", default="output", type=str, help="Path to output directory") |
| | parser.add_argument( |
| | "--data", |
| | type=str, |
| | required=True, |
| | help="Path to directory with audio files and associated transcripts (same respective names only formats are " |
| | "different or path to wav file (transcript should have the same base name and be located in the same folder" |
| | "as the wav file.", |
| | ) |
| | parser.add_argument("--window_len", type=int, default=8000, help="Window size for ctc segmentation algorithm") |
| | parser.add_argument("--sample_rate", type=int, default=16000, help="Sampling rate, Hz") |
| | parser.add_argument( |
| | "--model", type=str, default="QuartzNet15x5Base-En", help="Path to model checkpoint or pre-trained model name", |
| | ) |
| | parser.add_argument("--debug", action="store_true", help="Flag to enable debugging messages") |
| | parser.add_argument( |
| | "--num_jobs", |
| | default=-2, |
| | type=int, |
| | help="The maximum number of concurrently running jobs, `-2` - all CPUs but one are used", |
| | ) |
| |
|
| | logger = logging.getLogger("ctc_segmentation") |
| |
|
| | if __name__ == "__main__": |
| |
|
| | args = parser.parse_args() |
| | logging.basicConfig(level=logging.INFO) |
| | |
| | log_dir = os.path.join(args.output_dir, "logs") |
| | os.makedirs(log_dir, exist_ok=True) |
| | log_file = os.path.join(log_dir, f"ctc_segmentation_{args.window_len}.log") |
| | if os.path.exists(log_file): |
| | os.remove(log_file) |
| | level = "DEBUG" if args.debug else "INFO" |
| |
|
| | logger = logging.getLogger("CTC") |
| | file_handler = logging.FileHandler(filename=log_file) |
| | stdout_handler = logging.StreamHandler(sys.stdout) |
| | handlers = [file_handler, stdout_handler] |
| | logging.basicConfig(handlers=handlers, level=level) |
| |
|
| | if os.path.exists(args.model): |
| | asr_model = nemo_asr.models.ASRModel.restore_from(args.model) |
| | else: |
| | asr_model = nemo_asr.models.ASRModel.from_pretrained(args.model, strict=False) |
| |
|
| | if not (isinstance(asr_model, EncDecCTCModel) or isinstance(asr_model, EncDecHybridRNNTCTCModel)): |
| | raise NotImplementedError( |
| | f"Model is not an instance of NeMo EncDecCTCModel or ENCDecHybridRNNTCTCModel." |
| | " Currently only instances of these models are supported" |
| | ) |
| |
|
| | bpe_model = isinstance(asr_model, nemo_asr.models.EncDecCTCModelBPE) or isinstance( |
| | asr_model, nemo_asr.models.EncDecHybridRNNTCTCBPEModel |
| | ) |
| |
|
| | |
| | if bpe_model: |
| | tokenizer = asr_model.tokenizer |
| | else: |
| | tokenizer = None |
| |
|
| | if isinstance(asr_model, EncDecHybridRNNTCTCModel): |
| | asr_model.change_decoding_strategy(decoder_type="ctc") |
| |
|
| | |
| | if hasattr(asr_model, 'tokenizer'): |
| | vocabulary = asr_model.tokenizer.vocab |
| | elif hasattr(asr_model.decoder, "vocabulary"): |
| | vocabulary = asr_model.cfg.decoder.vocabulary |
| | else: |
| | raise ValueError("Unexpected model type. Vocabulary list not found.") |
| |
|
| | vocabulary = ["ε"] + list(vocabulary) |
| | logging.debug(f"ASR Model vocabulary: {vocabulary}") |
| |
|
| | data = Path(args.data) |
| | output_dir = Path(args.output_dir) |
| |
|
| | if os.path.isdir(data): |
| | audio_paths = data.glob("*.wav") |
| | data_dir = data |
| | else: |
| | audio_paths = [Path(data)] |
| | data_dir = Path(os.path.dirname(data)) |
| |
|
| | all_log_probs = [] |
| | all_transcript_file = [] |
| | all_segment_file = [] |
| | all_wav_paths = [] |
| | segments_dir = os.path.join(args.output_dir, "segments") |
| | os.makedirs(segments_dir, exist_ok=True) |
| |
|
| | index_duration = None |
| | for path_audio in audio_paths: |
| | logging.info(f"Processing {path_audio.name}...") |
| | transcript_file = os.path.join(data_dir, path_audio.name.replace(".wav", ".txt")) |
| | segment_file = os.path.join( |
| | segments_dir, f"{args.window_len}_" + path_audio.name.replace(".wav", "_segments.txt") |
| | ) |
| | if not os.path.exists(transcript_file): |
| | logging.info(f"{transcript_file} not found. Skipping {path_audio.name}") |
| | continue |
| | try: |
| | sample_rate, signal = wav.read(path_audio) |
| | if len(signal) == 0: |
| | logging.error(f"Skipping {path_audio.name}") |
| | continue |
| |
|
| | assert ( |
| | sample_rate == args.sample_rate |
| | ), f"Sampling rate of the audio file {path_audio} doesn't match --sample_rate={args.sample_rate}" |
| |
|
| | original_duration = len(signal) / sample_rate |
| | logging.debug(f"len(signal): {len(signal)}, sr: {sample_rate}") |
| | logging.debug(f"Duration: {original_duration}s, file_name: {path_audio}") |
| |
|
| | hypotheses = asr_model.transcribe([str(path_audio)], batch_size=1, return_hypotheses=True) |
| | |
| | if type(hypotheses) == tuple and len(hypotheses) == 2: |
| | hypotheses = hypotheses[0] |
| | log_probs = hypotheses[ |
| | 0 |
| | ].alignments |
| |
|
| | |
| | blank_col = log_probs[:, -1].reshape((log_probs.shape[0], 1)) |
| | log_probs = np.concatenate((blank_col, log_probs[:, :-1]), axis=1) |
| |
|
| | all_log_probs.append(log_probs) |
| | all_segment_file.append(str(segment_file)) |
| | all_transcript_file.append(str(transcript_file)) |
| | all_wav_paths.append(path_audio) |
| |
|
| | if index_duration is None: |
| | index_duration = len(signal) / log_probs.shape[0] / sample_rate |
| |
|
| | except Exception as e: |
| | logging.error(e) |
| | logging.error(f"Skipping {path_audio.name}") |
| | continue |
| |
|
| | asr_model_type = type(asr_model) |
| | del asr_model |
| | torch.cuda.empty_cache() |
| |
|
| | if len(all_log_probs) > 0: |
| | start_time = time.time() |
| |
|
| | normalized_lines = Parallel(n_jobs=args.num_jobs)( |
| | delayed(get_segments)( |
| | all_log_probs[i], |
| | all_wav_paths[i], |
| | all_transcript_file[i], |
| | all_segment_file[i], |
| | vocabulary, |
| | tokenizer, |
| | bpe_model, |
| | index_duration, |
| | args.window_len, |
| | log_file=log_file, |
| | debug=args.debug, |
| | ) |
| | for i in tqdm(range(len(all_log_probs))) |
| | ) |
| |
|
| | total_time = time.time() - start_time |
| | logger.info(f"Total execution time: ~{round(total_time/60)}min") |
| | logger.info(f"Saving logs to {log_file}") |
| |
|
| | if os.path.exists(log_file): |
| | with open(log_file, "r") as f: |
| | lines = f.readlines() |
| |
|