|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
import math |
|
import os |
|
from dataclasses import dataclass, field, is_dataclass |
|
from pathlib import Path |
|
from typing import List, Optional |
|
|
|
import torch |
|
from omegaconf import OmegaConf |
|
from utils.data_prep import ( |
|
add_t_start_end_to_utt_obj, |
|
get_batch_starts_ends, |
|
get_batch_variables, |
|
get_manifest_lines_batch, |
|
is_entry_in_all_lines, |
|
is_entry_in_any_lines, |
|
) |
|
from utils.make_ass_files import make_ass_files |
|
from utils.make_ctm_files import make_ctm_files |
|
from utils.make_output_manifest import write_manifest_out_line |
|
from utils.viterbi_decoding import viterbi_decoding |
|
|
|
from nemo.collections.asr.models.ctc_models import EncDecCTCModel |
|
from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel |
|
from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchASR |
|
from nemo.collections.asr.parts.utils.transcribe_utils import setup_model |
|
from nemo.core.config import hydra_runner |
|
from nemo.utils import logging |
|
|
|
""" |
|
Align the utterances in manifest_filepath. |
|
Results are saved in ctm files in output_dir. |
|
|
|
Arguments: |
|
pretrained_name: string specifying the name of a CTC NeMo ASR model which will be automatically downloaded |
|
from NGC and used for generating the log-probs which we will use to do alignment. |
|
Note: NFA can only use CTC models (not Transducer models) at the moment. |
|
model_path: string specifying the local filepath to a CTC NeMo ASR model which will be used to generate the |
|
log-probs which we will use to do alignment. |
|
Note: NFA can only use CTC models (not Transducer models) at the moment. |
|
Note: if a model_path is provided, it will override the pretrained_name. |
|
manifest_filepath: filepath to the manifest of the data you want to align, |
|
containing 'audio_filepath' and 'text' fields. |
|
output_dir: the folder where output CTM files and new JSON manifest will be saved. |
|
align_using_pred_text: if True, will transcribe the audio using the specified model and then use that transcription |
|
as the reference text for the forced alignment. |
|
transcribe_device: None, or a string specifying the device that will be used for generating log-probs (i.e. "transcribing"). |
|
The string needs to be in a format recognized by torch.device(). If None, NFA will set it to 'cuda' if it is available |
|
(otherwise will set it to 'cpu'). |
|
viterbi_device: None, or string specifying the device that will be used for doing Viterbi decoding. |
|
The string needs to be in a format recognized by torch.device(). If None, NFA will set it to 'cuda' if it is available |
|
(otherwise will set it to 'cpu'). |
|
batch_size: int specifying batch size that will be used for generating log-probs and doing Viterbi decoding. |
|
use_local_attention: boolean flag specifying whether to try to use local attention for the ASR Model (will only |
|
work if the ASR Model is a Conformer model). If local attention is used, we will set the local attention context |
|
size to [64,64]. |
|
additional_segment_grouping_separator: an optional string used to separate the text into smaller segments. |
|
If this is not specified, then the whole text will be treated as a single segment. |
|
remove_blank_tokens_from_ctm: a boolean denoting whether to remove <blank> tokens from token-level output CTMs. |
|
audio_filepath_parts_in_utt_id: int specifying how many of the 'parts' of the audio_filepath |
|
we will use (starting from the final part of the audio_filepath) to determine the |
|
utt_id that will be used in the CTM files. Note also that any spaces that are present in the audio_filepath |
|
will be replaced with dashes, so as not to change the number of space-separated elements in the |
|
CTM files. |
|
e.g. if audio_filepath is "/a/b/c/d/e 1.wav" and audio_filepath_parts_in_utt_id is 1 => utt_id will be "e1" |
|
e.g. if audio_filepath is "/a/b/c/d/e 1.wav" and audio_filepath_parts_in_utt_id is 2 => utt_id will be "d_e1" |
|
e.g. if audio_filepath is "/a/b/c/d/e 1.wav" and audio_filepath_parts_in_utt_id is 3 => utt_id will be "c_d_e1" |
|
use_buffered_infer: False, if set True, using streaming to do get the logits for alignment |
|
This flag is useful when aligning large audio file. |
|
However, currently the chunk streaming inference does not support batch inference, |
|
which means even you set batch_size > 1, it will only infer one by one instead of doing |
|
the whole batch inference together. |
|
chunk_len_in_secs: float chunk length in seconds |
|
total_buffer_in_secs: float Length of buffer (chunk + left and right padding) in seconds |
|
chunk_batch_size: int batch size for buffered chunk inference, |
|
which will cut one audio into segments and do inference on chunk_batch_size segments at a time |
|
|
|
simulate_cache_aware_streaming: False, if set True, using cache aware streaming to do get the logits for alignment |
|
|
|
save_output_file_formats: List of strings specifying what type of output files to save (default: ["ctm", "ass"]) |
|
ctm_file_config: CTMFileConfig to specify the configuration of the output CTM files |
|
ass_file_config: ASSFileConfig to specify the configuration of the output ASS files |
|
""" |
|
|
|
|
|
@dataclass |
|
class CTMFileConfig: |
|
remove_blank_tokens: bool = False |
|
|
|
|
|
|
|
|
|
minimum_timestamp_duration: float = 0 |
|
|
|
|
|
@dataclass |
|
class ASSFileConfig: |
|
fontsize: int = 20 |
|
vertical_alignment: str = "center" |
|
|
|
|
|
|
|
resegment_text_to_fill_space: bool = False |
|
max_lines_per_segment: int = 2 |
|
text_already_spoken_rgb: List[int] = field(default_factory=lambda: [49, 46, 61]) |
|
text_being_spoken_rgb: List[int] = field(default_factory=lambda: [57, 171, 9]) |
|
text_not_yet_spoken_rgb: List[int] = field(default_factory=lambda: [194, 193, 199]) |
|
|
|
|
|
@dataclass |
|
class AlignmentConfig: |
|
|
|
pretrained_name: Optional[str] = None |
|
model_path: Optional[str] = None |
|
manifest_filepath: Optional[str] = None |
|
output_dir: Optional[str] = None |
|
|
|
|
|
align_using_pred_text: bool = False |
|
transcribe_device: Optional[str] = None |
|
viterbi_device: Optional[str] = None |
|
batch_size: int = 1 |
|
use_local_attention: bool = True |
|
additional_segment_grouping_separator: Optional[str] = None |
|
audio_filepath_parts_in_utt_id: int = 1 |
|
|
|
|
|
use_buffered_chunked_streaming: bool = False |
|
chunk_len_in_secs: float = 1.6 |
|
total_buffer_in_secs: float = 4.0 |
|
chunk_batch_size: int = 32 |
|
|
|
|
|
simulate_cache_aware_streaming: Optional[bool] = False |
|
|
|
|
|
save_output_file_formats: List[str] = field(default_factory=lambda: ["ctm", "ass"]) |
|
ctm_file_config: CTMFileConfig = CTMFileConfig() |
|
ass_file_config: ASSFileConfig = ASSFileConfig() |
|
|
|
|
|
@hydra_runner(config_name="AlignmentConfig", schema=AlignmentConfig) |
|
def main(cfg: AlignmentConfig): |
|
|
|
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') |
|
|
|
if is_dataclass(cfg): |
|
cfg = OmegaConf.structured(cfg) |
|
|
|
|
|
if cfg.model_path is None and cfg.pretrained_name is None: |
|
raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None") |
|
|
|
if cfg.model_path is not None and cfg.pretrained_name is not None: |
|
raise ValueError("One of cfg.model_path and cfg.pretrained_name must be None") |
|
|
|
if cfg.manifest_filepath is None: |
|
raise ValueError("cfg.manifest_filepath must be specified") |
|
|
|
if cfg.output_dir is None: |
|
raise ValueError("cfg.output_dir must be specified") |
|
|
|
if cfg.batch_size < 1: |
|
raise ValueError("cfg.batch_size cannot be zero or a negative number") |
|
|
|
if cfg.additional_segment_grouping_separator == "" or cfg.additional_segment_grouping_separator == " ": |
|
raise ValueError("cfg.additional_grouping_separator cannot be empty string or space character") |
|
|
|
if cfg.ctm_file_config.minimum_timestamp_duration < 0: |
|
raise ValueError("cfg.minimum_timestamp_duration cannot be a negative number") |
|
|
|
if cfg.ass_file_config.vertical_alignment not in ["top", "center", "bottom"]: |
|
raise ValueError("cfg.ass_file_config.vertical_alignment must be one of 'top', 'center' or 'bottom'") |
|
|
|
for rgb_list in [ |
|
cfg.ass_file_config.text_already_spoken_rgb, |
|
cfg.ass_file_config.text_already_spoken_rgb, |
|
cfg.ass_file_config.text_already_spoken_rgb, |
|
]: |
|
if len(rgb_list) != 3: |
|
raise ValueError( |
|
"cfg.ass_file_config.text_already_spoken_rgb," |
|
" cfg.ass_file_config.text_being_spoken_rgb," |
|
" and cfg.ass_file_config.text_already_spoken_rgb all need to contain" |
|
" exactly 3 elements." |
|
) |
|
|
|
|
|
if not is_entry_in_all_lines(cfg.manifest_filepath, "audio_filepath"): |
|
raise RuntimeError( |
|
"At least one line in cfg.manifest_filepath does not contain an 'audio_filepath' entry. " |
|
"All lines must contain an 'audio_filepath' entry." |
|
) |
|
|
|
if cfg.align_using_pred_text: |
|
if is_entry_in_any_lines(cfg.manifest_filepath, "pred_text"): |
|
raise RuntimeError( |
|
"Cannot specify cfg.align_using_pred_text=True when the manifest at cfg.manifest_filepath " |
|
"contains 'pred_text' entries. This is because the audio will be transcribed and may produce " |
|
"a different 'pred_text'. This may cause confusion." |
|
) |
|
else: |
|
if not is_entry_in_all_lines(cfg.manifest_filepath, "text"): |
|
raise RuntimeError( |
|
"At least one line in cfg.manifest_filepath does not contain a 'text' entry. " |
|
"NFA requires all lines to contain a 'text' entry when cfg.align_using_pred_text=False." |
|
) |
|
|
|
|
|
if cfg.transcribe_device is None: |
|
transcribe_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
else: |
|
transcribe_device = torch.device(cfg.transcribe_device) |
|
logging.info(f"Device to be used for transcription step (`transcribe_device`) is {transcribe_device}") |
|
|
|
if cfg.viterbi_device is None: |
|
viterbi_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
else: |
|
viterbi_device = torch.device(cfg.viterbi_device) |
|
logging.info(f"Device to be used for viterbi step (`viterbi_device`) is {viterbi_device}") |
|
|
|
if transcribe_device.type == 'cuda' or viterbi_device.type == 'cuda': |
|
logging.warning( |
|
'One or both of transcribe_device and viterbi_device are GPUs. If you run into OOM errors ' |
|
'it may help to change both devices to be the CPU.' |
|
) |
|
|
|
|
|
model, _ = setup_model(cfg, transcribe_device) |
|
model.eval() |
|
|
|
if isinstance(model, EncDecHybridRNNTCTCModel): |
|
model.change_decoding_strategy(decoder_type="ctc") |
|
|
|
if cfg.use_local_attention: |
|
logging.info( |
|
"Flag use_local_attention is set to True => will try to use local attention for model if it allows it" |
|
) |
|
model.change_attention_model(self_attention_model="rel_pos_local_attn", att_context_size=[64, 64]) |
|
|
|
if not (isinstance(model, EncDecCTCModel) or isinstance(model, EncDecHybridRNNTCTCModel)): |
|
raise NotImplementedError( |
|
f"Model is not an instance of NeMo EncDecCTCModel or ENCDecHybridRNNTCTCModel." |
|
" Currently only instances of these models are supported" |
|
) |
|
|
|
if cfg.ctm_file_config.minimum_timestamp_duration > 0: |
|
logging.warning( |
|
f"cfg.ctm_file_config.minimum_timestamp_duration has been set to {cfg.ctm_file_config.minimum_timestamp_duration} seconds. " |
|
"This may cause the alignments for some tokens/words/additional segments to be overlapping." |
|
) |
|
|
|
buffered_chunk_params = {} |
|
if cfg.use_buffered_chunked_streaming: |
|
model_cfg = copy.deepcopy(model._cfg) |
|
|
|
OmegaConf.set_struct(model_cfg.preprocessor, False) |
|
|
|
model_cfg.preprocessor.dither = 0.0 |
|
model_cfg.preprocessor.pad_to = 0 |
|
|
|
if model_cfg.preprocessor.normalize != "per_feature": |
|
logging.error( |
|
"Only EncDecCTCModelBPE models trained with per_feature normalization are supported currently" |
|
) |
|
|
|
OmegaConf.set_struct(model_cfg.preprocessor, True) |
|
|
|
feature_stride = model_cfg.preprocessor['window_stride'] |
|
model_stride_in_secs = feature_stride * cfg.model_downsample_factor |
|
total_buffer = cfg.total_buffer_in_secs |
|
chunk_len = float(cfg.chunk_len_in_secs) |
|
tokens_per_chunk = math.ceil(chunk_len / model_stride_in_secs) |
|
mid_delay = math.ceil((chunk_len + (total_buffer - chunk_len) / 2) / model_stride_in_secs) |
|
logging.info(f"tokens_per_chunk is {tokens_per_chunk}, mid_delay is {mid_delay}") |
|
|
|
model = FrameBatchASR( |
|
asr_model=model, |
|
frame_len=chunk_len, |
|
total_buffer=cfg.total_buffer_in_secs, |
|
batch_size=cfg.chunk_batch_size, |
|
) |
|
buffered_chunk_params = { |
|
"delay": mid_delay, |
|
"model_stride_in_secs": model_stride_in_secs, |
|
"tokens_per_chunk": tokens_per_chunk, |
|
} |
|
|
|
starts, ends = get_batch_starts_ends(cfg.manifest_filepath, cfg.batch_size) |
|
|
|
|
|
output_timestep_duration = None |
|
|
|
|
|
os.makedirs(cfg.output_dir, exist_ok=True) |
|
tgt_manifest_name = str(Path(cfg.manifest_filepath).stem) + "_with_output_file_paths.json" |
|
tgt_manifest_filepath = str(Path(cfg.output_dir) / tgt_manifest_name) |
|
f_manifest_out = open(tgt_manifest_filepath, 'w') |
|
|
|
|
|
for start, end in zip(starts, ends): |
|
manifest_lines_batch = get_manifest_lines_batch(cfg.manifest_filepath, start, end) |
|
|
|
(log_probs_batch, y_batch, T_batch, U_batch, utt_obj_batch, output_timestep_duration,) = get_batch_variables( |
|
manifest_lines_batch, |
|
model, |
|
cfg.additional_segment_grouping_separator, |
|
cfg.align_using_pred_text, |
|
cfg.audio_filepath_parts_in_utt_id, |
|
output_timestep_duration, |
|
cfg.simulate_cache_aware_streaming, |
|
cfg.use_buffered_chunked_streaming, |
|
buffered_chunk_params, |
|
) |
|
|
|
alignments_batch = viterbi_decoding(log_probs_batch, y_batch, T_batch, U_batch, viterbi_device) |
|
|
|
for utt_obj, alignment_utt in zip(utt_obj_batch, alignments_batch): |
|
|
|
utt_obj = add_t_start_end_to_utt_obj(utt_obj, alignment_utt, output_timestep_duration) |
|
|
|
if "ctm" in cfg.save_output_file_formats: |
|
utt_obj = make_ctm_files(utt_obj, cfg.output_dir, cfg.ctm_file_config,) |
|
|
|
if "ass" in cfg.save_output_file_formats: |
|
utt_obj = make_ass_files(utt_obj, cfg.output_dir, cfg.ass_file_config) |
|
|
|
write_manifest_out_line( |
|
f_manifest_out, utt_obj, |
|
) |
|
|
|
f_manifest_out.close() |
|
|
|
return None |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|