tts-rvc-autopst / solver_1.py
jonathanjordan21's picture
67809715652a92b22870c50ad30f6ff38e292006aedc75ddbdc828aa856ef68f
c021d8e verified
raw
history blame
5.55 kB
import os
import time
import pickle
import datetime
import itertools
import numpy as np
import torch
import torch.nn.functional as F
from onmt_modules.misc import sequence_mask
from model_autopst import Generator_1 as Predictor
class Solver(object):
def __init__(self, data_loader, config, hparams):
"""Initialize configurations."""
self.data_loader = data_loader
self.hparams = hparams
self.gate_threshold = hparams.gate_threshold
self.use_cuda = torch.cuda.is_available()
self.device = torch.device('cuda:{}'.format(config.device_id) if self.use_cuda else 'cpu')
self.num_iters = config.num_iters
self.log_step = config.log_step
# Build the model
self.build_model()
def build_model(self):
self.P = Predictor(self.hparams)
self.optimizer = torch.optim.Adam(self.P.parameters(), 0.0001, [0.9, 0.999])
self.P.to(self.device)
self.BCELoss = torch.nn.BCEWithLogitsLoss().to(self.device)
def train(self):
# Set data loader
data_loader = self.data_loader
data_iter = iter(data_loader)
# Print logs in specified order
keys = ['P/loss_tx2sp', 'P/loss_stop_sp']
# Start training.
print('Start training...')
start_time = time.time()
for i in range(self.num_iters):
try:
sp_real, cep_real, cd_real, _, num_rep_sync, len_real, _, len_short_sync, spk_emb = next(data_iter)
except:
data_iter = iter(data_loader)
sp_real, cep_real, cd_real, _, num_rep_sync, len_real, _, len_short_sync, spk_emb = next(data_iter)
sp_real = sp_real.to(self.device)
cep_real = cep_real.to(self.device)
cd_real = cd_real.to(self.device)
len_real = len_real.to(self.device)
spk_emb = spk_emb.to(self.device)
num_rep_sync = num_rep_sync.to(self.device)
len_short_sync = len_short_sync.to(self.device)
# real spect masks
mask_sp_real = ~sequence_mask(len_real, sp_real.size(1))
mask_long = (~mask_sp_real).float()
len_real_mask = torch.min(len_real + 10,
torch.full_like(len_real, sp_real.size(1)))
loss_tx2sp_mask = sequence_mask(len_real_mask, sp_real.size(1)).float().unsqueeze(-1)
# text input masks
codes_mask = sequence_mask(len_short_sync, num_rep_sync.size(1)).float()
# =================================================================================== #
# 2. Train #
# =================================================================================== #
self.P = self.P.train()
sp_real_sft = torch.zeros_like(sp_real)
sp_real_sft[:, 1:, :] = sp_real[:, :-1, :]
spect_pred, stop_pred_sp = self.P(cep_real.transpose(2,1),
mask_long,
codes_mask,
num_rep_sync,
len_short_sync+1,
sp_real_sft.transpose(1,0),
len_real+1,
spk_emb)
loss_tx2sp = (F.mse_loss(spect_pred.permute(1,0,2), sp_real, reduction='none')
* loss_tx2sp_mask).sum() / loss_tx2sp_mask.sum()
loss_stop_sp = self.BCELoss(stop_pred_sp.squeeze(-1).t(), mask_sp_real.float())
loss_total = loss_tx2sp + loss_stop_sp
# Backward and optimize
self.optimizer.zero_grad()
loss_total.backward()
self.optimizer.step()
# Logging
loss = {}
loss['P/loss_tx2sp'] = loss_tx2sp.item()
loss['P/loss_stop_sp'] = loss_stop_sp.item()
# =================================================================================== #
# 4. Miscellaneous #
# =================================================================================== #
# Print out training information
if (i+1) % self.log_step == 0:
et = time.time() - start_time
et = str(datetime.timedelta(seconds=et))[:-7]
log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
for tag in keys:
log += ", {}: {:.8f}".format(tag, loss[tag])
print(log)
# Save model checkpoints.
if (i+1) % 10000 == 0:
torch.save({'model': self.P.state_dict(),
'optimizer': self.optimizer.state_dict()}, f'./assets/{i+1}-A.ckpt')
print('Saved model checkpoints into assets ...')