import time from pathlib import Path from os.path import exists import numpy as np import torch import torch.nn.functional as F from torch import no_grad, optim from torch.utils.data import DataLoader import vocoder.hparams as hp from vocoder.display import stream, simple_table from vocoder.distribution import discretized_mix_logistic_loss from vocoder.gen_wavernn import gen_devset from vocoder.models.fatchord_version import WaveRNN from vocoder.vocoder_dataset import VocoderDataset, collate_vocoder from vocoder.utils import ValueWindow from utils.profiler import Profiler def train(run_id: str, syn_dir: Path, voc_dir: Path, models_dir: Path, ground_truth: bool, save_every: int, backup_every: int, force_restart: bool, use_tb: bool): if use_tb: print("Use Tensorboard") import tensorflow as tf import datetime # Hide GPU from visible devices log_dir = f"log/vc/vocoder/tensorboard/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") train_summary_writer = tf.summary.create_file_writer(log_dir) # Check to make sure the hop length is correctly factorised train_syn_dir = syn_dir.joinpath("train") train_voc_dir = voc_dir.joinpath("train") dev_syn_dir = syn_dir.joinpath("dev") dev_voc_dir = voc_dir.joinpath("dev") assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length # Instantiate the model print("Initializing the model...") model = WaveRNN( rnn_dims=hp.voc_rnn_dims, fc_dims=hp.voc_fc_dims, bits=hp.bits, pad=hp.voc_pad, upsample_factors=hp.voc_upsample_factors, feat_dims=hp.num_mels, compute_dims=hp.voc_compute_dims, res_out_dims=hp.voc_res_out_dims, res_blocks=hp.voc_res_blocks, hop_length=hp.hop_length, sample_rate=hp.sample_rate, mode=hp.voc_mode ) if torch.cuda.is_available(): model = model.cuda() # Initialize the optimizer optimizer = optim.Adam(model.parameters()) for p in optimizer.param_groups: p["lr"] = hp.voc_lr loss_func = F.cross_entropy if model.mode == "RAW" else discretized_mix_logistic_loss train_loss_window = ValueWindow(100) # Load the weights model_dir = models_dir / run_id model_dir.mkdir(exist_ok=True) weights_fpath = model_dir / "vocoder.pt" # train_loss_file_path = "vocoder_loss/vocoder_train_loss.npy" # dev_loss_file_path = "vocoder_loss/vocoder_dev_loss.npy" # if not exists("vocoder_loss"): # import os # os.mkdir("vocoder_loss") if force_restart or not weights_fpath.exists(): print("\nStarting the training of WaveRNN from scratch\n") model.save(weights_fpath, optimizer) # losses = [] # dev_losses = [] else: print("\nLoading weights at %s" % weights_fpath) model.load(weights_fpath, optimizer) print("WaveRNN weights loaded from step %d" % model.step) # losses = list(np.load(train_loss_file_path)) if exists(train_loss_file_path) else [] # dev_losses = list(np.load(dev_loss_file_path)) if exists(dev_loss_file_path) else [] # Initialize the dataset train_metadata_fpath = train_syn_dir.joinpath("train.txt") if ground_truth else \ train_voc_dir.joinpath("synthesized.txt") train_mel_dir = train_syn_dir.joinpath("mels") if ground_truth else train_voc_dir.joinpath("mels_gta") train_wav_dir = train_syn_dir.joinpath("audio") train_dataset = VocoderDataset(train_metadata_fpath, train_mel_dir, train_wav_dir) dev_metadata_fpath = dev_syn_dir.joinpath("dev.txt") if ground_truth else \ dev_voc_dir.joinpath("synthesized.txt") dev_mel_dir = dev_syn_dir.joinpath("mels") if ground_truth else dev_voc_dir.joinpath("mels_gta") dev_wav_dir = dev_syn_dir.joinpath("audio") dev_dataset = VocoderDataset(dev_metadata_fpath, dev_mel_dir, dev_wav_dir) train_dataloader = DataLoader(train_dataset, hp.voc_batch_size, shuffle=True, num_workers=8, collate_fn=collate_vocoder, pin_memory=True) dev_dataloader = DataLoader(dev_dataset, hp.voc_batch_size, shuffle=True, num_workers=8, collate_fn=collate_vocoder, pin_memory=True) dev_dataloader_ = DataLoader(dev_dataset, 1, shuffle=True) # Begin the training simple_table([('Batch size', hp.voc_batch_size), ('LR', hp.voc_lr), ('Sequence Len', hp.voc_seq_len)]) # best_loss_file_path = "vocoder_loss/best_loss.npy" # best_loss = np.load(best_loss_file_path)[0] if exists(best_loss_file_path) else 1000 # profiler = Profiler(summarize_every=10, disabled=False) for epoch in range(1, 3500): start = time.time() for i, (x, y, m) in enumerate(train_dataloader, 1): model.train() # profiler.tick("Blocking, waiting for batch (threaded)") if torch.cuda.is_available(): x, m, y = x.cuda(), m.cuda(), y.cuda() # profiler.tick("Data to cuda") # Forward pass y_hat = model(x, m) if model.mode == 'RAW': y_hat = y_hat.transpose(1, 2).unsqueeze(-1) elif model.mode == 'MOL': y = y.float() y = y.unsqueeze(-1) # profiler.tick("Forward pass") # Backward pass loss = loss_func(y_hat, y) # profiler.tick("Loss") optimizer.zero_grad() loss.backward() # profiler.tick("Backward pass") optimizer.step() # profiler.tick("Parameter update") speed = i / (time.time() - start) train_loss_window.append(loss.item()) step = model.get_step() k = step // 1000 msg = f"| Epoch: {epoch} ({i}/{len(train_dataloader)}) | " \ f"Train Loss: {train_loss_window.average:.4f} | " \ f"{speed:.4f}steps/s | Step: {k}k | " stream(msg) if use_tb: with train_summary_writer.as_default(): tf.summary.scalar('train_loss', train_loss_window.average, step=step) torch.cuda.empty_cache() if backup_every != 0 and step % backup_every == 0 : model.checkpoint(model_dir, optimizer) if save_every != 0 and step % save_every == 0 : dev_loss = validate(dev_dataloader, model, loss_func) msg = f"| Epoch: {epoch} ({i}/{len(train_dataloader)}) | " \ f"Train Loss: {train_loss_window.average:.4f} | Dev Loss: {dev_loss:.4f} | " \ f"{speed:.4f}steps/s | Step: {k}k | " stream(msg) if use_tb: with train_summary_writer.as_default(): tf.summary.scalar('val_loss', dev_loss, step=step) # losses.append(train_loss_window.average) # np.save(train_loss_file_path, np.array(losses, dtype=float)) # dev_losses.append(dev_loss) # np.save(dev_loss_file_path, np.array(dev_losses, dtype=float)) # if dev_loss < best_loss : # best_loss = dev_loss # np.save(best_loss_file_path, np.array([best_loss])) model.save(weights_fpath, optimizer) # profiler.tick("Extra saving") # gen_devset(model, dev_dataloader_, hp.voc_gen_at_checkpoint, hp.voc_gen_batched, # hp.voc_target, hp.voc_overlap, model_dir) print("") def validate(dataloader, model, loss_func): model.eval() losses = [] with no_grad(): for i, (x, y, m) in enumerate(dataloader, 1): if torch.cuda.is_available(): x, m, y = x.cuda(), m.cuda(), y.cuda() y_hat = model(x, m) if model.mode == 'RAW': y_hat = y_hat.transpose(1, 2).unsqueeze(-1) elif model.mode == 'MOL': y = y.float() y = y.unsqueeze(-1) loss = loss_func(y_hat, y).item() losses.append(loss) torch.cuda.empty_cache() return sum(losses) / len(losses)