Spaces:
Runtime error
Runtime error
# Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved. | |
# This program is free software; you can redistribute it and/or modify | |
# it under the terms of the MIT License. | |
# This program is distributed in the hope that it will be useful, | |
# but WITHOUT ANY WARRANTY; without even the implied warranty of | |
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
# MIT License for more details. | |
import numpy as np | |
from tqdm import tqdm | |
import torch | |
from torch.utils.data import DataLoader | |
from torch.utils.tensorboard import SummaryWriter | |
import params | |
from model import GradTTS | |
from data import TextMelDataset, TextMelBatchCollate | |
from utils import plot_tensor, save_plot | |
from text.symbols import symbols | |
train_filelist_path = params.train_filelist_path | |
valid_filelist_path = params.valid_filelist_path | |
cmudict_path = params.cmudict_path | |
add_blank = params.add_blank | |
log_dir = params.log_dir | |
n_epochs = params.n_epochs | |
batch_size = params.batch_size | |
out_size = params.out_size | |
learning_rate = params.learning_rate | |
random_seed = params.seed | |
n_workers = params.n_workers | |
nsymbols = len(symbols) + 1 if add_blank else len(symbols) | |
n_enc_channels = params.n_enc_channels | |
filter_channels = params.filter_channels | |
filter_channels_dp = params.filter_channels_dp | |
n_enc_layers = params.n_enc_layers | |
enc_kernel = params.enc_kernel | |
enc_dropout = params.enc_dropout | |
n_heads = params.n_heads | |
window_size = params.window_size | |
n_feats = params.n_feats | |
n_fft = params.n_fft | |
sample_rate = params.sample_rate | |
hop_length = params.hop_length | |
win_length = params.win_length | |
f_min = params.f_min | |
f_max = params.f_max | |
dec_dim = params.dec_dim | |
beta_min = params.beta_min | |
beta_max = params.beta_max | |
pe_scale = params.pe_scale | |
num_workers = params.num_workers | |
if __name__ == "__main__": | |
torch.manual_seed(random_seed) | |
np.random.seed(random_seed) | |
print('Initializing logger...') | |
logger = SummaryWriter(log_dir=log_dir) | |
print('Initializing data loaders...') | |
train_dataset = TextMelDataset(train_filelist_path, cmudict_path, add_blank, | |
n_fft, n_feats, sample_rate, hop_length, | |
win_length, f_min, f_max) | |
batch_collate = TextMelBatchCollate() | |
loader = DataLoader(dataset=train_dataset, batch_size=batch_size, | |
collate_fn=batch_collate, drop_last=True, | |
num_workers=num_workers, shuffle=False) | |
test_dataset = TextMelDataset(valid_filelist_path, cmudict_path, add_blank, | |
n_fft, n_feats, sample_rate, hop_length, | |
win_length, f_min, f_max) | |
print('Initializing model...') | |
model = GradTTS(nsymbols, 1, None, n_enc_channels, filter_channels, filter_channels_dp, | |
n_heads, n_enc_layers, enc_kernel, enc_dropout, window_size, | |
n_feats, dec_dim, beta_min, beta_max, pe_scale).cuda() | |
print('Number of encoder + duration predictor parameters: %.2fm' % (model.encoder.nparams/1e6)) | |
print('Number of decoder parameters: %.2fm' % (model.decoder.nparams/1e6)) | |
print('Total parameters: %.2fm' % (model.nparams/1e6)) | |
print('Initializing optimizer...') | |
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate) | |
print('Logging test batch...') | |
test_batch = test_dataset.sample_test_batch(size=params.test_size) | |
for i, item in enumerate(test_batch): | |
mel = item['y'] | |
logger.add_image(f'image_{i}/ground_truth', plot_tensor(mel.squeeze()), | |
global_step=0, dataformats='HWC') | |
save_plot(mel.squeeze(), f'{log_dir}/original_{i}.png') | |
print('Start training...') | |
iteration = 0 | |
for epoch in range(1, n_epochs + 1): | |
model.train() | |
dur_losses = [] | |
prior_losses = [] | |
diff_losses = [] | |
with tqdm(loader, total=len(train_dataset)//batch_size) as progress_bar: | |
for batch_idx, batch in enumerate(progress_bar): | |
model.zero_grad() | |
x, x_lengths = batch['x'].cuda(), batch['x_lengths'].cuda() | |
y, y_lengths = batch['y'].cuda(), batch['y_lengths'].cuda() | |
dur_loss, prior_loss, diff_loss = model.compute_loss(x, x_lengths, | |
y, y_lengths, | |
out_size=out_size) | |
loss = sum([dur_loss, prior_loss, diff_loss]) | |
loss.backward() | |
enc_grad_norm = torch.nn.utils.clip_grad_norm_(model.encoder.parameters(), | |
max_norm=1) | |
dec_grad_norm = torch.nn.utils.clip_grad_norm_(model.decoder.parameters(), | |
max_norm=1) | |
optimizer.step() | |
logger.add_scalar('training/duration_loss', dur_loss.item(), | |
global_step=iteration) | |
logger.add_scalar('training/prior_loss', prior_loss.item(), | |
global_step=iteration) | |
logger.add_scalar('training/diffusion_loss', diff_loss.item(), | |
global_step=iteration) | |
logger.add_scalar('training/encoder_grad_norm', enc_grad_norm, | |
global_step=iteration) | |
logger.add_scalar('training/decoder_grad_norm', dec_grad_norm, | |
global_step=iteration) | |
dur_losses.append(dur_loss.item()) | |
prior_losses.append(prior_loss.item()) | |
diff_losses.append(diff_loss.item()) | |
if batch_idx % 5 == 0: | |
msg = f'Epoch: {epoch}, iteration: {iteration} | dur_loss: {dur_loss.item()}, prior_loss: {prior_loss.item()}, diff_loss: {diff_loss.item()}' | |
progress_bar.set_description(msg) | |
iteration += 1 | |
log_msg = 'Epoch %d: duration loss = %.3f ' % (epoch, np.mean(dur_losses)) | |
log_msg += '| prior loss = %.3f ' % np.mean(prior_losses) | |
log_msg += '| diffusion loss = %.3f\n' % np.mean(diff_losses) | |
with open(f'{log_dir}/train.log', 'a') as f: | |
f.write(log_msg) | |
if epoch % params.save_every > 0: | |
continue | |
model.eval() | |
print('Synthesis...') | |
with torch.no_grad(): | |
for i, item in enumerate(test_batch): | |
x = item['x'].to(torch.long).unsqueeze(0).cuda() | |
x_lengths = torch.LongTensor([x.shape[-1]]).cuda() | |
y_enc, y_dec, attn = model(x, x_lengths, n_timesteps=50) | |
logger.add_image(f'image_{i}/generated_enc', | |
plot_tensor(y_enc.squeeze().cpu()), | |
global_step=iteration, dataformats='HWC') | |
logger.add_image(f'image_{i}/generated_dec', | |
plot_tensor(y_dec.squeeze().cpu()), | |
global_step=iteration, dataformats='HWC') | |
logger.add_image(f'image_{i}/alignment', | |
plot_tensor(attn.squeeze().cpu()), | |
global_step=iteration, dataformats='HWC') | |
save_plot(y_enc.squeeze().cpu(), | |
f'{log_dir}/generated_enc_{i}.png') | |
save_plot(y_dec.squeeze().cpu(), | |
f'{log_dir}/generated_dec_{i}.png') | |
save_plot(attn.squeeze().cpu(), | |
f'{log_dir}/alignment_{i}.png') | |
ckpt = model.state_dict() | |
torch.save(ckpt, f=f"{log_dir}/grad_{epoch}.pt") | |