from multiprocessing.pool import Pool import matplotlib from utils.pl_utils import data_loader from utils.training_utils import RSQRTSchedule from vocoders.base_vocoder import get_vocoder_cls, BaseVocoder from modules.fastspeech.pe import PitchExtractor matplotlib.use('Agg') import os import numpy as np from tqdm import tqdm import torch.distributed as dist from tasks.base_task import BaseTask from utils.hparams import hparams from utils.text_encoder import TokenTextEncoder import json import torch import torch.optim import torch.utils.data import utils class TtsTask(BaseTask): def __init__(self, *args, **kwargs): self.vocoder = None self.phone_encoder = self.build_phone_encoder(hparams['binary_data_dir']) self.padding_idx = self.phone_encoder.pad() self.eos_idx = self.phone_encoder.eos() self.seg_idx = self.phone_encoder.seg() self.saving_result_pool = None self.saving_results_futures = None self.stats = {} super().__init__(*args, **kwargs) def build_scheduler(self, optimizer): return RSQRTSchedule(optimizer) def build_optimizer(self, model): self.optimizer = optimizer = torch.optim.AdamW( model.parameters(), lr=hparams['lr']) return optimizer def build_dataloader(self, dataset, shuffle, max_tokens=None, max_sentences=None, required_batch_size_multiple=-1, endless=False, batch_by_size=True): devices_cnt = torch.cuda.device_count() if devices_cnt == 0: devices_cnt = 1 if required_batch_size_multiple == -1: required_batch_size_multiple = devices_cnt def shuffle_batches(batches): np.random.shuffle(batches) return batches if max_tokens is not None: max_tokens *= devices_cnt if max_sentences is not None: max_sentences *= devices_cnt indices = dataset.ordered_indices() if batch_by_size: batch_sampler = utils.batch_by_size( indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences, required_batch_size_multiple=required_batch_size_multiple, ) else: batch_sampler = [] for i in range(0, len(indices), max_sentences): batch_sampler.append(indices[i:i + max_sentences]) if shuffle: batches = shuffle_batches(list(batch_sampler)) if endless: batches = [b for _ in range(1000) for b in shuffle_batches(list(batch_sampler))] else: batches = batch_sampler if endless: batches = [b for _ in range(1000) for b in batches] num_workers = dataset.num_workers if self.trainer.use_ddp: num_replicas = dist.get_world_size() rank = dist.get_rank() batches = [x[rank::num_replicas] for x in batches if len(x) % num_replicas == 0] return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collater, batch_sampler=batches, num_workers=num_workers, pin_memory=False) def build_phone_encoder(self, data_dir): phone_list_file = os.path.join(data_dir, 'phone_set.json') phone_list = json.load(open(phone_list_file)) return TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',') def build_optimizer(self, model): self.optimizer = optimizer = torch.optim.AdamW( model.parameters(), lr=hparams['lr']) return optimizer def test_start(self): self.saving_result_pool = Pool(8) self.saving_results_futures = [] self.vocoder: BaseVocoder = get_vocoder_cls(hparams)() if hparams.get('pe_enable') is not None and hparams['pe_enable']: self.pe = PitchExtractor().cuda() utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True) self.pe.eval() def test_end(self, outputs): self.saving_result_pool.close() [f.get() for f in tqdm(self.saving_results_futures)] self.saving_result_pool.join() return {} ########## # utils ########## def weights_nonzero_speech(self, target): # target : B x T x mel # Assign weight 1.0 to all labels except for padding (id=0). dim = target.size(-1) return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim) if __name__ == '__main__': TtsTask.start()