import argparse import importlib import os from argparse import RawTextHelpFormatter import numpy as np import torch from torch.utils.data import DataLoader from tqdm import tqdm from TTS.config import load_config from TTS.tts.datasets.TTSDataset import TTSDataset from TTS.tts.models import setup_model from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_checkpoint if __name__ == "__main__": # pylint: disable=bad-option-value parser = argparse.ArgumentParser( description="""Extract attention masks from trained Tacotron/Tacotron2 models. These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n""" """Each attention mask is written to the same path as the input wav file with ".npy" file extension. (e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n""" """ Example run: CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py --model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth --config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json --dataset_metafile metadata.csv --data_path /root/LJSpeech-1.1/ --batch_size 32 --dataset ljspeech --use_cuda True """, formatter_class=RawTextHelpFormatter, ) parser.add_argument("--model_path", type=str, required=True, help="Path to Tacotron/Tacotron2 model file ") parser.add_argument( "--config_path", type=str, required=True, help="Path to Tacotron/Tacotron2 config file.", ) parser.add_argument( "--dataset", type=str, default="", required=True, help="Target dataset processor name from TTS.tts.dataset.preprocess.", ) parser.add_argument( "--dataset_metafile", type=str, default="", required=True, help="Dataset metafile inclusing file paths with transcripts.", ) parser.add_argument("--data_path", type=str, default="", help="Defines the data path. It overwrites config.json.") parser.add_argument("--use_cuda", type=bool, default=False, help="enable/disable cuda.") parser.add_argument( "--batch_size", default=16, type=int, help="Batch size for the model. Use batch_size=1 if you have no CUDA." ) args = parser.parse_args() C = load_config(args.config_path) ap = AudioProcessor(**C.audio) # if the vocabulary was passed, replace the default if "characters" in C.keys(): symbols, phonemes = make_symbols(**C.characters) # load the model num_chars = len(phonemes) if C.use_phonemes else len(symbols) # TODO: handle multi-speaker model = setup_model(C) model, _ = load_checkpoint(model, args.model_path, args.use_cuda, True) # data loader preprocessor = importlib.import_module("TTS.tts.datasets.formatters") preprocessor = getattr(preprocessor, args.dataset) meta_data = preprocessor(args.data_path, args.dataset_metafile) dataset = TTSDataset( model.decoder.r, C.text_cleaner, compute_linear_spec=False, ap=ap, meta_data=meta_data, characters=C.characters if "characters" in C.keys() else None, add_blank=C["add_blank"] if "add_blank" in C.keys() else False, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, phoneme_language=C.phoneme_language, enable_eos_bos=C.enable_eos_bos_chars, ) dataset.sort_and_filter_items(C.get("sort_by_audio_len", default=False)) loader = DataLoader( dataset, batch_size=args.batch_size, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False, ) # compute attentions file_paths = [] with torch.no_grad(): for data in tqdm(loader): # setup input data text_input = data[0] text_lengths = data[1] linear_input = data[3] mel_input = data[4] mel_lengths = data[5] stop_targets = data[6] item_idxs = data[7] # dispatch data to GPU if args.use_cuda: text_input = text_input.cuda() text_lengths = text_lengths.cuda() mel_input = mel_input.cuda() mel_lengths = mel_lengths.cuda() model_outputs = model.forward(text_input, text_lengths, mel_input) alignments = model_outputs["alignments"].detach() for idx, alignment in enumerate(alignments): item_idx = item_idxs[idx] # interpolate if r > 1 alignment = ( torch.nn.functional.interpolate( alignment.transpose(0, 1).unsqueeze(0), size=None, scale_factor=model.decoder.r, mode="nearest", align_corners=None, recompute_scale_factor=None, ) .squeeze(0) .transpose(0, 1) ) # remove paddings alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy() # set file paths wav_file_name = os.path.basename(item_idx) align_file_name = os.path.splitext(wav_file_name)[0] + "_attn.npy" file_path = item_idx.replace(wav_file_name, align_file_name) # save output wav_file_abs_path = os.path.abspath(item_idx) file_abs_path = os.path.abspath(file_path) file_paths.append([wav_file_abs_path, file_abs_path]) np.save(file_path, alignment) # ourput metafile metafile = os.path.join(args.data_path, "metadata_attn_mask.txt") with open(metafile, "w", encoding="utf-8") as f: for p in file_paths: f.write(f"{p[0]}|{p[1]}\n") print(f" >> Metafile created: {metafile}")