# Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved. # This program is free software; you can redistribute it and/or modify # it under the terms of the MIT License. # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # MIT License for more details. import numpy as np from tqdm import tqdm import torch from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter import finetune_params as params from model import GradTTS from data import TextMelSpeakerDataset, TextMelSpeakerBatchCollate from utils import plot_tensor, save_plot from text.symbols import symbols train_filelist_path = params.train_filelist_path valid_filelist_path = params.valid_filelist_path cmudict_path = params.cmudict_path add_blank = params.add_blank n_spks = params.n_spks spk_emb_dim = params.spk_emb_dim log_dir = params.log_dir n_epochs = params.n_epochs batch_size = params.batch_size out_size = params.out_size learning_rate = params.learning_rate random_seed = params.seed nsymbols = len(symbols) + 1 if add_blank else len(symbols) n_enc_channels = params.n_enc_channels filter_channels = params.filter_channels filter_channels_dp = params.filter_channels_dp n_enc_layers = params.n_enc_layers enc_kernel = params.enc_kernel enc_dropout = params.enc_dropout n_heads = params.n_heads window_size = params.window_size n_feats = params.n_feats n_fft = params.n_fft sample_rate = params.sample_rate hop_length = params.hop_length win_length = params.win_length f_min = params.f_min f_max = params.f_max dec_dim = params.dec_dim beta_min = params.beta_min beta_max = params.beta_max pe_scale = params.pe_scale num_workers = params.num_workers checkpoint = params.checkpoint if __name__ == "__main__": torch.manual_seed(random_seed) np.random.seed(random_seed) print("Initializing logger...") logger = SummaryWriter(log_dir=log_dir) print("Initializing data loaders...") train_dataset = TextMelSpeakerDataset( train_filelist_path, cmudict_path, add_blank, n_fft, n_feats, sample_rate, hop_length, win_length, f_min, f_max, ) batch_collate = TextMelSpeakerBatchCollate() loader = DataLoader( dataset=train_dataset, batch_size=batch_size, collate_fn=batch_collate, drop_last=True, num_workers=num_workers, shuffle=True, ) test_dataset = TextMelSpeakerDataset( valid_filelist_path, cmudict_path, add_blank, n_fft, n_feats, sample_rate, hop_length, win_length, f_min, f_max, ) print("Initializing model...") model = GradTTS( nsymbols, n_spks, spk_emb_dim, n_enc_channels, filter_channels, filter_channels_dp, n_heads, n_enc_layers, enc_kernel, enc_dropout, window_size, n_feats, dec_dim, beta_min, beta_max, pe_scale, ).cuda() model.load_state_dict(torch.load(checkpoint, map_location=torch.device("cuda"))) print("Number of encoder parameters = %.2fm" % (model.encoder.nparams / 1e6)) print("Number of decoder parameters = %.2fm" % (model.decoder.nparams / 1e6)) print("Initializing optimizer...") optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate) print("Logging test batch...") test_batch = test_dataset.sample_test_batch(size=params.test_size) for item in test_batch: mel, spk = item["y"], item["spk"] i = int(spk.cpu()) logger.add_image( f"image_{i}/ground_truth", plot_tensor(mel.squeeze()), global_step=0, dataformats="HWC", ) save_plot(mel.squeeze(), f"{log_dir}/original_{i}.png") print("Start training...") iteration = 0 for epoch in range(1, n_epochs + 1): model.eval() print("Synthesis...") with torch.no_grad(): for item in test_batch: x = item["x"].to(torch.long).unsqueeze(0).cuda() x_lengths = torch.LongTensor([x.shape[-1]]).cuda() spk = item["spk"].to(torch.long).cuda() i = int(spk.cpu()) y_enc, y_dec, attn = model(x, x_lengths, n_timesteps=50, spk=spk) logger.add_image( f"image_{i}/generated_enc", plot_tensor(y_enc.squeeze().cpu()), global_step=iteration, dataformats="HWC", ) logger.add_image( f"image_{i}/generated_dec", plot_tensor(y_dec.squeeze().cpu()), global_step=iteration, dataformats="HWC", ) logger.add_image( f"image_{i}/alignment", plot_tensor(attn.squeeze().cpu()), global_step=iteration, dataformats="HWC", ) save_plot(y_enc.squeeze().cpu(), f"{log_dir}/generated_enc_{i}.png") save_plot(y_dec.squeeze().cpu(), f"{log_dir}/generated_dec_{i}.png") save_plot(attn.squeeze().cpu(), f"{log_dir}/alignment_{i}.png") model.train() dur_losses = [] prior_losses = [] diff_losses = [] with tqdm(loader, total=len(train_dataset) // batch_size) as progress_bar: for batch in progress_bar: model.zero_grad() x, x_lengths = batch["x"].cuda(), batch["x_lengths"].cuda() y, y_lengths = batch["y"].cuda(), batch["y_lengths"].cuda() spk = batch["spk"].cuda() dur_loss, prior_loss, diff_loss = model.compute_loss( x, x_lengths, y, y_lengths, spk=spk, out_size=out_size ) loss = sum([dur_loss, prior_loss, diff_loss]) loss.backward() enc_grad_norm = torch.nn.utils.clip_grad_norm_( model.encoder.parameters(), max_norm=1 ) dec_grad_norm = torch.nn.utils.clip_grad_norm_( model.decoder.parameters(), max_norm=1 ) optimizer.step() logger.add_scalar( "training/duration_loss", dur_loss, global_step=iteration ) logger.add_scalar( "training/prior_loss", prior_loss, global_step=iteration ) logger.add_scalar( "training/diffusion_loss", diff_loss, global_step=iteration ) logger.add_scalar( "training/encoder_grad_norm", enc_grad_norm, global_step=iteration ) logger.add_scalar( "training/decoder_grad_norm", dec_grad_norm, global_step=iteration ) msg = f"Epoch: {epoch}, iteration: {iteration} | dur_loss: {dur_loss.item()}, prior_loss: {prior_loss.item()}, diff_loss: {diff_loss.item()}" progress_bar.set_description(msg) dur_losses.append(dur_loss.item()) prior_losses.append(prior_loss.item()) diff_losses.append(diff_loss.item()) iteration += 1 msg = "Epoch %d: duration loss = %.3f " % (epoch, np.mean(dur_losses)) msg += "| prior loss = %.3f " % np.mean(prior_losses) msg += "| diffusion loss = %.3f\n" % np.mean(diff_losses) with open(f"{log_dir}/train.log", "a") as f: f.write(msg) if epoch % params.save_every > 0: continue ckpt = model.state_dict() torch.save(ckpt, f=f"{log_dir}/grad_{epoch}.pt")