| | import os |
| | import time |
| | import numpy as np |
| | import torch |
| | import librosa |
| | from logger.saver import Saver |
| | from logger import utils |
| | from torch import autocast |
| | from torch.cuda.amp import GradScaler |
| |
|
| | def test(args, model, vocoder, loader_test, saver): |
| | print(' [*] testing...') |
| | model.eval() |
| |
|
| | |
| | test_loss = 0. |
| | |
| | |
| | num_batches = len(loader_test) |
| | rtf_all = [] |
| | |
| | |
| | with torch.no_grad(): |
| | for bidx, data in enumerate(loader_test): |
| | fn = data['name'][0] |
| | print('--------') |
| | print('{}/{} - {}'.format(bidx, num_batches, fn)) |
| |
|
| | |
| | for k in data.keys(): |
| | if k != 'name': |
| | data[k] = data[k].to(args.device) |
| | print('>>', data['name'][0]) |
| |
|
| | |
| | st_time = time.time() |
| | mel = model( |
| | data['units'], |
| | data['f0'], |
| | data['volume'], |
| | data['spk_id'], |
| | gt_spec=None, |
| | infer=True, |
| | infer_speedup=args.infer.speedup, |
| | method=args.infer.method) |
| | signal = vocoder.infer(mel, data['f0']) |
| | ed_time = time.time() |
| | |
| | |
| | run_time = ed_time - st_time |
| | song_time = signal.shape[-1] / args.data.sampling_rate |
| | rtf = run_time / song_time |
| | print('RTF: {} | {} / {}'.format(rtf, run_time, song_time)) |
| | rtf_all.append(rtf) |
| | |
| | |
| | for i in range(args.train.batch_size): |
| | loss = model( |
| | data['units'], |
| | data['f0'], |
| | data['volume'], |
| | data['spk_id'], |
| | gt_spec=data['mel'], |
| | infer=False) |
| | test_loss += loss.item() |
| | |
| | |
| | saver.log_spec(data['name'][0], data['mel'], mel) |
| | |
| | |
| | path_audio = os.path.join(args.data.valid_path, 'audio', data['name'][0]) + '.wav' |
| | audio, sr = librosa.load(path_audio, sr=args.data.sampling_rate) |
| | if len(audio.shape) > 1: |
| | audio = librosa.to_mono(audio) |
| | audio = torch.from_numpy(audio).unsqueeze(0).to(signal) |
| | saver.log_audio({fn+'/gt.wav': audio, fn+'/pred.wav': signal}) |
| | |
| | |
| | test_loss /= args.train.batch_size |
| | test_loss /= num_batches |
| | |
| | |
| | print(' [test_loss] test_loss:', test_loss) |
| | print(' Real Time Factor', np.mean(rtf_all)) |
| | return test_loss |
| |
|
| |
|
| | def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_test): |
| | |
| | saver = Saver(args, initial_global_step=initial_global_step) |
| |
|
| | |
| | params_count = utils.get_network_paras_amount({'model': model}) |
| | saver.log_info('--- model size ---') |
| | saver.log_info(params_count) |
| | |
| | |
| | num_batches = len(loader_train) |
| | model.train() |
| | saver.log_info('======= start training =======') |
| | scaler = GradScaler() |
| | if args.train.amp_dtype == 'fp32': |
| | dtype = torch.float32 |
| | elif args.train.amp_dtype == 'fp16': |
| | dtype = torch.float16 |
| | elif args.train.amp_dtype == 'bf16': |
| | dtype = torch.bfloat16 |
| | else: |
| | raise ValueError(' [x] Unknown amp_dtype: ' + args.train.amp_dtype) |
| | for epoch in range(args.train.epochs): |
| | for batch_idx, data in enumerate(loader_train): |
| | saver.global_step_increment() |
| | optimizer.zero_grad() |
| |
|
| | |
| | for k in data.keys(): |
| | if k != 'name': |
| | data[k] = data[k].to(args.device) |
| | |
| | |
| | if dtype == torch.float32: |
| | loss = model(data['units'].float(), data['f0'], data['volume'], data['spk_id'], |
| | aug_shift = data['aug_shift'], gt_spec=data['mel'].float(), infer=False) |
| | else: |
| | with autocast(device_type=args.device, dtype=dtype): |
| | loss = model(data['units'], data['f0'], data['volume'], data['spk_id'], |
| | aug_shift = data['aug_shift'], gt_spec=data['mel'], infer=False) |
| | |
| | |
| | if torch.isnan(loss): |
| | raise ValueError(' [x] nan loss ') |
| | else: |
| | |
| | if dtype == torch.float32: |
| | loss.backward() |
| | optimizer.step() |
| | else: |
| | scaler.scale(loss).backward() |
| | scaler.step(optimizer) |
| | scaler.update() |
| | scheduler.step() |
| | |
| | |
| | if saver.global_step % args.train.interval_log == 0: |
| | current_lr = optimizer.param_groups[0]['lr'] |
| | saver.log_info( |
| | 'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | lr: {:.6} | loss: {:.3f} | time: {} | step: {}'.format( |
| | epoch, |
| | batch_idx, |
| | num_batches, |
| | args.env.expdir, |
| | args.train.interval_log/saver.get_interval_time(), |
| | current_lr, |
| | loss.item(), |
| | saver.get_total_time(), |
| | saver.global_step |
| | ) |
| | ) |
| | |
| | saver.log_value({ |
| | 'train/loss': loss.item() |
| | }) |
| | |
| | saver.log_value({ |
| | 'train/lr': current_lr |
| | }) |
| | |
| | |
| | if saver.global_step % args.train.interval_val == 0: |
| | optimizer_save = optimizer if args.train.save_opt else None |
| | |
| | |
| | saver.save_model(model, optimizer_save, postfix=f'{saver.global_step}') |
| | last_val_step = saver.global_step - args.train.interval_val |
| | if last_val_step % args.train.interval_force_save != 0: |
| | saver.delete_model(postfix=f'{last_val_step}') |
| | |
| | |
| | test_loss = test(args, model, vocoder, loader_test, saver) |
| | |
| | |
| | saver.log_info( |
| | ' --- <validation> --- \nloss: {:.3f}. '.format( |
| | test_loss, |
| | ) |
| | ) |
| | |
| | saver.log_value({ |
| | 'validation/loss': test_loss |
| | }) |
| | |
| | model.train() |
| |
|
| | |
| |
|