import os import torch import numpy as np from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from grad_extend.data import TextMelSpeakerDataset, TextMelSpeakerBatchCollate from grad_extend.utils import plot_tensor, save_plot, load_model, print_error from grad.utils import fix_len_compatibility from grad.model import GradTTS # 200 frames out_size = fix_len_compatibility(200) def train(hps, chkpt_path=None): print('Initializing logger...') logger = SummaryWriter(log_dir=hps.train.log_dir) print('Initializing data loaders...') train_dataset = TextMelSpeakerDataset(hps.train.train_files) batch_collate = TextMelSpeakerBatchCollate() loader = DataLoader(dataset=train_dataset, batch_size=hps.train.batch_size, collate_fn=batch_collate, drop_last=True, num_workers=8, shuffle=True) test_dataset = TextMelSpeakerDataset(hps.train.valid_files) print('Initializing model...') model = GradTTS(hps.grad.n_mels, hps.grad.n_vecs, hps.grad.n_pits, hps.grad.n_spks, hps.grad.n_embs, hps.grad.n_enc_channels, hps.grad.filter_channels, hps.grad.dec_dim, hps.grad.beta_min, hps.grad.beta_max, hps.grad.pe_scale).cuda() print('Number of encoder parameters = %.2fm' % (model.encoder.nparams/1e6)) print('Number of decoder parameters = %.2fm' % (model.decoder.nparams/1e6)) # Load Pretrain if os.path.isfile(hps.train.pretrain): print("Start from Grad_SVC pretrain model: %s" % hps.train.pretrain) checkpoint = torch.load(hps.train.pretrain, map_location='cpu') load_model(model, checkpoint['model']) hps.train.learning_rate = 2e-5 # fine_tune model.fine_tune() else: print_error(10 * '~' + "No Pretrain Model" + 10 * '~') print('Initializing optimizer...') optim = torch.optim.Adam(params=model.parameters(), lr=hps.train.learning_rate) initepoch = 1 iteration = 0 # Load Continue if chkpt_path is not None: print("Resuming from checkpoint: %s" % chkpt_path) checkpoint = torch.load(chkpt_path, map_location='cpu') model.load_state_dict(checkpoint['model']) optim.load_state_dict(checkpoint['optim']) initepoch = checkpoint['epoch'] iteration = checkpoint['steps'] print('Logging test batch...') test_batch = test_dataset.sample_test_batch(size=hps.train.test_size) for i, item in enumerate(test_batch): mel = item['mel'] logger.add_image(f'image_{i}/ground_truth', plot_tensor(mel.squeeze()), global_step=0, dataformats='HWC') save_plot(mel.squeeze(), f'{hps.train.log_dir}/original_{i}.png') print('Start training...') skip_diff_train = True if initepoch >= hps.train.fast_epochs: skip_diff_train = False for epoch in range(initepoch, hps.train.full_epochs + 1): if epoch % hps.train.test_step == 0: model.eval() print('Synthesis...') with torch.no_grad(): for i, item in enumerate(test_batch): l_vec = item['vec'].shape[0] d_vec = item['vec'].shape[1] lengths_fix = fix_len_compatibility(l_vec) lengths = torch.LongTensor([l_vec]).cuda() vec = torch.zeros((1, lengths_fix, d_vec), dtype=torch.float32).cuda() pit = torch.zeros((1, lengths_fix), dtype=torch.float32).cuda() spk = item['spk'].to(torch.float32).unsqueeze(0).cuda() vec[0, :l_vec, :] = item['vec'] pit[0, :l_vec] = item['pit'] y_enc, y_dec = model(lengths, vec, pit, spk, n_timesteps=50) logger.add_image(f'image_{i}/generated_enc', plot_tensor(y_enc.squeeze().cpu()), global_step=iteration, dataformats='HWC') logger.add_image(f'image_{i}/generated_dec', plot_tensor(y_dec.squeeze().cpu()), global_step=iteration, dataformats='HWC') save_plot(y_enc.squeeze().cpu(), f'{hps.train.log_dir}/generated_enc_{i}.png') save_plot(y_dec.squeeze().cpu(), f'{hps.train.log_dir}/generated_dec_{i}.png') model.train() prior_losses = [] diff_losses = [] mel_losses = [] spk_losses = [] with tqdm(loader, total=len(train_dataset)//hps.train.batch_size) as progress_bar: for batch in progress_bar: model.zero_grad() lengths = batch['lengths'].cuda() vec = batch['vec'].cuda() pit = batch['pit'].cuda() spk = batch['spk'].cuda() mel = batch['mel'].cuda() prior_loss, diff_loss, mel_loss, spk_loss = model.compute_loss( lengths, vec, pit, spk, mel, out_size=out_size, skip_diff=skip_diff_train) loss = sum([prior_loss, diff_loss, mel_loss, spk_loss]) loss.backward() enc_grad_norm = torch.nn.utils.clip_grad_norm_(model.encoder.parameters(), max_norm=1) dec_grad_norm = torch.nn.utils.clip_grad_norm_(model.decoder.parameters(), max_norm=1) optim.step() logger.add_scalar('training/mel_loss', mel_loss, global_step=iteration) logger.add_scalar('training/prior_loss', prior_loss, global_step=iteration) logger.add_scalar('training/diffusion_loss', diff_loss, global_step=iteration) logger.add_scalar('training/encoder_grad_norm', enc_grad_norm, global_step=iteration) logger.add_scalar('training/decoder_grad_norm', dec_grad_norm, global_step=iteration) msg = f'Epoch: {epoch}, iteration: {iteration} | ' msg = msg + f'prior_loss: {prior_loss.item():.3f}, ' msg = msg + f'diff_loss: {diff_loss.item():.3f}, ' msg = msg + f'mel_loss: {mel_loss.item():.3f}, ' msg = msg + f'spk_loss: {spk_loss.item():.3f}, ' progress_bar.set_description(msg) prior_losses.append(prior_loss.item()) diff_losses.append(diff_loss.item()) mel_losses.append(mel_loss.item()) spk_losses.append(spk_loss.item()) iteration += 1 msg = 'Epoch %d: ' % (epoch) msg += '| spk loss = %.3f ' % np.mean(spk_losses) msg += '| mel loss = %.3f ' % np.mean(mel_losses) msg += '| prior loss = %.3f ' % np.mean(prior_losses) msg += '| diffusion loss = %.3f\n' % np.mean(diff_losses) with open(f'{hps.train.log_dir}/train.log', 'a') as f: f.write(msg) # if (np.mean(prior_losses) < 1.05): # skip_diff_train = False if epoch > hps.train.fast_epochs: skip_diff_train = False if epoch % hps.train.save_step > 0: continue save_path = f"{hps.train.log_dir}/grad_svc_{epoch}.pt" torch.save({ 'model': model.state_dict(), 'optim': optim.state_dict(), 'epoch': epoch, 'steps': iteration, }, save_path) print("Saved checkpoint to: %s" % save_path)