# Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved. # This program is free software; you can redistribute it and/or modify # it under the terms of the MIT License. # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # MIT License for more details. import numpy as np from tqdm import tqdm import torch from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter import params from model import GradTTS from data import TextMelDataset, TextMelBatchCollate from utils import plot_tensor, save_plot from text.symbols import symbols train_filelist_path = params.train_filelist_path valid_filelist_path = params.valid_filelist_path cmudict_path = params.cmudict_path add_blank = params.add_blank log_dir = params.log_dir n_epochs = params.n_epochs batch_size = params.batch_size out_size = params.out_size learning_rate = params.learning_rate random_seed = params.seed n_workers = params.n_workers nsymbols = len(symbols) + 1 if add_blank else len(symbols) n_enc_channels = params.n_enc_channels filter_channels = params.filter_channels filter_channels_dp = params.filter_channels_dp n_enc_layers = params.n_enc_layers enc_kernel = params.enc_kernel enc_dropout = params.enc_dropout n_heads = params.n_heads window_size = params.window_size n_feats = params.n_feats n_fft = params.n_fft sample_rate = params.sample_rate hop_length = params.hop_length win_length = params.win_length f_min = params.f_min f_max = params.f_max dec_dim = params.dec_dim beta_min = params.beta_min beta_max = params.beta_max pe_scale = params.pe_scale num_workers = params.num_workers if __name__ == "__main__": torch.manual_seed(random_seed) np.random.seed(random_seed) print('Initializing logger...') logger = SummaryWriter(log_dir=log_dir) print('Initializing data loaders...') train_dataset = TextMelDataset(train_filelist_path, cmudict_path, add_blank, n_fft, n_feats, sample_rate, hop_length, win_length, f_min, f_max) batch_collate = TextMelBatchCollate() loader = DataLoader(dataset=train_dataset, batch_size=batch_size, collate_fn=batch_collate, drop_last=True, num_workers=num_workers, shuffle=False) test_dataset = TextMelDataset(valid_filelist_path, cmudict_path, add_blank, n_fft, n_feats, sample_rate, hop_length, win_length, f_min, f_max) print('Initializing model...') model = GradTTS(nsymbols, 1, None, n_enc_channels, filter_channels, filter_channels_dp, n_heads, n_enc_layers, enc_kernel, enc_dropout, window_size, n_feats, dec_dim, beta_min, beta_max, pe_scale).cuda() print('Number of encoder + duration predictor parameters: %.2fm' % (model.encoder.nparams/1e6)) print('Number of decoder parameters: %.2fm' % (model.decoder.nparams/1e6)) print('Total parameters: %.2fm' % (model.nparams/1e6)) print('Initializing optimizer...') optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate) print('Logging test batch...') test_batch = test_dataset.sample_test_batch(size=params.test_size) for i, item in enumerate(test_batch): mel = item['y'] logger.add_image(f'image_{i}/ground_truth', plot_tensor(mel.squeeze()), global_step=0, dataformats='HWC') save_plot(mel.squeeze(), f'{log_dir}/original_{i}.png') print('Start training...') iteration = 0 for epoch in range(1, n_epochs + 1): model.train() dur_losses = [] prior_losses = [] diff_losses = [] with tqdm(loader, total=len(train_dataset)//batch_size) as progress_bar: for batch_idx, batch in enumerate(progress_bar): model.zero_grad() x, x_lengths = batch['x'].cuda(), batch['x_lengths'].cuda() y, y_lengths = batch['y'].cuda(), batch['y_lengths'].cuda() dur_loss, prior_loss, diff_loss = model.compute_loss(x, x_lengths, y, y_lengths, out_size=out_size) loss = sum([dur_loss, prior_loss, diff_loss]) loss.backward() enc_grad_norm = torch.nn.utils.clip_grad_norm_(model.encoder.parameters(), max_norm=1) dec_grad_norm = torch.nn.utils.clip_grad_norm_(model.decoder.parameters(), max_norm=1) optimizer.step() logger.add_scalar('training/duration_loss', dur_loss.item(), global_step=iteration) logger.add_scalar('training/prior_loss', prior_loss.item(), global_step=iteration) logger.add_scalar('training/diffusion_loss', diff_loss.item(), global_step=iteration) logger.add_scalar('training/encoder_grad_norm', enc_grad_norm, global_step=iteration) logger.add_scalar('training/decoder_grad_norm', dec_grad_norm, global_step=iteration) dur_losses.append(dur_loss.item()) prior_losses.append(prior_loss.item()) diff_losses.append(diff_loss.item()) if batch_idx % 5 == 0: msg = f'Epoch: {epoch}, iteration: {iteration} | dur_loss: {dur_loss.item()}, prior_loss: {prior_loss.item()}, diff_loss: {diff_loss.item()}' progress_bar.set_description(msg) iteration += 1 log_msg = 'Epoch %d: duration loss = %.3f ' % (epoch, np.mean(dur_losses)) log_msg += '| prior loss = %.3f ' % np.mean(prior_losses) log_msg += '| diffusion loss = %.3f\n' % np.mean(diff_losses) with open(f'{log_dir}/train.log', 'a') as f: f.write(log_msg) if epoch % params.save_every > 0: continue model.eval() print('Synthesis...') with torch.no_grad(): for i, item in enumerate(test_batch): x = item['x'].to(torch.long).unsqueeze(0).cuda() x_lengths = torch.LongTensor([x.shape[-1]]).cuda() y_enc, y_dec, attn = model(x, x_lengths, n_timesteps=50) logger.add_image(f'image_{i}/generated_enc', plot_tensor(y_enc.squeeze().cpu()), global_step=iteration, dataformats='HWC') logger.add_image(f'image_{i}/generated_dec', plot_tensor(y_dec.squeeze().cpu()), global_step=iteration, dataformats='HWC') logger.add_image(f'image_{i}/alignment', plot_tensor(attn.squeeze().cpu()), global_step=iteration, dataformats='HWC') save_plot(y_enc.squeeze().cpu(), f'{log_dir}/generated_enc_{i}.png') save_plot(y_dec.squeeze().cpu(), f'{log_dir}/generated_dec_{i}.png') save_plot(attn.squeeze().cpu(), f'{log_dir}/alignment_{i}.png') ckpt = model.state_dict() torch.save(ckpt, f=f"{log_dir}/grad_{epoch}.pt")