import hydra from hydra import utils from itertools import chain from pathlib import Path import numpy as np import torch import torch.optim as optim from torch.utils.data import DataLoader from dataset import CPCDataset_sameSeq as CPCDataset from scheduler import WarmupScheduler from model_encoder import Encoder, CPCLoss_sameSeq, Encoder_lf0 from model_decoder import Decoder_ac from model_encoder import SpeakerEncoder as Encoder_spk from mi_estimators import CLUBSample_group, CLUBSample_reshape import apex.amp as amp import os import time torch.manual_seed(137) np.random.seed(137) def save_checkpoint(encoder, encoder_lf0, cpc, encoder_spk, \ cs_mi_net, ps_mi_net, cp_mi_net, decoder, \ optimizer, optimizer_cs_mi_net, optimizer_ps_mi_net, optimizer_cp_mi_net, scheduler, amp, epoch, checkpoint_dir, cfg): if cfg.use_amp: amp_state_dict = amp.state_dict() else: amp_state_dict = None checkpoint_state = { "encoder": encoder.state_dict(), "encoder_lf0": encoder_lf0.state_dict(), "cpc": cpc.state_dict(), "encoder_spk": encoder_spk.state_dict(), "ps_mi_net": ps_mi_net.state_dict(), "cp_mi_net": cp_mi_net.state_dict(), "cs_mi_net": cs_mi_net.state_dict(), "decoder": decoder.state_dict(), "optimizer": optimizer.state_dict(), "optimizer_cs_mi_net": optimizer_cs_mi_net.state_dict(), "optimizer_ps_mi_net": optimizer_ps_mi_net.state_dict(), "optimizer_cp_mi_net": optimizer_cp_mi_net.state_dict(), "scheduler": scheduler.state_dict(), "amp": amp_state_dict, "epoch": epoch } checkpoint_dir.mkdir(exist_ok=True, parents=True) checkpoint_path = checkpoint_dir / "model.ckpt-{}.pt".format(epoch) torch.save(checkpoint_state, checkpoint_path) print("Saved checkpoint: {}".format(checkpoint_path.stem)) def mi_first_forward(mels, lf0, encoder, encoder_lf0, encoder_spk, cs_mi_net, optimizer_cs_mi_net, ps_mi_net, optimizer_ps_mi_net, cp_mi_net, optimizer_cp_mi_net, cfg): optimizer_cs_mi_net.zero_grad() optimizer_ps_mi_net.zero_grad() optimizer_cp_mi_net.zero_grad() z, _, _, _, _ = encoder(mels) z = z.detach() lf0_embs = encoder_lf0(lf0).detach() spk_embs = encoder_spk(mels).detach() if cfg.use_CSMI: lld_cs_loss = -cs_mi_net.loglikeli(spk_embs, z) if cfg.use_amp: with amp.scale_loss(lld_cs_loss, optimizer_cs_mi_net) as sl: sl.backward() else: lld_cs_loss.backward() optimizer_cs_mi_net.step() else: lld_cs_loss = torch.tensor(0.) if cfg.use_CPMI: lld_cp_loss = -cp_mi_net.loglikeli(lf0_embs.unsqueeze(1).reshape(lf0_embs.shape[0],-1,2,lf0_embs.shape[-1]).mean(2), z) if cfg.use_amp: with amp.scale_loss(lld_cp_loss, optimizer_cp_mi_net) as slll: slll.backward() else: lld_cp_loss.backward() torch.nn.utils.clip_grad_norm_(cp_mi_net.parameters(), 1) optimizer_cp_mi_net.step() else: lld_cp_loss = torch.tensor(0.) if cfg.use_PSMI: lld_ps_loss = -ps_mi_net.loglikeli(spk_embs, lf0_embs) if cfg.use_amp: with amp.scale_loss(lld_ps_loss, optimizer_ps_mi_net) as sll: sll.backward() else: lld_ps_loss.backward() optimizer_ps_mi_net.step() else: lld_ps_loss = torch.tensor(0.) return optimizer_cs_mi_net, lld_cs_loss, optimizer_ps_mi_net, lld_ps_loss, optimizer_cp_mi_net, lld_cp_loss def mi_second_forward(mels, lf0, encoder, encoder_lf0, cpc, encoder_spk, cs_mi_net, ps_mi_net, cp_mi_net, decoder, cfg, optimizer, scheduler): optimizer.zero_grad() z, c, _, vq_loss, perplexity = encoder(mels) cpc_loss, accuracy = cpc(z, c) spk_embs = encoder_spk(mels) lf0_embs = encoder_lf0(lf0) recon_loss, pred_mels = decoder(z, lf0_embs, spk_embs, mels.transpose(1,2)) loss = recon_loss + cpc_loss + vq_loss if cfg.use_CSMI: mi_cs_loss = cfg.mi_weight*cs_mi_net.mi_est(spk_embs, z) else: mi_cs_loss = torch.tensor(0.).to(loss.device) if cfg.use_CPMI: mi_cp_loss = cfg.mi_weight*cp_mi_net.mi_est(lf0_embs.unsqueeze(1).reshape(lf0_embs.shape[0],-1,2,lf0_embs.shape[-1]).mean(2), z) else: mi_cp_loss = torch.tensor(0.).to(loss.device) if cfg.use_PSMI: mi_ps_loss = cfg.mi_weight*ps_mi_net.mi_est(spk_embs, lf0_embs) else: mi_ps_loss = torch.tensor(0.).to(loss.device) loss = loss + mi_cs_loss + mi_ps_loss + mi_cp_loss if cfg.use_amp: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() optimizer.step() return optimizer, recon_loss, vq_loss, cpc_loss, accuracy, perplexity, mi_cs_loss, mi_ps_loss, mi_cp_loss def calculate_eval_loss(mels, lf0, \ encoder, encoder_lf0, cpc, \ encoder_spk, cs_mi_net, ps_mi_net, \ cp_mi_net, decoder, cfg): with torch.no_grad(): z, c, z_beforeVQ, vq_loss, perplexity = encoder(mels) c = c lf0_embs = encoder_lf0(lf0) spk_embs = encoder_spk(mels) if cfg.use_CSMI: lld_cs_loss = -cs_mi_net.loglikeli(spk_embs, z) mi_cs_loss = cfg.mi_weight*cs_mi_net.mi_est(spk_embs, z) else: lld_cs_loss = torch.tensor(0.) mi_cs_loss = torch.tensor(0.) # z, c, z_beforeVQ, vq_loss, perplexity = encoder(mels) cpc_loss, accuracy = cpc(z, c) recon_loss, pred_mels = decoder(z, lf0_embs, spk_embs, mels.transpose(1,2)) if cfg.use_CPMI: mi_cp_loss = cfg.mi_weight*cp_mi_net.mi_est(lf0_embs.unsqueeze(1).reshape(lf0_embs.shape[0],-1,2,lf0_embs.shape[-1]).mean(2), z) lld_cp_loss = -cp_mi_net.loglikeli(lf0_embs.unsqueeze(1).reshape(lf0_embs.shape[0],-1,2,lf0_embs.shape[-1]).mean(2), z) else: mi_cp_loss = torch.tensor(0.) lld_cp_loss = torch.tensor(0.) if cfg.use_PSMI: mi_ps_loss = cfg.mi_weight*ps_mi_net.mi_est(spk_embs, lf0_embs) lld_ps_loss = -ps_mi_net.loglikeli(spk_embs, lf0_embs) else: mi_ps_loss = torch.tensor(0.) lld_ps_loss = torch.tensor(0.) return recon_loss, vq_loss, cpc_loss, accuracy, perplexity, mi_cs_loss, lld_cs_loss, mi_ps_loss, lld_ps_loss, mi_cp_loss, lld_cp_loss def to_eval(all_models): for m in all_models: m.eval() def to_train(all_models): for m in all_models: m.train() def eval_model(epoch, checkpoint_dir, device, valid_dataloader, encoder, encoder_lf0, cpc, encoder_spk, cs_mi_net, ps_mi_net, cp_mi_net, decoder, cfg): stime = time.time() average_cpc_loss = average_vq_loss = average_perplexity = average_recon_loss = 0 average_accuracies = np.zeros(cfg.training.n_prediction_steps) average_lld_cs_loss = average_mi_cs_loss = average_lld_ps_loss = average_mi_ps_loss = average_lld_cp_loss = average_mi_cp_loss = 0 all_models = [encoder, encoder_lf0, cpc, encoder_spk, cs_mi_net, ps_mi_net, cp_mi_net, decoder] to_eval(all_models) for i, (mels, lf0, speakers) in enumerate(valid_dataloader, 1): lf0 = lf0.to(device) mels = mels.to(device) # (bs, 80, 128) recon_loss, vq_loss, cpc_loss, accuracy, perplexity, mi_cs_loss, lld_cs_loss, mi_ps_loss, lld_ps_loss, mi_cp_loss, lld_cp_loss = \ calculate_eval_loss(mels, lf0, \ encoder, encoder_lf0, cpc, \ encoder_spk, cs_mi_net, ps_mi_net, \ cp_mi_net, decoder, cfg) average_recon_loss += (recon_loss.item() - average_recon_loss) / i average_cpc_loss += (cpc_loss.item() - average_cpc_loss) / i average_vq_loss += (vq_loss.item() - average_vq_loss) / i average_perplexity += (perplexity.item() - average_perplexity) / i average_accuracies += (np.array(accuracy) - average_accuracies) / i average_lld_cs_loss += (lld_cs_loss.item() - average_lld_cs_loss) / i average_mi_cs_loss += (mi_cs_loss.item() - average_mi_cs_loss) / i average_lld_ps_loss += (lld_ps_loss.item() - average_lld_ps_loss) / i average_mi_ps_loss += (mi_ps_loss.item() - average_mi_ps_loss) / i average_lld_cp_loss += (lld_cp_loss.item() - average_lld_cp_loss) / i average_mi_cp_loss += (mi_cp_loss.item() - average_mi_cp_loss) / i ctime = time.time() print("Eval | epoch:{}, recon loss:{:.3f}, cpc loss:{:.3f}, vq loss:{:.3f}, perpexlity:{:.3f}, lld cs loss:{:.3f}, mi cs loss:{:.3E}, lld ps loss:{:.3f}, mi ps loss:{:.3f}, lld cp loss:{:.3f}, mi cp loss:{:.3f}, used time:{:.3f}s" .format(epoch, average_recon_loss, average_cpc_loss, average_vq_loss, average_perplexity, average_lld_cs_loss, average_mi_cs_loss, average_lld_ps_loss, average_mi_ps_loss, average_lld_cp_loss, average_mi_cp_loss, ctime-stime)) print(100 * average_accuracies) results_txt = open(f'{str(checkpoint_dir)}/results.txt', 'a') results_txt.write("Eval | epoch:{}, recon loss:{:.3f}, cpc loss:{:.3f}, vq loss:{:.3f}, perpexlity:{:.3f}, lld cs loss:{:.3f}, mi cs loss:{:.3E}, lld ps loss:{:.3f}, mi ps loss:{:.3f}, lld cp loss:{:.3f}, mi cp loss:{:.3f}" .format(epoch, average_recon_loss, average_cpc_loss, average_vq_loss, average_perplexity, average_lld_cs_loss, average_mi_cs_loss, average_lld_ps_loss, average_mi_ps_loss, average_lld_cp_loss, average_mi_cp_loss)+'\n') results_txt.write(' '.join([str(cpc_acc) for cpc_acc in average_accuracies])+'\n') results_txt.close() to_train(all_models) @hydra.main(config_path="config/train.yaml") def train_model(cfg): cfg.checkpoint_dir = f'{cfg.checkpoint_dir}/useCSMI{cfg.use_CSMI}_useCPMI{cfg.use_CPMI}_usePSMI{cfg.use_PSMI}_useAmp{cfg.use_amp}' if cfg.encoder_lf0_type == 'no_emb': # default dim_lf0 = 1 else: dim_lf0 = 64 checkpoint_dir = Path(utils.to_absolute_path(cfg.checkpoint_dir)) checkpoint_dir.mkdir(exist_ok=True, parents=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # define model encoder = Encoder(**cfg.model.encoder) encoder_lf0 = Encoder_lf0(cfg.encoder_lf0_type) cpc = CPCLoss_sameSeq(**cfg.model.cpc) encoder_spk = Encoder_spk() cs_mi_net = CLUBSample_group(256, cfg.model.encoder.z_dim, 512) ps_mi_net = CLUBSample_group(256, dim_lf0, 512) cp_mi_net = CLUBSample_reshape(dim_lf0, cfg.model.encoder.z_dim, 512) decoder = Decoder_ac(dim_neck=cfg.model.encoder.z_dim, dim_lf0=dim_lf0, use_l1_loss=True) encoder.to(device) cpc.to(device) encoder_lf0.to(device) encoder_spk.to(device) cs_mi_net.to(device) ps_mi_net.to(device) cp_mi_net.to(device) decoder.to(device) optimizer = optim.Adam( chain(encoder.parameters(), encoder_lf0.parameters(), cpc.parameters(), encoder_spk.parameters(), decoder.parameters()), lr=cfg.training.scheduler.initial_lr) optimizer_cs_mi_net = optim.Adam(cs_mi_net.parameters(), lr=cfg.mi_lr) optimizer_ps_mi_net = optim.Adam(ps_mi_net.parameters(), lr=cfg.mi_lr) optimizer_cp_mi_net = optim.Adam(cp_mi_net.parameters(), lr=cfg.mi_lr) # TODO: use_amp is set default to True to speed up training; no-amp -> more stable training? => need to be verified if cfg.use_amp: [encoder, encoder_lf0, cpc, encoder_spk, decoder], optimizer = amp.initialize([encoder, encoder_lf0, cpc, encoder_spk, decoder], optimizer, opt_level='O1') [cs_mi_net], optimizer_cs_mi_net = amp.initialize([cs_mi_net], optimizer_cs_mi_net, opt_level='O1') [ps_mi_net], optimizer_ps_mi_net = amp.initialize([ps_mi_net], optimizer_ps_mi_net, opt_level='O1') [cp_mi_net], optimizer_cp_mi_net = amp.initialize([cp_mi_net], optimizer_cp_mi_net, opt_level='O1') root_path = Path(utils.to_absolute_path("data")) dataset = CPCDataset( root=root_path, n_sample_frames=cfg.training.sample_frames, # 128 mode='train') valid_dataset = CPCDataset( root=root_path, n_sample_frames=cfg.training.sample_frames, # 128 mode='valid') warmup_epochs = 2000 // (len(dataset)//cfg.training.batch_size) print('warmup_epochs:', warmup_epochs) scheduler = WarmupScheduler( optimizer, warmup_epochs=warmup_epochs, initial_lr=cfg.training.scheduler.initial_lr, max_lr=cfg.training.scheduler.max_lr, milestones=cfg.training.scheduler.milestones, gamma=cfg.training.scheduler.gamma) dataloader = DataLoader( dataset, batch_size=cfg.training.batch_size, # 256 shuffle=True, num_workers=cfg.training.n_workers, pin_memory=True, drop_last=False) valid_dataloader = DataLoader( valid_dataset, batch_size=cfg.training.batch_size, # 256 shuffle=False, num_workers=cfg.training.n_workers, pin_memory=True, drop_last=False) if cfg.resume: print("Resume checkpoint from: {}:".format(cfg.resume)) resume_path = utils.to_absolute_path(cfg.resume) checkpoint = torch.load(resume_path, map_location=lambda storage, loc: storage) encoder.load_state_dict(checkpoint["encoder"]) encoder_lf0.load_state_dict(checkpoint["encoder_lf0"]) cpc.load_state_dict(checkpoint["cpc"]) encoder_spk.load_state_dict(checkpoint["encoder_spk"]) cs_mi_net.load_state_dict(checkpoint["cs_mi_net"]) ps_mi_net.load_state_dict(checkpoint["ps_mi_net"]) if cfg.use_CPMI: cp_mi_net.load_state_dict(checkpoint["cp_mi_net"]) decoder.load_state_dict(checkpoint["decoder"]) optimizer.load_state_dict(checkpoint["optimizer"]) optimizer_cs_mi_net.load_state_dict(checkpoint["optimizer_cs_mi_net"]) optimizer_ps_mi_net.load_state_dict(checkpoint["optimizer_ps_mi_net"]) optimizer_cp_mi_net.load_state_dict(checkpoint["optimizer_cp_mi_net"]) if cfg.use_amp: amp.load_state_dict(checkpoint["amp"]) scheduler.load_state_dict(checkpoint["scheduler"]) start_epoch = checkpoint["epoch"] else: start_epoch = 1 if os.path.exists(f'{str(checkpoint_dir)}/results.txt'): wmode = 'a' else: wmode = 'w' results_txt = open(f'{str(checkpoint_dir)}/results.txt', wmode) results_txt.write('save training info...\n') results_txt.close() global_step = 0 stime = time.time() for epoch in range(start_epoch, cfg.training.n_epochs + 1): average_cpc_loss = average_vq_loss = average_perplexity = average_recon_loss = 0 average_accuracies = np.zeros(cfg.training.n_prediction_steps) average_lld_cs_loss = average_mi_cs_loss = average_lld_ps_loss = average_mi_ps_loss = average_lld_cp_loss = average_mi_cp_loss = 0 for i, (mels, lf0, speakers) in enumerate(dataloader, 1): lf0 = lf0.to(device) mels = mels.to(device) # (bs, 80, 128) if cfg.use_CSMI or cfg.use_CPMI or cfg.use_PSMI: for j in range(cfg.mi_iters): optimizer_cs_mi_net, lld_cs_loss, optimizer_ps_mi_net, lld_ps_loss, optimizer_cp_mi_net, lld_cp_loss = mi_first_forward(mels, lf0, encoder, encoder_lf0, encoder_spk, cs_mi_net, optimizer_cs_mi_net, \ ps_mi_net, optimizer_ps_mi_net, cp_mi_net, optimizer_cp_mi_net, cfg) else: lld_cs_loss = torch.tensor(0.) lld_ps_loss = torch.tensor(0.) lld_cp_loss = torch.tensor(0.) optimizer, recon_loss, vq_loss, cpc_loss, accuracy, perplexity, mi_cs_loss, mi_ps_loss, mi_cp_loss = mi_second_forward(mels, lf0, \ encoder, encoder_lf0, cpc, \ encoder_spk, cs_mi_net, ps_mi_net, \ cp_mi_net, decoder, cfg, \ optimizer, scheduler) average_recon_loss += (recon_loss.item() - average_recon_loss) / i average_cpc_loss += (cpc_loss.item() - average_cpc_loss) / i average_vq_loss += (vq_loss.item() - average_vq_loss) / i average_perplexity += (perplexity.item() - average_perplexity) / i average_accuracies += (np.array(accuracy) - average_accuracies) / i average_lld_cs_loss += (lld_cs_loss.item() - average_lld_cs_loss) / i average_mi_cs_loss += (mi_cs_loss.item() - average_mi_cs_loss) / i average_lld_ps_loss += (lld_ps_loss.item() - average_lld_ps_loss) / i average_mi_ps_loss += (mi_ps_loss.item() - average_mi_ps_loss) / i average_lld_cp_loss += (lld_cp_loss.item() - average_lld_cp_loss) / i average_mi_cp_loss += (mi_cp_loss.item() - average_mi_cp_loss) / i ctime = time.time() print("epoch:{}, global step:{}, recon loss:{:.3f}, cpc loss:{:.3f}, vq loss:{:.3f}, perpexlity:{:.3f}, lld cs loss:{:.3f}, mi cs loss:{:.3E}, lld ps loss:{:.3f}, mi ps loss:{:.3f}, lld cp loss:{:.3f}, mi cp loss:{:.3f}, used time:{:.3f}s" .format(epoch, global_step, average_recon_loss, average_cpc_loss, average_vq_loss, average_perplexity, average_lld_cs_loss, average_mi_cs_loss, average_lld_ps_loss, average_mi_ps_loss, average_lld_cp_loss, average_mi_cp_loss, ctime-stime)) print(100 * average_accuracies) stime = time.time() global_step += 1 # scheduler.step() results_txt = open(f'{str(checkpoint_dir)}/results.txt', 'a') results_txt.write("epoch:{}, global step:{}, recon loss:{:.3f}, cpc loss:{:.3f}, vq loss:{:.3f}, perpexlity:{:.3f}, lld cs loss:{:.3f}, mi cs loss:{:.3E}, lld ps loss:{:.3f}, mi ps loss:{:.3f}, lld cp loss:{:.3f}, mi cp loss:{:.3f}" .format(epoch, global_step, average_recon_loss, average_cpc_loss, average_vq_loss, average_perplexity, average_lld_cs_loss, average_mi_cs_loss, average_lld_ps_loss, average_mi_ps_loss, average_lld_cp_loss, average_mi_cp_loss)+'\n') results_txt.write(' '.join([str(cpc_acc) for cpc_acc in average_accuracies])+'\n') results_txt.close() scheduler.step() if epoch % cfg.training.log_interval == 0 and epoch != start_epoch: eval_model(epoch, checkpoint_dir, device, valid_dataloader, encoder, encoder_lf0, cpc, encoder_spk, cs_mi_net, ps_mi_net, cp_mi_net, decoder, cfg) ctime = time.time() print("epoch:{}, global step:{}, recon loss:{:.3f}, cpc loss:{:.3f}, vq loss:{:.3f}, perpexlity:{:.3f}, lld cs loss:{:.3f}, mi cs loss:{:.3E}, lld ps loss:{:.3f}, mi ps loss:{:.3f}, lld cp loss:{:.3f}, mi cp loss:{:.3f}, used time:{:.3f}s" .format(epoch, global_step, average_recon_loss, average_cpc_loss, average_vq_loss, average_perplexity, average_lld_cs_loss, average_mi_cs_loss, average_lld_ps_loss, average_mi_ps_loss, average_lld_cp_loss, average_mi_cp_loss, ctime-stime)) print(100 * average_accuracies) stime = time.time() if epoch % cfg.training.checkpoint_interval == 0 and epoch != start_epoch: save_checkpoint(encoder, encoder_lf0, cpc, encoder_spk, \ cs_mi_net, ps_mi_net, cp_mi_net, decoder, \ optimizer, optimizer_cs_mi_net, optimizer_ps_mi_net, optimizer_cp_mi_net, scheduler, amp, epoch, checkpoint_dir, cfg) if __name__ == "__main__": train_model()