import argparse import os import sys import traceback import time import glob import random import torch from torch.utils.data import DataLoader # from torch.utils.data.distributed import DistributedSampler from TTS.tts.utils.visual import plot_spectrogram from TTS.utils.audio import AudioProcessor from TTS.utils.radam import RAdam from TTS.utils.io import copy_model_files, load_config from TTS.utils.training import setup_torch_training_env from TTS.utils.console_logger import ConsoleLogger from TTS.utils.tensorboard_logger import TensorboardLogger from TTS.utils.generic_utils import ( KeepAverage, count_parameters, create_experiment_folder, get_git_branch, remove_experiment_folder, set_init_dict, ) from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset from TTS.vocoder.datasets.preprocess import ( load_wav_data, load_wav_feat_data ) from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss from TTS.vocoder.utils.generic_utils import setup_wavernn from TTS.vocoder.utils.io import save_best_model, save_checkpoint use_cuda, num_gpus = setup_torch_training_env(True, True) def setup_loader(ap, is_val=False, verbose=False): if is_val and not c.run_eval: loader = None else: dataset = WaveRNNDataset(ap=ap, items=eval_data if is_val else train_data, seq_len=c.seq_len, hop_len=ap.hop_length, pad=c.padding, mode=c.mode, mulaw=c.mulaw, is_training=not is_val, verbose=verbose, ) # sampler = DistributedSampler(dataset) if num_gpus > 1 else None loader = DataLoader(dataset, shuffle=True, collate_fn=dataset.collate, batch_size=c.batch_size, num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers, pin_memory=True, ) return loader def format_data(data): # setup input data x_input = data[0] mels = data[1] y_coarse = data[2] # dispatch data to GPU if use_cuda: x_input = x_input.cuda(non_blocking=True) mels = mels.cuda(non_blocking=True) y_coarse = y_coarse.cuda(non_blocking=True) return x_input, mels, y_coarse def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch): # create train loader data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) model.train() epoch_time = 0 keep_avg = KeepAverage() if use_cuda: batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() c_logger.print_train_start() # train loop for num_iter, data in enumerate(data_loader): start_time = time.time() x_input, mels, y_coarse = format_data(data) loader_time = time.time() - end_time global_step += 1 optimizer.zero_grad() if c.mixed_precision: # mixed precision training with torch.cuda.amp.autocast(): y_hat = model(x_input, mels) if isinstance(model.mode, int): y_hat = y_hat.transpose(1, 2).unsqueeze(-1) else: y_coarse = y_coarse.float() y_coarse = y_coarse.unsqueeze(-1) # compute losses loss = criterion(y_hat, y_coarse) scaler.scale(loss).backward() scaler.unscale_(optimizer) if c.grad_clip > 0: torch.nn.utils.clip_grad_norm_( model.parameters(), c.grad_clip) scaler.step(optimizer) scaler.update() else: # full precision training y_hat = model(x_input, mels) if isinstance(model.mode, int): y_hat = y_hat.transpose(1, 2).unsqueeze(-1) else: y_coarse = y_coarse.float() y_coarse = y_coarse.unsqueeze(-1) # compute losses loss = criterion(y_hat, y_coarse) if loss.item() is None: raise RuntimeError(" [!] None loss. Exiting ...") loss.backward() if c.grad_clip > 0: torch.nn.utils.clip_grad_norm_( model.parameters(), c.grad_clip) optimizer.step() if scheduler is not None: scheduler.step() # get the current learning rate cur_lr = list(optimizer.param_groups)[0]["lr"] step_time = time.time() - start_time epoch_time += step_time update_train_values = dict() loss_dict = dict() loss_dict["model_loss"] = loss.item() for key, value in loss_dict.items(): update_train_values["avg_" + key] = value update_train_values["avg_loader_time"] = loader_time update_train_values["avg_step_time"] = step_time keep_avg.update_values(update_train_values) # print training stats if global_step % c.print_step == 0: log_dict = {"step_time": [step_time, 2], "loader_time": [loader_time, 4], "current_lr": cur_lr, } c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values, ) # plot step stats if global_step % 10 == 0: iter_stats = {"lr": cur_lr, "step_time": step_time} iter_stats.update(loss_dict) tb_logger.tb_train_iter_stats(global_step, iter_stats) # save checkpoint if global_step % c.save_step == 0: if c.checkpoint: # save model save_checkpoint(model, optimizer, scheduler, None, None, None, global_step, epoch, OUT_PATH, model_losses=loss_dict, scaler=scaler.state_dict() if c.mixed_precision else None ) # synthesize a full voice rand_idx = random.randrange(0, len(train_data)) wav_path = train_data[rand_idx] if not isinstance( train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0] wav = ap.load_wav(wav_path) ground_mel = ap.melspectrogram(wav) sample_wav = model.generate(ground_mel, c.batched, c.target_samples, c.overlap_samples, use_cuda ) predict_mel = ap.melspectrogram(sample_wav) # compute spectrograms figures = {"train/ground_truth": plot_spectrogram(ground_mel.T), "train/prediction": plot_spectrogram(predict_mel.T) } tb_logger.tb_train_figures(global_step, figures) # Sample audio tb_logger.tb_train_audios( global_step, { "train/audio": sample_wav}, c.audio["sample_rate"] ) end_time = time.time() # print epoch stats c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg) # Plot Training Epoch Stats epoch_stats = {"epoch_time": epoch_time} epoch_stats.update(keep_avg.avg_values) tb_logger.tb_train_epoch_stats(global_step, epoch_stats) # TODO: plot model stats # if c.tb_model_param_stats: # tb_logger.tb_model_weights(model, global_step) return keep_avg.avg_values, global_step @torch.no_grad() def evaluate(model, criterion, ap, global_step, epoch): # create train loader data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0)) model.eval() epoch_time = 0 keep_avg = KeepAverage() end_time = time.time() c_logger.print_eval_start() with torch.no_grad(): for num_iter, data in enumerate(data_loader): start_time = time.time() # format data x_input, mels, y_coarse = format_data(data) loader_time = time.time() - end_time global_step += 1 y_hat = model(x_input, mels) if isinstance(model.mode, int): y_hat = y_hat.transpose(1, 2).unsqueeze(-1) else: y_coarse = y_coarse.float() y_coarse = y_coarse.unsqueeze(-1) loss = criterion(y_hat, y_coarse) # Compute avg loss # if num_gpus > 1: # loss = reduce_tensor(loss.data, num_gpus) loss_dict = dict() loss_dict["model_loss"] = loss.item() step_time = time.time() - start_time epoch_time += step_time # update avg stats update_eval_values = dict() for key, value in loss_dict.items(): update_eval_values["avg_" + key] = value update_eval_values["avg_loader_time"] = loader_time update_eval_values["avg_step_time"] = step_time keep_avg.update_values(update_eval_values) # print eval stats if c.print_eval: c_logger.print_eval_step( num_iter, loss_dict, keep_avg.avg_values) if epoch % c.test_every_epochs == 0 and epoch != 0: # synthesize a full voice rand_idx = random.randrange(0, len(eval_data)) wav_path = eval_data[rand_idx] if not isinstance( eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0] wav = ap.load_wav(wav_path) ground_mel = ap.melspectrogram(wav) sample_wav = model.generate(ground_mel, c.batched, c.target_samples, c.overlap_samples, use_cuda ) predict_mel = ap.melspectrogram(sample_wav) # Sample audio tb_logger.tb_eval_audios( global_step, { "eval/audio": sample_wav}, c.audio["sample_rate"] ) # compute spectrograms figures = {"eval/ground_truth": plot_spectrogram(ground_mel.T), "eval/prediction": plot_spectrogram(predict_mel.T) } tb_logger.tb_eval_figures(global_step, figures) tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) return keep_avg.avg_values # FIXME: move args definition/parsing inside of main? def main(args): # pylint: disable=redefined-outer-name # pylint: disable=global-variable-undefined global train_data, eval_data # setup audio processor ap = AudioProcessor(**c.audio) # print(f" > Loading wavs from: {c.data_path}") # if c.feature_path is not None: # print(f" > Loading features from: {c.feature_path}") # eval_data, train_data = load_wav_feat_data( # c.data_path, c.feature_path, c.eval_split_size # ) # else: # mel_feat_path = os.path.join(OUT_PATH, "mel") # feat_data = find_feat_files(mel_feat_path) # if feat_data: # print(f" > Loading features from: {mel_feat_path}") # eval_data, train_data = load_wav_feat_data( # c.data_path, mel_feat_path, c.eval_split_size # ) # else: # print(" > No feature data found. Preprocessing...") # # preprocessing feature data from given wav files # preprocess_wav_files(OUT_PATH, CONFIG, ap) # eval_data, train_data = load_wav_feat_data( # c.data_path, mel_feat_path, c.eval_split_size # ) print(f" > Loading wavs from: {c.data_path}") if c.feature_path is not None: print(f" > Loading features from: {c.feature_path}") eval_data, train_data = load_wav_feat_data( c.data_path, c.feature_path, c.eval_split_size) else: eval_data, train_data = load_wav_data( c.data_path, c.eval_split_size) # setup model model_wavernn = setup_wavernn(c) # setup amp scaler scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None # define train functions if c.mode == "mold": criterion = discretized_mix_logistic_loss elif c.mode == "gauss": criterion = gaussian_loss elif isinstance(c.mode, int): criterion = torch.nn.CrossEntropyLoss() if use_cuda: model_wavernn.cuda() if isinstance(c.mode, int): criterion.cuda() optimizer = RAdam(model_wavernn.parameters(), lr=c.lr, weight_decay=0) scheduler = None if "lr_scheduler" in c: scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler) scheduler = scheduler(optimizer, **c.lr_scheduler_params) # slow start for the first 5 epochs # lr_lambda = lambda epoch: min(epoch / c.warmup_steps, 1) # scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) # restore any checkpoint if args.restore_path: checkpoint = torch.load(args.restore_path, map_location="cpu") try: print(" > Restoring Model...") model_wavernn.load_state_dict(checkpoint["model"]) print(" > Restoring Optimizer...") optimizer.load_state_dict(checkpoint["optimizer"]) if "scheduler" in checkpoint: print(" > Restoring Generator LR Scheduler...") scheduler.load_state_dict(checkpoint["scheduler"]) scheduler.optimizer = optimizer if "scaler" in checkpoint and c.mixed_precision: print(" > Restoring AMP Scaler...") scaler.load_state_dict(checkpoint["scaler"]) except RuntimeError: # retore only matching layers. print(" > Partial model initialization...") model_dict = model_wavernn.state_dict() model_dict = set_init_dict(model_dict, checkpoint["model"], c) model_wavernn.load_state_dict(model_dict) print(" > Model restored from step %d" % checkpoint["step"], flush=True) args.restore_step = checkpoint["step"] else: args.restore_step = 0 # DISTRIBUTED # if num_gpus > 1: # model = apply_gradient_allreduce(model) num_parameters = count_parameters(model_wavernn) print(" > Model has {} parameters".format(num_parameters), flush=True) if "best_loss" not in locals(): best_loss = float("inf") global_step = args.restore_step for epoch in range(0, c.epochs): c_logger.print_epoch_start(epoch, c.epochs) _, global_step = train(model_wavernn, optimizer, criterion, scheduler, scaler, ap, global_step, epoch) eval_avg_loss_dict = evaluate( model_wavernn, criterion, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = eval_avg_loss_dict["avg_model_loss"] best_loss = save_best_model( target_loss, best_loss, model_wavernn, optimizer, scheduler, None, None, None, global_step, epoch, OUT_PATH, model_losses=eval_avg_loss_dict, scaler=scaler.state_dict() if c.mixed_precision else None ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--continue_path", type=str, help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.', default="", required="--config_path" not in sys.argv, ) parser.add_argument( "--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default="", ) parser.add_argument( "--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in sys.argv, ) parser.add_argument( "--debug", type=bool, default=False, help="Do not verify commit integrity to run training.", ) # DISTRUBUTED parser.add_argument( "--rank", type=int, default=0, help="DISTRIBUTED: process rank for distributed training.", ) parser.add_argument( "--group_id", type=str, default="", help="DISTRIBUTED: process group id." ) args = parser.parse_args() if args.continue_path != "": args.output_path = args.continue_path args.config_path = os.path.join(args.continue_path, "config.json") list_of_files = glob.glob( args.continue_path + "/*.pth.tar" ) # * means all if need specific format then *.csv latest_model_file = max(list_of_files, key=os.path.getctime) args.restore_path = latest_model_file print(f" > Training continues for {args.restore_path}") # setup output paths and read configs c = load_config(args.config_path) # check_config(c) _ = os.path.dirname(os.path.realpath(__file__)) OUT_PATH = args.continue_path if args.continue_path == "": OUT_PATH = create_experiment_folder( c.output_path, c.run_name, args.debug ) AUDIO_PATH = os.path.join(OUT_PATH, "test_audios") c_logger = ConsoleLogger() if args.rank == 0: os.makedirs(AUDIO_PATH, exist_ok=True) new_fields = {} if args.restore_path: new_fields["restore_path"] = args.restore_path new_fields["github_branch"] = get_git_branch() copy_model_files( c, args.config_path, OUT_PATH, new_fields ) os.chmod(AUDIO_PATH, 0o775) os.chmod(OUT_PATH, 0o775) LOG_DIR = OUT_PATH tb_logger = TensorboardLogger(LOG_DIR, model_name="VOCODER") # write model desc to tensorboard tb_logger.tb_add_text("model-description", c["run_description"], 0) try: main(args) except KeyboardInterrupt: remove_experiment_folder(OUT_PATH) try: sys.exit(0) except SystemExit: os._exit(0) # pylint: disable=protected-access except Exception: # pylint: disable=broad-except remove_experiment_folder(OUT_PATH) traceback.print_exc() sys.exit(1)