94insane's picture
230817,V0.1
320e69e
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import hparams as hp
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]=hp.train_visible_devices
import numpy as np
import argparse
import time
from fastspeech2 import FastSpeech2
from loss import FastSpeech2Loss
from dataset import Dataset
from optimizer import ScheduledOptim
from evaluate import evaluate
import utils
import audio as Audio
def main(args):
torch.manual_seed(0)
# Get device
device = torch.device('cuda'if torch.cuda.is_available()else 'cpu')
# Get dataset
dataset = Dataset("train.txt")
loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=True,
collate_fn=dataset.collate_fn, drop_last=True, num_workers=0)
print(len(loader))
# Define model
model = nn.DataParallel(FastSpeech2()).to(device)
print("Model Has Been Defined")
num_param = utils.get_param_num(model)
print('Number of FastSpeech2 Parameters:', num_param)
# Optimizer and loss
optimizer = torch.optim.Adam(model.parameters(), betas=hp.betas, eps=hp.eps, weight_decay = hp.weight_decay)
scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden, hp.n_warm_up_step, args.restore_step)
Loss = FastSpeech2Loss().to(device)
print("Optimizer and Loss Function Defined.")
# Load checkpoint if exists
checkpoint_path = os.path.join(hp.checkpoint_path)
try:
checkpoint = torch.load(os.path.join(
checkpoint_path, 'checkpoint_{}.pth.tar'.format(args.restore_step)))
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("\n---Model Restored at Step {}---\n".format(args.restore_step))
except:
print("\n---Start New Training---\n")
if not os.path.exists(checkpoint_path):
os.makedirs(checkpoint_path)
# read params
mean_mel, std_mel = torch.tensor(np.load(os.path.join(hp.preprocessed_path, "mel_stat.npy")), dtype=torch.float).to(device)
mean_f0, std_f0 = torch.tensor(np.load(os.path.join(hp.preprocessed_path, "f0_stat.npy")), dtype=torch.float).to(device)
mean_energy, std_energy = torch.tensor(np.load(os.path.join(hp.preprocessed_path, "energy_stat.npy")), dtype=torch.float).to(device)
mean_mel, std_mel = mean_mel.reshape(1, -1), std_mel.reshape(1, -1)
mean_f0, std_f0 = mean_f0.reshape(1, -1), std_f0.reshape(1, -1)
mean_energy, std_energy = mean_energy.reshape(1, -1), std_energy.reshape(1, -1)
# Load vocoder
if hp.vocoder == 'vocgan':
vocoder = utils.get_vocgan(ckpt_path = hp.vocoder_pretrained_model_path)
vocoder.to(device)
else:
vocoder = None
# Init logger
log_path = hp.log_path
if not os.path.exists(log_path):
os.makedirs(log_path)
os.makedirs(os.path.join(log_path, 'train'))
os.makedirs(os.path.join(log_path, 'validation'))
train_logger = SummaryWriter(os.path.join(log_path, 'train'))
val_logger = SummaryWriter(os.path.join(log_path, 'validation'))
# Define Some Information
Time = np.array([])
Start = time.perf_counter()
# Training
model = model.train()
from tqdm import tqdm
for epoch in tqdm(range(hp.epochs)):
# Get Training Loader
total_step = hp.epochs * len(loader) * hp.batch_size
for i, batchs in enumerate(loader):
for j, data_of_batch in enumerate(batchs):
start_time = time.perf_counter()
current_step = i*hp.batch_size + j + args.restore_step + epoch*len(loader)*hp.batch_size + 1
# Get Data
text = torch.from_numpy(data_of_batch["text"]).long().to(device)
mel_target = torch.from_numpy(data_of_batch["mel_target"]).float().to(device)
D = torch.from_numpy(data_of_batch["D"]).long().to(device)
log_D = torch.from_numpy(data_of_batch["log_D"]).float().to(device)
f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
energy = torch.from_numpy(data_of_batch["energy"]).float().to(device)
src_len = torch.from_numpy(data_of_batch["src_len"]).long().to(device)
mel_len = torch.from_numpy(data_of_batch["mel_len"]).long().to(device)
max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)
# Forward
mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model(
text, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len)
# Cal Loss
mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss(
log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, ~src_mask, ~mel_mask)
total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss
# Logger
t_l = total_loss.item()
m_l = mel_loss.item()
m_p_l = mel_postnet_loss.item()
d_l = d_loss.item()
f_l = f_loss.item()
e_l = e_loss.item()
with open(os.path.join(log_path, "total_loss.txt"), "a") as f_total_loss:
f_total_loss.write(str(t_l)+"\n")
with open(os.path.join(log_path, "mel_loss.txt"), "a") as f_mel_loss:
f_mel_loss.write(str(m_l)+"\n")
with open(os.path.join(log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss:
f_mel_postnet_loss.write(str(m_p_l)+"\n")
with open(os.path.join(log_path, "duration_loss.txt"), "a") as f_d_loss:
f_d_loss.write(str(d_l)+"\n")
with open(os.path.join(log_path, "f0_loss.txt"), "a") as f_f_loss:
f_f_loss.write(str(f_l)+"\n")
with open(os.path.join(log_path, "energy_loss.txt"), "a") as f_e_loss:
f_e_loss.write(str(e_l)+"\n")
# Backward
total_loss = total_loss / hp.acc_steps
total_loss.backward()
if current_step % hp.acc_steps != 0:
continue
# Clipping gradients to avoid gradient explosion
nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh)
# Update weights
scheduled_optim.step_and_update_lr()
scheduled_optim.zero_grad()
# Print
if current_step % hp.log_step == 0:
Now = time.perf_counter()
str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
epoch+1, hp.epochs, current_step, total_step)
str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format(
t_l, m_l, m_p_l, d_l, f_l, e_l)
str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
(Now-Start), (total_step-current_step)*np.mean(Time))
print("\n" + str1)
print(str2)
print(str3)
with open(os.path.join(log_path, "log.txt"), "a") as f_log:
f_log.write(str1 + "\n")
f_log.write(str2 + "\n")
f_log.write(str3 + "\n")
f_log.write("\n")
train_logger.add_scalar('Loss/total_loss', t_l, current_step)
train_logger.add_scalar('Loss/mel_loss', m_l, current_step)
train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step)
train_logger.add_scalar('Loss/duration_loss', d_l, current_step)
train_logger.add_scalar('Loss/F0_loss', f_l, current_step)
train_logger.add_scalar('Loss/energy_loss', e_l, current_step)
if current_step % hp.save_step == 0:
torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(
)}, os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(current_step)))
print("save model at step {} ...".format(current_step))
if current_step % hp.eval_step == 0:
model.eval()
with torch.no_grad():
d_l, f_l, e_l, m_l, m_p_l = evaluate(model, current_step, vocoder)
t_l = d_l + f_l + e_l + m_l + m_p_l
val_logger.add_scalar('Loss/total_loss', t_l, current_step)
val_logger.add_scalar('Loss/mel_loss', m_l, current_step)
val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step)
val_logger.add_scalar('Loss/duration_loss', d_l, current_step)
val_logger.add_scalar('Loss/F0_loss', f_l, current_step)
val_logger.add_scalar('Loss/energy_loss', e_l, current_step)
model.train()
end_time = time.perf_counter()
Time = np.append(Time, end_time - start_time)
if len(Time) == hp.clear_Time:
temp_value = np.mean(Time)
Time = np.delete(
Time, [i for i in range(len(Time))], axis=None)
Time = np.append(Time, temp_value)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--restore_step', type=int, default=99999)
args = parser.parse_args()
main(args)