VQMIVC / train.py
akhaliq3
spaces demo
2b7bf83
raw
history blame
20.4 kB
import hydra
from hydra import utils
from itertools import chain
from pathlib import Path
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from dataset import CPCDataset_sameSeq as CPCDataset
from scheduler import WarmupScheduler
from model_encoder import Encoder, CPCLoss_sameSeq, Encoder_lf0
from model_decoder import Decoder_ac
from model_encoder import SpeakerEncoder as Encoder_spk
from mi_estimators import CLUBSample_group, CLUBSample_reshape
import apex.amp as amp
import os
import time
torch.manual_seed(137)
np.random.seed(137)
def save_checkpoint(encoder, encoder_lf0, cpc, encoder_spk, \
cs_mi_net, ps_mi_net, cp_mi_net, decoder, \
optimizer, optimizer_cs_mi_net, optimizer_ps_mi_net, optimizer_cp_mi_net, scheduler, amp, epoch, checkpoint_dir, cfg):
if cfg.use_amp:
amp_state_dict = amp.state_dict()
else:
amp_state_dict = None
checkpoint_state = {
"encoder": encoder.state_dict(),
"encoder_lf0": encoder_lf0.state_dict(),
"cpc": cpc.state_dict(),
"encoder_spk": encoder_spk.state_dict(),
"ps_mi_net": ps_mi_net.state_dict(),
"cp_mi_net": cp_mi_net.state_dict(),
"cs_mi_net": cs_mi_net.state_dict(),
"decoder": decoder.state_dict(),
"optimizer": optimizer.state_dict(),
"optimizer_cs_mi_net": optimizer_cs_mi_net.state_dict(),
"optimizer_ps_mi_net": optimizer_ps_mi_net.state_dict(),
"optimizer_cp_mi_net": optimizer_cp_mi_net.state_dict(),
"scheduler": scheduler.state_dict(),
"amp": amp_state_dict,
"epoch": epoch
}
checkpoint_dir.mkdir(exist_ok=True, parents=True)
checkpoint_path = checkpoint_dir / "model.ckpt-{}.pt".format(epoch)
torch.save(checkpoint_state, checkpoint_path)
print("Saved checkpoint: {}".format(checkpoint_path.stem))
def mi_first_forward(mels, lf0, encoder, encoder_lf0, encoder_spk, cs_mi_net, optimizer_cs_mi_net,
ps_mi_net, optimizer_ps_mi_net, cp_mi_net, optimizer_cp_mi_net, cfg):
optimizer_cs_mi_net.zero_grad()
optimizer_ps_mi_net.zero_grad()
optimizer_cp_mi_net.zero_grad()
z, _, _, _, _ = encoder(mels)
z = z.detach()
lf0_embs = encoder_lf0(lf0).detach()
spk_embs = encoder_spk(mels).detach()
if cfg.use_CSMI:
lld_cs_loss = -cs_mi_net.loglikeli(spk_embs, z)
if cfg.use_amp:
with amp.scale_loss(lld_cs_loss, optimizer_cs_mi_net) as sl:
sl.backward()
else:
lld_cs_loss.backward()
optimizer_cs_mi_net.step()
else:
lld_cs_loss = torch.tensor(0.)
if cfg.use_CPMI:
lld_cp_loss = -cp_mi_net.loglikeli(lf0_embs.unsqueeze(1).reshape(lf0_embs.shape[0],-1,2,lf0_embs.shape[-1]).mean(2), z)
if cfg.use_amp:
with amp.scale_loss(lld_cp_loss, optimizer_cp_mi_net) as slll:
slll.backward()
else:
lld_cp_loss.backward()
torch.nn.utils.clip_grad_norm_(cp_mi_net.parameters(), 1)
optimizer_cp_mi_net.step()
else:
lld_cp_loss = torch.tensor(0.)
if cfg.use_PSMI:
lld_ps_loss = -ps_mi_net.loglikeli(spk_embs, lf0_embs)
if cfg.use_amp:
with amp.scale_loss(lld_ps_loss, optimizer_ps_mi_net) as sll:
sll.backward()
else:
lld_ps_loss.backward()
optimizer_ps_mi_net.step()
else:
lld_ps_loss = torch.tensor(0.)
return optimizer_cs_mi_net, lld_cs_loss, optimizer_ps_mi_net, lld_ps_loss, optimizer_cp_mi_net, lld_cp_loss
def mi_second_forward(mels, lf0, encoder, encoder_lf0, cpc, encoder_spk, cs_mi_net, ps_mi_net, cp_mi_net, decoder, cfg, optimizer, scheduler):
optimizer.zero_grad()
z, c, _, vq_loss, perplexity = encoder(mels)
cpc_loss, accuracy = cpc(z, c)
spk_embs = encoder_spk(mels)
lf0_embs = encoder_lf0(lf0)
recon_loss, pred_mels = decoder(z, lf0_embs, spk_embs, mels.transpose(1,2))
loss = recon_loss + cpc_loss + vq_loss
if cfg.use_CSMI:
mi_cs_loss = cfg.mi_weight*cs_mi_net.mi_est(spk_embs, z)
else:
mi_cs_loss = torch.tensor(0.).to(loss.device)
if cfg.use_CPMI:
mi_cp_loss = cfg.mi_weight*cp_mi_net.mi_est(lf0_embs.unsqueeze(1).reshape(lf0_embs.shape[0],-1,2,lf0_embs.shape[-1]).mean(2), z)
else:
mi_cp_loss = torch.tensor(0.).to(loss.device)
if cfg.use_PSMI:
mi_ps_loss = cfg.mi_weight*ps_mi_net.mi_est(spk_embs, lf0_embs)
else:
mi_ps_loss = torch.tensor(0.).to(loss.device)
loss = loss + mi_cs_loss + mi_ps_loss + mi_cp_loss
if cfg.use_amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
return optimizer, recon_loss, vq_loss, cpc_loss, accuracy, perplexity, mi_cs_loss, mi_ps_loss, mi_cp_loss
def calculate_eval_loss(mels, lf0, \
encoder, encoder_lf0, cpc, \
encoder_spk, cs_mi_net, ps_mi_net, \
cp_mi_net, decoder, cfg):
with torch.no_grad():
z, c, z_beforeVQ, vq_loss, perplexity = encoder(mels)
c = c
lf0_embs = encoder_lf0(lf0)
spk_embs = encoder_spk(mels)
if cfg.use_CSMI:
lld_cs_loss = -cs_mi_net.loglikeli(spk_embs, z)
mi_cs_loss = cfg.mi_weight*cs_mi_net.mi_est(spk_embs, z)
else:
lld_cs_loss = torch.tensor(0.)
mi_cs_loss = torch.tensor(0.)
# z, c, z_beforeVQ, vq_loss, perplexity = encoder(mels)
cpc_loss, accuracy = cpc(z, c)
recon_loss, pred_mels = decoder(z, lf0_embs, spk_embs, mels.transpose(1,2))
if cfg.use_CPMI:
mi_cp_loss = cfg.mi_weight*cp_mi_net.mi_est(lf0_embs.unsqueeze(1).reshape(lf0_embs.shape[0],-1,2,lf0_embs.shape[-1]).mean(2), z)
lld_cp_loss = -cp_mi_net.loglikeli(lf0_embs.unsqueeze(1).reshape(lf0_embs.shape[0],-1,2,lf0_embs.shape[-1]).mean(2), z)
else:
mi_cp_loss = torch.tensor(0.)
lld_cp_loss = torch.tensor(0.)
if cfg.use_PSMI:
mi_ps_loss = cfg.mi_weight*ps_mi_net.mi_est(spk_embs, lf0_embs)
lld_ps_loss = -ps_mi_net.loglikeli(spk_embs, lf0_embs)
else:
mi_ps_loss = torch.tensor(0.)
lld_ps_loss = torch.tensor(0.)
return recon_loss, vq_loss, cpc_loss, accuracy, perplexity, mi_cs_loss, lld_cs_loss, mi_ps_loss, lld_ps_loss, mi_cp_loss, lld_cp_loss
def to_eval(all_models):
for m in all_models:
m.eval()
def to_train(all_models):
for m in all_models:
m.train()
def eval_model(epoch, checkpoint_dir, device, valid_dataloader, encoder, encoder_lf0, cpc, encoder_spk, cs_mi_net, ps_mi_net, cp_mi_net, decoder, cfg):
stime = time.time()
average_cpc_loss = average_vq_loss = average_perplexity = average_recon_loss = 0
average_accuracies = np.zeros(cfg.training.n_prediction_steps)
average_lld_cs_loss = average_mi_cs_loss = average_lld_ps_loss = average_mi_ps_loss = average_lld_cp_loss = average_mi_cp_loss = 0
all_models = [encoder, encoder_lf0, cpc, encoder_spk, cs_mi_net, ps_mi_net, cp_mi_net, decoder]
to_eval(all_models)
for i, (mels, lf0, speakers) in enumerate(valid_dataloader, 1):
lf0 = lf0.to(device)
mels = mels.to(device) # (bs, 80, 128)
recon_loss, vq_loss, cpc_loss, accuracy, perplexity, mi_cs_loss, lld_cs_loss, mi_ps_loss, lld_ps_loss, mi_cp_loss, lld_cp_loss = \
calculate_eval_loss(mels, lf0, \
encoder, encoder_lf0, cpc, \
encoder_spk, cs_mi_net, ps_mi_net, \
cp_mi_net, decoder, cfg)
average_recon_loss += (recon_loss.item() - average_recon_loss) / i
average_cpc_loss += (cpc_loss.item() - average_cpc_loss) / i
average_vq_loss += (vq_loss.item() - average_vq_loss) / i
average_perplexity += (perplexity.item() - average_perplexity) / i
average_accuracies += (np.array(accuracy) - average_accuracies) / i
average_lld_cs_loss += (lld_cs_loss.item() - average_lld_cs_loss) / i
average_mi_cs_loss += (mi_cs_loss.item() - average_mi_cs_loss) / i
average_lld_ps_loss += (lld_ps_loss.item() - average_lld_ps_loss) / i
average_mi_ps_loss += (mi_ps_loss.item() - average_mi_ps_loss) / i
average_lld_cp_loss += (lld_cp_loss.item() - average_lld_cp_loss) / i
average_mi_cp_loss += (mi_cp_loss.item() - average_mi_cp_loss) / i
ctime = time.time()
print("Eval | epoch:{}, recon loss:{:.3f}, cpc loss:{:.3f}, vq loss:{:.3f}, perpexlity:{:.3f}, lld cs loss:{:.3f}, mi cs loss:{:.3E}, lld ps loss:{:.3f}, mi ps loss:{:.3f}, lld cp loss:{:.3f}, mi cp loss:{:.3f}, used time:{:.3f}s"
.format(epoch, average_recon_loss, average_cpc_loss, average_vq_loss, average_perplexity, average_lld_cs_loss, average_mi_cs_loss, average_lld_ps_loss, average_mi_ps_loss, average_lld_cp_loss, average_mi_cp_loss, ctime-stime))
print(100 * average_accuracies)
results_txt = open(f'{str(checkpoint_dir)}/results.txt', 'a')
results_txt.write("Eval | epoch:{}, recon loss:{:.3f}, cpc loss:{:.3f}, vq loss:{:.3f}, perpexlity:{:.3f}, lld cs loss:{:.3f}, mi cs loss:{:.3E}, lld ps loss:{:.3f}, mi ps loss:{:.3f}, lld cp loss:{:.3f}, mi cp loss:{:.3f}"
.format(epoch, average_recon_loss, average_cpc_loss, average_vq_loss, average_perplexity, average_lld_cs_loss, average_mi_cs_loss, average_lld_ps_loss, average_mi_ps_loss, average_lld_cp_loss, average_mi_cp_loss)+'\n')
results_txt.write(' '.join([str(cpc_acc) for cpc_acc in average_accuracies])+'\n')
results_txt.close()
to_train(all_models)
@hydra.main(config_path="config/train.yaml")
def train_model(cfg):
cfg.checkpoint_dir = f'{cfg.checkpoint_dir}/useCSMI{cfg.use_CSMI}_useCPMI{cfg.use_CPMI}_usePSMI{cfg.use_PSMI}_useAmp{cfg.use_amp}'
if cfg.encoder_lf0_type == 'no_emb': # default
dim_lf0 = 1
else:
dim_lf0 = 64
checkpoint_dir = Path(utils.to_absolute_path(cfg.checkpoint_dir))
checkpoint_dir.mkdir(exist_ok=True, parents=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# define model
encoder = Encoder(**cfg.model.encoder)
encoder_lf0 = Encoder_lf0(cfg.encoder_lf0_type)
cpc = CPCLoss_sameSeq(**cfg.model.cpc)
encoder_spk = Encoder_spk()
cs_mi_net = CLUBSample_group(256, cfg.model.encoder.z_dim, 512)
ps_mi_net = CLUBSample_group(256, dim_lf0, 512)
cp_mi_net = CLUBSample_reshape(dim_lf0, cfg.model.encoder.z_dim, 512)
decoder = Decoder_ac(dim_neck=cfg.model.encoder.z_dim, dim_lf0=dim_lf0, use_l1_loss=True)
encoder.to(device)
cpc.to(device)
encoder_lf0.to(device)
encoder_spk.to(device)
cs_mi_net.to(device)
ps_mi_net.to(device)
cp_mi_net.to(device)
decoder.to(device)
optimizer = optim.Adam(
chain(encoder.parameters(), encoder_lf0.parameters(), cpc.parameters(), encoder_spk.parameters(), decoder.parameters()),
lr=cfg.training.scheduler.initial_lr)
optimizer_cs_mi_net = optim.Adam(cs_mi_net.parameters(), lr=cfg.mi_lr)
optimizer_ps_mi_net = optim.Adam(ps_mi_net.parameters(), lr=cfg.mi_lr)
optimizer_cp_mi_net = optim.Adam(cp_mi_net.parameters(), lr=cfg.mi_lr)
# TODO: use_amp is set default to True to speed up training; no-amp -> more stable training? => need to be verified
if cfg.use_amp:
[encoder, encoder_lf0, cpc, encoder_spk, decoder], optimizer = amp.initialize([encoder, encoder_lf0, cpc, encoder_spk, decoder], optimizer, opt_level='O1')
[cs_mi_net], optimizer_cs_mi_net = amp.initialize([cs_mi_net], optimizer_cs_mi_net, opt_level='O1')
[ps_mi_net], optimizer_ps_mi_net = amp.initialize([ps_mi_net], optimizer_ps_mi_net, opt_level='O1')
[cp_mi_net], optimizer_cp_mi_net = amp.initialize([cp_mi_net], optimizer_cp_mi_net, opt_level='O1')
root_path = Path(utils.to_absolute_path("data"))
dataset = CPCDataset(
root=root_path,
n_sample_frames=cfg.training.sample_frames, # 128
mode='train')
valid_dataset = CPCDataset(
root=root_path,
n_sample_frames=cfg.training.sample_frames, # 128
mode='valid')
warmup_epochs = 2000 // (len(dataset)//cfg.training.batch_size)
print('warmup_epochs:', warmup_epochs)
scheduler = WarmupScheduler(
optimizer,
warmup_epochs=warmup_epochs,
initial_lr=cfg.training.scheduler.initial_lr,
max_lr=cfg.training.scheduler.max_lr,
milestones=cfg.training.scheduler.milestones,
gamma=cfg.training.scheduler.gamma)
dataloader = DataLoader(
dataset,
batch_size=cfg.training.batch_size, # 256
shuffle=True,
num_workers=cfg.training.n_workers,
pin_memory=True,
drop_last=False)
valid_dataloader = DataLoader(
valid_dataset,
batch_size=cfg.training.batch_size, # 256
shuffle=False,
num_workers=cfg.training.n_workers,
pin_memory=True,
drop_last=False)
if cfg.resume:
print("Resume checkpoint from: {}:".format(cfg.resume))
resume_path = utils.to_absolute_path(cfg.resume)
checkpoint = torch.load(resume_path, map_location=lambda storage, loc: storage)
encoder.load_state_dict(checkpoint["encoder"])
encoder_lf0.load_state_dict(checkpoint["encoder_lf0"])
cpc.load_state_dict(checkpoint["cpc"])
encoder_spk.load_state_dict(checkpoint["encoder_spk"])
cs_mi_net.load_state_dict(checkpoint["cs_mi_net"])
ps_mi_net.load_state_dict(checkpoint["ps_mi_net"])
if cfg.use_CPMI:
cp_mi_net.load_state_dict(checkpoint["cp_mi_net"])
decoder.load_state_dict(checkpoint["decoder"])
optimizer.load_state_dict(checkpoint["optimizer"])
optimizer_cs_mi_net.load_state_dict(checkpoint["optimizer_cs_mi_net"])
optimizer_ps_mi_net.load_state_dict(checkpoint["optimizer_ps_mi_net"])
optimizer_cp_mi_net.load_state_dict(checkpoint["optimizer_cp_mi_net"])
if cfg.use_amp:
amp.load_state_dict(checkpoint["amp"])
scheduler.load_state_dict(checkpoint["scheduler"])
start_epoch = checkpoint["epoch"]
else:
start_epoch = 1
if os.path.exists(f'{str(checkpoint_dir)}/results.txt'):
wmode = 'a'
else:
wmode = 'w'
results_txt = open(f'{str(checkpoint_dir)}/results.txt', wmode)
results_txt.write('save training info...\n')
results_txt.close()
global_step = 0
stime = time.time()
for epoch in range(start_epoch, cfg.training.n_epochs + 1):
average_cpc_loss = average_vq_loss = average_perplexity = average_recon_loss = 0
average_accuracies = np.zeros(cfg.training.n_prediction_steps)
average_lld_cs_loss = average_mi_cs_loss = average_lld_ps_loss = average_mi_ps_loss = average_lld_cp_loss = average_mi_cp_loss = 0
for i, (mels, lf0, speakers) in enumerate(dataloader, 1):
lf0 = lf0.to(device)
mels = mels.to(device) # (bs, 80, 128)
if cfg.use_CSMI or cfg.use_CPMI or cfg.use_PSMI:
for j in range(cfg.mi_iters):
optimizer_cs_mi_net, lld_cs_loss, optimizer_ps_mi_net, lld_ps_loss, optimizer_cp_mi_net, lld_cp_loss = mi_first_forward(mels, lf0, encoder, encoder_lf0, encoder_spk, cs_mi_net, optimizer_cs_mi_net, \
ps_mi_net, optimizer_ps_mi_net, cp_mi_net, optimizer_cp_mi_net, cfg)
else:
lld_cs_loss = torch.tensor(0.)
lld_ps_loss = torch.tensor(0.)
lld_cp_loss = torch.tensor(0.)
optimizer, recon_loss, vq_loss, cpc_loss, accuracy, perplexity, mi_cs_loss, mi_ps_loss, mi_cp_loss = mi_second_forward(mels, lf0, \
encoder, encoder_lf0, cpc, \
encoder_spk, cs_mi_net, ps_mi_net, \
cp_mi_net, decoder, cfg, \
optimizer, scheduler)
average_recon_loss += (recon_loss.item() - average_recon_loss) / i
average_cpc_loss += (cpc_loss.item() - average_cpc_loss) / i
average_vq_loss += (vq_loss.item() - average_vq_loss) / i
average_perplexity += (perplexity.item() - average_perplexity) / i
average_accuracies += (np.array(accuracy) - average_accuracies) / i
average_lld_cs_loss += (lld_cs_loss.item() - average_lld_cs_loss) / i
average_mi_cs_loss += (mi_cs_loss.item() - average_mi_cs_loss) / i
average_lld_ps_loss += (lld_ps_loss.item() - average_lld_ps_loss) / i
average_mi_ps_loss += (mi_ps_loss.item() - average_mi_ps_loss) / i
average_lld_cp_loss += (lld_cp_loss.item() - average_lld_cp_loss) / i
average_mi_cp_loss += (mi_cp_loss.item() - average_mi_cp_loss) / i
ctime = time.time()
print("epoch:{}, global step:{}, recon loss:{:.3f}, cpc loss:{:.3f}, vq loss:{:.3f}, perpexlity:{:.3f}, lld cs loss:{:.3f}, mi cs loss:{:.3E}, lld ps loss:{:.3f}, mi ps loss:{:.3f}, lld cp loss:{:.3f}, mi cp loss:{:.3f}, used time:{:.3f}s"
.format(epoch, global_step, average_recon_loss, average_cpc_loss, average_vq_loss, average_perplexity, average_lld_cs_loss, average_mi_cs_loss, average_lld_ps_loss, average_mi_ps_loss, average_lld_cp_loss, average_mi_cp_loss, ctime-stime))
print(100 * average_accuracies)
stime = time.time()
global_step += 1
# scheduler.step()
results_txt = open(f'{str(checkpoint_dir)}/results.txt', 'a')
results_txt.write("epoch:{}, global step:{}, recon loss:{:.3f}, cpc loss:{:.3f}, vq loss:{:.3f}, perpexlity:{:.3f}, lld cs loss:{:.3f}, mi cs loss:{:.3E}, lld ps loss:{:.3f}, mi ps loss:{:.3f}, lld cp loss:{:.3f}, mi cp loss:{:.3f}"
.format(epoch, global_step, average_recon_loss, average_cpc_loss, average_vq_loss, average_perplexity, average_lld_cs_loss, average_mi_cs_loss, average_lld_ps_loss, average_mi_ps_loss, average_lld_cp_loss, average_mi_cp_loss)+'\n')
results_txt.write(' '.join([str(cpc_acc) for cpc_acc in average_accuracies])+'\n')
results_txt.close()
scheduler.step()
if epoch % cfg.training.log_interval == 0 and epoch != start_epoch:
eval_model(epoch, checkpoint_dir, device, valid_dataloader, encoder, encoder_lf0, cpc, encoder_spk, cs_mi_net, ps_mi_net, cp_mi_net, decoder, cfg)
ctime = time.time()
print("epoch:{}, global step:{}, recon loss:{:.3f}, cpc loss:{:.3f}, vq loss:{:.3f}, perpexlity:{:.3f}, lld cs loss:{:.3f}, mi cs loss:{:.3E}, lld ps loss:{:.3f}, mi ps loss:{:.3f}, lld cp loss:{:.3f}, mi cp loss:{:.3f}, used time:{:.3f}s"
.format(epoch, global_step, average_recon_loss, average_cpc_loss, average_vq_loss, average_perplexity, average_lld_cs_loss, average_mi_cs_loss, average_lld_ps_loss, average_mi_ps_loss, average_lld_cp_loss, average_mi_cp_loss, ctime-stime))
print(100 * average_accuracies)
stime = time.time()
if epoch % cfg.training.checkpoint_interval == 0 and epoch != start_epoch:
save_checkpoint(encoder, encoder_lf0, cpc, encoder_spk, \
cs_mi_net, ps_mi_net, cp_mi_net, decoder, \
optimizer, optimizer_cs_mi_net, optimizer_ps_mi_net, optimizer_cp_mi_net, scheduler, amp, epoch, checkpoint_dir, cfg)
if __name__ == "__main__":
train_model()