import time import librosa import numpy as np import torch from torch import autocast from torch.cuda.amp import GradScaler from diffusion.logger import utils from diffusion.logger.saver import Saver def test(args, model, vocoder, loader_test, saver): print(' [*] testing...') model.eval() # losses test_loss = 0. # intialization num_batches = len(loader_test) rtf_all = [] # run with torch.no_grad(): for bidx, data in enumerate(loader_test): fn = data['name'][0].split("/")[-1] speaker = data['name'][0].split("/")[-2] print('--------') print('{}/{} - {}'.format(bidx, num_batches, fn)) # unpack data for k in data.keys(): if not k.startswith('name'): data[k] = data[k].to(args.device) print('>>', data['name'][0]) # forward st_time = time.time() mel = model( data['units'], data['f0'], data['volume'], data['spk_id'], gt_spec=None if model.k_step_max == model.timesteps else data['mel'], infer=True, infer_speedup=args.infer.speedup, method=args.infer.method, k_step=model.k_step_max ) signal = vocoder.infer(mel, data['f0']) ed_time = time.time() # RTF run_time = ed_time - st_time song_time = signal.shape[-1] / args.data.sampling_rate rtf = run_time / song_time print('RTF: {} | {} / {}'.format(rtf, run_time, song_time)) rtf_all.append(rtf) # loss for i in range(args.train.batch_size): loss = model( data['units'], data['f0'], data['volume'], data['spk_id'], gt_spec=data['mel'], infer=False, k_step=model.k_step_max) test_loss += loss.item() # log mel saver.log_spec(f"{speaker}_{fn}.wav", data['mel'], mel) # log audi path_audio = data['name_ext'][0] audio, sr = librosa.load(path_audio, sr=args.data.sampling_rate) if len(audio.shape) > 1: audio = librosa.to_mono(audio) audio = torch.from_numpy(audio).unsqueeze(0).to(signal) saver.log_audio({f"{speaker}_{fn}_gt.wav": audio,f"{speaker}_{fn}_pred.wav": signal}) # report test_loss /= args.train.batch_size test_loss /= num_batches # check print(' [test_loss] test_loss:', test_loss) print(' Real Time Factor', np.mean(rtf_all)) return test_loss def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_test): # saver saver = Saver(args, initial_global_step=initial_global_step) # model size params_count = utils.get_network_paras_amount({'model': model}) saver.log_info('--- model size ---') saver.log_info(params_count) # run num_batches = len(loader_train) model.train() saver.log_info('======= start training =======') scaler = GradScaler() if args.train.amp_dtype == 'fp32': dtype = torch.float32 elif args.train.amp_dtype == 'fp16': dtype = torch.float16 elif args.train.amp_dtype == 'bf16': dtype = torch.bfloat16 else: raise ValueError(' [x] Unknown amp_dtype: ' + args.train.amp_dtype) saver.log_info("epoch|batch_idx/num_batches|output_dir|batch/s|lr|time|step") for epoch in range(args.train.epochs): for batch_idx, data in enumerate(loader_train): saver.global_step_increment() optimizer.zero_grad() # unpack data for k in data.keys(): if not k.startswith('name'): data[k] = data[k].to(args.device) # forward if dtype == torch.float32: loss = model(data['units'].float(), data['f0'], data['volume'], data['spk_id'], aug_shift = data['aug_shift'], gt_spec=data['mel'].float(), infer=False, k_step=model.k_step_max) else: with autocast(device_type=args.device, dtype=dtype): loss = model(data['units'], data['f0'], data['volume'], data['spk_id'], aug_shift = data['aug_shift'], gt_spec=data['mel'], infer=False, k_step=model.k_step_max) # handle nan loss if torch.isnan(loss): raise ValueError(' [x] nan loss ') else: # backpropagate if dtype == torch.float32: loss.backward() optimizer.step() else: scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() scheduler.step() # log loss if saver.global_step % args.train.interval_log == 0: current_lr = optimizer.param_groups[0]['lr'] saver.log_info( 'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | lr: {:.6} | loss: {:.3f} | time: {} | step: {}'.format( epoch, batch_idx, num_batches, args.env.expdir, args.train.interval_log/saver.get_interval_time(), current_lr, loss.item(), saver.get_total_time(), saver.global_step ) ) saver.log_value({ 'train/loss': loss.item() }) saver.log_value({ 'train/lr': current_lr }) # validation if saver.global_step % args.train.interval_val == 0: optimizer_save = optimizer if args.train.save_opt else None # save latest saver.save_model(model, optimizer_save, postfix=f'{saver.global_step}') last_val_step = saver.global_step - args.train.interval_val if last_val_step % args.train.interval_force_save != 0: saver.delete_model(postfix=f'{last_val_step}') # run testing set test_loss = test(args, model, vocoder, loader_test, saver) # log loss saver.log_info( ' --- --- \nloss: {:.3f}. '.format( test_loss, ) ) saver.log_value({ 'validation/loss': test_loss }) model.train()